StarMist0012's picture
Add files using upload-large-folder tool
05b535a verified
"""Training implementations for pretrain, SFT, and RL."""
import time
from typing import Optional, Dict, Tuple
from pathlib import Path
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from taoTrain.core.base import BaseModel, BaseTrainer
from taoTrain.config import TrainingConfig, PretrainConfig, SFTConfig, RLConfig
from taoTrain.data.loaders import get_dataloader
from taoTrain.data.async_loader import AsyncBatchIterator
from taoTrain.data.tokenization_queue import TokenizationQueue
from taoTrain.logging import AimLogger
from taoTrain.optimizers import get_optimizer
from taoTrain.schedulers import get_scheduler
from taoTrain.utils import set_seed, get_dtype
# ============================================================================
# Metrics
# ============================================================================
class MetricsTracker:
"""Track training and validation metrics."""
def __init__(self):
"""Initialize tracker."""
self.metrics = {}
def update(self, metrics: Dict[str, float]):
"""Update metrics."""
for key, value in metrics.items():
if key not in self.metrics:
self.metrics[key] = []
self.metrics[key].append(value)
def get_average(self) -> Dict[str, float]:
"""Get average of all metrics."""
return {
key: sum(values) / len(values)
for key, values in self.metrics.items()
if values
}
def get_latest(self) -> Dict[str, float]:
"""Get latest value of all metrics."""
return {
key: values[-1] if values else 0.0
for key, values in self.metrics.items()
}
def reset(self):
"""Reset metrics."""
self.metrics = {}
# ============================================================================
# Base Trainer Implementation
# ============================================================================
class BaseTrainerImpl(BaseTrainer):
"""Base trainer implementation with common functionality."""
def __init__(
self,
model: BaseModel,
train_dataset,
val_dataset,
config: TrainingConfig,
device: torch.device,
):
"""Initialize trainer."""
super().__init__(model, train_dataset, val_dataset, config, device)
# Setup optimizer and scheduler using factories
print("\n✓ Setting up optimizer and scheduler...")
self.optimizer = get_optimizer(self.model, config)
# Compute number of training steps for scheduler
num_training_steps = self._compute_num_steps()
self.scheduler = get_scheduler(self.optimizer, config, num_training_steps)
print(f"✓ Optimizer and scheduler setup complete. Total training steps: {num_training_steps}")
# Setup AimStack logging
self.logger = AimLogger(config)
print("✓ AimStack logger initialized.")
# Data type
self.dtype = get_dtype(config.dtype.value)
self.use_autocast = config.dtype.value != "float32"
# Metrics tracker
self.train_metrics = MetricsTracker()
self.val_metrics = MetricsTracker()
# Setup async loading if using JSONL datasets
print("\n✓ Setting up data loading...")
self._setup_async_loading(train_dataset)
def _compute_num_steps(self) -> int:
"""Compute total training steps."""
if self.config.max_steps:
return self.config.max_steps
num_steps_per_epoch = (
len(self.train_dataset) // (self.config.batch_size * self.config.gradient_accumulation_steps)
)
return num_steps_per_epoch * self.config.num_epochs
def _setup_async_loading(self, train_dataset):
"""
Setup async loading for JSONL datasets.
For JSONL datasets, creates TokenizationQueue and AsyncBatchIterator.
For HuggingFace datasets, sets async_loader to None.
Note: All JSONL datasets now operate in async-only mode.
"""
self.async_loader = None
# Check if this is a JSONL-based dataset
from taoTrain.data.jsonl_base import BaseJSONLDataset
print("\n✓ Checking dataset type for async loading...")
if isinstance(train_dataset, BaseJSONLDataset):
# Set up async loading pipeline
print("✓ Detected JSONL dataset, setting up async loading...")
# Create tokenization queue
print("✓ Creating TokenizationQueue...")
tokenization_queue = TokenizationQueue(
chunk_manager=self.train_dataset.chunk_manager,
tokenizer=self.train_dataset.tokenizer,
config=self.config,
max_queue_size=32, # Memory constraint
shuffle_chunks=True,
num_threads=self.config.dataset.tokenizer_threads,
)
# Create async batch iterator
print("✓ Creating AsyncBatchIterator...")
self.async_loader = AsyncBatchIterator(
tokenization_queue=tokenization_queue,
batch_size=self.config.batch_size,
device=self.device,
drop_last=True,
gradient_accumulation_steps=self.config.gradient_accumulation_steps,
)
def training_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
"""
Single training step.
Args:
batch: Training batch
Returns:
Dict with loss and other metrics
"""
self.model.train()
# Move batch to device (may already be on device for async loader)
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
for k, v in batch.items()}
# Forward pass with mixed precision
with torch.autocast(device_type="cuda" if self.device.type == "cuda" else "cpu",
dtype=torch.bfloat16 if self.use_autocast else torch.float32,
enabled=self.use_autocast):
outputs = self.model(
input_ids=batch["input_ids"],
attention_mask=batch.get("attention_mask"),
labels=batch.get("labels"),
)
loss = outputs["loss"]
# Backward pass
if self.config.gradient_accumulation_steps > 1:
loss = loss / self.config.gradient_accumulation_steps
loss.backward()
# Only step optimizer every N accumulation steps
accumulation_counter = (self.global_step + 1) % self.config.gradient_accumulation_steps
if accumulation_counter == 0 or self.config.gradient_accumulation_steps == 1:
# Gradient clipping
if self.config.max_grad_norm > 0:
torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.config.max_grad_norm
)
self.optimizer.step()
self.scheduler.step()
self.optimizer.zero_grad()
# Store unaccumulated loss for logging
raw_loss = loss.item() * (self.config.gradient_accumulation_steps or 1)
return {
"loss": raw_loss,
"lr": self.scheduler.get_last_lr()[0],
}
def validation_step(self, batch: Dict[str, torch.Tensor]) -> Dict[str, float]:
"""Single validation step (just compute loss)."""
self.model.eval()
# Move batch to device
batch = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v
for k, v in batch.items()}
with torch.no_grad():
with torch.autocast(device_type="cuda" if self.device.type == "cuda" else "cpu",
dtype=torch.bfloat16 if self.use_autocast else torch.float32,
enabled=self.use_autocast):
outputs = self.model(
input_ids=batch["input_ids"],
attention_mask=batch.get("attention_mask"),
labels=batch.get("labels"),
)
loss = outputs["loss"]
return {"val_loss": loss.item()}
def train_epoch(self) -> Dict[str, float]:
"""
Train for one epoch.
Returns:
Dict with epoch metrics
"""
self.current_epoch += 1
self.train_metrics.reset()
# Use async loader for JSONL datasets, regular DataLoader for HuggingFace
if self.async_loader is not None:
print("\n✓ Using AsyncBatchIterator for training...")
train_iterator = self.async_loader
else:
print("\n✓ Creating DataLoader for training dataset...")
train_loader = get_dataloader(
self.train_dataset,
self.config,
shuffle=True,
drop_last=True,
)
train_iterator = train_loader
pbar = tqdm(train_iterator, desc=f"Epoch {self.current_epoch}")
for batch_idx, batch in enumerate(pbar):
# Check if we've hit max steps
if self.config.max_steps and self.global_step >= self.config.max_steps:
print(f"\n✓ Reached max steps ({self.global_step}), ending training.")
break
# Training step
metrics = self.training_step(batch)
self.train_metrics.update(metrics)
self.global_step += 1
pbar.set_postfix(self.train_metrics.get_latest())
# Logging
if self.global_step % self.config.log_every_steps == 0:
latest_metrics = self.train_metrics.get_latest()
log_dict = {"step": self.global_step, "epoch": self.current_epoch}
log_dict.update(latest_metrics)
self.logger.log_metrics(log_dict)
# Validation
if self.global_step % self.config.eval_every_steps == 0:
val_metrics = self.validate()
self.logger.log_metrics({"step": self.global_step, **val_metrics})
# Save checkpoint if best
if val_metrics.get("val_loss", float('inf')) < self.best_loss:
self.best_loss = val_metrics["val_loss"]
if self.config.save_best_model:
ckpt_path = Path(self.config.checkpoint_dir) / "best_model.pt"
self.save_checkpoint(ckpt_path)
# Save periodic checkpoint
if self.global_step % self.config.save_every_steps == 0:
ckpt_path = Path(self.config.checkpoint_dir) / f"checkpoint_step_{self.global_step}.pt"
self.save_checkpoint(ckpt_path)
print(f"\n✓ Finished epoch {self.current_epoch}.")
return self.train_metrics.get_average()
def validate(self) -> Dict[str, float]:
"""Run validation."""
if self.val_dataset is None:
return {}
val_loader = get_dataloader(
self.val_dataset,
self.config,
shuffle=False,
drop_last=False,
)
self.val_metrics.reset()
with torch.no_grad():
for batch in tqdm(val_loader, desc="Validating", disable=True):
metrics = self.validation_step(batch)
self.val_metrics.update(metrics)
return self.val_metrics.get_average()
# ============================================================================
# Stage-Specific Trainers
# ============================================================================
class PretrainTrainer(BaseTrainerImpl):
"""Trainer for pretraining."""
pass # Inherits all from BaseTrainerImpl
class SFTTrainer(BaseTrainerImpl):
"""Trainer for supervised fine-tuning."""
pass # Can add SFT-specific logic here if needed
class RLTrainer(BaseTrainerImpl):
"""Trainer for reinforcement learning."""
# Will implement PPO/DPO logic in separate module
pass