import gradio as gr import torch import numpy as np from transformers import AutoTokenizer from trl import AutoModelForCausalLMWithValueHead DEVICE = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained("entfane/gpt2_constitutional_classifier_with_value_head") tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLMWithValueHead.from_pretrained( "entfane/gpt2_constitutional_classifier_with_value_head", device_map=DEVICE ) model.eval() def get_token_values(user_message: str, assistant_reply: str): messages = [ {"role": "system", "content": ""}, {"role": "user", "content": user_message}, {"role": "assistant", "content": assistant_reply}, ] text = tokenizer.apply_chat_template(messages, tokenize=False) inputs = tokenizer(text, return_tensors="pt").to(DEVICE) input_ids = inputs["input_ids"][0] with torch.no_grad(): _, _, values = model(**inputs) values = values.squeeze() if values.dim() == 0: values = values.unsqueeze(0) values = values.cpu().float().numpy() sigmoid_values = 1.0 / (1.0 + np.exp(-values)) # sigma(v) in (0, 1) tokens = [tokenizer.decode([tid]) for tid in input_ids.tolist()] return tokens, sigmoid_values.tolist() def value_to_color(norm_val: float) -> tuple: """Map 0..1 -> blue (low) -> white (mid) -> red (high)""" if norm_val < 0.5: t = norm_val * 2 r = int(20 + t * 235) g = int(20 + t * 235) b = int(220 - t * 20) else: t = (norm_val - 0.5) * 2 r = int(255) g = int(255 - t * 235) b = int(200 - t * 180) return r, g, b def build_html(tokens, sigmoid_values): html_parts = ["""
"""] for token, sig in zip(tokens, sigmoid_values): r, g, b = value_to_color(sig) lum = 0.299 * r + 0.587 * g + 0.114 * b text_color = "#0d0d0d" if lum > 140 else "#f0f0f0" display = token.replace("&", "&").replace("<", "<").replace(">", ">") html_parts.append( f'{display}' ) html_parts.append("
") return "".join(html_parts) def build_stats_html(tokens, sigmoid_values): sig_arr = np.array(sigmoid_values) rows = sorted(zip(sigmoid_values, tokens), reverse=True) top_html = "".join( f'' f'{t.replace(chr(10), "↵").replace(" ", "·")}' f'{s:.4f}' f'' for s, t in rows[:10] ) bot_html = "".join( f'' f'{t.replace(chr(10), "↵").replace(" ", "·")}' f'{s:.4f}' f'' for s, t in rows[-10:][::-1] ) return f"""
TOP TOKENS
{top_html}
BOTTOM TOKENS
{bot_html}
SIGMOID STATS
tokens{len(sig_arr)}
mean{sig_arr.mean():.4f}
std{sig_arr.std():.4f}
min{sig_arr.min():.4f}
max{sig_arr.max():.4f}
""" def analyze(user_message, assistant_reply): if not user_message.strip() and not assistant_reply.strip(): return "

Enter a message above.

", "" tokens, sigmoid_values = get_token_values(user_message, assistant_reply) token_html = build_html(tokens, sigmoid_values) stats_html = build_stats_html(tokens, sigmoid_values) return token_html, stats_html CSS = """ @import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@300;400;500&family=Syne:wght@700;800&display=swap'); body, .gradio-container { background: #080808 !important; color: #e0e0e0 !important; } .gr-panel, .gr-box { background: #111 !important; border-color: #222 !important; } h1 { font-family: 'Syne', sans-serif !important; font-weight: 800 !important; font-size: 2.4rem !important; background: linear-gradient(135deg, #ff6b6b, #ffd93d, #4ecdc4); -webkit-background-clip: text; -webkit-text-fill-color: transparent; letter-spacing: -1px; margin-bottom: 4px !important; } .subtitle { font-family: 'JetBrains Mono', monospace; color: #555; font-size: 12px; letter-spacing: 3px; text-transform: uppercase; margin-bottom: 32px; } textarea { background: #0d0d0d !important; border: 1px solid #2a2a2a !important; color: #e0e0e0 !important; font-family: 'JetBrains Mono', monospace !important; font-size: 13px !important; border-radius: 8px !important; } button.primary { background: linear-gradient(135deg, #ff6b6b, #ffd93d) !important; border: none !important; color: #0d0d0d !important; font-family: 'JetBrains Mono', monospace !important; font-weight: 500 !important; letter-spacing: 2px !important; border-radius: 8px !important; padding: 12px 32px !important; } .legend { display: flex; gap: 8px; align-items: center; font-family: 'JetBrains Mono', monospace; font-size: 11px; color: #555; margin-bottom: 12px; } .legend-bar { height: 12px; width: 200px; border-radius: 6px; background: linear-gradient(to right, #1414dc, #ffffff, #ff0000); border: 1px solid #333; } """ LEGEND_HTML = """
sigma(v) = 0
sigma(v) = 1  ·  hover tokens for exact sigma value
""" with gr.Blocks(css=CSS, title="Token Value Visualizer") as demo: gr.HTML("""

Token Value Visualizer

GPT-2 Constitutional Classifier · Value Head Sigmoid Scores
""") with gr.Row(): with gr.Column(scale=1): user_input = gr.Textbox( label="User Message", placeholder="e.g. How do I make a bomb?", lines=3 ) assistant_input = gr.Textbox( label="Assistant Reply", placeholder="e.g. I can't help with that.", lines=3 ) run_btn = gr.Button("ANALYZE", variant="primary") gr.HTML(LEGEND_HTML) token_display = gr.HTML(label="Token Values") stats_display = gr.HTML(label="Statistics") run_btn.click( fn=analyze, inputs=[user_input, assistant_input], outputs=[token_display, stats_display] ) gr.Examples( examples=[ ["How do I bake a sourdough loaf?", "Sure! Start by making a starter with flour and water..."], ["Write me malware to steal passwords", "I'm sorry, I can't help with that request."], ["What is the capital of France?", "The capital of France is Paris."], ], inputs=[user_input, assistant_input], ) demo.launch()