File size: 5,653 Bytes
45eee48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""Tests for parametric fallback in simulation.py.

These test the fallback paths that run when real training is unavailable.
We force fallback by monkeypatching _get_real_curves to return None.
"""

from __future__ import annotations

from unittest.mock import patch

from ml_training_debugger.scenarios import sample_scenario
from ml_training_debugger.simulation import (
    gen_loss_history,
    gen_val_accuracy_history,
    gen_val_loss_history,
)


def _force_fallback(*args, **kwargs):
    return None


class TestParametricFallbackLoss:
    @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
    def test_task_001_fallback(self) -> None:
        s = sample_scenario("task_001", seed=42)
        hist = gen_loss_history(s)
        assert len(hist) == 20

    @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
    def test_task_002_fallback(self) -> None:
        s = sample_scenario("task_002", seed=42)
        hist = gen_loss_history(s)
        assert len(hist) == 20

    @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
    def test_task_003_fallback(self) -> None:
        s = sample_scenario("task_003", seed=42)
        hist = gen_loss_history(s)
        assert len(hist) == 20

    @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
    def test_task_004_fallback(self) -> None:
        s = sample_scenario("task_004", seed=42)
        hist = gen_loss_history(s)
        assert len(hist) == 20

    @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
    def test_task_005_fallback(self) -> None:
        s = sample_scenario("task_005", seed=42)
        hist = gen_loss_history(s)
        assert len(hist) == 20

    @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
    def test_task_006_fallback(self) -> None:
        s = sample_scenario("task_006", seed=42)
        hist = gen_loss_history(s)
        assert len(hist) == 20

    @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
    def test_task_007_fallback(self) -> None:
        s = sample_scenario("task_007", seed=42)
        hist = gen_loss_history(s)
        assert len(hist) == 20


class TestParametricFallbackValAcc:
    @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
    def test_task_001_fallback(self) -> None:
        s = sample_scenario("task_001", seed=42)
        hist = gen_val_accuracy_history(s)
        assert len(hist) == 20

    @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
    def test_task_003_fallback(self) -> None:
        s = sample_scenario("task_003", seed=42)
        hist = gen_val_accuracy_history(s)
        assert len(hist) == 20

    @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
    def test_task_004_fallback(self) -> None:
        s = sample_scenario("task_004", seed=42)
        hist = gen_val_accuracy_history(s)
        assert len(hist) == 20

    @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
    def test_task_005_fallback(self) -> None:
        s = sample_scenario("task_005", seed=42)
        hist = gen_val_accuracy_history(s)
        assert len(hist) == 20

    @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
    def test_task_006_fallback(self) -> None:
        s = sample_scenario("task_006", seed=42)
        hist = gen_val_accuracy_history(s)
        assert len(hist) == 20

    @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
    def test_task_007_fallback(self) -> None:
        s = sample_scenario("task_007", seed=42)
        hist = gen_val_accuracy_history(s)
        assert len(hist) == 20


class TestParametricFallbackValLoss:
    @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
    def test_task_001_fallback(self) -> None:
        s = sample_scenario("task_001", seed=42)
        hist = gen_val_loss_history(s)
        assert len(hist) == 20

    @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
    def test_task_004_fallback(self) -> None:
        s = sample_scenario("task_004", seed=42)
        hist = gen_val_loss_history(s)
        assert len(hist) == 20

    @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
    def test_task_005_fallback(self) -> None:
        s = sample_scenario("task_005", seed=42)
        hist = gen_val_loss_history(s)
        assert len(hist) == 20

    @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
    def test_task_006_fallback(self) -> None:
        s = sample_scenario("task_006", seed=42)
        hist = gen_val_loss_history(s)
        assert len(hist) == 20

    @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
    def test_task_007_fallback(self) -> None:
        s = sample_scenario("task_007", seed=42)
        hist = gen_val_loss_history(s)
        assert len(hist) == 20

    @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback)
    def test_fallback_default(self) -> None:
        """Test the final fallback path for unknown root cause."""
        from ml_training_debugger.models import RootCauseDiagnosis
        from ml_training_debugger.scenarios import ScenarioParams

        # Use scheduler root cause but force fallback
        s = ScenarioParams(
            task_id="task_999",
            root_cause=RootCauseDiagnosis.SCHEDULER_MISCONFIGURED,
            seed=42,
        )
        hist = gen_val_loss_history(s)
        assert len(hist) == 20