workbench / tests /unit /test_training_planner.py
GitHub Actions
Initial ZeroGPU deployment with spaces shim
7f9dfed
Raw
History Blame Contribute Delete
1.77 kB
from __future__ import annotations
import tempfile
import unittest
from pathlib import Path
from training.planner import (
LoraConfig,
TrainingConfig,
build_training_plan,
load_training_config,
training_hardware_notes,
validate_training_plan,
)
class TrainingPlannerTest(unittest.TestCase):
def test_loads_training_config(self) -> None:
lora, training = load_training_config("config/training.yaml")
self.assertEqual(lora.rank, 16)
self.assertEqual(training.epochs, 1)
def test_validates_missing_dataset(self) -> None:
errors = validate_training_plan(
"missing.jsonl",
LoraConfig(rank=16, alpha=32, dropout=0.05),
TrainingConfig(epochs=1, batch_size=2, grad_accum=4, lr=0.0002, report_to="none"),
)
self.assertIn("does not exist", errors[0])
def test_builds_valid_dry_run_plan(self) -> None:
with tempfile.TemporaryDirectory() as tmp:
dataset = Path(tmp) / "field_notes.jsonl"
dataset.write_text('{"prompt":"hello","correction":"world"}\n', encoding="utf-8")
plan = build_training_plan(
str(dataset),
rank=8,
epochs=2,
output_root=Path(tmp) / "out",
)
self.assertEqual(plan.lora.rank, 8)
self.assertEqual(plan.training.epochs, 2)
self.assertEqual(plan.validation_errors, [])
self.assertFalse(plan.executes_training)
self.assertIn("field_notes", plan.output_dir)
def test_hardware_notes_warn_about_git_weights(self) -> None:
self.assertTrue(any("git" in note for note in training_hardware_notes()))
if __name__ == "__main__":
unittest.main()