Spaces:
Sleeping
Sleeping
| """ | |
| risk_scoring.py | |
| =============== | |
| Aggregates signals from all detection layers into a single risk score | |
| and determines the final verdict for a request. | |
| Risk score: float in [0, 1] | |
| 0.0 β 0.30 β LOW (safe) | |
| 0.30 β 0.60 β MEDIUM (flagged for review) | |
| 0.60 β 0.80 β HIGH (suspicious, sanitise or block) | |
| 0.80 β 1.0 β CRITICAL (block) | |
| Status strings: "safe" | "flagged" | "blocked" | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import time | |
| from dataclasses import dataclass, field | |
| from enum import Enum | |
| from typing import Optional | |
| logger = logging.getLogger("ai_firewall.risk_scoring") | |
| class RiskLevel(str, Enum): | |
| LOW = "low" | |
| MEDIUM = "medium" | |
| HIGH = "high" | |
| CRITICAL = "critical" | |
| class RequestStatus(str, Enum): | |
| SAFE = "safe" | |
| FLAGGED = "flagged" | |
| BLOCKED = "blocked" | |
| class RiskReport: | |
| """Comprehensive risk assessment for a single request.""" | |
| status: RequestStatus | |
| risk_score: float | |
| risk_level: RiskLevel | |
| # Per-layer scores | |
| injection_score: float = 0.0 | |
| adversarial_score: float = 0.0 | |
| output_score: float = 0.0 # filled in after generation | |
| # Attack metadata | |
| attack_type: Optional[str] = None | |
| attack_category: Optional[str] = None | |
| flags: list = field(default_factory=list) | |
| # Timing | |
| latency_ms: float = 0.0 | |
| def to_dict(self) -> dict: | |
| d = { | |
| "status": self.status.value, | |
| "risk_score": round(self.risk_score, 4), | |
| "risk_level": self.risk_level.value, | |
| "injection_score": round(self.injection_score, 4), | |
| "adversarial_score": round(self.adversarial_score, 4), | |
| "output_score": round(self.output_score, 4), | |
| "flags": self.flags, | |
| "latency_ms": round(self.latency_ms, 2), | |
| } | |
| if self.attack_type: | |
| d["attack_type"] = self.attack_type | |
| if self.attack_category: | |
| d["attack_category"] = self.attack_category | |
| return d | |
| def _level_from_score(score: float) -> RiskLevel: | |
| if score < 0.30: | |
| return RiskLevel.LOW | |
| if score < 0.60: | |
| return RiskLevel.MEDIUM | |
| if score < 0.80: | |
| return RiskLevel.HIGH | |
| return RiskLevel.CRITICAL | |
| class RiskScorer: | |
| """ | |
| Aggregates injection and adversarial scores into a unified risk report. | |
| The weighting reflects the relative danger of each signal: | |
| - Injection score carries 60% weight (direct attack) | |
| - Adversarial score carries 40% weight (indirect / evasion) | |
| Additional modifier: if the injection detector fires AND the | |
| adversarial detector fires, the combined score is boosted by a | |
| small multiplicative factor to account for compound attacks. | |
| Parameters | |
| ---------- | |
| block_threshold : float | |
| Score >= this β status BLOCKED (default 0.70). | |
| flag_threshold : float | |
| Score >= this β status FLAGGED (default 0.40). | |
| injection_weight : float | |
| Weight for injection score (default 0.60). | |
| adversarial_weight : float | |
| Weight for adversarial score (default 0.40). | |
| compound_boost : float | |
| Multiplier applied when both detectors fire (default 1.15). | |
| """ | |
| def __init__( | |
| self, | |
| block_threshold: float = 0.70, | |
| flag_threshold: float = 0.40, | |
| injection_weight: float = 0.60, | |
| adversarial_weight: float = 0.40, | |
| compound_boost: float = 1.15, | |
| ) -> None: | |
| self.block_threshold = block_threshold | |
| self.flag_threshold = flag_threshold | |
| self.injection_weight = injection_weight | |
| self.adversarial_weight = adversarial_weight | |
| self.compound_boost = compound_boost | |
| def score( | |
| self, | |
| injection_score: float, | |
| adversarial_score: float, | |
| injection_is_flagged: bool = False, | |
| adversarial_is_flagged: bool = False, | |
| attack_type: Optional[str] = None, | |
| attack_category: Optional[str] = None, | |
| flags: Optional[list] = None, | |
| output_score: float = 0.0, | |
| latency_ms: float = 0.0, | |
| ) -> RiskReport: | |
| """ | |
| Compute the unified risk report. | |
| Parameters | |
| ---------- | |
| injection_score : float | |
| Confidence score from InjectionDetector (0-1). | |
| adversarial_score : float | |
| Risk score from AdversarialDetector (0-1). | |
| injection_is_flagged : bool | |
| Whether InjectionDetector marked the input as injection. | |
| adversarial_is_flagged : bool | |
| Whether AdversarialDetector marked input as adversarial. | |
| attack_type : str, optional | |
| Human-readable attack type label. | |
| attack_category : str, optional | |
| Injection attack category enum value. | |
| flags : list, optional | |
| All flags raised by detectors. | |
| output_score : float | |
| Risk score from OutputGuardrail (added post-generation). | |
| latency_ms : float | |
| Total pipeline latency. | |
| Returns | |
| ------- | |
| RiskReport | |
| """ | |
| t0 = time.perf_counter() | |
| # Weighted combination | |
| combined = ( | |
| injection_score * self.injection_weight | |
| + adversarial_score * self.adversarial_weight | |
| ) | |
| # Compound boost | |
| if injection_is_flagged and adversarial_is_flagged: | |
| combined = min(combined * self.compound_boost, 1.0) | |
| # Factor in output score (secondary signal, lower weight) | |
| if output_score > 0: | |
| combined = min(combined + output_score * 0.20, 1.0) | |
| risk_score = round(combined, 4) | |
| level = _level_from_score(risk_score) | |
| if risk_score >= self.block_threshold: | |
| status = RequestStatus.BLOCKED | |
| elif risk_score >= self.flag_threshold: | |
| status = RequestStatus.FLAGGED | |
| else: | |
| status = RequestStatus.SAFE | |
| elapsed = (time.perf_counter() - t0) * 1000 + latency_ms | |
| report = RiskReport( | |
| status=status, | |
| risk_score=risk_score, | |
| risk_level=level, | |
| injection_score=injection_score, | |
| adversarial_score=adversarial_score, | |
| output_score=output_score, | |
| attack_type=attack_type if status != RequestStatus.SAFE else None, | |
| attack_category=attack_category if status != RequestStatus.SAFE else None, | |
| flags=flags or [], | |
| latency_ms=elapsed, | |
| ) | |
| logger.info( | |
| "Risk report | status=%s score=%.3f level=%s", | |
| status.value, risk_score, level.value, | |
| ) | |
| return report | |