|
|
|
|
|
|
|
|
|
|
|
import gradio as gr |
|
|
import torch |
|
|
import numpy as np |
|
|
import plotly.express as px |
|
|
import plotly.graph_objects as go |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
from sklearn.decomposition import PCA |
|
|
import time |
|
|
|
|
|
|
|
|
DEFAULT_MODEL = "distilgpt2" |
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
_MODEL_CACHE = {} |
|
|
|
|
|
def load_model(model_name): |
|
|
if model_name in _MODEL_CACHE: |
|
|
return _MODEL_CACHE[model_name] |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForCausalLM.from_pretrained(model_name, output_attentions=True, output_hidden_states=True) |
|
|
model.to(DEVICE) |
|
|
model.eval() |
|
|
_MODEL_CACHE[model_name] = (model, tokenizer) |
|
|
return model, tokenizer |
|
|
|
|
|
|
|
|
def softmax(x): |
|
|
e = np.exp(x - np.max(x)) |
|
|
return e / e.sum(axis=-1, keepdims=True) |
|
|
|
|
|
def compute_pca(hidden_layer): |
|
|
|
|
|
try: |
|
|
p = PCA(n_components=2).fit_transform(hidden_layer) |
|
|
return p |
|
|
except Exception: |
|
|
|
|
|
seq = hidden_layer.shape[0] |
|
|
d0 = hidden_layer[:, 0] if hidden_layer.shape[1] > 0 else np.zeros(seq) |
|
|
d1 = hidden_layer[:, 1] if hidden_layer.shape[1] > 1 else np.zeros(seq) |
|
|
return np.vstack([d0, d1]).T |
|
|
|
|
|
def make_attention_figure(attn_matrix, tokens, title=None): |
|
|
fig = px.imshow( |
|
|
attn_matrix, |
|
|
x=tokens, |
|
|
y=tokens, |
|
|
labels={"x": "Key token", "y": "Query token", "color": "Attention"}, |
|
|
title=title or "Attention" |
|
|
) |
|
|
fig.update_layout(height=420, margin=dict(l=60, r=20, t=40, b=40)) |
|
|
return fig |
|
|
|
|
|
def make_pca_figure(points, tokens, highlight_idx=None, title=None): |
|
|
fig = px.scatter(x=points[:, 0], y=points[:, 1], text=tokens, title=title or "PCA (2D)") |
|
|
fig.update_traces(textposition="top center", marker=dict(size=10)) |
|
|
if highlight_idx is not None: |
|
|
fig.add_trace(go.Scatter( |
|
|
x=[points[highlight_idx, 0]], |
|
|
y=[points[highlight_idx, 1]], |
|
|
mode="markers+text", |
|
|
text=[tokens[highlight_idx]], |
|
|
marker=dict(size=18, color="red"), |
|
|
name="selected token" |
|
|
)) |
|
|
fig.update_layout(height=420, margin=dict(l=40, r=40, t=40, b=40)) |
|
|
return fig |
|
|
|
|
|
def make_probs_figure(top_tokens, top_scores, title=None): |
|
|
fig = go.Figure(data=[go.Bar(x=top_tokens, y=top_scores, marker=dict())]) |
|
|
fig.update_layout(title=title or "Next-token top predictions", yaxis_title="Probability", height=360, margin=dict(l=40, r=20, t=40, b=40)) |
|
|
return fig |
|
|
|
|
|
|
|
|
def analyze_text(text, model_name, explain_simple): |
|
|
|
|
|
if not text or len(text.strip()) == 0: |
|
|
return {"error": "Please enter some text."} |
|
|
|
|
|
try: |
|
|
model, tokenizer = load_model(model_name) |
|
|
except Exception as e: |
|
|
return {"error": f"Failed to load model '{model_name}': {e}"} |
|
|
|
|
|
|
|
|
try: |
|
|
inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False).to(DEVICE) |
|
|
except Exception as e: |
|
|
return {"error": f"Tokenization error: {e}"} |
|
|
|
|
|
with torch.no_grad(): |
|
|
try: |
|
|
outputs = model(**inputs) |
|
|
except Exception as e: |
|
|
return {"error": f"Model forward error: {e}"} |
|
|
|
|
|
|
|
|
try: |
|
|
input_ids = inputs["input_ids"][0].cpu().numpy().tolist() |
|
|
tokens = tokenizer.convert_ids_to_tokens(input_ids) |
|
|
except Exception: |
|
|
return {"error": "Failed to extract tokens."} |
|
|
|
|
|
attentions = [a[0].cpu().numpy() for a in outputs.attentions] if outputs.attentions is not None else None |
|
|
hidden = [h[0].cpu().numpy() for h in outputs.hidden_states] if outputs.hidden_states is not None else None |
|
|
logits = outputs.logits[0].cpu().numpy() |
|
|
|
|
|
|
|
|
pca_layers = [] |
|
|
if hidden is not None: |
|
|
for layer_h in hidden: |
|
|
pca_layers.append(compute_pca(layer_h)) |
|
|
else: |
|
|
pca_layers = None |
|
|
|
|
|
|
|
|
last_logits = logits[-1] |
|
|
probs = softmax(last_logits) |
|
|
topk = 20 |
|
|
idx = np.argsort(probs)[-topk:][::-1] |
|
|
top_tokens = [tokenizer.decode([int(i)]) for i in idx] |
|
|
top_scores = probs[idx].tolist() |
|
|
|
|
|
|
|
|
default_layer = len(attentions) - 1 if attentions is not None else (len(pca_layers) - 1 if pca_layers else 0) |
|
|
default_head = 0 |
|
|
|
|
|
|
|
|
fig_attn = None |
|
|
if attentions is not None: |
|
|
fig_attn = make_attention_figure(attentions[default_layer][default_head], tokens, title=f"Layer {default_layer} Head {default_head}") |
|
|
|
|
|
fig_pca = None |
|
|
if pca_layers is not None: |
|
|
fig_pca = make_pca_figure(pca_layers[default_layer], tokens, highlight_idx=None, title=f"PCA (layer {default_layer})") |
|
|
|
|
|
fig_probs = make_probs_figure(top_tokens, top_scores) |
|
|
|
|
|
explanation = ( |
|
|
"Simple: The model splits text into pieces (tokens), looks at which tokens it should pay attention to, and guesses the next word." |
|
|
if explain_simple else |
|
|
"Technical: showing tokens, attention maps (query-key), hidden states projected to 2D, and next-token probabilities." |
|
|
) |
|
|
|
|
|
return { |
|
|
"tokens": tokens, |
|
|
"attentions": attentions, |
|
|
"pca_layers": pca_layers, |
|
|
"logits": logits, |
|
|
"fig_attn": fig_attn, |
|
|
"fig_pca": fig_pca, |
|
|
"fig_probs": fig_probs, |
|
|
"default_layer": default_layer, |
|
|
"default_head": default_head, |
|
|
"token_display": " ".join([f"[{t}]" for t in tokens]), |
|
|
"explanation": explanation, |
|
|
} |
|
|
|
|
|
|
|
|
def run_analysis(text, model_name, explain_simple): |
|
|
|
|
|
start = time.time() |
|
|
res = analyze_text(text, model_name, explain_simple) |
|
|
if "error" in res: |
|
|
|
|
|
token_display_text = "" |
|
|
explanation_text = res.get("error", "Unknown error") |
|
|
model_info = f"Model: {model_name}" |
|
|
return ( |
|
|
token_display_text, explanation_text, model_info, |
|
|
None, None, None, |
|
|
gr.update(maximum=0, value=0), gr.update(maximum=0, value=0), gr.update(maximum=0, value=0), |
|
|
None |
|
|
) |
|
|
|
|
|
tokens = res["tokens"] |
|
|
num_layers = len(res["attentions"]) if res["attentions"] is not None else (len(res["pca_layers"]) - 1 if res["pca_layers"] else 0) |
|
|
num_heads = res["attentions"][0].shape[0] if res["attentions"] is not None else 1 |
|
|
max_token_idx = len(tokens) - 1 |
|
|
|
|
|
token_display_text = f"**Tokens:** {res['token_display']}" |
|
|
explanation_text = res["explanation"] |
|
|
model_info = f"Model: {model_name} • layers: {num_layers} • heads: {num_heads} • tokens: {len(tokens)}" |
|
|
|
|
|
|
|
|
layer_update = gr.update(maximum=max(0, num_layers - 1), value=res["default_layer"]) |
|
|
head_update = gr.update(maximum=max(0, num_heads - 1), value=res["default_head"]) |
|
|
token_step_update = gr.update(maximum=max_token_idx, value=0) |
|
|
|
|
|
|
|
|
attn_fig = res["fig_attn"] |
|
|
pca_fig = res["fig_pca"] |
|
|
probs_fig = res["fig_probs"] |
|
|
|
|
|
|
|
|
return ( |
|
|
token_display_text, explanation_text, model_info, |
|
|
attn_fig, pca_fig, probs_fig, |
|
|
layer_update, head_update, token_step_update, |
|
|
res |
|
|
) |
|
|
|
|
|
def update_visuals(state_obj, layer, head, token_idx): |
|
|
|
|
|
if not state_obj: |
|
|
return None, None, None |
|
|
res = state_obj |
|
|
tokens = res["tokens"] |
|
|
|
|
|
layer = int(min(max(0, layer), (len(res["attentions"]) - 1) if res["attentions"] is not None else 0)) |
|
|
head = int(min(max(0, head), (res["attentions"][0].shape[0] - 1) if res["attentions"] is not None else 0)) |
|
|
token_idx = int(min(max(0, token_idx), len(tokens) - 1)) |
|
|
|
|
|
|
|
|
attn_fig = None |
|
|
if res["attentions"] is not None: |
|
|
attn_fig = make_attention_figure(res["attentions"][layer][head], tokens, title=f"Layer {layer} Head {head}") |
|
|
|
|
|
|
|
|
pca_fig = None |
|
|
if res["pca_layers"] is not None: |
|
|
pts = res["pca_layers"][layer] |
|
|
pca_fig = make_pca_figure(pts, tokens, highlight_idx=token_idx, title=f"PCA (layer {layer})") |
|
|
|
|
|
|
|
|
step_attn_fig = None |
|
|
if res["attentions"] is not None: |
|
|
row = res["attentions"][layer][head][token_idx] |
|
|
step_attn_fig = go.Figure(data=[go.Bar(x=tokens, y=row)]) |
|
|
step_attn_fig.update_layout(title=f"Token {token_idx} attends to (layer {layer}, head {head})", height=300, margin=dict(l=40, r=20, t=30, b=40)) |
|
|
|
|
|
return attn_fig, pca_fig, step_attn_fig |
|
|
|
|
|
|
|
|
with gr.Blocks(title="LLM Visualizer — Medium (HF Spaces)", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("# 🧠 LLM Visualizer — Medium (safe for free HF Spaces)") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=3): |
|
|
model_input = gr.Textbox(label="Model (Hugging Face name)", value=DEFAULT_MODEL) |
|
|
text_input = gr.Textbox(label="Input text", value="Hello world, this is a test.", lines=3) |
|
|
explain_simple = gr.Checkbox(label="Explain simply (kid/elder mode)", value=True) |
|
|
run_btn = gr.Button("Run visualizer", variant="primary") |
|
|
gr.Markdown("**Presets:**") |
|
|
with gr.Row(): |
|
|
gr.Button("Greeting").click(lambda: "Hello! How are you today?", None, text_input) |
|
|
gr.Button("Story start").click(lambda: "Once upon a time, there was a small robot...", None, text_input) |
|
|
gr.Button("Question").click(lambda: "Why is the sky blue?", None, text_input) |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
token_display = gr.Markdown("Tokens will appear here.") |
|
|
explanation_md = gr.Markdown("Explanation will appear here.") |
|
|
model_info = gr.Markdown("Model info: —") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
layer_slider = gr.Slider(label="Layer", minimum=0, maximum=0, step=1, value=0) |
|
|
head_slider = gr.Slider(label="Head", minimum=0, maximum=0, step=1, value=0) |
|
|
token_step = gr.Slider(label="Token index (step through tokens)", minimum=0, maximum=0, step=1, value=0) |
|
|
attn_plot = gr.Plot(label="Attention heatmap") |
|
|
with gr.Column(): |
|
|
pca_plot = gr.Plot(label="PCA hidden states (2D)") |
|
|
step_attn_plot = gr.Plot(label="Attention row for selected token") |
|
|
probs_plot = gr.Plot(label="Next-token top predictions") |
|
|
|
|
|
|
|
|
state = gr.State() |
|
|
|
|
|
|
|
|
run_btn.click( |
|
|
fn=run_analysis, |
|
|
inputs=[text_input, model_input, explain_simple], |
|
|
outputs=[ |
|
|
token_display, explanation_md, model_info, |
|
|
attn_plot, pca_plot, probs_plot, |
|
|
layer_slider, head_slider, token_step, |
|
|
state |
|
|
], |
|
|
) |
|
|
|
|
|
|
|
|
layer_slider.change(fn=update_visuals, inputs=[state, layer_slider, head_slider, token_step], outputs=[attn_plot, pca_plot, step_attn_plot]) |
|
|
head_slider.change(fn=update_visuals, inputs=[state, layer_slider, head_slider, token_step], outputs=[attn_plot, pca_plot, step_attn_plot]) |
|
|
token_step.change(fn=update_visuals, inputs=[state, layer_slider, head_slider, token_step], outputs=[attn_plot, pca_plot, step_attn_plot]) |
|
|
|
|
|
demo.launch() |