File size: 9,621 Bytes
e2f8b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eeb6913
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e2f8b29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
"""Test full episode lifecycle — reset, step, state transitions."""

from __future__ import annotations

import pytest

from ml_training_debugger.models import MLTrainingAction
from server.environment import MLTrainingEnvironment


@pytest.fixture
def env():
    return MLTrainingEnvironment()


class TestReset:
    def test_reset_returns_valid_observation(self, env):
        obs = env.reset(seed=42, episode_id="test", task_id="task_001")
        assert obs.run_id == "test"
        assert obs.framework == "pytorch"
        assert len(obs.training_loss_history) == 20
        assert len(obs.val_accuracy_history) == 20
        assert obs.done is False

    def test_reset_initial_state(self, env):
        obs = env.reset(seed=42, episode_id="test", task_id="task_001")
        assert obs.episode_state.step_count == 0
        assert not obs.episode_state.gradients_inspected
        assert not obs.episode_state.diagnosis_submitted

    def test_reset_progressive_reveal(self, env):
        obs = env.reset(seed=42, episode_id="test", task_id="task_001")
        assert obs.gradient_stats == []
        assert obs.model_weight_stats is None
        assert obs.data_batch_stats is None
        assert obs.model_mode_info is None
        assert obs.code_snippet is None

    def test_reset_available_actions(self, env):
        obs = env.reset(seed=42, episode_id="test", task_id="task_001")
        assert "inspect_gradients" in obs.available_actions
        assert "mark_diagnosed" in obs.available_actions
        assert "fix_code" not in obs.available_actions
        assert "restart_run" not in obs.available_actions


class TestStepInspections:
    def test_inspect_gradients_populates_stats(self, env):
        env.reset(seed=42, episode_id="test", task_id="task_001")
        obs = env.step(MLTrainingAction(action_type="inspect_gradients"))
        assert len(obs.gradient_stats) > 0
        assert obs.episode_state.gradients_inspected

    def test_inspect_gradients_gives_investigation_bonus(self, env):
        """First-time inspection must give +0.05 bonus (total +0.04 with step penalty)."""
        env.reset(seed=42, episode_id="test", task_id="task_001")
        obs = env.step(MLTrainingAction(action_type="inspect_gradients"))
        assert obs.reward == pytest.approx(0.04)

    def test_inspect_data_batch_gives_investigation_bonus(self, env):
        """First-time data inspection must give +0.05 bonus."""
        env.reset(seed=42, episode_id="test", task_id="task_003")
        obs = env.step(MLTrainingAction(action_type="inspect_data_batch"))
        assert obs.reward == pytest.approx(0.04)

    def test_inspect_model_modes_gives_investigation_bonus(self, env):
        """First-time model modes inspection must give +0.05 bonus."""
        env.reset(seed=42, episode_id="test", task_id="task_005")
        obs = env.step(MLTrainingAction(action_type="inspect_model_modes"))
        assert obs.reward == pytest.approx(0.04)

    def test_repeat_inspection_no_bonus(self, env):
        """Second inspection of same type must NOT give bonus."""
        env.reset(seed=42, episode_id="test", task_id="task_001")
        env.step(MLTrainingAction(action_type="inspect_gradients"))
        obs = env.step(MLTrainingAction(action_type="inspect_gradients"))
        assert obs.reward == pytest.approx(-0.01)

    def test_inspect_data_batch(self, env):
        env.reset(seed=42, episode_id="test", task_id="task_003")
        obs = env.step(MLTrainingAction(action_type="inspect_data_batch"))
        assert obs.data_batch_stats is not None
        assert obs.episode_state.data_inspected

    def test_inspect_model_modes(self, env):
        env.reset(seed=42, episode_id="test", task_id="task_005")
        obs = env.step(MLTrainingAction(action_type="inspect_model_modes"))
        assert obs.model_mode_info is not None
        assert obs.episode_state.model_modes_inspected

    def test_inspect_model_weights(self, env):
        env.reset(seed=42, episode_id="test", task_id="task_001")
        obs = env.step(MLTrainingAction(action_type="inspect_model_weights"))
        assert obs.model_weight_stats is not None
        assert obs.episode_state.model_weights_inspected


class TestStepFixActions:
    def test_modify_config(self, env):
        env.reset(seed=42, episode_id="test", task_id="task_001")
        obs = env.step(
            MLTrainingAction(
                action_type="modify_config",
                target="learning_rate",
                value=0.001,
            )
        )
        assert obs.episode_state.fix_action_taken
        assert "restart_run" in obs.available_actions

    def test_restart_run_after_fix(self, env):
        env.reset(seed=42, episode_id="test", task_id="task_001")
        env.step(
            MLTrainingAction(
                action_type="modify_config",
                target="learning_rate",
                value=0.001,
            )
        )
        obs = env.step(MLTrainingAction(action_type="restart_run"))
        assert obs.episode_state.restart_after_fix


