UEI / visuals.py
GirishaBuilds01's picture
Create visuals.py
716e4bc verified
# visuals.py
import torch
import matplotlib.pyplot as plt
import numpy as np
import gradio as gr
from PIL import Image
from uei_core.models import ModelPortfolio
from uei_core.uncertainty import UncertaintyEstimator
from uei_core.energy import EnergyProfiler
device = "cpu"
models = ModelPortfolio(device=device)
unc = UncertaintyEstimator()
energy = EnergyProfiler()
# ------------------------------
# 1️⃣ Plot: Uncertainty vs Energy Curve
# ------------------------------
def plot_unc_energy(img):
x = models.preprocess(img)
logits_s, E_s = energy.measure(models.infer_small, x)
logits_l, E_l = energy.measure(models.infer_large, x)
U_s = float(unc.estimate(logits_s))
U_l = float(unc.estimate(logits_l))
# Create plot
fig, ax = plt.subplots(figsize=(5,4), dpi=120)
xs = [E_s, E_l]
ys = [U_s, U_l]
labels = ["Small Model", "Large Model"]
colors = ["#1f77b4", "#ff7f0e"]
ax.scatter(xs, ys, s=150, color=colors)
ax.plot(xs, ys, linestyle="--", color="#888")
for i, label in enumerate(labels):
ax.annotate(label, (xs[i], ys[i]), textcoords="offset points",
xytext=(8,5), ha='left', fontsize=10)
ax.set_xlabel("Energy (proxy units)")
ax.set_ylabel("Estimated Uncertainty")
ax.set_title("Uncertainty vs Energy")
ax.grid(True, alpha=0.3)
return fig
# ------------------------------
# 2️⃣ Plot: Layer Activation Heatmap
# ------------------------------
def activation_heatmap(img):
x = models.preprocess(img)
# Register forward hook on the first conv
activations = {}
def hook(module, input, output):
activations["feat"] = output.detach().cpu()
h = models.small.features[0].register_forward_hook(hook)
models.small(x)
h.remove()
feat = activations["feat"][0] # first batch
# Average channels → 2D heatmap
heat = feat.mean(dim=0).numpy()
fig, ax = plt.subplots(figsize=(4,4), dpi=120)
ax.imshow(heat, cmap="viridis")
ax.set_title("Early Layer Activation Heatmap")
ax.axis("off")
return fig
# ------------------------------
# 3️⃣ Plot: Model Comparison Bars
# ------------------------------
def model_comparison(img):
x = models.preprocess(img)
logits_s, E_s = energy.measure(models.infer_small, x)
logits_l, E_l = energy.measure(models.infer_large, x)
U_s = float(unc.estimate(logits_s))
U_l = float(unc.estimate(logits_l))
fig, ax = plt.subplots(figsize=(6,4))
labels = ["Small Model", "Large Model"]
energy_vals = [E_s, E_l]
unc_vals = [U_s, U_l]
x_axis = np.arange(len(labels))
w = 0.35
ax.bar(x_axis - w/2, energy_vals, w, label="Energy", color="#2ca02c")
ax.bar(x_axis + w/2, unc_vals, w, label="Uncertainty", color="#d62728")
ax.set_xticks(x_axis)
ax.set_xticklabels(labels)
ax.set_title("Model Energy & Uncertainty Comparison")
ax.legend()
ax.grid(alpha=0.2)
return fig
# ------------------------------
# 🔥 Gradio Interface
# ------------------------------
def get_visual_ui():
with gr.Blocks() as demo:
gr.Markdown("## 🔍 UEI Visualization Dashboard")
gr.Markdown("Explore how UEI behaves internally with colorful charts")
img = gr.Image(type="pil", label="Upload Image")
with gr.Tabs():
with gr.Tab("Uncertainty vs Energy"):
gr.Plot(label="Chart").render(fn=plot_unc_energy, inputs=img)
with gr.Tab("Layer Activations"):
gr.Plot(label="Activation Heatmap").render(fn=activation_heatmap, inputs=img)
with gr.Tab("Model Comparison"):
gr.Plot(label="Energy & Uncertainty Bars").render(fn=model_comparison, inputs=img)
return demo