llm-training / tests /unit /test_train_headless.py
percyraskova's picture
Upload folder using huggingface_hub
81b3473 verified
"""
Unit tests for the headless training module.
Tests cover:
- Environment variable parsing (get_env, get_env_int, get_env_float)
- Checkpoint discovery (find_latest_checkpoint)
- Model upload to HuggingFace Hub (upload_to_hub)
"""
from __future__ import annotations
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
# =============================================================================
# Environment Variable Parsing Tests
# =============================================================================
class TestGetEnv:
"""Test environment variable retrieval functions."""
def test_get_env_returns_value(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""get_env returns the environment variable value when set."""
monkeypatch.setenv("TEST_VAR", "test_value")
# Import after setting env to avoid module-level checks
from prolewiki_llm.train_headless import get_env
assert get_env("TEST_VAR") == "test_value"
def test_get_env_returns_default(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""get_env returns default when variable not set."""
monkeypatch.delenv("NONEXISTENT_VAR", raising=False)
from prolewiki_llm.train_headless import get_env
assert get_env("NONEXISTENT_VAR", "default_value") == "default_value"
def test_get_env_required_exits(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""get_env exits when required variable is missing."""
monkeypatch.delenv("REQUIRED_VAR", raising=False)
from prolewiki_llm.train_headless import get_env
with pytest.raises(SystemExit) as exc_info:
get_env("REQUIRED_VAR", required=True)
assert exc_info.value.code == 1
class TestGetEnvInt:
"""Test integer environment variable parsing."""
def test_get_env_int_parses_integer(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""get_env_int correctly parses integer values."""
monkeypatch.setenv("INT_VAR", "42")
from prolewiki_llm.train_headless import get_env_int
assert get_env_int("INT_VAR", 0) == 42
def test_get_env_int_returns_default(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""get_env_int returns default when variable not set."""
monkeypatch.delenv("NONEXISTENT_INT", raising=False)
from prolewiki_llm.train_headless import get_env_int
assert get_env_int("NONEXISTENT_INT", 100) == 100
class TestGetEnvFloat:
"""Test float environment variable parsing."""
def test_get_env_float_parses_float(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""get_env_float correctly parses float values."""
monkeypatch.setenv("FLOAT_VAR", "3.14")
from prolewiki_llm.train_headless import get_env_float
assert get_env_float("FLOAT_VAR", 0.0) == pytest.approx(3.14)
def test_get_env_float_parses_scientific(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""get_env_float correctly parses scientific notation."""
monkeypatch.setenv("FLOAT_VAR", "5e-6")
from prolewiki_llm.train_headless import get_env_float
assert get_env_float("FLOAT_VAR", 0.0) == pytest.approx(5e-6)
def test_get_env_float_returns_default(
self, monkeypatch: pytest.MonkeyPatch
) -> None:
"""get_env_float returns default when variable not set."""
monkeypatch.delenv("NONEXISTENT_FLOAT", raising=False)
from prolewiki_llm.train_headless import get_env_float
assert get_env_float("NONEXISTENT_FLOAT", 1.5) == 1.5
# =============================================================================
# Checkpoint Discovery Tests
# =============================================================================
class TestFindLatestCheckpoint:
"""Test checkpoint discovery logic."""
def test_returns_none_for_nonexistent_dir(self, tmp_path: Path) -> None:
"""Returns None when checkpoint directory doesn't exist."""
from prolewiki_llm.train_headless import find_latest_checkpoint
nonexistent = tmp_path / "nonexistent"
assert find_latest_checkpoint(nonexistent) is None
def test_returns_none_for_empty_dir(self, tmp_path: Path) -> None:
"""Returns None when checkpoint directory is empty."""
from prolewiki_llm.train_headless import find_latest_checkpoint
checkpoint_dir = tmp_path / "checkpoints"
checkpoint_dir.mkdir()
assert find_latest_checkpoint(checkpoint_dir) is None
def test_returns_none_when_no_checkpoint_dirs(self, tmp_path: Path) -> None:
"""Returns None when no checkpoint-* directories exist."""
from prolewiki_llm.train_headless import find_latest_checkpoint
checkpoint_dir = tmp_path / "checkpoints"
checkpoint_dir.mkdir()
# Create non-checkpoint directories
(checkpoint_dir / "random_dir").mkdir()
(checkpoint_dir / "other_file.txt").write_text("test")
assert find_latest_checkpoint(checkpoint_dir) is None
def test_finds_single_checkpoint(self, tmp_path: Path) -> None:
"""Finds single checkpoint directory."""
from prolewiki_llm.train_headless import find_latest_checkpoint
checkpoint_dir = tmp_path / "checkpoints"
checkpoint_dir.mkdir()
checkpoint = checkpoint_dir / "checkpoint-100"
checkpoint.mkdir()
result = find_latest_checkpoint(checkpoint_dir)
assert result == checkpoint
def test_finds_latest_checkpoint(self, tmp_path: Path) -> None:
"""Finds the checkpoint with the highest step number."""
from prolewiki_llm.train_headless import find_latest_checkpoint
checkpoint_dir = tmp_path / "checkpoints"
checkpoint_dir.mkdir()
# Create checkpoints in random order
(checkpoint_dir / "checkpoint-50").mkdir()
(checkpoint_dir / "checkpoint-200").mkdir()
(checkpoint_dir / "checkpoint-100").mkdir()
(checkpoint_dir / "checkpoint-150").mkdir()
result = find_latest_checkpoint(checkpoint_dir)
assert result == checkpoint_dir / "checkpoint-200"
def test_ignores_non_checkpoint_dirs(self, tmp_path: Path) -> None:
"""Ignores directories that don't match checkpoint-* pattern."""
from prolewiki_llm.train_headless import find_latest_checkpoint
checkpoint_dir = tmp_path / "checkpoints"
checkpoint_dir.mkdir()
# Create mix of checkpoint and non-checkpoint dirs
(checkpoint_dir / "checkpoint-50").mkdir()
(checkpoint_dir / "logs").mkdir()
(checkpoint_dir / "checkpoint-100").mkdir()
(checkpoint_dir / "outputs").mkdir()
result = find_latest_checkpoint(checkpoint_dir)
assert result == checkpoint_dir / "checkpoint-100"
# =============================================================================
# HuggingFace Hub Upload Tests
# =============================================================================
class TestUploadToHub:
"""Test model upload to HuggingFace Hub."""
def test_creates_repo(self, tmp_path: Path) -> None:
"""upload_to_hub creates the repository if it doesn't exist."""
from prolewiki_llm.train_headless import upload_to_hub
model_path = tmp_path / "lora-output"
model_path.mkdir()
(model_path / "adapter_model.safetensors").write_bytes(b"mock model")
mock_api = MagicMock()
# HfApi is imported inside upload_to_hub, so we patch at the source
with patch("huggingface_hub.HfApi", return_value=mock_api):
upload_to_hub(model_path, "test-org/test-model", "test-token")
mock_api.create_repo.assert_called_once_with(
"test-org/test-model", exist_ok=True, private=True
)
def test_uploads_folder(self, tmp_path: Path) -> None:
"""upload_to_hub uploads the model folder."""
from prolewiki_llm.train_headless import upload_to_hub
model_path = tmp_path / "lora-output"
model_path.mkdir()
(model_path / "adapter_model.safetensors").write_bytes(b"mock model")
mock_api = MagicMock()
with patch("huggingface_hub.HfApi", return_value=mock_api):
upload_to_hub(model_path, "test-org/test-model", "test-token")
mock_api.upload_folder.assert_called_once_with(
folder_path=str(model_path),
repo_id="test-org/test-model",
commit_message="Headless GRPO training run",
)
def test_handles_repo_creation_failure(self, tmp_path: Path) -> None:
"""upload_to_hub continues if repo already exists."""
from prolewiki_llm.train_headless import upload_to_hub
model_path = tmp_path / "lora-output"
model_path.mkdir()
(model_path / "adapter_model.safetensors").write_bytes(b"mock model")
mock_api = MagicMock()
mock_api.create_repo.side_effect = Exception("Repo already exists")
with patch("huggingface_hub.HfApi", return_value=mock_api):
# Should not raise
upload_to_hub(model_path, "test-org/test-model", "test-token")
# Should still attempt upload
mock_api.upload_folder.assert_called_once()