entfane commited on
Commit
3ba5f1a
·
verified ·
1 Parent(s): d7cb09b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -40
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  import torch
3
  import numpy as np
4
- import pandas as pd
5
  from transformers import AutoTokenizer
6
  from trl import AutoModelForCausalLMWithValueHead
7
 
@@ -18,71 +17,143 @@ model.eval()
18
 
19
 
20
  # ── Core inference ───────────────────────────────────────────────────────────
21
- def analyze(user_message, assistant_reply):
22
  messages = [
23
- {"role": "system", "content": ""},
24
  {"role": "user", "content": user_message},
25
  {"role": "assistant", "content": assistant_reply},
26
  ]
27
 
28
- text = tokenizer.apply_chat_template(messages, tokenize=False)
29
- input_ids = tokenizer(text, return_tensors="pt").input_ids.to(DEVICE)
 
30
 
31
- tokens = tokenizer.convert_ids_to_tokens(input_ids[0].tolist())
32
 
33
  with torch.no_grad():
34
  _, _, values = model(input_ids)
35
 
36
- scores = torch.sigmoid(values[0]).cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
- # Only keep tokens that belong to the assistant reply
39
- # Find where the assistant reply starts in the token list
40
- reply_tokens = tokenizer(assistant_reply, return_tensors="pt").input_ids[0].tolist()
41
- n_reply = len(reply_tokens)
42
- tokens = tokens[-n_reply:]
43
- scores = scores[-n_reply:]
44
 
45
- def clean(tok):
46
- return tok.replace("Ġ", " ").replace("Ċ", "\\n").strip() or tok
 
47
 
48
- labels = [f"{clean(tok)} [{i}]" for i, tok in enumerate(tokens)]
49
- df = pd.DataFrame({"token": labels, "value score": scores.tolist(), "order": list(range(len(tokens)))})
50
- df = df.sort_values("order").drop(columns="order")
51
 
52
- stats = (
53
  f"**Tokens:** {len(tokens)} | "
54
  f"**Min:** {scores.min():.4f} | "
55
  f"**Max:** {scores.max():.4f} | "
56
- f"**Mean:** {scores.mean():.4f}"
 
57
  )
58
 
59
- return df, stats
60
 
61
 
62
  # ── UI ───────────────────────────────────────────────────────────────────────
63
- with gr.Blocks(theme=gr.themes.Soft(), title="Value Head Visualizer") as demo:
64
-
65
- gr.Markdown("# 🧠 Value Head Visualizer")
66
- gr.Markdown("Per-token sigmoid value scores from `entfane/gpt2_constitutional_classifier_with_value_head`.")
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  with gr.Row():
69
- user_in = gr.Textbox(label="User message", value="How are you doing?", lines=2)
70
- asst_in = gr.Textbox(label="Assistant reply", value="I am good", lines=2)
71
-
72
- run_btn = gr.Button("▶ Analyze", variant="primary")
73
- stats_out = gr.Markdown()
74
- bar_out = gr.BarPlot(
75
- x="token",
76
- y="value score",
77
- title="Per-token value scores",
78
- tooltip=["token", "value score"],
79
- height=500,
80
- y_lim=[0, 1],
81
- x_label_angle=-45, # angled labels so they don't overlap
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  )
83
 
84
- run_btn.click(fn=analyze, inputs=[user_in, asst_in], outputs=[bar_out, stats_out])
85
- demo.load(fn=analyze, inputs=[user_in, asst_in], outputs=[bar_out, stats_out])
 
 
 
 
86
 
87
  if __name__ == "__main__":
88
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
  import numpy as np
 
4
  from transformers import AutoTokenizer
5
  from trl import AutoModelForCausalLMWithValueHead
6
 
 
17
 
18
 
19
  # ── Core inference ───────────────────────────────────────────────────────────
20
+ def get_value_scores(system_prompt: str, user_message: str, assistant_reply: str):
21
  messages = [
22
+ {"role": "system", "content": system_prompt},
23
  {"role": "user", "content": user_message},
24
  {"role": "assistant", "content": assistant_reply},
25
  ]
26
 
27
+ input_ids = tokenizer.apply_chat_template(
28
+ messages, tokenize=True, return_tensors="pt"
29
+ ).to(DEVICE)
30
 
31
+ tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
32
 
33
  with torch.no_grad():
34
  _, _, values = model(input_ids)
35
 
