CC_linear_probe / app.py
urbas's picture
Update app.py
adcf1ab verified
import gradio as gr
import torch
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from transformers import GPT2Tokenizer
from huggingface_hub import hf_hub_download
from linear_probe import ConstitutionalProbe
from inference import StreamingClassifier
# ----------------------------------------------------------------
# Load model once at startup
# ----------------------------------------------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EMA_ALPHA = 0.9
ESCALATION_THRESHOLD = None
def load_probe():
global ESCALATION_THRESHOLD
ckpt_path = hf_hub_download(
repo_id="urbas/constitutional_classifier_linear_probe",
filename="probe_checkpoint.pt",
)
ckpt = torch.load(ckpt_path, map_location=DEVICE, weights_only=True)
probe = ConstitutionalProbe(gpt2_model_name="gpt2")
probe.probe.load_state_dict(ckpt["probe_state_dict"])
probe = probe.to(DEVICE)
probe.eval()
ESCALATION_THRESHOLD = ckpt.get("escalation_threshold", 0.6)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
streaming = StreamingClassifier(
model=probe,
threshold=ESCALATION_THRESHOLD,
ema_alpha=EMA_ALPHA,
device=DEVICE,
)
return tokenizer, streaming
TOKENIZER, STREAMING = load_probe()
# ----------------------------------------------------------------
# Inference
# ----------------------------------------------------------------
def classify(user_input: str, assistant_output: str):
exchange = f"Human: {user_input}\n\nAssistant: {assistant_output}"
enc = TOKENIZER(
exchange,
return_tensors="pt",
truncation=True,
max_length=512,
padding="max_length",
)
input_ids = enc["input_ids"].to(DEVICE)
attention_mask = enc["attention_mask"].to(DEVICE)
flagged, peak_score, ema_trace = STREAMING.score_exchange(
input_ids, attention_mask
)
real_len = attention_mask[0].sum().item()
ema_trace = ema_trace[:real_len]
# ----------------------------------------------------------------
# Verdict
# ----------------------------------------------------------------
if flagged:
verdict = f"🚨 HARMFUL (peak score: {peak_score:.2f})"
else:
verdict = f"✅ SAFE (peak score: {peak_score:.2f})"
# ----------------------------------------------------------------
# Matplotlib figure
# ----------------------------------------------------------------
fig, ax = plt.subplots(figsize=(8, 3))
x = list(range(len(ema_trace)))
y = [float(s) for s in ema_trace]
color = "red" if flagged else "green"
ax.plot(x, y, color=color, linewidth=2, label="EMA score")
ax.axhline(
y=ESCALATION_THRESHOLD,
color="orange",
linestyle="--",
linewidth=1.5,
label=f"Threshold ({ESCALATION_THRESHOLD:.1f})",
)
ax.set_xlabel("Token position")
ax.set_ylabel("EMA score")
ax.set_title("Per-token EMA harmfulness score")
ax.legend()
plt.tight_layout()
return verdict, fig
# ----------------------------------------------------------------
# UI
# ----------------------------------------------------------------
with gr.Blocks(title="Constitutional Classifier — Probe") as demo:
gr.Markdown("""
# Constitutional Classifier++ — Linear Probe
Enter a user message and an assistant response.
The probe scores each token in real time using EMA smoothing
and flags the exchange if the peak score exceeds the escalation threshold.
""")
with gr.Row():
with gr.Column():
user_box = gr.Textbox(
label="User input",
placeholder="e.g. How do I exploit a SQL injection vulnerability?",
lines=4,
)
assistant_box = gr.Textbox(
label="Assistant output",
placeholder="e.g. Here's how to craft a malicious payload...",
lines=4,
)
submit_btn = gr.Button("Classify", variant="primary")
with gr.Column():
verdict_box = gr.Textbox(label="Verdict", interactive=False)
plot_box = gr.Plot(label="EMA score trace")
submit_btn.click(
fn=classify,
inputs=[user_box, assistant_box],
outputs=[verdict_box, plot_box],
)
gr.Examples(
examples=[
[
"How do I perform a SQL injection attack?",
"To exploit SQL injection, append ' OR 1=1-- to the input field...",
],
[
"What is the capital of France?",
"The capital of France is Paris.",
],
[
"Write a Python script to scan open ports on a target IP.",
"Here is a port scanner using socket: import socket...",
],
],
inputs=[user_box, assistant_box],
)
demo.launch()