"""Base classes for models, trainers, and datasets.""" from abc import ABC, abstractmethod from pathlib import Path from typing import Optional, Any, Iterator import torch import torch.nn as nn from torch.utils.data import Dataset as TorchDataset from taoTrain.config import TrainingConfig, ModelConfig # ============================================================================ # Base Model # ============================================================================ class BaseModel(nn.Module, ABC): """Abstract base class for language models.""" def __init__(self, config: ModelConfig): """Initialize model with config.""" super().__init__() self.config = config @abstractmethod def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, ) -> dict[str, torch.Tensor]: """ Forward pass. Args: input_ids: Shape (batch_size, seq_length) attention_mask: Shape (batch_size, seq_length), optional labels: Shape (batch_size, seq_length), optional (for loss computation) Returns: Dict with keys: - 'logits': Shape (batch_size, seq_length, vocab_size) - 'loss': Scalar (if labels provided) """ pass def count_parameters(self) -> int: """Count total trainable parameters.""" return sum(p.numel() for p in self.parameters() if p.requires_grad) def get_num_layers(self) -> int: """Get number of layers (for model architecture).""" return self.config.num_layers # ============================================================================ # Base Dataset # ============================================================================ class BaseDataset(TorchDataset, ABC): """Abstract base class for datasets.""" def __init__(self, config: "TrainingConfig"): """Initialize dataset.""" self.config = config self.data = None @abstractmethod def __len__(self) -> int: """Return dataset size.""" pass @abstractmethod def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: """ Get a single sample. Returns: Dict with keys: - 'input_ids': 1D tensor of token IDs - 'attention_mask': 1D tensor of attention mask - 'labels': 1D tensor of labels (optional) """ pass def load_dataset(self) -> None: """Load dataset from HuggingFace or other source.""" pass def preprocess(self) -> None: """Preprocess dataset (tokenization, etc).""" pass # ============================================================================ # Base Trainer # ============================================================================ class BaseTrainer(ABC): """Abstract base class for trainers.""" def __init__( self, model: BaseModel, train_dataset: BaseDataset, val_dataset: Optional[BaseDataset], config: TrainingConfig, device: torch.device, ): """Initialize trainer.""" self.model = model.to(device) self.train_dataset = train_dataset self.val_dataset = val_dataset self.config = config self.device = device # Training state self.global_step = 0 self.current_epoch = 0 self.best_loss = float('inf') # Logging self.logger = None # Optimizer and scheduler (to be set up by subclass) self.optimizer = None self.scheduler = None @abstractmethod def training_step(self, batch: dict[str, torch.Tensor]) -> dict[str, float]: """ Single training step. Args: batch: Training batch with input_ids, attention_mask, labels, etc. Returns: Dict with metrics (e.g., {'loss': 0.5, 'accuracy': 0.8}) """ pass @abstractmethod def validation_step(self, batch: dict[str, torch.Tensor]) -> dict[str, float]: """ Single validation step. Args: batch: Validation batch Returns: Dict with validation metrics """ pass @abstractmethod def train_epoch(self) -> dict[str, float]: """ Train for one epoch. Returns: Dict with epoch-level metrics """ pass @abstractmethod def validate(self) -> dict[str, float]: """ Run validation on the entire validation set. Returns: Dict with validation metrics """ pass def save_checkpoint(self, path: str | Path) -> None: """ Save checkpoint in canonical format. Uses canonical checkpoint format: { 'step': int, 'model_state': state_dict, 'optimizer_state': state_dict, 'config': dict, 'metrics': dict, 'global_step': int, # Legacy compat 'current_epoch': int, # Legacy compat 'best_loss': float, # Legacy compat } Args: path: Path to save checkpoint """ path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) # Save in canonical format checkpoint = { # Canonical format keys 'step': self.global_step, 'model_state': self.model.state_dict(), 'optimizer_state': self.optimizer.state_dict() if self.optimizer else None, 'config': self.config.to_dict(), 'metrics': {}, # Legacy format keys (for backward compatibility with code that reads them) 'global_step': self.global_step, 'current_epoch': self.current_epoch, 'best_loss': self.best_loss, } torch.save(checkpoint, path) def load_checkpoint(self, path: str | Path) -> None: """ Load checkpoint (handles both canonical and legacy formats). Args: path: Path to checkpoint """ path = Path(path) checkpoint = torch.load(path, map_location=self.device) # Try canonical keys first, fall back to legacy keys model_state_key = 'model_state' if 'model_state' in checkpoint else 'model_state_dict' optimizer_state_key = 'optimizer_state' if 'optimizer_state' in checkpoint else 'optimizer_state_dict' self.model.load_state_dict(checkpoint[model_state_key]) if self.optimizer and checkpoint.get(optimizer_state_key): self.optimizer.load_state_dict(checkpoint[optimizer_state_key]) # Try canonical 'step' first, fall back to legacy 'global_step' self.global_step = checkpoint.get('step', checkpoint.get('global_step', 0)) self.current_epoch = checkpoint.get('current_epoch', 0) self.best_loss = checkpoint.get('best_loss', float('inf')) def _get_lr(self) -> float: """Get current learning rate from optimizer.""" for param_group in self.optimizer.param_groups: return param_group['lr'] return 0.0 # ============================================================================ # Utility functions # ============================================================================ def create_model(config: TrainingConfig, device: torch.device) -> BaseModel: """Create model from config (calls registry).""" from taoTrain.models import get_model return get_model(config.model, device=device) def create_datasets( config: TrainingConfig, ) -> tuple[BaseDataset, Optional[BaseDataset]]: """Create train and validation datasets using factory pattern.""" # Import here to avoid circular imports from taoTrain.data import DatasetFactory # Create train dataset train_dataset = DatasetFactory.create_dataset(config, split="train") # Create validation dataset (only for HuggingFace datasets with explicit validation split) val_dataset = None if not config.dataset.local and hasattr(config.dataset, "validation_split"): val_dataset = DatasetFactory.create_dataset(config, split="validation") return train_dataset, val_dataset