File size: 4,918 Bytes
e2f8b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b9b77b
 
e2f8b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
"""Test all Pydantic models instantiate and serialize correctly."""

from __future__ import annotations

import json

from openenv.core.env_server.types import Action, Observation

from ml_training_debugger.models import (
    EpisodeState,
    GradientStats,
    MLTrainingAction,
    MLTrainingObservation,
    RootCauseDiagnosis,
    TrainingConfig,
)


class TestRootCauseDiagnosis:
    def test_all_values_exist(self):
        assert len(RootCauseDiagnosis) == 7

    def test_values_are_strings(self):
        for d in RootCauseDiagnosis:
            assert isinstance(d.value, str)

    def test_specific_values(self):
        assert RootCauseDiagnosis.LR_TOO_HIGH.value == "lr_too_high"
        assert RootCauseDiagnosis.CODE_BUG.value == "code_bug"


class TestTrainingConfig:
    def test_default_instantiation(self):
        config = TrainingConfig()
        assert config.learning_rate == 0.001
        assert config.gradient_clip_norm is None

    def test_json_roundtrip(self):
        config = TrainingConfig(learning_rate=0.01, weight_decay=0.1)
        data = json.loads(config.model_dump_json())
        restored = TrainingConfig.model_validate(data)
        assert restored.learning_rate == 0.01
        assert restored.weight_decay == 0.1


class TestGradientStats:
    def test_exploding(self):
        stats = GradientStats(
            layer_name="fc",
            norm_history=[15.0],
            mean_norm=15.0,
            max_norm=15.0,
            is_exploding=True,
            is_vanishing=False,
        )
        assert stats.is_exploding

    def test_vanishing(self):
        stats = GradientStats(
            layer_name="conv1",
            norm_history=[1e-7],
            mean_norm=1e-7,
            max_norm=1e-7,
            is_exploding=False,
            is_vanishing=True,
        )
        assert stats.is_vanishing

    def test_normal(self):
        stats = GradientStats(
            layer_name="conv1",
            norm_history=[0.5],
            mean_norm=0.5,
            max_norm=0.5,
            is_exploding=False,
            is_vanishing=False,
        )
        assert not stats.is_exploding
        assert not stats.is_vanishing


class TestEpisodeState:
    def test_fresh_state(self):
        state = EpisodeState()
        assert state.step_count == 0
        assert not state.gradients_inspected
        assert not state.diagnosis_submitted

    def test_available_actions_initial(self):
        state = EpisodeState()
        actions = state.compute_available_actions()
        assert "inspect_gradients" in actions
        assert "mark_diagnosed" in actions
        assert "fix_code" not in actions
        assert "restart_run" not in actions
    def test_fix_code_available_after_code_inspected(self):
        state = EpisodeState(code_inspected=True)
        actions = state.compute_available_actions()
        assert "fix_code" in actions

    def test_restart_run_available_after_fix(self):
        state = EpisodeState(fix_action_taken=True)
        actions = state.compute_available_actions()
        assert "restart_run" in actions

    def test_mark_diagnosed_disappears_after_submission(self):
        state = EpisodeState(diagnosis_submitted=True)
        actions = state.compute_available_actions()
        assert "mark_diagnosed" not in actions


class TestMLTrainingObservation:
    def test_extends_observation(self):
        assert issubclass(MLTrainingObservation, Observation)

    def test_has_done_and_reward(self):
        obs = MLTrainingObservation(done=True, reward=0.5)
        assert obs.done is True
        assert obs.reward == 0.5

    def test_json_serialization(self):
        obs = MLTrainingObservation(
            run_id="test",
            training_loss_history=[1.0, 2.0],
            val_accuracy_history=[0.5],
        )
        data = json.loads(obs.model_dump_json())
        assert data["run_id"] == "test"
        assert data["framework"] == "pytorch"


class TestMLTrainingAction:
    def test_extends_action(self):
        assert issubclass(MLTrainingAction, Action)

    def test_basic_action(self):
        action = MLTrainingAction(action_type="inspect_gradients")
        assert action.action_type == "inspect_gradients"

    def test_modify_config_action(self):
        action = MLTrainingAction(
            action_type="modify_config",
            target="learning_rate",
            value=0.001,
        )
        assert action.target == "learning_rate"

    def test_mark_diagnosed_action(self):
        action = MLTrainingAction(
            action_type="mark_diagnosed",
            diagnosis="lr_too_high",
        )
        assert action.diagnosis == "lr_too_high"

    def test_fix_code_action(self):
        action = MLTrainingAction(
            action_type="fix_code",
            line=13,
            replacement="loss = criterion(output, batch_y)",
        )
        assert action.line == 13