PraneshJs's picture
switched to medium features
9146c56 verified
# app.py — Medium LLM Visualizer for Hugging Face Spaces (Gradio)
# Safe for free CPU Spaces (distilgpt2 / gpt2). No heavy patching/residual features.
import gradio as gr
import torch
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from transformers import AutoTokenizer, AutoModelForCausalLM
from sklearn.decomposition import PCA
import time
# ------- Config -------
DEFAULT_MODEL = "distilgpt2" # fits free CPU Spaces
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# simple cache so model/tokenizer aren't reloaded on subsequent calls
_MODEL_CACHE = {}
def load_model(model_name):
if model_name in _MODEL_CACHE:
return _MODEL_CACHE[model_name]
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
model.to(DEVICE)
model.eval()
_MODEL_CACHE[model_name] = (model, tokenizer)
return model, tokenizer
# ------- Helpers -------
def softmax(x):
e = np.exp(x - np.max(x))
return e / e.sum(axis=-1, keepdims=True)
def compute_pca(hidden_layer):
# hidden_layer: (seq_len, hidden_dim)
try:
p = PCA(n_components=2).fit_transform(hidden_layer)
return p
except Exception:
# fallback: take first two dimensions (ugly but safe)
seq = hidden_layer.shape[0]
d0 = hidden_layer[:, 0] if hidden_layer.shape[1] > 0 else np.zeros(seq)
d1 = hidden_layer[:, 1] if hidden_layer.shape[1] > 1 else np.zeros(seq)
return np.vstack([d0, d1]).T
def make_attention_figure(attn_matrix, tokens, title=None):
fig = px.imshow(
attn_matrix,
x=tokens,
y=tokens,
labels={"x": "Key token", "y": "Query token", "color": "Attention"},
title=title or "Attention"
)
fig.update_layout(height=420, margin=dict(l=60, r=20, t=40, b=40))
return fig
def make_pca_figure(points, tokens, highlight_idx=None, title=None):
fig = px.scatter(x=points[:, 0], y=points[:, 1], text=tokens, title=title or "PCA (2D)")
fig.update_traces(textposition="top center", marker=dict(size=10))
if highlight_idx is not None:
fig.add_trace(go.Scatter(
x=[points[highlight_idx, 0]],
y=[points[highlight_idx, 1]],
mode="markers+text",
text=[tokens[highlight_idx]],
marker=dict(size=18, color="red"),
name="selected token"
))
fig.update_layout(height=420, margin=dict(l=40, r=40, t=40, b=40))
return fig
def make_probs_figure(top_tokens, top_scores, title=None):
fig = go.Figure(data=[go.Bar(x=top_tokens, y=top_scores, marker=dict())])
fig.update_layout(title=title or "Next-token top predictions", yaxis_title="Probability", height=360, margin=dict(l=40, r=20, t=40, b=40))
return fig
# ------- Core analysis (fast, safe) -------
def analyze_text(text, model_name, explain_simple):
# Basic validation
if not text or len(text.strip()) == 0:
return {"error": "Please enter some text."}
try:
model, tokenizer = load_model(model_name)
except Exception as e:
return {"error": f"Failed to load model '{model_name}': {e}"}
# Tokenize and forward
try:
inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False).to(DEVICE)
except Exception as e:
return {"error": f"Tokenization error: {e}"}
with torch.no_grad():
try:
outputs = model(**inputs)
except Exception as e:
return {"error": f"Model forward error: {e}"}
# Extract internals
try:
input_ids = inputs["input_ids"][0].cpu().numpy().tolist()
tokens = tokenizer.convert_ids_to_tokens(input_ids)
except Exception:
return {"error": "Failed to extract tokens."}
attentions = [a[0].cpu().numpy() for a in outputs.attentions] if outputs.attentions is not None else None
hidden = [h[0].cpu().numpy() for h in outputs.hidden_states] if outputs.hidden_states is not None else None
logits = outputs.logits[0].cpu().numpy() # shape (seq_len, vocab_size)
# Precompute PCA for each layer (small sequences only)
pca_layers = []
if hidden is not None:
for layer_h in hidden:
pca_layers.append(compute_pca(layer_h))
else:
pca_layers = None
# Next-token top-k
last_logits = logits[-1]
probs = softmax(last_logits)
topk = 20
idx = np.argsort(probs)[-topk:][::-1]
top_tokens = [tokenizer.decode([int(i)]) for i in idx]
top_scores = probs[idx].tolist()
# default selections
default_layer = len(attentions) - 1 if attentions is not None else (len(pca_layers) - 1 if pca_layers else 0)
default_head = 0
# Figures
fig_attn = None
if attentions is not None:
fig_attn = make_attention_figure(attentions[default_layer][default_head], tokens, title=f"Layer {default_layer} Head {default_head}")
fig_pca = None
if pca_layers is not None:
fig_pca = make_pca_figure(pca_layers[default_layer], tokens, highlight_idx=None, title=f"PCA (layer {default_layer})")
fig_probs = make_probs_figure(top_tokens, top_scores)
explanation = (
"Simple: The model splits text into pieces (tokens), looks at which tokens it should pay attention to, and guesses the next word."
if explain_simple else
"Technical: showing tokens, attention maps (query-key), hidden states projected to 2D, and next-token probabilities."
)
return {
"tokens": tokens,
"attentions": attentions,
"pca_layers": pca_layers,
"logits": logits,
"fig_attn": fig_attn,
"fig_pca": fig_pca,
"fig_probs": fig_probs,
"default_layer": default_layer,
"default_head": default_head,
"token_display": " ".join([f"[{t}]" for t in tokens]),
"explanation": explanation,
}
# ------- UI Functions -------
def run_analysis(text, model_name, explain_simple):
# runs analysis and returns values in the same order as outputs below
start = time.time()
res = analyze_text(text, model_name, explain_simple)
if "error" in res:
# return simple error outputs matching output order
token_display_text = ""
explanation_text = res.get("error", "Unknown error")
model_info = f"Model: {model_name}"
return (
token_display_text, explanation_text, model_info,
None, None, None,
gr.update(maximum=0, value=0), gr.update(maximum=0, value=0), gr.update(maximum=0, value=0),
None
)
tokens = res["tokens"]
num_layers = len(res["attentions"]) if res["attentions"] is not None else (len(res["pca_layers"]) - 1 if res["pca_layers"] else 0)
num_heads = res["attentions"][0].shape[0] if res["attentions"] is not None else 1
max_token_idx = len(tokens) - 1
token_display_text = f"**Tokens:** {res['token_display']}"
explanation_text = res["explanation"]
model_info = f"Model: {model_name} • layers: {num_layers} • heads: {num_heads} • tokens: {len(tokens)}"
# slider updates (use gr.update)
layer_update = gr.update(maximum=max(0, num_layers - 1), value=res["default_layer"])
head_update = gr.update(maximum=max(0, num_heads - 1), value=res["default_head"])
token_step_update = gr.update(maximum=max_token_idx, value=0)
# figures
attn_fig = res["fig_attn"]
pca_fig = res["fig_pca"]
probs_fig = res["fig_probs"]
# return in order matching outputs declared in the UI wiring below
return (
token_display_text, explanation_text, model_info,
attn_fig, pca_fig, probs_fig,
layer_update, head_update, token_step_update,
res # store the whole result in state for slider-driven updates
)
def update_visuals(state_obj, layer, head, token_idx):
# given cached analysis state, return updated (attn_fig, pca_fig, step_attn_fig)
if not state_obj:
return None, None, None
res = state_obj
tokens = res["tokens"]
# bounds safety
layer = int(min(max(0, layer), (len(res["attentions"]) - 1) if res["attentions"] is not None else 0))
head = int(min(max(0, head), (res["attentions"][0].shape[0] - 1) if res["attentions"] is not None else 0))
token_idx = int(min(max(0, token_idx), len(tokens) - 1))
# attention figure for requested layer/head
attn_fig = None
if res["attentions"] is not None:
attn_fig = make_attention_figure(res["attentions"][layer][head], tokens, title=f"Layer {layer} Head {head}")
# PCA for requested layer, highlight token_idx
pca_fig = None
if res["pca_layers"] is not None:
pts = res["pca_layers"][layer]
pca_fig = make_pca_figure(pts, tokens, highlight_idx=token_idx, title=f"PCA (layer {layer})")
# attention row (who token_idx attends to)
step_attn_fig = None
if res["attentions"] is not None:
row = res["attentions"][layer][head][token_idx] # shape: seq_len
step_attn_fig = go.Figure(data=[go.Bar(x=tokens, y=row)])
step_attn_fig.update_layout(title=f"Token {token_idx} attends to (layer {layer}, head {head})", height=300, margin=dict(l=40, r=20, t=30, b=40))
return attn_fig, pca_fig, step_attn_fig
# ------- Gradio UI -------
with gr.Blocks(title="LLM Visualizer — Medium (HF Spaces)", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🧠 LLM Visualizer — Medium (safe for free HF Spaces)")
with gr.Row():
with gr.Column(scale=3):
model_input = gr.Textbox(label="Model (Hugging Face name)", value=DEFAULT_MODEL)
text_input = gr.Textbox(label="Input text", value="Hello world, this is a test.", lines=3)
explain_simple = gr.Checkbox(label="Explain simply (kid/elder mode)", value=True)
run_btn = gr.Button("Run visualizer", variant="primary")
gr.Markdown("**Presets:**")
with gr.Row():
gr.Button("Greeting").click(lambda: "Hello! How are you today?", None, text_input)
gr.Button("Story start").click(lambda: "Once upon a time, there was a small robot...", None, text_input)
gr.Button("Question").click(lambda: "Why is the sky blue?", None, text_input)
with gr.Column(scale=2):
token_display = gr.Markdown("Tokens will appear here.")
explanation_md = gr.Markdown("Explanation will appear here.")
model_info = gr.Markdown("Model info: —")
with gr.Row():
with gr.Column():
layer_slider = gr.Slider(label="Layer", minimum=0, maximum=0, step=1, value=0)
head_slider = gr.Slider(label="Head", minimum=0, maximum=0, step=1, value=0)
token_step = gr.Slider(label="Token index (step through tokens)", minimum=0, maximum=0, step=1, value=0)
attn_plot = gr.Plot(label="Attention heatmap")
with gr.Column():
pca_plot = gr.Plot(label="PCA hidden states (2D)")
step_attn_plot = gr.Plot(label="Attention row for selected token")
probs_plot = gr.Plot(label="Next-token top predictions")
# state storage (cached analysis)
state = gr.State()
# Wire up events
run_btn.click(
fn=run_analysis,
inputs=[text_input, model_input, explain_simple],
outputs=[
token_display, explanation_md, model_info,
attn_plot, pca_plot, probs_plot,
layer_slider, head_slider, token_step,
state
],
)
# sliders -> update visuals (use the stored state)
layer_slider.change(fn=update_visuals, inputs=[state, layer_slider, head_slider, token_step], outputs=[attn_plot, pca_plot, step_attn_plot])
head_slider.change(fn=update_visuals, inputs=[state, layer_slider, head_slider, token_step], outputs=[attn_plot, pca_plot, step_attn_plot])
token_step.change(fn=update_visuals, inputs=[state, layer_slider, head_slider, token_step], outputs=[attn_plot, pca_plot, step_attn_plot])
demo.launch()