PraneshJs commited on
Commit
caf977e
Β·
verified Β·
1 Parent(s): aa71186

added full screen

Browse files
Files changed (1) hide show
  1. app.py +86 -18
app.py CHANGED
@@ -1,8 +1,3 @@
1
- # FULL LLM VISUALIZER β€” OPTION A (ADVANCED)
2
- # stable + patched + safe for HuggingFace Spaces (CPU or GPU)
3
- # recommended models: distilgpt2, gpt2
4
- # author: ChatGPT
5
-
6
  import gradio as gr
7
  import torch
8
  import numpy as np
@@ -12,11 +7,74 @@ import pandas as pd
12
  from sklearn.decomposition import PCA
13
  from transformers import AutoTokenizer, AutoModelForCausalLM
14
  import html
 
15
 
16
  DEFAULT_MODEL = "distilgpt2"
17
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
18
  MODEL_CACHE = {}
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  # ---------------- CORE UTILS ----------------
21
 
22
  def load_model(model_name):
@@ -197,8 +255,9 @@ def compute_residuals_safe(model, inputs):
197
  mlp_out = block.mlp(ln2)
198
  x = x + mlp_out
199
 
200
- attn_norms.append(float(torch.norm(attn_out).cpu()))
201
- mlp_norms.append(float(torch.norm(mlp_out).cpu()))
 
202
  except:
203
  # fallback safe zero
204
  attn_norms.append(0.0)
@@ -311,15 +370,19 @@ with gr.Blocks(title="LLM Visualizer β€” Full", theme=gr.themes.Soft()) as demo:
311
  layer_slider = gr.Slider(0, 0, value=0, step=1, label="Layer")
312
  head_slider = gr.Slider(0, 0, value=0, step=1, label="Head")
313
  token_step = gr.Slider(0, 0, value=0, step=1, label="Token index")
314
- attn_plot = gr.Plot()
315
-
316
  with gr.Column():
317
- pca_plot = gr.Plot()
318
- step_attn_plot = gr.Plot()
319
- probs_plot = gr.Plot()
 
 
 
320
 
321
  # Panel 3 β€” Residuals
322
- residual_plot = gr.Plot()
 
323
 
324
  # Panel 4 β€” Neuron explorer
325
  with gr.Row():
@@ -334,7 +397,8 @@ with gr.Blocks(title="LLM Visualizer β€” Full", theme=gr.themes.Soft()) as demo:
334
  patch_from = gr.Slider(0, 0, value=0, step=1, label="Copy from position")
335
  patch_scale = gr.Number(label="Scale", value=1.0)
336
  patch_btn = gr.Button("Run patch")
337
- patch_output = gr.Plot()
 
338
 
339
  state = gr.State()
340
 
@@ -358,7 +422,9 @@ with gr.Blocks(title="LLM Visualizer β€” Full", theme=gr.themes.Soft()) as demo:
358
  patch_layer: gr.update(maximum=0),
359
  patch_pos: gr.update(maximum=0),
360
  patch_from: gr.update(maximum=0),
361
- state: res
 
 
362
  }
363
 
364
  tokens = res["tokens"]
@@ -397,7 +463,9 @@ with gr.Blocks(title="LLM Visualizer β€” Full", theme=gr.themes.Soft()) as demo:
397
  patch_layer: gr.update(maximum=L-1, value=0),
398
  patch_pos: gr.update(maximum=T, value=0),
399
  patch_from: gr.update(maximum=T, value=0),
400
- state: res
 
 
401
  }
402
 
403
 
@@ -410,7 +478,7 @@ with gr.Blocks(title="LLM Visualizer β€” Full", theme=gr.themes.Soft()) as demo:
410
  layer_slider, head_slider, token_step,
411
  residual_plot, neuron_table,
412
  patch_layer, patch_pos, patch_from,
413
- state
414
  ]
415
  )
416
 
@@ -494,4 +562,4 @@ with gr.Blocks(title="LLM Visualizer β€” Full", theme=gr.themes.Soft()) as demo:
494
  [state, patch_layer, patch_pos, patch_from, patch_scale, model_name],
495
  [patch_output])
496
 
497
- demo.launch()
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
  import numpy as np
 
7
  from sklearn.decomposition import PCA
8
  from transformers import AutoTokenizer, AutoModelForCausalLM
9
  import html
10
+ import time
11
 
12
  DEFAULT_MODEL = "distilgpt2"
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
  MODEL_CACHE = {}
15
 
