Spaces:
Sleeping
Sleeping
| # ================== SAFE IMPORTS ================== | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import torch | |
| import torch.optim as optim | |
| from scipy.special import beta as beta_func | |
| from scipy.stats import beta, norm | |
| import io | |
| import gradio as gr | |
| from PIL import Image | |
| # ========== 1. MATHEMATICAL CORE ========== | |
| def exact_bayesian_inference(prior_a, prior_b, heads, tails): | |
| posterior_a = prior_a + heads | |
| posterior_b = prior_b + tails | |
| x = np.linspace(0.001, 0.999, 400) | |
| prior_pdf = beta.pdf(x, prior_a, prior_b) | |
| posterior_pdf = beta.pdf(x, posterior_a, posterior_b) | |
| return x, prior_pdf, posterior_pdf, posterior_a, posterior_b | |
| def variational_inference(prior_a, prior_b, heads, tails, num_iterations=800, lr=0.01): | |
| mu = torch.tensor(0.0, requires_grad=True) | |
| log_sigma = torch.tensor(0.0, requires_grad=True) | |
| optimizer = optim.Adam([mu, log_sigma], lr=lr) | |
| elbo_history = [] | |
| for _ in range(num_iterations): | |
| optimizer.zero_grad() | |
| sigma = torch.exp(log_sigma) | |
| eps = torch.randn(1) | |
| z = mu + sigma * eps | |
| theta = torch.sigmoid(z) | |
| log_likelihood = heads * torch.log(theta + 1e-10) + \ | |
| tails * torch.log(1 - theta + 1e-10) | |
| log_prior = (prior_a - 1) * torch.log(theta + 1e-10) + \ | |
| (prior_b - 1) * torch.log(1 - theta + 1e-10) - \ | |
| torch.log(torch.tensor(beta_func(prior_a, prior_b))) | |
| log_q = -0.5 * torch.log(torch.tensor(2 * np.pi)) - log_sigma - 0.5 * eps**2 | |
| elbo = log_likelihood + log_prior - log_q | |
| (-elbo).backward() | |
| optimizer.step() | |
| elbo_history.append(elbo.item()) | |
| mu_f = mu.item() | |
| sigma_f = torch.exp(log_sigma).item() | |
| z_grid = np.linspace(mu_f - 3 * sigma_f, mu_f + 3 * sigma_f, 400) | |
| q_z = norm.pdf(z_grid, mu_f, sigma_f) | |
| theta_grid = 1 / (1 + np.exp(-z_grid)) | |
| q_theta = q_z / (theta_grid * (1 - theta_grid) + 1e-10) | |
| q_theta /= np.trapz(q_theta, theta_grid) | |
| return theta_grid, q_theta, elbo_history, mu_f, sigma_f | |
| # ========== 2. SAFE FIGURE → PIL ========== | |
| def fig_to_pil(fig): | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", dpi=100) | |
| plt.close(fig) | |
| buf.seek(0) | |
| return Image.open(buf).copy() | |
| # ========== 3. VISUALIZATION ========== | |
| def create_plot(prior_a, prior_b, heads, tails, show_vi=True): | |
| fig, axes = plt.subplots(1, 2 if show_vi else 1, | |
| figsize=(12, 4) if show_vi else (6, 4)) | |
| if not show_vi: | |
| axes = [axes] | |
| x, prior_pdf, posterior_pdf, post_a, post_b = exact_bayesian_inference( | |
| prior_a, prior_b, heads, tails | |
| ) | |
| ax = axes[0] | |
| ax.plot(x, prior_pdf, label=f"Prior Beta({prior_a},{prior_b})") | |
| ax.plot(x, posterior_pdf, linewidth=3, | |
| label=f"Posterior Beta({post_a:.1f},{post_b:.1f})") | |
| ax.fill_between(x, 0, posterior_pdf, alpha=0.2) | |
| ax.set_title("Exact Bayesian Inference") | |
| ax.set_xlabel("θ") | |
| ax.set_ylabel("Density") | |
| ax.legend() | |
| ax.grid(alpha=0.3) | |
| if show_vi: | |
| ax = axes[1] | |
| theta_grid, q_theta, elbo_hist, mu_f, sigma_f = variational_inference( | |
| prior_a, prior_b, heads, tails | |
| ) | |
| ax.plot(x, prior_pdf, label="Prior") | |
| ax.plot(x, posterior_pdf, label="Exact Posterior") | |
| ax.plot(theta_grid, q_theta, "--", linewidth=3, | |
| label=f"VI N({mu_f:.2f},{sigma_f:.2f})") | |
| ax.fill_between(theta_grid, 0, q_theta, alpha=0.2) | |
| ax.set_title("Variational Approximation") | |
| ax.set_xlabel("θ") | |
| ax.set_ylabel("Density") | |
| ax.legend() | |
| ax.grid(alpha=0.3) | |
| plt.tight_layout() | |
| return fig_to_pil(fig) | |
| # ========== 4. BRAIN HEALTH DEMO ========== | |
| def brain_health_demo(): | |
| np.random.seed(42) | |
| t = np.linspace(0, 3, 50) | |
| true_cbf, true_att = 60, 1.5 | |
| signal = true_cbf * (1 - np.exp(-t / true_att)) * np.exp(-t / 1.6) | |
| noisy = signal + np.random.normal(0, 3, len(t)) | |
| fig, ax = plt.subplots(1, 2, figsize=(12, 4)) | |
| ax[0].plot(t, signal, linewidth=3, label="True Signal") | |
| ax[0].scatter(t, noisy, alpha=0.6, label="Noisy") | |
| ax[0].set_title("Simulated ASL-MRI") | |
| ax[0].legend() | |
| ax[0].grid(alpha=0.3) | |
| cbf = np.random.normal(58, 5, 1000) | |
| att = np.random.normal(1.6, 0.3, 1000) | |
| from scipy.stats import gaussian_kde | |
| x = np.linspace(40, 80, 100) | |
| y = np.linspace(0.8, 2.4, 100) | |
| X, Y = np.meshgrid(x, y) | |
| Z = gaussian_kde(np.vstack([cbf, att]))(np.vstack([X.ravel(), Y.ravel()])).reshape(X.shape) | |
| ax[1].contour(X, Y, Z, levels=5) | |
| ax[1].scatter([true_cbf], [true_att], color="red", s=200, marker="*") | |
| ax[1].set_title("VI Parameter Posterior") | |
| ax[1].grid(alpha=0.3) | |
| plt.tight_layout() | |
| return fig_to_pil(fig) | |
| # ========== 5. GRADIO UI ========== | |
| with gr.Blocks(title="Variational Inference Playground") as demo: | |
| gr.Markdown("# 🧠 Variational Inference Playground") | |
| with gr.Tab("🎯 Coin Flip"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| prior_a = gr.Slider(0.1, 10, 2, 0.1, label="α") | |
| prior_b = gr.Slider(0.1, 10, 2, 0.1, label="β") | |
| heads = gr.Slider(0, 100, 8, 1, label="Heads") | |
| tails = gr.Slider(0, 100, 4, 1, label="Tails") | |
| show_vi = gr.Checkbox(True, label="Show VI") | |
| run_btn = gr.Button("Update", variant="primary") | |
| with gr.Column(): | |
| plot_output = gr.Image(type="pil") | |
| with gr.Tab("🧠 Brain Health"): | |
| brain_btn = gr.Button("Run Demo", variant="primary") | |
| brain_output = gr.Image(type="pil") | |
| run_btn.click( | |
| update_plot := lambda a, b, h, t, v: create_plot(a, b, h, t, v), | |
| inputs=[prior_a, prior_b, heads, tails, show_vi], | |
| outputs=plot_output | |
| ) | |
| brain_btn.click(brain_health_demo, outputs=brain_output) | |
| demo.load(lambda: create_plot(2, 2, 8, 4, True), outputs=plot_output) | |
| if __name__ == "__main__": | |
| demo.launch() | |