StarMist0012's picture
Add files using upload-large-folder tool
3270dae verified
"""Base classes for models, trainers, and datasets."""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional, Any, Iterator
import torch
import torch.nn as nn
from torch.utils.data import Dataset as TorchDataset
from taoTrain.config import TrainingConfig, ModelConfig
# ============================================================================
# Base Model
# ============================================================================
class BaseModel(nn.Module, ABC):
"""Abstract base class for language models."""
def __init__(self, config: ModelConfig):
"""Initialize model with config."""
super().__init__()
self.config = config
@abstractmethod
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
) -> dict[str, torch.Tensor]:
"""
Forward pass.
Args:
input_ids: Shape (batch_size, seq_length)
attention_mask: Shape (batch_size, seq_length), optional
labels: Shape (batch_size, seq_length), optional (for loss computation)
Returns:
Dict with keys:
- 'logits': Shape (batch_size, seq_length, vocab_size)
- 'loss': Scalar (if labels provided)
"""
pass
def count_parameters(self) -> int:
"""Count total trainable parameters."""
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def get_num_layers(self) -> int:
"""Get number of layers (for model architecture)."""
return self.config.num_layers
# ============================================================================
# Base Dataset
# ============================================================================
class BaseDataset(TorchDataset, ABC):
"""Abstract base class for datasets."""
def __init__(self, config: "TrainingConfig"):
"""Initialize dataset."""
self.config = config
self.data = None
@abstractmethod
def __len__(self) -> int:
"""Return dataset size."""
pass
@abstractmethod
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
"""
Get a single sample.
Returns:
Dict with keys:
- 'input_ids': 1D tensor of token IDs
- 'attention_mask': 1D tensor of attention mask
- 'labels': 1D tensor of labels (optional)
"""
pass
def load_dataset(self) -> None:
"""Load dataset from HuggingFace or other source."""
pass
def preprocess(self) -> None:
"""Preprocess dataset (tokenization, etc)."""
pass
# ============================================================================
# Base Trainer
# ============================================================================
class BaseTrainer(ABC):
"""Abstract base class for trainers."""
def __init__(
self,
model: BaseModel,
train_dataset: BaseDataset,
val_dataset: Optional[BaseDataset],
config: TrainingConfig,
device: torch.device,
):
"""Initialize trainer."""
self.model = model.to(device)
self.train_dataset = train_dataset
self.val_dataset = val_dataset
self.config = config
self.device = device
# Training state
self.global_step = 0
self.current_epoch = 0
self.best_loss = float('inf')
# Logging
self.logger = None
# Optimizer and scheduler (to be set up by subclass)
self.optimizer = None
self.scheduler = None
@abstractmethod
def training_step(self, batch: dict[str, torch.Tensor]) -> dict[str, float]:
"""
Single training step.
Args:
batch: Training batch with input_ids, attention_mask, labels, etc.
Returns:
Dict with metrics (e.g., {'loss': 0.5, 'accuracy': 0.8})
"""
pass
@abstractmethod
def validation_step(self, batch: dict[str, torch.Tensor]) -> dict[str, float]:
"""
Single validation step.
Args:
batch: Validation batch
Returns:
Dict with validation metrics
"""
pass
@abstractmethod
def train_epoch(self) -> dict[str, float]:
"""
Train for one epoch.
Returns:
Dict with epoch-level metrics
"""
pass
@abstractmethod
def validate(self) -> dict[str, float]:
"""
Run validation on the entire validation set.
Returns:
Dict with validation metrics
"""
pass
def save_checkpoint(self, path: str | Path) -> None:
"""
Save checkpoint in canonical format.
Uses canonical checkpoint format:
{
'step': int,
'model_state': state_dict,
'optimizer_state': state_dict,
'config': dict,
'metrics': dict,
'global_step': int, # Legacy compat
'current_epoch': int, # Legacy compat
'best_loss': float, # Legacy compat
}
Args:
path: Path to save checkpoint
"""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
# Save in canonical format
checkpoint = {
# Canonical format keys
'step': self.global_step,
'model_state': self.model.state_dict(),
'optimizer_state': self.optimizer.state_dict() if self.optimizer else None,
'config': self.config.to_dict(),
'metrics': {},
# Legacy format keys (for backward compatibility with code that reads them)
'global_step': self.global_step,
'current_epoch': self.current_epoch,
'best_loss': self.best_loss,
}
torch.save(checkpoint, path)
def load_checkpoint(self, path: str | Path) -> None:
"""
Load checkpoint (handles both canonical and legacy formats).
Args:
path: Path to checkpoint
"""
path = Path(path)
checkpoint = torch.load(path, map_location=self.device)
# Try canonical keys first, fall back to legacy keys
model_state_key = 'model_state' if 'model_state' in checkpoint else 'model_state_dict'
optimizer_state_key = 'optimizer_state' if 'optimizer_state' in checkpoint else 'optimizer_state_dict'
self.model.load_state_dict(checkpoint[model_state_key])
if self.optimizer and checkpoint.get(optimizer_state_key):
self.optimizer.load_state_dict(checkpoint[optimizer_state_key])
# Try canonical 'step' first, fall back to legacy 'global_step'
self.global_step = checkpoint.get('step', checkpoint.get('global_step', 0))
self.current_epoch = checkpoint.get('current_epoch', 0)
self.best_loss = checkpoint.get('best_loss', float('inf'))
def _get_lr(self) -> float:
"""Get current learning rate from optimizer."""
for param_group in self.optimizer.param_groups:
return param_group['lr']
return 0.0
# ============================================================================
# Utility functions
# ============================================================================
def create_model(config: TrainingConfig, device: torch.device) -> BaseModel:
"""Create model from config (calls registry)."""
from taoTrain.models import get_model
return get_model(config.model, device=device)
def create_datasets(
config: TrainingConfig,
) -> tuple[BaseDataset, Optional[BaseDataset]]:
"""Create train and validation datasets using factory pattern."""
# Import here to avoid circular imports
from taoTrain.data import DatasetFactory
# Create train dataset
train_dataset = DatasetFactory.create_dataset(config, split="train")
# Create validation dataset (only for HuggingFace datasets with explicit validation split)
val_dataset = None
if not config.dataset.local and hasattr(config.dataset, "validation_split"):
val_dataset = DatasetFactory.create_dataset(config, split="validation")
return train_dataset, val_dataset