# visuals.py import torch import matplotlib.pyplot as plt import numpy as np import gradio as gr from PIL import Image from uei_core.models import ModelPortfolio from uei_core.uncertainty import UncertaintyEstimator from uei_core.energy import EnergyProfiler device = "cpu" models = ModelPortfolio(device=device) unc = UncertaintyEstimator() energy = EnergyProfiler() # ------------------------------ # 1️⃣ Plot: Uncertainty vs Energy Curve # ------------------------------ def plot_unc_energy(img): x = models.preprocess(img) logits_s, E_s = energy.measure(models.infer_small, x) logits_l, E_l = energy.measure(models.infer_large, x) U_s = float(unc.estimate(logits_s)) U_l = float(unc.estimate(logits_l)) # Create plot fig, ax = plt.subplots(figsize=(5,4), dpi=120) xs = [E_s, E_l] ys = [U_s, U_l] labels = ["Small Model", "Large Model"] colors = ["#1f77b4", "#ff7f0e"] ax.scatter(xs, ys, s=150, color=colors) ax.plot(xs, ys, linestyle="--", color="#888") for i, label in enumerate(labels): ax.annotate(label, (xs[i], ys[i]), textcoords="offset points", xytext=(8,5), ha='left', fontsize=10) ax.set_xlabel("Energy (proxy units)") ax.set_ylabel("Estimated Uncertainty") ax.set_title("Uncertainty vs Energy") ax.grid(True, alpha=0.3) return fig # ------------------------------ # 2️⃣ Plot: Layer Activation Heatmap # ------------------------------ def activation_heatmap(img): x = models.preprocess(img) # Register forward hook on the first conv activations = {} def hook(module, input, output): activations["feat"] = output.detach().cpu() h = models.small.features[0].register_forward_hook(hook) models.small(x) h.remove() feat = activations["feat"][0] # first batch # Average channels → 2D heatmap heat = feat.mean(dim=0).numpy() fig, ax = plt.subplots(figsize=(4,4), dpi=120) ax.imshow(heat, cmap="viridis") ax.set_title("Early Layer Activation Heatmap") ax.axis("off") return fig # ------------------------------ # 3️⃣ Plot: Model Comparison Bars # ------------------------------ def model_comparison(img): x = models.preprocess(img) logits_s, E_s = energy.measure(models.infer_small, x) logits_l, E_l = energy.measure(models.infer_large, x) U_s = float(unc.estimate(logits_s)) U_l = float(unc.estimate(logits_l)) fig, ax = plt.subplots(figsize=(6,4)) labels = ["Small Model", "Large Model"] energy_vals = [E_s, E_l] unc_vals = [U_s, U_l] x_axis = np.arange(len(labels)) w = 0.35 ax.bar(x_axis - w/2, energy_vals, w, label="Energy", color="#2ca02c") ax.bar(x_axis + w/2, unc_vals, w, label="Uncertainty", color="#d62728") ax.set_xticks(x_axis) ax.set_xticklabels(labels) ax.set_title("Model Energy & Uncertainty Comparison") ax.legend() ax.grid(alpha=0.2) return fig # ------------------------------ # 🔥 Gradio Interface # ------------------------------ def get_visual_ui(): with gr.Blocks() as demo: gr.Markdown("## 🔍 UEI Visualization Dashboard") gr.Markdown("Explore how UEI behaves internally with colorful charts") img = gr.Image(type="pil", label="Upload Image") with gr.Tabs(): with gr.Tab("Uncertainty vs Energy"): gr.Plot(label="Chart").render(fn=plot_unc_energy, inputs=img) with gr.Tab("Layer Activations"): gr.Plot(label="Activation Heatmap").render(fn=activation_heatmap, inputs=img) with gr.Tab("Model Comparison"): gr.Plot(label="Energy & Uncertainty Bars").render(fn=model_comparison, inputs=img) return demo