Spaces:
Running on Zero
Running on Zero
| from __future__ import annotations | |
| import tempfile | |
| import unittest | |
| from pathlib import Path | |
| from training.planner import ( | |
| LoraConfig, | |
| TrainingConfig, | |
| build_training_plan, | |
| load_training_config, | |
| training_hardware_notes, | |
| validate_training_plan, | |
| ) | |
| class TrainingPlannerTest(unittest.TestCase): | |
| def test_loads_training_config(self) -> None: | |
| lora, training = load_training_config("config/training.yaml") | |
| self.assertEqual(lora.rank, 16) | |
| self.assertEqual(training.epochs, 1) | |
| def test_validates_missing_dataset(self) -> None: | |
| errors = validate_training_plan( | |
| "missing.jsonl", | |
| LoraConfig(rank=16, alpha=32, dropout=0.05), | |
| TrainingConfig(epochs=1, batch_size=2, grad_accum=4, lr=0.0002, report_to="none"), | |
| ) | |
| self.assertIn("does not exist", errors[0]) | |
| def test_builds_valid_dry_run_plan(self) -> None: | |
| with tempfile.TemporaryDirectory() as tmp: | |
| dataset = Path(tmp) / "field_notes.jsonl" | |
| dataset.write_text('{"prompt":"hello","correction":"world"}\n', encoding="utf-8") | |
| plan = build_training_plan( | |
| str(dataset), | |
| rank=8, | |
| epochs=2, | |
| output_root=Path(tmp) / "out", | |
| ) | |
| self.assertEqual(plan.lora.rank, 8) | |
| self.assertEqual(plan.training.epochs, 2) | |
| self.assertEqual(plan.validation_errors, []) | |
| self.assertFalse(plan.executes_training) | |
| self.assertIn("field_notes", plan.output_dir) | |
| def test_hardware_notes_warn_about_git_weights(self) -> None: | |
| self.assertTrue(any("git" in note for note in training_hardware_notes())) | |
| if __name__ == "__main__": | |
| unittest.main() | |