ObjectverseDiary / tests /test_finetune_lora_tooling.py
qqyule's picture
Deploy latest Objectverse Diary from fa09aac
dd6cefc verified
"""Tests for Modal LoRA fine-tuning scaffolding helpers."""
from __future__ import annotations
import json
import tempfile
import unittest
from pathlib import Path
from scripts import finetune_lora
class FakeTokenizer:
pad_token = "<pad>"
eos_token = "</s>"
def apply_chat_template(
self,
messages: list[dict[str, str]],
*,
tokenize: bool,
add_generation_prompt: bool,
) -> str:
text = "".join(
f"<{message['role']}>{message['content']}</{message['role']}>"
for message in messages
)
if add_generation_prompt:
text += "<assistant>"
return text
def __call__(
self,
text: str,
*,
truncation: bool,
max_length: int,
padding: bool,
add_special_tokens: bool = False,
) -> dict[str, list[int]]:
del padding, add_special_tokens
ids = [ord(character) % 251 + 1 for character in text]
if truncation:
ids = ids[:max_length]
return {"input_ids": ids, "attention_mask": [1] * len(ids)}
def _valid_record() -> dict[str, object]:
return {
"id": "sft-preview-0001",
"messages": [
{"role": "system", "content": "You are Objectverse Diary."},
{"role": "user", "content": "Create a persona."},
{"role": "assistant", "content": "{\"persona\": {}, \"diary\": {}}"},
],
}
class FinetuneLoraToolingTest(unittest.TestCase):
def test_load_sft_records_rejects_missing_messages(self) -> None:
with tempfile.TemporaryDirectory() as tmp_dir:
path = Path(tmp_dir) / "bad.jsonl"
path.write_text(json.dumps({"id": "bad"}) + "\n", encoding="utf-8")
with self.assertRaises(ValueError):
finetune_lora.load_sft_records(path)
def test_load_sft_records_rejects_malformed_messages(self) -> None:
bad_record = {"id": "bad", "messages": [{"role": "user"}]}
with tempfile.TemporaryDirectory() as tmp_dir:
path = Path(tmp_dir) / "bad.jsonl"
path.write_text(json.dumps(bad_record) + "\n", encoding="utf-8")
with self.assertRaises(ValueError):
finetune_lora.load_sft_records(path)
def test_record_to_training_text_is_non_empty(self) -> None:
text = finetune_lora.record_to_training_text(_valid_record())
self.assertIn("system:", text)
self.assertIn("user:", text)
self.assertIn("assistant:", text)
self.assertIn("Objectverse Diary", text)
def test_default_training_config_uses_safe_qwen_lora_defaults(self) -> None:
config = finetune_lora.TrainingConfig()
self.assertEqual(config.base_model, "Qwen/Qwen2.5-1.5B-Instruct")
self.assertEqual(config.lora_r, 16)
self.assertEqual(config.lora_alpha, 32)
self.assertEqual(config.lora_dropout, 0.05)
self.assertEqual(config.max_steps, 80)
self.assertEqual(config.num_train_epochs, 3.0)
self.assertEqual(config.per_device_train_batch_size, 1)
self.assertEqual(config.gradient_accumulation_steps, 4)
self.assertEqual(config.eval_ratio, 0.1)
self.assertTrue(config.assistant_only_loss)
self.assertIn("q_proj", config.target_modules)
self.assertIn("down_proj", config.target_modules)
def test_training_config_serializes_v2_experiment_settings(self) -> None:
config = finetune_lora.TrainingConfig(
max_steps=0,
num_train_epochs=4.0,
per_device_train_batch_size=2,
gradient_accumulation_steps=8,
eval_ratio=0.2,
eval_steps=25,
lora_r=32,
lora_alpha=64,
assistant_only_loss=False,
)
payload = config.as_remote_dict()
self.assertEqual(payload["num_train_epochs"], 4.0)
self.assertEqual(payload["per_device_train_batch_size"], 2)
self.assertEqual(payload["gradient_accumulation_steps"], 8)
self.assertEqual(payload["eval_ratio"], 0.2)
self.assertEqual(payload["eval_steps"], 25)
self.assertEqual(payload["lora_r"], 32)
self.assertEqual(payload["lora_alpha"], 64)
self.assertFalse(payload["assistant_only_loss"])
def test_dry_run_does_not_call_remote_runner(self) -> None:
with tempfile.TemporaryDirectory() as tmp_dir:
path = Path(tmp_dir) / "records.jsonl"
path.write_text(json.dumps(_valid_record()) + "\n", encoding="utf-8")
def fail_remote_call(
records: list[dict[str, object]],
config: finetune_lora.TrainingConfig,
) -> dict[str, object]:
raise AssertionError("dry-run should not call remote training")
summary = finetune_lora.run_training_entrypoint(
dataset=path,
config=finetune_lora.TrainingConfig(),
dry_run=True,
allow_remote=False,
remote_runner=fail_remote_call,
)
self.assertEqual(summary["mode"], "dry-run")
self.assertEqual(summary["record_count"], 1)
self.assertEqual(summary["base_model"], "Qwen/Qwen2.5-1.5B-Instruct")
self.assertEqual(summary["train_record_count"], 1)
self.assertEqual(summary["eval_record_count"], 0)
def test_dry_run_reports_eval_split_for_larger_datasets(self) -> None:
records = [_valid_record() for _ in range(20)]
summary = finetune_lora._dry_run_summary(
Path("records.jsonl"),
records,
finetune_lora.TrainingConfig(eval_ratio=0.2),
)
self.assertEqual(summary["train_record_count"], 16)
self.assertEqual(summary["eval_record_count"], 4)
def test_assistant_only_tokenization_masks_prompt_labels(self) -> None:
tokenized = finetune_lora._tokenize_training_example(
_valid_record(),
FakeTokenizer(),
max_length=512,
assistant_only_loss=True,
)
labels = tokenized["labels"]
self.assertIn(-100, labels)
self.assertTrue(any(label != -100 for label in labels))
first_unmasked = next(index for index, label in enumerate(labels) if label != -100)
self.assertGreater(first_unmasked, 0)
def test_full_loss_tokenization_keeps_all_labels(self) -> None:
tokenized = finetune_lora._tokenize_training_example(
_valid_record(),
FakeTokenizer(),
max_length=512,
assistant_only_loss=False,
)
self.assertNotIn(-100, tokenized["labels"])
if __name__ == "__main__":
unittest.main()