umrr's picture
Update app.py
cd47df5 verified
Raw
History Blame Contribute Delete
2.42 kB
import json
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
# Path to saved fine-tuned model (upload this folder to Hugging Face Space)
MODEL_DIR = "./saved_mbert_prompt_injection"
MAX_LENGTH = 128
# Load label names and threshold saved during training
with open(f"{MODEL_DIR}/label_config.json", "r", encoding="utf-8") as f:
config = json.load(f)
LABELS = config["labels"]
THRESHOLD = config.get("threshold", 0.5)
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def predict(prompt, threshold=THRESHOLD):
"""Predict 3 attack labels and confidence scores for one prompt."""
if not prompt.strip():
return "Please enter text.", {}
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
padding=True,
max_length=MAX_LENGTH,
).to(device)
with torch.no_grad():
logits = model(**inputs).logits
probs = torch.sigmoid(logits).cpu().numpy()[0]
pred_dict = {label: float(probs[i]) for i, label in enumerate(LABELS)}
detected = [label for i, label in enumerate(LABELS) if probs[i] >= threshold]
if not detected:
detected = ["Benign / No Attack Detected"]
return "Detected: " + ", ".join(detected), pred_dict
# Professional Gradio UI for Hugging Face Spaces
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# Prompt Injection Attack Detector
Multilingual BERT multi-label classifier for:
- Direct Injection
- Goal Hijacking
- Information Leakage
"""
)
with gr.Row():
prompt_box = gr.Textbox(
label="Prompt",
lines=5,
placeholder="Enter user prompt here...",
)
threshold = gr.Slider(
0.1,
0.9,
value=THRESHOLD,
step=0.05,
label="Threshold",
)
summary = gr.Textbox(label="Prediction")
scores = gr.Label(label="Confidence Scores", num_top_classes=3)
run_btn = gr.Button("Analyze Prompt", variant="primary")
run_btn.click(fn=predict, inputs=[prompt_box, threshold], outputs=[summary, scores])
demo.launch()