Spaces:
Runtime error
Runtime error
File size: 2,421 Bytes
cd47df5 4eb3a08 cd47df5 4eb3a08 cd47df5 4eb3a08 cd47df5 4eb3a08 cd47df5 4eb3a08 cd47df5 fdfd1fb cd47df5 fdfd1fb cd47df5 fdfd1fb cd47df5 fdfd1fb cd47df5 fdfd1fb cd47df5 fdfd1fb cd47df5 fdfd1fb cd47df5 fdfd1fb cd47df5 fdfd1fb cd47df5 fdfd1fb cd47df5 fdfd1fb cd47df5 fdfd1fb cd47df5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 | import json
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Path to saved fine-tuned model (upload this folder to Hugging Face Space)
MODEL_DIR = "./saved_mbert_prompt_injection"
MAX_LENGTH = 128
# Load label names and threshold saved during training
with open(f"{MODEL_DIR}/label_config.json", "r", encoding="utf-8") as f:
config = json.load(f)
LABELS = config["labels"]
THRESHOLD = config.get("threshold", 0.5)
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def predict(prompt, threshold=THRESHOLD):
"""Predict 3 attack labels and confidence scores for one prompt."""
if not prompt.strip():
return "Please enter text.", {}
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
padding=True,
max_length=MAX_LENGTH,
).to(device)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.sigmoid(logits).cpu().numpy()[0]
pred_dict = {label: float(probs[i]) for i, label in enumerate(LABELS)}
detected = [label for i, label in enumerate(LABELS) if probs[i] >= threshold]
if not detected:
detected = ["Benign / No Attack Detected"]
return "Detected: " + ", ".join(detected), pred_dict
# Professional Gradio UI for Hugging Face Spaces
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# Prompt Injection Attack Detector
Multilingual BERT multi-label classifier for:
- Direct Injection
- Goal Hijacking
- Information Leakage
"""
)
with gr.Row():
prompt_box = gr.Textbox(
label="Prompt",
lines=5,
placeholder="Enter user prompt here...",
)
threshold = gr.Slider(
0.1,
0.9,
value=THRESHOLD,
step=0.05,
label="Threshold",
)
summary = gr.Textbox(label="Prediction")
scores = gr.Label(label="Confidence Scores", num_top_classes=3)
run_btn = gr.Button("Analyze Prompt", variant="primary")
run_btn.click(fn=predict, inputs=[prompt_box, threshold], outputs=[summary, scores])
demo.launch()
|