import argparse import json import os import random from dataclasses import dataclass from pathlib import Path from typing import Dict, List import torch from datasets import load_dataset from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training from transformers import ( AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, Trainer, TrainingArguments, set_seed, ) SYSTEM_PREFIX = ( "You are GravityLLM, a Spatial9 scene generation model. " "Given music constraints and stem features, output ONLY valid Spatial9Scene JSON. " "Do not return markdown. Do not explain your answer. " "Respect hard constraints such as object budgets, anchor positions, and low-end centering.\n\n" ) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Fine-tune GravityLLM for Spatial9 scene generation.") parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-1.5B-Instruct") parser.add_argument("--train_file", type=str, default="data/train.jsonl") parser.add_argument("--valid_file", type=str, default="data/valid.jsonl") parser.add_argument("--output_dir", type=str, default="outputs/GravityLLM-Qwen2.5-1.5B-S9") parser.add_argument("--max_length", type=int, default=2048) parser.add_argument("--num_train_epochs", type=float, default=1.0) parser.add_argument("--learning_rate", type=float, default=2e-4) parser.add_argument("--train_batch_size", type=int, default=1) parser.add_argument("--eval_batch_size", type=int, default=1) parser.add_argument("--gradient_accumulation_steps", type=int, default=16) parser.add_argument("--warmup_ratio", type=float, default=0.03) parser.add_argument("--weight_decay", type=float, default=0.0) parser.add_argument("--logging_steps", type=int, default=10) parser.add_argument("--save_steps", type=int, default=200) parser.add_argument("--eval_steps", type=int, default=200) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--lora", action="store_true", help="Enable LoRA adapters.") parser.add_argument("--qlora", action="store_true", help="Enable 4-bit QLoRA training.") parser.add_argument("--lora_r", type=int, default=16) parser.add_argument("--lora_alpha", type=int, default=32) parser.add_argument("--lora_dropout", type=float, default=0.05) parser.add_argument("--bf16", action="store_true") parser.add_argument("--fp16", action="store_true") parser.add_argument("--push_to_hub", action="store_true") parser.add_argument("--hub_model_id", type=str, default=None) parser.add_argument("--hub_private_repo", action="store_true") return parser.parse_args() def load_jsonl(file_path: str): return load_dataset("json", data_files=file_path, split="train") def format_prompt(raw_prompt: str) -> str: raw_prompt = raw_prompt.strip() if raw_prompt.lower().startswith("gravityllm:"): raw_prompt = raw_prompt.split(":", 1)[1].strip() return SYSTEM_PREFIX + raw_prompt + "\n\nOUTPUT:\n" def tokenize_example(example: Dict[str, str], tokenizer, max_length: int) -> Dict[str, List[int]]: prompt_text = format_prompt(example["prompt"]) completion_text = example["completion"].strip() prompt_ids = tokenizer(prompt_text, add_special_tokens=False)["input_ids"] completion_ids = tokenizer(completion_text + tokenizer.eos_token, add_special_tokens=False)["input_ids"] input_ids = prompt_ids + completion_ids labels = [-100] * len(prompt_ids) + completion_ids if len(input_ids) > max_length: input_ids = input_ids[:max_length] labels = labels[:max_length] attention_mask = [1] * len(input_ids) return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, } @dataclass class CausalDataCollator: pad_token_id: int label_pad_token_id: int = -100 def __call__(self, features): max_len = max(len(f["input_ids"]) for f in features) input_ids = [] attention_mask = [] labels = [] for f in features: pad_len = max_len - len(f["input_ids"]) input_ids.append(f["input_ids"] + [self.pad_token_id] * pad_len) attention_mask.append(f["attention_mask"] + [0] * pad_len) labels.append(f["labels"] + [self.label_pad_token_id] * pad_len) batch = { "input_ids": torch.tensor(input_ids, dtype=torch.long), "attention_mask": torch.tensor(attention_mask, dtype=torch.long), "labels": torch.tensor(labels, dtype=torch.long), } return batch def prepare_model(args: argparse.Namespace): model_kwargs = {} if args.qlora: compute_dtype = torch.bfloat16 if args.bf16 else torch.float16 model_kwargs["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=compute_dtype, ) model_kwargs["device_map"] = "auto" model = AutoModelForCausalLM.from_pretrained( args.model, torch_dtype=torch.bfloat16 if args.bf16 else (torch.float16 if args.fp16 else None), trust_remote_code=True, **model_kwargs, ) model.config.use_cache = False if args.qlora: model = prepare_model_for_kbit_training(model) if args.lora or args.qlora: lora_config = LoraConfig( r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=args.lora_dropout, bias="none", task_type="CAUSAL_LM", target_modules="all-linear", ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() return model def main() -> None: args = parse_args() os.makedirs(args.output_dir, exist_ok=True) set_seed(args.seed) tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True, trust_remote_code=True) tokenizer.padding_side = "right" if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token train_ds = load_jsonl(args.train_file) valid_ds = load_jsonl(args.valid_file) if args.valid_file and Path(args.valid_file).exists() else None train_ds = train_ds.map( lambda row: tokenize_example(row, tokenizer, args.max_length), remove_columns=train_ds.column_names, desc="Tokenizing train set", ) if valid_ds is not None: valid_ds = valid_ds.map( lambda row: tokenize_example(row, tokenizer, args.max_length), remove_columns=valid_ds.column_names, desc="Tokenizing valid set", ) model = prepare_model(args) training_args = TrainingArguments( output_dir=args.output_dir, overwrite_output_dir=True, num_train_epochs=args.num_train_epochs, learning_rate=args.learning_rate, per_device_train_batch_size=args.train_batch_size, per_device_eval_batch_size=args.eval_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, warmup_ratio=args.warmup_ratio, weight_decay=args.weight_decay, logging_steps=args.logging_steps, save_steps=args.save_steps, eval_steps=args.eval_steps, evaluation_strategy="steps" if valid_ds is not None else "no", save_strategy="steps", bf16=args.bf16, fp16=args.fp16, report_to="none", gradient_checkpointing=True, lr_scheduler_type="cosine", optim="paged_adamw_32bit" if (args.lora or args.qlora) else "adamw_torch", max_grad_norm=1.0, push_to_hub=args.push_to_hub, hub_model_id=args.hub_model_id, hub_private_repo=args.hub_private_repo, hub_strategy="end" if args.push_to_hub else "every_save", ) trainer = Trainer( model=model, args=training_args, train_dataset=train_ds, eval_dataset=valid_ds, data_collator=CausalDataCollator(pad_token_id=tokenizer.pad_token_id), tokenizer=tokenizer, ) train_result = trainer.train() trainer.save_model(args.output_dir) tokenizer.save_pretrained(args.output_dir) metrics = train_result.metrics with open(Path(args.output_dir) / "training_metrics.json", "w", encoding="utf-8") as f: json.dump(metrics, f, indent=2) run_meta = vars(args).copy() run_meta["train_examples"] = len(train_ds) run_meta["valid_examples"] = len(valid_ds) if valid_ds is not None else 0 with open(Path(args.output_dir) / "run_config.json", "w", encoding="utf-8") as f: json.dump(run_meta, f, indent=2) if args.push_to_hub: trainer.push_to_hub(commit_message="Add GravityLLM fine-tuned adapter") print(f"Training complete. Artifacts saved to: {args.output_dir}") if __name__ == "__main__": main()