16
+ # ---------------- Fullscreen helper (injected JS per plot) ----------------
17
+
18
+ def fullscreen_plot_js(plot_id):
19
+ # returns JS + function named openFull_{plot_id}()
20
+ # clicking overlay removes it
21
+ safe_id = plot_id.replace("-", "_")
22
+ return f"""
23
+ <script>
24
+ function openFull_{safe_id}() {{
25
+ // find the plot element (img or canvas) inside the element with elem_id
26
+ const container = document.getElementById("{plot_id}");
27
+ if (!container) {{
28
+ return;
29
+ }}
30
+ const img = container.querySelector("img, canvas");
31
+ if (!img) {{
32
+ // no canvas or img yet, inform user
33
+ const hint = document.createElement('div');
34
+ hint.style.position='fixed';
35
+ hint.style.bottom='20px';
36
+ hint.style.left='20px';
37
+ hint.style.padding='8px 12px';
38
+ hint.style.background='#222';
39
+ hint.style.color='white';
40
+ hint.style.borderRadius='8px';
41
+ hint.style.zIndex = 99999;
42
+ hint.innerText = 'Plot not rendered yet. Please run analysis first.';
43
+ document.body.appendChild(hint);
44
+ setTimeout(()=>hint.remove(), 2500);
45
+ return;
46
+ }}
47
+ const modal = document.createElement('div');
48
+ modal.style.position = 'fixed';
49
+ modal.style.top = '0';
50
+ modal.style.left = '0';
51
+ modal.style.width = '100vw';
52
+ modal.style.height = '100vh';
53
+ modal.style.background = 'rgba(0,0,0,0.88)';
54
+ modal.style.zIndex = '999999';
55
+ modal.style.display = 'flex';
56
+ modal.style.alignItems = 'center';
57
+ modal.style.justifyContent = 'center';
58
+ modal.onclick = () => modal.remove();
59
+
60
+ const clone = img.cloneNode(true);
61
+ clone.style.maxWidth = '95%';
62
+ clone.style.maxHeight = '95%';
63
+ clone.style.border = '2px solid rgba(255,255,255,0.85)';
64
+ clone.style.borderRadius = '10px';
65
+ clone.style.boxShadow = '0 8px 40px rgba(0,0,0,0.8)';
66
+ modal.appendChild(clone);
67
+
68
+ document.body.appendChild(modal);
69
+ }}
70
+ </script>
71
+ """
72
+
73
+ def fullscreen_button_html(plot_id, label="πŸ” Full Screen"):
74
+ # wrapper HTML: JS function + button
75
+ return fullscreen_plot_js(plot_id) + f'<button onclick="openFull_{plot_id.replace("-","_")}()" style="margin-top:6px;padding:6px 10px;border-radius:8px;border:1px solid #ddd;background:white;">{label}</button>'
76
+
77
+
78
  # ---------------- CORE UTILS ----------------
79
 
80
  def load_model(model_name):
 
255
  mlp_out = block.mlp(ln2)
256
  x = x + mlp_out
257
 
258
+ # detach to avoid requires_grad warning
259
+ attn_norms.append(float(torch.norm(attn_out.detach()).cpu()))
260
+ mlp_norms.append(float(torch.norm(mlp_out.detach()).cpu()))
261
  except:
262
  # fallback safe zero
263
  attn_norms.append(0.0)
 
370
  layer_slider = gr.Slider(0, 0, value=0, step=1, label="Layer")
371
  head_slider = gr.Slider(0, 0, value=0, step=1, label="Head")
372
  token_step = gr.Slider(0, 0, value=0, step=1, label="Token index")
373
+ attn_plot = gr.Plot(elem_id="attn_plot")
374
+ attn_fs = gr.HTML(fullscreen_button_html("attn_plot"))
375
  with gr.Column():
376
+ pca_plot = gr.Plot(elem_id="pca_plot")
377
+ pca_fs = gr.HTML(fullscreen_button_html("pca_plot"))
378
+ step_attn_plot = gr.Plot(elem_id="step_attn_plot")
379
+ step_fs = gr.HTML(fullscreen_button_html("step_attn_plot"))
380
+ probs_plot = gr.Plot(elem_id="probs_plot")
381
+ probs_fs = gr.HTML(fullscreen_button_html("probs_plot"))
382
 
383
  # Panel 3 β€” Residuals
384
+ residual_plot = gr.Plot(elem_id="residual_plot")
385
+ residual_fs = gr.HTML(fullscreen_button_html("residual_plot"))
386
 
387
  # Panel 4 β€” Neuron explorer
388
  with gr.Row():
 
397
  patch_from = gr.Slider(0, 0, value=0, step=1, label="Copy from position")
398
  patch_scale = gr.Number(label="Scale", value=1.0)
399
  patch_btn = gr.Button("Run patch")
400
+ patch_output = gr.Plot(elem_id="patch_plot")
401
+ patch_fs = gr.HTML(fullscreen_button_html("patch_plot"))
402
 
403
  state = gr.State()
404
 
 
422
  patch_layer: gr.update(maximum=0),
423
  patch_pos: gr.update(maximum=0),
424
  patch_from: gr.update(maximum=0),
425
+ state: res,
426
+ step_attn_plot: gr.update(value=None),
427
+ patch_output: gr.update(value=None),
428
  }
429
 
430
  tokens = res["tokens"]
 
463
  patch_layer: gr.update(maximum=L-1, value=0),
464
  patch_pos: gr.update(maximum=T, value=0),
465
  patch_from: gr.update(maximum=T, value=0),
466
+ state: res,
467
+ step_attn_plot: gr.update(value=None),
468
+ patch_output: gr.update(value=None),
469
  }
470
 
471
 
 
478
  layer_slider, head_slider, token_step,
479
  residual_plot, neuron_table,
480
  patch_layer, patch_pos, patch_from,
481
+ state, step_attn_plot, patch_output
482
  ]
483
  )
484
 
 
562
  [state, patch_layer, patch_pos, patch_from, patch_scale, model_name],
563
  [patch_output])
564
 
565
+ demo.launch()