"""Test all Pydantic models instantiate and serialize correctly.""" from __future__ import annotations import json from openenv.core.env_server.types import Action, Observation from ml_training_debugger.models import ( EpisodeState, GradientStats, MLTrainingAction, MLTrainingObservation, RootCauseDiagnosis, TrainingConfig, ) class TestRootCauseDiagnosis: def test_all_values_exist(self): assert len(RootCauseDiagnosis) == 7 def test_values_are_strings(self): for d in RootCauseDiagnosis: assert isinstance(d.value, str) def test_specific_values(self): assert RootCauseDiagnosis.LR_TOO_HIGH.value == "lr_too_high" assert RootCauseDiagnosis.CODE_BUG.value == "code_bug" class TestTrainingConfig: def test_default_instantiation(self): config = TrainingConfig() assert config.learning_rate == 0.001 assert config.gradient_clip_norm is None def test_json_roundtrip(self): config = TrainingConfig(learning_rate=0.01, weight_decay=0.1) data = json.loads(config.model_dump_json()) restored = TrainingConfig.model_validate(data) assert restored.learning_rate == 0.01 assert restored.weight_decay == 0.1 class TestGradientStats: def test_exploding(self): stats = GradientStats( layer_name="fc", norm_history=[15.0], mean_norm=15.0, max_norm=15.0, is_exploding=True, is_vanishing=False, ) assert stats.is_exploding def test_vanishing(self): stats = GradientStats( layer_name="conv1", norm_history=[1e-7], mean_norm=1e-7, max_norm=1e-7, is_exploding=False, is_vanishing=True, ) assert stats.is_vanishing def test_normal(self): stats = GradientStats( layer_name="conv1", norm_history=[0.5], mean_norm=0.5, max_norm=0.5, is_exploding=False, is_vanishing=False, ) assert not stats.is_exploding assert not stats.is_vanishing class TestEpisodeState: def test_fresh_state(self): state = EpisodeState() assert state.step_count == 0 assert not state.gradients_inspected assert not state.diagnosis_submitted def test_available_actions_initial(self): state = EpisodeState() actions = state.compute_available_actions() assert "inspect_gradients" in actions assert "mark_diagnosed" in actions assert "fix_code" not in actions assert "restart_run" not in actions def test_fix_code_available_after_code_inspected(self): state = EpisodeState(code_inspected=True) actions = state.compute_available_actions() assert "fix_code" in actions def test_restart_run_available_after_fix(self): state = EpisodeState(fix_action_taken=True) actions = state.compute_available_actions() assert "restart_run" in actions def test_mark_diagnosed_disappears_after_submission(self): state = EpisodeState(diagnosis_submitted=True) actions = state.compute_available_actions() assert "mark_diagnosed" not in actions class TestMLTrainingObservation: def test_extends_observation(self): assert issubclass(MLTrainingObservation, Observation) def test_has_done_and_reward(self): obs = MLTrainingObservation(done=True, reward=0.5) assert obs.done is True assert obs.reward == 0.5 def test_json_serialization(self): obs = MLTrainingObservation( run_id="test", training_loss_history=[1.0, 2.0], val_accuracy_history=[0.5], ) data = json.loads(obs.model_dump_json()) assert data["run_id"] == "test" assert data["framework"] == "pytorch" class TestMLTrainingAction: def test_extends_action(self): assert issubclass(MLTrainingAction, Action) def test_basic_action(self): action = MLTrainingAction(action_type="inspect_gradients") assert action.action_type == "inspect_gradients" def test_modify_config_action(self): action = MLTrainingAction( action_type="modify_config", target="learning_rate", value=0.001, ) assert action.target == "learning_rate" def test_mark_diagnosed_action(self): action = MLTrainingAction( action_type="mark_diagnosed", diagnosis="lr_too_high", ) assert action.diagnosis == "lr_too_high" def test_fix_code_action(self): action = MLTrainingAction( action_type="fix_code", line=13, replacement="loss = criterion(output, batch_y)", ) assert action.line == 13