File size: 4,918 Bytes
e2f8b29 0b9b77b e2f8b29 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | """Test all Pydantic models instantiate and serialize correctly."""
from __future__ import annotations
import json
from openenv.core.env_server.types import Action, Observation
from ml_training_debugger.models import (
EpisodeState,
GradientStats,
MLTrainingAction,
MLTrainingObservation,
RootCauseDiagnosis,
TrainingConfig,
)
class TestRootCauseDiagnosis:
def test_all_values_exist(self):
assert len(RootCauseDiagnosis) == 7
def test_values_are_strings(self):
for d in RootCauseDiagnosis:
assert isinstance(d.value, str)
def test_specific_values(self):
assert RootCauseDiagnosis.LR_TOO_HIGH.value == "lr_too_high"
assert RootCauseDiagnosis.CODE_BUG.value == "code_bug"
class TestTrainingConfig:
def test_default_instantiation(self):
config = TrainingConfig()
assert config.learning_rate == 0.001
assert config.gradient_clip_norm is None
def test_json_roundtrip(self):
config = TrainingConfig(learning_rate=0.01, weight_decay=0.1)
data = json.loads(config.model_dump_json())
restored = TrainingConfig.model_validate(data)
assert restored.learning_rate == 0.01
assert restored.weight_decay == 0.1
class TestGradientStats:
def test_exploding(self):
stats = GradientStats(
layer_name="fc",
norm_history=[15.0],
mean_norm=15.0,
max_norm=15.0,
is_exploding=True,
is_vanishing=False,
)
assert stats.is_exploding
def test_vanishing(self):
stats = GradientStats(
layer_name="conv1",
norm_history=[1e-7],
mean_norm=1e-7,
max_norm=1e-7,
is_exploding=False,
is_vanishing=True,
)
assert stats.is_vanishing
def test_normal(self):
stats = GradientStats(
layer_name="conv1",
norm_history=[0.5],
mean_norm=0.5,
max_norm=0.5,
is_exploding=False,
is_vanishing=False,
)
assert not stats.is_exploding
assert not stats.is_vanishing
class TestEpisodeState:
def test_fresh_state(self):
state = EpisodeState()
assert state.step_count == 0
assert not state.gradients_inspected
assert not state.diagnosis_submitted
def test_available_actions_initial(self):
state = EpisodeState()
actions = state.compute_available_actions()
assert "inspect_gradients" in actions
assert "mark_diagnosed" in actions
assert "fix_code" not in actions
assert "restart_run" not in actions
def test_fix_code_available_after_code_inspected(self):
state = EpisodeState(code_inspected=True)
actions = state.compute_available_actions()
assert "fix_code" in actions
def test_restart_run_available_after_fix(self):
state = EpisodeState(fix_action_taken=True)
actions = state.compute_available_actions()
assert "restart_run" in actions
def test_mark_diagnosed_disappears_after_submission(self):
state = EpisodeState(diagnosis_submitted=True)
actions = state.compute_available_actions()
assert "mark_diagnosed" not in actions
class TestMLTrainingObservation:
def test_extends_observation(self):
assert issubclass(MLTrainingObservation, Observation)
def test_has_done_and_reward(self):
obs = MLTrainingObservation(done=True, reward=0.5)
assert obs.done is True
assert obs.reward == 0.5
def test_json_serialization(self):
obs = MLTrainingObservation(
run_id="test",
training_loss_history=[1.0, 2.0],
val_accuracy_history=[0.5],
)
data = json.loads(obs.model_dump_json())
assert data["run_id"] == "test"
assert data["framework"] == "pytorch"
class TestMLTrainingAction:
def test_extends_action(self):
assert issubclass(MLTrainingAction, Action)
def test_basic_action(self):
action = MLTrainingAction(action_type="inspect_gradients")
assert action.action_type == "inspect_gradients"
def test_modify_config_action(self):
action = MLTrainingAction(
action_type="modify_config",
target="learning_rate",
value=0.001,
)
assert action.target == "learning_rate"
def test_mark_diagnosed_action(self):
action = MLTrainingAction(
action_type="mark_diagnosed",
diagnosis="lr_too_high",
)
assert action.diagnosis == "lr_too_high"
def test_fix_code_action(self):
action = MLTrainingAction(
action_type="fix_code",
line=13,
replacement="loss = criterion(output, batch_y)",
)
assert action.line == 13
|