pytorch-training-debugger / tests /test_reward_engine.py
omkarrr88
Version 1
e2f8b29
"""Test reward engine — all 7 components. THE MOST CRITICAL TEST FILE."""
from __future__ import annotations
import pytest
from ml_training_debugger.models import EpisodeState, MLTrainingAction
from ml_training_debugger.reward_engine import (
CONTEXT_GATED_PENALTY,
CORRECT_DIAGNOSIS_REWARD,
INVALID_ACTION_PENALTY,
INVESTIGATION_BONUS,
STEP_PENALTY,
TERMINAL_CONVERGENCE_REWARD,
WRONG_CODE_FIX_PENALTY,
WRONG_DIAGNOSIS_PENALTY,
compute_reward,
)
from ml_training_debugger.scenarios import sample_scenario
@pytest.fixture
def scenario():
return sample_scenario("task_001", seed=42)
@pytest.fixture
def scenario_005():
return sample_scenario("task_005", seed=42)
class TestStepPenalty:
def test_flat_step_penalty(self, scenario):
state = EpisodeState()
action = MLTrainingAction(action_type="add_callback")
reward = compute_reward(action, state, scenario)
assert reward == pytest.approx(STEP_PENALTY)
def test_step_penalty_not_multiplied_by_step_count(self, scenario):
state = EpisodeState(step_count=30)
action = MLTrainingAction(action_type="add_callback")
reward = compute_reward(action, state, scenario)
# Must be flat -0.01, NOT -0.01 * 30
assert reward == pytest.approx(-0.01)
class TestInvestigationBonus:
def test_first_time_bonus(self, scenario):
state = EpisodeState(gradients_inspected=False)
action = MLTrainingAction(action_type="inspect_gradients")
reward = compute_reward(action, state, scenario)
assert reward == pytest.approx(STEP_PENALTY + INVESTIGATION_BONUS)
def test_no_bonus_on_repeat(self, scenario):
state = EpisodeState(gradients_inspected=True)
action = MLTrainingAction(action_type="inspect_gradients")
reward = compute_reward(action, state, scenario)
assert reward == pytest.approx(STEP_PENALTY)
def test_each_inspection_type_gives_bonus(self, scenario):
for action_type, field in [
("inspect_gradients", "gradients_inspected"),
("inspect_data_batch", "data_inspected"),
("inspect_model_modes", "model_modes_inspected"),
("inspect_model_weights", "model_weights_inspected"),
("inspect_code", "code_inspected"),
]:
state = EpisodeState(**{field: False})
action = MLTrainingAction(action_type=action_type)
reward = compute_reward(action, state, scenario)
assert reward == pytest.approx(
STEP_PENALTY + INVESTIGATION_BONUS
), f"Failed for {action_type}"
class TestContextGatedPenalty:
"""The project's primary innovation — must be exact."""
def test_no_penalty_before_inspection(self, scenario_005):
"""add_callback at step 1 (no prior inspection) -> NO penalty."""
state = EpisodeState()
action = MLTrainingAction(action_type="add_callback")
reward = compute_reward(action, state, scenario_005)
assert reward == pytest.approx(STEP_PENALTY)
def test_penalty_after_normal_gradients(self, scenario_005):
"""inspect_gradients (normal) then add_callback -> -0.20 penalty."""
state = EpisodeState(gradients_inspected=True, gradients_were_normal=True)
action = MLTrainingAction(action_type="add_callback")
reward = compute_reward(action, state, scenario_005)
assert reward == pytest.approx(STEP_PENALTY + CONTEXT_GATED_PENALTY)
def test_no_penalty_after_abnormal_gradients(self, scenario):
"""inspect_gradients (exploding) then add_callback -> no context penalty."""
state = EpisodeState(gradients_inspected=True, gradients_were_normal=False)
action = MLTrainingAction(action_type="add_callback")
reward = compute_reward(action, state, scenario)
assert reward == pytest.approx(STEP_PENALTY)
def test_penalty_only_for_add_callback(self, scenario_005):
"""Other fix actions don't trigger context-gated penalty."""
state = EpisodeState(gradients_inspected=True, gradients_were_normal=True)
for action_type in ["modify_config", "fix_model_mode", "patch_data_loader"]:
action = MLTrainingAction(
action_type=action_type, target="learning_rate", value=0.001
)
reward = compute_reward(action, state, scenario_005)
assert reward == pytest.approx(
STEP_PENALTY
), f"Unexpected penalty for {action_type}"
class TestDiagnosisReward:
def test_correct_diagnosis(self, scenario):
state = EpisodeState()
action = MLTrainingAction(action_type="mark_diagnosed", diagnosis="lr_too_high")
reward = compute_reward(action, state, scenario)
assert reward == pytest.approx(STEP_PENALTY + CORRECT_DIAGNOSIS_REWARD)
def test_wrong_diagnosis(self, scenario):
state = EpisodeState()
action = MLTrainingAction(
action_type="mark_diagnosed", diagnosis="data_leakage"
)
reward = compute_reward(action, state, scenario)
assert reward == pytest.approx(STEP_PENALTY + WRONG_DIAGNOSIS_PENALTY)
class TestTerminalConvergence:
def test_convergence_after_fix_and_restart(self, scenario):
state = EpisodeState(fix_action_taken=True)
action = MLTrainingAction(action_type="restart_run")
reward = compute_reward(action, state, scenario, convergence_confirmed=True)
assert reward == pytest.approx(STEP_PENALTY + TERMINAL_CONVERGENCE_REWARD)
def test_no_convergence_without_fix(self, scenario):
state = EpisodeState(fix_action_taken=False)
action = MLTrainingAction(action_type="restart_run")
reward = compute_reward(action, state, scenario, convergence_confirmed=True)
# fix_action_taken is False, so no convergence reward
assert reward == pytest.approx(STEP_PENALTY)
class TestInvalidAction:
def test_invalid_action_penalty(self, scenario):
state = EpisodeState()
action = MLTrainingAction(action_type="restart_run")
reward = compute_reward(action, state, scenario, is_valid_action=False)
assert reward == pytest.approx(STEP_PENALTY + INVALID_ACTION_PENALTY)
class TestWrongCodeFix:
def test_wrong_code_fix_penalty(self, scenario):
state = EpisodeState(code_inspected=True)
action = MLTrainingAction(action_type="fix_code", line=1, replacement="pass")
reward = compute_reward(action, state, scenario, is_correct_fix=False)
assert reward == pytest.approx(STEP_PENALTY + WRONG_CODE_FIX_PENALTY)
class TestRewardCap:
def test_reward_capped_at_one(self, scenario):
# Theoretical max would exceed 1.0 in some scenarios
reward = compute_reward(
MLTrainingAction(action_type="mark_diagnosed", diagnosis="lr_too_high"),
EpisodeState(),
scenario,
)
assert reward <= 1.0
def test_reward_capped_at_negative_one(self, scenario):
reward = compute_reward(
MLTrainingAction(action_type="mark_diagnosed", diagnosis="wrong"),
EpisodeState(),
scenario,
)
assert reward >= -1.0