class TestStepDiagnosis:
    def test_mark_diagnosed_ends_episode(self, env):
        env.reset(seed=42, episode_id="test", task_id="task_001")
        obs = env.step(
            MLTrainingAction(
                action_type="mark_diagnosed",
                diagnosis="lr_too_high",
            )
        )
        assert obs.done is True
        assert obs.episode_state.diagnosis_submitted

    def test_step_after_done(self, env):
        env.reset(seed=42, episode_id="test", task_id="task_001")
        env.step(
            MLTrainingAction(
                action_type="mark_diagnosed",
                diagnosis="lr_too_high",
            )
        )
        obs = env.step(MLTrainingAction(action_type="inspect_gradients"))
        assert obs.done is True
        assert obs.reward == 0.0


class TestErrorHandling:
    def test_invalid_action_type(self, env):
        env.reset(seed=42, episode_id="test", task_id="task_001")
        obs = env.step(MLTrainingAction(action_type="nonexistent_action"))
        assert obs.reward == pytest.approx(-0.01 + -0.05)
        assert obs.error_log is not None

    def test_action_not_in_available(self, env):
        env.reset(seed=42, episode_id="test", task_id="task_001")
        # fix_code requires code_inspected=True
        obs = env.step(
            MLTrainingAction(
                action_type="fix_code",
                line=1,
                replacement="pass",
            )
        )
        assert obs.reward < 0

    def test_modify_config_missing_target(self, env):
        env.reset(seed=42, episode_id="test", task_id="task_001")
        obs = env.step(MLTrainingAction(action_type="modify_config"))
        assert "target" in obs.error_log.lower() or "value" in obs.error_log.lower()

    def test_mark_diagnosed_missing_diagnosis(self, env):
        env.reset(seed=42, episode_id="test", task_id="task_001")
        obs = env.step(MLTrainingAction(action_type="mark_diagnosed"))
        assert "diagnosis" in obs.error_log.lower()

    def test_mark_diagnosed_invalid_diagnosis(self, env):
        env.reset(seed=42, episode_id="test", task_id="task_001")
        obs = env.step(
            MLTrainingAction(
                action_type="mark_diagnosed",
                diagnosis="not_a_real_diagnosis",
            )
        )
        assert "invalid" in obs.error_log.lower()

    def test_step_before_reset(self, env):
        obs = env.step(MLTrainingAction(action_type="inspect_gradients"))
        assert obs.done is True


class TestFullEpisodeFlow:
    def test_task_001_full_flow(self, env):
        """Full optimal flow for Task 1."""
        obs = env.reset(seed=42, episode_id="test", task_id="task_001")
        assert not obs.done

        obs = env.step(MLTrainingAction(action_type="inspect_gradients"))
        assert obs.episode_state.gradients_inspected
        assert any(g.is_exploding for g in obs.gradient_stats)

        obs = env.step(
            MLTrainingAction(
                action_type="modify_config",
                target="learning_rate",
                value=0.001,
            )
        )
        assert obs.episode_state.fix_action_taken

        obs = env.step(MLTrainingAction(action_type="restart_run"))
        assert obs.episode_state.restart_after_fix

        obs = env.step(
            MLTrainingAction(
                action_type="mark_diagnosed",
                diagnosis="lr_too_high",
            )
        )
        assert obs.done
        assert obs.reward > 0

    def test_task_005_context_gated_penalty(self, env):
        """Task 5: inspect gradients (normal) → add_callback → penalty fires."""
        obs = env.reset(seed=42, episode_id="test", task_id="task_005")

        obs = env.step(MLTrainingAction(action_type="inspect_gradients"))
        assert obs.episode_state.gradients_inspected
        assert obs.episode_state.gradients_were_normal
        # All layers is_exploding=False
        for g in obs.gradient_stats:
            assert not g.is_exploding

        # Now add_callback should trigger context-gated penalty
        obs = env.step(MLTrainingAction(action_type="add_callback"))
        assert obs.reward == pytest.approx(-0.01 + -0.20)

    def test_task_003_data_leakage(self, env):
        """Task 3: data inspection reveals leakage."""
        obs = env.reset(seed=42, episode_id="test", task_id="task_003")

        obs = env.step(MLTrainingAction(action_type="inspect_data_batch"))
        assert obs.data_batch_stats is not None
        assert obs.data_batch_stats.class_overlap_score > 0.5