| """Tests for parametric fallback in simulation.py. |
| |
| These test the fallback paths that run when real training is unavailable. |
| We force fallback by monkeypatching _get_real_curves to return None. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from unittest.mock import patch |
|
|
| from ml_training_debugger.scenarios import sample_scenario |
| from ml_training_debugger.simulation import ( |
| gen_loss_history, |
| gen_val_accuracy_history, |
| gen_val_loss_history, |
| ) |
|
|
|
|
| def _force_fallback(*args, **kwargs): |
| return None |
|
|
|
|
| class TestParametricFallbackLoss: |
| @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback) |
| def test_task_001_fallback(self) -> None: |
| s = sample_scenario("task_001", seed=42) |
| hist = gen_loss_history(s) |
| assert len(hist) == 20 |
|
|
| @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback) |
| def test_task_002_fallback(self) -> None: |
| s = sample_scenario("task_002", seed=42) |
| hist = gen_loss_history(s) |
| assert len(hist) == 20 |
|
|
| @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback) |
| def test_task_003_fallback(self) -> None: |
| s = sample_scenario("task_003", seed=42) |
| hist = gen_loss_history(s) |
| assert len(hist) == 20 |
|
|
| @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback) |
| def test_task_004_fallback(self) -> None: |
| s = sample_scenario("task_004", seed=42) |
| hist = gen_loss_history(s) |
| assert len(hist) == 20 |
|
|
| @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback) |
| def test_task_005_fallback(self) -> None: |
| s = sample_scenario("task_005", seed=42) |
| hist = gen_loss_history(s) |
| assert len(hist) == 20 |
|
|
| @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback) |
| def test_task_006_fallback(self) -> None: |
| s = sample_scenario("task_006", seed=42) |
| hist = gen_loss_history(s) |
| assert len(hist) == 20 |
|
|
| @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback) |
| def test_task_007_fallback(self) -> None: |
| s = sample_scenario("task_007", seed=42) |
| hist = gen_loss_history(s) |
| assert len(hist) == 20 |
|
|
|
|
| class TestParametricFallbackValAcc: |
| @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback) |
| def test_task_001_fallback(self) -> None: |
| s = sample_scenario("task_001", seed=42) |
| hist = gen_val_accuracy_history(s) |
| assert len(hist) == 20 |
|
|
| @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback) |
| def test_task_003_fallback(self) -> None: |
| s = sample_scenario("task_003", seed=42) |
| hist = gen_val_accuracy_history(s) |
| assert len(hist) == 20 |
|
|
| @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback) |
| def test_task_004_fallback(self) -> None: |
| s = sample_scenario("task_004", seed=42) |
| hist = gen_val_accuracy_history(s) |
| assert len(hist) == 20 |
|
|
| @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback) |
| def test_task_005_fallback(self) -> None: |
| s = sample_scenario("task_005", seed=42) |
| hist = gen_val_accuracy_history(s) |
| assert len(hist) == 20 |
|
|
| @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback) |
| def test_task_006_fallback(self) -> None: |
| s = sample_scenario("task_006", seed=42) |
| hist = gen_val_accuracy_history(s) |
| assert len(hist) == 20 |
|
|
| @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback) |
| def test_task_007_fallback(self) -> None: |
| s = sample_scenario("task_007", seed=42) |
| hist = gen_val_accuracy_history(s) |
| assert len(hist) == 20 |
|
|
|
|
| class TestParametricFallbackValLoss: |
| @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback) |
| def test_task_001_fallback(self) -> None: |
| s = sample_scenario("task_001", seed=42) |
| hist = gen_val_loss_history(s) |
| assert len(hist) == 20 |
|
|
| @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback) |
| def test_task_004_fallback(self) -> None: |
| s = sample_scenario("task_004", seed=42) |
| hist = gen_val_loss_history(s) |
| assert len(hist) == 20 |
|
|
| @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback) |
| def test_task_005_fallback(self) -> None: |
| s = sample_scenario("task_005", seed=42) |
| hist = gen_val_loss_history(s) |
| assert len(hist) == 20 |
|
|
| @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback) |
| def test_task_006_fallback(self) -> None: |
| s = sample_scenario("task_006", seed=42) |
| hist = gen_val_loss_history(s) |
| assert len(hist) == 20 |
|
|
| @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback) |
| def test_task_007_fallback(self) -> None: |
| s = sample_scenario("task_007", seed=42) |
| hist = gen_val_loss_history(s) |
| assert len(hist) == 20 |
|
|
| @patch("ml_training_debugger.simulation._get_real_curves", _force_fallback) |
| def test_fallback_default(self) -> None: |
| """Test the final fallback path for unknown root cause.""" |
| from ml_training_debugger.models import RootCauseDiagnosis |
| from ml_training_debugger.scenarios import ScenarioParams |
|
|
| |
| s = ScenarioParams( |
| task_id="task_999", |
| root_cause=RootCauseDiagnosis.SCHEDULER_MISCONFIGURED, |
| seed=42, |
| ) |
| hist = gen_val_loss_history(s) |
| assert len(hist) == 20 |
|
|