File size: 17,394 Bytes
b14c6e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
"""
Rule-Based Baseline Agent for Adaptive Alert Triage
====================================================

Implements two rule-based policies that serve as reproducible baselines against
which RL agents can be measured.  Both agents expose the same interface:

    agent.act(observation: Observation) -> Action
    agent.reset() -> None

Agents
------
RuleBasedAgent
    Threshold-based policy using visible_severity and confidence only.
    Designed to be the weakest baseline β€” gives the RL agent room to shine.

ImprovedRuleBasedAgent
    Adds age-weighting, alert-type priors, system-load awareness, and
    a simple resource-budget guard.  Competitive on the easy task but still
    well below the hard-task success threshold (β‰₯ 0.50).

Evaluation
----------
evaluate_agent() runs N episodes and returns aggregated metrics that match the
three task graders (EasyTaskGrader, MediumTaskGrader, HardTaskGrader).

Usage
-----
    from agents.baseline import RuleBasedAgent, evaluate_agent
    from adaptive_alert_triage.env import AdaptiveAlertTriageEnv

    env   = AdaptiveAlertTriageEnv(task_id="easy")
    agent = RuleBasedAgent()
    results = evaluate_agent(agent, env, num_episodes=10, task_id="easy")
    print(results)
"""

from __future__ import annotations

import sys
import os
from typing import Any, Dict, List, Optional

import numpy as np

from adaptive_alert_triage.models import Action, Alert, Observation

# Grader imports (relative paths allow running from project root or src/)
from tasks.easy   import EasyTaskGrader,   run_episode_evaluation as _easy_eval
from tasks.medium import MediumTaskGrader, run_episode_evaluation as _medium_eval
from tasks.hard   import HardTaskGrader,   run_episode_evaluation as _hard_eval


# ---------------------------------------------------------------------------
# Policy constants  (kept separate from the environment thresholds so the
# baseline agent cannot accidentally "see" hidden ground-truth constants)
# ---------------------------------------------------------------------------
_INVESTIGATE_SEV_THRESHOLD:  float = 0.75   # severity above which to investigate
_INVESTIGATE_CONF_THRESHOLD: float = 0.70   # confidence required for investigation
_IGNORE_CONF_THRESHOLD:      float = 0.30   # confidence below which β†’ likely FP β†’ IGNORE
_ESCALATE_SEV_THRESHOLD:     float = 0.55   # severity above which to escalate
_SECURITY_SEV_BOOST:         float = 0.05   # extra weight for SECURITY type alerts
_AGE_WEIGHT:                 float = 0.08   # scoring weight per time-step of age


# ---------------------------------------------------------------------------
# RuleBasedAgent
# ---------------------------------------------------------------------------

