| """ |
| generation_demo.py - Demonstrates text generation from trained models. |
| |
| Trains both RippleGPT and VanillaGPT2 briefly, then generates text |
| from the same prompt to show qualitative differences. |
| """ |
|
|
| import sys |
| from pathlib import Path |
| import torch |
|
|
| sys.path.insert(0, str(Path(__file__).parent.parent.parent)) |
|
|
| from src.config import RippleConfig |
| from src.model import RippleGPT |
| from validation.benchmarks.baseline_gpt2 import VanillaGPT2, GPT2Config |
| from validation.benchmarks.quick_benchmark import ( |
| SimpleTextDataset, |
| get_sample_text, |
| get_device |
| ) |
| from torch.utils.data import DataLoader |
|
|
|
|
| def train_model_quick(model, dataloader, device, iterations=1000): |
| """Quick training for demonstration.""" |
| model = model.to(device) |
| optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) |
| |
| model.train() |
| data_iter = iter(dataloader) |
| |
| for i in range(iterations): |
| try: |
| x, y = next(data_iter) |
| except StopIteration: |
| data_iter = iter(dataloader) |
| x, y = next(data_iter) |
| |
| x, y = x.to(device), y.to(device) |
| optimizer.zero_grad() |
| _, loss = model(x, y) |
| loss.backward() |
| optimizer.step() |
| |
| if (i + 1) % 50 == 0: |
| print(f" Iteration {i+1}/{iterations}, loss: {loss.item():.4f}") |
| |
| return model |
|
|
|
|
| def generate_text(model, dataset, prompt_str, max_tokens=100, temperature=0.8): |
| """Generate text from a prompt.""" |
| model.eval() |
| device = next(model.parameters()).device |
| |
| |
| prompt_ids = [dataset.stoi.get(c, 0) for c in prompt_str] |
| x = torch.tensor([prompt_ids], dtype=torch.long, device=device) |
| |
| |
| with torch.no_grad(): |
| output = model.generate(x, max_new_tokens=max_tokens, temperature=temperature, top_k=40) |
| |
| |
| generated_ids = output[0].tolist() |
| generated_text = ''.join([dataset.itos.get(i, '?') for i in generated_ids]) |
| |
| return generated_text |
|
|
|
|
| def main(): |
| device = get_device() |
| print("="*70) |
| print("๐ญ TEXT GENERATION DEMO: RippleGPT vs VanillaGPT2") |
| print("="*70) |
| print(f"Device: {device}") |
| |
| |
| print("\n๐ Creating dataset...") |
| text = get_sample_text() |
| dataset = SimpleTextDataset(text, block_size=256) |
| dataloader = DataLoader(dataset, batch_size=32, shuffle=True) |
| |
| print(f" Vocab size: {dataset.vocab_size}") |
| print(f" Dataset size: {len(dataset)} samples") |
| |
| |
| print("\n๐ง Creating models...") |
| |
| ripple_config = RippleConfig( |
| vocab_size=dataset.vocab_size, |
| n_layer=4, |
| n_head=4, |
| n_embd=256, |
| block_size=256, |
| dropout=0.1, |
| use_absolute_pos_emb=False |
| ) |
| ripple_model = RippleGPT(ripple_config) |
| |
| baseline_config = GPT2Config( |
| vocab_size=dataset.vocab_size, |
| n_layer=4, |
| n_head=4, |
| n_embd=256, |
| block_size=256, |
| dropout=0.1 |
| ) |
| baseline_model = VanillaGPT2(baseline_config) |
| |
| print(f" RippleGPT: {ripple_model.get_num_params():,} params") |
| print(f" VanillaGPT2: {baseline_model.get_num_params():,} params") |
| |
| |
| print("\n๐๏ธ Training RippleGPT (200 iterations)...") |
| ripple_model = train_model_quick(ripple_model, dataloader, device) |
| |
| print("\n๐๏ธ Training VanillaGPT2 (200 iterations)...") |
| baseline_model = train_model_quick(baseline_model, dataloader, device) |
| |
| |
| prompts = [ |
| "def hello():\n ", |
| "for i in range(", |
| "Once upon a time, ", |
| "class MyClass:\n def ", |
| "The cat ", |
| ] |
| |
| print("\n" + "="*70) |
| print("๐ GENERATION EXAMPLES") |
| print("="*70) |
| |
| for prompt in prompts: |
| print(f"\n{'='*50}") |
| print(f"PROMPT: {repr(prompt)}") |
| print("-"*50) |
| |
| |
| ripple_output = generate_text(ripple_model, dataset, prompt, max_tokens=60) |
| print(f"\n๐ข RippleGPT:") |
| print(ripple_output) |
| |
| |
| baseline_output = generate_text(baseline_model, dataset, prompt, max_tokens=60) |
| print(f"\n๐ต VanillaGPT2:") |
| print(baseline_output) |
| |
| print("\n" + "="*70) |
| print("โ
Generation demo complete!") |
| print("="*70) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|