"""Advanced Trainer with Multi-Task Learning, Curriculum, and MoE Support""" import logging import os from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn from torch.utils.data import DataLoader from torch.cuda.amp import GradScaler, autocast from tqdm import tqdm from transformers import PreTrainedModel from ..configs import ( DataConfig, TrainingConfig, ZenithConfig, get_7b_config, get_32b_config, get_70b_config, ) from ..data import ( OpenThoughtsProcessor, OpenThoughtsConfig, CurriculumSampler, QualityFilter, ) from ..evaluation import BenchmarkSuite, BenchmarkConfig from ..utils import CheckpointManager, MetricsLogger, setup_logging logger = logging.getLogger(__name__) @dataclass class TrainerConfig: """Complete trainer configuration.""" model_config: ZenithConfig data_config: DataConfig training_config: TrainingConfig # Paths output_dir: str = "./outputs" logging_dir: str = "./logs" checkpoint_dir: str = "./checkpoints" # Distributed training local_rank: int = -1 world_size: int = 1 distributed: bool = False # Mixed precision use_amp: bool = True amp_dtype: str = "bfloat16" # Gradient accumulation gradient_accumulation_steps: int = 4 # Logging and evaluation log_interval: int = 10 eval_interval: int = 500 save_interval: int = 1000 # Resume resume_from_checkpoint: Optional[str] = None def __post_init__(self): """Setup derived configs.""" self.training_config.gradient_accumulation_steps = self.gradient_accumulation_steps class MultiTaskLoss(nn.Module): """Multi-task loss for different objectives.""" def __init__(self, task_weights: Dict[str, float]): super().__init__() self.task_weights = task_weights self.loss_fns = { "next_token": nn.CrossEntropyLoss(ignore_index=-100), "thoughts": nn.MSELoss(), "eq_classification": nn.CrossEntropyLoss(), "frustration_detection": nn.MSELoss(), } def forward( self, outputs: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor], ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """Compute weighted multi-task loss.""" total_loss = 0.0 losses = {} # Next token prediction (primary LM loss) if "next_token" in self.task_weights: lm_loss = self.loss_fns["next_token"]( outputs["logits"].view(-1, outputs["logits"].size(-1)), batch["labels"].view(-1), ) total_loss += self.task_weights["next_token"] * lm_loss losses["next_token"] = lm_loss # Thoughts prediction (auxiliary) if "thoughts" in self.task_weights and "thoughts_logits" in outputs: thoughts_loss = self.loss_fns["thoughts"]( outputs["thoughts_logits"], batch.get("thoughts_labels", torch.zeros_like(outputs["thoughts_logits"])), ) total_loss += self.task_weights["thoughts"] * thoughts_loss losses["thoughts"] = thoughts_loss # Emotion classification if "eq_classification" in self.task_weights and "emotion_logits" in outputs: emotion_loss = self.loss_fns["eq_classification"]( outputs["emotion_logits"], batch.get("emotion_labels", torch.zeros_like(outputs["emotion_logits"][:, 0]).long()), ) total_loss += self.task_weights["eq_classification"] * emotion_loss losses["eq_classification"] = emotion_loss # Frustration detection if "frustration_detection" in self.task_weights and "frustration_logits" in outputs: frustration_loss = self.loss_fns["frustration_detection"]( outputs["frustration_logits"].squeeze(-1), batch.get("frustration_labels", torch.zeros_like(outputs["frustration_logits"].squeeze(-1))), ) total_loss += self.task_weights["frustration_detection"] * frustration_loss losses["frustration_detection"] = frustration_loss losses["total"] = total_loss return total_loss, losses class Trainer: """Advanced trainer with all Zenith features.""" def __init__( self, model: nn.Module, config: TrainerConfig, train_loader: DataLoader, val_loader: Optional[DataLoader] = None, optimizer: Optional[torch.optim.Optimizer] = None, scheduler: Optional[Any] = None, ): self.model = model self.config = config self.train_loader = train_loader self.val_loader = val_loader # Setup device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) # Setup optimizer if optimizer is None: optimizer_config = config.training_config.optimizer self.optimizer = self._create_optimizer(optimizer_config) else: self.optimizer = optimizer # Setup scheduler self.scheduler = scheduler # Mixed precision self.scaler = GradScaler() if config.use_amp and torch.cuda.is_available() else None # Loss self.criterion = MultiTaskLoss(config.data_config.task_weights) # Logging self.metrics_logger = MetricsLogger(config.logging_dir) self.checkpoint_manager = CheckpointManager( config.checkpoint_dir, save_total_limit=config.training_config.save_total_limit, ) # Curriculum sampler self.curriculum_sampler = None if isinstance(train_loader.sampler, CurriculumSampler): self.curriculum_sampler = train_loader.sampler # State self.global_step = 0 self.epoch = 0 logger.info(f"Trainer initialized on {self.device}") def _create_optimizer(self, optimizer_config) -> torch.optim.Optimizer: """Create optimizer from config.""" if optimizer_config.use_8bit: import bitsandbytes as bnb optimizer = bnb.optim.AdamW8bit( self.model.parameters(), lr=optimizer_config.learning_rate, betas=(optimizer_config.beta1, optimizer_config.beta2), weight_decay=optimizer_config.weight_decay, eps=optimizer_config.epsilon, ) else: optimizer = torch.optim.AdamW( self.model.parameters(), lr=optimizer_config.learning_rate, betas=(optimizer_config.beta1, optimizer_config.beta2), weight_decay=optimizer_config.weight_decay, eps=optimizer_config.epsilon, ) return optimizer def train(self): """Main training loop.""" logger.info("Starting training...") # Resume from checkpoint if specified if self.config.resume_from_checkpoint: self._load_checkpoint(self.config.resume_from_checkpoint) max_steps = self.config.training_config.max_steps num_epochs = self.config.training_config.num_train_epochs for epoch in range(self.epoch, num_epochs): self.epoch = epoch # Update curriculum sampler if self.curriculum_sampler: self.curriculum_sampler.set_epoch(epoch) # Train one epoch epoch_loss = self._train_epoch() # Evaluation if self.val_loader and (epoch + 1) % self.config.eval_interval == 0: eval_metrics = self.evaluate() self.metrics_logger.log(eval_metrics, self.global_step, prefix="eval") # Save checkpoint if (epoch + 1) % self.config.save_interval == 0: self._save_checkpoint() logger.info(f"Epoch {epoch} completed. Average loss: {epoch_loss:.4f}") # Final save self._save_checkpoint(final=True) def _train_epoch(self) -> float: """Train for one epoch.""" self.model.train() total_loss = 0.0 num_batches = 0 progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.epoch}") for batch_idx, batch in enumerate(progress_bar): # Move batch to device batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # Forward pass with mixed precision with autocast(enabled=self.config.use_amp, dtype=getattr(torch, self.config.amp_dtype)): outputs = self.model(**batch) loss, task_losses = self.criterion(outputs, batch) # Normalize loss for gradient accumulation loss = loss / self.config.gradient_accumulation_steps # Backward pass if self.scaler: self.scaler.scale(loss).backward() else: loss.backward() # Gradient accumulation if (batch_idx + 1) % self.config.gradient_accumulation_steps == 0: # Gradient clipping if self.config.training_config.max_grad_norm > 0: if self.scaler: self.scaler.unscale_(self.optimizer) torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.training_config.max_grad_norm) # Optimizer step if self.scaler: self.scaler.step(self.optimizer) self.scaler.update() else: self.optimizer.step() self.optimizer.zero_grad() self.global_step += 1 # Scheduler step if self.scheduler: self.scheduler.step() # Logging if self.global_step % self.config.log_interval == 0: self._log_metrics(loss, task_losses, progress_bar) total_loss += loss.item() * self.config.gradient_accumulation_steps num_batches += 1 return total_loss / num_batches if num_batches > 0 else 0.0 def _log_metrics(self, loss: torch.Tensor, task_losses: Dict[str, torch.Tensor], progress_bar: tqdm): """Log metrics to console and logger.""" metrics = {"loss": loss.item()} metrics.update({f"{k}_loss": v.item() for k, v in task_losses.items()}) if self.scheduler: metrics["lr"] = self.scheduler.get_last_lr()[0] self.metrics_logger.log(metrics, self.global_step, prefix="train") # Update progress bar progress_bar.set_postfix(metrics) def evaluate(self) -> Dict[str, float]: """Run evaluation.""" self.model.eval() total_loss = 0.0 num_batches = 0 with torch.no_grad(): for batch in tqdm(self.val_loader, desc="Evaluation"): batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} with autocast(enabled=self.config.use_amp, dtype=getattr(torch, self.config.amp_dtype)): outputs = self.model(**batch) loss, _ = self.criterion(outputs, batch) total_loss += loss.item() num_batches += 1 avg_loss = total_loss / num_batches if num_batches > 0 else 0.0 perplexity = torch.exp(torch.tensor(avg_loss)).item() self.model.train() return {"loss": avg_loss, "perplexity": perplexity} def _save_checkpoint(self, final: bool = False): """Save checkpoint.""" checkpoint = { "epoch": self.epoch, "global_step": self.global_step, "model_state_dict": self.model.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), "scheduler_state_dict": self.scheduler.state_dict() if self.scheduler else None, "scaler_state_dict": self.scaler.state_dict() if self.scaler else None, "config": self.config, } if final: path = self.checkpoint_manager.save_checkpoint(checkpoint, f"final") else: path = self.checkpoint_manager.save_checkpoint(checkpoint, f"step-{self.global_step}") logger.info(f"Checkpoint saved to {path}") def _load_checkpoint(self, path: str): """Load checkpoint.""" logger.info(f"Loading checkpoint from {path}") checkpoint = torch.load(path, map_location=self.device) self.model.load_state_dict(checkpoint["model_state_dict"]) self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) if self.scheduler and checkpoint["scheduler_state_dict"]: self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"]) if self.scaler and checkpoint["scaler_state_dict"]: self.scaler.load_state_dict(checkpoint["scaler_state_dict"]) self.epoch = checkpoint["epoch"] self.global_step = checkpoint["global_step"] logger.info(f"Resumed from epoch {self.epoch}, step {self.global_step}") def train_zenith_model( model: nn.Module, tokenizer: Any, config: TrainerConfig, train_dataset: Any, val_dataset: Optional[Any] = None, ) -> Trainer: """Main training function.""" # Create data processor data_processor = OpenThoughtsProcessor(config.data_config) # Create dataloaders train_loader = data_processor.create_dataloader( train_dataset, batch_size=config.training_config.train_batch_size, shuffle=True, num_workers=config.training_config.dataloader_num_workers, curriculum_epoch=0, ) if val_dataset: val_loader = data_processor.create_dataloader( val_dataset, batch_size=config.training_config.eval_batch_size, shuffle=False, num_workers=config.training_config.dataloader_num_workers, ) else: val_loader = None # Create trainer trainer = Trainer( model=model, config=config, train_loader=train_loader, val_loader=val_loader, ) # Train trainer.train() return trainer