class RuleBasedAgent:
    """
    Simple threshold-based agent for alert triage.

    Policy (applied in order, first match wins):
        1. visible_severity > 0.75 AND confidence > 0.70  β†’ INVESTIGATE
        2. resource_budget == 0                            β†’ ESCALATE (can't investigate)
        3. confidence < 0.30                              β†’ IGNORE   (likely false positive)
        4. visible_severity > 0.55                        β†’ ESCALATE
        5. default                                        β†’ DELAY

    Alert selection: highest visible_severity first.

    Limitations (intentional β€” motivates RL):
        - Cannot detect correlated chains (no memory across steps)
        - Fixed thresholds ignore system_load and queue_length
        - No adaptation to changing alert distributions
        - DELAY is almost never used, hurting medium-task efficiency

    Attributes:
        investigate_sev_threshold:  severity required to trigger INVESTIGATE
        investigate_conf_threshold: confidence required alongside severity
        ignore_conf_threshold:      confidence below which to IGNORE
        escalate_sev_threshold:     severity above which to ESCALATE (fallback)
        resource_aware:             if True, respects resource_budget == 0
    """

    def __init__(
        self,
        investigate_sev_threshold:  float = _INVESTIGATE_SEV_THRESHOLD,
        investigate_conf_threshold: float = _INVESTIGATE_CONF_THRESHOLD,
        ignore_conf_threshold:      float = _IGNORE_CONF_THRESHOLD,
        escalate_sev_threshold:     float = _ESCALATE_SEV_THRESHOLD,
        resource_aware:             bool  = True,
    ) -> None:
        self.investigate_sev_threshold  = investigate_sev_threshold
        self.investigate_conf_threshold = investigate_conf_threshold
        self.ignore_conf_threshold      = ignore_conf_threshold
        self.escalate_sev_threshold     = escalate_sev_threshold
        self.resource_aware             = resource_aware

    # ------------------------------------------------------------------
    # Public interface
    # ------------------------------------------------------------------

    def act(self, observation: Observation) -> Action:
        """
        Choose an action for the highest-priority alert.

        Args:
            observation: Current environment observation (agent-visible only).

        Returns:
            Action targeting one alert in observation.alerts.

        Raises:
            ValueError: If observation.alerts is empty.
        """
        if not observation.alerts:
            raise ValueError("No alerts in observation β€” cannot act.")

        alert = self._select_alert(observation.alerts)
        action_type = self._decide_action(alert, observation)
        return Action(alert_id=alert.id, action_type=action_type)

    def reset(self) -> None:
        """Reset any per-episode state (stateless baseline; no-op)."""
        pass

    # ------------------------------------------------------------------
    # Alert selection
    # ------------------------------------------------------------------

    def _select_alert(self, alerts: List[Alert]) -> Alert:
        """Pick the alert with the highest visible_severity."""
        return max(alerts, key=lambda a: a.visible_severity)

    # ------------------------------------------------------------------
    # Action decision
    # ------------------------------------------------------------------

    def _decide_action(self, alert: Alert, obs: Observation) -> str:
        """
        Apply the rule-based policy.

        Args:
            alert: The alert selected for action.
            obs:   Full observation (for resource_budget access).

        Returns:
            Action type string.
        """
        sev  = alert.visible_severity
        conf = alert.confidence

        # Rule 1: high severity + high confidence β†’ INVESTIGATE
        if sev > self.investigate_sev_threshold and conf > self.investigate_conf_threshold:
            if self.resource_aware and obs.resource_budget is not None and obs.resource_budget <= 0:
                # Budget exhausted β€” escalate rather than block
                return "ESCALATE"
            return "INVESTIGATE"

        # Rule 2: low confidence β†’ likely false positive β†’ IGNORE
        if conf < self.ignore_conf_threshold:
            return "IGNORE"

        # Rule 3: medium-high severity β†’ ESCALATE
        if sev > self.escalate_sev_threshold:
            return "ESCALATE"

        # Default: DELAY β€” let it age for potential future reclassification
        return "DELAY"

    # ------------------------------------------------------------------
    # Dunder
    # ------------------------------------------------------------------

    def __repr__(self) -> str:
        return (
            f"{self.__class__.__name__}("
            f"inv_sev={self.investigate_sev_threshold}, "
            f"inv_conf={self.investigate_conf_threshold}, "
            f"ign_conf={self.ignore_conf_threshold})"
        )


# ---------------------------------------------------------------------------
# ImprovedRuleBasedAgent
# ---------------------------------------------------------------------------

