Spaces:
Running on Zero
Running on Zero
| """Modal LoRA fine-tuning scaffold for Objectverse Diary text generation.""" | |
| from __future__ import annotations | |
| import argparse | |
| import inspect | |
| import json | |
| import math | |
| import sys | |
| from collections.abc import Callable, Mapping, Sequence | |
| from dataclasses import asdict, dataclass, field | |
| from pathlib import Path | |
| from typing import Any | |
| PROJECT_ROOT = Path(__file__).resolve().parents[1] | |
| if str(PROJECT_ROOT) not in sys.path: | |
| sys.path.insert(0, str(PROJECT_ROOT)) | |
| try: | |
| import modal | |
| except ImportError: # Modal is optional for local dry-run and tests. | |
| modal = None # type: ignore[assignment] | |
| APP_NAME = "objectverse-diary-lora" | |
| DEFAULT_DATASET_PATH = Path("data/train/objectverse_sft_preview.jsonl") | |
| DEFAULT_RUN_NAME = "objectverse-diary-qwen15b-preview" | |
| DEFAULT_BASE_MODEL = "Qwen/Qwen2.5-1.5B-Instruct" | |
| HOURS = 60 * 60 | |
| CACHE_DIR = "/cache" | |
| OUTPUT_DIR = "/outputs" | |
| LORA_TARGET_MODULES = ( | |
| "q_proj", | |
| "k_proj", | |
| "v_proj", | |
| "o_proj", | |
| "gate_proj", | |
| "up_proj", | |
| "down_proj", | |
| ) | |
| class TrainingConfig: | |
| """Serializable training settings shared by dry-run and Modal execution.""" | |
| run_name: str = DEFAULT_RUN_NAME | |
| base_model: str = DEFAULT_BASE_MODEL | |
| max_steps: int = 80 | |
| num_train_epochs: float = 3.0 | |
| learning_rate: float = 2e-4 | |
| max_seq_length: int = 1024 | |
| per_device_train_batch_size: int = 1 | |
| gradient_accumulation_steps: int = 4 | |
| eval_ratio: float = 0.1 | |
| eval_steps: int = 10 | |
| warmup_ratio: float = 0.03 | |
| weight_decay: float = 0.0 | |
| logging_steps: int = 5 | |
| save_total_limit: int = 2 | |
| seed: int = 42 | |
| assistant_only_loss: bool = True | |
| lora_r: int = 16 | |
| lora_alpha: int = 32 | |
| lora_dropout: float = 0.05 | |
| target_modules: tuple[str, ...] = field(default_factory=lambda: LORA_TARGET_MODULES) | |
| def __post_init__(self) -> None: | |
| if self.max_steps < 0: | |
| raise ValueError("max_steps must be 0 or greater.") | |
| if self.max_steps == 0 and self.num_train_epochs <= 0: | |
| raise ValueError("num_train_epochs must be greater than 0 when max_steps is 0.") | |
| if self.per_device_train_batch_size < 1: | |
| raise ValueError("per_device_train_batch_size must be at least 1.") | |
| if self.gradient_accumulation_steps < 1: | |
| raise ValueError("gradient_accumulation_steps must be at least 1.") | |
| if not 0 <= self.eval_ratio < 1: | |
| raise ValueError("eval_ratio must be between 0 and 1.") | |
| if self.eval_steps < 1: | |
| raise ValueError("eval_steps must be at least 1.") | |
| if self.logging_steps < 1: | |
| raise ValueError("logging_steps must be at least 1.") | |
| if self.save_total_limit < 1: | |
| raise ValueError("save_total_limit must be at least 1.") | |
| if self.lora_r < 1: | |
| raise ValueError("lora_r must be at least 1.") | |
| if self.lora_alpha < 1: | |
| raise ValueError("lora_alpha must be at least 1.") | |
| if not 0 <= self.lora_dropout < 1: | |
| raise ValueError("lora_dropout must be between 0 and 1.") | |
| def as_remote_dict(self) -> dict[str, object]: | |
| payload = asdict(self) | |
| payload["target_modules"] = list(self.target_modules) | |
| return payload | |
| def load_sft_records(path: Path) -> list[dict[str, object]]: | |
| """Load and validate chat-style SFT JSONL records.""" | |
| if not path.exists(): | |
| raise FileNotFoundError(f"Dataset path does not exist: {path}") | |
| records: list[dict[str, object]] = [] | |
| for line_number, line in enumerate(path.read_text(encoding="utf-8").splitlines(), start=1): | |
| if not line.strip(): | |
| continue | |
| try: | |
| raw = json.loads(line) | |
| except json.JSONDecodeError as exc: | |
| raise ValueError(f"Invalid JSON on line {line_number}: {exc.msg}") from exc | |
| if not isinstance(raw, dict): | |
| raise ValueError(f"Line {line_number} must be a JSON object.") | |
| records.append(_validate_sft_record(raw, line_number)) | |
| if not records: | |
| raise ValueError(f"Dataset has no records: {path}") | |
| return records | |
| def record_to_training_text(record: Mapping[str, object]) -> str: | |
| """Convert one validated chat record into a simple fallback training string.""" | |
| messages = _validate_messages(record.get("messages"), line_number=None) | |
| return _messages_to_training_text(messages) | |
| def _messages_to_training_text( | |
| messages: Sequence[Mapping[str, str]], | |
| *, | |
| add_generation_prompt: bool = False, | |
| ) -> str: | |
| blocks = [] | |
| for message in messages: | |
| role = str(message["role"]).strip().lower() | |
| content = str(message["content"]).strip() | |
| blocks.append(f"{role}:\n{content}") | |
| if add_generation_prompt: | |
| blocks.append("assistant:\n") | |
| return "\n\n".join(blocks).strip() | |
| def run_training_entrypoint( | |
| *, | |
| dataset: Path, | |
| config: TrainingConfig, | |
| dry_run: bool, | |
| allow_remote: bool, | |
| remote_runner: Callable[[list[dict[str, object]], TrainingConfig], dict[str, object]] | None = None, | |
| ) -> dict[str, object]: | |
| """Validate inputs and either return a dry-run summary or launch Modal training.""" | |
| records = load_sft_records(dataset) | |
| if dry_run: | |
| return _dry_run_summary(dataset, records, config) | |
| if not allow_remote: | |
| raise RuntimeError("Use `modal run scripts/finetune_lora.py ...` for real training.") | |
| runner = remote_runner or _run_modal_training | |
| return runner(records, config) | |
| def _validate_sft_record(raw: dict[str, object], line_number: int) -> dict[str, object]: | |
| _validate_messages(raw.get("messages"), line_number=line_number) | |
| return raw | |
| def _validate_messages(raw_messages: object, line_number: int | None) -> list[dict[str, str]]: | |
| location = f"line {line_number}" if line_number is not None else "record" | |
| if not isinstance(raw_messages, list) or not raw_messages: | |
| raise ValueError(f"{location} must include a non-empty messages list.") | |
| messages: list[dict[str, str]] = [] | |
| for index, raw_message in enumerate(raw_messages, start=1): | |
| if not isinstance(raw_message, dict): | |
| raise ValueError(f"{location} message {index} must be an object.") | |
| role = raw_message.get("role") | |
| content = raw_message.get("content") | |
| if not isinstance(role, str) or not role.strip(): | |
| raise ValueError(f"{location} message {index} must include a role.") | |
| if not isinstance(content, str) or not content.strip(): | |
| raise ValueError(f"{location} message {index} must include content.") | |
| messages.append({"role": role.strip(), "content": content.strip()}) | |
| return messages | |
| def _dry_run_summary( | |
| dataset: Path, | |
| records: Sequence[Mapping[str, object]], | |
| config: TrainingConfig, | |
| ) -> dict[str, object]: | |
| first_text = record_to_training_text(records[0]) | |
| eval_count = _eval_record_count(len(records), config.eval_ratio) | |
| return { | |
| "mode": "dry-run", | |
| "dataset": str(dataset), | |
| "record_count": len(records), | |
| "train_record_count": len(records) - eval_count, | |
| "eval_record_count": eval_count, | |
| "base_model": config.base_model, | |
| "run_name": config.run_name, | |
| "max_steps": config.max_steps, | |
| "num_train_epochs": config.num_train_epochs, | |
| "learning_rate": config.learning_rate, | |
| "max_seq_length": config.max_seq_length, | |
| "per_device_train_batch_size": config.per_device_train_batch_size, | |
| "gradient_accumulation_steps": config.gradient_accumulation_steps, | |
| "effective_batch_size": ( | |
| config.per_device_train_batch_size * config.gradient_accumulation_steps | |
| ), | |
| "eval_ratio": config.eval_ratio, | |
| "eval_steps": config.eval_steps, | |
| "warmup_ratio": config.warmup_ratio, | |
| "weight_decay": config.weight_decay, | |
| "assistant_only_loss": config.assistant_only_loss, | |
| "lora": { | |
| "r": config.lora_r, | |
| "alpha": config.lora_alpha, | |
| "dropout": config.lora_dropout, | |
| "target_modules": list(config.target_modules), | |
| }, | |
| "first_training_text_chars": len(first_text), | |
| "will_launch_modal": False, | |
| } | |
| def _run_modal_training( | |
| records: list[dict[str, object]], | |
| config: TrainingConfig, | |
| ) -> dict[str, object]: | |
| if modal is None: | |
| raise RuntimeError("Modal is not installed. Install `requirements-training.txt` first.") | |
| return train_lora_remote.remote(records, config.as_remote_dict()) | |
| def _train_lora_impl( | |
| records: list[dict[str, object]], | |
| config_payload: Mapping[str, object], | |
| ) -> dict[str, object]: | |
| from datasets import Dataset | |
| import torch | |
| from peft import LoraConfig, TaskType, get_peft_model | |
| from transformers import ( | |
| AutoModelForCausalLM, | |
| AutoTokenizer, | |
| Trainer, | |
| TrainingArguments, | |
| ) | |
| config = _training_config_from_payload(config_payload) | |
| output_path = Path(OUTPUT_DIR) / config.run_name | |
| adapter_path = output_path / "adapter" | |
| output_path.mkdir(parents=True, exist_ok=True) | |
| tokenizer = AutoTokenizer.from_pretrained(config.base_model, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model_kwargs: dict[str, object] = {"trust_remote_code": True} | |
| if torch.cuda.is_available(): | |
| model_kwargs["torch_dtype"] = torch.float16 | |
| model = AutoModelForCausalLM.from_pretrained(config.base_model, **model_kwargs) | |
| if hasattr(model, "config"): | |
| model.config.use_cache = False | |
| peft_config = LoraConfig( | |
| r=config.lora_r, | |
| lora_alpha=config.lora_alpha, | |
| lora_dropout=config.lora_dropout, | |
| target_modules=list(config.target_modules), | |
| bias="none", | |
| task_type=TaskType.CAUSAL_LM, | |
| ) | |
| model = get_peft_model(model, peft_config) | |
| model.print_trainable_parameters() | |
| dataset = Dataset.from_list( | |
| [ | |
| _tokenize_training_example( | |
| record, | |
| tokenizer, | |
| max_length=config.max_seq_length, | |
| assistant_only_loss=config.assistant_only_loss, | |
| ) | |
| for record in records | |
| ] | |
| ) | |
| train_dataset, eval_dataset = _split_dataset(dataset, config) | |
| training_kwargs = _training_arguments_kwargs( | |
| output_dir=output_path / "trainer", | |
| config=config, | |
| has_eval=eval_dataset is not None, | |
| training_arguments_cls=TrainingArguments, | |
| ) | |
| training_kwargs["fp16"] = torch.cuda.is_available() | |
| training_args = TrainingArguments(**training_kwargs) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| data_collator=_build_supervised_data_collator(tokenizer, torch), | |
| ) | |
| train_result = trainer.train() | |
| eval_metrics: dict[str, object] = {} | |
| if eval_dataset is not None: | |
| eval_metrics = dict(trainer.evaluate()) | |
| model.save_pretrained(adapter_path) | |
| tokenizer.save_pretrained(adapter_path) | |
| metrics = dict(train_result.metrics) | |
| metrics.update(eval_metrics) | |
| metrics["train_records"] = len(train_dataset) | |
| metrics["eval_records"] = len(eval_dataset) if eval_dataset is not None else 0 | |
| metrics["base_model"] = config.base_model | |
| (output_path / "metrics.json").write_text( | |
| json.dumps(metrics, indent=2, sort_keys=True), | |
| encoding="utf-8", | |
| ) | |
| (output_path / "training_config.json").write_text( | |
| json.dumps(config.as_remote_dict(), indent=2, sort_keys=True), | |
| encoding="utf-8", | |
| ) | |
| if _OUTPUT_VOLUME is not None: | |
| _OUTPUT_VOLUME.commit() | |
| return { | |
| "mode": "modal-training", | |
| "run_name": config.run_name, | |
| "record_count": len(records), | |
| "train_record_count": len(train_dataset), | |
| "eval_record_count": len(eval_dataset) if eval_dataset is not None else 0, | |
| "adapter_path": str(adapter_path), | |
| "metrics_path": str(output_path / "metrics.json"), | |
| } | |
| def _tokenize_training_example( | |
| record: Mapping[str, object], | |
| tokenizer: Any, | |
| *, | |
| max_length: int, | |
| assistant_only_loss: bool, | |
| ) -> dict[str, list[int]]: | |
| full_text = _format_training_text(record, tokenizer) | |
| encoded = tokenizer( | |
| full_text, | |
| truncation=True, | |
| max_length=max_length, | |
| padding=False, | |
| add_special_tokens=False, | |
| ) | |
| input_ids = list(encoded["input_ids"]) | |
| labels = list(input_ids) | |
| if assistant_only_loss: | |
| prompt_text = _format_prompt_text(record, tokenizer) | |
| prompt_encoded = tokenizer( | |
| prompt_text, | |
| truncation=True, | |
| max_length=max_length, | |
| padding=False, | |
| add_special_tokens=False, | |
| ) | |
| mask_count = min(len(prompt_encoded["input_ids"]), len(labels)) | |
| labels[:mask_count] = [-100] * mask_count | |
| if not any(label != -100 for label in labels): | |
| raise ValueError( | |
| "max_seq_length truncates all assistant labels; increase max_seq_length." | |
| ) | |
| return { | |
| "input_ids": input_ids, | |
| "attention_mask": list(encoded["attention_mask"]), | |
| "labels": labels, | |
| } | |
| def _split_dataset(dataset: Any, config: TrainingConfig) -> tuple[Any, Any | None]: | |
| eval_count = _eval_record_count(len(dataset), config.eval_ratio) | |
| if eval_count == 0: | |
| return dataset, None | |
| split = dataset.train_test_split(test_size=eval_count, shuffle=True, seed=config.seed) | |
| return split["train"], split["test"] | |
| def _eval_record_count(record_count: int, eval_ratio: float) -> int: | |
| if record_count < 2 or eval_ratio <= 0: | |
| return 0 | |
| return max(1, min(record_count - 1, math.ceil(record_count * eval_ratio))) | |
| def _training_arguments_kwargs( | |
| *, | |
| output_dir: Path, | |
| config: TrainingConfig, | |
| has_eval: bool, | |
| training_arguments_cls: Any | None = None, | |
| ) -> dict[str, object]: | |
| kwargs: dict[str, object] = { | |
| "output_dir": str(output_dir), | |
| "per_device_train_batch_size": config.per_device_train_batch_size, | |
| "gradient_accumulation_steps": config.gradient_accumulation_steps, | |
| "learning_rate": config.learning_rate, | |
| "logging_steps": config.logging_steps, | |
| "warmup_ratio": config.warmup_ratio, | |
| "weight_decay": config.weight_decay, | |
| "report_to": [], | |
| "optim": "adamw_torch", | |
| "seed": config.seed, | |
| "data_seed": config.seed, | |
| } | |
| if config.max_steps > 0: | |
| kwargs["max_steps"] = config.max_steps | |
| else: | |
| kwargs["num_train_epochs"] = config.num_train_epochs | |
| if has_eval: | |
| kwargs.update( | |
| { | |
| "eval_steps": config.eval_steps, | |
| "save_steps": config.eval_steps, | |
| "save_strategy": "steps", | |
| "save_total_limit": config.save_total_limit, | |
| "load_best_model_at_end": True, | |
| "metric_for_best_model": "eval_loss", | |
| "greater_is_better": False, | |
| } | |
| ) | |
| if training_arguments_cls is None: | |
| kwargs["eval_strategy"] = "steps" | |
| else: | |
| _set_eval_strategy_kwarg(kwargs, training_arguments_cls, "steps") | |
| else: | |
| kwargs["save_strategy"] = "no" | |
| return kwargs | |
| def _set_eval_strategy_kwarg( | |
| kwargs: dict[str, object], | |
| training_arguments_cls: Any, | |
| strategy: str, | |
| ) -> None: | |
| parameters = inspect.signature(training_arguments_cls.__init__).parameters | |
| if "eval_strategy" in parameters: | |
| kwargs["eval_strategy"] = strategy | |
| elif "evaluation_strategy" in parameters: | |
| kwargs["evaluation_strategy"] = strategy | |
| else: | |
| kwargs["do_eval"] = strategy != "no" | |
| def _build_supervised_data_collator(tokenizer: Any, torch_module: Any) -> Callable: | |
| def collate(features: list[Mapping[str, list[int]]]) -> dict[str, object]: | |
| labels = [list(feature["labels"]) for feature in features] | |
| model_features = [ | |
| { | |
| "input_ids": list(feature["input_ids"]), | |
| "attention_mask": list(feature["attention_mask"]), | |
| } | |
| for feature in features | |
| ] | |
| batch = tokenizer.pad(model_features, padding=True, return_tensors="pt") | |
| max_length = batch["input_ids"].shape[1] | |
| label_tensor = torch_module.full( | |
| (len(labels), max_length), | |
| -100, | |
| dtype=torch_module.long, | |
| ) | |
| for index, label in enumerate(labels): | |
| label_tensor[index, : len(label)] = torch_module.tensor( | |
| label, | |
| dtype=torch_module.long, | |
| ) | |
| batch["labels"] = label_tensor | |
| return batch | |
| return collate | |
| def _training_config_from_payload(payload: Mapping[str, object]) -> TrainingConfig: | |
| target_modules = payload.get("target_modules", LORA_TARGET_MODULES) | |
| if not isinstance(target_modules, Sequence) or isinstance(target_modules, (str, bytes)): | |
| raise ValueError("target_modules must be a sequence of strings.") | |
| return TrainingConfig( | |
| run_name=str(payload.get("run_name", DEFAULT_RUN_NAME)), | |
| base_model=str(payload.get("base_model", DEFAULT_BASE_MODEL)), | |
| max_steps=int(payload.get("max_steps", 80)), | |
| num_train_epochs=float(payload.get("num_train_epochs", 3.0)), | |
| learning_rate=float(payload.get("learning_rate", 2e-4)), | |
| max_seq_length=int(payload.get("max_seq_length", 1024)), | |
| per_device_train_batch_size=int(payload.get("per_device_train_batch_size", 1)), | |
| gradient_accumulation_steps=int(payload.get("gradient_accumulation_steps", 4)), | |
| eval_ratio=float(payload.get("eval_ratio", 0.1)), | |
| eval_steps=int(payload.get("eval_steps", 10)), | |
| warmup_ratio=float(payload.get("warmup_ratio", 0.03)), | |
| weight_decay=float(payload.get("weight_decay", 0.0)), | |
| logging_steps=int(payload.get("logging_steps", 5)), | |
| save_total_limit=int(payload.get("save_total_limit", 2)), | |
| seed=int(payload.get("seed", 42)), | |
| assistant_only_loss=bool(payload.get("assistant_only_loss", True)), | |
| lora_r=int(payload.get("lora_r", 16)), | |
| lora_alpha=int(payload.get("lora_alpha", 32)), | |
| lora_dropout=float(payload.get("lora_dropout", 0.05)), | |
| target_modules=tuple(str(module) for module in target_modules), | |
| ) | |
| def _format_training_text(record: Mapping[str, object], tokenizer: Any) -> str: | |
| messages = _validate_messages(record.get("messages"), line_number=None) | |
| return _format_messages(messages, tokenizer, add_generation_prompt=False) | |
| def _format_prompt_text(record: Mapping[str, object], tokenizer: Any) -> str: | |
| messages = _validate_messages(record.get("messages"), line_number=None) | |
| assistant_indices = [ | |
| index for index, message in enumerate(messages) if message["role"].lower() == "assistant" | |
| ] | |
| if not assistant_indices: | |
| raise ValueError("assistant_only_loss requires at least one assistant message.") | |
| prompt_messages = messages[: assistant_indices[-1]] | |
| return _format_messages(prompt_messages, tokenizer, add_generation_prompt=True) | |
| def _format_messages( | |
| messages: Sequence[Mapping[str, str]], | |
| tokenizer: Any, | |
| *, | |
| add_generation_prompt: bool, | |
| ) -> str: | |
| if hasattr(tokenizer, "apply_chat_template"): | |
| try: | |
| return tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=add_generation_prompt, | |
| ) | |
| except Exception: | |
| pass | |
| return _messages_to_training_text( | |
| messages, | |
| add_generation_prompt=add_generation_prompt, | |
| ) | |
| def _print_json(payload: Mapping[str, object]) -> None: | |
| print(json.dumps(payload, indent=2, sort_keys=True), flush=True) | |
| def _build_config_from_args(args: argparse.Namespace) -> TrainingConfig: | |
| return TrainingConfig( | |
| run_name=args.run_name, | |
| base_model=args.base_model, | |
| max_steps=args.max_steps, | |
| num_train_epochs=args.num_train_epochs, | |
| learning_rate=args.learning_rate, | |
| max_seq_length=args.max_seq_length, | |
| per_device_train_batch_size=args.per_device_train_batch_size, | |
| gradient_accumulation_steps=args.gradient_accumulation_steps, | |
| eval_ratio=args.eval_ratio, | |
| eval_steps=args.eval_steps, | |
| warmup_ratio=args.warmup_ratio, | |
| weight_decay=args.weight_decay, | |
| logging_steps=args.logging_steps, | |
| save_total_limit=args.save_total_limit, | |
| seed=args.seed, | |
| assistant_only_loss=args.assistant_only_loss, | |
| lora_r=args.lora_r, | |
| lora_alpha=args.lora_alpha, | |
| lora_dropout=args.lora_dropout, | |
| ) | |
| def _parse_args(argv: Sequence[str] | None = None) -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description=__doc__) | |
| parser.add_argument("--dataset", type=Path, default=DEFAULT_DATASET_PATH) | |
| parser.add_argument("--run-name", default=DEFAULT_RUN_NAME) | |
| parser.add_argument("--base-model", default=DEFAULT_BASE_MODEL) | |
| parser.add_argument("--max-steps", type=int, default=80) | |
| parser.add_argument("--num-train-epochs", type=float, default=3.0) | |
| parser.add_argument("--learning-rate", type=float, default=2e-4) | |
| parser.add_argument("--max-seq-length", type=int, default=1024) | |
| parser.add_argument("--per-device-train-batch-size", type=int, default=1) | |
| parser.add_argument("--gradient-accumulation-steps", type=int, default=4) | |
| parser.add_argument("--eval-ratio", type=float, default=0.1) | |
| parser.add_argument("--eval-steps", type=int, default=10) | |
| parser.add_argument("--warmup-ratio", type=float, default=0.03) | |
| parser.add_argument("--weight-decay", type=float, default=0.0) | |
| parser.add_argument("--logging-steps", type=int, default=5) | |
| parser.add_argument("--save-total-limit", type=int, default=2) | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument("--assistant-only-loss", action=argparse.BooleanOptionalAction, default=True) | |
| parser.add_argument("--lora-r", type=int, default=16) | |
| parser.add_argument("--lora-alpha", type=int, default=32) | |
| parser.add_argument("--lora-dropout", type=float, default=0.05) | |
| parser.add_argument("--dry-run", action="store_true") | |
| return parser.parse_args(argv) | |
| def _main(argv: Sequence[str] | None = None, *, allow_remote: bool = False) -> dict[str, object]: | |
| args = _parse_args(argv) | |
| result = run_training_entrypoint( | |
| dataset=args.dataset, | |
| config=_build_config_from_args(args), | |
| dry_run=args.dry_run, | |
| allow_remote=allow_remote, | |
| ) | |
| _print_json(result) | |
| return result | |
| if modal is not None: | |
| _IMAGE = ( | |
| modal.Image.debian_slim(python_version="3.10") | |
| .uv_pip_install( | |
| "torch", | |
| "transformers>=4.40,<5", | |
| "datasets", | |
| "accelerate", | |
| "peft", | |
| "sentencepiece", | |
| ) | |
| .env({"HF_HOME": CACHE_DIR}) | |
| ) | |
| _CACHE_VOLUME = modal.Volume.from_name("objectverse-diary-hf-cache", create_if_missing=True) | |
| _OUTPUT_VOLUME = modal.Volume.from_name( | |
| "objectverse-diary-lora-output", | |
| create_if_missing=True, | |
| ) | |
| app = modal.App(APP_NAME) | |
| def train_lora_remote( | |
| records: list[dict[str, object]], | |
| config_payload: dict[str, object], | |
| ) -> dict[str, object]: | |
| return _train_lora_impl(records, config_payload) | |
| def modal_entrypoint( | |
| dataset: str = str(DEFAULT_DATASET_PATH), | |
| run_name: str = DEFAULT_RUN_NAME, | |
| base_model: str = DEFAULT_BASE_MODEL, | |
| max_steps: int = 80, | |
| num_train_epochs: float = 3.0, | |
| learning_rate: float = 2e-4, | |
| max_seq_length: int = 1024, | |
| per_device_train_batch_size: int = 1, | |
| gradient_accumulation_steps: int = 4, | |
| eval_ratio: float = 0.1, | |
| eval_steps: int = 10, | |
| warmup_ratio: float = 0.03, | |
| weight_decay: float = 0.0, | |
| logging_steps: int = 5, | |
| save_total_limit: int = 2, | |
| seed: int = 42, | |
| assistant_only_loss: bool = True, | |
| lora_r: int = 16, | |
| lora_alpha: int = 32, | |
| lora_dropout: float = 0.05, | |
| dry_run: bool = False, | |
| ) -> None: | |
| result = run_training_entrypoint( | |
| dataset=Path(dataset), | |
| config=TrainingConfig( | |
| run_name=run_name, | |
| base_model=base_model, | |
| max_steps=max_steps, | |
| num_train_epochs=num_train_epochs, | |
| learning_rate=learning_rate, | |
| max_seq_length=max_seq_length, | |
| per_device_train_batch_size=per_device_train_batch_size, | |
| gradient_accumulation_steps=gradient_accumulation_steps, | |
| eval_ratio=eval_ratio, | |
| eval_steps=eval_steps, | |
| warmup_ratio=warmup_ratio, | |
| weight_decay=weight_decay, | |
| logging_steps=logging_steps, | |
| save_total_limit=save_total_limit, | |
| seed=seed, | |
| assistant_only_loss=assistant_only_loss, | |
| lora_r=lora_r, | |
| lora_alpha=lora_alpha, | |
| lora_dropout=lora_dropout, | |
| ), | |
| dry_run=dry_run, | |
| allow_remote=True, | |
| ) | |
| _print_json(result) | |
| else: | |
| _OUTPUT_VOLUME = None | |
| app = None | |
| def train_lora_remote( | |
| records: list[dict[str, object]], | |
| config_payload: dict[str, object], | |
| ) -> dict[str, object]: | |
| raise RuntimeError("Modal is not installed. Install `requirements-training.txt` first.") | |
| if __name__ == "__main__": | |
| try: | |
| _main(allow_remote=False) | |
| except Exception as exc: | |
| raise SystemExit(str(exc)) from exc | |