File size: 3,263 Bytes
4058302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""Reward engine invariants."""

from __future__ import annotations

from server.domain.incidents import build_incident_library, instantiate_incident
from server.domain.reward import (
    CLOSURE_CORRECT_BASE,
    CLUE_CAP_PER_INCIDENT,
    CLUE_REWARD,
    HANDOFF_CORRECT_REWARD,
    MITIGATION_CORRECT_REWARD,
    RewardEngine,
)
from server.domain.rng import SeededRNG


LIBRARY = build_incident_library()


def _sample_incident(task: str = "easy", idx: int = 0):
    template = LIBRARY.templates_for(task)[idx]
    return instantiate_incident(template, SeededRNG(1))


def test_step_cost_applied_for_inspect() -> None:
    engine = RewardEngine()
    br = engine.step_cost("inspect_logs")
    assert br.total() < 0


def test_wrong_actor_penalty_applied_only_when_disallowed() -> None:
    engine = RewardEngine()
    disallowed = engine.wrong_actor("triage_agent", "close_incident", allowed=False)
    allowed = engine.wrong_actor("triage_agent", "inspect_logs", allowed=True)
    assert disallowed.total() < 0
    assert allowed.total() == 0.0


def test_correct_handoff_is_positive() -> None:
    engine = RewardEngine()
    incident = _sample_incident()
    br = engine.handoff(incident, incident.good_handoff)
    assert br.total() >= HANDOFF_CORRECT_REWARD


def test_mitigation_keyword_match() -> None:
    engine = RewardEngine()
    incident = _sample_incident("easy", 0)  # redis pool
    br, ok = engine.mitigation(incident, "increase redis pool size and recycle connections")
    assert ok
    assert br.total() >= MITIGATION_CORRECT_REWARD

    bad_br, bad_ok = engine.mitigation(incident, "delete caches randomly")
    assert not bad_ok
    assert bad_br.total() < 0


def test_clue_reward_capped_and_deduped() -> None:
    engine = RewardEngine()
    incident = _sample_incident("easy", 0)
    used: list[str] = []
    total_new_clue_rewards = 0.0

    for _ in range(10):
        br, was_new, matched = engine.clue_reward(
            incident,
            "redis pool exhaustion in checkout-worker",
            already_used_keys=used,
            current_clue_count=len(used),
        )
        if was_new and matched is not None:
            used.append(matched)
            total_new_clue_rewards += br.total()

    assert len(used) <= CLUE_CAP_PER_INCIDENT
    assert total_new_clue_rewards <= CLUE_CAP_PER_INCIDENT * CLUE_REWARD + 1e-6


def test_closure_correct_scales_with_tier() -> None:
    engine = RewardEngine()
    incident = _sample_incident("medium", 0)  # premium tier
    br, correct = engine.closure(
        incident,
        predicted_root_cause=incident.root_cause,
        mitigation_applied=True,
        clues_count=incident.required_investigations,
        steps_on_incident=3,
        postmortem_submitted=incident.postmortem_required,
    )
    assert correct
    assert br.total() >= CLOSURE_CORRECT_BASE


def test_closure_wrong_is_negative() -> None:
    engine = RewardEngine()
    incident = _sample_incident("easy", 0)
    br, correct = engine.closure(
        incident,
        predicted_root_cause="completely unrelated guess",
        mitigation_applied=False,
        clues_count=0,
        steps_on_incident=1,
        postmortem_submitted=False,
    )
    assert not correct
    assert br.total() < 0