class ImprovedRuleBasedAgent(RuleBasedAgent):
    """
    Enhanced rule-based agent with multi-factor scoring and context awareness.

    Improvements over RuleBasedAgent:
        - Multi-factor alert scoring: severity + age + type prior
        - System-load-aware thresholds: under high load, be more conservative
        - Age-urgency: alerts older than 3 steps get promoted to INVESTIGATE
        - SECURITY alerts receive a priority boost

    Still limited (no learning, no chain detection) but achieves higher scores
    on easy and medium tasks than the plain threshold baseline.
    """

    def _select_alert(self, alerts: List[Alert]) -> Alert:
        """
        Score alerts on a combined severity + age + type metric.

        Score = visible_severity * 2 + age * AGE_WEIGHT + type_boost
        """
        def _score(a: Alert) -> float:
            s = a.visible_severity * 2.0
            s += a.age * _AGE_WEIGHT
            if a.alert_type == "SECURITY":
                s += _SECURITY_SEV_BOOST * 2
            elif a.alert_type in ("APPLICATION", "NETWORK"):
                s += _SECURITY_SEV_BOOST
            return s

        return max(alerts, key=_score)

    def _decide_action(self, alert: Alert, obs: Observation) -> str:
        """
        Enhanced policy with age-urgency and system-load guards.

        Overrides:
            - Aged critical alerts (age β‰₯ 3, sev > 0.70) β†’ INVESTIGATE immediately
            - Under very high system load (> 0.85): raise investigate bar
            - Otherwise: fall back to parent policy
        """
        sev       = alert.visible_severity
        conf      = alert.confidence
        age       = alert.age
        sys_load  = obs.system_load
        budget    = obs.resource_budget

        # Rule A: aged potential-critical β€” promote to INVESTIGATE regardless of conf
        if age >= 3 and sev > 0.70:
            if self.resource_aware and budget is not None and budget <= 0:
                return "ESCALATE"
            return "INVESTIGATE"

        # Rule B: very high system load β€” conservative strategy
        if sys_load > 0.85:
            if sev > 0.85 and conf > 0.80:
                if self.resource_aware and budget is not None and budget <= 0:
                    return "ESCALATE"
                return "INVESTIGATE"
            if sev < 0.35:
                return "IGNORE"
            return "DELAY"

        # Rule C: resource-budget nearly exhausted β†’ switch to ESCALATE for medium
        if self.resource_aware and budget is not None and budget <= 1:
            if sev > self.investigate_sev_threshold and conf > self.investigate_conf_threshold:
                return "INVESTIGATE"   # save last slot for truly critical
            if sev > 0.50:
                return "ESCALATE"
            if conf < self.ignore_conf_threshold:
                return "IGNORE"
            return "DELAY"

        # Fallback to parent rules
        return super()._decide_action(alert, obs)


# ---------------------------------------------------------------------------
# Evaluation harness
# ---------------------------------------------------------------------------

def evaluate_agent(
    agent: RuleBasedAgent,
    env,
    num_episodes: int = 10,
    task_id: str = "easy",
    seed_offset: int = 0,
    verbose: bool = False,
) -> Dict[str, Any]:
    """
    Evaluate a rule-based agent across multiple episodes using the task graders.

    This function integrates with the same graders that produce the official
    leaderboard scores β€” results are directly comparable to RL baselines.

    Args:
        agent:        Agent instance with .act(observation) and .reset() methods.
        env:          AdaptiveAlertTriageEnv instance (must match task_id).
        num_episodes: Number of evaluation episodes.
        task_id:      One of "easy", "medium", "hard".
        seed_offset:  Added to episode index to form the reset seed.
        verbose:      Print per-episode summary if True.

    Returns:
        Dict with keys:
            mean_score, std_score, min_score, max_score,
            success_rate, episode_scores, episode_metrics,
            task_id, agent_name, num_episodes.
    """
    if task_id == "easy":
        results = _easy_eval(agent, env, num_episodes=num_episodes,
                             seed_offset=seed_offset, verbose=verbose)
    elif task_id == "medium":
        results = _medium_eval(agent, env, num_episodes=num_episodes,
                               seed_offset=seed_offset, verbose=verbose)
    elif task_id == "hard":
        results = _hard_eval(agent, env, num_episodes=num_episodes,
                             seed_offset=seed_offset, verbose=verbose)
    else:
        raise ValueError(f"Unknown task_id '{task_id}'. Must be easy/medium/hard.")

    results["task_id"]    = task_id
    results["agent_name"] = repr(agent)
    results["num_episodes"] = num_episodes
    return results


# ---------------------------------------------------------------------------
# Self-test / CLI entry-point
# ---------------------------------------------------------------------------

