| |
|
| | """
|
| | Main training entry point for TouchGrass models.
|
| | Fine-tunes Qwen3.5 with LoRA and music modules.
|
| | """
|
| |
|
| | import argparse
|
| | import sys
|
| | from pathlib import Path
|
| |
|
| | import torch
|
| | from transformers import AutoModelForCausalLM, AutoTokenizer
|
| | from peft import LoraConfig, get_peft_model, TaskType
|
| |
|
| | from configs.touchgrass_3b_config import TOUCHGRASS_3B_CONFIG
|
| | from configs.touchgrass_7b_config import TOUCHGRASS_7B_CONFIG
|
| | from configs.training_config import (
|
| | TRAINING_CONFIG_3B_CUDA,
|
| | TRAINING_CONFIG_7B_CUDA,
|
| | TRAINING_CONFIG_MPS,
|
| | )
|
| | from data.dataset_loader import TouchGrassDataset
|
| | from training.trainer import TouchGrassTrainer
|
| | from tokenizer.music_token_extension import MusicTokenizerExtension
|
| |
|
| |
|
| | def parse_args():
|
| | parser = argparse.ArgumentParser(description="Train TouchGrass music assistant model")
|
| | parser.add_argument(
|
| | "--model_size",
|
| | type=str,
|
| | choices=["3b", "7b"],
|
| | default="3b",
|
| | help="Model size to train",
|
| | )
|
| | parser.add_argument(
|
| | "--device",
|
| | type=str,
|
| | default="cuda",
|
| | choices=["cuda", "mps", "cpu"],
|
| | help="Device to train on",
|
| | )
|
| | parser.add_argument(
|
| | "--use_mps",
|
| | action="store_true",
|
| | help="Use MPS backend (Apple Silicon)",
|
| | )
|
| | parser.add_argument(
|
| | "--data_dir",
|
| | type=str,
|
| | default="./data/processed",
|
| | help="Directory with processed data shards",
|
| | )
|
| | parser.add_argument(
|
| | "--output_dir",
|
| | type=str,
|
| | default="./checkpoints",
|
| | help="Output directory for checkpoints",
|
| | )
|
| | parser.add_argument(
|
| | "--max_steps",
|
| | type=int,
|
| | default=None,
|
| | help="Override max training steps",
|
| | )
|
| | parser.add_argument(
|
| | "--micro_batch_size",
|
| | type=int,
|
| | default=None,
|
| | help="Override micro batch size",
|
| | )
|
| | parser.add_argument(
|
| | "--lora_r",
|
| | type=int,
|
| | default=16,
|
| | help="LoRA rank",
|
| | )
|
| | parser.add_argument(
|
| | "--lora_alpha",
|
| | type=int,
|
| | default=32,
|
| | help="LoRA alpha",
|
| | )
|
| | parser.add_argument(
|
| | "--resume_from_checkpoint",
|
| | type=str,
|
| | default=None,
|
| | help="Resume training from checkpoint",
|
| | )
|
| | parser.add_argument(
|
| | "--generate_data",
|
| | action="store_true",
|
| | help="Generate synthetic training data before training",
|
| | )
|
| | parser.add_argument(
|
| | "--num_train_samples",
|
| | type=int,
|
| | default=10000,
|
| | help="Number of training samples to generate",
|
| | )
|
| | return parser.parse_args()
|
| |
|
| |
|
| | def load_tokenizer(config: dict, args):
|
| | """Load and extend tokenizer with music tokens."""
|
| | base_model = config["base_model"]
|
| | print(f"Loading base tokenizer: {base_model}")
|
| |
|
| |
|
| | tokenizer_ext = MusicTokenizerExtension(
|
| | base_tokenizer_name=base_model,
|
| | special_tokens=config.get("special_tokens"),
|
| | )
|
| |
|
| | tokenizer = tokenizer_ext.get_tokenizer()
|
| | print(f"Extended tokenizer vocab size: {tokenizer.vocab_size}")
|
| |
|
| | return tokenizer_ext, tokenizer
|
| |
|
| |
|
| | def load_model(config: dict, args, tokenizer):
|
| | """Load base model and apply LoRA."""
|
| | base_model = config["base_model"]
|
| | print(f"Loading base model: {base_model}")
|
| |
|
| |
|
| | if args.device == "cuda" and torch.cuda.is_available():
|
| | dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
| | elif args.device == "mps":
|
| | dtype = torch.float32
|
| | else:
|
| | dtype = torch.float32
|
| |
|
| |
|
| | model = AutoModelForCausalLM.from_pretrained(
|
| | base_model,
|
| | torch_dtype=dtype,
|
| | trust_remote_code=True,
|
| | )
|
| |
|
| |
|
| | model.resize_token_embeddings(tokenizer.vocab_size)
|
| |
|
| |
|
| | print("Applying LoRA...")
|
| | lora_config = LoraConfig(
|
| | task_type=TaskType.CAUSAL_LM,
|
| | r=args.lora_r,
|
| | lora_alpha=args.lora_alpha,
|
| | lora_dropout=0.1,
|
| | target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
|
| | bias="none",
|
| | )
|
| |
|
| | model = get_peft_model(model, lora_config)
|
| | model.print_trainable_parameters()
|
| |
|
| | return model
|
| |
|
| |
|
| | def generate_synthetic_data(config: dict, args, tokenizer):
|
| | """Generate synthetic training data."""
|
| | from data.music_qa_generator import MusicQAGenerator
|
| | from data.chat_formatter import ChatFormatter
|
| |
|
| | print("Generating synthetic training data...")
|
| |
|
| |
|
| | generator = MusicQAGenerator(seed=42)
|
| |
|
| |
|
| | output_dir = Path(args.data_dir)
|
| | output_dir.mkdir(parents=True, exist_ok=True)
|
| |
|
| |
|
| | dataset = generator.generate_dataset(
|
| | num_samples=args.num_train_samples,
|
| | output_path=output_dir / "synthetic_music_qa.jsonl",
|
| | )
|
| |
|
| |
|
| | formatter = ChatFormatter(tokenizer=tokenizer)
|
| | formatted_samples = []
|
| |
|
| | for item in dataset:
|
| | formatted = formatter.format_qa_pair(
|
| | question=item["messages"][1]["content"],
|
| | answer=item["messages"][2]["content"],
|
| | context=None,
|
| | )
|
| | formatted_samples.append(formatted)
|
| |
|
| |
|
| | splits = formatter.create_pretraining_dataset(
|
| | formatted_samples,
|
| | output_dir=output_dir,
|
| | train_split=0.9,
|
| | )
|
| |
|
| | print(f"Data generation complete. Train: {splits['train']}, Val: {splits['val']}")
|
| |
|
| | return splits
|
| |
|
| |
|
| | def load_datasets(args, tokenizer):
|
| | """Load training and validation datasets."""
|
| | data_dir = Path(args.data_dir)
|
| |
|
| | train_path = data_dir / "train.jsonl"
|
| | val_path = data_dir / "val.jsonl"
|
| |
|
| | if not train_path.exists() or not val_path.exists():
|
| | print(f"Data not found in {data_dir}. Generate with --generate_data")
|
| | sys.exit(1)
|
| |
|
| | print(f"Loading datasets from {data_dir}")
|
| |
|
| | train_dataset = TouchGrassDataset(
|
| | data_path=str(train_path),
|
| | tokenizer=tokenizer,
|
| | max_seq_length=4096,
|
| | mode="train",
|
| | )
|
| |
|
| | val_dataset = TouchGrassDataset(
|
| | data_path=str(val_path),
|
| | tokenizer=tokenizer,
|
| | max_seq_length=4096,
|
| | mode="eval",
|
| | )
|
| |
|
| | return train_dataset, val_dataset
|
| |
|
| |
|
| | def main():
|
| | args = parse_args()
|
| |
|
| |
|
| | if args.model_size == "3b":
|
| | model_config = TOUCHGRASS_3B_CONFIG.copy()
|
| | train_config = TRAINING_CONFIG_3B_CUDA.copy()
|
| | else:
|
| | model_config = TOUCHGRASS_7B_CONFIG.copy()
|
| | train_config = TRAINING_CONFIG_7B_CUDA.copy()
|
| |
|
| |
|
| | if args.use_mps or args.device == "mps":
|
| | train_config = TRAINING_CONFIG_MPS.copy()
|
| | train_config["use_mps"] = True
|
| |
|
| |
|
| | if args.max_steps:
|
| | train_config["max_steps"] = args.max_steps
|
| | if args.micro_batch_size:
|
| | train_config["micro_batch_size"] = args.micro_batch_size
|
| |
|
| |
|
| | device = torch.device(args.device)
|
| | train_config["device"] = args.device
|
| |
|
| | print(f"Training TouchGrass-{args.model_size.upper()}")
|
| | print(f"Device: {device}")
|
| | print(f"Max steps: {train_config['max_steps']}")
|
| | print(f"Micro batch size: {train_config['micro_batch_size']}")
|
| | print(f"LoRA: r={args.lora_r}, alpha={args.lora_alpha}")
|
| |
|
| |
|
| | tokenizer_ext, tokenizer = load_tokenizer(model_config, args)
|
| |
|
| |
|
| | if args.generate_data:
|
| | generate_synthetic_data(model_config, args, tokenizer)
|
| |
|
| |
|
| | train_dataset, val_dataset = load_datasets(args, tokenizer)
|
| | print(f"Training samples: {len(train_dataset)}")
|
| | print(f"Validation samples: {len(val_dataset)}")
|
| |
|
| |
|
| | model = load_model(model_config, args, tokenizer)
|
| |
|
| |
|
| | trainer = TouchGrassTrainer(
|
| | model=model,
|
| | tokenizer=tokenizer,
|
| | train_dataset=train_dataset,
|
| | config=train_config,
|
| | eval_dataset=val_dataset,
|
| | )
|
| |
|
| |
|
| | if args.resume_from_checkpoint:
|
| | trainer.load_checkpoint(args.resume_from_checkpoint)
|
| |
|
| |
|
| | trainer.train()
|
| |
|
| |
|
| | output_dir = Path(args.output_dir) / f"touchgrass-{args.model_size}b-final"
|
| | output_dir.mkdir(parents=True, exist_ok=True)
|
| |
|
| | print(f"\nSaving final model to {output_dir}")
|
| | model.save_pretrained(output_dir)
|
| | tokenizer.save_pretrained(output_dir)
|
| |
|
| |
|
| | tokenizer_ext.save_pretrained(output_dir)
|
| |
|
| | print("Training complete! Model saved.")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | main() |