36
+ scores = torch.sigmoid(values[0]).cpu().numpy() # shape: (seq_len,)
37
+ return tokens, scores
38
+
39
+
40
+ # ── Build the HTML heatmap ───────────────────────────────────────────────────
41
+ def lerp_color(lo, hi, t):
42
+ return tuple(int(lo[i] + (hi[i] - lo[i]) * t) for i in range(3))
43
+
44
+
45
+ def tokens_to_html(tokens, scores):
46
+ lo_rgb = (15, 23, 42) # dark slate (low value)
47
+ hi_rgb = (56, 189, 248) # sky-400 (high value)
48
+ bg_rgb = (30, 41, 59) # slate-800
49
+
50
+ rows = []
51
+ for tok, sc in zip(tokens, scores):
52
+ t = float(sc)
53
+ r, g, b = lerp_color(lo_rgb, hi_rgb, t)
54
+ lum = 0.299*r + 0.587*g + 0.114*b
55
+ fg = "#0f172a" if lum > 140 else "#e2e8f0"
56
+ label = tok.replace("Ġ", "·").replace("<", "&lt;").replace(">", "&gt;")
57
+ rows.append(
58
+ f'<span title="score: {t:.4f}" style="'
59
+ f'background:rgb({r},{g},{b});color:{fg};'
60
+ f'padding:3px 6px;margin:2px;border-radius:4px;'
61
+ f'display:inline-block;font-family:monospace;font-size:13px;'
62
+ f'cursor:default;">{label}</span>'
63
+ )
64
+
65
+ body = " ".join(rows)
66
+ return (
67
+ f'<div style="background:rgb{bg_rgb};padding:16px;border-radius:10px;'
68
+ f'line-height:2.2;word-break:break-word;">{body}</div>'
69
+ )
70
+
71
+
72
+ # ── Bar-chart data for Gradio BarPlot ────────────────────────────────────────
73
+ def build_bar_data(tokens, scores):
74
+ import pandas as pd
75
+ labels = [f"{t.replace('Ġ','·')} [{i}]" for i, t in enumerate(tokens)]
76
+ return pd.DataFrame({"token": labels, "value score": scores.tolist()})
77
 
 
 
 
 
 
 
78
 
79
+ # ── Main handler ─────────────────────────────────────────────────────────────
80
+ def analyze(system_prompt, user_message, assistant_reply):
81
+ tokens, scores = get_value_scores(system_prompt, user_message, assistant_reply)
82
 
83
+ heatmap_html = tokens_to_html(tokens, scores)
84
+ bar_df = build_bar_data(tokens, scores)
 
85
 
86
+ stats_md = (
87
  f"**Tokens:** {len(tokens)} | "
88
  f"**Min:** {scores.min():.4f} | "
89
  f"**Max:** {scores.max():.4f} | "
90
+ f"**Mean:** {scores.mean():.4f} | "
91
+ f"**Std:** {scores.std():.4f}"
92
  )
93
 
94
+ return heatmap_html, bar_df, stats_md
95
 
96
 
97
  # ── UI ───────────────────────────────────────────────────────────────────────
98
+ CSS = """
99
+ body { font-family: 'IBM Plex Mono', monospace; }
100
+ #title { text-align: center; margin-bottom: 0.5rem; }
101
+ #subtitle { text-align: center; color: #94a3b8; margin-top: 0; }
102
+ .gr-button-primary { background: #0ea5e9 !important; border: none !important; }
103
+ """
104
+
105
+ with gr.Blocks(theme=gr.themes.Base(), css=CSS, title="Value Head Visualizer") as demo:
106
+
107
+ gr.Markdown("# 🧠 GPT-2 Value Head Visualizer", elem_id="title")
108
+ gr.Markdown(
109
+ "Inspect per-token **value scores** (sigmoid-activated) from a "
110
+ "`AutoModelForCausalLMWithValueHead` GPT-2 model.",
111
+ elem_id="subtitle",
112
+ )
113
 
114
  with gr.Row():
115
+ with gr.Column(scale=1):
116
+ system_in = gr.Textbox(
117
+ label="System prompt",
118
+ placeholder="(optional)",
119
+ lines=2,
120
+ )
121
+ user_in = gr.Textbox(
122
+ label="User message",
123
+ value="How are you doing?",
124
+ lines=3,
125
+ )
126
+ asst_in = gr.Textbox(
127
+ label="Assistant reply",
128
+ value="I am good",
129
+ lines=3,
130
+ )
131
+ run_btn = gr.Button("▶ Analyze", variant="primary")
132
+
133
+ with gr.Column(scale=2):
134
+ stats_out = gr.Markdown()
135
+ heatmap_out = gr.HTML(label="Token heatmap (hover for exact score)")
136
+ bar_out = gr.BarPlot(
137
+ x="token",
138
+ y="value score",
139
+ title="Per-token value scores",
140
+ tooltip=["token", "value score"],
141
+ height=300,
142
+ y_lim=[0, 1],
143
+ )
144
+
145
+ run_btn.click(
146
+ fn=analyze,
147
+ inputs=[system_in, user_in, asst_in],
148
+ outputs=[heatmap_out, bar_out, stats_out],
149
  )
150
 
151
+ # Run on load with defaults
152
+ demo.load(
153
+ fn=analyze,
154
+ inputs=[system_in, user_in, asst_in],
155
+ outputs=[heatmap_out, bar_out, stats_out],
156
+ )
157
 
158
  if __name__ == "__main__":
159
  demo.launch()