import json import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForSequenceClassification # Path to saved fine-tuned model (upload this folder to Hugging Face Space) MODEL_DIR = "./saved_mbert_prompt_injection" MAX_LENGTH = 128 # Load label names and threshold saved during training with open(f"{MODEL_DIR}/label_config.json", "r", encoding="utf-8") as f: config = json.load(f) LABELS = config["labels"] THRESHOLD = config.get("threshold", 0.5) # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR) model.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) def predict(prompt, threshold=THRESHOLD): """Predict 3 attack labels and confidence scores for one prompt.""" if not prompt.strip(): return "Please enter text.", {} inputs = tokenizer( prompt, return_tensors="pt", truncation=True, padding=True, max_length=MAX_LENGTH, ).to(device) with torch.no_grad(): logits = model(**inputs).logits probs = torch.sigmoid(logits).cpu().numpy()[0] pred_dict = {label: float(probs[i]) for i, label in enumerate(LABELS)} detected = [label for i, label in enumerate(LABELS) if probs[i] >= threshold] if not detected: detected = ["Benign / No Attack Detected"] return "Detected: " + ", ".join(detected), pred_dict # Professional Gradio UI for Hugging Face Spaces with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """ # Prompt Injection Attack Detector Multilingual BERT multi-label classifier for: - Direct Injection - Goal Hijacking - Information Leakage """ ) with gr.Row(): prompt_box = gr.Textbox( label="Prompt", lines=5, placeholder="Enter user prompt here...", ) threshold = gr.Slider( 0.1, 0.9, value=THRESHOLD, step=0.05, label="Threshold", ) summary = gr.Textbox(label="Prediction") scores = gr.Label(label="Confidence Scores", num_top_classes=3) run_btn = gr.Button("Analyze Prompt", variant="primary") run_btn.click(fn=predict, inputs=[prompt_box, threshold], outputs=[summary, scores]) demo.launch()