RippleGPT-Nano / validation /benchmarks /generation_demo.py
Tavernari's picture
Upload folder using huggingface_hub
148b631 verified
"""
generation_demo.py - Demonstrates text generation from trained models.
Trains both RippleGPT and VanillaGPT2 briefly, then generates text
from the same prompt to show qualitative differences.
"""
import sys
from pathlib import Path
import torch
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from src.config import RippleConfig
from src.model import RippleGPT
from validation.benchmarks.baseline_gpt2 import VanillaGPT2, GPT2Config
from validation.benchmarks.quick_benchmark import (
SimpleTextDataset,
get_sample_text,
get_device
)
from torch.utils.data import DataLoader
def train_model_quick(model, dataloader, device, iterations=1000):
"""Quick training for demonstration."""
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
model.train()
data_iter = iter(dataloader)
for i in range(iterations):
try:
x, y = next(data_iter)
except StopIteration:
data_iter = iter(dataloader)
x, y = next(data_iter)
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
_, loss = model(x, y)
loss.backward()
optimizer.step()
if (i + 1) % 50 == 0:
print(f" Iteration {i+1}/{iterations}, loss: {loss.item():.4f}")
return model
def generate_text(model, dataset, prompt_str, max_tokens=100, temperature=0.8):
"""Generate text from a prompt."""
model.eval()
device = next(model.parameters()).device
# Encode prompt
prompt_ids = [dataset.stoi.get(c, 0) for c in prompt_str]
x = torch.tensor([prompt_ids], dtype=torch.long, device=device)
# Generate
with torch.no_grad():
output = model.generate(x, max_new_tokens=max_tokens, temperature=temperature, top_k=40)
# Decode
generated_ids = output[0].tolist()
generated_text = ''.join([dataset.itos.get(i, '?') for i in generated_ids])
return generated_text
def main():
device = get_device()
print("="*70)
print("๐ŸŽญ TEXT GENERATION DEMO: RippleGPT vs VanillaGPT2")
print("="*70)
print(f"Device: {device}")
# Create dataset
print("\n๐Ÿ“š Creating dataset...")
text = get_sample_text()
dataset = SimpleTextDataset(text, block_size=256)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
print(f" Vocab size: {dataset.vocab_size}")
print(f" Dataset size: {len(dataset)} samples")
# Create models
print("\n๐Ÿ”ง Creating models...")
ripple_config = RippleConfig(
vocab_size=dataset.vocab_size,
n_layer=4,
n_head=4,
n_embd=256,
block_size=256,
dropout=0.1,
use_absolute_pos_emb=False
)
ripple_model = RippleGPT(ripple_config)
baseline_config = GPT2Config(
vocab_size=dataset.vocab_size,
n_layer=4,
n_head=4,
n_embd=256,
block_size=256,
dropout=0.1
)
baseline_model = VanillaGPT2(baseline_config)
print(f" RippleGPT: {ripple_model.get_num_params():,} params")
print(f" VanillaGPT2: {baseline_model.get_num_params():,} params")
# Train models
print("\n๐Ÿ‹๏ธ Training RippleGPT (200 iterations)...")
ripple_model = train_model_quick(ripple_model, dataloader, device)
print("\n๐Ÿ‹๏ธ Training VanillaGPT2 (200 iterations)...")
baseline_model = train_model_quick(baseline_model, dataloader, device)
# Test prompts
prompts = [
"def hello():\n ",
"for i in range(",
"Once upon a time, ",
"class MyClass:\n def ",
"The cat ",
]
print("\n" + "="*70)
print("๐Ÿ“ GENERATION EXAMPLES")
print("="*70)
for prompt in prompts:
print(f"\n{'='*50}")
print(f"PROMPT: {repr(prompt)}")
print("-"*50)
# RippleGPT generation
ripple_output = generate_text(ripple_model, dataset, prompt, max_tokens=60)
print(f"\n๐ŸŸข RippleGPT:")
print(ripple_output)
# VanillaGPT2 generation
baseline_output = generate_text(baseline_model, dataset, prompt, max_tokens=60)
print(f"\n๐Ÿ”ต VanillaGPT2:")
print(baseline_output)
print("\n" + "="*70)
print("โœ… Generation demo complete!")
print("="*70)
if __name__ == '__main__':
main()