| | import argparse |
| | import json |
| | import os |
| | import random |
| | from dataclasses import dataclass |
| | from pathlib import Path |
| | from typing import Dict, List |
| |
|
| | import torch |
| | from datasets import load_dataset |
| | from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training |
| | from transformers import ( |
| | AutoModelForCausalLM, |
| | AutoTokenizer, |
| | BitsAndBytesConfig, |
| | Trainer, |
| | TrainingArguments, |
| | set_seed, |
| | ) |
| |
|
| | SYSTEM_PREFIX = ( |
| | "You are GravityLLM, a Spatial9 scene generation model. " |
| | "Given music constraints and stem features, output ONLY valid Spatial9Scene JSON. " |
| | "Do not return markdown. Do not explain your answer. " |
| | "Respect hard constraints such as object budgets, anchor positions, and low-end centering.\n\n" |
| | ) |
| |
|
| |
|
| | def parse_args() -> argparse.Namespace: |
| | parser = argparse.ArgumentParser(description="Fine-tune GravityLLM for Spatial9 scene generation.") |
| | parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-1.5B-Instruct") |
| | parser.add_argument("--train_file", type=str, default="data/train.jsonl") |
| | parser.add_argument("--valid_file", type=str, default="data/valid.jsonl") |
| | parser.add_argument("--output_dir", type=str, default="outputs/GravityLLM-Qwen2.5-1.5B-S9") |
| | parser.add_argument("--max_length", type=int, default=2048) |
| |
|
| | parser.add_argument("--num_train_epochs", type=float, default=1.0) |
| | parser.add_argument("--learning_rate", type=float, default=2e-4) |
| | parser.add_argument("--train_batch_size", type=int, default=1) |
| | parser.add_argument("--eval_batch_size", type=int, default=1) |
| | parser.add_argument("--gradient_accumulation_steps", type=int, default=16) |
| | parser.add_argument("--warmup_ratio", type=float, default=0.03) |
| | parser.add_argument("--weight_decay", type=float, default=0.0) |
| | parser.add_argument("--logging_steps", type=int, default=10) |
| | parser.add_argument("--save_steps", type=int, default=200) |
| | parser.add_argument("--eval_steps", type=int, default=200) |
| | parser.add_argument("--seed", type=int, default=42) |
| |
|
| | parser.add_argument("--lora", action="store_true", help="Enable LoRA adapters.") |
| | parser.add_argument("--qlora", action="store_true", help="Enable 4-bit QLoRA training.") |
| | parser.add_argument("--lora_r", type=int, default=16) |
| | parser.add_argument("--lora_alpha", type=int, default=32) |
| | parser.add_argument("--lora_dropout", type=float, default=0.05) |
| |
|
| | parser.add_argument("--bf16", action="store_true") |
| | parser.add_argument("--fp16", action="store_true") |
| |
|
| | parser.add_argument("--push_to_hub", action="store_true") |
| | parser.add_argument("--hub_model_id", type=str, default=None) |
| | parser.add_argument("--hub_private_repo", action="store_true") |
| | return parser.parse_args() |
| |
|
| |
|
| | def load_jsonl(file_path: str): |
| | return load_dataset("json", data_files=file_path, split="train") |
| |
|
| |
|
| | def format_prompt(raw_prompt: str) -> str: |
| | raw_prompt = raw_prompt.strip() |
| | if raw_prompt.lower().startswith("gravityllm:"): |
| | raw_prompt = raw_prompt.split(":", 1)[1].strip() |
| | return SYSTEM_PREFIX + raw_prompt + "\n\nOUTPUT:\n" |
| |
|
| |
|
| | def tokenize_example(example: Dict[str, str], tokenizer, max_length: int) -> Dict[str, List[int]]: |
| | prompt_text = format_prompt(example["prompt"]) |
| | completion_text = example["completion"].strip() |
| |
|
| | prompt_ids = tokenizer(prompt_text, add_special_tokens=False)["input_ids"] |
| | completion_ids = tokenizer(completion_text + tokenizer.eos_token, add_special_tokens=False)["input_ids"] |
| |
|
| | input_ids = prompt_ids + completion_ids |
| | labels = [-100] * len(prompt_ids) + completion_ids |
| |
|
| | if len(input_ids) > max_length: |
| | input_ids = input_ids[:max_length] |
| | labels = labels[:max_length] |
| |
|
| | attention_mask = [1] * len(input_ids) |
| | return { |
| | "input_ids": input_ids, |
| | "attention_mask": attention_mask, |
| | "labels": labels, |
| | } |
| |
|
| |
|
| | @dataclass |
| | class CausalDataCollator: |
| | pad_token_id: int |
| | label_pad_token_id: int = -100 |
| |
|
| | def __call__(self, features): |
| | max_len = max(len(f["input_ids"]) for f in features) |
| |
|
| | input_ids = [] |
| | attention_mask = [] |
| | labels = [] |
| |
|
| | for f in features: |
| | pad_len = max_len - len(f["input_ids"]) |
| | input_ids.append(f["input_ids"] + [self.pad_token_id] * pad_len) |
| | attention_mask.append(f["attention_mask"] + [0] * pad_len) |
| | labels.append(f["labels"] + [self.label_pad_token_id] * pad_len) |
| |
|
| | batch = { |
| | "input_ids": torch.tensor(input_ids, dtype=torch.long), |
| | "attention_mask": torch.tensor(attention_mask, dtype=torch.long), |
| | "labels": torch.tensor(labels, dtype=torch.long), |
| | } |
| | return batch |
| |
|
| |
|
| | def prepare_model(args: argparse.Namespace): |
| | model_kwargs = {} |
| | if args.qlora: |
| | compute_dtype = torch.bfloat16 if args.bf16 else torch.float16 |
| | model_kwargs["quantization_config"] = BitsAndBytesConfig( |
| | load_in_4bit=True, |
| | bnb_4bit_quant_type="nf4", |
| | bnb_4bit_use_double_quant=True, |
| | bnb_4bit_compute_dtype=compute_dtype, |
| | ) |
| | model_kwargs["device_map"] = "auto" |
| |
|
| | model = AutoModelForCausalLM.from_pretrained( |
| | args.model, |
| | torch_dtype=torch.bfloat16 if args.bf16 else (torch.float16 if args.fp16 else None), |
| | trust_remote_code=True, |
| | **model_kwargs, |
| | ) |
| | model.config.use_cache = False |
| |
|
| | if args.qlora: |
| | model = prepare_model_for_kbit_training(model) |
| |
|
| | if args.lora or args.qlora: |
| | lora_config = LoraConfig( |
| | r=args.lora_r, |
| | lora_alpha=args.lora_alpha, |
| | lora_dropout=args.lora_dropout, |
| | bias="none", |
| | task_type="CAUSAL_LM", |
| | target_modules="all-linear", |
| | ) |
| | model = get_peft_model(model, lora_config) |
| | model.print_trainable_parameters() |
| |
|
| | return model |
| |
|
| |
|
| | def main() -> None: |
| | args = parse_args() |
| | os.makedirs(args.output_dir, exist_ok=True) |
| | set_seed(args.seed) |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True, trust_remote_code=True) |
| | tokenizer.padding_side = "right" |
| | if tokenizer.pad_token is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| |
|
| | train_ds = load_jsonl(args.train_file) |
| | valid_ds = load_jsonl(args.valid_file) if args.valid_file and Path(args.valid_file).exists() else None |
| |
|
| | train_ds = train_ds.map( |
| | lambda row: tokenize_example(row, tokenizer, args.max_length), |
| | remove_columns=train_ds.column_names, |
| | desc="Tokenizing train set", |
| | ) |
| | if valid_ds is not None: |
| | valid_ds = valid_ds.map( |
| | lambda row: tokenize_example(row, tokenizer, args.max_length), |
| | remove_columns=valid_ds.column_names, |
| | desc="Tokenizing valid set", |
| | ) |
| |
|
| | model = prepare_model(args) |
| |
|
| | training_args = TrainingArguments( |
| | output_dir=args.output_dir, |
| | overwrite_output_dir=True, |
| | num_train_epochs=args.num_train_epochs, |
| | learning_rate=args.learning_rate, |
| | per_device_train_batch_size=args.train_batch_size, |
| | per_device_eval_batch_size=args.eval_batch_size, |
| | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| | warmup_ratio=args.warmup_ratio, |
| | weight_decay=args.weight_decay, |
| | logging_steps=args.logging_steps, |
| | save_steps=args.save_steps, |
| | eval_steps=args.eval_steps, |
| | evaluation_strategy="steps" if valid_ds is not None else "no", |
| | save_strategy="steps", |
| | bf16=args.bf16, |
| | fp16=args.fp16, |
| | report_to="none", |
| | gradient_checkpointing=True, |
| | lr_scheduler_type="cosine", |
| | optim="paged_adamw_32bit" if (args.lora or args.qlora) else "adamw_torch", |
| | max_grad_norm=1.0, |
| | push_to_hub=args.push_to_hub, |
| | hub_model_id=args.hub_model_id, |
| | hub_private_repo=args.hub_private_repo, |
| | hub_strategy="end" if args.push_to_hub else "every_save", |
| | ) |
| |
|
| | trainer = Trainer( |
| | model=model, |
| | args=training_args, |
| | train_dataset=train_ds, |
| | eval_dataset=valid_ds, |
| | data_collator=CausalDataCollator(pad_token_id=tokenizer.pad_token_id), |
| | tokenizer=tokenizer, |
| | ) |
| |
|
| | train_result = trainer.train() |
| | trainer.save_model(args.output_dir) |
| | tokenizer.save_pretrained(args.output_dir) |
| |
|
| | metrics = train_result.metrics |
| | with open(Path(args.output_dir) / "training_metrics.json", "w", encoding="utf-8") as f: |
| | json.dump(metrics, f, indent=2) |
| |
|
| | run_meta = vars(args).copy() |
| | run_meta["train_examples"] = len(train_ds) |
| | run_meta["valid_examples"] = len(valid_ds) if valid_ds is not None else 0 |
| | with open(Path(args.output_dir) / "run_config.json", "w", encoding="utf-8") as f: |
| | json.dump(run_meta, f, indent=2) |
| |
|
| | if args.push_to_hub: |
| | trainer.push_to_hub(commit_message="Add GravityLLM fine-tuned adapter") |
| | print(f"Training complete. Artifacts saved to: {args.output_dir}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|