File size: 6,624 Bytes
4afcb3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
"""
risk_scoring.py
===============
Aggregates signals from all detection layers into a single risk score
and determines the final verdict for a request.

Risk score: float in [0, 1]
  0.0 – 0.30   β†’ LOW    (safe)
  0.30 – 0.60  β†’ MEDIUM (flagged for review)
  0.60 – 0.80  β†’ HIGH   (suspicious, sanitise or block)
  0.80 – 1.0   β†’ CRITICAL (block)

Status strings: "safe" | "flagged" | "blocked"
"""

from __future__ import annotations

import logging
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional

logger = logging.getLogger("ai_firewall.risk_scoring")


class RiskLevel(str, Enum):
    LOW      = "low"
    MEDIUM   = "medium"
    HIGH     = "high"
    CRITICAL = "critical"


class RequestStatus(str, Enum):
    SAFE    = "safe"
    FLAGGED = "flagged"
    BLOCKED = "blocked"


@dataclass
class RiskReport:
    """Comprehensive risk assessment for a single request."""

    status: RequestStatus
    risk_score: float
    risk_level: RiskLevel

    # Per-layer scores
    injection_score: float = 0.0
    adversarial_score: float = 0.0
    output_score: float = 0.0     # filled in after generation

    # Attack metadata
    attack_type: Optional[str] = None
    attack_category: Optional[str] = None
    flags: list = field(default_factory=list)

    # Timing
    latency_ms: float = 0.0

    def to_dict(self) -> dict:
        d = {
            "status": self.status.value,
            "risk_score": round(self.risk_score, 4),
            "risk_level": self.risk_level.value,
            "injection_score": round(self.injection_score, 4),
            "adversarial_score": round(self.adversarial_score, 4),
            "output_score": round(self.output_score, 4),
            "flags": self.flags,
            "latency_ms": round(self.latency_ms, 2),
        }
        if self.attack_type:
            d["attack_type"] = self.attack_type
        if self.attack_category:
            d["attack_category"] = self.attack_category
        return d


def _level_from_score(score: float) -> RiskLevel:
    if score < 0.30:
        return RiskLevel.LOW
    if score < 0.60:
        return RiskLevel.MEDIUM
    if score < 0.80:
        return RiskLevel.HIGH
    return RiskLevel.CRITICAL


class RiskScorer:
    """
    Aggregates injection and adversarial scores into a unified risk report.

    The weighting reflects the relative danger of each signal:
      - Injection score carries 60% weight  (direct attack)
      - Adversarial score carries 40% weight (indirect / evasion)

    Additional modifier: if the injection detector fires AND the
    adversarial detector fires, the combined score is boosted by a
    small multiplicative factor to account for compound attacks.

    Parameters
    ----------
    block_threshold : float
        Score >= this β†’ status BLOCKED (default 0.70).
    flag_threshold : float
        Score >= this β†’ status FLAGGED (default 0.40).
    injection_weight : float
        Weight for injection score (default 0.60).
    adversarial_weight : float
        Weight for adversarial score (default 0.40).
    compound_boost : float
        Multiplier applied when both detectors fire (default 1.15).
    """

    def __init__(
        self,
        block_threshold: float = 0.70,
        flag_threshold: float = 0.40,
        injection_weight: float = 0.60,
        adversarial_weight: float = 0.40,
        compound_boost: float = 1.15,
    ) -> None:
        self.block_threshold = block_threshold
        self.flag_threshold = flag_threshold
        self.injection_weight = injection_weight
        self.adversarial_weight = adversarial_weight
        self.compound_boost = compound_boost

    def score(
        self,
        injection_score: float,
        adversarial_score: float,
        injection_is_flagged: bool = False,
        adversarial_is_flagged: bool = False,
        attack_type: Optional[str] = None,
        attack_category: Optional[str] = None,
        flags: Optional[list] = None,
        output_score: float = 0.0,
        latency_ms: float = 0.0,
    ) -> RiskReport:
        """
        Compute the unified risk report.

        Parameters
        ----------
        injection_score : float
            Confidence score from InjectionDetector (0-1).
        adversarial_score : float
            Risk score from AdversarialDetector (0-1).
        injection_is_flagged : bool
            Whether InjectionDetector marked the input as injection.
        adversarial_is_flagged : bool
            Whether AdversarialDetector marked input as adversarial.
        attack_type : str, optional
            Human-readable attack type label.
        attack_category : str, optional
            Injection attack category enum value.
        flags : list, optional
            All flags raised by detectors.
        output_score : float
            Risk score from OutputGuardrail (added post-generation).
        latency_ms : float
            Total pipeline latency.

        Returns
        -------
        RiskReport
        """
        t0 = time.perf_counter()

        # Weighted combination
        combined = (
            injection_score * self.injection_weight
            + adversarial_score * self.adversarial_weight
        )

        # Compound boost
        if injection_is_flagged and adversarial_is_flagged:
            combined = min(combined * self.compound_boost, 1.0)

        # Factor in output score (secondary signal, lower weight)
        if output_score > 0:
            combined = min(combined + output_score * 0.20, 1.0)

        risk_score = round(combined, 4)
        level = _level_from_score(risk_score)

        if risk_score >= self.block_threshold:
            status = RequestStatus.BLOCKED
        elif risk_score >= self.flag_threshold:
            status = RequestStatus.FLAGGED
        else:
            status = RequestStatus.SAFE

        elapsed = (time.perf_counter() - t0) * 1000 + latency_ms

        report = RiskReport(
            status=status,
            risk_score=risk_score,
            risk_level=level,
            injection_score=injection_score,
            adversarial_score=adversarial_score,
            output_score=output_score,
            attack_type=attack_type if status != RequestStatus.SAFE else None,
            attack_category=attack_category if status != RequestStatus.SAFE else None,
            flags=flags or [],
            latency_ms=elapsed,
        )

        logger.info(
            "Risk report | status=%s score=%.3f level=%s",
            status.value, risk_score, level.value,
        )

        return report