""" quick_benchmark.py - Quick benchmark with smaller vocabulary for fast validation. This script uses a character-level tokenizer (much smaller vocab) for faster training and lower memory usage. Ideal for quick architecture comparison. """ import argparse import json import sys import time from datetime import datetime from pathlib import Path from typing import Dict, List, Optional import gc import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader # Add parent paths sys.path.insert(0, str(Path(__file__).parent.parent.parent)) from src.config import RippleConfig from src.model import RippleGPT from validation.benchmarks.baseline_gpt2 import VanillaGPT2, GPT2Config # ============================================================================ # SIMPLE CHARACTER-LEVEL DATASET # ============================================================================ class SimpleTextDataset(Dataset): """ Simple character-level dataset for quick benchmarks. Much smaller vocab size (~100) compared to BPE (~50k). """ def __init__(self, text: str, block_size: int = 256): # Build vocabulary chars = sorted(list(set(text))) self.vocab_size = len(chars) self.stoi = {ch: i for i, ch in enumerate(chars)} self.itos = {i: ch for i, ch in enumerate(chars)} # Encode text data = [self.stoi[ch] for ch in text] self.data = torch.tensor(data, dtype=torch.long) self.block_size = block_size def __len__(self): return len(self.data) - self.block_size - 1 def __getitem__(self, idx): x = self.data[idx:idx + self.block_size] y = self.data[idx + 1:idx + self.block_size + 1] return x, y def get_sample_text() -> str: """Generate sample text for quick benchmarks.""" # Simple patterns that both models should be able to learn samples = [] # Python-like code patterns code_patterns = [ "def hello():\n print('hello world')\n\n", "for i in range(10):\n x = i * 2\n print(x)\n\n", "class MyClass:\n def __init__(self):\n self.x = 0\n\n", "if x > 0:\n result = x + 1\nelse:\n result = 0\n\n", "def add(a, b):\n return a + b\n\n", "numbers = [1, 2, 3, 4, 5]\nfor n in numbers:\n print(n)\n\n", ] # Story-like patterns story_patterns = [ "Once upon a time, there was a little cat. The cat liked to play. ", "The dog ran fast. It was happy. The sun was shining bright. ", "A bird flew in the sky. It sang a beautiful song. Everyone listened. ", "The boy went to school. He learned many things. He was smart. ", ] # Repeat patterns to create dataset for _ in range(100): samples.extend(code_patterns) samples.extend(story_patterns) return "".join(samples) # ============================================================================ # UTILITY FUNCTIONS # ============================================================================ def get_device() -> torch.device: """Get the best available device.""" if torch.cuda.is_available(): return torch.device("cuda") elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu") def get_memory_mb() -> float: """Get current memory usage in MB.""" import psutil return psutil.Process().memory_info().rss / 1024 / 1024 # ============================================================================ # MODEL CREATION # ============================================================================ def create_ripple_model(vocab_size: int) -> RippleGPT: """Create a small RippleGPT model.""" config = RippleConfig( vocab_size=vocab_size, n_layer=4, n_head=4, n_embd=256, block_size=256, dropout=0.1, use_absolute_pos_emb=False ) return RippleGPT(config) def create_baseline_model(vocab_size: int) -> VanillaGPT2: """Create a small VanillaGPT2 model.""" config = GPT2Config( vocab_size=vocab_size, n_layer=4, n_head=4, n_embd=256, block_size=256, dropout=0.1 ) return VanillaGPT2(config) # ============================================================================ # TRAINING # ============================================================================ def train_model( model: nn.Module, dataloader: DataLoader, max_iters: int, model_name: str, device: torch.device ) -> Dict: """Train a model and collect metrics.""" model = model.to(device) optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_iters) train_losses = [] total_samples = 0 iteration = 0 start_time = time.time() print(f"\nšŸ‹ļø Training {model_name}...") print(f" Max iterations: {max_iters}") model.train() # Use infinite dataloader iteration data_iter = iter(dataloader) while iteration < max_iters: # Get next batch (cycle through dataset) try: x, y = next(data_iter) except StopIteration: data_iter = iter(dataloader) x, y = next(data_iter) x, y = x.to(device), y.to(device) optimizer.zero_grad() _, loss = model(x, y) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() total_samples += x.size(0) iteration += 1 if iteration % 50 == 0 or iteration == max_iters: train_losses.append((iteration, loss.item())) elapsed = time.time() - start_time print(f" [{iteration:4d}/{max_iters}] loss: {loss.item():.4f} | " f"{total_samples/elapsed:.1f} samples/sec") elapsed_time = time.time() - start_time return { "train_losses": train_losses, "final_loss": train_losses[-1][1] if train_losses else float('inf'), "samples_per_sec": total_samples / elapsed_time, "total_time_sec": elapsed_time } # ============================================================================ # MAIN # ============================================================================ def run_quick_benchmark(): """Run a quick comparative benchmark.""" device = get_device() print("\n" + "="*60) print("šŸš€ QUICK BENCHMARK: RippleGPT vs VanillaGPT2") print("="*60) print(f"Device: {device}") # Create dataset print("\nšŸ“š Creating dataset...") text = get_sample_text() dataset = SimpleTextDataset(text, block_size=256) dataloader = DataLoader(dataset, batch_size=32, shuffle=True) print(f" Vocab size: {dataset.vocab_size}") print(f" Dataset size: {len(dataset)} samples") print(f" Block size: 256") # Create models print("\nšŸ”§ Creating models...") ripple_model = create_ripple_model(dataset.vocab_size) baseline_model = create_baseline_model(dataset.vocab_size) ripple_params = ripple_model.get_num_params() baseline_params = baseline_model.get_num_params() print(f" RippleGPT: {ripple_params:,} parameters") print(f" VanillaGPT2: {baseline_params:,} parameters") print(f" Difference: {baseline_params - ripple_params:+,} ({(baseline_params/ripple_params - 1)*100:+.1f}%)") max_iters = 1000 # Train RippleGPT print("\n" + "="*50) ripple_results = train_model(ripple_model, dataloader, max_iters, "RippleGPT", device) # Train VanillaGPT2 print("\n" + "="*50) baseline_results = train_model(baseline_model, dataloader, max_iters, "VanillaGPT2", device) # Summary print("\n" + "="*60) print("šŸ“Š RESULTS SUMMARY") print("="*60) print(f"\n{'Metric':<25} {'RippleGPT':<15} {'VanillaGPT2':<15} {'Winner':<12}") print("-"*60) # Parameters winner = "RippleGPT" if ripple_params < baseline_params else "VanillaGPT2" print(f"{'Parameters':<25} {ripple_params:,} {baseline_params:,} {winner:<12}") # Final loss r_loss = ripple_results["final_loss"] b_loss = baseline_results["final_loss"] winner = "RippleGPT" if r_loss < b_loss else "VanillaGPT2" print(f"{'Final Loss':<25} {r_loss:.4f} {b_loss:.4f} {winner:<12}") # Speed r_speed = ripple_results["samples_per_sec"] b_speed = baseline_results["samples_per_sec"] winner = "RippleGPT" if r_speed > b_speed else "VanillaGPT2" print(f"{'Speed (samples/sec)':<25} {r_speed:.1f} {b_speed:.1f} {winner:<12}") # Time r_time = ripple_results["total_time_sec"] b_time = baseline_results["total_time_sec"] winner = "RippleGPT" if r_time < b_time else "VanillaGPT2" print(f"{'Time (sec)':<25} {r_time:.1f} {b_time:.1f} {winner:<12}") print("="*60) # Save results results = { "metadata": { "timestamp": datetime.now().isoformat(), "device": str(device), "vocab_size": dataset.vocab_size, "max_iters": max_iters }, "parameters": { "ripple": ripple_params, "baseline": baseline_params }, "ripple": ripple_results, "baseline": baseline_results } output_dir = Path("validation/benchmarks/results") output_dir.mkdir(parents=True, exist_ok=True) result_file = output_dir / f"quick_benchmark_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" with open(result_file, "w") as f: json.dump(results, f, indent=2) print(f"\nšŸ’¾ Results saved to: {result_file}") return results if __name__ == '__main__': run_quick_benchmark()