JailBreakDefense / src /detector.py
kriti0608's picture
Create src/detector.py
c4f4657 verified
from dataclasses import dataclass
from typing import Dict, Any, List, Optional
from .rules import match_rules
@dataclass
class DetectionResult:
risk_score: float
fired_rules: List[Dict[str, Any]]
metadata: Dict[str, Any]
class JailbreakDetector:
"""
Lightweight, rule-based jailbreak detector.
- Looks at prompt, output, or both.
- Returns a normalized risk score 0–1 + which patterns fired.
"""
def __init__(self, consider_output: bool = True):
self.consider_output = consider_output
def score(self, prompt: str, output: Optional[str] = None) -> DetectionResult:
combined_text = prompt or ""
source_flags = {"prompt_rules": [], "output_rules": []}
prompt_hits = match_rules(prompt or "")
source_flags["prompt_rules"] = prompt_hits
all_hits = list(prompt_hits)
if self.consider_output and output:
out_hits = match_rules(output)
source_flags["output_rules"] = out_hits
all_hits.extend(out_hits)
# Compute a simple normalized risk score
total_weight = sum(h["weight"] for h in all_hits)
# Cap score to [0,1] using a simple normalization
risk_score = min(1.0, total_weight / 3.0) # 3.0 is arbitrary scale
return DetectionResult(
risk_score=risk_score,
fired_rules=all_hits,
metadata=source_flags,
)