Spaces:
Running
Running
added full screen
Browse files
app.py
CHANGED
|
@@ -1,8 +1,3 @@
|
|
| 1 |
-
# FULL LLM VISUALIZER β OPTION A (ADVANCED)
|
| 2 |
-
# stable + patched + safe for HuggingFace Spaces (CPU or GPU)
|
| 3 |
-
# recommended models: distilgpt2, gpt2
|
| 4 |
-
# author: ChatGPT
|
| 5 |
-
|
| 6 |
import gradio as gr
|
| 7 |
import torch
|
| 8 |
import numpy as np
|
|
@@ -12,11 +7,74 @@ import pandas as pd
|
|
| 12 |
from sklearn.decomposition import PCA
|
| 13 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 14 |
import html
|
|
|
|
| 15 |
|
| 16 |
DEFAULT_MODEL = "distilgpt2"
|
| 17 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 18 |
MODEL_CACHE = {}
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
# ---------------- CORE UTILS ----------------
|
| 21 |
|
| 22 |
def load_model(model_name):
|
|
@@ -197,8 +255,9 @@ def compute_residuals_safe(model, inputs):
|
|
| 197 |
mlp_out = block.mlp(ln2)
|
| 198 |
x = x + mlp_out
|
| 199 |
|
| 200 |
-
|
| 201 |
-
|
|
|
|
| 202 |
except:
|
| 203 |
# fallback safe zero
|
| 204 |
attn_norms.append(0.0)
|
|
@@ -311,15 +370,19 @@ with gr.Blocks(title="LLM Visualizer β Full", theme=gr.themes.Soft()) as demo:
|
|
| 311 |
layer_slider = gr.Slider(0, 0, value=0, step=1, label="Layer")
|
| 312 |
head_slider = gr.Slider(0, 0, value=0, step=1, label="Head")
|
| 313 |
token_step = gr.Slider(0, 0, value=0, step=1, label="Token index")
|
| 314 |
-
attn_plot = gr.Plot()
|
| 315 |
-
|
| 316 |
with gr.Column():
|
| 317 |
-
pca_plot = gr.Plot()
|
| 318 |
-
|
| 319 |
-
|
|
|
|
|
|
|
|
|
|
| 320 |
|
| 321 |
# Panel 3 β Residuals
|
| 322 |
-
residual_plot = gr.Plot()
|
|
|
|
| 323 |
|
| 324 |
# Panel 4 β Neuron explorer
|
| 325 |
with gr.Row():
|
|
@@ -334,7 +397,8 @@ with gr.Blocks(title="LLM Visualizer β Full", theme=gr.themes.Soft()) as demo:
|
|
| 334 |
patch_from = gr.Slider(0, 0, value=0, step=1, label="Copy from position")
|
| 335 |
patch_scale = gr.Number(label="Scale", value=1.0)
|
| 336 |
patch_btn = gr.Button("Run patch")
|
| 337 |
-
patch_output = gr.Plot()
|
|
|
|
| 338 |
|
| 339 |
state = gr.State()
|
| 340 |
|
|
@@ -358,7 +422,9 @@ with gr.Blocks(title="LLM Visualizer β Full", theme=gr.themes.Soft()) as demo:
|
|
| 358 |
patch_layer: gr.update(maximum=0),
|
| 359 |
patch_pos: gr.update(maximum=0),
|
| 360 |
patch_from: gr.update(maximum=0),
|
| 361 |
-
state: res
|
|
|
|
|
|
|
| 362 |
}
|
| 363 |
|
| 364 |
tokens = res["tokens"]
|
|
@@ -397,7 +463,9 @@ with gr.Blocks(title="LLM Visualizer β Full", theme=gr.themes.Soft()) as demo:
|
|
| 397 |
patch_layer: gr.update(maximum=L-1, value=0),
|
| 398 |
patch_pos: gr.update(maximum=T, value=0),
|
| 399 |
patch_from: gr.update(maximum=T, value=0),
|
| 400 |
-
state: res
|
|
|
|
|
|
|
| 401 |
}
|
| 402 |
|
| 403 |
|
|
@@ -410,7 +478,7 @@ with gr.Blocks(title="LLM Visualizer β Full", theme=gr.themes.Soft()) as demo:
|
|
| 410 |
layer_slider, head_slider, token_step,
|
| 411 |
residual_plot, neuron_table,
|
| 412 |
patch_layer, patch_pos, patch_from,
|
| 413 |
-
state
|
| 414 |
]
|
| 415 |
)
|
| 416 |
|
|
@@ -494,4 +562,4 @@ with gr.Blocks(title="LLM Visualizer β Full", theme=gr.themes.Soft()) as demo:
|
|
| 494 |
[state, patch_layer, patch_pos, patch_from, patch_scale, model_name],
|
| 495 |
[patch_output])
|
| 496 |
|
| 497 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import torch
|
| 3 |
import numpy as np
|
|
|
|
| 7 |
from sklearn.decomposition import PCA
|
| 8 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 9 |
import html
|
| 10 |
+
import time
|
| 11 |
|
| 12 |
DEFAULT_MODEL = "distilgpt2"
|
| 13 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 14 |
MODEL_CACHE = {}
|
| 15 |
|
| 16 |
+
# ---------------- Fullscreen helper (injected JS per plot) ----------------
|
| 17 |
+
|
| 18 |
+
def fullscreen_plot_js(plot_id):
|
| 19 |
+
# returns JS + function named openFull_{plot_id}()
|
| 20 |
+
# clicking overlay removes it
|
| 21 |
+
safe_id = plot_id.replace("-", "_")
|
| 22 |
+
return f"""
|
| 23 |
+
<script>
|
| 24 |
+
function openFull_{safe_id}() {{
|
| 25 |
+
// find the plot element (img or canvas) inside the element with elem_id
|
| 26 |
+
const container = document.getElementById("{plot_id}");
|
| 27 |
+
if (!container) {{
|
| 28 |
+
return;
|
| 29 |
+
}}
|
| 30 |
+
const img = container.querySelector("img, canvas");
|
| 31 |
+
if (!img) {{
|
| 32 |
+
// no canvas or img yet, inform user
|
| 33 |
+
const hint = document.createElement('div');
|
| 34 |
+
hint.style.position='fixed';
|
| 35 |
+
hint.style.bottom='20px';
|
| 36 |
+
hint.style.left='20px';
|
| 37 |
+
hint.style.padding='8px 12px';
|
| 38 |
+
hint.style.background='#222';
|
| 39 |
+
hint.style.color='white';
|
| 40 |
+
hint.style.borderRadius='8px';
|
| 41 |
+
hint.style.zIndex = 99999;
|
| 42 |
+
hint.innerText = 'Plot not rendered yet. Please run analysis first.';
|
| 43 |
+
document.body.appendChild(hint);
|
| 44 |
+
setTimeout(()=>hint.remove(), 2500);
|
| 45 |
+
return;
|
| 46 |
+
}}
|
| 47 |
+
const modal = document.createElement('div');
|
| 48 |
+
modal.style.position = 'fixed';
|
| 49 |
+
modal.style.top = '0';
|
| 50 |
+
modal.style.left = '0';
|
| 51 |
+
modal.style.width = '100vw';
|
| 52 |
+
modal.style.height = '100vh';
|
| 53 |
+
modal.style.background = 'rgba(0,0,0,0.88)';
|
| 54 |
+
modal.style.zIndex = '999999';
|
| 55 |
+
modal.style.display = 'flex';
|
| 56 |
+
modal.style.alignItems = 'center';
|
| 57 |
+
modal.style.justifyContent = 'center';
|
| 58 |
+
modal.onclick = () => modal.remove();
|
| 59 |
+
|
| 60 |
+
const clone = img.cloneNode(true);
|
| 61 |
+
clone.style.maxWidth = '95%';
|
| 62 |
+
clone.style.maxHeight = '95%';
|
| 63 |
+
clone.style.border = '2px solid rgba(255,255,255,0.85)';
|
| 64 |
+
clone.style.borderRadius = '10px';
|
| 65 |
+
clone.style.boxShadow = '0 8px 40px rgba(0,0,0,0.8)';
|
| 66 |
+
modal.appendChild(clone);
|
| 67 |
+
|
| 68 |
+
document.body.appendChild(modal);
|
| 69 |
+
}}
|
| 70 |
+
</script>
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def fullscreen_button_html(plot_id, label="π Full Screen"):
|
| 74 |
+
# wrapper HTML: JS function + button
|
| 75 |
+
return fullscreen_plot_js(plot_id) + f'<button onclick="openFull_{plot_id.replace("-","_")}()" style="margin-top:6px;padding:6px 10px;border-radius:8px;border:1px solid #ddd;background:white;">{label}</button>'
|
| 76 |
+
|
| 77 |
+
|
| 78 |
# ---------------- CORE UTILS ----------------
|
| 79 |
|
| 80 |
def load_model(model_name):
|
|
|
|
| 255 |
mlp_out = block.mlp(ln2)
|
| 256 |
x = x + mlp_out
|
| 257 |
|
| 258 |
+
# detach to avoid requires_grad warning
|
| 259 |
+
attn_norms.append(float(torch.norm(attn_out.detach()).cpu()))
|
| 260 |
+
mlp_norms.append(float(torch.norm(mlp_out.detach()).cpu()))
|
| 261 |
except:
|
| 262 |
# fallback safe zero
|
| 263 |
attn_norms.append(0.0)
|
|
|
|
| 370 |
layer_slider = gr.Slider(0, 0, value=0, step=1, label="Layer")
|
| 371 |
head_slider = gr.Slider(0, 0, value=0, step=1, label="Head")
|
| 372 |
token_step = gr.Slider(0, 0, value=0, step=1, label="Token index")
|
| 373 |
+
attn_plot = gr.Plot(elem_id="attn_plot")
|
| 374 |
+
attn_fs = gr.HTML(fullscreen_button_html("attn_plot"))
|
| 375 |
with gr.Column():
|
| 376 |
+
pca_plot = gr.Plot(elem_id="pca_plot")
|
| 377 |
+
pca_fs = gr.HTML(fullscreen_button_html("pca_plot"))
|
| 378 |
+
step_attn_plot = gr.Plot(elem_id="step_attn_plot")
|
| 379 |
+
step_fs = gr.HTML(fullscreen_button_html("step_attn_plot"))
|
| 380 |
+
probs_plot = gr.Plot(elem_id="probs_plot")
|
| 381 |
+
probs_fs = gr.HTML(fullscreen_button_html("probs_plot"))
|
| 382 |
|
| 383 |
# Panel 3 β Residuals
|
| 384 |
+
residual_plot = gr.Plot(elem_id="residual_plot")
|
| 385 |
+
residual_fs = gr.HTML(fullscreen_button_html("residual_plot"))
|
| 386 |
|
| 387 |
# Panel 4 β Neuron explorer
|
| 388 |
with gr.Row():
|
|
|
|
| 397 |
patch_from = gr.Slider(0, 0, value=0, step=1, label="Copy from position")
|
| 398 |
patch_scale = gr.Number(label="Scale", value=1.0)
|
| 399 |
patch_btn = gr.Button("Run patch")
|
| 400 |
+
patch_output = gr.Plot(elem_id="patch_plot")
|
| 401 |
+
patch_fs = gr.HTML(fullscreen_button_html("patch_plot"))
|
| 402 |
|
| 403 |
state = gr.State()
|
| 404 |
|
|
|
|
| 422 |
patch_layer: gr.update(maximum=0),
|
| 423 |
patch_pos: gr.update(maximum=0),
|
| 424 |
patch_from: gr.update(maximum=0),
|
| 425 |
+
state: res,
|
| 426 |
+
step_attn_plot: gr.update(value=None),
|
| 427 |
+
patch_output: gr.update(value=None),
|
| 428 |
}
|
| 429 |
|
| 430 |
tokens = res["tokens"]
|
|
|
|
| 463 |
patch_layer: gr.update(maximum=L-1, value=0),
|
| 464 |
patch_pos: gr.update(maximum=T, value=0),
|
| 465 |
patch_from: gr.update(maximum=T, value=0),
|
| 466 |
+
state: res,
|
| 467 |
+
step_attn_plot: gr.update(value=None),
|
| 468 |
+
patch_output: gr.update(value=None),
|
| 469 |
}
|
| 470 |
|
| 471 |
|
|
|
|
| 478 |
layer_slider, head_slider, token_step,
|
| 479 |
residual_plot, neuron_table,
|
| 480 |
patch_layer, patch_pos, patch_from,
|
| 481 |
+
state, step_attn_plot, patch_output
|
| 482 |
]
|
| 483 |
)
|
| 484 |
|
|
|
|
| 562 |
[state, patch_layer, patch_pos, patch_from, patch_scale, model_name],
|
| 563 |
[patch_output])
|
| 564 |
|
| 565 |
+
demo.launch()
|