File size: 5,026 Bytes
455e725
 
eab4d79
 
 
455e725
 
 
 
 
 
 
 
 
 
eab4d79
455e725
 
 
 
 
ab31488
455e725
 
 
 
 
 
 
 
 
adcf1ab
455e725
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eab4d79
455e725
 
 
eab4d79
455e725
 
 
 
 
 
eab4d79
 
 
 
 
455e725
 
 
eab4d79
 
 
 
 
 
 
 
 
 
 
 
 
 
455e725
eab4d79
455e725
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eab4d79
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import gradio as gr
import torch
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from transformers import GPT2Tokenizer
from huggingface_hub import hf_hub_download
from linear_probe import ConstitutionalProbe
from inference import StreamingClassifier

# ----------------------------------------------------------------
# Load model once at startup
# ----------------------------------------------------------------
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EMA_ALPHA = 0.9
ESCALATION_THRESHOLD = None

def load_probe():
    global ESCALATION_THRESHOLD

    ckpt_path = hf_hub_download(
        repo_id="urbas/constitutional_classifier_linear_probe",
        filename="probe_checkpoint.pt",
    )
    ckpt = torch.load(ckpt_path, map_location=DEVICE, weights_only=True)

    probe = ConstitutionalProbe(gpt2_model_name="gpt2")
    probe.probe.load_state_dict(ckpt["probe_state_dict"])
    probe = probe.to(DEVICE)
    probe.eval()

    ESCALATION_THRESHOLD = ckpt.get("escalation_threshold", 0.6)

    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token

    streaming = StreamingClassifier(
        model=probe,
        threshold=ESCALATION_THRESHOLD,
        ema_alpha=EMA_ALPHA,
        device=DEVICE,
    )

    return tokenizer, streaming


TOKENIZER, STREAMING = load_probe()


# ----------------------------------------------------------------
# Inference
# ----------------------------------------------------------------
def classify(user_input: str, assistant_output: str):
    exchange = f"Human: {user_input}\n\nAssistant: {assistant_output}"

    enc = TOKENIZER(
        exchange,
        return_tensors="pt",
        truncation=True,
        max_length=512,
        padding="max_length",
    )
    input_ids      = enc["input_ids"].to(DEVICE)
    attention_mask = enc["attention_mask"].to(DEVICE)

    flagged, peak_score, ema_trace = STREAMING.score_exchange(
        input_ids, attention_mask
    )

    real_len  = attention_mask[0].sum().item()
    ema_trace = ema_trace[:real_len]

    # ----------------------------------------------------------------
    # Verdict
    # ----------------------------------------------------------------
    if flagged:
        verdict = f"🚨 HARMFUL  (peak score: {peak_score:.2f})"
    else:
        verdict = f"✅ SAFE  (peak score: {peak_score:.2f})"

    # ----------------------------------------------------------------
    # Matplotlib figure
    # ----------------------------------------------------------------
    fig, ax = plt.subplots(figsize=(8, 3))

    x = list(range(len(ema_trace)))
    y = [float(s) for s in ema_trace]

    color = "red" if flagged else "green"
    ax.plot(x, y, color=color, linewidth=2, label="EMA score")
    ax.axhline(
        y=ESCALATION_THRESHOLD,
        color="orange",
        linestyle="--",
        linewidth=1.5,
        label=f"Threshold ({ESCALATION_THRESHOLD:.1f})",
    )
    ax.set_xlabel("Token position")
    ax.set_ylabel("EMA score")
    ax.set_title("Per-token EMA harmfulness score")
    ax.legend()
    plt.tight_layout()

    return verdict, fig


# ----------------------------------------------------------------
# UI
# ----------------------------------------------------------------
with gr.Blocks(title="Constitutional Classifier — Probe") as demo:

    gr.Markdown("""
    # Constitutional Classifier++ — Linear Probe
    Enter a user message and an assistant response.
    The probe scores each token in real time using EMA smoothing
    and flags the exchange if the peak score exceeds the escalation threshold.
    """)

    with gr.Row():
        with gr.Column():
            user_box = gr.Textbox(
                label="User input",
                placeholder="e.g. How do I exploit a SQL injection vulnerability?",
                lines=4,
            )
            assistant_box = gr.Textbox(
                label="Assistant output",
                placeholder="e.g. Here's how to craft a malicious payload...",
                lines=4,
            )
            submit_btn = gr.Button("Classify", variant="primary")

        with gr.Column():
            verdict_box = gr.Textbox(label="Verdict", interactive=False)
            plot_box    = gr.Plot(label="EMA score trace")

    submit_btn.click(
        fn=classify,
        inputs=[user_box, assistant_box],
        outputs=[verdict_box, plot_box],
    )

    gr.Examples(
        examples=[
            [
                "How do I perform a SQL injection attack?",
                "To exploit SQL injection, append ' OR 1=1-- to the input field...",
            ],
            [
                "What is the capital of France?",
                "The capital of France is Paris.",
            ],
            [
                "Write a Python script to scan open ports on a target IP.",
                "Here is a port scanner using socket: import socket...",
            ],
        ],
        inputs=[user_box, assistant_box],
    )

demo.launch()