""" Unit tests for the headless training module. Tests cover: - Environment variable parsing (get_env, get_env_int, get_env_float) - Checkpoint discovery (find_latest_checkpoint) - Model upload to HuggingFace Hub (upload_to_hub) """ from __future__ import annotations from pathlib import Path from unittest.mock import MagicMock, patch import pytest # ============================================================================= # Environment Variable Parsing Tests # ============================================================================= class TestGetEnv: """Test environment variable retrieval functions.""" def test_get_env_returns_value(self, monkeypatch: pytest.MonkeyPatch) -> None: """get_env returns the environment variable value when set.""" monkeypatch.setenv("TEST_VAR", "test_value") # Import after setting env to avoid module-level checks from prolewiki_llm.train_headless import get_env assert get_env("TEST_VAR") == "test_value" def test_get_env_returns_default(self, monkeypatch: pytest.MonkeyPatch) -> None: """get_env returns default when variable not set.""" monkeypatch.delenv("NONEXISTENT_VAR", raising=False) from prolewiki_llm.train_headless import get_env assert get_env("NONEXISTENT_VAR", "default_value") == "default_value" def test_get_env_required_exits(self, monkeypatch: pytest.MonkeyPatch) -> None: """get_env exits when required variable is missing.""" monkeypatch.delenv("REQUIRED_VAR", raising=False) from prolewiki_llm.train_headless import get_env with pytest.raises(SystemExit) as exc_info: get_env("REQUIRED_VAR", required=True) assert exc_info.value.code == 1 class TestGetEnvInt: """Test integer environment variable parsing.""" def test_get_env_int_parses_integer(self, monkeypatch: pytest.MonkeyPatch) -> None: """get_env_int correctly parses integer values.""" monkeypatch.setenv("INT_VAR", "42") from prolewiki_llm.train_headless import get_env_int assert get_env_int("INT_VAR", 0) == 42 def test_get_env_int_returns_default(self, monkeypatch: pytest.MonkeyPatch) -> None: """get_env_int returns default when variable not set.""" monkeypatch.delenv("NONEXISTENT_INT", raising=False) from prolewiki_llm.train_headless import get_env_int assert get_env_int("NONEXISTENT_INT", 100) == 100 class TestGetEnvFloat: """Test float environment variable parsing.""" def test_get_env_float_parses_float(self, monkeypatch: pytest.MonkeyPatch) -> None: """get_env_float correctly parses float values.""" monkeypatch.setenv("FLOAT_VAR", "3.14") from prolewiki_llm.train_headless import get_env_float assert get_env_float("FLOAT_VAR", 0.0) == pytest.approx(3.14) def test_get_env_float_parses_scientific( self, monkeypatch: pytest.MonkeyPatch ) -> None: """get_env_float correctly parses scientific notation.""" monkeypatch.setenv("FLOAT_VAR", "5e-6") from prolewiki_llm.train_headless import get_env_float assert get_env_float("FLOAT_VAR", 0.0) == pytest.approx(5e-6) def test_get_env_float_returns_default( self, monkeypatch: pytest.MonkeyPatch ) -> None: """get_env_float returns default when variable not set.""" monkeypatch.delenv("NONEXISTENT_FLOAT", raising=False) from prolewiki_llm.train_headless import get_env_float assert get_env_float("NONEXISTENT_FLOAT", 1.5) == 1.5 # ============================================================================= # Checkpoint Discovery Tests # ============================================================================= class TestFindLatestCheckpoint: """Test checkpoint discovery logic.""" def test_returns_none_for_nonexistent_dir(self, tmp_path: Path) -> None: """Returns None when checkpoint directory doesn't exist.""" from prolewiki_llm.train_headless import find_latest_checkpoint nonexistent = tmp_path / "nonexistent" assert find_latest_checkpoint(nonexistent) is None def test_returns_none_for_empty_dir(self, tmp_path: Path) -> None: """Returns None when checkpoint directory is empty.""" from prolewiki_llm.train_headless import find_latest_checkpoint checkpoint_dir = tmp_path / "checkpoints" checkpoint_dir.mkdir() assert find_latest_checkpoint(checkpoint_dir) is None def test_returns_none_when_no_checkpoint_dirs(self, tmp_path: Path) -> None: """Returns None when no checkpoint-* directories exist.""" from prolewiki_llm.train_headless import find_latest_checkpoint checkpoint_dir = tmp_path / "checkpoints" checkpoint_dir.mkdir() # Create non-checkpoint directories (checkpoint_dir / "random_dir").mkdir() (checkpoint_dir / "other_file.txt").write_text("test") assert find_latest_checkpoint(checkpoint_dir) is None def test_finds_single_checkpoint(self, tmp_path: Path) -> None: """Finds single checkpoint directory.""" from prolewiki_llm.train_headless import find_latest_checkpoint checkpoint_dir = tmp_path / "checkpoints" checkpoint_dir.mkdir() checkpoint = checkpoint_dir / "checkpoint-100" checkpoint.mkdir() result = find_latest_checkpoint(checkpoint_dir) assert result == checkpoint def test_finds_latest_checkpoint(self, tmp_path: Path) -> None: """Finds the checkpoint with the highest step number.""" from prolewiki_llm.train_headless import find_latest_checkpoint checkpoint_dir = tmp_path / "checkpoints" checkpoint_dir.mkdir() # Create checkpoints in random order (checkpoint_dir / "checkpoint-50").mkdir() (checkpoint_dir / "checkpoint-200").mkdir() (checkpoint_dir / "checkpoint-100").mkdir() (checkpoint_dir / "checkpoint-150").mkdir() result = find_latest_checkpoint(checkpoint_dir) assert result == checkpoint_dir / "checkpoint-200" def test_ignores_non_checkpoint_dirs(self, tmp_path: Path) -> None: """Ignores directories that don't match checkpoint-* pattern.""" from prolewiki_llm.train_headless import find_latest_checkpoint checkpoint_dir = tmp_path / "checkpoints" checkpoint_dir.mkdir() # Create mix of checkpoint and non-checkpoint dirs (checkpoint_dir / "checkpoint-50").mkdir() (checkpoint_dir / "logs").mkdir() (checkpoint_dir / "checkpoint-100").mkdir() (checkpoint_dir / "outputs").mkdir() result = find_latest_checkpoint(checkpoint_dir) assert result == checkpoint_dir / "checkpoint-100" # ============================================================================= # HuggingFace Hub Upload Tests # ============================================================================= class TestUploadToHub: """Test model upload to HuggingFace Hub.""" def test_creates_repo(self, tmp_path: Path) -> None: """upload_to_hub creates the repository if it doesn't exist.""" from prolewiki_llm.train_headless import upload_to_hub model_path = tmp_path / "lora-output" model_path.mkdir() (model_path / "adapter_model.safetensors").write_bytes(b"mock model") mock_api = MagicMock() # HfApi is imported inside upload_to_hub, so we patch at the source with patch("huggingface_hub.HfApi", return_value=mock_api): upload_to_hub(model_path, "test-org/test-model", "test-token") mock_api.create_repo.assert_called_once_with( "test-org/test-model", exist_ok=True, private=True ) def test_uploads_folder(self, tmp_path: Path) -> None: """upload_to_hub uploads the model folder.""" from prolewiki_llm.train_headless import upload_to_hub model_path = tmp_path / "lora-output" model_path.mkdir() (model_path / "adapter_model.safetensors").write_bytes(b"mock model") mock_api = MagicMock() with patch("huggingface_hub.HfApi", return_value=mock_api): upload_to_hub(model_path, "test-org/test-model", "test-token") mock_api.upload_folder.assert_called_once_with( folder_path=str(model_path), repo_id="test-org/test-model", commit_message="Headless GRPO training run", ) def test_handles_repo_creation_failure(self, tmp_path: Path) -> None: """upload_to_hub continues if repo already exists.""" from prolewiki_llm.train_headless import upload_to_hub model_path = tmp_path / "lora-output" model_path.mkdir() (model_path / "adapter_model.safetensors").write_bytes(b"mock model") mock_api = MagicMock() mock_api.create_repo.side_effect = Exception("Repo already exists") with patch("huggingface_hub.HfApi", return_value=mock_api): # Should not raise upload_to_hub(model_path, "test-org/test-model", "test-token") # Should still attempt upload mock_api.upload_folder.assert_called_once()