Spaces:
Running on Zero
Running on Zero
| """Tests for Modal LoRA fine-tuning scaffolding helpers.""" | |
| from __future__ import annotations | |
| import json | |
| import tempfile | |
| import unittest | |
| from pathlib import Path | |
| from scripts import finetune_lora | |
| class FakeTokenizer: | |
| pad_token = "<pad>" | |
| eos_token = "</s>" | |
| def apply_chat_template( | |
| self, | |
| messages: list[dict[str, str]], | |
| *, | |
| tokenize: bool, | |
| add_generation_prompt: bool, | |
| ) -> str: | |
| text = "".join( | |
| f"<{message['role']}>{message['content']}</{message['role']}>" | |
| for message in messages | |
| ) | |
| if add_generation_prompt: | |
| text += "<assistant>" | |
| return text | |
| def __call__( | |
| self, | |
| text: str, | |
| *, | |
| truncation: bool, | |
| max_length: int, | |
| padding: bool, | |
| add_special_tokens: bool = False, | |
| ) -> dict[str, list[int]]: | |
| del padding, add_special_tokens | |
| ids = [ord(character) % 251 + 1 for character in text] | |
| if truncation: | |
| ids = ids[:max_length] | |
| return {"input_ids": ids, "attention_mask": [1] * len(ids)} | |
| def _valid_record() -> dict[str, object]: | |
| return { | |
| "id": "sft-preview-0001", | |
| "messages": [ | |
| {"role": "system", "content": "You are Objectverse Diary."}, | |
| {"role": "user", "content": "Create a persona."}, | |
| {"role": "assistant", "content": "{\"persona\": {}, \"diary\": {}}"}, | |
| ], | |
| } | |
| class FinetuneLoraToolingTest(unittest.TestCase): | |
| def test_load_sft_records_rejects_missing_messages(self) -> None: | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| path = Path(tmp_dir) / "bad.jsonl" | |
| path.write_text(json.dumps({"id": "bad"}) + "\n", encoding="utf-8") | |
| with self.assertRaises(ValueError): | |
| finetune_lora.load_sft_records(path) | |
| def test_load_sft_records_rejects_malformed_messages(self) -> None: | |
| bad_record = {"id": "bad", "messages": [{"role": "user"}]} | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| path = Path(tmp_dir) / "bad.jsonl" | |
| path.write_text(json.dumps(bad_record) + "\n", encoding="utf-8") | |
| with self.assertRaises(ValueError): | |
| finetune_lora.load_sft_records(path) | |
| def test_record_to_training_text_is_non_empty(self) -> None: | |
| text = finetune_lora.record_to_training_text(_valid_record()) | |
| self.assertIn("system:", text) | |
| self.assertIn("user:", text) | |
| self.assertIn("assistant:", text) | |
| self.assertIn("Objectverse Diary", text) | |
| def test_default_training_config_uses_safe_qwen_lora_defaults(self) -> None: | |
| config = finetune_lora.TrainingConfig() | |
| self.assertEqual(config.base_model, "Qwen/Qwen2.5-1.5B-Instruct") | |
| self.assertEqual(config.lora_r, 16) | |
| self.assertEqual(config.lora_alpha, 32) | |
| self.assertEqual(config.lora_dropout, 0.05) | |
| self.assertEqual(config.max_steps, 80) | |
| self.assertEqual(config.num_train_epochs, 3.0) | |
| self.assertEqual(config.per_device_train_batch_size, 1) | |
| self.assertEqual(config.gradient_accumulation_steps, 4) | |
| self.assertEqual(config.eval_ratio, 0.1) | |
| self.assertTrue(config.assistant_only_loss) | |
| self.assertIn("q_proj", config.target_modules) | |
| self.assertIn("down_proj", config.target_modules) | |
| def test_training_config_serializes_v2_experiment_settings(self) -> None: | |
| config = finetune_lora.TrainingConfig( | |
| max_steps=0, | |
| num_train_epochs=4.0, | |
| per_device_train_batch_size=2, | |
| gradient_accumulation_steps=8, | |
| eval_ratio=0.2, | |
| eval_steps=25, | |
| lora_r=32, | |
| lora_alpha=64, | |
| assistant_only_loss=False, | |
| ) | |
| payload = config.as_remote_dict() | |
| self.assertEqual(payload["num_train_epochs"], 4.0) | |
| self.assertEqual(payload["per_device_train_batch_size"], 2) | |
| self.assertEqual(payload["gradient_accumulation_steps"], 8) | |
| self.assertEqual(payload["eval_ratio"], 0.2) | |
| self.assertEqual(payload["eval_steps"], 25) | |
| self.assertEqual(payload["lora_r"], 32) | |
| self.assertEqual(payload["lora_alpha"], 64) | |
| self.assertFalse(payload["assistant_only_loss"]) | |
| def test_dry_run_does_not_call_remote_runner(self) -> None: | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| path = Path(tmp_dir) / "records.jsonl" | |
| path.write_text(json.dumps(_valid_record()) + "\n", encoding="utf-8") | |
| def fail_remote_call( | |
| records: list[dict[str, object]], | |
| config: finetune_lora.TrainingConfig, | |
| ) -> dict[str, object]: | |
| raise AssertionError("dry-run should not call remote training") | |
| summary = finetune_lora.run_training_entrypoint( | |
| dataset=path, | |
| config=finetune_lora.TrainingConfig(), | |
| dry_run=True, | |
| allow_remote=False, | |
| remote_runner=fail_remote_call, | |
| ) | |
| self.assertEqual(summary["mode"], "dry-run") | |
| self.assertEqual(summary["record_count"], 1) | |
| self.assertEqual(summary["base_model"], "Qwen/Qwen2.5-1.5B-Instruct") | |
| self.assertEqual(summary["train_record_count"], 1) | |
| self.assertEqual(summary["eval_record_count"], 0) | |
| def test_dry_run_reports_eval_split_for_larger_datasets(self) -> None: | |
| records = [_valid_record() for _ in range(20)] | |
| summary = finetune_lora._dry_run_summary( | |
| Path("records.jsonl"), | |
| records, | |
| finetune_lora.TrainingConfig(eval_ratio=0.2), | |
| ) | |
| self.assertEqual(summary["train_record_count"], 16) | |
| self.assertEqual(summary["eval_record_count"], 4) | |
| def test_assistant_only_tokenization_masks_prompt_labels(self) -> None: | |
| tokenized = finetune_lora._tokenize_training_example( | |
| _valid_record(), | |
| FakeTokenizer(), | |
| max_length=512, | |
| assistant_only_loss=True, | |
| ) | |
| labels = tokenized["labels"] | |
| self.assertIn(-100, labels) | |
| self.assertTrue(any(label != -100 for label in labels)) | |
| first_unmasked = next(index for index, label in enumerate(labels) if label != -100) | |
| self.assertGreater(first_unmasked, 0) | |
| def test_full_loss_tokenization_keeps_all_labels(self) -> None: | |
| tokenized = finetune_lora._tokenize_training_example( | |
| _valid_record(), | |
| FakeTokenizer(), | |
| max_length=512, | |
| assistant_only_loss=False, | |
| ) | |
| self.assertNotIn(-100, tokenized["labels"]) | |
| if __name__ == "__main__": | |
| unittest.main() | |