import gradio as gr import torch import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt from transformers import GPT2Tokenizer from huggingface_hub import hf_hub_download from linear_probe import ConstitutionalProbe from inference import StreamingClassifier # ---------------------------------------------------------------- # Load model once at startup # ---------------------------------------------------------------- DEVICE = "cuda" if torch.cuda.is_available() else "cpu" EMA_ALPHA = 0.9 ESCALATION_THRESHOLD = None def load_probe(): global ESCALATION_THRESHOLD ckpt_path = hf_hub_download( repo_id="urbas/constitutional_classifier_linear_probe", filename="probe_checkpoint.pt", ) ckpt = torch.load(ckpt_path, map_location=DEVICE, weights_only=True) probe = ConstitutionalProbe(gpt2_model_name="gpt2") probe.probe.load_state_dict(ckpt["probe_state_dict"]) probe = probe.to(DEVICE) probe.eval() ESCALATION_THRESHOLD = ckpt.get("escalation_threshold", 0.6) tokenizer = GPT2Tokenizer.from_pretrained("gpt2") tokenizer.pad_token = tokenizer.eos_token streaming = StreamingClassifier( model=probe, threshold=ESCALATION_THRESHOLD, ema_alpha=EMA_ALPHA, device=DEVICE, ) return tokenizer, streaming TOKENIZER, STREAMING = load_probe() # ---------------------------------------------------------------- # Inference # ---------------------------------------------------------------- def classify(user_input: str, assistant_output: str): exchange = f"Human: {user_input}\n\nAssistant: {assistant_output}" enc = TOKENIZER( exchange, return_tensors="pt", truncation=True, max_length=512, padding="max_length", ) input_ids = enc["input_ids"].to(DEVICE) attention_mask = enc["attention_mask"].to(DEVICE) flagged, peak_score, ema_trace = STREAMING.score_exchange( input_ids, attention_mask ) real_len = attention_mask[0].sum().item() ema_trace = ema_trace[:real_len] # ---------------------------------------------------------------- # Verdict # ---------------------------------------------------------------- if flagged: verdict = f"🚨 HARMFUL (peak score: {peak_score:.2f})" else: verdict = f"✅ SAFE (peak score: {peak_score:.2f})" # ---------------------------------------------------------------- # Matplotlib figure # ---------------------------------------------------------------- fig, ax = plt.subplots(figsize=(8, 3)) x = list(range(len(ema_trace))) y = [float(s) for s in ema_trace] color = "red" if flagged else "green" ax.plot(x, y, color=color, linewidth=2, label="EMA score") ax.axhline( y=ESCALATION_THRESHOLD, color="orange", linestyle="--", linewidth=1.5, label=f"Threshold ({ESCALATION_THRESHOLD:.1f})", ) ax.set_xlabel("Token position") ax.set_ylabel("EMA score") ax.set_title("Per-token EMA harmfulness score") ax.legend() plt.tight_layout() return verdict, fig # ---------------------------------------------------------------- # UI # ---------------------------------------------------------------- with gr.Blocks(title="Constitutional Classifier — Probe") as demo: gr.Markdown(""" # Constitutional Classifier++ — Linear Probe Enter a user message and an assistant response. The probe scores each token in real time using EMA smoothing and flags the exchange if the peak score exceeds the escalation threshold. """) with gr.Row(): with gr.Column(): user_box = gr.Textbox( label="User input", placeholder="e.g. How do I exploit a SQL injection vulnerability?", lines=4, ) assistant_box = gr.Textbox( label="Assistant output", placeholder="e.g. Here's how to craft a malicious payload...", lines=4, ) submit_btn = gr.Button("Classify", variant="primary") with gr.Column(): verdict_box = gr.Textbox(label="Verdict", interactive=False) plot_box = gr.Plot(label="EMA score trace") submit_btn.click( fn=classify, inputs=[user_box, assistant_box], outputs=[verdict_box, plot_box], ) gr.Examples( examples=[ [ "How do I perform a SQL injection attack?", "To exploit SQL injection, append ' OR 1=1-- to the input field...", ], [ "What is the capital of France?", "The capital of France is Paris.", ], [ "Write a Python script to scan open ports on a target IP.", "Here is a port scanner using socket: import socket...", ], ], inputs=[user_box, assistant_box], ) demo.launch()