# app.py — Medium LLM Visualizer for Hugging Face Spaces (Gradio) # Safe for free CPU Spaces (distilgpt2 / gpt2). No heavy patching/residual features. import gradio as gr import torch import numpy as np import plotly.express as px import plotly.graph_objects as go from transformers import AutoTokenizer, AutoModelForCausalLM from sklearn.decomposition import PCA import time # ------- Config ------- DEFAULT_MODEL = "distilgpt2" # fits free CPU Spaces DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # simple cache so model/tokenizer aren't reloaded on subsequent calls _MODEL_CACHE = {} def load_model(model_name): if model_name in _MODEL_CACHE: return _MODEL_CACHE[model_name] tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name, output_attentions=True, output_hidden_states=True) model.to(DEVICE) model.eval() _MODEL_CACHE[model_name] = (model, tokenizer) return model, tokenizer # ------- Helpers ------- def softmax(x): e = np.exp(x - np.max(x)) return e / e.sum(axis=-1, keepdims=True) def compute_pca(hidden_layer): # hidden_layer: (seq_len, hidden_dim) try: p = PCA(n_components=2).fit_transform(hidden_layer) return p except Exception: # fallback: take first two dimensions (ugly but safe) seq = hidden_layer.shape[0] d0 = hidden_layer[:, 0] if hidden_layer.shape[1] > 0 else np.zeros(seq) d1 = hidden_layer[:, 1] if hidden_layer.shape[1] > 1 else np.zeros(seq) return np.vstack([d0, d1]).T def make_attention_figure(attn_matrix, tokens, title=None): fig = px.imshow( attn_matrix, x=tokens, y=tokens, labels={"x": "Key token", "y": "Query token", "color": "Attention"}, title=title or "Attention" ) fig.update_layout(height=420, margin=dict(l=60, r=20, t=40, b=40)) return fig def make_pca_figure(points, tokens, highlight_idx=None, title=None): fig = px.scatter(x=points[:, 0], y=points[:, 1], text=tokens, title=title or "PCA (2D)") fig.update_traces(textposition="top center", marker=dict(size=10)) if highlight_idx is not None: fig.add_trace(go.Scatter( x=[points[highlight_idx, 0]], y=[points[highlight_idx, 1]], mode="markers+text", text=[tokens[highlight_idx]], marker=dict(size=18, color="red"), name="selected token" )) fig.update_layout(height=420, margin=dict(l=40, r=40, t=40, b=40)) return fig def make_probs_figure(top_tokens, top_scores, title=None): fig = go.Figure(data=[go.Bar(x=top_tokens, y=top_scores, marker=dict())]) fig.update_layout(title=title or "Next-token top predictions", yaxis_title="Probability", height=360, margin=dict(l=40, r=20, t=40, b=40)) return fig # ------- Core analysis (fast, safe) ------- def analyze_text(text, model_name, explain_simple): # Basic validation if not text or len(text.strip()) == 0: return {"error": "Please enter some text."} try: model, tokenizer = load_model(model_name) except Exception as e: return {"error": f"Failed to load model '{model_name}': {e}"} # Tokenize and forward try: inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False).to(DEVICE) except Exception as e: return {"error": f"Tokenization error: {e}"} with torch.no_grad(): try: outputs = model(**inputs) except Exception as e: return {"error": f"Model forward error: {e}"} # Extract internals try: input_ids = inputs["input_ids"][0].cpu().numpy().tolist() tokens = tokenizer.convert_ids_to_tokens(input_ids) except Exception: return {"error": "Failed to extract tokens."} attentions = [a[0].cpu().numpy() for a in outputs.attentions] if outputs.attentions is not None else None hidden = [h[0].cpu().numpy() for h in outputs.hidden_states] if outputs.hidden_states is not None else None logits = outputs.logits[0].cpu().numpy() # shape (seq_len, vocab_size) # Precompute PCA for each layer (small sequences only) pca_layers = [] if hidden is not None: for layer_h in hidden: pca_layers.append(compute_pca(layer_h)) else: pca_layers = None # Next-token top-k last_logits = logits[-1] probs = softmax(last_logits) topk = 20 idx = np.argsort(probs)[-topk:][::-1] top_tokens = [tokenizer.decode([int(i)]) for i in idx] top_scores = probs[idx].tolist() # default selections default_layer = len(attentions) - 1 if attentions is not None else (len(pca_layers) - 1 if pca_layers else 0) default_head = 0 # Figures fig_attn = None if attentions is not None: fig_attn = make_attention_figure(attentions[default_layer][default_head], tokens, title=f"Layer {default_layer} Head {default_head}") fig_pca = None if pca_layers is not None: fig_pca = make_pca_figure(pca_layers[default_layer], tokens, highlight_idx=None, title=f"PCA (layer {default_layer})") fig_probs = make_probs_figure(top_tokens, top_scores) explanation = ( "Simple: The model splits text into pieces (tokens), looks at which tokens it should pay attention to, and guesses the next word." if explain_simple else "Technical: showing tokens, attention maps (query-key), hidden states projected to 2D, and next-token probabilities." ) return { "tokens": tokens, "attentions": attentions, "pca_layers": pca_layers, "logits": logits, "fig_attn": fig_attn, "fig_pca": fig_pca, "fig_probs": fig_probs, "default_layer": default_layer, "default_head": default_head, "token_display": " ".join([f"[{t}]" for t in tokens]), "explanation": explanation, } # ------- UI Functions ------- def run_analysis(text, model_name, explain_simple): # runs analysis and returns values in the same order as outputs below start = time.time() res = analyze_text(text, model_name, explain_simple) if "error" in res: # return simple error outputs matching output order token_display_text = "" explanation_text = res.get("error", "Unknown error") model_info = f"Model: {model_name}" return ( token_display_text, explanation_text, model_info, None, None, None, gr.update(maximum=0, value=0), gr.update(maximum=0, value=0), gr.update(maximum=0, value=0), None ) tokens = res["tokens"] num_layers = len(res["attentions"]) if res["attentions"] is not None else (len(res["pca_layers"]) - 1 if res["pca_layers"] else 0) num_heads = res["attentions"][0].shape[0] if res["attentions"] is not None else 1 max_token_idx = len(tokens) - 1 token_display_text = f"**Tokens:** {res['token_display']}" explanation_text = res["explanation"] model_info = f"Model: {model_name} • layers: {num_layers} • heads: {num_heads} • tokens: {len(tokens)}" # slider updates (use gr.update) layer_update = gr.update(maximum=max(0, num_layers - 1), value=res["default_layer"]) head_update = gr.update(maximum=max(0, num_heads - 1), value=res["default_head"]) token_step_update = gr.update(maximum=max_token_idx, value=0) # figures attn_fig = res["fig_attn"] pca_fig = res["fig_pca"] probs_fig = res["fig_probs"] # return in order matching outputs declared in the UI wiring below return ( token_display_text, explanation_text, model_info, attn_fig, pca_fig, probs_fig, layer_update, head_update, token_step_update, res # store the whole result in state for slider-driven updates ) def update_visuals(state_obj, layer, head, token_idx): # given cached analysis state, return updated (attn_fig, pca_fig, step_attn_fig) if not state_obj: return None, None, None res = state_obj tokens = res["tokens"] # bounds safety layer = int(min(max(0, layer), (len(res["attentions"]) - 1) if res["attentions"] is not None else 0)) head = int(min(max(0, head), (res["attentions"][0].shape[0] - 1) if res["attentions"] is not None else 0)) token_idx = int(min(max(0, token_idx), len(tokens) - 1)) # attention figure for requested layer/head attn_fig = None if res["attentions"] is not None: attn_fig = make_attention_figure(res["attentions"][layer][head], tokens, title=f"Layer {layer} Head {head}") # PCA for requested layer, highlight token_idx pca_fig = None if res["pca_layers"] is not None: pts = res["pca_layers"][layer] pca_fig = make_pca_figure(pts, tokens, highlight_idx=token_idx, title=f"PCA (layer {layer})") # attention row (who token_idx attends to) step_attn_fig = None if res["attentions"] is not None: row = res["attentions"][layer][head][token_idx] # shape: seq_len step_attn_fig = go.Figure(data=[go.Bar(x=tokens, y=row)]) step_attn_fig.update_layout(title=f"Token {token_idx} attends to (layer {layer}, head {head})", height=300, margin=dict(l=40, r=20, t=30, b=40)) return attn_fig, pca_fig, step_attn_fig # ------- Gradio UI ------- with gr.Blocks(title="LLM Visualizer — Medium (HF Spaces)", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🧠 LLM Visualizer — Medium (safe for free HF Spaces)") with gr.Row(): with gr.Column(scale=3): model_input = gr.Textbox(label="Model (Hugging Face name)", value=DEFAULT_MODEL) text_input = gr.Textbox(label="Input text", value="Hello world, this is a test.", lines=3) explain_simple = gr.Checkbox(label="Explain simply (kid/elder mode)", value=True) run_btn = gr.Button("Run visualizer", variant="primary") gr.Markdown("**Presets:**") with gr.Row(): gr.Button("Greeting").click(lambda: "Hello! How are you today?", None, text_input) gr.Button("Story start").click(lambda: "Once upon a time, there was a small robot...", None, text_input) gr.Button("Question").click(lambda: "Why is the sky blue?", None, text_input) with gr.Column(scale=2): token_display = gr.Markdown("Tokens will appear here.") explanation_md = gr.Markdown("Explanation will appear here.") model_info = gr.Markdown("Model info: —") with gr.Row(): with gr.Column(): layer_slider = gr.Slider(label="Layer", minimum=0, maximum=0, step=1, value=0) head_slider = gr.Slider(label="Head", minimum=0, maximum=0, step=1, value=0) token_step = gr.Slider(label="Token index (step through tokens)", minimum=0, maximum=0, step=1, value=0) attn_plot = gr.Plot(label="Attention heatmap") with gr.Column(): pca_plot = gr.Plot(label="PCA hidden states (2D)") step_attn_plot = gr.Plot(label="Attention row for selected token") probs_plot = gr.Plot(label="Next-token top predictions") # state storage (cached analysis) state = gr.State() # Wire up events run_btn.click( fn=run_analysis, inputs=[text_input, model_input, explain_simple], outputs=[ token_display, explanation_md, model_info, attn_plot, pca_plot, probs_plot, layer_slider, head_slider, token_step, state ], ) # sliders -> update visuals (use the stored state) layer_slider.change(fn=update_visuals, inputs=[state, layer_slider, head_slider, token_step], outputs=[attn_plot, pca_plot, step_attn_plot]) head_slider.change(fn=update_visuals, inputs=[state, layer_slider, head_slider, token_step], outputs=[attn_plot, pca_plot, step_attn_plot]) token_step.change(fn=update_visuals, inputs=[state, layer_slider, head_slider, token_step], outputs=[attn_plot, pca_plot, step_attn_plot]) demo.launch()