ascad-training-pipeline / test_speed_optimizations.py
lemousehunter
feat: training speed optimizations — mixed precision, vectorized augmentation, cached eval predictions
1fe1a19
"""
Tests for training speed optimizations:
1. Mixed precision CLI flag and model dtype compatibility
2. Vectorized augmentation (gather-based shift)
3. Cached predictions in evaluation (single forward pass)
4. Agent command builder mapping for mixed_precision
These tests validate code structure, imports, and logic without
requiring TensorFlow GPU or the ASCAD dataset.
"""
import os
import sys
import inspect
import subprocess
import unittest
from unittest.mock import MagicMock
import numpy as np
# Ensure the project root is on the path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
class TestMixedPrecisionCLI(unittest.TestCase):
"""Test that the --mixed-precision CLI flag exists and works."""
def test_cli_has_mixed_precision_flag(self):
"""CLI --help should include --mixed-precision."""
result = subprocess.run(
[sys.executable, "train_mtl.py", "--help"],
capture_output=True, text=True,
cwd=os.path.dirname(os.path.abspath(__file__)),
timeout=30,
)
self.assertEqual(result.returncode, 0)
self.assertIn("--mixed-precision", result.stdout)
def test_cli_backward_compatible(self):
"""Old CLI args should still work without --mixed-precision."""
result = subprocess.run(
[sys.executable, "train_mtl.py", "--help"],
capture_output=True, text=True,
cwd=os.path.dirname(os.path.abspath(__file__)),
timeout=30,
)
self.assertEqual(result.returncode, 0)
# All old flags still present
for flag in ["--variant", "--desync", "--upload", "--wandb-project",
"--augment-shift", "--gradnorm", "--dtp",
"--spectral-decoupling", "--label-smoothing"]:
self.assertIn(flag, result.stdout, f"Missing flag: {flag}")
class TestModelFloat32Outputs(unittest.TestCase):
"""Test that model output layers have dtype=float32 for mixed precision safety."""
def test_hps_output_dtype(self):
"""HPS model output layers should have dtype=float32."""
from src.models.mtl import HPSModel
hps = HPSModel()
model = hps.build()
for i in range(16):
output_layer = model.get_layer(f"byte_{i}")
self.assertEqual(
output_layer.dtype, "float32",
f"byte_{i} output layer dtype should be float32, got {output_layer.dtype}"
)
def test_mtan_lite_output_dtype(self):
"""MTAN-Lite model output layers should have dtype=float32."""
from src.models.mtl import MTANLiteModel
mtan = MTANLiteModel()
model = mtan.build()
for i in range(16):
output_layer = model.get_layer(f"byte_{i}")
self.assertEqual(
output_layer.dtype, "float32",
f"byte_{i} output layer dtype should be float32, got {output_layer.dtype}"
)
class TestVectorizedAugmentation(unittest.TestCase):
"""Test the vectorized augmentation (no tf.map_fn)."""
def test_augmentor_import(self):
"""RandomShiftAugmentor should still be importable."""
from src.augmentation import RandomShiftAugmentor
aug = RandomShiftAugmentor(max_shift=5)
self.assertEqual(aug.max_shift, 5)
def test_augment_numpy_still_works(self):
"""NumPy augmentation should still produce correct shapes."""
from src.augmentation import RandomShiftAugmentor
aug = RandomShiftAugmentor(max_shift=3)
traces = np.random.randn(10, 100).astype(np.float32)
result = aug.augment_numpy(traces)
self.assertEqual(result.shape, (10, 100))
def test_augment_numpy_3d_still_works(self):
"""NumPy augmentation with 3D input should still work."""
from src.augmentation import RandomShiftAugmentor
aug = RandomShiftAugmentor(max_shift=3)
traces = np.random.randn(10, 100, 1).astype(np.float32)
result = aug.augment_numpy(traces)
self.assertEqual(result.shape, (10, 100, 1))
def test_make_tf_dataset_exists(self):
"""make_tf_dataset method should still exist."""
from src.augmentation import RandomShiftAugmentor
self.assertTrue(hasattr(RandomShiftAugmentor, 'make_tf_dataset'))
def test_make_tf_dataset_no_map_fn(self):
"""The augmentation code should not call tf.map_fn anymore."""
import ast
aug_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"src", "augmentation.py"
)
with open(aug_path) as f:
source = f.read()
# Parse the AST to check for actual tf.map_fn calls,
# ignoring comments and docstrings
tree = ast.parse(source)
map_fn_calls = []
for node in ast.walk(tree):
if isinstance(node, ast.Call):
func = node.func
# Check for tf.map_fn(...)
if (isinstance(func, ast.Attribute)
and func.attr == "map_fn"
and isinstance(func.value, ast.Name)
and func.value.id == "tf"):
map_fn_calls.append(node.lineno)
self.assertEqual(
len(map_fn_calls), 0,
f"augmentation.py should use vectorized shift, not tf.map_fn "
f"(found calls at lines: {map_fn_calls})"
)
class TestCachedPredictions(unittest.TestCase):
"""Test that evaluate_model accepts cached_predictions parameter."""
def test_evaluate_model_has_cached_predictions_param(self):
"""evaluate_model should accept cached_predictions kwarg."""
from src.evaluation import evaluate_model
sig = inspect.signature(evaluate_model)
self.assertIn("cached_predictions", sig.parameters)
def test_cached_predictions_default_none(self):
"""cached_predictions should default to None."""
from src.evaluation import evaluate_model
sig = inspect.signature(evaluate_model)
param = sig.parameters["cached_predictions"]
self.assertIs(param.default, None)
class TestAgentCommandBuilder(unittest.TestCase):
"""Test that agent.py command builder maps mixed_precision config."""
def test_agent_has_mixed_precision_mapping(self):
"""agent.py should contain mixed_precision CLI mapping."""
agent_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"orchestrator", "worker", "agent.py"
)
with open(agent_path) as f:
source = f.read()
self.assertIn("mixed_precision", source,
"agent.py should map mixed_precision config to CLI flag")
self.assertIn("--mixed-precision", source,
"agent.py should include --mixed-precision CLI flag")
if __name__ == "__main__":
unittest.main()