pytorch-training-debugger / tests /test_graders.py
omkarrr88
task 6 and 7 made hard
02e58fe
"""Test grader functions — each returns 0.0-1.0."""
from __future__ import annotations
import pytest
from ml_training_debugger.graders import (
_submitted_diagnosis,
grade_episode,
grade_task_001,
grade_task_003,
grade_task_005,
grade_task_006,
grade_task_007,
)
from ml_training_debugger.models import EpisodeState
from ml_training_debugger.scenarios import sample_scenario
@pytest.fixture
def scenario_001():
return sample_scenario("task_001", seed=42)
@pytest.fixture
def scenario_003():
return sample_scenario("task_003", seed=42)
@pytest.fixture
def scenario_005():
return sample_scenario("task_005", seed=42)
class TestGradeTask001:
def test_perfect_score(self, scenario_001):
state = EpisodeState(
gradients_inspected=True,
fix_action_taken=True,
restart_after_fix=True,
diagnosis_submitted=True,
actions_taken=[
"inspect_gradients",
"modify_config",
"restart_run",
"mark_diagnosed:lr_too_high",
],
)
score = grade_task_001(state, scenario_001)
assert score == 1.0
def test_wrong_diagnosis(self, scenario_001):
state = EpisodeState(
gradients_inspected=True,
fix_action_taken=True,
restart_after_fix=True,
diagnosis_submitted=True,
actions_taken=[
"inspect_gradients",
"modify_config",
"restart_run",
"mark_diagnosed:data_leakage",
],
)
score = grade_task_001(state, scenario_001)
assert score < 0.7 # Missing diagnosis credit
def test_no_investigation(self, scenario_001):
state = EpisodeState(
diagnosis_submitted=True,
actions_taken=["mark_diagnosed:lr_too_high"],
)
score = grade_task_001(state, scenario_001)
assert 0.0 < score < 1.0
def test_score_in_range(self, scenario_001):
state = EpisodeState()
score = grade_task_001(state, scenario_001)
assert 0.0 <= score <= 1.0
class TestGradeTask003:
def test_perfect_score(self, scenario_003):
state = EpisodeState(
data_inspected=True,
fix_action_taken=True,
restart_after_fix=True,
diagnosis_submitted=True,
actions_taken=[
"inspect_data_batch",
"patch_data_loader",
"restart_run",
"mark_diagnosed:data_leakage",
],
)
score = grade_task_003(state, scenario_003)
assert score == pytest.approx(1.0)
def test_wrong_diagnosis(self, scenario_003):
state = EpisodeState(
data_inspected=True,
diagnosis_submitted=True,
actions_taken=[
"inspect_data_batch",
"mark_diagnosed:overfitting",
],
)
score = grade_task_003(state, scenario_003)
assert score < 0.5
class TestGradeTask005:
def test_perfect_score_thorough(self, scenario_005):
"""Thorough agent inspects weights AND data — gets perfect score."""
state = EpisodeState(
gradients_inspected=True,
gradients_were_normal=True,
model_modes_inspected=True,
model_weights_inspected=True,
data_inspected=True,
fix_action_taken=True,
restart_after_fix=True,
diagnosis_submitted=True,
actions_taken=[
"inspect_gradients",
"inspect_data_batch",
"inspect_model_weights",
"inspect_model_modes",
"fix_model_mode",
"restart_run",
"mark_diagnosed:batchnorm_eval_mode",
],
)
score = grade_task_005(state, scenario_005)
assert score == pytest.approx(1.0)
def test_quick_fix_partial_credit(self, scenario_005):
"""Agent that skips weight inspection gets partial credit."""
state = EpisodeState(
gradients_inspected=True,
gradients_were_normal=True,
model_modes_inspected=True,
fix_action_taken=True,
restart_after_fix=True,
diagnosis_submitted=True,
actions_taken=[
"inspect_gradients",
"inspect_model_modes",
"fix_model_mode",
"restart_run",
"mark_diagnosed:batchnorm_eval_mode",
],
)
score = grade_task_005(state, scenario_005)
# No weights, no data → 0.05+0.05+0.08+0.10+0.40 = 0.68
assert score == pytest.approx(0.68)
def test_red_herring_chaser(self, scenario_005):
"""Agent that chases gradient red herring loses 0.20."""
state = EpisodeState(
gradients_inspected=True,
gradients_were_normal=True,
model_modes_inspected=True,
fix_action_taken=True,
restart_after_fix=True,
diagnosis_submitted=True,
actions_taken=[
"inspect_gradients",
"add_callback", # Wrong: chases red herring
"inspect_model_modes",
"fix_model_mode",
"restart_run",
"mark_diagnosed:batchnorm_eval_mode",
],
)
score = grade_task_005(state, scenario_005)
# -0.20 penalty for add_callback after normal gradients
assert score == pytest.approx(0.48)
def test_wrong_fix_penalty(self, scenario_005):
"""Agent that tries modify_config (wrong fix) gets penalized."""
state = EpisodeState(
gradients_inspected=True,
gradients_were_normal=True,
model_modes_inspected=True,
fix_action_taken=True,
restart_after_fix=True,
diagnosis_submitted=True,
actions_taken=[
"inspect_gradients",
"modify_config", # Wrong: LR isn't the problem
"inspect_model_modes",
"fix_model_mode",
"restart_run",
"mark_diagnosed:batchnorm_eval_mode",
],
)
score = grade_task_005(state, scenario_005)
# -0.10 penalty for modify_config on task 5
assert score == pytest.approx(0.58)
def test_double_trap_devastates_score(self, scenario_005):
"""Agent that falls for BOTH traps (add_callback + modify_config) scores poorly."""
state = EpisodeState(
gradients_inspected=True,
gradients_were_normal=True,
model_modes_inspected=True,
fix_action_taken=True,
restart_after_fix=True,
diagnosis_submitted=True,
actions_taken=[
"inspect_gradients",
"add_callback",
"modify_config",
"inspect_model_modes",
"fix_model_mode",
"restart_run",
"mark_diagnosed:batchnorm_eval_mode",
],
)
score = grade_task_005(state, scenario_005)
# -0.20 (add_callback) + -0.10 (modify_config) = -0.30 penalties
assert score == pytest.approx(0.38)
class TestGradeEpisode:
def test_dispatch_to_correct_grader(self, scenario_001):
state = EpisodeState(
gradients_inspected=True,
diagnosis_submitted=True,
actions_taken=[
"inspect_gradients",
"mark_diagnosed:lr_too_high",
],
)
score = grade_episode("task_001", state, scenario_001)
assert 0.0 <= score <= 1.0
def test_unknown_task_returns_zero(self, scenario_001):
state = EpisodeState()
score = grade_episode("task_999", state, scenario_001)
assert score == 0.0
class TestGradeTask006:
@pytest.fixture
def scenario_006(self):
return sample_scenario("task_006", seed=42)
def test_perfect_score_thorough(self, scenario_006):
"""Thorough agent inspects ALL systems including weights — gets perfect score."""
state = EpisodeState(
code_inspected=True,
gradients_inspected=True,
model_modes_inspected=True,
model_weights_inspected=True,
data_inspected=True,
fix_action_taken=True,
restart_after_fix=True,
diagnosis_submitted=True,
actions_taken=[
"inspect_gradients",
"inspect_data_batch",
"inspect_model_weights",
"inspect_model_modes",
"inspect_code",
"fix_code",
"restart_run",
"mark_diagnosed:code_bug",
],
)
score = grade_task_006(state, scenario_006)
assert score == pytest.approx(1.0)
def test_no_weights_inspection_partial(self, scenario_006):
"""Agent that skips weight inspection gets reduced fix/restart credit."""
state = EpisodeState(
code_inspected=True,
gradients_inspected=True,
model_modes_inspected=True,
data_inspected=True,
fix_action_taken=True,
restart_after_fix=True,
diagnosis_submitted=True,
actions_taken=[
"inspect_gradients",
"inspect_data_batch",
"inspect_model_modes",
"inspect_code",
"fix_code",
"restart_run",
"mark_diagnosed:code_bug",
],
)
score = grade_task_006(state, scenario_006)
# 0.05*4 + 0.08 + 0.08 + 0.45 = 0.81
assert score == pytest.approx(0.81)
assert score < 1.0 # Must not be perfect without weights
def test_minimal_investigation(self, scenario_006):
"""Agent that only inspects code, fixes, and diagnoses."""
state = EpisodeState(
code_inspected=True,
fix_action_taken=True,
restart_after_fix=True,
diagnosis_submitted=True,
actions_taken=[
"inspect_code",
"fix_code",
"restart_run",
"mark_diagnosed:code_bug",
],
)
score = grade_task_006(state, scenario_006)
# 0.05 + 0.08 + 0.08 + 0.45 = 0.66
assert score == pytest.approx(0.66)
def test_wrong_diagnosis(self, scenario_006):
"""Submitting batchnorm_eval_mode on a code_bug task fails."""
state = EpisodeState(
code_inspected=True,
diagnosis_submitted=True,
actions_taken=[
"inspect_code",
"mark_diagnosed:batchnorm_eval_mode",
],
)
score = grade_task_006(state, scenario_006)
assert score < 0.2 # Only gets code_inspected bonus
def test_score_in_range(self, scenario_006):
state = EpisodeState()
score = grade_task_006(state, scenario_006)
assert 0.0 <= score <= 1.0
class TestGradeTask007:
def test_perfect_score_thorough(self):
"""Thorough agent inspects weights — gets perfect score."""
scenario = sample_scenario("task_007", seed=42)
state = EpisodeState(
gradients_inspected=True,
data_inspected=True,
model_weights_inspected=True,
model_modes_inspected=True,
fix_action_taken=True,
restart_after_fix=True,
diagnosis_submitted=True,
actions_taken=[
"inspect_gradients",
"inspect_data_batch",
"inspect_model_weights",
"inspect_model_modes",
"modify_config",
"restart_run",
"mark_diagnosed:scheduler_misconfigured",
],
)
score = grade_task_007(state, scenario)
assert score == pytest.approx(1.0)
def test_no_weights_partial(self):
"""Agent without weight inspection gets reduced fix/restart credit."""
scenario = sample_scenario("task_007", seed=42)
state = EpisodeState(
gradients_inspected=True,
data_inspected=True,
model_modes_inspected=True,
fix_action_taken=True,
restart_after_fix=True,
diagnosis_submitted=True,
actions_taken=[
"inspect_gradients",
"inspect_data_batch",
"inspect_model_modes",
"modify_config",
"restart_run",
"mark_diagnosed:scheduler_misconfigured",
],
)
score = grade_task_007(state, scenario)
# 0.05*3 + 0.12 + 0.12 + 0.40 = 0.79
assert score == pytest.approx(0.79)
assert score < 1.0
def test_wrong_fix_penalty(self):
"""Agent that patches data loader (wrong fix) gets penalized."""
scenario = sample_scenario("task_007", seed=42)
state = EpisodeState(
gradients_inspected=True,
data_inspected=True,
fix_action_taken=True,
restart_after_fix=True,
diagnosis_submitted=True,
actions_taken=[
"inspect_gradients",
"inspect_data_batch",
"patch_data_loader",
"modify_config",
"restart_run",
"mark_diagnosed:scheduler_misconfigured",
],
)
score = grade_task_007(state, scenario)
# Normal partial score minus 0.10 penalty
assert score < 0.75
def test_wrong_diagnosis(self):
scenario = sample_scenario("task_007", seed=42)
state = EpisodeState(
diagnosis_submitted=True,
actions_taken=["mark_diagnosed:overfitting"],
)
score = grade_task_007(state, scenario)
assert score < 0.5
def test_score_in_range(self):
scenario = sample_scenario("task_007", seed=42)
state = EpisodeState()
score = grade_task_007(state, scenario)
assert 0.0 <= score <= 1.0
class TestSubmittedDiagnosis:
def test_finds_diagnosis(self):
state = EpisodeState(
actions_taken=["inspect_gradients", "mark_diagnosed:lr_too_high"],
)
assert _submitted_diagnosis(state) == "lr_too_high"
def test_no_diagnosis(self):
state = EpisodeState(actions_taken=["inspect_gradients"])
assert _submitted_diagnosis(state) is None
def test_latest_diagnosis(self):
state = EpisodeState(
actions_taken=[
"mark_diagnosed:overfitting",
"mark_diagnosed:lr_too_high",
],
)
assert _submitted_diagnosis(state) == "lr_too_high"