File size: 7,262 Bytes
e2f8b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
"""Test reward engine — all 7 components. THE MOST CRITICAL TEST FILE."""

from __future__ import annotations

import pytest

from ml_training_debugger.models import EpisodeState, MLTrainingAction
from ml_training_debugger.reward_engine import (
    CONTEXT_GATED_PENALTY,
    CORRECT_DIAGNOSIS_REWARD,
    INVALID_ACTION_PENALTY,
    INVESTIGATION_BONUS,
    STEP_PENALTY,
    TERMINAL_CONVERGENCE_REWARD,
    WRONG_CODE_FIX_PENALTY,
    WRONG_DIAGNOSIS_PENALTY,
    compute_reward,
)
from ml_training_debugger.scenarios import sample_scenario


@pytest.fixture
def scenario():
    return sample_scenario("task_001", seed=42)


@pytest.fixture
def scenario_005():
    return sample_scenario("task_005", seed=42)


class TestStepPenalty:
    def test_flat_step_penalty(self, scenario):
        state = EpisodeState()
        action = MLTrainingAction(action_type="add_callback")
        reward = compute_reward(action, state, scenario)
        assert reward == pytest.approx(STEP_PENALTY)

    def test_step_penalty_not_multiplied_by_step_count(self, scenario):
        state = EpisodeState(step_count=30)
        action = MLTrainingAction(action_type="add_callback")
        reward = compute_reward(action, state, scenario)
        # Must be flat -0.01, NOT -0.01 * 30
        assert reward == pytest.approx(-0.01)


class TestInvestigationBonus:
    def test_first_time_bonus(self, scenario):
        state = EpisodeState(gradients_inspected=False)
        action = MLTrainingAction(action_type="inspect_gradients")
        reward = compute_reward(action, state, scenario)
        assert reward == pytest.approx(STEP_PENALTY + INVESTIGATION_BONUS)

    def test_no_bonus_on_repeat(self, scenario):
        state = EpisodeState(gradients_inspected=True)
        action = MLTrainingAction(action_type="inspect_gradients")
        reward = compute_reward(action, state, scenario)
        assert reward == pytest.approx(STEP_PENALTY)

    def test_each_inspection_type_gives_bonus(self, scenario):
        for action_type, field in [
            ("inspect_gradients", "gradients_inspected"),
            ("inspect_data_batch", "data_inspected"),
            ("inspect_model_modes", "model_modes_inspected"),
            ("inspect_model_weights", "model_weights_inspected"),
            ("inspect_code", "code_inspected"),
        ]:
            state = EpisodeState(**{field: False})
            action = MLTrainingAction(action_type=action_type)
            reward = compute_reward(action, state, scenario)
            assert reward == pytest.approx(
                STEP_PENALTY + INVESTIGATION_BONUS
            ), f"Failed for {action_type}"


class TestContextGatedPenalty:
    """The project's primary innovation — must be exact."""

    def test_no_penalty_before_inspection(self, scenario_005):
        """add_callback at step 1 (no prior inspection) -> NO penalty."""
        state = EpisodeState()
        action = MLTrainingAction(action_type="add_callback")
        reward = compute_reward(action, state, scenario_005)
        assert reward == pytest.approx(STEP_PENALTY)

    def test_penalty_after_normal_gradients(self, scenario_005):
        """inspect_gradients (normal) then add_callback -> -0.20 penalty."""
        state = EpisodeState(gradients_inspected=True, gradients_were_normal=True)
        action = MLTrainingAction(action_type="add_callback")
        reward = compute_reward(action, state, scenario_005)
        assert reward == pytest.approx(STEP_PENALTY + CONTEXT_GATED_PENALTY)

    def test_no_penalty_after_abnormal_gradients(self, scenario):
        """inspect_gradients (exploding) then add_callback -> no context penalty."""
        state = EpisodeState(gradients_inspected=True, gradients_were_normal=False)
        action = MLTrainingAction(action_type="add_callback")
        reward = compute_reward(action, state, scenario)
        assert reward == pytest.approx(STEP_PENALTY)

    def test_penalty_only_for_add_callback(self, scenario_005):
        """Other fix actions don't trigger context-gated penalty."""
        state = EpisodeState(gradients_inspected=True, gradients_were_normal=True)
        for action_type in ["modify_config", "fix_model_mode", "patch_data_loader"]:
            action = MLTrainingAction(
                action_type=action_type, target="learning_rate", value=0.001
            )
            reward = compute_reward(action, state, scenario_005)
            assert reward == pytest.approx(
                STEP_PENALTY
            ), f"Unexpected penalty for {action_type}"


class TestDiagnosisReward:
    def test_correct_diagnosis(self, scenario):
        state = EpisodeState()
        action = MLTrainingAction(action_type="mark_diagnosed", diagnosis="lr_too_high")
        reward = compute_reward(action, state, scenario)
        assert reward == pytest.approx(STEP_PENALTY + CORRECT_DIAGNOSIS_REWARD)

    def test_wrong_diagnosis(self, scenario):
        state = EpisodeState()
        action = MLTrainingAction(
            action_type="mark_diagnosed", diagnosis="data_leakage"
        )
        reward = compute_reward(action, state, scenario)
        assert reward == pytest.approx(STEP_PENALTY + WRONG_DIAGNOSIS_PENALTY)


class TestTerminalConvergence:
    def test_convergence_after_fix_and_restart(self, scenario):
        state = EpisodeState(fix_action_taken=True)
        action = MLTrainingAction(action_type="restart_run")
        reward = compute_reward(action, state, scenario, convergence_confirmed=True)
        assert reward == pytest.approx(STEP_PENALTY + TERMINAL_CONVERGENCE_REWARD)

    def test_no_convergence_without_fix(self, scenario):
        state = EpisodeState(fix_action_taken=False)
        action = MLTrainingAction(action_type="restart_run")
        reward = compute_reward(action, state, scenario, convergence_confirmed=True)
        # fix_action_taken is False, so no convergence reward
        assert reward == pytest.approx(STEP_PENALTY)


class TestInvalidAction:
    def test_invalid_action_penalty(self, scenario):
        state = EpisodeState()
        action = MLTrainingAction(action_type="restart_run")
        reward = compute_reward(action, state, scenario, is_valid_action=False)
        assert reward == pytest.approx(STEP_PENALTY + INVALID_ACTION_PENALTY)


class TestWrongCodeFix:
    def test_wrong_code_fix_penalty(self, scenario):
        state = EpisodeState(code_inspected=True)
        action = MLTrainingAction(action_type="fix_code", line=1, replacement="pass")
        reward = compute_reward(action, state, scenario, is_correct_fix=False)
        assert reward == pytest.approx(STEP_PENALTY + WRONG_CODE_FIX_PENALTY)


class TestRewardCap:
    def test_reward_capped_at_one(self, scenario):
        # Theoretical max would exceed 1.0 in some scenarios
        reward = compute_reward(
            MLTrainingAction(action_type="mark_diagnosed", diagnosis="lr_too_high"),
            EpisodeState(),
            scenario,
        )
        assert reward <= 1.0

    def test_reward_capped_at_negative_one(self, scenario):
        reward = compute_reward(
            MLTrainingAction(action_type="mark_diagnosed", diagnosis="wrong"),
            EpisodeState(),
            scenario,
        )
        assert reward >= -1.0