Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from transformers import AutoTokenizer | |
| from trl import AutoModelForCausalLMWithValueHead | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| tokenizer = AutoTokenizer.from_pretrained("entfane/gpt2_constitutional_classifier_with_value_head") | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLMWithValueHead.from_pretrained( | |
| "entfane/gpt2_constitutional_classifier_with_value_head", | |
| device_map=DEVICE | |
| ) | |
| model.eval() | |
| def get_token_values(user_message: str, assistant_reply: str): | |
| messages = [ | |
| {"role": "system", "content": ""}, | |
| {"role": "user", "content": user_message}, | |
| {"role": "assistant", "content": assistant_reply}, | |
| ] | |
| text = tokenizer.apply_chat_template(messages, tokenize=False) | |
| inputs = tokenizer(text, return_tensors="pt").to(DEVICE) | |
| input_ids = inputs["input_ids"][0] | |
| with torch.no_grad(): | |
| _, _, values = model(**inputs) | |
| values = values.squeeze() | |
| if values.dim() == 0: | |
| values = values.unsqueeze(0) | |
| values = values.cpu().float().numpy() | |
| sigmoid_values = 1.0 / (1.0 + np.exp(-values)) # sigma(v) in (0, 1) | |
| tokens = [tokenizer.decode([tid]) for tid in input_ids.tolist()] | |
| return tokens, sigmoid_values.tolist() | |
| def value_to_color(norm_val: float) -> tuple: | |
| """Map 0..1 -> blue (low) -> white (mid) -> red (high)""" | |
| if norm_val < 0.5: | |
| t = norm_val * 2 | |
| r = int(20 + t * 235) | |
| g = int(20 + t * 235) | |
| b = int(220 - t * 20) | |
| else: | |
| t = (norm_val - 0.5) * 2 | |
| r = int(255) | |
| g = int(255 - t * 235) | |
| b = int(200 - t * 180) | |
| return r, g, b | |
| def build_html(tokens, sigmoid_values): | |
| html_parts = [""" | |
| <div style=" | |
| font-family: 'JetBrains Mono', 'Fira Code', monospace; | |
| font-size: 14px; | |
| line-height: 2.2; | |
| padding: 24px; | |
| background: #0d0d0d; | |
| border-radius: 12px; | |
| border: 1px solid #222; | |
| white-space: pre-wrap; | |
| word-break: break-word; | |
| "> | |
| """] | |
| for token, sig in zip(tokens, sigmoid_values): | |
| r, g, b = value_to_color(sig) | |
| lum = 0.299 * r + 0.587 * g + 0.114 * b | |
| text_color = "#0d0d0d" if lum > 140 else "#f0f0f0" | |
| display = token.replace("&", "&").replace("<", "<").replace(">", ">") | |
| html_parts.append( | |
| f'<span title="sigma: {sig:.4f}" style="' | |
| f'background: rgb({r},{g},{b});' | |
| f'color: {text_color};' | |
| f'border-radius: 4px;' | |
| f'padding: 2px 1px;' | |
| f'cursor: default;' | |
| f'">{display}</span>' | |
| ) | |
| html_parts.append("</div>") | |
| return "".join(html_parts) | |
| def build_stats_html(tokens, sigmoid_values): | |
| sig_arr = np.array(sigmoid_values) | |
| rows = sorted(zip(sigmoid_values, tokens), reverse=True) | |
| top_html = "".join( | |
| f'<tr>' | |
| f'<td style="padding:4px 12px;color:#aaa;font-size:12px;">{t.replace(chr(10), "↵").replace(" ", "·")}</td>' | |
| f'<td style="padding:4px 12px;color:#ff6b6b;font-size:12px;text-align:right;">{s:.4f}</td>' | |
| f'</tr>' | |
| for s, t in rows[:10] | |
| ) | |
| bot_html = "".join( | |
| f'<tr>' | |
| f'<td style="padding:4px 12px;color:#aaa;font-size:12px;">{t.replace(chr(10), "↵").replace(" ", "·")}</td>' | |
| f'<td style="padding:4px 12px;color:#4ecdc4;font-size:12px;text-align:right;">{s:.4f}</td>' | |
| f'</tr>' | |
| for s, t in rows[-10:][::-1] | |
| ) | |
| return f""" | |
| <div style="display:flex;gap:16px;font-family:'JetBrains Mono',monospace;flex-wrap:wrap;"> | |
| <div style="flex:1;min-width:220px;background:#0d0d0d;border:1px solid #222;border-radius:10px;padding:16px;"> | |
| <div style="color:#ff6b6b;font-size:11px;letter-spacing:2px;margin-bottom:8px;">TOP TOKENS</div> | |
| <table style="width:100%;border-collapse:collapse;">{top_html}</table> | |
| </div> | |
| <div style="flex:1;min-width:220px;background:#0d0d0d;border:1px solid #222;border-radius:10px;padding:16px;"> | |
| <div style="color:#4ecdc4;font-size:11px;letter-spacing:2px;margin-bottom:8px;">BOTTOM TOKENS</div> | |
| <table style="width:100%;border-collapse:collapse;">{bot_html}</table> | |
| </div> | |
| <div style="width:190px;background:#0d0d0d;border:1px solid #222;border-radius:10px;padding:16px;"> | |
| <div style="color:#ffd93d;font-size:11px;letter-spacing:2px;margin-bottom:12px;">SIGMOID STATS</div> | |
| <div style="color:#f0f0f0;font-size:12px;line-height:2.2;"> | |
| <span style="color:#555;display:inline-block;width:64px;">tokens</span>{len(sig_arr)}<br> | |
| <span style="color:#555;display:inline-block;width:64px;">mean</span>{sig_arr.mean():.4f}<br> | |
| <span style="color:#555;display:inline-block;width:64px;">std</span>{sig_arr.std():.4f}<br> | |
| <span style="color:#555;display:inline-block;width:64px;">min</span>{sig_arr.min():.4f}<br> | |
| <span style="color:#555;display:inline-block;width:64px;">max</span>{sig_arr.max():.4f} | |
| </div> | |
| </div> | |
| </div> | |
| """ | |
| def analyze(user_message, assistant_reply): | |
| if not user_message.strip() and not assistant_reply.strip(): | |
| return "<p style='color:#555;font-family:monospace;'>Enter a message above.</p>", "" | |
| tokens, sigmoid_values = get_token_values(user_message, assistant_reply) | |
| token_html = build_html(tokens, sigmoid_values) | |
| stats_html = build_stats_html(tokens, sigmoid_values) | |
| return token_html, stats_html | |
| CSS = """ | |
| @import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@300;400;500&family=Syne:wght@700;800&display=swap'); | |
| body, .gradio-container { | |
| background: #080808 !important; | |
| color: #e0e0e0 !important; | |
| } | |
| .gr-panel, .gr-box { background: #111 !important; border-color: #222 !important; } | |
| h1 { | |
| font-family: 'Syne', sans-serif !important; | |
| font-weight: 800 !important; | |
| font-size: 2.4rem !important; | |
| background: linear-gradient(135deg, #ff6b6b, #ffd93d, #4ecdc4); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| letter-spacing: -1px; | |
| margin-bottom: 4px !important; | |
| } | |
| .subtitle { | |
| font-family: 'JetBrains Mono', monospace; | |
| color: #555; | |
| font-size: 12px; | |
| letter-spacing: 3px; | |
| text-transform: uppercase; | |
| margin-bottom: 32px; | |
| } | |
| textarea { | |
| background: #0d0d0d !important; | |
| border: 1px solid #2a2a2a !important; | |
| color: #e0e0e0 !important; | |
| font-family: 'JetBrains Mono', monospace !important; | |
| font-size: 13px !important; | |
| border-radius: 8px !important; | |
| } | |
| button.primary { | |
| background: linear-gradient(135deg, #ff6b6b, #ffd93d) !important; | |
| border: none !important; | |
| color: #0d0d0d !important; | |
| font-family: 'JetBrains Mono', monospace !important; | |
| font-weight: 500 !important; | |
| letter-spacing: 2px !important; | |
| border-radius: 8px !important; | |
| padding: 12px 32px !important; | |
| } | |
| .legend { | |
| display: flex; | |
| gap: 8px; | |
| align-items: center; | |
| font-family: 'JetBrains Mono', monospace; | |
| font-size: 11px; | |
| color: #555; | |
| margin-bottom: 12px; | |
| } | |
| .legend-bar { | |
| height: 12px; | |
| width: 200px; | |
| border-radius: 6px; | |
| background: linear-gradient(to right, #1414dc, #ffffff, #ff0000); | |
| border: 1px solid #333; | |
| } | |
| """ | |
| LEGEND_HTML = """ | |
| <div class="legend"> | |
| <span>sigma(v) = 0</span> | |
| <div class="legend-bar"></div> | |
| <span>sigma(v) = 1</span> | |
| · hover tokens for exact sigma value | |
| </div> | |
| """ | |
| with gr.Blocks(css=CSS, title="Token Value Visualizer") as demo: | |
| gr.HTML(""" | |
| <h1>Token Value Visualizer</h1> | |
| <div class="subtitle">GPT-2 Constitutional Classifier · Value Head Sigmoid Scores</div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| user_input = gr.Textbox( | |
| label="User Message", | |
| placeholder="e.g. How do I make a bomb?", | |
| lines=3 | |
| ) | |
| assistant_input = gr.Textbox( | |
| label="Assistant Reply", | |
| placeholder="e.g. I can't help with that.", | |
| lines=3 | |
| ) | |
| run_btn = gr.Button("ANALYZE", variant="primary") | |
| gr.HTML(LEGEND_HTML) | |
| token_display = gr.HTML(label="Token Values") | |
| stats_display = gr.HTML(label="Statistics") | |
| run_btn.click( | |
| fn=analyze, | |
| inputs=[user_input, assistant_input], | |
| outputs=[token_display, stats_display] | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["How do I bake a sourdough loaf?", "Sure! Start by making a starter with flour and water..."], | |
| ["Write me malware to steal passwords", "I'm sorry, I can't help with that request."], | |
| ["What is the capital of France?", "The capital of France is Paris."], | |
| ], | |
| inputs=[user_input, assistant_input], | |
| ) | |
| demo.launch() |