middleware / prompt_injection.py
PARTHA181098's picture
Create prompt_injection.py
e818330 verified
# import re
# from typing import List
# # Common prompt injection patterns
# INJECTION_PATTERNS: List[str] = [
# r"ignore\s+previous\s+instructions",
# r"disregard\s+all\s+rules",
# r"you\s+are\s+no\s+longer",
# r"system\s+prompt",
# r"developer\s+message",
# r"act\s+as\s+.*",
# r"jailbreak",
# r"bypass\s+security",
# r"reveal\s+hidden\s+instructions",
# r"follow\s+my\s+instructions\s+only",
# ]
import os
import requests
from dotenv import load_dotenv
load_dotenv() # loads .env into os.environ
HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
raise RuntimeError("HF_TOKEN not set. Check your .env file.")
API_URL = "https://router.huggingface.co/hf-inference/models/meta-llama/Llama-Prompt-Guard-2-86M"
headers = {
"Authorization": f"Bearer {HF_TOKEN}",
}
# Detector function
def is_prompt_injection(text: str) -> bool:
"""
Returns True if prompt injection is detected.
"""
# normalized = text.lower()
# for pattern in INJECTION_PATTERNS:
# if re.search(pattern, normalized):
# return True
payload = {"inputs": text}
try:
response = requests.post(API_URL, headers=headers, json=payload, timeout=15)
response.raise_for_status()
data = response.json()
print(data)
except Exception as e:
# If the detector is unavailable, don't hard-fail the whole chat.
# (Fail-open; you can change this to `return True` if you prefer fail-closed.)
print(f"[prompt_injection] detector error: {e}")
return False
# HF inference can return either:
# - [{label, score}, ...]
# - [[{label, score}, ...]]
candidates = []
if isinstance(data, list) and data:
if isinstance(data[0], list):
candidates = data[0]
else:
candidates = data
if not candidates:
# Unexpected response shape; fail-open.
print(f"[prompt_injection] unexpected response: {data}")
return False
def _score(item) -> float:
try:
return float(item.get("score", 0.0))
except Exception:
return 0.0
# Prefer an explicit "safe" label if present.
# safe_entry = next(
# (c for c in candidates if isinstance(c, dict) and "safe" in str(c.get("label", "")).lower()),
# None,
# )
# print('safe_entry...', safe_entry)
threshold = 0.5
# if safe_entry is not None:
# safe_score = _score(safe_entry)
# return safe_score < threshold
# Fallback: keep the original intent (treat the first score as the "safe" score).
first = candidates[0] if isinstance(candidates[0], dict) else {}
safe_score = _score(first)
label = first.get('label', None)
print('safe_score...', safe_score)
if label == 'LABEL_0':
return safe_score < threshold
else:
return safe_score > threshold