#!/usr/bin/env python3 """ Main training entry point for TouchGrass models. Fine-tunes Qwen3.5 with LoRA and music modules. """ import argparse import sys from pathlib import Path import torch from transformers import AutoModelForCausalLM, AutoTokenizer from peft import LoraConfig, get_peft_model, TaskType from configs.touchgrass_3b_config import TOUCHGRASS_3B_CONFIG from configs.touchgrass_7b_config import TOUCHGRASS_7B_CONFIG from configs.training_config import ( TRAINING_CONFIG_3B_CUDA, TRAINING_CONFIG_7B_CUDA, TRAINING_CONFIG_MPS, ) from data.dataset_loader import TouchGrassDataset from training.trainer import TouchGrassTrainer from tokenizer.music_token_extension import MusicTokenizerExtension def parse_args(): parser = argparse.ArgumentParser(description="Train TouchGrass music assistant model") parser.add_argument( "--model_size", type=str, choices=["3b", "7b"], default="3b", help="Model size to train", ) parser.add_argument( "--device", type=str, default="cuda", choices=["cuda", "mps", "cpu"], help="Device to train on", ) parser.add_argument( "--use_mps", action="store_true", help="Use MPS backend (Apple Silicon)", ) parser.add_argument( "--data_dir", type=str, default="./data/processed", help="Directory with processed data shards", ) parser.add_argument( "--output_dir", type=str, default="./checkpoints", help="Output directory for checkpoints", ) parser.add_argument( "--max_steps", type=int, default=None, help="Override max training steps", ) parser.add_argument( "--micro_batch_size", type=int, default=None, help="Override micro batch size", ) parser.add_argument( "--lora_r", type=int, default=16, help="LoRA rank", ) parser.add_argument( "--lora_alpha", type=int, default=32, help="LoRA alpha", ) parser.add_argument( "--resume_from_checkpoint", type=str, default=None, help="Resume training from checkpoint", ) parser.add_argument( "--generate_data", action="store_true", help="Generate synthetic training data before training", ) parser.add_argument( "--num_train_samples", type=int, default=10000, help="Number of training samples to generate", ) return parser.parse_args() def load_tokenizer(config: dict, args): """Load and extend tokenizer with music tokens.""" base_model = config["base_model"] print(f"Loading base tokenizer: {base_model}") # Extend tokenizer with music tokens tokenizer_ext = MusicTokenizerExtension( base_tokenizer_name=base_model, special_tokens=config.get("special_tokens"), ) tokenizer = tokenizer_ext.get_tokenizer() print(f"Extended tokenizer vocab size: {tokenizer.vocab_size}") return tokenizer_ext, tokenizer def load_model(config: dict, args, tokenizer): """Load base model and apply LoRA.""" base_model = config["base_model"] print(f"Loading base model: {base_model}") # Determine torch dtype if args.device == "cuda" and torch.cuda.is_available(): dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 elif args.device == "mps": dtype = torch.float32 # MPS doesn't support bf16 well else: dtype = torch.float32 # Load model model = AutoModelForCausalLM.from_pretrained( base_model, torch_dtype=dtype, trust_remote_code=True, ) # Resize embeddings to match extended tokenizer model.resize_token_embeddings(tokenizer.vocab_size) # Apply LoRA print("Applying LoRA...") lora_config = LoraConfig( task_type=TaskType.CAUSAL_LM, r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=0.1, target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], bias="none", ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() return model def generate_synthetic_data(config: dict, args, tokenizer): """Generate synthetic training data.""" from data.music_qa_generator import MusicQAGenerator from data.chat_formatter import ChatFormatter print("Generating synthetic training data...") # Create generator generator = MusicQAGenerator(seed=42) # Generate dataset output_dir = Path(args.data_dir) output_dir.mkdir(parents=True, exist_ok=True) # Generate full dataset dataset = generator.generate_dataset( num_samples=args.num_train_samples, output_path=output_dir / "synthetic_music_qa.jsonl", ) # Format with chat formatter formatter = ChatFormatter(tokenizer=tokenizer) formatted_samples = [] for item in dataset: formatted = formatter.format_qa_pair( question=item["messages"][1]["content"], answer=item["messages"][2]["content"], context=None, # Context already in question ) formatted_samples.append(formatted) # Create train/val splits splits = formatter.create_pretraining_dataset( formatted_samples, output_dir=output_dir, train_split=0.9, ) print(f"Data generation complete. Train: {splits['train']}, Val: {splits['val']}") return splits def load_datasets(args, tokenizer): """Load training and validation datasets.""" data_dir = Path(args.data_dir) train_path = data_dir / "train.jsonl" val_path = data_dir / "val.jsonl" if not train_path.exists() or not val_path.exists(): print(f"Data not found in {data_dir}. Generate with --generate_data") sys.exit(1) print(f"Loading datasets from {data_dir}") train_dataset = TouchGrassDataset( data_path=str(train_path), tokenizer=tokenizer, max_seq_length=4096, mode="train", ) val_dataset = TouchGrassDataset( data_path=str(val_path), tokenizer=tokenizer, max_seq_length=4096, mode="eval", ) return train_dataset, val_dataset def main(): args = parse_args() # Load config if args.model_size == "3b": model_config = TOUCHGRASS_3B_CONFIG.copy() train_config = TRAINING_CONFIG_3B_CUDA.copy() else: model_config = TOUCHGRASS_7B_CONFIG.copy() train_config = TRAINING_CONFIG_7B_CUDA.copy() # Override with MPS config if needed if args.use_mps or args.device == "mps": train_config = TRAINING_CONFIG_MPS.copy() train_config["use_mps"] = True # Apply overrides if args.max_steps: train_config["max_steps"] = args.max_steps if args.micro_batch_size: train_config["micro_batch_size"] = args.micro_batch_size # Set device device = torch.device(args.device) train_config["device"] = args.device print(f"Training TouchGrass-{args.model_size.upper()}") print(f"Device: {device}") print(f"Max steps: {train_config['max_steps']}") print(f"Micro batch size: {train_config['micro_batch_size']}") print(f"LoRA: r={args.lora_r}, alpha={args.lora_alpha}") # Load tokenizer tokenizer_ext, tokenizer = load_tokenizer(model_config, args) # Generate data if requested if args.generate_data: generate_synthetic_data(model_config, args, tokenizer) # Load datasets train_dataset, val_dataset = load_datasets(args, tokenizer) print(f"Training samples: {len(train_dataset)}") print(f"Validation samples: {len(val_dataset)}") # Load model with LoRA model = load_model(model_config, args, tokenizer) # Create trainer trainer = TouchGrassTrainer( model=model, tokenizer=tokenizer, train_dataset=train_dataset, config=train_config, eval_dataset=val_dataset, ) # Resume from checkpoint if specified if args.resume_from_checkpoint: trainer.load_checkpoint(args.resume_from_checkpoint) # Train trainer.train() # Save final model output_dir = Path(args.output_dir) / f"touchgrass-{args.model_size}b-final" output_dir.mkdir(parents=True, exist_ok=True) print(f"\nSaving final model to {output_dir}") model.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) # Save tokenizer extension metadata tokenizer_ext.save_pretrained(output_dir) print("Training complete! Model saved.") if __name__ == "__main__": main()