| """Advanced Trainer with Multi-Task Learning, Curriculum, and MoE Support"""
|
|
|
| import logging
|
| import os
|
| from dataclasses import dataclass, field
|
| from pathlib import Path
|
| from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
| import torch
|
| import torch.nn as nn
|
| from torch.utils.data import DataLoader
|
| from torch.cuda.amp import GradScaler, autocast
|
| from tqdm import tqdm
|
| from transformers import PreTrainedModel
|
|
|
| from ..configs import (
|
| DataConfig,
|
| TrainingConfig,
|
| ZenithConfig,
|
| get_7b_config,
|
| get_32b_config,
|
| get_70b_config,
|
| )
|
| from ..data import (
|
| OpenThoughtsProcessor,
|
| OpenThoughtsConfig,
|
| CurriculumSampler,
|
| QualityFilter,
|
| )
|
| from ..evaluation import BenchmarkSuite, BenchmarkConfig
|
| from ..utils import CheckpointManager, MetricsLogger, setup_logging
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
| @dataclass
|
| class TrainerConfig:
|
| """Complete trainer configuration."""
|
| model_config: ZenithConfig
|
| data_config: DataConfig
|
| training_config: TrainingConfig
|
|
|
|
|
| output_dir: str = "./outputs"
|
| logging_dir: str = "./logs"
|
| checkpoint_dir: str = "./checkpoints"
|
|
|
|
|
| local_rank: int = -1
|
| world_size: int = 1
|
| distributed: bool = False
|
|
|
|
|
| use_amp: bool = True
|
| amp_dtype: str = "bfloat16"
|
|
|
|
|
| gradient_accumulation_steps: int = 4
|
|
|
|
|
| log_interval: int = 10
|
| eval_interval: int = 500
|
| save_interval: int = 1000
|
|
|
|
|
| resume_from_checkpoint: Optional[str] = None
|
|
|
| def __post_init__(self):
|
| """Setup derived configs."""
|
| self.training_config.gradient_accumulation_steps = self.gradient_accumulation_steps
|
|
|
|
|
| class MultiTaskLoss(nn.Module):
|
| """Multi-task loss for different objectives."""
|
|
|
| def __init__(self, task_weights: Dict[str, float]):
|
| super().__init__()
|
| self.task_weights = task_weights
|
| self.loss_fns = {
|
| "next_token": nn.CrossEntropyLoss(ignore_index=-100),
|
| "thoughts": nn.MSELoss(),
|
| "eq_classification": nn.CrossEntropyLoss(),
|
| "frustration_detection": nn.MSELoss(),
|
| }
|
|
|
| def forward(
|
| self,
|
| outputs: Dict[str, torch.Tensor],
|
| batch: Dict[str, torch.Tensor],
|
| ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
| """Compute weighted multi-task loss."""
|
| total_loss = 0.0
|
| losses = {}
|
|
|
|
|
| if "next_token" in self.task_weights:
|
| lm_loss = self.loss_fns["next_token"](
|
| outputs["logits"].view(-1, outputs["logits"].size(-1)),
|
| batch["labels"].view(-1),
|
| )
|
| total_loss += self.task_weights["next_token"] * lm_loss
|
| losses["next_token"] = lm_loss
|
|
|
|
|
| if "thoughts" in self.task_weights and "thoughts_logits" in outputs:
|
| thoughts_loss = self.loss_fns["thoughts"](
|
| outputs["thoughts_logits"],
|
| batch.get("thoughts_labels", torch.zeros_like(outputs["thoughts_logits"])),
|
| )
|
| total_loss += self.task_weights["thoughts"] * thoughts_loss
|
| losses["thoughts"] = thoughts_loss
|
|
|
|
|
| if "eq_classification" in self.task_weights and "emotion_logits" in outputs:
|
| emotion_loss = self.loss_fns["eq_classification"](
|
| outputs["emotion_logits"],
|
| batch.get("emotion_labels", torch.zeros_like(outputs["emotion_logits"][:, 0]).long()),
|
| )
|
| total_loss += self.task_weights["eq_classification"] * emotion_loss
|
| losses["eq_classification"] = emotion_loss
|
|
|
|
|
| if "frustration_detection" in self.task_weights and "frustration_logits" in outputs:
|
| frustration_loss = self.loss_fns["frustration_detection"](
|
| outputs["frustration_logits"].squeeze(-1),
|
| batch.get("frustration_labels", torch.zeros_like(outputs["frustration_logits"].squeeze(-1))),
|
| )
|
| total_loss += self.task_weights["frustration_detection"] * frustration_loss
|
| losses["frustration_detection"] = frustration_loss
|
|
|
| losses["total"] = total_loss
|
| return total_loss, losses
|
|
|
|
|
| class Trainer:
|
| """Advanced trainer with all Zenith features."""
|
|
|
| def __init__(
|
| self,
|
| model: nn.Module,
|
| config: TrainerConfig,
|
| train_loader: DataLoader,
|
| val_loader: Optional[DataLoader] = None,
|
| optimizer: Optional[torch.optim.Optimizer] = None,
|
| scheduler: Optional[Any] = None,
|
| ):
|
| self.model = model
|
| self.config = config
|
| self.train_loader = train_loader
|
| self.val_loader = val_loader
|
|
|
|
|
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| self.model.to(self.device)
|
|
|
|
|
| if optimizer is None:
|
| optimizer_config = config.training_config.optimizer
|
| self.optimizer = self._create_optimizer(optimizer_config)
|
| else:
|
| self.optimizer = optimizer
|
|
|
|
|
| self.scheduler = scheduler
|
|
|
|
|
| self.scaler = GradScaler() if config.use_amp and torch.cuda.is_available() else None
|
|
|
|
|
| self.criterion = MultiTaskLoss(config.data_config.task_weights)
|
|
|
|
|
| self.metrics_logger = MetricsLogger(config.logging_dir)
|
| self.checkpoint_manager = CheckpointManager(
|
| config.checkpoint_dir,
|
| save_total_limit=config.training_config.save_total_limit,
|
| )
|
|
|
|
|
| self.curriculum_sampler = None
|
| if isinstance(train_loader.sampler, CurriculumSampler):
|
| self.curriculum_sampler = train_loader.sampler
|
|
|
|
|
| self.global_step = 0
|
| self.epoch = 0
|
|
|
| logger.info(f"Trainer initialized on {self.device}")
|
|
|
| def _create_optimizer(self, optimizer_config) -> torch.optim.Optimizer:
|
| """Create optimizer from config."""
|
| if optimizer_config.use_8bit:
|
| import bitsandbytes as bnb
|
| optimizer = bnb.optim.AdamW8bit(
|
| self.model.parameters(),
|
| lr=optimizer_config.learning_rate,
|
| betas=(optimizer_config.beta1, optimizer_config.beta2),
|
| weight_decay=optimizer_config.weight_decay,
|
| eps=optimizer_config.epsilon,
|
| )
|
| else:
|
| optimizer = torch.optim.AdamW(
|
| self.model.parameters(),
|
| lr=optimizer_config.learning_rate,
|
| betas=(optimizer_config.beta1, optimizer_config.beta2),
|
| weight_decay=optimizer_config.weight_decay,
|
| eps=optimizer_config.epsilon,
|
| )
|
| return optimizer
|
|
|
| def train(self):
|
| """Main training loop."""
|
| logger.info("Starting training...")
|
|
|
|
|
| if self.config.resume_from_checkpoint:
|
| self._load_checkpoint(self.config.resume_from_checkpoint)
|
|
|
| max_steps = self.config.training_config.max_steps
|
| num_epochs = self.config.training_config.num_train_epochs
|
|
|
| for epoch in range(self.epoch, num_epochs):
|
| self.epoch = epoch
|
|
|
|
|
| if self.curriculum_sampler:
|
| self.curriculum_sampler.set_epoch(epoch)
|
|
|
|
|
| epoch_loss = self._train_epoch()
|
|
|
|
|
| if self.val_loader and (epoch + 1) % self.config.eval_interval == 0:
|
| eval_metrics = self.evaluate()
|
| self.metrics_logger.log(eval_metrics, self.global_step, prefix="eval")
|
|
|
|
|
| if (epoch + 1) % self.config.save_interval == 0:
|
| self._save_checkpoint()
|
|
|
| logger.info(f"Epoch {epoch} completed. Average loss: {epoch_loss:.4f}")
|
|
|
|
|
| self._save_checkpoint(final=True)
|
|
|
| def _train_epoch(self) -> float:
|
| """Train for one epoch."""
|
| self.model.train()
|
|
|
| total_loss = 0.0
|
| num_batches = 0
|
|
|
| progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.epoch}")
|
| for batch_idx, batch in enumerate(progress_bar):
|
|
|
| batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
|
|
|
|
|
| with autocast(enabled=self.config.use_amp, dtype=getattr(torch, self.config.amp_dtype)):
|
| outputs = self.model(**batch)
|
| loss, task_losses = self.criterion(outputs, batch)
|
|
|
|
|
| loss = loss / self.config.gradient_accumulation_steps
|
|
|
|
|
| if self.scaler:
|
| self.scaler.scale(loss).backward()
|
| else:
|
| loss.backward()
|
|
|
|
|
| if (batch_idx + 1) % self.config.gradient_accumulation_steps == 0:
|
|
|
| if self.config.training_config.max_grad_norm > 0:
|
| if self.scaler:
|
| self.scaler.unscale_(self.optimizer)
|
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.training_config.max_grad_norm)
|
|
|
|
|
| if self.scaler:
|
| self.scaler.step(self.optimizer)
|
| self.scaler.update()
|
| else:
|
| self.optimizer.step()
|
|
|
| self.optimizer.zero_grad()
|
| self.global_step += 1
|
|
|
|
|
| if self.scheduler:
|
| self.scheduler.step()
|
|
|
|
|
| if self.global_step % self.config.log_interval == 0:
|
| self._log_metrics(loss, task_losses, progress_bar)
|
|
|
| total_loss += loss.item() * self.config.gradient_accumulation_steps
|
| num_batches += 1
|
|
|
| return total_loss / num_batches if num_batches > 0 else 0.0
|
|
|
| def _log_metrics(self, loss: torch.Tensor, task_losses: Dict[str, torch.Tensor], progress_bar: tqdm):
|
| """Log metrics to console and logger."""
|
| metrics = {"loss": loss.item()}
|
| metrics.update({f"{k}_loss": v.item() for k, v in task_losses.items()})
|
|
|
| if self.scheduler:
|
| metrics["lr"] = self.scheduler.get_last_lr()[0]
|
|
|
| self.metrics_logger.log(metrics, self.global_step, prefix="train")
|
|
|
|
|
| progress_bar.set_postfix(metrics)
|
|
|
| def evaluate(self) -> Dict[str, float]:
|
| """Run evaluation."""
|
| self.model.eval()
|
| total_loss = 0.0
|
| num_batches = 0
|
|
|
| with torch.no_grad():
|
| for batch in tqdm(self.val_loader, desc="Evaluation"):
|
| batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
|
|
|
| with autocast(enabled=self.config.use_amp, dtype=getattr(torch, self.config.amp_dtype)):
|
| outputs = self.model(**batch)
|
| loss, _ = self.criterion(outputs, batch)
|
|
|
| total_loss += loss.item()
|
| num_batches += 1
|
|
|
| avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
|
| perplexity = torch.exp(torch.tensor(avg_loss)).item()
|
|
|
| self.model.train()
|
| return {"loss": avg_loss, "perplexity": perplexity}
|
|
|
| def _save_checkpoint(self, final: bool = False):
|
| """Save checkpoint."""
|
| checkpoint = {
|
| "epoch": self.epoch,
|
| "global_step": self.global_step,
|
| "model_state_dict": self.model.state_dict(),
|
| "optimizer_state_dict": self.optimizer.state_dict(),
|
| "scheduler_state_dict": self.scheduler.state_dict() if self.scheduler else None,
|
| "scaler_state_dict": self.scaler.state_dict() if self.scaler else None,
|
| "config": self.config,
|
| }
|
|
|
| if final:
|
| path = self.checkpoint_manager.save_checkpoint(checkpoint, f"final")
|
| else:
|
| path = self.checkpoint_manager.save_checkpoint(checkpoint, f"step-{self.global_step}")
|
|
|
| logger.info(f"Checkpoint saved to {path}")
|
|
|
| def _load_checkpoint(self, path: str):
|
| """Load checkpoint."""
|
| logger.info(f"Loading checkpoint from {path}")
|
| checkpoint = torch.load(path, map_location=self.device)
|
|
|
| self.model.load_state_dict(checkpoint["model_state_dict"])
|
| self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
| if self.scheduler and checkpoint["scheduler_state_dict"]:
|
| self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
| if self.scaler and checkpoint["scaler_state_dict"]:
|
| self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
|
|
|
| self.epoch = checkpoint["epoch"]
|
| self.global_step = checkpoint["global_step"]
|
|
|
| logger.info(f"Resumed from epoch {self.epoch}, step {self.global_step}")
|
|
|
|
|
| def train_zenith_model(
|
| model: nn.Module,
|
| tokenizer: Any,
|
| config: TrainerConfig,
|
| train_dataset: Any,
|
| val_dataset: Optional[Any] = None,
|
| ) -> Trainer:
|
| """Main training function."""
|
|
|
|
|
| data_processor = OpenThoughtsProcessor(config.data_config)
|
|
|
|
|
| train_loader = data_processor.create_dataloader(
|
| train_dataset,
|
| batch_size=config.training_config.train_batch_size,
|
| shuffle=True,
|
| num_workers=config.training_config.dataloader_num_workers,
|
| curriculum_epoch=0,
|
| )
|
|
|
| if val_dataset:
|
| val_loader = data_processor.create_dataloader(
|
| val_dataset,
|
| batch_size=config.training_config.eval_batch_size,
|
| shuffle=False,
|
| num_workers=config.training_config.dataloader_num_workers,
|
| )
|
| else:
|
| val_loader = None
|
|
|
|
|
| trainer = Trainer(
|
| model=model,
|
| config=config,
|
| train_loader=train_loader,
|
| val_loader=val_loader,
|
| )
|
|
|
|
|
| trainer.train()
|
|
|
| return trainer
|
|
|