"""Static checks for training/scripts/sft_warmstart.py. Phase 5d. No GPU / no HF Hub access — these tests verify config surface and preflight logic without loading Unsloth/torch or hitting the network. Live SFT training (~30 min A100, ~$1.25) is a manual post-Phase-5c step, not pytest-driven. """ from __future__ import annotations import importlib.util from pathlib import Path from unittest.mock import MagicMock, patch import pytest SCRIPT_PATH = Path(__file__).parent.parent / "training" / "scripts" / "sft_warmstart.py" def _load_module(): """Load training/scripts/sft_warmstart.py as a module without running main().""" import os os.environ.setdefault("HF_TOKEN", "test_token_static_only") os.environ.setdefault("OUTPUT_REPO", "test/sft_static_load") spec = importlib.util.spec_from_file_location("sft_warmstart_under_test", SCRIPT_PATH) assert spec is not None and spec.loader is not None module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) return module def test_script_exists_and_loads() -> None: """File present and importable.""" assert SCRIPT_PATH.exists(), f"missing {SCRIPT_PATH}" mod = _load_module() for attr in ("main", "preflight_model_access", "preflight_dataset_access"): assert hasattr(mod, attr), f"missing {attr}" def test_default_model_is_qwen3_7b() -> None: """B1-Qwen warm-start default.""" mod = _load_module() assert mod.MODEL_NAME.startswith("unsloth/Qwen3-7B-Instruct") def test_default_dataset_repo_matches_phase_5c() -> None: """SFT trainer reads the dataset Phase 5c writes.""" mod = _load_module() assert mod.SFT_DATASET_REPO == "Angshuman28/crisisworld-sft-trajectories" def test_default_lr_is_2e_minus_5() -> None: """M-FR-22: SFT LR ~10x higher than GRPO's 5e-6.""" mod = _load_module() assert abs(mod.LR - 2e-5) < 1e-9 def test_default_lora_rank_matches_grpo_downstream() -> None: """M-FR-19: LoRA rank 32, same as Phase 5b GRPO, for downstream compat.""" mod = _load_module() assert mod.LORA_RANK == 32 def test_default_max_train_steps_is_200() -> None: """Per spec.""" mod = _load_module() assert mod.MAX_TRAIN_STEPS == 200 def test_default_num_epochs_is_2() -> None: """M-FR-20: 2 epochs default with MAX_TRAIN_STEPS as cap.""" mod = _load_module() assert mod.NUM_EPOCHS == 2 def test_required_env_vars_raise_systemexit_when_missing() -> None: """HF_TOKEN and OUTPUT_REPO required.""" import os import subprocess import sys env = {k: v for k, v in os.environ.items() if k not in ("HF_TOKEN", "OUTPUT_REPO")} env["PYTHONPATH"] = str(SCRIPT_PATH.parent.parent.parent) result = subprocess.run( [sys.executable, str(SCRIPT_PATH)], env=env, capture_output=True, text=True, timeout=30, ) assert result.returncode != 0 out = result.stdout + result.stderr # Either HF_TOKEN or OUTPUT_REPO surfaces first; both are valid fail signals. assert "HF_TOKEN" in out or "OUTPUT_REPO" in out def test_preflight_dataset_aborts_on_missing() -> None: """preflight_dataset_access aborts cleanly if dataset doesn't exist.""" mod = _load_module() from huggingface_hub.utils import RepositoryNotFoundError class _FakeNotFound(RepositoryNotFoundError): def __init__(self, msg: str) -> None: Exception.__init__(self, msg) err = _FakeNotFound("no such dataset") with patch("huggingface_hub.HfApi") as MockApi: MockApi.return_value.dataset_info.side_effect = err with pytest.raises(SystemExit, match="not found"): mod.preflight_dataset_access("does/not/exist", "tok") def test_preflight_dataset_passes_when_exists() -> None: """preflight_dataset_access does not raise when dataset_info succeeds.""" mod = _load_module() fake_info = MagicMock() with patch("huggingface_hub.HfApi") as MockApi: MockApi.return_value.dataset_info.return_value = fake_info mod.preflight_dataset_access("real/dataset", "tok") # should not raise def test_preflight_model_aborts_on_gated() -> None: """preflight_model_access aborts on GatedRepoError (Llama-3.1-8B path).""" mod = _load_module() from huggingface_hub.utils import GatedRepoError class _FakeGated(GatedRepoError): def __init__(self, msg: str) -> None: Exception.__init__(self, msg) err = _FakeGated("license required") with patch("huggingface_hub.HfApi") as MockApi: MockApi.return_value.model_info.side_effect = err with pytest.raises(SystemExit, match="gated"): mod.preflight_model_access("meta-llama/Llama-3.1-8B-Instruct", "tok")