""" extrapolation_test.py - Scientific Extrapolation Test for Ripple Field This test validates the MAIN THESIS of RippleGPT: "A model trained with block_size=X can infer with quality on 2X, 4X, etc." The test: 1. Loads a trained model (e.g. block_size=512) 2. Measures perplexity on contexts of 256, 512, 1024, 2048 tokens 3. Compares the quality degradation IF perplexity remains stable beyond the training block_size, the ALiBi/Ripple Field architecture is VALIDATED. Usage: python validation/memory/extrapolation_test.py --config medium python validation/memory/extrapolation_test.py --config large --max-context 4096 """ import os import sys import argparse import pickle import time from typing import Tuple, List, Dict import torch import numpy as np import psutil # Add root directory to path sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) from src.model import RippleGPT from src.config import RippleConfig from validation.memory.model_configs import get_config # Directories DATA_DIR = os.path.join(os.path.dirname(__file__), 'data') CKPT_DIR = os.path.join(os.path.dirname(__file__), 'checkpoints') DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' def load_model(config_name: str) -> Tuple[RippleGPT, RippleConfig]: """Loads trained model without modifying block_size.""" best_path = os.path.join(CKPT_DIR, f'ckpt_{config_name}_best.pt') final_path = os.path.join(CKPT_DIR, f'ckpt_{config_name}_final.pt') if os.path.exists(best_path): ckpt_path = best_path elif os.path.exists(final_path): ckpt_path = final_path else: raise FileNotFoundError( f"Checkpoint not found for config '{config_name}'\n" f"Run: python validation/memory/train_large.py --config {config_name}" ) print(f"📦 Loading model from: {ckpt_path}") checkpoint = torch.load(ckpt_path, map_location=DEVICE, weights_only=False) config = checkpoint['config'] model = RippleGPT(config) state_dict = checkpoint['model'] unwanted_prefix = '_orig_mod.' for k in list(state_dict.keys()): if k.startswith(unwanted_prefix): state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) model.load_state_dict(state_dict) model.to(DEVICE) model.eval() print(f" ✅ Model loaded ({model.get_num_params()/1e6:.2f}M params)") print(f" 📏 Training block size: {config.block_size}") return model, config def load_data() -> torch.Tensor: """Loads validation data.""" val_path = os.path.join(DATA_DIR, 'val.bin') if not os.path.exists(val_path): raise FileNotFoundError( f"Validation data not found at {val_path}\n" f"Run: python validation/memory/prepare_large_data.py" ) data = np.fromfile(val_path, dtype=np.uint16) return torch.from_numpy(data.astype(np.int64)) @torch.no_grad() def measure_perplexity( model: RippleGPT, data: torch.Tensor, context_len: int, num_batches: int = 20 ) -> Dict: """ Measures perplexity on a specific context. Returns: Dict with loss, perplexity, memory usage, time """ if len(data) < context_len + 1: return {'error': 'Insufficient data for this context'} # Measure memory before if DEVICE == 'cuda': torch.cuda.reset_peak_memory_stats() mem_before = torch.cuda.memory_allocated() / 1e6 else: mem_before = psutil.Process().memory_info().rss / 1e6 total_loss = 0 valid_batches = 0 start_time = time.time() for i in range(num_batches): start_idx = i * context_len if start_idx + context_len + 1 > len(data): break x = data[start_idx : start_idx + context_len].unsqueeze(0).to(DEVICE) y = data[start_idx + 1 : start_idx + context_len + 1].unsqueeze(0).to(DEVICE) try: _, loss = model(x, y) total_loss += loss.item() valid_batches += 1 except RuntimeError as e: if 'out of memory' in str(e).lower(): if DEVICE == 'cuda': torch.cuda.empty_cache() return {'error': f'OOM on context {context_len}', 'memory_error': True} raise elapsed = time.time() - start_time # Measure memory after if DEVICE == 'cuda': mem_after = torch.cuda.max_memory_allocated() / 1e6 else: mem_after = psutil.Process().memory_info().rss / 1e6 if valid_batches == 0: return {'error': 'No batch processed'} avg_loss = total_loss / valid_batches perplexity = np.exp(avg_loss) return { 'context_len': context_len, 'loss': avg_loss, 'perplexity': perplexity, 'memory_mb': mem_after - mem_before, 'peak_memory_mb': mem_after, 'time_seconds': elapsed, 'tokens_per_second': (context_len * valid_batches) / elapsed } def run_extrapolation_test( model: RippleGPT, config: RippleConfig, data: torch.Tensor, max_context: int = 4096 ) -> Dict: """ Executes progressive extrapolation test. """ train_block_size = config.block_size # Contexts to test: 0.5x, 1x, 2x, 4x, 8x of training block_size multipliers = [0.5, 1.0, 2.0, 4.0, 8.0] contexts = [int(train_block_size * m) for m in multipliers] contexts = [c for c in contexts if c <= max_context and c >= 64] print(f"\n📊 Testing extrapolation:") print(f" Training block size: {train_block_size}") print(f" Contexts to test: {contexts}") print("-" * 70) results = { 'train_block_size': train_block_size, 'tests': [] } baseline_perplexity = None for ctx_len in contexts: is_extrapolation = ctx_len > train_block_size marker = "🔬" if is_extrapolation else "📏" label = f"({ctx_len/train_block_size:.1f}x)" if ctx_len != train_block_size else "(train)" print(f"\n{marker} Context: {ctx_len} tokens {label}") result = measure_perplexity(model, data, ctx_len) if 'error' in result: print(f" ❌ {result['error']}") result['is_extrapolation'] = is_extrapolation result['extrapolation_ratio'] = ctx_len / train_block_size results['tests'].append(result) continue # Save baseline if ctx_len == train_block_size: baseline_perplexity = result['perplexity'] # Calculate degradation if baseline_perplexity: degradation = (result['perplexity'] - baseline_perplexity) / baseline_perplexity * 100 else: degradation = 0 result['is_extrapolation'] = is_extrapolation result['extrapolation_ratio'] = ctx_len / train_block_size result['degradation_pct'] = degradation status = "✅" if degradation < 20 else ("⚠️" if degradation < 50 else "❌") print(f" Loss: {result['loss']:.4f}") print(f" Perplexity: {result['perplexity']:.2f}") print(f" Degradation vs train: {degradation:+.1f}%") print(f" Memory: {result['peak_memory_mb']:.1f} MB") print(f" Status: {status}") results['tests'].append(result) return results def print_summary(results: Dict): """Prints extrapolation test summary.""" print("\n" + "=" * 70) print("📈 EXTRAPOLATION TEST SUMMARY") print("=" * 70) train_bs = results['train_block_size'] tests = [t for t in results['tests'] if 'error' not in t] if not tests: print("❌ No test completed successfully.") return print(f"\n{'Context':<12} {'Ratio':<8} {'Loss':<10} {'PPL':<10} {'Degrad.':<10} {'Mem (MB)':<12}") print("-" * 70) for t in tests: ctx = t['context_len'] ratio = f"{t['extrapolation_ratio']:.1f}x" loss = f"{t['loss']:.4f}" ppl = f"{t['perplexity']:.2f}" deg = f"{t.get('degradation_pct', 0):+.1f}%" mem = f"{t['peak_memory_mb']:.1f}" marker = "🔬" if t['is_extrapolation'] else "📏" print(f"{marker} {ctx:<10} {ratio:<8} {loss:<10} {ppl:<10} {deg:<10} {mem:<12}") # Verdict extrapolation_tests = [t for t in tests if t['is_extrapolation']] if not extrapolation_tests: print("\n⚠️ No extrapolation test was executed.") return avg_degradation = sum(t.get('degradation_pct', 0) for t in extrapolation_tests) / len(extrapolation_tests) max_successful_ratio = max(t['extrapolation_ratio'] for t in extrapolation_tests if t.get('degradation_pct', 100) < 50) print("\n" + "-" * 70) print(f"Average degradation in extrapolation: {avg_degradation:.1f}%") print(f"Max ratio with <50% degradation: {max_successful_ratio:.1f}x") if avg_degradation < 15: print("\n🎉 VERDICT: EXCELLENT! Ripple Field extrapolates with quality!") print(" The ALiBi architecture is working as expected.") elif avg_degradation < 30: print("\n✅ VERDICT: GOOD. Functional extrapolation with moderate degradation.") elif avg_degradation < 50: print("\n⚠️ VERDICT: MARGINAL. Extrapolation works, but with significant loss.") else: print("\n❌ VERDICT: FAIL. The model does not extrapolate well beyond training context.") print("=" * 70) def main(): parser = argparse.ArgumentParser(description='Ripple Field Extrapolation Test') parser.add_argument('--config', type=str, default='medium', choices=['small', 'medium', 'large', 'xlarge']) parser.add_argument('--max-context', type=int, default=4096, help='Max context to test') args = parser.parse_args() print("=" * 70) print("🔬 EXTRAPOLATION TEST - RippleGPT ALiBi Validation") print("=" * 70) print("\n⚠️ NOTE: This test validates the central thesis of RippleGPT:") print(" 'Train on N tokens, infer on 2N-4N with quality'") print(" Memory scales with O(T²) - OOM expected in very long contexts.") # Load model try: model, config = load_model(args.config) except FileNotFoundError as e: print(f"\n❌ {e}") return 1 # Load data try: data = load_data() print(f"\n📚 Data loaded: {len(data)} tokens") except FileNotFoundError as e: print(f"\n❌ {e}") return 1 # Run tests results = run_extrapolation_test(model, config, data, args.max_context) # Print summary print_summary(results) return 0 if __name__ == '__main__': exit(main())