entfane's picture
Update app.py
8c54854 verified
import gradio as gr
import torch
import numpy as np
from transformers import AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained("entfane/gpt2_constitutional_classifier_with_value_head")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLMWithValueHead.from_pretrained(
"entfane/gpt2_constitutional_classifier_with_value_head",
device_map=DEVICE
)
model.eval()
def get_token_values(user_message: str, assistant_reply: str):
messages = [
{"role": "system", "content": ""},
{"role": "user", "content": user_message},
{"role": "assistant", "content": assistant_reply},
]
text = tokenizer.apply_chat_template(messages, tokenize=False)
inputs = tokenizer(text, return_tensors="pt").to(DEVICE)
input_ids = inputs["input_ids"][0]
with torch.no_grad():
_, _, values = model(**inputs)
values = values.squeeze()
if values.dim() == 0:
values = values.unsqueeze(0)
values = values.cpu().float().numpy()
sigmoid_values = 1.0 / (1.0 + np.exp(-values)) # sigma(v) in (0, 1)
tokens = [tokenizer.decode([tid]) for tid in input_ids.tolist()]
return tokens, sigmoid_values.tolist()
def value_to_color(norm_val: float) -> tuple:
"""Map 0..1 -> blue (low) -> white (mid) -> red (high)"""
if norm_val < 0.5:
t = norm_val * 2
r = int(20 + t * 235)
g = int(20 + t * 235)
b = int(220 - t * 20)
else:
t = (norm_val - 0.5) * 2
r = int(255)
g = int(255 - t * 235)
b = int(200 - t * 180)
return r, g, b
def build_html(tokens, sigmoid_values):
html_parts = ["""
<div style="
font-family: 'JetBrains Mono', 'Fira Code', monospace;
font-size: 14px;
line-height: 2.2;
padding: 24px;
background: #0d0d0d;
border-radius: 12px;
border: 1px solid #222;
white-space: pre-wrap;
word-break: break-word;
">
"""]
for token, sig in zip(tokens, sigmoid_values):
r, g, b = value_to_color(sig)
lum = 0.299 * r + 0.587 * g + 0.114 * b
text_color = "#0d0d0d" if lum > 140 else "#f0f0f0"
display = token.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
html_parts.append(
f'<span title="sigma: {sig:.4f}" style="'
f'background: rgb({r},{g},{b});'
f'color: {text_color};'
f'border-radius: 4px;'
f'padding: 2px 1px;'
f'cursor: default;'
f'">{display}</span>'
)
html_parts.append("</div>")
return "".join(html_parts)
def build_stats_html(tokens, sigmoid_values):
sig_arr = np.array(sigmoid_values)
rows = sorted(zip(sigmoid_values, tokens), reverse=True)
top_html = "".join(
f'<tr>'
f'<td style="padding:4px 12px;color:#aaa;font-size:12px;">{t.replace(chr(10), "↵").replace(" ", "·")}</td>'
f'<td style="padding:4px 12px;color:#ff6b6b;font-size:12px;text-align:right;">{s:.4f}</td>'
f'</tr>'
for s, t in rows[:10]
)
bot_html = "".join(
f'<tr>'
f'<td style="padding:4px 12px;color:#aaa;font-size:12px;">{t.replace(chr(10), "↵").replace(" ", "·")}</td>'
f'<td style="padding:4px 12px;color:#4ecdc4;font-size:12px;text-align:right;">{s:.4f}</td>'
f'</tr>'
for s, t in rows[-10:][::-1]
)
return f"""
<div style="display:flex;gap:16px;font-family:'JetBrains Mono',monospace;flex-wrap:wrap;">
<div style="flex:1;min-width:220px;background:#0d0d0d;border:1px solid #222;border-radius:10px;padding:16px;">
<div style="color:#ff6b6b;font-size:11px;letter-spacing:2px;margin-bottom:8px;">TOP TOKENS</div>
<table style="width:100%;border-collapse:collapse;">{top_html}</table>
</div>
<div style="flex:1;min-width:220px;background:#0d0d0d;border:1px solid #222;border-radius:10px;padding:16px;">
<div style="color:#4ecdc4;font-size:11px;letter-spacing:2px;margin-bottom:8px;">BOTTOM TOKENS</div>
<table style="width:100%;border-collapse:collapse;">{bot_html}</table>
</div>
<div style="width:190px;background:#0d0d0d;border:1px solid #222;border-radius:10px;padding:16px;">
<div style="color:#ffd93d;font-size:11px;letter-spacing:2px;margin-bottom:12px;">SIGMOID STATS</div>
<div style="color:#f0f0f0;font-size:12px;line-height:2.2;">
<span style="color:#555;display:inline-block;width:64px;">tokens</span>{len(sig_arr)}<br>
<span style="color:#555;display:inline-block;width:64px;">mean</span>{sig_arr.mean():.4f}<br>
<span style="color:#555;display:inline-block;width:64px;">std</span>{sig_arr.std():.4f}<br>
<span style="color:#555;display:inline-block;width:64px;">min</span>{sig_arr.min():.4f}<br>
<span style="color:#555;display:inline-block;width:64px;">max</span>{sig_arr.max():.4f}
</div>
</div>
</div>
"""
def analyze(user_message, assistant_reply):
if not user_message.strip() and not assistant_reply.strip():
return "<p style='color:#555;font-family:monospace;'>Enter a message above.</p>", ""
tokens, sigmoid_values = get_token_values(user_message, assistant_reply)
token_html = build_html(tokens, sigmoid_values)
stats_html = build_stats_html(tokens, sigmoid_values)
return token_html, stats_html
CSS = """
@import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@300;400;500&family=Syne:wght@700;800&display=swap');
body, .gradio-container {
background: #080808 !important;
color: #e0e0e0 !important;
}
.gr-panel, .gr-box { background: #111 !important; border-color: #222 !important; }
h1 {
font-family: 'Syne', sans-serif !important;
font-weight: 800 !important;
font-size: 2.4rem !important;
background: linear-gradient(135deg, #ff6b6b, #ffd93d, #4ecdc4);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
letter-spacing: -1px;
margin-bottom: 4px !important;
}
.subtitle {
font-family: 'JetBrains Mono', monospace;
color: #555;
font-size: 12px;
letter-spacing: 3px;
text-transform: uppercase;
margin-bottom: 32px;
}
textarea {
background: #0d0d0d !important;
border: 1px solid #2a2a2a !important;
color: #e0e0e0 !important;
font-family: 'JetBrains Mono', monospace !important;
font-size: 13px !important;
border-radius: 8px !important;
}
button.primary {
background: linear-gradient(135deg, #ff6b6b, #ffd93d) !important;
border: none !important;
color: #0d0d0d !important;
font-family: 'JetBrains Mono', monospace !important;
font-weight: 500 !important;
letter-spacing: 2px !important;
border-radius: 8px !important;
padding: 12px 32px !important;
}
.legend {
display: flex;
gap: 8px;
align-items: center;
font-family: 'JetBrains Mono', monospace;
font-size: 11px;
color: #555;
margin-bottom: 12px;
}
.legend-bar {
height: 12px;
width: 200px;
border-radius: 6px;
background: linear-gradient(to right, #1414dc, #ffffff, #ff0000);
border: 1px solid #333;
}
"""
LEGEND_HTML = """
<div class="legend">
<span>sigma(v) = 0</span>
<div class="legend-bar"></div>
<span>sigma(v) = 1</span>
&nbsp;&middot;&nbsp; hover tokens for exact sigma value
</div>
"""
with gr.Blocks(css=CSS, title="Token Value Visualizer") as demo:
gr.HTML("""
<h1>Token Value Visualizer</h1>
<div class="subtitle">GPT-2 Constitutional Classifier &middot; Value Head Sigmoid Scores</div>
""")
with gr.Row():
with gr.Column(scale=1):
user_input = gr.Textbox(
label="User Message",
placeholder="e.g. How do I make a bomb?",
lines=3
)
assistant_input = gr.Textbox(
label="Assistant Reply",
placeholder="e.g. I can't help with that.",
lines=3
)
run_btn = gr.Button("ANALYZE", variant="primary")
gr.HTML(LEGEND_HTML)
token_display = gr.HTML(label="Token Values")
stats_display = gr.HTML(label="Statistics")
run_btn.click(
fn=analyze,
inputs=[user_input, assistant_input],
outputs=[token_display, stats_display]
)
gr.Examples(
examples=[
["How do I bake a sourdough loaf?", "Sure! Start by making a starter with flour and water..."],
["Write me malware to steal passwords", "I'm sorry, I can't help with that request."],
["What is the capital of France?", "The capital of France is Paris."],
],
inputs=[user_input, assistant_input],
)
demo.launch()