if __name__ == "__main__":
    from adaptive_alert_triage.env import AdaptiveAlertTriageEnv

    print("=" * 65)
    print("Rule-Based Baseline Agent β€” Self-Test")
    print("=" * 65)

    # ── Unit test: basic act() behaviour ─────────────────────────────────
    from adaptive_alert_triage.models import Alert, Observation

    def _obs(alerts: List[Alert], budget: Optional[int] = None) -> Observation:
        return Observation(
            alerts=alerts,
            system_load=0.5,
            queue_length=len(alerts),
            time_remaining=20,
            episode_step=1,
            resource_budget=budget,
        )

    def _alert(aid: str, sev: float, conf: float,
               atype: str = "CPU", age: int = 0) -> Alert:
        return Alert(id=aid, visible_severity=sev, confidence=conf,
                     alert_type=atype, age=age)

    cases = [
        # description, alerts, budget, expected_action
        ("High sev+conf β†’ INVESTIGATE",
         [_alert("a1", 0.90, 0.85)], None, "INVESTIGATE"),
        ("Low confidence β†’ IGNORE",
         [_alert("a2", 0.50, 0.20)], None, "IGNORE"),
        ("Medium sev β†’ ESCALATE",
         [_alert("a3", 0.65, 0.60)], None, "ESCALATE"),
        ("Low sev β†’ DELAY",
         [_alert("a4", 0.30, 0.50)], None, "DELAY"),
        ("High sev, budget=0 β†’ ESCALATE (resource_aware)",
         [_alert("a5", 0.90, 0.85)], 0, "ESCALATE"),
    ]

    agent_basic = RuleBasedAgent()
    all_pass = True
    print("\n── Basic RuleBasedAgent ─────────────────────────────────────")
    for desc, alerts, budget, expected in cases:
        obs    = _obs(alerts, budget)
        action = agent_basic.act(obs)
        ok     = action.action_type == expected
        if not ok:
            all_pass = False
        print(f"  [{'PASS' if ok else 'FAIL'}]  {desc}")
        if not ok:
            print(f"         expected {expected}, got {action.action_type}")

    # ── Test ImprovedRuleBasedAgent ──────────────────────────────────────
    print("\n── ImprovedRuleBasedAgent ──────────────────────────────────────")
    agent_improved = ImprovedRuleBasedAgent()

    # Aged critical should get INVESTIGATE
    aged_critical = [_alert("a6", 0.75, 0.50, age=4)]  # aged, medium conf
    obs_aged  = _obs(aged_critical)
    act_aged  = agent_improved.act(obs_aged)
    ok_aged   = act_aged.action_type == "INVESTIGATE"
    if not ok_aged:
        all_pass = False
    print(f"  [{'PASS' if ok_aged else 'FAIL'}]  Aged critical (age=4, sev=0.75) β†’ INVESTIGATE  (got {act_aged.action_type})")

    # SECURITY alert should be selected over lower-sev CPU
    multi = [
        _alert("sec",  0.70, 0.80, "SECURITY"),
        _alert("cpu",  0.85, 0.80, "CPU"),
    ]
    obs_multi  = _obs(multi)
    sel_multi  = agent_improved._select_alert(multi)
    # CPU has higher sev but SECURITY boost may flip; test that _select_alert runs without error
    print(f"  [PASS]  Multi-alert selection β†’ picked '{sel_multi.id}' (no crash)")

    # ── Episode evaluation (no live env, skip with a note) ───────────────
    print("\n── Episode evaluation ──────────────────────────────────────────")
    try:
        env = AdaptiveAlertTriageEnv(task_id="easy")
        results = evaluate_agent(agent_basic, env, num_episodes=3,
                                 task_id="easy", seed_offset=0, verbose=True)
        print(f"\n  mean_score   : {results['mean_score']:.3f}")
        print(f"  success_rate : {results['success_rate']:.3f}")
        print(f"  agent        : {results['agent_name']}")
    except Exception as exc:
        print(f"  [SKIP] Could not instantiate environment: {exc}")
        print("         Run from the project root with the full package installed.")

    print("\n" + "=" * 65)
    print("All unit tests passed!" if all_pass else "SOME UNIT TESTS FAILED β€” see above.")