Spaces:
Runtime error
Runtime error
| import json | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| # Path to saved fine-tuned model (upload this folder to Hugging Face Space) | |
| MODEL_DIR = "./saved_mbert_prompt_injection" | |
| MAX_LENGTH = 128 | |
| # Load label names and threshold saved during training | |
| with open(f"{MODEL_DIR}/label_config.json", "r", encoding="utf-8") as f: | |
| config = json.load(f) | |
| LABELS = config["labels"] | |
| THRESHOLD = config.get("threshold", 0.5) | |
| # Load tokenizer and model | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) | |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR) | |
| model.eval() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| def predict(prompt, threshold=THRESHOLD): | |
| """Predict 3 attack labels and confidence scores for one prompt.""" | |
| if not prompt.strip(): | |
| return "Please enter text.", {} | |
| inputs = tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding=True, | |
| max_length=MAX_LENGTH, | |
| ).to(device) | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| probs = torch.sigmoid(logits).cpu().numpy()[0] | |
| pred_dict = {label: float(probs[i]) for i, label in enumerate(LABELS)} | |
| detected = [label for i, label in enumerate(LABELS) if probs[i] >= threshold] | |
| if not detected: | |
| detected = ["Benign / No Attack Detected"] | |
| return "Detected: " + ", ".join(detected), pred_dict | |
| # Professional Gradio UI for Hugging Face Spaces | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # Prompt Injection Attack Detector | |
| Multilingual BERT multi-label classifier for: | |
| - Direct Injection | |
| - Goal Hijacking | |
| - Information Leakage | |
| """ | |
| ) | |
| with gr.Row(): | |
| prompt_box = gr.Textbox( | |
| label="Prompt", | |
| lines=5, | |
| placeholder="Enter user prompt here...", | |
| ) | |
| threshold = gr.Slider( | |
| 0.1, | |
| 0.9, | |
| value=THRESHOLD, | |
| step=0.05, | |
| label="Threshold", | |
| ) | |
| summary = gr.Textbox(label="Prediction") | |
| scores = gr.Label(label="Confidence Scores", num_top_classes=3) | |
| run_btn = gr.Button("Analyze Prompt", variant="primary") | |
| run_btn.click(fn=predict, inputs=[prompt_box, threshold], outputs=[summary, scores]) | |
| demo.launch() | |