""" Trainer for TouchGrass LoRA fine-tuning. Handles training loop, checkpointing, evaluation. """ import os import json import torch import torch.nn as nn from torch.utils.data import DataLoader from typing import Optional, Dict, List, Any, Callable from pathlib import Path import logging from tqdm import tqdm from .losses import TouchGrassLoss, compute_lora_gradient_norm, get_parameter_groups class TouchGrassTrainer: """ Trainer for TouchGrass LoRA fine-tuning. Handles gradient accumulation, mixed precision, checkpointing. """ def __init__( self, model: nn.Module, tokenizer, train_dataset, config: Dict, eval_dataset: Optional[Any] = None, music_modules: Optional[Dict[str, nn.Module]] = None, ): """ Initialize trainer. Args: model: Base model with LoRA adapters tokenizer: Tokenizer train_dataset: Training dataset config: Training configuration dictionary eval_dataset: Optional evaluation dataset music_modules: Optional dict of music modules to include in training """ self.model = model self.tokenizer = tokenizer self.train_dataset = train_dataset self.eval_dataset = eval_dataset self.config = config self.music_modules = music_modules or {} # Setup device self.device = torch.device(config.get("device", "cuda")) self.model.to(self.device) # Move music modules to device for module in self.music_modules.values(): module.to(self.device) # Setup optimizer (only train LoRA + music modules) self.optimizer = self._create_optimizer() # Setup loss self.loss_fn = TouchGrassLoss(config) # Training state self.global_step = 0 self.epoch = 0 # Logging logging.basicConfig(level=logging.INFO) self.logger = logging.getLogger(__name__) def _create_optimizer(self): """Create AdamW optimizer with LoRA parameter groups.""" # Get trainable parameters (LoRA + music modules) trainable_params = [] for name, param in self.model.named_parameters(): if param.requires_grad: trainable_params.append(param) # Add music module parameters for module in self.music_modules.values(): for param in module.parameters(): if param.requires_grad: trainable_params.append(param) # Use parameter groups for weight decay param_groups = get_parameter_groups(self.model, self.config.get("weight_decay", 0.1)) optimizer = torch.optim.AdamW( param_groups, lr=self.config.get("learning_rate", 2e-4), betas=(self.config.get("beta1", 0.9), self.config.get("beta2", 0.95)), ) self.logger.info(f"Optimizer: {len(param_groups)} parameter groups, {len(trainable_params)} trainable params") return optimizer def train(self): """Main training loop.""" self.logger.info("Starting training...") # Create dataloader train_loader = DataLoader( self.train_dataset, batch_size=self.config.get("micro_batch_size", 8), shuffle=True, num_workers=self.config.get("num_workers", 4), pin_memory=self.config.get("pin_memory", True), ) # Training loop self.model.train() for epoch in range(self.config.get("max_epochs", 3)): self.epoch = epoch epoch_loss = 0.0 progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}") for batch_idx, batch in enumerate(progress_bar): # Move batch to device batch = {k: v.to(self.device) for k, v in batch.items()} # Forward pass outputs = self.model( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"], return_dict=True, ) logits = outputs["logits"] labels = batch["labels"] # Compute loss loss_dict = self.loss_fn.forward( logits=logits, labels=labels, ) loss = loss_dict["total_loss"] # Backward pass loss.backward() # Gradient accumulation if (batch_idx + 1) % self.config.get("gradient_accumulation_steps", 1) == 0: # Gradient clipping torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.config.get("clip_grad_norm", 1.0), ) # Optimizer step self.optimizer.step() self.optimizer.zero_grad() self.global_step += 1 # Logging epoch_loss += loss.item() avg_loss = epoch_loss / (batch_idx + 1) progress_bar.set_postfix({"loss": avg_loss}) # Save checkpoint if self.global_step % self.config.get("save_interval", 1000) == 0: self.save_checkpoint() # Evaluation if self.eval_dataset and self.global_step % self.config.get("eval_interval", 1000) == 0: self.evaluate() self.logger.info(f"Epoch {epoch} completed. Average loss: {avg_loss:.4f}") self.logger.info("Training complete!") def evaluate(self): """Run evaluation.""" if not self.eval_dataset: return self.logger.info("Running evaluation...") self.model.eval() eval_loader = DataLoader( self.eval_dataset, batch_size=self.config.get("micro_batch_size", 8), shuffle=False, ) total_loss = 0.0 num_batches = 0 with torch.no_grad(): for batch in tqdm(eval_loader, desc="Evaluating"): batch = {k: v.to(self.device) for k, v in batch.items()} outputs = self.model( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"], return_dict=True, ) loss = outputs["loss"] total_loss += loss.item() num_batches += 1 avg_eval_loss = total_loss / num_batches self.logger.info(f"Evaluation loss: {avg_eval_loss:.4f}") self.model.train() def save_checkpoint(self, path: Optional[str] = None): """Save training checkpoint.""" if path is None: checkpoint_dir = Path(self.config.get("checkpoint_dir", "checkpoints")) checkpoint_dir.mkdir(parents=True, exist_ok=True) path = checkpoint_dir / f"checkpoint-{self.global_step}" path = Path(path) path.mkdir(parents=True, exist_ok=True) # Save model state dict (only LoRA + music modules) state_dict = {} for name, param in self.model.named_parameters(): if param.requires_grad: state_dict[name] = param.cpu() # Add music modules for module_name, module in self.music_modules.items(): for name, param in module.named_parameters(): if param.requires_grad: state_dict[f"music_modules.{module_name}.{name}"] = param.cpu() checkpoint = { "global_step": self.global_step, "epoch": self.epoch, "model_state_dict": state_dict, "optimizer_state_dict": self.optimizer.state_dict(), "config": self.config, } torch.save(checkpoint, path / "checkpoint.pt") self.logger.info(f"Checkpoint saved to {path}") def load_checkpoint(self, path: str): """Load training checkpoint.""" checkpoint = torch.load(path, map_location=self.device) # Load model weights model_state_dict = checkpoint["model_state_dict"] self.model.load_state_dict(model_state_dict, strict=False) # Load music modules if present music_state = {k: v for k, v in model_state_dict.items() if k.startswith("music_modules.")} for module_name, module in self.music_modules.items(): module_state = {k.replace(f"music_modules.{module_name}.", ""): v for k, v in music_state.items() if k.startswith(f"music_modules.{module_name}.")} if module_state: module.load_state_dict(module_state) # Load optimizer self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) self.global_step = checkpoint["global_step"] self.epoch = checkpoint["epoch"] self.logger.info(f"Checkpoint loaded from {path} (step {self.global_step})") def test_trainer(): """Test the trainer with dummy data.""" from transformers import AutoModelForCausalLM, AutoTokenizer from peft import LoraConfig, get_peft_model, TaskType print("Testing TouchGrassTrainer...\n") # Load base model and tokenizer print("Loading base model...") model_name = "Qwen/Qwen3.5-3B-Instruct" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float32, # Use float32 for testing trust_remote_code=True, ) # Add LoRA lora_config = LoraConfig( task_type=TaskType.CAUSAL_LM, r=16, lora_alpha=32, lora_dropout=0.1, target_modules=["q_proj", "v_proj", "k_proj", "o_proj"], ) model = get_peft_model(model, lora_config) print(f"Model trainable parameters: {model.print_trainable_parameters()}") # Dummy dataset class DummyDataset(torch.utils.data.Dataset): def __init__(self, size=10): self.size = size def __len__(self): return self.size def __getitem__(self, idx): return { "input_ids": torch.randint(0, 32000, (128,)), "attention_mask": torch.ones(128), "labels": torch.randint(0, 32000, (128,)), } train_dataset = DummyDataset(20) eval_dataset = DummyDataset(5) # Config train_config = { "learning_rate": 2e-4, "weight_decay": 0.1, "beta1": 0.9, "beta2": 0.95, "clip_grad_norm": 1.0, "micro_batch_size": 2, "gradient_accumulation_steps": 4, "max_epochs": 1, "loss_weights": { "lm_loss": 1.0, "eq_loss": 0.1, "music_module_loss": 0.05, }, "checkpoint_dir": "./test_checkpoints", "save_interval": 5, "eval_interval": 5, } # Create trainer trainer = TouchGrassTrainer( model=model, tokenizer=tokenizer, train_dataset=train_dataset, config=train_config, eval_dataset=eval_dataset, ) print("\nTrainer initialized successfully!") print(f"Device: {trainer.device}") print(f"Number of training samples: {len(train_dataset)}") # Test one batch print("\nTesting single forward/backward pass...") batch = train_dataset[0] batch = {k: v.to(trainer.device) for k, v in batch.items()} outputs = model(**batch) loss = outputs.loss loss.backward() print(f"Forward pass loss: {loss.item():.4f}") print("Backward pass completed!") print("\nTrainer test complete!") if __name__ == "__main__": test_trainer()