File size: 2,421 Bytes
cd47df5
 
4eb3a08
cd47df5
4eb3a08
cd47df5
 
 
4eb3a08
cd47df5
 
 
4eb3a08
cd47df5
 
4eb3a08
cd47df5
 
 
 
fdfd1fb
cd47df5
 
fdfd1fb
 
cd47df5
 
 
 
fdfd1fb
cd47df5
 
 
 
 
 
 
fdfd1fb
cd47df5
 
 
fdfd1fb
cd47df5
 
fdfd1fb
cd47df5
 
fdfd1fb
cd47df5
fdfd1fb
 
cd47df5
 
fdfd1fb
 
cd47df5
 
 
 
 
fdfd1fb
 
 
 
cd47df5
 
 
 
 
 
 
 
 
 
 
fdfd1fb
 
cd47df5
 
 
 
 
fdfd1fb
cd47df5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import json
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# Path to saved fine-tuned model (upload this folder to Hugging Face Space)
MODEL_DIR = "./saved_mbert_prompt_injection"
MAX_LENGTH = 128

# Load label names and threshold saved during training
with open(f"{MODEL_DIR}/label_config.json", "r", encoding="utf-8") as f:
    config = json.load(f)

LABELS = config["labels"]
THRESHOLD = config.get("threshold", 0.5)

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


def predict(prompt, threshold=THRESHOLD):
    """Predict 3 attack labels and confidence scores for one prompt."""
    if not prompt.strip():
        return "Please enter text.", {}

    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=MAX_LENGTH,
    ).to(device)

    with torch.no_grad():
        logits = model(**inputs).logits
        probs = torch.sigmoid(logits).cpu().numpy()[0]

    pred_dict = {label: float(probs[i]) for i, label in enumerate(LABELS)}
    detected = [label for i, label in enumerate(LABELS) if probs[i] >= threshold]

    if not detected:
        detected = ["Benign / No Attack Detected"]

    return "Detected: " + ", ".join(detected), pred_dict


# Professional Gradio UI for Hugging Face Spaces
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # Prompt Injection Attack Detector
        Multilingual BERT multi-label classifier for:
        - Direct Injection
        - Goal Hijacking
        - Information Leakage
        """
    )

    with gr.Row():
        prompt_box = gr.Textbox(
            label="Prompt",
            lines=5,
            placeholder="Enter user prompt here...",
        )
        threshold = gr.Slider(
            0.1,
            0.9,
            value=THRESHOLD,
            step=0.05,
            label="Threshold",
        )

    summary = gr.Textbox(label="Prediction")
    scores = gr.Label(label="Confidence Scores", num_top_classes=3)
    run_btn = gr.Button("Analyze Prompt", variant="primary")

    run_btn.click(fn=predict, inputs=[prompt_box, threshold], outputs=[summary, scores])

demo.launch()