RippleGPT-Nano / validation /benchmarks /quick_benchmark.py
Tavernari's picture
Upload folder using huggingface_hub
148b631 verified
"""
quick_benchmark.py - Quick benchmark with smaller vocabulary for fast validation.
This script uses a character-level tokenizer (much smaller vocab) for faster
training and lower memory usage. Ideal for quick architecture comparison.
"""
import argparse
import json
import sys
import time
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional
import gc
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
# Add parent paths
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from src.config import RippleConfig
from src.model import RippleGPT
from validation.benchmarks.baseline_gpt2 import VanillaGPT2, GPT2Config
# ============================================================================
# SIMPLE CHARACTER-LEVEL DATASET
# ============================================================================
class SimpleTextDataset(Dataset):
"""
Simple character-level dataset for quick benchmarks.
Much smaller vocab size (~100) compared to BPE (~50k).
"""
def __init__(self, text: str, block_size: int = 256):
# Build vocabulary
chars = sorted(list(set(text)))
self.vocab_size = len(chars)
self.stoi = {ch: i for i, ch in enumerate(chars)}
self.itos = {i: ch for i, ch in enumerate(chars)}
# Encode text
data = [self.stoi[ch] for ch in text]
self.data = torch.tensor(data, dtype=torch.long)
self.block_size = block_size
def __len__(self):
return len(self.data) - self.block_size - 1
def __getitem__(self, idx):
x = self.data[idx:idx + self.block_size]
y = self.data[idx + 1:idx + self.block_size + 1]
return x, y
def get_sample_text() -> str:
"""Generate sample text for quick benchmarks."""
# Simple patterns that both models should be able to learn
samples = []
# Python-like code patterns
code_patterns = [
"def hello():\n print('hello world')\n\n",
"for i in range(10):\n x = i * 2\n print(x)\n\n",
"class MyClass:\n def __init__(self):\n self.x = 0\n\n",
"if x > 0:\n result = x + 1\nelse:\n result = 0\n\n",
"def add(a, b):\n return a + b\n\n",
"numbers = [1, 2, 3, 4, 5]\nfor n in numbers:\n print(n)\n\n",
]
# Story-like patterns
story_patterns = [
"Once upon a time, there was a little cat. The cat liked to play. ",
"The dog ran fast. It was happy. The sun was shining bright. ",
"A bird flew in the sky. It sang a beautiful song. Everyone listened. ",
"The boy went to school. He learned many things. He was smart. ",
]
# Repeat patterns to create dataset
for _ in range(100):
samples.extend(code_patterns)
samples.extend(story_patterns)
return "".join(samples)
# ============================================================================
# UTILITY FUNCTIONS
# ============================================================================
def get_device() -> torch.device:
"""Get the best available device."""
if torch.cuda.is_available():
return torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
def get_memory_mb() -> float:
"""Get current memory usage in MB."""
import psutil
return psutil.Process().memory_info().rss / 1024 / 1024
# ============================================================================
# MODEL CREATION
# ============================================================================
def create_ripple_model(vocab_size: int) -> RippleGPT:
"""Create a small RippleGPT model."""
config = RippleConfig(
vocab_size=vocab_size,
n_layer=4,
n_head=4,
n_embd=256,
block_size=256,
dropout=0.1,
use_absolute_pos_emb=False
)
return RippleGPT(config)
def create_baseline_model(vocab_size: int) -> VanillaGPT2:
"""Create a small VanillaGPT2 model."""
config = GPT2Config(
vocab_size=vocab_size,
n_layer=4,
n_head=4,
n_embd=256,
block_size=256,
dropout=0.1
)
return VanillaGPT2(config)
# ============================================================================
# TRAINING
# ============================================================================
def train_model(
model: nn.Module,
dataloader: DataLoader,
max_iters: int,
model_name: str,
device: torch.device
) -> Dict:
"""Train a model and collect metrics."""
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_iters)
train_losses = []
total_samples = 0
iteration = 0
start_time = time.time()
print(f"\n🏋️ Training {model_name}...")
print(f" Max iterations: {max_iters}")
model.train()
# Use infinite dataloader iteration
data_iter = iter(dataloader)
while iteration < max_iters:
# Get next batch (cycle through dataset)
try:
x, y = next(data_iter)
except StopIteration:
data_iter = iter(dataloader)
x, y = next(data_iter)
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
_, loss = model(x, y)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
scheduler.step()
total_samples += x.size(0)
iteration += 1
if iteration % 50 == 0 or iteration == max_iters:
train_losses.append((iteration, loss.item()))
elapsed = time.time() - start_time
print(f" [{iteration:4d}/{max_iters}] loss: {loss.item():.4f} | "
f"{total_samples/elapsed:.1f} samples/sec")
elapsed_time = time.time() - start_time
return {
"train_losses": train_losses,
"final_loss": train_losses[-1][1] if train_losses else float('inf'),
"samples_per_sec": total_samples / elapsed_time,
"total_time_sec": elapsed_time
}
# ============================================================================
# MAIN
# ============================================================================
def run_quick_benchmark():
"""Run a quick comparative benchmark."""
device = get_device()
print("\n" + "="*60)
print("🚀 QUICK BENCHMARK: RippleGPT vs VanillaGPT2")
print("="*60)
print(f"Device: {device}")
# Create dataset
print("\n📚 Creating dataset...")
text = get_sample_text()
dataset = SimpleTextDataset(text, block_size=256)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
print(f" Vocab size: {dataset.vocab_size}")
print(f" Dataset size: {len(dataset)} samples")
print(f" Block size: 256")
# Create models
print("\n🔧 Creating models...")
ripple_model = create_ripple_model(dataset.vocab_size)
baseline_model = create_baseline_model(dataset.vocab_size)
ripple_params = ripple_model.get_num_params()
baseline_params = baseline_model.get_num_params()
print(f" RippleGPT: {ripple_params:,} parameters")
print(f" VanillaGPT2: {baseline_params:,} parameters")
print(f" Difference: {baseline_params - ripple_params:+,} ({(baseline_params/ripple_params - 1)*100:+.1f}%)")
max_iters = 1000
# Train RippleGPT
print("\n" + "="*50)
ripple_results = train_model(ripple_model, dataloader, max_iters, "RippleGPT", device)
# Train VanillaGPT2
print("\n" + "="*50)
baseline_results = train_model(baseline_model, dataloader, max_iters, "VanillaGPT2", device)
# Summary
print("\n" + "="*60)
print("📊 RESULTS SUMMARY")
print("="*60)
print(f"\n{'Metric':<25} {'RippleGPT':<15} {'VanillaGPT2':<15} {'Winner':<12}")
print("-"*60)
# Parameters
winner = "RippleGPT" if ripple_params < baseline_params else "VanillaGPT2"
print(f"{'Parameters':<25} {ripple_params:,} {baseline_params:,} {winner:<12}")
# Final loss
r_loss = ripple_results["final_loss"]
b_loss = baseline_results["final_loss"]
winner = "RippleGPT" if r_loss < b_loss else "VanillaGPT2"
print(f"{'Final Loss':<25} {r_loss:.4f} {b_loss:.4f} {winner:<12}")
# Speed
r_speed = ripple_results["samples_per_sec"]
b_speed = baseline_results["samples_per_sec"]
winner = "RippleGPT" if r_speed > b_speed else "VanillaGPT2"
print(f"{'Speed (samples/sec)':<25} {r_speed:.1f} {b_speed:.1f} {winner:<12}")
# Time
r_time = ripple_results["total_time_sec"]
b_time = baseline_results["total_time_sec"]
winner = "RippleGPT" if r_time < b_time else "VanillaGPT2"
print(f"{'Time (sec)':<25} {r_time:.1f} {b_time:.1f} {winner:<12}")
print("="*60)
# Save results
results = {
"metadata": {
"timestamp": datetime.now().isoformat(),
"device": str(device),
"vocab_size": dataset.vocab_size,
"max_iters": max_iters
},
"parameters": {
"ripple": ripple_params,
"baseline": baseline_params
},
"ripple": ripple_results,
"baseline": baseline_results
}
output_dir = Path("validation/benchmarks/results")
output_dir.mkdir(parents=True, exist_ok=True)
result_file = output_dir / f"quick_benchmark_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
with open(result_file, "w") as f:
json.dump(results, f, indent=2)
print(f"\n💾 Results saved to: {result_file}")
return results
if __name__ == '__main__':
run_quick_benchmark()