Spaces:
Runtime error
Runtime error
| # visuals.py | |
| import torch | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image | |
| from uei_core.models import ModelPortfolio | |
| from uei_core.uncertainty import UncertaintyEstimator | |
| from uei_core.energy import EnergyProfiler | |
| device = "cpu" | |
| models = ModelPortfolio(device=device) | |
| unc = UncertaintyEstimator() | |
| energy = EnergyProfiler() | |
| # ------------------------------ | |
| # 1️⃣ Plot: Uncertainty vs Energy Curve | |
| # ------------------------------ | |
| def plot_unc_energy(img): | |
| x = models.preprocess(img) | |
| logits_s, E_s = energy.measure(models.infer_small, x) | |
| logits_l, E_l = energy.measure(models.infer_large, x) | |
| U_s = float(unc.estimate(logits_s)) | |
| U_l = float(unc.estimate(logits_l)) | |
| # Create plot | |
| fig, ax = plt.subplots(figsize=(5,4), dpi=120) | |
| xs = [E_s, E_l] | |
| ys = [U_s, U_l] | |
| labels = ["Small Model", "Large Model"] | |
| colors = ["#1f77b4", "#ff7f0e"] | |
| ax.scatter(xs, ys, s=150, color=colors) | |
| ax.plot(xs, ys, linestyle="--", color="#888") | |
| for i, label in enumerate(labels): | |
| ax.annotate(label, (xs[i], ys[i]), textcoords="offset points", | |
| xytext=(8,5), ha='left', fontsize=10) | |
| ax.set_xlabel("Energy (proxy units)") | |
| ax.set_ylabel("Estimated Uncertainty") | |
| ax.set_title("Uncertainty vs Energy") | |
| ax.grid(True, alpha=0.3) | |
| return fig | |
| # ------------------------------ | |
| # 2️⃣ Plot: Layer Activation Heatmap | |
| # ------------------------------ | |
| def activation_heatmap(img): | |
| x = models.preprocess(img) | |
| # Register forward hook on the first conv | |
| activations = {} | |
| def hook(module, input, output): | |
| activations["feat"] = output.detach().cpu() | |
| h = models.small.features[0].register_forward_hook(hook) | |
| models.small(x) | |
| h.remove() | |
| feat = activations["feat"][0] # first batch | |
| # Average channels → 2D heatmap | |
| heat = feat.mean(dim=0).numpy() | |
| fig, ax = plt.subplots(figsize=(4,4), dpi=120) | |
| ax.imshow(heat, cmap="viridis") | |
| ax.set_title("Early Layer Activation Heatmap") | |
| ax.axis("off") | |
| return fig | |
| # ------------------------------ | |
| # 3️⃣ Plot: Model Comparison Bars | |
| # ------------------------------ | |
| def model_comparison(img): | |
| x = models.preprocess(img) | |
| logits_s, E_s = energy.measure(models.infer_small, x) | |
| logits_l, E_l = energy.measure(models.infer_large, x) | |
| U_s = float(unc.estimate(logits_s)) | |
| U_l = float(unc.estimate(logits_l)) | |
| fig, ax = plt.subplots(figsize=(6,4)) | |
| labels = ["Small Model", "Large Model"] | |
| energy_vals = [E_s, E_l] | |
| unc_vals = [U_s, U_l] | |
| x_axis = np.arange(len(labels)) | |
| w = 0.35 | |
| ax.bar(x_axis - w/2, energy_vals, w, label="Energy", color="#2ca02c") | |
| ax.bar(x_axis + w/2, unc_vals, w, label="Uncertainty", color="#d62728") | |
| ax.set_xticks(x_axis) | |
| ax.set_xticklabels(labels) | |
| ax.set_title("Model Energy & Uncertainty Comparison") | |
| ax.legend() | |
| ax.grid(alpha=0.2) | |
| return fig | |
| # ------------------------------ | |
| # 🔥 Gradio Interface | |
| # ------------------------------ | |
| def get_visual_ui(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## 🔍 UEI Visualization Dashboard") | |
| gr.Markdown("Explore how UEI behaves internally with colorful charts") | |
| img = gr.Image(type="pil", label="Upload Image") | |
| with gr.Tabs(): | |
| with gr.Tab("Uncertainty vs Energy"): | |
| gr.Plot(label="Chart").render(fn=plot_unc_energy, inputs=img) | |
| with gr.Tab("Layer Activations"): | |
| gr.Plot(label="Activation Heatmap").render(fn=activation_heatmap, inputs=img) | |
| with gr.Tab("Model Comparison"): | |
| gr.Plot(label="Energy & Uncertainty Bars").render(fn=model_comparison, inputs=img) | |
| return demo |