"""Training implementations for pretrain, SFT, and RL.""" import time from typing import Optional, Dict, Tuple from pathlib import Path import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from tqdm import tqdm from taoTrain.core.base import BaseModel, BaseTrainer from taoTrain.config import TrainingConfig, PretrainConfig, SFTConfig, RLConfig from taoTrain.data.loaders import get_dataloader from taoTrain.data.async_loader import AsyncBatchIterator from taoTrain.data.tokenization_queue import TokenizationQueue from taoTrain.logging import AimLogger from taoTrain.optimizers import get_optimizer from taoTrain.schedulers import get_scheduler from taoTrain.utils import set_seed, get_dtype # ============================================================================ # Metrics # ============================================================================ class MetricsTracker: """Track training and validation metrics.""" def __init__(self): """Initialize tracker.""" self.metrics = {} def update(self, metrics: Dict[str, float]): """Update metrics.""" for key, value in metrics.items(): if key not in self.metrics: self.metrics[key] = [] self.metrics[key].append(value) def get_average(self) -> Dict[str, float]: """Get average of all metrics.""" return { key: sum(values) / len(values) for key, values in self.metrics.items() if values } def get_latest(self) -> Dict[str, float]: """Get latest value of all metrics.""" return { key: values[-1] if values else 0.0 for key, values in self.metrics.items() } def reset(self): """Reset metrics.""" self.metrics = {} # ============================================================================ # Base Trainer Implementation # ============================================================================ class BaseTrainerImpl(BaseTrainer): """Base trainer implementation with common functionality.""" def __init__( self, model: BaseModel, train_dataset, val_dataset, config: TrainingConfig, device: torch.device, ): """Initialize trainer.""" super().__init__(model, train_dataset, val_dataset, config, device) # Setup optimizer and scheduler using factories print("\n✓ Setting up optimizer and scheduler...") self.optimizer = get_optimizer(self.model, config) # Compute number of training steps for scheduler num_training_steps = self._compute_num_steps() self.scheduler = get_scheduler(self.optimizer, config, num_training_steps) print(f"✓ Optimizer and scheduler setup complete. Total training steps: {num_training_steps}") # Setup AimStack logging self.logger = AimLogger(config) print("✓ AimStack logger initialized.") # Data type self.dtype = get_dtype(config.dtype.value) self.use_autocast = config.dtype.value != "float32" # Metrics tracker self.train_metrics = MetricsTracker() self.val_metrics = MetricsTracker() # Setup async loading if using JSONL datasets print("\n✓ Setting up data loading...") self._setup_async_loading(train_dataset) def _compute_num_steps(self) -> int: """Compute total training steps.""" if self.config.max_steps: return self.config.max_steps num_steps_per_epoch = ( len(self.train_dataset) // (self.config.batch_size * self.config.gradient_accumulation_steps) ) return num_steps_per_epoch * self.config.num_epochs def _setup_async_loading(self, train_dataset): """ Setup async loading for JSONL datasets. For JSONL datasets, creates TokenizationQueue and AsyncBatchIterator. For HuggingFace datasets, sets async_loader to None. Note: All JSONL datasets now operate in async-only mode. """ self.async_loader = None # Check if this is a JSONL-based dataset from taoTrain.data.jsonl_base import BaseJSONLDataset print("\n✓ Checking dataset type for async loading...") if isinstance(train_dataset, BaseJSONLDataset): # Set up async loading pipeline print("✓ Detected JSONL dataset, setting up async loading...") # Create tokenization queue print("✓ Creating TokenizationQueue...") tokenization_queue = TokenizationQueue( chunk_manager=self.train_dataset.chunk_manager, tokenizer=self.train_dataset.tokenizer, config=self.config, max_queue_size=32, # Memory constraint shuffle_chunks=True, num_threads=self.config.dataset.tokenizer_threads, ) # Create async batch iterator print("✓ Creating AsyncBatchIterator...") self.async_loader = AsyncBatchIterator( tokenization_queue=tokenization_queue, batch_size=self.config.batch_size, device=self.device, drop_last=True, gradient_accumulation_steps=self.config.gradient_accumulation_steps, ) def training_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]: """ Single training step. Args: batch: Training batch Returns: Dict with loss and other metrics """ self.model.train() # Move batch to device (may already be on device for async loader) batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} # Forward pass with mixed precision with torch.autocast(device_type="cuda" if self.device.type == "cuda" else "cpu", dtype=torch.bfloat16 if self.use_autocast else torch.float32, enabled=self.use_autocast): outputs = self.model( input_ids=batch["input_ids"], attention_mask=batch.get("attention_mask"), labels=batch.get("labels"), ) loss = outputs["loss"] # Backward pass if self.config.gradient_accumulation_steps > 1: loss = loss / self.config.gradient_accumulation_steps loss.backward() # Only step optimizer every N accumulation steps accumulation_counter = (self.global_step + 1) % self.config.gradient_accumulation_steps if accumulation_counter == 0 or self.config.gradient_accumulation_steps == 1: # Gradient clipping if self.config.max_grad_norm > 0: torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.config.max_grad_norm ) self.optimizer.step() self.scheduler.step() self.optimizer.zero_grad() # Store unaccumulated loss for logging raw_loss = loss.item() * (self.config.gradient_accumulation_steps or 1) return { "loss": raw_loss, "lr": self.scheduler.get_last_lr()[0], } def validation_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]: """Single validation step (just compute loss).""" self.model.eval() # Move batch to device batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()} with torch.no_grad(): with torch.autocast(device_type="cuda" if self.device.type == "cuda" else "cpu", dtype=torch.bfloat16 if self.use_autocast else torch.float32, enabled=self.use_autocast): outputs = self.model( input_ids=batch["input_ids"], attention_mask=batch.get("attention_mask"), labels=batch.get("labels"), ) loss = outputs["loss"] return {"val_loss": loss.item()} def train_epoch(self) -> Dict[str, float]: """ Train for one epoch. Returns: Dict with epoch metrics """ self.current_epoch += 1 self.train_metrics.reset() # Use async loader for JSONL datasets, regular DataLoader for HuggingFace if self.async_loader is not None: print("\n✓ Using AsyncBatchIterator for training...") train_iterator = self.async_loader else: print("\n✓ Creating DataLoader for training dataset...") train_loader = get_dataloader( self.train_dataset, self.config, shuffle=True, drop_last=True, ) train_iterator = train_loader pbar = tqdm(train_iterator, desc=f"Epoch {self.current_epoch}") for batch_idx, batch in enumerate(pbar): # Check if we've hit max steps if self.config.max_steps and self.global_step >= self.config.max_steps: print(f"\n✓ Reached max steps ({self.global_step}), ending training.") break # Training step metrics = self.training_step(batch) self.train_metrics.update(metrics) self.global_step += 1 pbar.set_postfix(self.train_metrics.get_latest()) # Logging if self.global_step % self.config.log_every_steps == 0: latest_metrics = self.train_metrics.get_latest() log_dict = {"step": self.global_step, "epoch": self.current_epoch} log_dict.update(latest_metrics) self.logger.log_metrics(log_dict) # Validation if self.global_step % self.config.eval_every_steps == 0: val_metrics = self.validate() self.logger.log_metrics({"step": self.global_step, **val_metrics}) # Save checkpoint if best if val_metrics.get("val_loss", float('inf')) < self.best_loss: self.best_loss = val_metrics["val_loss"] if self.config.save_best_model: ckpt_path = Path(self.config.checkpoint_dir) / "best_model.pt" self.save_checkpoint(ckpt_path) # Save periodic checkpoint if self.global_step % self.config.save_every_steps == 0: ckpt_path = Path(self.config.checkpoint_dir) / f"checkpoint_step_{self.global_step}.pt" self.save_checkpoint(ckpt_path) print(f"\n✓ Finished epoch {self.current_epoch}.") return self.train_metrics.get_average() def validate(self) -> Dict[str, float]: """Run validation.""" if self.val_dataset is None: return {} val_loader = get_dataloader( self.val_dataset, self.config, shuffle=False, drop_last=False, ) self.val_metrics.reset() with torch.no_grad(): for batch in tqdm(val_loader, desc="Validating", disable=True): metrics = self.validation_step(batch) self.val_metrics.update(metrics) return self.val_metrics.get_average() # ============================================================================ # Stage-Specific Trainers # ============================================================================ class PretrainTrainer(BaseTrainerImpl): """Trainer for pretraining.""" pass # Inherits all from BaseTrainerImpl class SFTTrainer(BaseTrainerImpl): """Trainer for supervised fine-tuning.""" pass # Can add SFT-specific logic here if needed class RLTrainer(BaseTrainerImpl): """Trainer for reinforcement learning.""" # Will implement PPO/DPO logic in separate module pass