lemousehunter
feat: training speed optimizations — mixed precision, vectorized augmentation, cached eval predictions
1fe1a19 | """ | |
| Tests for training speed optimizations: | |
| 1. Mixed precision CLI flag and model dtype compatibility | |
| 2. Vectorized augmentation (gather-based shift) | |
| 3. Cached predictions in evaluation (single forward pass) | |
| 4. Agent command builder mapping for mixed_precision | |
| These tests validate code structure, imports, and logic without | |
| requiring TensorFlow GPU or the ASCAD dataset. | |
| """ | |
| import os | |
| import sys | |
| import inspect | |
| import subprocess | |
| import unittest | |
| from unittest.mock import MagicMock | |
| import numpy as np | |
| # Ensure the project root is on the path | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| class TestMixedPrecisionCLI(unittest.TestCase): | |
| """Test that the --mixed-precision CLI flag exists and works.""" | |
| def test_cli_has_mixed_precision_flag(self): | |
| """CLI --help should include --mixed-precision.""" | |
| result = subprocess.run( | |
| [sys.executable, "train_mtl.py", "--help"], | |
| capture_output=True, text=True, | |
| cwd=os.path.dirname(os.path.abspath(__file__)), | |
| timeout=30, | |
| ) | |
| self.assertEqual(result.returncode, 0) | |
| self.assertIn("--mixed-precision", result.stdout) | |
| def test_cli_backward_compatible(self): | |
| """Old CLI args should still work without --mixed-precision.""" | |
| result = subprocess.run( | |
| [sys.executable, "train_mtl.py", "--help"], | |
| capture_output=True, text=True, | |
| cwd=os.path.dirname(os.path.abspath(__file__)), | |
| timeout=30, | |
| ) | |
| self.assertEqual(result.returncode, 0) | |
| # All old flags still present | |
| for flag in ["--variant", "--desync", "--upload", "--wandb-project", | |
| "--augment-shift", "--gradnorm", "--dtp", | |
| "--spectral-decoupling", "--label-smoothing"]: | |
| self.assertIn(flag, result.stdout, f"Missing flag: {flag}") | |
| class TestModelFloat32Outputs(unittest.TestCase): | |
| """Test that model output layers have dtype=float32 for mixed precision safety.""" | |
| def test_hps_output_dtype(self): | |
| """HPS model output layers should have dtype=float32.""" | |
| from src.models.mtl import HPSModel | |
| hps = HPSModel() | |
| model = hps.build() | |
| for i in range(16): | |
| output_layer = model.get_layer(f"byte_{i}") | |
| self.assertEqual( | |
| output_layer.dtype, "float32", | |
| f"byte_{i} output layer dtype should be float32, got {output_layer.dtype}" | |
| ) | |
| def test_mtan_lite_output_dtype(self): | |
| """MTAN-Lite model output layers should have dtype=float32.""" | |
| from src.models.mtl import MTANLiteModel | |
| mtan = MTANLiteModel() | |
| model = mtan.build() | |
| for i in range(16): | |
| output_layer = model.get_layer(f"byte_{i}") | |
| self.assertEqual( | |
| output_layer.dtype, "float32", | |
| f"byte_{i} output layer dtype should be float32, got {output_layer.dtype}" | |
| ) | |
| class TestVectorizedAugmentation(unittest.TestCase): | |
| """Test the vectorized augmentation (no tf.map_fn).""" | |
| def test_augmentor_import(self): | |
| """RandomShiftAugmentor should still be importable.""" | |
| from src.augmentation import RandomShiftAugmentor | |
| aug = RandomShiftAugmentor(max_shift=5) | |
| self.assertEqual(aug.max_shift, 5) | |
| def test_augment_numpy_still_works(self): | |
| """NumPy augmentation should still produce correct shapes.""" | |
| from src.augmentation import RandomShiftAugmentor | |
| aug = RandomShiftAugmentor(max_shift=3) | |
| traces = np.random.randn(10, 100).astype(np.float32) | |
| result = aug.augment_numpy(traces) | |
| self.assertEqual(result.shape, (10, 100)) | |
| def test_augment_numpy_3d_still_works(self): | |
| """NumPy augmentation with 3D input should still work.""" | |
| from src.augmentation import RandomShiftAugmentor | |
| aug = RandomShiftAugmentor(max_shift=3) | |
| traces = np.random.randn(10, 100, 1).astype(np.float32) | |
| result = aug.augment_numpy(traces) | |
| self.assertEqual(result.shape, (10, 100, 1)) | |
| def test_make_tf_dataset_exists(self): | |
| """make_tf_dataset method should still exist.""" | |
| from src.augmentation import RandomShiftAugmentor | |
| self.assertTrue(hasattr(RandomShiftAugmentor, 'make_tf_dataset')) | |
| def test_make_tf_dataset_no_map_fn(self): | |
| """The augmentation code should not call tf.map_fn anymore.""" | |
| import ast | |
| aug_path = os.path.join( | |
| os.path.dirname(os.path.abspath(__file__)), | |
| "src", "augmentation.py" | |
| ) | |
| with open(aug_path) as f: | |
| source = f.read() | |
| # Parse the AST to check for actual tf.map_fn calls, | |
| # ignoring comments and docstrings | |
| tree = ast.parse(source) | |
| map_fn_calls = [] | |
| for node in ast.walk(tree): | |
| if isinstance(node, ast.Call): | |
| func = node.func | |
| # Check for tf.map_fn(...) | |
| if (isinstance(func, ast.Attribute) | |
| and func.attr == "map_fn" | |
| and isinstance(func.value, ast.Name) | |
| and func.value.id == "tf"): | |
| map_fn_calls.append(node.lineno) | |
| self.assertEqual( | |
| len(map_fn_calls), 0, | |
| f"augmentation.py should use vectorized shift, not tf.map_fn " | |
| f"(found calls at lines: {map_fn_calls})" | |
| ) | |
| class TestCachedPredictions(unittest.TestCase): | |
| """Test that evaluate_model accepts cached_predictions parameter.""" | |
| def test_evaluate_model_has_cached_predictions_param(self): | |
| """evaluate_model should accept cached_predictions kwarg.""" | |
| from src.evaluation import evaluate_model | |
| sig = inspect.signature(evaluate_model) | |
| self.assertIn("cached_predictions", sig.parameters) | |
| def test_cached_predictions_default_none(self): | |
| """cached_predictions should default to None.""" | |
| from src.evaluation import evaluate_model | |
| sig = inspect.signature(evaluate_model) | |
| param = sig.parameters["cached_predictions"] | |
| self.assertIs(param.default, None) | |
| class TestAgentCommandBuilder(unittest.TestCase): | |
| """Test that agent.py command builder maps mixed_precision config.""" | |
| def test_agent_has_mixed_precision_mapping(self): | |
| """agent.py should contain mixed_precision CLI mapping.""" | |
| agent_path = os.path.join( | |
| os.path.dirname(os.path.abspath(__file__)), | |
| "orchestrator", "worker", "agent.py" | |
| ) | |
| with open(agent_path) as f: | |
| source = f.read() | |
| self.assertIn("mixed_precision", source, | |
| "agent.py should map mixed_precision config to CLI flag") | |
| self.assertIn("--mixed-precision", source, | |
| "agent.py should include --mixed-precision CLI flag") | |
| if __name__ == "__main__": | |
| unittest.main() | |