artydemo / tests /test_train_cnn_safe.py
Pablo Dejuan
Add HF trainer utilities (dataset+OOM-safe training)
cf79e95
from __future__ import annotations
from pathlib import Path
from unittest.mock import patch
import importlib.util
import sys
ROOT = Path(__file__).resolve().parent.parent
SCRIPT = ROOT / "scripts" / "train_cnn_safe.py"
spec = importlib.util.spec_from_file_location("train_cnn_safe", SCRIPT)
mod = importlib.util.module_from_spec(spec)
sys.modules["train_cnn_safe"] = mod
assert spec.loader is not None
spec.loader.exec_module(mod)
def test_train_cnn_safe_retries_on_failure() -> None:
run_train = mod.run_train
# Patch subprocess.run so first call fails, second succeeds.
with patch("subprocess.run") as run:
run.side_effect = [
type("R", (), {"returncode": 1})(),
type("R", (), {"returncode": 0})(),
]
# First run: non-zero
rc1 = run_train(arch="cnnrnn", epochs=1, batch_size=32, cpu=False, extra_args=None)
assert rc1 == 1
# Second run: zero
rc2 = run_train(arch="cnnrnn", epochs=1, batch_size=16, cpu=False, extra_args=None)
assert rc2 == 0