Variational / app.py
AndaiMD's picture
requirements.txt
9011b8c
# ================== SAFE IMPORTS ==================
import matplotlib
matplotlib.use("Agg")
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
from scipy.special import beta as beta_func
from scipy.stats import beta, norm
import io
import gradio as gr
from PIL import Image
# ========== 1. MATHEMATICAL CORE ==========
def exact_bayesian_inference(prior_a, prior_b, heads, tails):
posterior_a = prior_a + heads
posterior_b = prior_b + tails
x = np.linspace(0.001, 0.999, 400)
prior_pdf = beta.pdf(x, prior_a, prior_b)
posterior_pdf = beta.pdf(x, posterior_a, posterior_b)
return x, prior_pdf, posterior_pdf, posterior_a, posterior_b
def variational_inference(prior_a, prior_b, heads, tails, num_iterations=800, lr=0.01):
mu = torch.tensor(0.0, requires_grad=True)
log_sigma = torch.tensor(0.0, requires_grad=True)
optimizer = optim.Adam([mu, log_sigma], lr=lr)
elbo_history = []
for _ in range(num_iterations):
optimizer.zero_grad()
sigma = torch.exp(log_sigma)
eps = torch.randn(1)
z = mu + sigma * eps
theta = torch.sigmoid(z)
log_likelihood = heads * torch.log(theta + 1e-10) + \
tails * torch.log(1 - theta + 1e-10)
log_prior = (prior_a - 1) * torch.log(theta + 1e-10) + \
(prior_b - 1) * torch.log(1 - theta + 1e-10) - \
torch.log(torch.tensor(beta_func(prior_a, prior_b)))
log_q = -0.5 * torch.log(torch.tensor(2 * np.pi)) - log_sigma - 0.5 * eps**2
elbo = log_likelihood + log_prior - log_q
(-elbo).backward()
optimizer.step()
elbo_history.append(elbo.item())
mu_f = mu.item()
sigma_f = torch.exp(log_sigma).item()
z_grid = np.linspace(mu_f - 3 * sigma_f, mu_f + 3 * sigma_f, 400)
q_z = norm.pdf(z_grid, mu_f, sigma_f)
theta_grid = 1 / (1 + np.exp(-z_grid))
q_theta = q_z / (theta_grid * (1 - theta_grid) + 1e-10)
q_theta /= np.trapz(q_theta, theta_grid)
return theta_grid, q_theta, elbo_history, mu_f, sigma_f
# ========== 2. SAFE FIGURE → PIL ==========
def fig_to_pil(fig):
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=100)
plt.close(fig)
buf.seek(0)
return Image.open(buf).copy()
# ========== 3. VISUALIZATION ==========
def create_plot(prior_a, prior_b, heads, tails, show_vi=True):
fig, axes = plt.subplots(1, 2 if show_vi else 1,
figsize=(12, 4) if show_vi else (6, 4))
if not show_vi:
axes = [axes]
x, prior_pdf, posterior_pdf, post_a, post_b = exact_bayesian_inference(
prior_a, prior_b, heads, tails
)
ax = axes[0]
ax.plot(x, prior_pdf, label=f"Prior Beta({prior_a},{prior_b})")
ax.plot(x, posterior_pdf, linewidth=3,
label=f"Posterior Beta({post_a:.1f},{post_b:.1f})")
ax.fill_between(x, 0, posterior_pdf, alpha=0.2)
ax.set_title("Exact Bayesian Inference")
ax.set_xlabel("θ")
ax.set_ylabel("Density")
ax.legend()
ax.grid(alpha=0.3)
if show_vi:
ax = axes[1]
theta_grid, q_theta, elbo_hist, mu_f, sigma_f = variational_inference(
prior_a, prior_b, heads, tails
)
ax.plot(x, prior_pdf, label="Prior")
ax.plot(x, posterior_pdf, label="Exact Posterior")
ax.plot(theta_grid, q_theta, "--", linewidth=3,
label=f"VI N({mu_f:.2f},{sigma_f:.2f})")
ax.fill_between(theta_grid, 0, q_theta, alpha=0.2)
ax.set_title("Variational Approximation")
ax.set_xlabel("θ")
ax.set_ylabel("Density")
ax.legend()
ax.grid(alpha=0.3)
plt.tight_layout()
return fig_to_pil(fig)
# ========== 4. BRAIN HEALTH DEMO ==========
def brain_health_demo():
np.random.seed(42)
t = np.linspace(0, 3, 50)
true_cbf, true_att = 60, 1.5
signal = true_cbf * (1 - np.exp(-t / true_att)) * np.exp(-t / 1.6)
noisy = signal + np.random.normal(0, 3, len(t))
fig, ax = plt.subplots(1, 2, figsize=(12, 4))
ax[0].plot(t, signal, linewidth=3, label="True Signal")
ax[0].scatter(t, noisy, alpha=0.6, label="Noisy")
ax[0].set_title("Simulated ASL-MRI")
ax[0].legend()
ax[0].grid(alpha=0.3)
cbf = np.random.normal(58, 5, 1000)
att = np.random.normal(1.6, 0.3, 1000)
from scipy.stats import gaussian_kde
x = np.linspace(40, 80, 100)
y = np.linspace(0.8, 2.4, 100)
X, Y = np.meshgrid(x, y)
Z = gaussian_kde(np.vstack([cbf, att]))(np.vstack([X.ravel(), Y.ravel()])).reshape(X.shape)
ax[1].contour(X, Y, Z, levels=5)
ax[1].scatter([true_cbf], [true_att], color="red", s=200, marker="*")
ax[1].set_title("VI Parameter Posterior")
ax[1].grid(alpha=0.3)
plt.tight_layout()
return fig_to_pil(fig)
# ========== 5. GRADIO UI ==========
with gr.Blocks(title="Variational Inference Playground") as demo:
gr.Markdown("# 🧠 Variational Inference Playground")
with gr.Tab("🎯 Coin Flip"):
with gr.Row():
with gr.Column():
prior_a = gr.Slider(0.1, 10, 2, 0.1, label="α")
prior_b = gr.Slider(0.1, 10, 2, 0.1, label="β")
heads = gr.Slider(0, 100, 8, 1, label="Heads")
tails = gr.Slider(0, 100, 4, 1, label="Tails")
show_vi = gr.Checkbox(True, label="Show VI")
run_btn = gr.Button("Update", variant="primary")
with gr.Column():
plot_output = gr.Image(type="pil")
with gr.Tab("🧠 Brain Health"):
brain_btn = gr.Button("Run Demo", variant="primary")
brain_output = gr.Image(type="pil")
run_btn.click(
update_plot := lambda a, b, h, t, v: create_plot(a, b, h, t, v),
inputs=[prior_a, prior_b, heads, tails, show_vi],
outputs=plot_output
)
brain_btn.click(brain_health_demo, outputs=brain_output)
demo.load(lambda: create_plot(2, 2, 8, 4, True), outputs=plot_output)
if __name__ == "__main__":
demo.launch()