| """
|
| Trainer for TouchGrass LoRA fine-tuning.
|
| Handles training loop, checkpointing, evaluation.
|
| """
|
|
|
| import os
|
| import json
|
| import torch
|
| import torch.nn as nn
|
| from torch.utils.data import DataLoader
|
| from typing import Optional, Dict, List, Any, Callable
|
| from pathlib import Path
|
| import logging
|
| from tqdm import tqdm
|
| from .losses import TouchGrassLoss, compute_lora_gradient_norm, get_parameter_groups
|
|
|
|
|
| class TouchGrassTrainer:
|
| """
|
| Trainer for TouchGrass LoRA fine-tuning.
|
| Handles gradient accumulation, mixed precision, checkpointing.
|
| """
|
|
|
| def __init__(
|
| self,
|
| model: nn.Module,
|
| tokenizer,
|
| train_dataset,
|
| config: Dict,
|
| eval_dataset: Optional[Any] = None,
|
| music_modules: Optional[Dict[str, nn.Module]] = None,
|
| ):
|
| """
|
| Initialize trainer.
|
|
|
| Args:
|
| model: Base model with LoRA adapters
|
| tokenizer: Tokenizer
|
| train_dataset: Training dataset
|
| config: Training configuration dictionary
|
| eval_dataset: Optional evaluation dataset
|
| music_modules: Optional dict of music modules to include in training
|
| """
|
| self.model = model
|
| self.tokenizer = tokenizer
|
| self.train_dataset = train_dataset
|
| self.eval_dataset = eval_dataset
|
| self.config = config
|
| self.music_modules = music_modules or {}
|
|
|
|
|
| self.device = torch.device(config.get("device", "cuda"))
|
| self.model.to(self.device)
|
|
|
|
|
| for module in self.music_modules.values():
|
| module.to(self.device)
|
|
|
|
|
| self.optimizer = self._create_optimizer()
|
|
|
|
|
| self.loss_fn = TouchGrassLoss(config)
|
|
|
|
|
| self.global_step = 0
|
| self.epoch = 0
|
|
|
|
|
| logging.basicConfig(level=logging.INFO)
|
| self.logger = logging.getLogger(__name__)
|
|
|
| def _create_optimizer(self):
|
| """Create AdamW optimizer with LoRA parameter groups."""
|
|
|
| trainable_params = []
|
| for name, param in self.model.named_parameters():
|
| if param.requires_grad:
|
| trainable_params.append(param)
|
|
|
|
|
| for module in self.music_modules.values():
|
| for param in module.parameters():
|
| if param.requires_grad:
|
| trainable_params.append(param)
|
|
|
|
|
| param_groups = get_parameter_groups(self.model, self.config.get("weight_decay", 0.1))
|
|
|
| optimizer = torch.optim.AdamW(
|
| param_groups,
|
| lr=self.config.get("learning_rate", 2e-4),
|
| betas=(self.config.get("beta1", 0.9), self.config.get("beta2", 0.95)),
|
| )
|
|
|
| self.logger.info(f"Optimizer: {len(param_groups)} parameter groups, {len(trainable_params)} trainable params")
|
|
|
| return optimizer
|
|
|
| def train(self):
|
| """Main training loop."""
|
| self.logger.info("Starting training...")
|
|
|
|
|
| train_loader = DataLoader(
|
| self.train_dataset,
|
| batch_size=self.config.get("micro_batch_size", 8),
|
| shuffle=True,
|
| num_workers=self.config.get("num_workers", 4),
|
| pin_memory=self.config.get("pin_memory", True),
|
| )
|
|
|
|
|
| self.model.train()
|
| for epoch in range(self.config.get("max_epochs", 3)):
|
| self.epoch = epoch
|
| epoch_loss = 0.0
|
|
|
| progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}")
|
| for batch_idx, batch in enumerate(progress_bar):
|
|
|
| batch = {k: v.to(self.device) for k, v in batch.items()}
|
|
|
|
|
| outputs = self.model(
|
| input_ids=batch["input_ids"],
|
| attention_mask=batch["attention_mask"],
|
| labels=batch["labels"],
|
| return_dict=True,
|
| )
|
|
|
| logits = outputs["logits"]
|
| labels = batch["labels"]
|
|
|
|
|
| loss_dict = self.loss_fn.forward(
|
| logits=logits,
|
| labels=labels,
|
| )
|
| loss = loss_dict["total_loss"]
|
|
|
|
|
| loss.backward()
|
|
|
|
|
| if (batch_idx + 1) % self.config.get("gradient_accumulation_steps", 1) == 0:
|
|
|
| torch.nn.utils.clip_grad_norm_(
|
| self.model.parameters(),
|
| self.config.get("clip_grad_norm", 1.0),
|
| )
|
|
|
|
|
| self.optimizer.step()
|
| self.optimizer.zero_grad()
|
|
|
| self.global_step += 1
|
|
|
|
|
| epoch_loss += loss.item()
|
| avg_loss = epoch_loss / (batch_idx + 1)
|
|
|
| progress_bar.set_postfix({"loss": avg_loss})
|
|
|
|
|
| if self.global_step % self.config.get("save_interval", 1000) == 0:
|
| self.save_checkpoint()
|
|
|
|
|
| if self.eval_dataset and self.global_step % self.config.get("eval_interval", 1000) == 0:
|
| self.evaluate()
|
|
|
| self.logger.info(f"Epoch {epoch} completed. Average loss: {avg_loss:.4f}")
|
|
|
| self.logger.info("Training complete!")
|
|
|
| def evaluate(self):
|
| """Run evaluation."""
|
| if not self.eval_dataset:
|
| return
|
|
|
| self.logger.info("Running evaluation...")
|
| self.model.eval()
|
|
|
| eval_loader = DataLoader(
|
| self.eval_dataset,
|
| batch_size=self.config.get("micro_batch_size", 8),
|
| shuffle=False,
|
| )
|
|
|
| total_loss = 0.0
|
| num_batches = 0
|
|
|
| with torch.no_grad():
|
| for batch in tqdm(eval_loader, desc="Evaluating"):
|
| batch = {k: v.to(self.device) for k, v in batch.items()}
|
| outputs = self.model(
|
| input_ids=batch["input_ids"],
|
| attention_mask=batch["attention_mask"],
|
| labels=batch["labels"],
|
| return_dict=True,
|
| )
|
| loss = outputs["loss"]
|
| total_loss += loss.item()
|
| num_batches += 1
|
|
|
| avg_eval_loss = total_loss / num_batches
|
| self.logger.info(f"Evaluation loss: {avg_eval_loss:.4f}")
|
|
|
| self.model.train()
|
|
|
| def save_checkpoint(self, path: Optional[str] = None):
|
| """Save training checkpoint."""
|
| if path is None:
|
| checkpoint_dir = Path(self.config.get("checkpoint_dir", "checkpoints"))
|
| checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| path = checkpoint_dir / f"checkpoint-{self.global_step}"
|
|
|
| path = Path(path)
|
| path.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
| state_dict = {}
|
| for name, param in self.model.named_parameters():
|
| if param.requires_grad:
|
| state_dict[name] = param.cpu()
|
|
|
|
|
| for module_name, module in self.music_modules.items():
|
| for name, param in module.named_parameters():
|
| if param.requires_grad:
|
| state_dict[f"music_modules.{module_name}.{name}"] = param.cpu()
|
|
|
| checkpoint = {
|
| "global_step": self.global_step,
|
| "epoch": self.epoch,
|
| "model_state_dict": state_dict,
|
| "optimizer_state_dict": self.optimizer.state_dict(),
|
| "config": self.config,
|
| }
|
|
|
| torch.save(checkpoint, path / "checkpoint.pt")
|
| self.logger.info(f"Checkpoint saved to {path}")
|
|
|
| def load_checkpoint(self, path: str):
|
| """Load training checkpoint."""
|
| checkpoint = torch.load(path, map_location=self.device)
|
|
|
|
|
| model_state_dict = checkpoint["model_state_dict"]
|
| self.model.load_state_dict(model_state_dict, strict=False)
|
|
|
|
|
| music_state = {k: v for k, v in model_state_dict.items() if k.startswith("music_modules.")}
|
| for module_name, module in self.music_modules.items():
|
| module_state = {k.replace(f"music_modules.{module_name}.", ""): v
|
| for k, v in music_state.items()
|
| if k.startswith(f"music_modules.{module_name}.")}
|
| if module_state:
|
| module.load_state_dict(module_state)
|
|
|
|
|
| self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
|
|
| self.global_step = checkpoint["global_step"]
|
| self.epoch = checkpoint["epoch"]
|
|
|
| self.logger.info(f"Checkpoint loaded from {path} (step {self.global_step})")
|
|
|
|
|
| def test_trainer():
|
| """Test the trainer with dummy data."""
|
| from transformers import AutoModelForCausalLM, AutoTokenizer
|
| from peft import LoraConfig, get_peft_model, TaskType
|
|
|
| print("Testing TouchGrassTrainer...\n")
|
|
|
|
|
| print("Loading base model...")
|
| model_name = "Qwen/Qwen3.5-3B-Instruct"
|
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| tokenizer.pad_token = tokenizer.eos_token
|
|
|
| model = AutoModelForCausalLM.from_pretrained(
|
| model_name,
|
| torch_dtype=torch.float32,
|
| trust_remote_code=True,
|
| )
|
|
|
|
|
| lora_config = LoraConfig(
|
| task_type=TaskType.CAUSAL_LM,
|
| r=16,
|
| lora_alpha=32,
|
| lora_dropout=0.1,
|
| target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
|
| )
|
| model = get_peft_model(model, lora_config)
|
|
|
| print(f"Model trainable parameters: {model.print_trainable_parameters()}")
|
|
|
|
|
| class DummyDataset(torch.utils.data.Dataset):
|
| def __init__(self, size=10):
|
| self.size = size
|
|
|
| def __len__(self):
|
| return self.size
|
|
|
| def __getitem__(self, idx):
|
| return {
|
| "input_ids": torch.randint(0, 32000, (128,)),
|
| "attention_mask": torch.ones(128),
|
| "labels": torch.randint(0, 32000, (128,)),
|
| }
|
|
|
| train_dataset = DummyDataset(20)
|
| eval_dataset = DummyDataset(5)
|
|
|
|
|
| train_config = {
|
| "learning_rate": 2e-4,
|
| "weight_decay": 0.1,
|
| "beta1": 0.9,
|
| "beta2": 0.95,
|
| "clip_grad_norm": 1.0,
|
| "micro_batch_size": 2,
|
| "gradient_accumulation_steps": 4,
|
| "max_epochs": 1,
|
| "loss_weights": {
|
| "lm_loss": 1.0,
|
| "eq_loss": 0.1,
|
| "music_module_loss": 0.05,
|
| },
|
| "checkpoint_dir": "./test_checkpoints",
|
| "save_interval": 5,
|
| "eval_interval": 5,
|
| }
|
|
|
|
|
| trainer = TouchGrassTrainer(
|
| model=model,
|
| tokenizer=tokenizer,
|
| train_dataset=train_dataset,
|
| config=train_config,
|
| eval_dataset=eval_dataset,
|
| )
|
|
|
| print("\nTrainer initialized successfully!")
|
| print(f"Device: {trainer.device}")
|
| print(f"Number of training samples: {len(train_dataset)}")
|
|
|
|
|
| print("\nTesting single forward/backward pass...")
|
| batch = train_dataset[0]
|
| batch = {k: v.to(trainer.device) for k, v in batch.items()}
|
|
|
| outputs = model(**batch)
|
| loss = outputs.loss
|
| loss.backward()
|
|
|
| print(f"Forward pass loss: {loss.item():.4f}")
|
| print("Backward pass completed!")
|
|
|
| print("\nTrainer test complete!")
|
|
|
|
|
| if __name__ == "__main__":
|
| test_trainer() |