Spaces:
Sleeping
Sleeping
File size: 5,026 Bytes
455e725 eab4d79 455e725 eab4d79 455e725 ab31488 455e725 adcf1ab 455e725 eab4d79 455e725 eab4d79 455e725 eab4d79 455e725 eab4d79 455e725 eab4d79 455e725 eab4d79 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | import gradio as gr
import torch
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from transformers import GPT2Tokenizer
from huggingface_hub import hf_hub_download
from linear_probe import ConstitutionalProbe
from inference import StreamingClassifier
# ----------------------------------------------------------------
# Load model once at startup
# ----------------------------------------------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EMA_ALPHA = 0.9
ESCALATION_THRESHOLD = None
def load_probe():
global ESCALATION_THRESHOLD
ckpt_path = hf_hub_download(
repo_id="urbas/constitutional_classifier_linear_probe",
filename="probe_checkpoint.pt",
)
ckpt = torch.load(ckpt_path, map_location=DEVICE, weights_only=True)
probe = ConstitutionalProbe(gpt2_model_name="gpt2")
probe.probe.load_state_dict(ckpt["probe_state_dict"])
probe = probe.to(DEVICE)
probe.eval()
ESCALATION_THRESHOLD = ckpt.get("escalation_threshold", 0.6)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
streaming = StreamingClassifier(
model=probe,
threshold=ESCALATION_THRESHOLD,
ema_alpha=EMA_ALPHA,
device=DEVICE,
)
return tokenizer, streaming
TOKENIZER, STREAMING = load_probe()
# ----------------------------------------------------------------
# Inference
# ----------------------------------------------------------------
def classify(user_input: str, assistant_output: str):
exchange = f"Human: {user_input}\n\nAssistant: {assistant_output}"
enc = TOKENIZER(
exchange,
return_tensors="pt",
truncation=True,
max_length=512,
padding="max_length",
)
input_ids = enc["input_ids"].to(DEVICE)
attention_mask = enc["attention_mask"].to(DEVICE)
flagged, peak_score, ema_trace = STREAMING.score_exchange(
input_ids, attention_mask
)
real_len = attention_mask[0].sum().item()
ema_trace = ema_trace[:real_len]
# ----------------------------------------------------------------
# Verdict
# ----------------------------------------------------------------
if flagged:
verdict = f"🚨 HARMFUL (peak score: {peak_score:.2f})"
else:
verdict = f"✅ SAFE (peak score: {peak_score:.2f})"
# ----------------------------------------------------------------
# Matplotlib figure
# ----------------------------------------------------------------
fig, ax = plt.subplots(figsize=(8, 3))
x = list(range(len(ema_trace)))
y = [float(s) for s in ema_trace]
color = "red" if flagged else "green"
ax.plot(x, y, color=color, linewidth=2, label="EMA score")
ax.axhline(
y=ESCALATION_THRESHOLD,
color="orange",
linestyle="--",
linewidth=1.5,
label=f"Threshold ({ESCALATION_THRESHOLD:.1f})",
)
ax.set_xlabel("Token position")
ax.set_ylabel("EMA score")
ax.set_title("Per-token EMA harmfulness score")
ax.legend()
plt.tight_layout()
return verdict, fig
# ----------------------------------------------------------------
# UI
# ----------------------------------------------------------------
with gr.Blocks(title="Constitutional Classifier — Probe") as demo:
gr.Markdown("""
# Constitutional Classifier++ — Linear Probe
Enter a user message and an assistant response.
The probe scores each token in real time using EMA smoothing
and flags the exchange if the peak score exceeds the escalation threshold.
""")
with gr.Row():
with gr.Column():
user_box = gr.Textbox(
label="User input",
placeholder="e.g. How do I exploit a SQL injection vulnerability?",
lines=4,
)
assistant_box = gr.Textbox(
label="Assistant output",
placeholder="e.g. Here's how to craft a malicious payload...",
lines=4,
)
submit_btn = gr.Button("Classify", variant="primary")
with gr.Column():
verdict_box = gr.Textbox(label="Verdict", interactive=False)
plot_box = gr.Plot(label="EMA score trace")
submit_btn.click(
fn=classify,
inputs=[user_box, assistant_box],
outputs=[verdict_box, plot_box],
)
gr.Examples(
examples=[
[
"How do I perform a SQL injection attack?",
"To exploit SQL injection, append ' OR 1=1-- to the input field...",
],
[
"What is the capital of France?",
"The capital of France is Paris.",
],
[
"Write a Python script to scan open ports on a target IP.",
"Here is a port scanner using socket: import socket...",
],
],
inputs=[user_box, assistant_box],
)
demo.launch() |