TouchGrass-7b / training /trainer.py
Zandy-Wandy's picture
Upload 39 files
4f0238f verified
"""
Trainer for TouchGrass LoRA fine-tuning.
Handles training loop, checkpointing, evaluation.
"""
import os
import json
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from typing import Optional, Dict, List, Any, Callable
from pathlib import Path
import logging
from tqdm import tqdm
from .losses import TouchGrassLoss, compute_lora_gradient_norm, get_parameter_groups
class TouchGrassTrainer:
"""
Trainer for TouchGrass LoRA fine-tuning.
Handles gradient accumulation, mixed precision, checkpointing.
"""
def __init__(
self,
model: nn.Module,
tokenizer,
train_dataset,
config: Dict,
eval_dataset: Optional[Any] = None,
music_modules: Optional[Dict[str, nn.Module]] = None,
):
"""
Initialize trainer.
Args:
model: Base model with LoRA adapters
tokenizer: Tokenizer
train_dataset: Training dataset
config: Training configuration dictionary
eval_dataset: Optional evaluation dataset
music_modules: Optional dict of music modules to include in training
"""
self.model = model
self.tokenizer = tokenizer
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.config = config
self.music_modules = music_modules or {}
# Setup device
self.device = torch.device(config.get("device", "cuda"))
self.model.to(self.device)
# Move music modules to device
for module in self.music_modules.values():
module.to(self.device)
# Setup optimizer (only train LoRA + music modules)
self.optimizer = self._create_optimizer()
# Setup loss
self.loss_fn = TouchGrassLoss(config)
# Training state
self.global_step = 0
self.epoch = 0
# Logging
logging.basicConfig(level=logging.INFO)
self.logger = logging.getLogger(__name__)
def _create_optimizer(self):
"""Create AdamW optimizer with LoRA parameter groups."""
# Get trainable parameters (LoRA + music modules)
trainable_params = []
for name, param in self.model.named_parameters():
if param.requires_grad:
trainable_params.append(param)
# Add music module parameters
for module in self.music_modules.values():
for param in module.parameters():
if param.requires_grad:
trainable_params.append(param)
# Use parameter groups for weight decay
param_groups = get_parameter_groups(self.model, self.config.get("weight_decay", 0.1))
optimizer = torch.optim.AdamW(
param_groups,
lr=self.config.get("learning_rate", 2e-4),
betas=(self.config.get("beta1", 0.9), self.config.get("beta2", 0.95)),
)
self.logger.info(f"Optimizer: {len(param_groups)} parameter groups, {len(trainable_params)} trainable params")
return optimizer
def train(self):
"""Main training loop."""
self.logger.info("Starting training...")
# Create dataloader
train_loader = DataLoader(
self.train_dataset,
batch_size=self.config.get("micro_batch_size", 8),
shuffle=True,
num_workers=self.config.get("num_workers", 4),
pin_memory=self.config.get("pin_memory", True),
)
# Training loop
self.model.train()
for epoch in range(self.config.get("max_epochs", 3)):
self.epoch = epoch
epoch_loss = 0.0
progress_bar = tqdm(train_loader, desc=f"Epoch {epoch}")
for batch_idx, batch in enumerate(progress_bar):
# Move batch to device
batch = {k: v.to(self.device) for k, v in batch.items()}
# Forward pass
outputs = self.model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"],
return_dict=True,
)
logits = outputs["logits"]
labels = batch["labels"]
# Compute loss
loss_dict = self.loss_fn.forward(
logits=logits,
labels=labels,
)
loss = loss_dict["total_loss"]
# Backward pass
loss.backward()
# Gradient accumulation
if (batch_idx + 1) % self.config.get("gradient_accumulation_steps", 1) == 0:
# Gradient clipping
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.config.get("clip_grad_norm", 1.0),
)
# Optimizer step
self.optimizer.step()
self.optimizer.zero_grad()
self.global_step += 1
# Logging
epoch_loss += loss.item()
avg_loss = epoch_loss / (batch_idx + 1)
progress_bar.set_postfix({"loss": avg_loss})
# Save checkpoint
if self.global_step % self.config.get("save_interval", 1000) == 0:
self.save_checkpoint()
# Evaluation
if self.eval_dataset and self.global_step % self.config.get("eval_interval", 1000) == 0:
self.evaluate()
self.logger.info(f"Epoch {epoch} completed. Average loss: {avg_loss:.4f}")
self.logger.info("Training complete!")
def evaluate(self):
"""Run evaluation."""
if not self.eval_dataset:
return
self.logger.info("Running evaluation...")
self.model.eval()
eval_loader = DataLoader(
self.eval_dataset,
batch_size=self.config.get("micro_batch_size", 8),
shuffle=False,
)
total_loss = 0.0
num_batches = 0
with torch.no_grad():
for batch in tqdm(eval_loader, desc="Evaluating"):
batch = {k: v.to(self.device) for k, v in batch.items()}
outputs = self.model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"],
return_dict=True,
)
loss = outputs["loss"]
total_loss += loss.item()
num_batches += 1
avg_eval_loss = total_loss / num_batches
self.logger.info(f"Evaluation loss: {avg_eval_loss:.4f}")
self.model.train()
def save_checkpoint(self, path: Optional[str] = None):
"""Save training checkpoint."""
if path is None:
checkpoint_dir = Path(self.config.get("checkpoint_dir", "checkpoints"))
checkpoint_dir.mkdir(parents=True, exist_ok=True)
path = checkpoint_dir / f"checkpoint-{self.global_step}"
path = Path(path)
path.mkdir(parents=True, exist_ok=True)
# Save model state dict (only LoRA + music modules)
state_dict = {}
for name, param in self.model.named_parameters():
if param.requires_grad:
state_dict[name] = param.cpu()
# Add music modules
for module_name, module in self.music_modules.items():
for name, param in module.named_parameters():
if param.requires_grad:
state_dict[f"music_modules.{module_name}.{name}"] = param.cpu()
checkpoint = {
"global_step": self.global_step,
"epoch": self.epoch,
"model_state_dict": state_dict,
"optimizer_state_dict": self.optimizer.state_dict(),
"config": self.config,
}
torch.save(checkpoint, path / "checkpoint.pt")
self.logger.info(f"Checkpoint saved to {path}")
def load_checkpoint(self, path: str):
"""Load training checkpoint."""
checkpoint = torch.load(path, map_location=self.device)
# Load model weights
model_state_dict = checkpoint["model_state_dict"]
self.model.load_state_dict(model_state_dict, strict=False)
# Load music modules if present
music_state = {k: v for k, v in model_state_dict.items() if k.startswith("music_modules.")}
for module_name, module in self.music_modules.items():
module_state = {k.replace(f"music_modules.{module_name}.", ""): v
for k, v in music_state.items()
if k.startswith(f"music_modules.{module_name}.")}
if module_state:
module.load_state_dict(module_state)
# Load optimizer
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
self.global_step = checkpoint["global_step"]
self.epoch = checkpoint["epoch"]
self.logger.info(f"Checkpoint loaded from {path} (step {self.global_step})")
def test_trainer():
"""Test the trainer with dummy data."""
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model, TaskType
print("Testing TouchGrassTrainer...\n")
# Load base model and tokenizer
print("Loading base model...")
model_name = "Qwen/Qwen3.5-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float32, # Use float32 for testing
trust_remote_code=True,
)
# Add LoRA
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=16,
lora_alpha=32,
lora_dropout=0.1,
target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
)
model = get_peft_model(model, lora_config)
print(f"Model trainable parameters: {model.print_trainable_parameters()}")
# Dummy dataset
class DummyDataset(torch.utils.data.Dataset):
def __init__(self, size=10):
self.size = size
def __len__(self):
return self.size
def __getitem__(self, idx):
return {
"input_ids": torch.randint(0, 32000, (128,)),
"attention_mask": torch.ones(128),
"labels": torch.randint(0, 32000, (128,)),
}
train_dataset = DummyDataset(20)
eval_dataset = DummyDataset(5)
# Config
train_config = {
"learning_rate": 2e-4,
"weight_decay": 0.1,
"beta1": 0.9,
"beta2": 0.95,
"clip_grad_norm": 1.0,
"micro_batch_size": 2,
"gradient_accumulation_steps": 4,
"max_epochs": 1,
"loss_weights": {
"lm_loss": 1.0,
"eq_loss": 0.1,
"music_module_loss": 0.05,
},
"checkpoint_dir": "./test_checkpoints",
"save_interval": 5,
"eval_interval": 5,
}
# Create trainer
trainer = TouchGrassTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
config=train_config,
eval_dataset=eval_dataset,
)
print("\nTrainer initialized successfully!")
print(f"Device: {trainer.device}")
print(f"Number of training samples: {len(train_dataset)}")
# Test one batch
print("\nTesting single forward/backward pass...")
batch = train_dataset[0]
batch = {k: v.to(trainer.device) for k, v in batch.items()}
outputs = model(**batch)
loss = outputs.loss
loss.backward()
print(f"Forward pass loss: {loss.item():.4f}")
print("Backward pass completed!")
print("\nTrainer test complete!")
if __name__ == "__main__":
test_trainer()