Zenith-7b-V1 / training /trainer.py
Zandy-Wandy's picture
Upload Zenith-7B model
8d18b7c verified
"""Advanced Trainer with Multi-Task Learning, Curriculum, and MoE Support"""
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
from transformers import PreTrainedModel
from ..configs import (
DataConfig,
TrainingConfig,
ZenithConfig,
get_7b_config,
get_32b_config,
get_70b_config,
)
from ..data import (
OpenThoughtsProcessor,
OpenThoughtsConfig,
CurriculumSampler,
QualityFilter,
)
from ..evaluation import BenchmarkSuite, BenchmarkConfig
from ..utils import CheckpointManager, MetricsLogger, setup_logging
logger = logging.getLogger(__name__)
@dataclass
class TrainerConfig:
"""Complete trainer configuration."""
model_config: ZenithConfig
data_config: DataConfig
training_config: TrainingConfig
# Paths
output_dir: str = "./outputs"
logging_dir: str = "./logs"
checkpoint_dir: str = "./checkpoints"
# Distributed training
local_rank: int = -1
world_size: int = 1
distributed: bool = False
# Mixed precision
use_amp: bool = True
amp_dtype: str = "bfloat16"
# Gradient accumulation
gradient_accumulation_steps: int = 4
# Logging and evaluation
log_interval: int = 10
eval_interval: int = 500
save_interval: int = 1000
# Resume
resume_from_checkpoint: Optional[str] = None
def __post_init__(self):
"""Setup derived configs."""
self.training_config.gradient_accumulation_steps = self.gradient_accumulation_steps
class MultiTaskLoss(nn.Module):
"""Multi-task loss for different objectives."""
def __init__(self, task_weights: Dict[str, float]):
super().__init__()
self.task_weights = task_weights
self.loss_fns = {
"next_token": nn.CrossEntropyLoss(ignore_index=-100),
"thoughts": nn.MSELoss(),
"eq_classification": nn.CrossEntropyLoss(),
"frustration_detection": nn.MSELoss(),
}
def forward(
self,
outputs: Dict[str, torch.Tensor],
batch: Dict[str, torch.Tensor],
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Compute weighted multi-task loss."""
total_loss = 0.0
losses = {}
# Next token prediction (primary LM loss)
if "next_token" in self.task_weights:
lm_loss = self.loss_fns["next_token"](
outputs["logits"].view(-1, outputs["logits"].size(-1)),
batch["labels"].view(-1),
)
total_loss += self.task_weights["next_token"] * lm_loss
losses["next_token"] = lm_loss
# Thoughts prediction (auxiliary)
if "thoughts" in self.task_weights and "thoughts_logits" in outputs:
thoughts_loss = self.loss_fns["thoughts"](
outputs["thoughts_logits"],
batch.get("thoughts_labels", torch.zeros_like(outputs["thoughts_logits"])),
)
total_loss += self.task_weights["thoughts"] * thoughts_loss
losses["thoughts"] = thoughts_loss
# Emotion classification
if "eq_classification" in self.task_weights and "emotion_logits" in outputs:
emotion_loss = self.loss_fns["eq_classification"](
outputs["emotion_logits"],
batch.get("emotion_labels", torch.zeros_like(outputs["emotion_logits"][:, 0]).long()),
)
total_loss += self.task_weights["eq_classification"] * emotion_loss
losses["eq_classification"] = emotion_loss
# Frustration detection
if "frustration_detection" in self.task_weights and "frustration_logits" in outputs:
frustration_loss = self.loss_fns["frustration_detection"](
outputs["frustration_logits"].squeeze(-1),
batch.get("frustration_labels", torch.zeros_like(outputs["frustration_logits"].squeeze(-1))),
)
total_loss += self.task_weights["frustration_detection"] * frustration_loss
losses["frustration_detection"] = frustration_loss
losses["total"] = total_loss
return total_loss, losses
class Trainer:
"""Advanced trainer with all Zenith features."""
def __init__(
self,
model: nn.Module,
config: TrainerConfig,
train_loader: DataLoader,
val_loader: Optional[DataLoader] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[Any] = None,
):
self.model = model
self.config = config
self.train_loader = train_loader
self.val_loader = val_loader
# Setup device
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
# Setup optimizer
if optimizer is None:
optimizer_config = config.training_config.optimizer
self.optimizer = self._create_optimizer(optimizer_config)
else:
self.optimizer = optimizer
# Setup scheduler
self.scheduler = scheduler
# Mixed precision
self.scaler = GradScaler() if config.use_amp and torch.cuda.is_available() else None
# Loss
self.criterion = MultiTaskLoss(config.data_config.task_weights)
# Logging
self.metrics_logger = MetricsLogger(config.logging_dir)
self.checkpoint_manager = CheckpointManager(
config.checkpoint_dir,
save_total_limit=config.training_config.save_total_limit,
)
# Curriculum sampler
self.curriculum_sampler = None
if isinstance(train_loader.sampler, CurriculumSampler):
self.curriculum_sampler = train_loader.sampler
# State
self.global_step = 0
self.epoch = 0
logger.info(f"Trainer initialized on {self.device}")
def _create_optimizer(self, optimizer_config) -> torch.optim.Optimizer:
"""Create optimizer from config."""
if optimizer_config.use_8bit:
import bitsandbytes as bnb
optimizer = bnb.optim.AdamW8bit(
self.model.parameters(),
lr=optimizer_config.learning_rate,
betas=(optimizer_config.beta1, optimizer_config.beta2),
weight_decay=optimizer_config.weight_decay,
eps=optimizer_config.epsilon,
)
else:
optimizer = torch.optim.AdamW(
self.model.parameters(),
lr=optimizer_config.learning_rate,
betas=(optimizer_config.beta1, optimizer_config.beta2),
weight_decay=optimizer_config.weight_decay,
eps=optimizer_config.epsilon,
)
return optimizer
def train(self):
"""Main training loop."""
logger.info("Starting training...")
# Resume from checkpoint if specified
if self.config.resume_from_checkpoint:
self._load_checkpoint(self.config.resume_from_checkpoint)
max_steps = self.config.training_config.max_steps
num_epochs = self.config.training_config.num_train_epochs
for epoch in range(self.epoch, num_epochs):
self.epoch = epoch
# Update curriculum sampler
if self.curriculum_sampler:
self.curriculum_sampler.set_epoch(epoch)
# Train one epoch
epoch_loss = self._train_epoch()
# Evaluation
if self.val_loader and (epoch + 1) % self.config.eval_interval == 0:
eval_metrics = self.evaluate()
self.metrics_logger.log(eval_metrics, self.global_step, prefix="eval")
# Save checkpoint
if (epoch + 1) % self.config.save_interval == 0:
self._save_checkpoint()
logger.info(f"Epoch {epoch} completed. Average loss: {epoch_loss:.4f}")
# Final save
self._save_checkpoint(final=True)
def _train_epoch(self) -> float:
"""Train for one epoch."""
self.model.train()
total_loss = 0.0
num_batches = 0
progress_bar = tqdm(self.train_loader, desc=f"Epoch {self.epoch}")
for batch_idx, batch in enumerate(progress_bar):
# Move batch to device
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
# Forward pass with mixed precision
with autocast(enabled=self.config.use_amp, dtype=getattr(torch, self.config.amp_dtype)):
outputs = self.model(**batch)
loss, task_losses = self.criterion(outputs, batch)
# Normalize loss for gradient accumulation
loss = loss / self.config.gradient_accumulation_steps
# Backward pass
if self.scaler:
self.scaler.scale(loss).backward()
else:
loss.backward()
# Gradient accumulation
if (batch_idx + 1) % self.config.gradient_accumulation_steps == 0:
# Gradient clipping
if self.config.training_config.max_grad_norm > 0:
if self.scaler:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.training_config.max_grad_norm)
# Optimizer step
if self.scaler:
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.optimizer.step()
self.optimizer.zero_grad()
self.global_step += 1
# Scheduler step
if self.scheduler:
self.scheduler.step()
# Logging
if self.global_step % self.config.log_interval == 0:
self._log_metrics(loss, task_losses, progress_bar)
total_loss += loss.item() * self.config.gradient_accumulation_steps
num_batches += 1
return total_loss / num_batches if num_batches > 0 else 0.0
def _log_metrics(self, loss: torch.Tensor, task_losses: Dict[str, torch.Tensor], progress_bar: tqdm):
"""Log metrics to console and logger."""
metrics = {"loss": loss.item()}
metrics.update({f"{k}_loss": v.item() for k, v in task_losses.items()})
if self.scheduler:
metrics["lr"] = self.scheduler.get_last_lr()[0]
self.metrics_logger.log(metrics, self.global_step, prefix="train")
# Update progress bar
progress_bar.set_postfix(metrics)
def evaluate(self) -> Dict[str, float]:
"""Run evaluation."""
self.model.eval()
total_loss = 0.0
num_batches = 0
with torch.no_grad():
for batch in tqdm(self.val_loader, desc="Evaluation"):
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
with autocast(enabled=self.config.use_amp, dtype=getattr(torch, self.config.amp_dtype)):
outputs = self.model(**batch)
loss, _ = self.criterion(outputs, batch)
total_loss += loss.item()
num_batches += 1
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
perplexity = torch.exp(torch.tensor(avg_loss)).item()
self.model.train()
return {"loss": avg_loss, "perplexity": perplexity}
def _save_checkpoint(self, final: bool = False):
"""Save checkpoint."""
checkpoint = {
"epoch": self.epoch,
"global_step": self.global_step,
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
"scheduler_state_dict": self.scheduler.state_dict() if self.scheduler else None,
"scaler_state_dict": self.scaler.state_dict() if self.scaler else None,
"config": self.config,
}
if final:
path = self.checkpoint_manager.save_checkpoint(checkpoint, f"final")
else:
path = self.checkpoint_manager.save_checkpoint(checkpoint, f"step-{self.global_step}")
logger.info(f"Checkpoint saved to {path}")
def _load_checkpoint(self, path: str):
"""Load checkpoint."""
logger.info(f"Loading checkpoint from {path}")
checkpoint = torch.load(path, map_location=self.device)
self.model.load_state_dict(checkpoint["model_state_dict"])
self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
if self.scheduler and checkpoint["scheduler_state_dict"]:
self.scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
if self.scaler and checkpoint["scaler_state_dict"]:
self.scaler.load_state_dict(checkpoint["scaler_state_dict"])
self.epoch = checkpoint["epoch"]
self.global_step = checkpoint["global_step"]
logger.info(f"Resumed from epoch {self.epoch}, step {self.global_step}")
def train_zenith_model(
model: nn.Module,
tokenizer: Any,
config: TrainerConfig,
train_dataset: Any,
val_dataset: Optional[Any] = None,
) -> Trainer:
"""Main training function."""
# Create data processor
data_processor = OpenThoughtsProcessor(config.data_config)
# Create dataloaders
train_loader = data_processor.create_dataloader(
train_dataset,
batch_size=config.training_config.train_batch_size,
shuffle=True,
num_workers=config.training_config.dataloader_num_workers,
curriculum_epoch=0,
)
if val_dataset:
val_loader = data_processor.create_dataloader(
val_dataset,
batch_size=config.training_config.eval_batch_size,
shuffle=False,
num_workers=config.training_config.dataloader_num_workers,
)
else:
val_loader = None
# Create trainer
trainer = Trainer(
model=model,
config=config,
train_loader=train_loader,
val_loader=val_loader,
)
# Train
trainer.train()
return trainer