RippleGPT-Nano / validation /memory /extrapolation_test.py
Tavernari's picture
Upload folder using huggingface_hub
148b631 verified
"""
extrapolation_test.py - Scientific Extrapolation Test for Ripple Field
This test validates the MAIN THESIS of RippleGPT:
"A model trained with block_size=X can infer with quality on 2X, 4X, etc."
The test:
1. Loads a trained model (e.g. block_size=512)
2. Measures perplexity on contexts of 256, 512, 1024, 2048 tokens
3. Compares the quality degradation
IF perplexity remains stable beyond the training block_size,
the ALiBi/Ripple Field architecture is VALIDATED.
Usage:
python validation/memory/extrapolation_test.py --config medium
python validation/memory/extrapolation_test.py --config large --max-context 4096
"""
import os
import sys
import argparse
import pickle
import time
from typing import Tuple, List, Dict
import torch
import numpy as np
import psutil
# Add root directory to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
from src.model import RippleGPT
from src.config import RippleConfig
from validation.memory.model_configs import get_config
# Directories
DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
CKPT_DIR = os.path.join(os.path.dirname(__file__), 'checkpoints')
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
def load_model(config_name: str) -> Tuple[RippleGPT, RippleConfig]:
"""Loads trained model without modifying block_size."""
best_path = os.path.join(CKPT_DIR, f'ckpt_{config_name}_best.pt')
final_path = os.path.join(CKPT_DIR, f'ckpt_{config_name}_final.pt')
if os.path.exists(best_path):
ckpt_path = best_path
elif os.path.exists(final_path):
ckpt_path = final_path
else:
raise FileNotFoundError(
f"Checkpoint not found for config '{config_name}'\n"
f"Run: python validation/memory/train_large.py --config {config_name}"
)
print(f"๐Ÿ“ฆ Loading model from: {ckpt_path}")
checkpoint = torch.load(ckpt_path, map_location=DEVICE, weights_only=False)
config = checkpoint['config']
model = RippleGPT(config)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k in list(state_dict.keys()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.to(DEVICE)
model.eval()
print(f" โœ… Model loaded ({model.get_num_params()/1e6:.2f}M params)")
print(f" ๐Ÿ“ Training block size: {config.block_size}")
return model, config
def load_data() -> torch.Tensor:
"""Loads validation data."""
val_path = os.path.join(DATA_DIR, 'val.bin')
if not os.path.exists(val_path):
raise FileNotFoundError(
f"Validation data not found at {val_path}\n"
f"Run: python validation/memory/prepare_large_data.py"
)
data = np.fromfile(val_path, dtype=np.uint16)
return torch.from_numpy(data.astype(np.int64))
@torch.no_grad()
def measure_perplexity(
model: RippleGPT,
data: torch.Tensor,
context_len: int,
num_batches: int = 20
) -> Dict:
"""
Measures perplexity on a specific context.
Returns:
Dict with loss, perplexity, memory usage, time
"""
if len(data) < context_len + 1:
return {'error': 'Insufficient data for this context'}
# Measure memory before
if DEVICE == 'cuda':
torch.cuda.reset_peak_memory_stats()
mem_before = torch.cuda.memory_allocated() / 1e6
else:
mem_before = psutil.Process().memory_info().rss / 1e6
total_loss = 0
valid_batches = 0
start_time = time.time()
for i in range(num_batches):
start_idx = i * context_len
if start_idx + context_len + 1 > len(data):
break
x = data[start_idx : start_idx + context_len].unsqueeze(0).to(DEVICE)
y = data[start_idx + 1 : start_idx + context_len + 1].unsqueeze(0).to(DEVICE)
try:
_, loss = model(x, y)
total_loss += loss.item()
valid_batches += 1
except RuntimeError as e:
if 'out of memory' in str(e).lower():
if DEVICE == 'cuda':
torch.cuda.empty_cache()
return {'error': f'OOM on context {context_len}', 'memory_error': True}
raise
elapsed = time.time() - start_time
# Measure memory after
if DEVICE == 'cuda':
mem_after = torch.cuda.max_memory_allocated() / 1e6
else:
mem_after = psutil.Process().memory_info().rss / 1e6
if valid_batches == 0:
return {'error': 'No batch processed'}
avg_loss = total_loss / valid_batches
perplexity = np.exp(avg_loss)
return {
'context_len': context_len,
'loss': avg_loss,
'perplexity': perplexity,
'memory_mb': mem_after - mem_before,
'peak_memory_mb': mem_after,
'time_seconds': elapsed,
'tokens_per_second': (context_len * valid_batches) / elapsed
}
def run_extrapolation_test(
model: RippleGPT,
config: RippleConfig,
data: torch.Tensor,
max_context: int = 4096
) -> Dict:
"""
Executes progressive extrapolation test.
"""
train_block_size = config.block_size
# Contexts to test: 0.5x, 1x, 2x, 4x, 8x of training block_size
multipliers = [0.5, 1.0, 2.0, 4.0, 8.0]
contexts = [int(train_block_size * m) for m in multipliers]
contexts = [c for c in contexts if c <= max_context and c >= 64]
print(f"\n๐Ÿ“Š Testing extrapolation:")
print(f" Training block size: {train_block_size}")
print(f" Contexts to test: {contexts}")
print("-" * 70)
results = {
'train_block_size': train_block_size,
'tests': []
}
baseline_perplexity = None
for ctx_len in contexts:
is_extrapolation = ctx_len > train_block_size
marker = "๐Ÿ”ฌ" if is_extrapolation else "๐Ÿ“"
label = f"({ctx_len/train_block_size:.1f}x)" if ctx_len != train_block_size else "(train)"
print(f"\n{marker} Context: {ctx_len} tokens {label}")
result = measure_perplexity(model, data, ctx_len)
if 'error' in result:
print(f" โŒ {result['error']}")
result['is_extrapolation'] = is_extrapolation
result['extrapolation_ratio'] = ctx_len / train_block_size
results['tests'].append(result)
continue
# Save baseline
if ctx_len == train_block_size:
baseline_perplexity = result['perplexity']
# Calculate degradation
if baseline_perplexity:
degradation = (result['perplexity'] - baseline_perplexity) / baseline_perplexity * 100
else:
degradation = 0
result['is_extrapolation'] = is_extrapolation
result['extrapolation_ratio'] = ctx_len / train_block_size
result['degradation_pct'] = degradation
status = "โœ…" if degradation < 20 else ("โš ๏ธ" if degradation < 50 else "โŒ")
print(f" Loss: {result['loss']:.4f}")
print(f" Perplexity: {result['perplexity']:.2f}")
print(f" Degradation vs train: {degradation:+.1f}%")
print(f" Memory: {result['peak_memory_mb']:.1f} MB")
print(f" Status: {status}")
results['tests'].append(result)
return results
def print_summary(results: Dict):
"""Prints extrapolation test summary."""
print("\n" + "=" * 70)
print("๐Ÿ“ˆ EXTRAPOLATION TEST SUMMARY")
print("=" * 70)
train_bs = results['train_block_size']
tests = [t for t in results['tests'] if 'error' not in t]
if not tests:
print("โŒ No test completed successfully.")
return
print(f"\n{'Context':<12} {'Ratio':<8} {'Loss':<10} {'PPL':<10} {'Degrad.':<10} {'Mem (MB)':<12}")
print("-" * 70)
for t in tests:
ctx = t['context_len']
ratio = f"{t['extrapolation_ratio']:.1f}x"
loss = f"{t['loss']:.4f}"
ppl = f"{t['perplexity']:.2f}"
deg = f"{t.get('degradation_pct', 0):+.1f}%"
mem = f"{t['peak_memory_mb']:.1f}"
marker = "๐Ÿ”ฌ" if t['is_extrapolation'] else "๐Ÿ“"
print(f"{marker} {ctx:<10} {ratio:<8} {loss:<10} {ppl:<10} {deg:<10} {mem:<12}")
# Verdict
extrapolation_tests = [t for t in tests if t['is_extrapolation']]
if not extrapolation_tests:
print("\nโš ๏ธ No extrapolation test was executed.")
return
avg_degradation = sum(t.get('degradation_pct', 0) for t in extrapolation_tests) / len(extrapolation_tests)
max_successful_ratio = max(t['extrapolation_ratio'] for t in extrapolation_tests if t.get('degradation_pct', 100) < 50)
print("\n" + "-" * 70)
print(f"Average degradation in extrapolation: {avg_degradation:.1f}%")
print(f"Max ratio with <50% degradation: {max_successful_ratio:.1f}x")
if avg_degradation < 15:
print("\n๐ŸŽ‰ VERDICT: EXCELLENT! Ripple Field extrapolates with quality!")
print(" The ALiBi architecture is working as expected.")
elif avg_degradation < 30:
print("\nโœ… VERDICT: GOOD. Functional extrapolation with moderate degradation.")
elif avg_degradation < 50:
print("\nโš ๏ธ VERDICT: MARGINAL. Extrapolation works, but with significant loss.")
else:
print("\nโŒ VERDICT: FAIL. The model does not extrapolate well beyond training context.")
print("=" * 70)
def main():
parser = argparse.ArgumentParser(description='Ripple Field Extrapolation Test')
parser.add_argument('--config', type=str, default='medium',
choices=['small', 'medium', 'large', 'xlarge'])
parser.add_argument('--max-context', type=int, default=4096,
help='Max context to test')
args = parser.parse_args()
print("=" * 70)
print("๐Ÿ”ฌ EXTRAPOLATION TEST - RippleGPT ALiBi Validation")
print("=" * 70)
print("\nโš ๏ธ NOTE: This test validates the central thesis of RippleGPT:")
print(" 'Train on N tokens, infer on 2N-4N with quality'")
print(" Memory scales with O(Tยฒ) - OOM expected in very long contexts.")
# Load model
try:
model, config = load_model(args.config)
except FileNotFoundError as e:
print(f"\nโŒ {e}")
return 1
# Load data
try:
data = load_data()
print(f"\n๐Ÿ“š Data loaded: {len(data)} tokens")
except FileNotFoundError as e:
print(f"\nโŒ {e}")
return 1
# Run tests
results = run_extrapolation_test(model, config, data, args.max_context)
# Print summary
print_summary(results)
return 0
if __name__ == '__main__':
exit(main())