|
|
""" |
|
|
Unit tests for the headless training module. |
|
|
|
|
|
Tests cover: |
|
|
- Environment variable parsing (get_env, get_env_int, get_env_float) |
|
|
- Checkpoint discovery (find_latest_checkpoint) |
|
|
- Model upload to HuggingFace Hub (upload_to_hub) |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from pathlib import Path |
|
|
from unittest.mock import MagicMock, patch |
|
|
|
|
|
import pytest |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestGetEnv: |
|
|
"""Test environment variable retrieval functions.""" |
|
|
|
|
|
def test_get_env_returns_value(self, monkeypatch: pytest.MonkeyPatch) -> None: |
|
|
"""get_env returns the environment variable value when set.""" |
|
|
monkeypatch.setenv("TEST_VAR", "test_value") |
|
|
|
|
|
|
|
|
from prolewiki_llm.train_headless import get_env |
|
|
|
|
|
assert get_env("TEST_VAR") == "test_value" |
|
|
|
|
|
def test_get_env_returns_default(self, monkeypatch: pytest.MonkeyPatch) -> None: |
|
|
"""get_env returns default when variable not set.""" |
|
|
monkeypatch.delenv("NONEXISTENT_VAR", raising=False) |
|
|
|
|
|
from prolewiki_llm.train_headless import get_env |
|
|
|
|
|
assert get_env("NONEXISTENT_VAR", "default_value") == "default_value" |
|
|
|
|
|
def test_get_env_required_exits(self, monkeypatch: pytest.MonkeyPatch) -> None: |
|
|
"""get_env exits when required variable is missing.""" |
|
|
monkeypatch.delenv("REQUIRED_VAR", raising=False) |
|
|
|
|
|
from prolewiki_llm.train_headless import get_env |
|
|
|
|
|
with pytest.raises(SystemExit) as exc_info: |
|
|
get_env("REQUIRED_VAR", required=True) |
|
|
|
|
|
assert exc_info.value.code == 1 |
|
|
|
|
|
|
|
|
class TestGetEnvInt: |
|
|
"""Test integer environment variable parsing.""" |
|
|
|
|
|
def test_get_env_int_parses_integer(self, monkeypatch: pytest.MonkeyPatch) -> None: |
|
|
"""get_env_int correctly parses integer values.""" |
|
|
monkeypatch.setenv("INT_VAR", "42") |
|
|
|
|
|
from prolewiki_llm.train_headless import get_env_int |
|
|
|
|
|
assert get_env_int("INT_VAR", 0) == 42 |
|
|
|
|
|
def test_get_env_int_returns_default(self, monkeypatch: pytest.MonkeyPatch) -> None: |
|
|
"""get_env_int returns default when variable not set.""" |
|
|
monkeypatch.delenv("NONEXISTENT_INT", raising=False) |
|
|
|
|
|
from prolewiki_llm.train_headless import get_env_int |
|
|
|
|
|
assert get_env_int("NONEXISTENT_INT", 100) == 100 |
|
|
|
|
|
|
|
|
class TestGetEnvFloat: |
|
|
"""Test float environment variable parsing.""" |
|
|
|
|
|
def test_get_env_float_parses_float(self, monkeypatch: pytest.MonkeyPatch) -> None: |
|
|
"""get_env_float correctly parses float values.""" |
|
|
monkeypatch.setenv("FLOAT_VAR", "3.14") |
|
|
|
|
|
from prolewiki_llm.train_headless import get_env_float |
|
|
|
|
|
assert get_env_float("FLOAT_VAR", 0.0) == pytest.approx(3.14) |
|
|
|
|
|
def test_get_env_float_parses_scientific( |
|
|
self, monkeypatch: pytest.MonkeyPatch |
|
|
) -> None: |
|
|
"""get_env_float correctly parses scientific notation.""" |
|
|
monkeypatch.setenv("FLOAT_VAR", "5e-6") |
|
|
|
|
|
from prolewiki_llm.train_headless import get_env_float |
|
|
|
|
|
assert get_env_float("FLOAT_VAR", 0.0) == pytest.approx(5e-6) |
|
|
|
|
|
def test_get_env_float_returns_default( |
|
|
self, monkeypatch: pytest.MonkeyPatch |
|
|
) -> None: |
|
|
"""get_env_float returns default when variable not set.""" |
|
|
monkeypatch.delenv("NONEXISTENT_FLOAT", raising=False) |
|
|
|
|
|
from prolewiki_llm.train_headless import get_env_float |
|
|
|
|
|
assert get_env_float("NONEXISTENT_FLOAT", 1.5) == 1.5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestFindLatestCheckpoint: |
|
|
"""Test checkpoint discovery logic.""" |
|
|
|
|
|
def test_returns_none_for_nonexistent_dir(self, tmp_path: Path) -> None: |
|
|
"""Returns None when checkpoint directory doesn't exist.""" |
|
|
from prolewiki_llm.train_headless import find_latest_checkpoint |
|
|
|
|
|
nonexistent = tmp_path / "nonexistent" |
|
|
assert find_latest_checkpoint(nonexistent) is None |
|
|
|
|
|
def test_returns_none_for_empty_dir(self, tmp_path: Path) -> None: |
|
|
"""Returns None when checkpoint directory is empty.""" |
|
|
from prolewiki_llm.train_headless import find_latest_checkpoint |
|
|
|
|
|
checkpoint_dir = tmp_path / "checkpoints" |
|
|
checkpoint_dir.mkdir() |
|
|
|
|
|
assert find_latest_checkpoint(checkpoint_dir) is None |
|
|
|
|
|
def test_returns_none_when_no_checkpoint_dirs(self, tmp_path: Path) -> None: |
|
|
"""Returns None when no checkpoint-* directories exist.""" |
|
|
from prolewiki_llm.train_headless import find_latest_checkpoint |
|
|
|
|
|
checkpoint_dir = tmp_path / "checkpoints" |
|
|
checkpoint_dir.mkdir() |
|
|
|
|
|
(checkpoint_dir / "random_dir").mkdir() |
|
|
(checkpoint_dir / "other_file.txt").write_text("test") |
|
|
|
|
|
assert find_latest_checkpoint(checkpoint_dir) is None |
|
|
|
|
|
def test_finds_single_checkpoint(self, tmp_path: Path) -> None: |
|
|
"""Finds single checkpoint directory.""" |
|
|
from prolewiki_llm.train_headless import find_latest_checkpoint |
|
|
|
|
|
checkpoint_dir = tmp_path / "checkpoints" |
|
|
checkpoint_dir.mkdir() |
|
|
checkpoint = checkpoint_dir / "checkpoint-100" |
|
|
checkpoint.mkdir() |
|
|
|
|
|
result = find_latest_checkpoint(checkpoint_dir) |
|
|
assert result == checkpoint |
|
|
|
|
|
def test_finds_latest_checkpoint(self, tmp_path: Path) -> None: |
|
|
"""Finds the checkpoint with the highest step number.""" |
|
|
from prolewiki_llm.train_headless import find_latest_checkpoint |
|
|
|
|
|
checkpoint_dir = tmp_path / "checkpoints" |
|
|
checkpoint_dir.mkdir() |
|
|
|
|
|
|
|
|
(checkpoint_dir / "checkpoint-50").mkdir() |
|
|
(checkpoint_dir / "checkpoint-200").mkdir() |
|
|
(checkpoint_dir / "checkpoint-100").mkdir() |
|
|
(checkpoint_dir / "checkpoint-150").mkdir() |
|
|
|
|
|
result = find_latest_checkpoint(checkpoint_dir) |
|
|
assert result == checkpoint_dir / "checkpoint-200" |
|
|
|
|
|
def test_ignores_non_checkpoint_dirs(self, tmp_path: Path) -> None: |
|
|
"""Ignores directories that don't match checkpoint-* pattern.""" |
|
|
from prolewiki_llm.train_headless import find_latest_checkpoint |
|
|
|
|
|
checkpoint_dir = tmp_path / "checkpoints" |
|
|
checkpoint_dir.mkdir() |
|
|
|
|
|
|
|
|
(checkpoint_dir / "checkpoint-50").mkdir() |
|
|
(checkpoint_dir / "logs").mkdir() |
|
|
(checkpoint_dir / "checkpoint-100").mkdir() |
|
|
(checkpoint_dir / "outputs").mkdir() |
|
|
|
|
|
result = find_latest_checkpoint(checkpoint_dir) |
|
|
assert result == checkpoint_dir / "checkpoint-100" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestUploadToHub: |
|
|
"""Test model upload to HuggingFace Hub.""" |
|
|
|
|
|
def test_creates_repo(self, tmp_path: Path) -> None: |
|
|
"""upload_to_hub creates the repository if it doesn't exist.""" |
|
|
from prolewiki_llm.train_headless import upload_to_hub |
|
|
|
|
|
model_path = tmp_path / "lora-output" |
|
|
model_path.mkdir() |
|
|
(model_path / "adapter_model.safetensors").write_bytes(b"mock model") |
|
|
|
|
|
mock_api = MagicMock() |
|
|
|
|
|
|
|
|
with patch("huggingface_hub.HfApi", return_value=mock_api): |
|
|
upload_to_hub(model_path, "test-org/test-model", "test-token") |
|
|
|
|
|
mock_api.create_repo.assert_called_once_with( |
|
|
"test-org/test-model", exist_ok=True, private=True |
|
|
) |
|
|
|
|
|
def test_uploads_folder(self, tmp_path: Path) -> None: |
|
|
"""upload_to_hub uploads the model folder.""" |
|
|
from prolewiki_llm.train_headless import upload_to_hub |
|
|
|
|
|
model_path = tmp_path / "lora-output" |
|
|
model_path.mkdir() |
|
|
(model_path / "adapter_model.safetensors").write_bytes(b"mock model") |
|
|
|
|
|
mock_api = MagicMock() |
|
|
|
|
|
with patch("huggingface_hub.HfApi", return_value=mock_api): |
|
|
upload_to_hub(model_path, "test-org/test-model", "test-token") |
|
|
|
|
|
mock_api.upload_folder.assert_called_once_with( |
|
|
folder_path=str(model_path), |
|
|
repo_id="test-org/test-model", |
|
|
commit_message="Headless GRPO training run", |
|
|
) |
|
|
|
|
|
def test_handles_repo_creation_failure(self, tmp_path: Path) -> None: |
|
|
"""upload_to_hub continues if repo already exists.""" |
|
|
from prolewiki_llm.train_headless import upload_to_hub |
|
|
|
|
|
model_path = tmp_path / "lora-output" |
|
|
model_path.mkdir() |
|
|
(model_path / "adapter_model.safetensors").write_bytes(b"mock model") |
|
|
|
|
|
mock_api = MagicMock() |
|
|
mock_api.create_repo.side_effect = Exception("Repo already exists") |
|
|
|
|
|
with patch("huggingface_hub.HfApi", return_value=mock_api): |
|
|
|
|
|
upload_to_hub(model_path, "test-org/test-model", "test-token") |
|
|
|
|
|
|
|
|
mock_api.upload_folder.assert_called_once() |
|
|
|