# ================== SAFE IMPORTS ================== import matplotlib matplotlib.use("Agg") import numpy as np import matplotlib.pyplot as plt import torch import torch.optim as optim from scipy.special import beta as beta_func from scipy.stats import beta, norm import io import gradio as gr from PIL import Image # ========== 1. MATHEMATICAL CORE ========== def exact_bayesian_inference(prior_a, prior_b, heads, tails): posterior_a = prior_a + heads posterior_b = prior_b + tails x = np.linspace(0.001, 0.999, 400) prior_pdf = beta.pdf(x, prior_a, prior_b) posterior_pdf = beta.pdf(x, posterior_a, posterior_b) return x, prior_pdf, posterior_pdf, posterior_a, posterior_b def variational_inference(prior_a, prior_b, heads, tails, num_iterations=800, lr=0.01): mu = torch.tensor(0.0, requires_grad=True) log_sigma = torch.tensor(0.0, requires_grad=True) optimizer = optim.Adam([mu, log_sigma], lr=lr) elbo_history = [] for _ in range(num_iterations): optimizer.zero_grad() sigma = torch.exp(log_sigma) eps = torch.randn(1) z = mu + sigma * eps theta = torch.sigmoid(z) log_likelihood = heads * torch.log(theta + 1e-10) + \ tails * torch.log(1 - theta + 1e-10) log_prior = (prior_a - 1) * torch.log(theta + 1e-10) + \ (prior_b - 1) * torch.log(1 - theta + 1e-10) - \ torch.log(torch.tensor(beta_func(prior_a, prior_b))) log_q = -0.5 * torch.log(torch.tensor(2 * np.pi)) - log_sigma - 0.5 * eps**2 elbo = log_likelihood + log_prior - log_q (-elbo).backward() optimizer.step() elbo_history.append(elbo.item()) mu_f = mu.item() sigma_f = torch.exp(log_sigma).item() z_grid = np.linspace(mu_f - 3 * sigma_f, mu_f + 3 * sigma_f, 400) q_z = norm.pdf(z_grid, mu_f, sigma_f) theta_grid = 1 / (1 + np.exp(-z_grid)) q_theta = q_z / (theta_grid * (1 - theta_grid) + 1e-10) q_theta /= np.trapz(q_theta, theta_grid) return theta_grid, q_theta, elbo_history, mu_f, sigma_f # ========== 2. SAFE FIGURE → PIL ========== def fig_to_pil(fig): buf = io.BytesIO() fig.savefig(buf, format="png", dpi=100) plt.close(fig) buf.seek(0) return Image.open(buf).copy() # ========== 3. VISUALIZATION ========== def create_plot(prior_a, prior_b, heads, tails, show_vi=True): fig, axes = plt.subplots(1, 2 if show_vi else 1, figsize=(12, 4) if show_vi else (6, 4)) if not show_vi: axes = [axes] x, prior_pdf, posterior_pdf, post_a, post_b = exact_bayesian_inference( prior_a, prior_b, heads, tails ) ax = axes[0] ax.plot(x, prior_pdf, label=f"Prior Beta({prior_a},{prior_b})") ax.plot(x, posterior_pdf, linewidth=3, label=f"Posterior Beta({post_a:.1f},{post_b:.1f})") ax.fill_between(x, 0, posterior_pdf, alpha=0.2) ax.set_title("Exact Bayesian Inference") ax.set_xlabel("θ") ax.set_ylabel("Density") ax.legend() ax.grid(alpha=0.3) if show_vi: ax = axes[1] theta_grid, q_theta, elbo_hist, mu_f, sigma_f = variational_inference( prior_a, prior_b, heads, tails ) ax.plot(x, prior_pdf, label="Prior") ax.plot(x, posterior_pdf, label="Exact Posterior") ax.plot(theta_grid, q_theta, "--", linewidth=3, label=f"VI N({mu_f:.2f},{sigma_f:.2f})") ax.fill_between(theta_grid, 0, q_theta, alpha=0.2) ax.set_title("Variational Approximation") ax.set_xlabel("θ") ax.set_ylabel("Density") ax.legend() ax.grid(alpha=0.3) plt.tight_layout() return fig_to_pil(fig) # ========== 4. BRAIN HEALTH DEMO ========== def brain_health_demo(): np.random.seed(42) t = np.linspace(0, 3, 50) true_cbf, true_att = 60, 1.5 signal = true_cbf * (1 - np.exp(-t / true_att)) * np.exp(-t / 1.6) noisy = signal + np.random.normal(0, 3, len(t)) fig, ax = plt.subplots(1, 2, figsize=(12, 4)) ax[0].plot(t, signal, linewidth=3, label="True Signal") ax[0].scatter(t, noisy, alpha=0.6, label="Noisy") ax[0].set_title("Simulated ASL-MRI") ax[0].legend() ax[0].grid(alpha=0.3) cbf = np.random.normal(58, 5, 1000) att = np.random.normal(1.6, 0.3, 1000) from scipy.stats import gaussian_kde x = np.linspace(40, 80, 100) y = np.linspace(0.8, 2.4, 100) X, Y = np.meshgrid(x, y) Z = gaussian_kde(np.vstack([cbf, att]))(np.vstack([X.ravel(), Y.ravel()])).reshape(X.shape) ax[1].contour(X, Y, Z, levels=5) ax[1].scatter([true_cbf], [true_att], color="red", s=200, marker="*") ax[1].set_title("VI Parameter Posterior") ax[1].grid(alpha=0.3) plt.tight_layout() return fig_to_pil(fig) # ========== 5. GRADIO UI ========== with gr.Blocks(title="Variational Inference Playground") as demo: gr.Markdown("# 🧠 Variational Inference Playground") with gr.Tab("🎯 Coin Flip"): with gr.Row(): with gr.Column(): prior_a = gr.Slider(0.1, 10, 2, 0.1, label="α") prior_b = gr.Slider(0.1, 10, 2, 0.1, label="β") heads = gr.Slider(0, 100, 8, 1, label="Heads") tails = gr.Slider(0, 100, 4, 1, label="Tails") show_vi = gr.Checkbox(True, label="Show VI") run_btn = gr.Button("Update", variant="primary") with gr.Column(): plot_output = gr.Image(type="pil") with gr.Tab("🧠 Brain Health"): brain_btn = gr.Button("Run Demo", variant="primary") brain_output = gr.Image(type="pil") run_btn.click( update_plot := lambda a, b, h, t, v: create_plot(a, b, h, t, v), inputs=[prior_a, prior_b, heads, tails, show_vi], outputs=plot_output ) brain_btn.click(brain_health_demo, outputs=brain_output) demo.load(lambda: create_plot(2, 2, 8, 4, True), outputs=plot_output) if __name__ == "__main__": demo.launch()