"""Benchmarking suite for evaluating trained models.""" import time from pathlib import Path from typing import Optional, Dict import torch from torch.utils.data import DataLoader from taoTrain.core import BaseModel from taoTrain.config import TrainingConfig from taoTrain.data.loaders import get_dataloader from taoTrain.inference import Inferencer class BenchmarkRunner: """Run benchmarks on a trained model.""" def __init__( self, model: BaseModel, device: torch.device, dtype: torch.dtype = torch.float32, ): """ Initialize benchmark runner. Args: model: Trained model device: Device for inference dtype: Data type """ self.model = model.to(device) self.model.eval() self.device = device self.dtype = dtype @staticmethod def load_from_checkpoint( checkpoint_path: str | Path, device: Optional[torch.device] = None, ) -> "BenchmarkRunner": """Load model from checkpoint.""" if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") checkpoint = torch.load(checkpoint_path, map_location=device) # Reconstruct model config from taoTrain.config import ModelConfig from taoTrain.models import get_model model_config = ModelConfig(**checkpoint.get("config", {}).get("model", {})) model = get_model(model_config, device=device) model.load_state_dict(checkpoint["model_state_dict"]) return BenchmarkRunner(model, device) def benchmark_perplexity( self, dataset: "DataLoader", num_batches: Optional[int] = None, ) -> float: """ Compute perplexity on a dataset. Args: dataset: DataLoader for evaluation num_batches: Limit evaluation to N batches Returns: Perplexity (exp of average loss) """ total_loss = 0.0 total_tokens = 0 with torch.no_grad(): for batch_idx, batch in enumerate(dataset): if num_batches and batch_idx >= num_batches: break # Move to device input_ids = batch["input_ids"].to(self.device) attention_mask = batch.get("attention_mask") if attention_mask is not None: attention_mask = attention_mask.to(self.device) labels = batch.get("labels") if labels is not None: labels = labels.to(self.device) # Forward pass with torch.autocast( device_type="cuda" if self.device.type == "cuda" else "cpu", dtype=torch.bfloat16 if self.dtype == torch.bfloat16 else torch.float32, ): outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, labels=labels, ) loss = outputs.get("loss") if loss is not None: total_loss += loss.item() * input_ids.shape[0] total_tokens += input_ids.shape[0] avg_loss = total_loss / total_tokens if total_tokens > 0 else float('inf') perplexity = torch.exp(torch.tensor(avg_loss)).item() return perplexity def benchmark_throughput( self, batch_size: int = 32, seq_length: int = 1024, num_iters: int = 10, ) -> Dict[str, float]: """ Benchmark forward pass throughput. Args: batch_size: Batch size seq_length: Sequence length num_iters: Number of iterations Returns: Dict with throughput metrics """ # Create dummy batch dummy_input = torch.randint( 0, self.model.config.vocab_size, (batch_size, seq_length) ).to(self.device) # Warmup with torch.no_grad(): for _ in range(2): _ = self.model(dummy_input) torch.cuda.synchronize() if torch.cuda.is_available() else None # Benchmark forward pass start = time.time() with torch.no_grad(): for _ in range(num_iters): _ = self.model(dummy_input) torch.cuda.synchronize() if torch.cuda.is_available() else None elapsed = time.time() - start total_tokens = batch_size * seq_length * num_iters tokens_per_sec = total_tokens / elapsed return { "throughput_tokens_per_sec": tokens_per_sec, "throughput_samples_per_sec": (batch_size * num_iters) / elapsed, "avg_time_per_iter_ms": (elapsed / num_iters) * 1000, } def benchmark_memory(self) -> Dict[str, float]: """ Benchmark peak GPU memory usage. Returns: Dict with memory stats """ if not torch.cuda.is_available(): return {"peak_memory_gb": 0.0} torch.cuda.reset_peak_memory_stats() torch.cuda.synchronize() # Create dummy batch dummy_input = torch.randint( 0, self.model.config.vocab_size, (16, 1024) ).to(self.device) with torch.no_grad(): _ = self.model(dummy_input) torch.cuda.synchronize() peak_memory = torch.cuda.max_memory_allocated() / (1024 ** 3) # GB return {"peak_memory_gb": peak_memory} def run_all_benchmarks( self, dataset: Optional["DataLoader"] = None, batch_size: int = 32, seq_length: int = 1024, ) -> Dict[str, float]: """ Run all benchmarks. Args: dataset: DataLoader for perplexity benchmark batch_size: Batch size for throughput benchmark seq_length: Sequence length for throughput benchmark Returns: Dict with all benchmark results """ results = {} if dataset is not None: print("Running perplexity benchmark...") ppl = self.benchmark_perplexity(dataset, num_batches=10) results["perplexity"] = ppl print("Running throughput benchmark...") throughput = self.benchmark_throughput(batch_size, seq_length) results.update(throughput) print("Running memory benchmark...") memory = self.benchmark_memory() results.update(memory) return results