File size: 7,262 Bytes
e2f8b29 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 | """Test reward engine — all 7 components. THE MOST CRITICAL TEST FILE."""
from __future__ import annotations
import pytest
from ml_training_debugger.models import EpisodeState, MLTrainingAction
from ml_training_debugger.reward_engine import (
CONTEXT_GATED_PENALTY,
CORRECT_DIAGNOSIS_REWARD,
INVALID_ACTION_PENALTY,
INVESTIGATION_BONUS,
STEP_PENALTY,
TERMINAL_CONVERGENCE_REWARD,
WRONG_CODE_FIX_PENALTY,
WRONG_DIAGNOSIS_PENALTY,
compute_reward,
)
from ml_training_debugger.scenarios import sample_scenario
@pytest.fixture
def scenario():
return sample_scenario("task_001", seed=42)
@pytest.fixture
def scenario_005():
return sample_scenario("task_005", seed=42)
class TestStepPenalty:
def test_flat_step_penalty(self, scenario):
state = EpisodeState()
action = MLTrainingAction(action_type="add_callback")
reward = compute_reward(action, state, scenario)
assert reward == pytest.approx(STEP_PENALTY)
def test_step_penalty_not_multiplied_by_step_count(self, scenario):
state = EpisodeState(step_count=30)
action = MLTrainingAction(action_type="add_callback")
reward = compute_reward(action, state, scenario)
# Must be flat -0.01, NOT -0.01 * 30
assert reward == pytest.approx(-0.01)
class TestInvestigationBonus:
def test_first_time_bonus(self, scenario):
state = EpisodeState(gradients_inspected=False)
action = MLTrainingAction(action_type="inspect_gradients")
reward = compute_reward(action, state, scenario)
assert reward == pytest.approx(STEP_PENALTY + INVESTIGATION_BONUS)
def test_no_bonus_on_repeat(self, scenario):
state = EpisodeState(gradients_inspected=True)
action = MLTrainingAction(action_type="inspect_gradients")
reward = compute_reward(action, state, scenario)
assert reward == pytest.approx(STEP_PENALTY)
def test_each_inspection_type_gives_bonus(self, scenario):
for action_type, field in [
("inspect_gradients", "gradients_inspected"),
("inspect_data_batch", "data_inspected"),
("inspect_model_modes", "model_modes_inspected"),
("inspect_model_weights", "model_weights_inspected"),
("inspect_code", "code_inspected"),
]:
state = EpisodeState(**{field: False})
action = MLTrainingAction(action_type=action_type)
reward = compute_reward(action, state, scenario)
assert reward == pytest.approx(
STEP_PENALTY + INVESTIGATION_BONUS
), f"Failed for {action_type}"
class TestContextGatedPenalty:
"""The project's primary innovation — must be exact."""
def test_no_penalty_before_inspection(self, scenario_005):
"""add_callback at step 1 (no prior inspection) -> NO penalty."""
state = EpisodeState()
action = MLTrainingAction(action_type="add_callback")
reward = compute_reward(action, state, scenario_005)
assert reward == pytest.approx(STEP_PENALTY)
def test_penalty_after_normal_gradients(self, scenario_005):
"""inspect_gradients (normal) then add_callback -> -0.20 penalty."""
state = EpisodeState(gradients_inspected=True, gradients_were_normal=True)
action = MLTrainingAction(action_type="add_callback")
reward = compute_reward(action, state, scenario_005)
assert reward == pytest.approx(STEP_PENALTY + CONTEXT_GATED_PENALTY)
def test_no_penalty_after_abnormal_gradients(self, scenario):
"""inspect_gradients (exploding) then add_callback -> no context penalty."""
state = EpisodeState(gradients_inspected=True, gradients_were_normal=False)
action = MLTrainingAction(action_type="add_callback")
reward = compute_reward(action, state, scenario)
assert reward == pytest.approx(STEP_PENALTY)
def test_penalty_only_for_add_callback(self, scenario_005):
"""Other fix actions don't trigger context-gated penalty."""
state = EpisodeState(gradients_inspected=True, gradients_were_normal=True)
for action_type in ["modify_config", "fix_model_mode", "patch_data_loader"]:
action = MLTrainingAction(
action_type=action_type, target="learning_rate", value=0.001
)
reward = compute_reward(action, state, scenario_005)
assert reward == pytest.approx(
STEP_PENALTY
), f"Unexpected penalty for {action_type}"
class TestDiagnosisReward:
def test_correct_diagnosis(self, scenario):
state = EpisodeState()
action = MLTrainingAction(action_type="mark_diagnosed", diagnosis="lr_too_high")
reward = compute_reward(action, state, scenario)
assert reward == pytest.approx(STEP_PENALTY + CORRECT_DIAGNOSIS_REWARD)
def test_wrong_diagnosis(self, scenario):
state = EpisodeState()
action = MLTrainingAction(
action_type="mark_diagnosed", diagnosis="data_leakage"
)
reward = compute_reward(action, state, scenario)
assert reward == pytest.approx(STEP_PENALTY + WRONG_DIAGNOSIS_PENALTY)
class TestTerminalConvergence:
def test_convergence_after_fix_and_restart(self, scenario):
state = EpisodeState(fix_action_taken=True)
action = MLTrainingAction(action_type="restart_run")
reward = compute_reward(action, state, scenario, convergence_confirmed=True)
assert reward == pytest.approx(STEP_PENALTY + TERMINAL_CONVERGENCE_REWARD)
def test_no_convergence_without_fix(self, scenario):
state = EpisodeState(fix_action_taken=False)
action = MLTrainingAction(action_type="restart_run")
reward = compute_reward(action, state, scenario, convergence_confirmed=True)
# fix_action_taken is False, so no convergence reward
assert reward == pytest.approx(STEP_PENALTY)
class TestInvalidAction:
def test_invalid_action_penalty(self, scenario):
state = EpisodeState()
action = MLTrainingAction(action_type="restart_run")
reward = compute_reward(action, state, scenario, is_valid_action=False)
assert reward == pytest.approx(STEP_PENALTY + INVALID_ACTION_PENALTY)
class TestWrongCodeFix:
def test_wrong_code_fix_penalty(self, scenario):
state = EpisodeState(code_inspected=True)
action = MLTrainingAction(action_type="fix_code", line=1, replacement="pass")
reward = compute_reward(action, state, scenario, is_correct_fix=False)
assert reward == pytest.approx(STEP_PENALTY + WRONG_CODE_FIX_PENALTY)
class TestRewardCap:
def test_reward_capped_at_one(self, scenario):
# Theoretical max would exceed 1.0 in some scenarios
reward = compute_reward(
MLTrainingAction(action_type="mark_diagnosed", diagnosis="lr_too_high"),
EpisodeState(),
scenario,
)
assert reward <= 1.0
def test_reward_capped_at_negative_one(self, scenario):
reward = compute_reward(
MLTrainingAction(action_type="mark_diagnosed", diagnosis="wrong"),
EpisodeState(),
scenario,
)
assert reward >= -1.0
|