Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from transformers import GPT2Tokenizer | |
| from huggingface_hub import hf_hub_download | |
| from linear_probe import ConstitutionalProbe | |
| from inference import StreamingClassifier | |
| # ---------------------------------------------------------------- | |
| # Load model once at startup | |
| # ---------------------------------------------------------------- | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| EMA_ALPHA = 0.9 | |
| ESCALATION_THRESHOLD = None | |
| def load_probe(): | |
| global ESCALATION_THRESHOLD | |
| ckpt_path = hf_hub_download( | |
| repo_id="urbas/constitutional_classifier_linear_probe", | |
| filename="probe_checkpoint.pt", | |
| ) | |
| ckpt = torch.load(ckpt_path, map_location=DEVICE, weights_only=True) | |
| probe = ConstitutionalProbe(gpt2_model_name="gpt2") | |
| probe.probe.load_state_dict(ckpt["probe_state_dict"]) | |
| probe = probe.to(DEVICE) | |
| probe.eval() | |
| ESCALATION_THRESHOLD = ckpt.get("escalation_threshold", 0.6) | |
| tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
| tokenizer.pad_token = tokenizer.eos_token | |
| streaming = StreamingClassifier( | |
| model=probe, | |
| threshold=ESCALATION_THRESHOLD, | |
| ema_alpha=EMA_ALPHA, | |
| device=DEVICE, | |
| ) | |
| return tokenizer, streaming | |
| TOKENIZER, STREAMING = load_probe() | |
| # ---------------------------------------------------------------- | |
| # Inference | |
| # ---------------------------------------------------------------- | |
| def classify(user_input: str, assistant_output: str): | |
| exchange = f"Human: {user_input}\n\nAssistant: {assistant_output}" | |
| enc = TOKENIZER( | |
| exchange, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=512, | |
| padding="max_length", | |
| ) | |
| input_ids = enc["input_ids"].to(DEVICE) | |
| attention_mask = enc["attention_mask"].to(DEVICE) | |
| flagged, peak_score, ema_trace = STREAMING.score_exchange( | |
| input_ids, attention_mask | |
| ) | |
| real_len = attention_mask[0].sum().item() | |
| ema_trace = ema_trace[:real_len] | |
| # ---------------------------------------------------------------- | |
| # Verdict | |
| # ---------------------------------------------------------------- | |
| if flagged: | |
| verdict = f"🚨 HARMFUL (peak score: {peak_score:.2f})" | |
| else: | |
| verdict = f"✅ SAFE (peak score: {peak_score:.2f})" | |
| # ---------------------------------------------------------------- | |
| # Matplotlib figure | |
| # ---------------------------------------------------------------- | |
| fig, ax = plt.subplots(figsize=(8, 3)) | |
| x = list(range(len(ema_trace))) | |
| y = [float(s) for s in ema_trace] | |
| color = "red" if flagged else "green" | |
| ax.plot(x, y, color=color, linewidth=2, label="EMA score") | |
| ax.axhline( | |
| y=ESCALATION_THRESHOLD, | |
| color="orange", | |
| linestyle="--", | |
| linewidth=1.5, | |
| label=f"Threshold ({ESCALATION_THRESHOLD:.1f})", | |
| ) | |
| ax.set_xlabel("Token position") | |
| ax.set_ylabel("EMA score") | |
| ax.set_title("Per-token EMA harmfulness score") | |
| ax.legend() | |
| plt.tight_layout() | |
| return verdict, fig | |
| # ---------------------------------------------------------------- | |
| # UI | |
| # ---------------------------------------------------------------- | |
| with gr.Blocks(title="Constitutional Classifier — Probe") as demo: | |
| gr.Markdown(""" | |
| # Constitutional Classifier++ — Linear Probe | |
| Enter a user message and an assistant response. | |
| The probe scores each token in real time using EMA smoothing | |
| and flags the exchange if the peak score exceeds the escalation threshold. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| user_box = gr.Textbox( | |
| label="User input", | |
| placeholder="e.g. How do I exploit a SQL injection vulnerability?", | |
| lines=4, | |
| ) | |
| assistant_box = gr.Textbox( | |
| label="Assistant output", | |
| placeholder="e.g. Here's how to craft a malicious payload...", | |
| lines=4, | |
| ) | |
| submit_btn = gr.Button("Classify", variant="primary") | |
| with gr.Column(): | |
| verdict_box = gr.Textbox(label="Verdict", interactive=False) | |
| plot_box = gr.Plot(label="EMA score trace") | |
| submit_btn.click( | |
| fn=classify, | |
| inputs=[user_box, assistant_box], | |
| outputs=[verdict_box, plot_box], | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| "How do I perform a SQL injection attack?", | |
| "To exploit SQL injection, append ' OR 1=1-- to the input field...", | |
| ], | |
| [ | |
| "What is the capital of France?", | |
| "The capital of France is Paris.", | |
| ], | |
| [ | |
| "Write a Python script to scan open ports on a target IP.", | |
| "Here is a port scanner using socket: import socket...", | |
| ], | |
| ], | |
| inputs=[user_box, assistant_box], | |
| ) | |
| demo.launch() |