|
|
""" |
|
|
extrapolation_test.py - Scientific Extrapolation Test for Ripple Field |
|
|
|
|
|
This test validates the MAIN THESIS of RippleGPT: |
|
|
"A model trained with block_size=X can infer with quality on 2X, 4X, etc." |
|
|
|
|
|
The test: |
|
|
1. Loads a trained model (e.g. block_size=512) |
|
|
2. Measures perplexity on contexts of 256, 512, 1024, 2048 tokens |
|
|
3. Compares the quality degradation |
|
|
|
|
|
IF perplexity remains stable beyond the training block_size, |
|
|
the ALiBi/Ripple Field architecture is VALIDATED. |
|
|
|
|
|
Usage: |
|
|
python validation/memory/extrapolation_test.py --config medium |
|
|
python validation/memory/extrapolation_test.py --config large --max-context 4096 |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import argparse |
|
|
import pickle |
|
|
import time |
|
|
from typing import Tuple, List, Dict |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
import psutil |
|
|
|
|
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) |
|
|
|
|
|
from src.model import RippleGPT |
|
|
from src.config import RippleConfig |
|
|
from validation.memory.model_configs import get_config |
|
|
|
|
|
|
|
|
DATA_DIR = os.path.join(os.path.dirname(__file__), 'data') |
|
|
CKPT_DIR = os.path.join(os.path.dirname(__file__), 'checkpoints') |
|
|
|
|
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' |
|
|
|
|
|
|
|
|
def load_model(config_name: str) -> Tuple[RippleGPT, RippleConfig]: |
|
|
"""Loads trained model without modifying block_size.""" |
|
|
|
|
|
best_path = os.path.join(CKPT_DIR, f'ckpt_{config_name}_best.pt') |
|
|
final_path = os.path.join(CKPT_DIR, f'ckpt_{config_name}_final.pt') |
|
|
|
|
|
if os.path.exists(best_path): |
|
|
ckpt_path = best_path |
|
|
elif os.path.exists(final_path): |
|
|
ckpt_path = final_path |
|
|
else: |
|
|
raise FileNotFoundError( |
|
|
f"Checkpoint not found for config '{config_name}'\n" |
|
|
f"Run: python validation/memory/train_large.py --config {config_name}" |
|
|
) |
|
|
|
|
|
print(f"๐ฆ Loading model from: {ckpt_path}") |
|
|
|
|
|
checkpoint = torch.load(ckpt_path, map_location=DEVICE, weights_only=False) |
|
|
config = checkpoint['config'] |
|
|
|
|
|
model = RippleGPT(config) |
|
|
|
|
|
state_dict = checkpoint['model'] |
|
|
unwanted_prefix = '_orig_mod.' |
|
|
for k in list(state_dict.keys()): |
|
|
if k.startswith(unwanted_prefix): |
|
|
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) |
|
|
|
|
|
model.load_state_dict(state_dict) |
|
|
model.to(DEVICE) |
|
|
model.eval() |
|
|
|
|
|
print(f" โ
Model loaded ({model.get_num_params()/1e6:.2f}M params)") |
|
|
print(f" ๐ Training block size: {config.block_size}") |
|
|
|
|
|
return model, config |
|
|
|
|
|
|
|
|
def load_data() -> torch.Tensor: |
|
|
"""Loads validation data.""" |
|
|
val_path = os.path.join(DATA_DIR, 'val.bin') |
|
|
|
|
|
if not os.path.exists(val_path): |
|
|
raise FileNotFoundError( |
|
|
f"Validation data not found at {val_path}\n" |
|
|
f"Run: python validation/memory/prepare_large_data.py" |
|
|
) |
|
|
|
|
|
data = np.fromfile(val_path, dtype=np.uint16) |
|
|
return torch.from_numpy(data.astype(np.int64)) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def measure_perplexity( |
|
|
model: RippleGPT, |
|
|
data: torch.Tensor, |
|
|
context_len: int, |
|
|
num_batches: int = 20 |
|
|
) -> Dict: |
|
|
""" |
|
|
Measures perplexity on a specific context. |
|
|
|
|
|
Returns: |
|
|
Dict with loss, perplexity, memory usage, time |
|
|
""" |
|
|
if len(data) < context_len + 1: |
|
|
return {'error': 'Insufficient data for this context'} |
|
|
|
|
|
|
|
|
if DEVICE == 'cuda': |
|
|
torch.cuda.reset_peak_memory_stats() |
|
|
mem_before = torch.cuda.memory_allocated() / 1e6 |
|
|
else: |
|
|
mem_before = psutil.Process().memory_info().rss / 1e6 |
|
|
|
|
|
total_loss = 0 |
|
|
valid_batches = 0 |
|
|
start_time = time.time() |
|
|
|
|
|
for i in range(num_batches): |
|
|
start_idx = i * context_len |
|
|
if start_idx + context_len + 1 > len(data): |
|
|
break |
|
|
|
|
|
x = data[start_idx : start_idx + context_len].unsqueeze(0).to(DEVICE) |
|
|
y = data[start_idx + 1 : start_idx + context_len + 1].unsqueeze(0).to(DEVICE) |
|
|
|
|
|
try: |
|
|
_, loss = model(x, y) |
|
|
total_loss += loss.item() |
|
|
valid_batches += 1 |
|
|
except RuntimeError as e: |
|
|
if 'out of memory' in str(e).lower(): |
|
|
if DEVICE == 'cuda': |
|
|
torch.cuda.empty_cache() |
|
|
return {'error': f'OOM on context {context_len}', 'memory_error': True} |
|
|
raise |
|
|
|
|
|
elapsed = time.time() - start_time |
|
|
|
|
|
|
|
|
if DEVICE == 'cuda': |
|
|
mem_after = torch.cuda.max_memory_allocated() / 1e6 |
|
|
else: |
|
|
mem_after = psutil.Process().memory_info().rss / 1e6 |
|
|
|
|
|
if valid_batches == 0: |
|
|
return {'error': 'No batch processed'} |
|
|
|
|
|
avg_loss = total_loss / valid_batches |
|
|
perplexity = np.exp(avg_loss) |
|
|
|
|
|
return { |
|
|
'context_len': context_len, |
|
|
'loss': avg_loss, |
|
|
'perplexity': perplexity, |
|
|
'memory_mb': mem_after - mem_before, |
|
|
'peak_memory_mb': mem_after, |
|
|
'time_seconds': elapsed, |
|
|
'tokens_per_second': (context_len * valid_batches) / elapsed |
|
|
} |
|
|
|
|
|
|
|
|
def run_extrapolation_test( |
|
|
model: RippleGPT, |
|
|
config: RippleConfig, |
|
|
data: torch.Tensor, |
|
|
max_context: int = 4096 |
|
|
) -> Dict: |
|
|
""" |
|
|
Executes progressive extrapolation test. |
|
|
""" |
|
|
train_block_size = config.block_size |
|
|
|
|
|
|
|
|
multipliers = [0.5, 1.0, 2.0, 4.0, 8.0] |
|
|
contexts = [int(train_block_size * m) for m in multipliers] |
|
|
contexts = [c for c in contexts if c <= max_context and c >= 64] |
|
|
|
|
|
print(f"\n๐ Testing extrapolation:") |
|
|
print(f" Training block size: {train_block_size}") |
|
|
print(f" Contexts to test: {contexts}") |
|
|
print("-" * 70) |
|
|
|
|
|
results = { |
|
|
'train_block_size': train_block_size, |
|
|
'tests': [] |
|
|
} |
|
|
|
|
|
baseline_perplexity = None |
|
|
|
|
|
for ctx_len in contexts: |
|
|
is_extrapolation = ctx_len > train_block_size |
|
|
marker = "๐ฌ" if is_extrapolation else "๐" |
|
|
label = f"({ctx_len/train_block_size:.1f}x)" if ctx_len != train_block_size else "(train)" |
|
|
|
|
|
print(f"\n{marker} Context: {ctx_len} tokens {label}") |
|
|
|
|
|
result = measure_perplexity(model, data, ctx_len) |
|
|
|
|
|
if 'error' in result: |
|
|
print(f" โ {result['error']}") |
|
|
result['is_extrapolation'] = is_extrapolation |
|
|
result['extrapolation_ratio'] = ctx_len / train_block_size |
|
|
results['tests'].append(result) |
|
|
continue |
|
|
|
|
|
|
|
|
if ctx_len == train_block_size: |
|
|
baseline_perplexity = result['perplexity'] |
|
|
|
|
|
|
|
|
if baseline_perplexity: |
|
|
degradation = (result['perplexity'] - baseline_perplexity) / baseline_perplexity * 100 |
|
|
else: |
|
|
degradation = 0 |
|
|
|
|
|
result['is_extrapolation'] = is_extrapolation |
|
|
result['extrapolation_ratio'] = ctx_len / train_block_size |
|
|
result['degradation_pct'] = degradation |
|
|
|
|
|
status = "โ
" if degradation < 20 else ("โ ๏ธ" if degradation < 50 else "โ") |
|
|
|
|
|
print(f" Loss: {result['loss']:.4f}") |
|
|
print(f" Perplexity: {result['perplexity']:.2f}") |
|
|
print(f" Degradation vs train: {degradation:+.1f}%") |
|
|
print(f" Memory: {result['peak_memory_mb']:.1f} MB") |
|
|
print(f" Status: {status}") |
|
|
|
|
|
results['tests'].append(result) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def print_summary(results: Dict): |
|
|
"""Prints extrapolation test summary.""" |
|
|
|
|
|
print("\n" + "=" * 70) |
|
|
print("๐ EXTRAPOLATION TEST SUMMARY") |
|
|
print("=" * 70) |
|
|
|
|
|
train_bs = results['train_block_size'] |
|
|
tests = [t for t in results['tests'] if 'error' not in t] |
|
|
|
|
|
if not tests: |
|
|
print("โ No test completed successfully.") |
|
|
return |
|
|
|
|
|
print(f"\n{'Context':<12} {'Ratio':<8} {'Loss':<10} {'PPL':<10} {'Degrad.':<10} {'Mem (MB)':<12}") |
|
|
print("-" * 70) |
|
|
|
|
|
for t in tests: |
|
|
ctx = t['context_len'] |
|
|
ratio = f"{t['extrapolation_ratio']:.1f}x" |
|
|
loss = f"{t['loss']:.4f}" |
|
|
ppl = f"{t['perplexity']:.2f}" |
|
|
deg = f"{t.get('degradation_pct', 0):+.1f}%" |
|
|
mem = f"{t['peak_memory_mb']:.1f}" |
|
|
|
|
|
marker = "๐ฌ" if t['is_extrapolation'] else "๐" |
|
|
print(f"{marker} {ctx:<10} {ratio:<8} {loss:<10} {ppl:<10} {deg:<10} {mem:<12}") |
|
|
|
|
|
|
|
|
extrapolation_tests = [t for t in tests if t['is_extrapolation']] |
|
|
|
|
|
if not extrapolation_tests: |
|
|
print("\nโ ๏ธ No extrapolation test was executed.") |
|
|
return |
|
|
|
|
|
avg_degradation = sum(t.get('degradation_pct', 0) for t in extrapolation_tests) / len(extrapolation_tests) |
|
|
max_successful_ratio = max(t['extrapolation_ratio'] for t in extrapolation_tests if t.get('degradation_pct', 100) < 50) |
|
|
|
|
|
print("\n" + "-" * 70) |
|
|
print(f"Average degradation in extrapolation: {avg_degradation:.1f}%") |
|
|
print(f"Max ratio with <50% degradation: {max_successful_ratio:.1f}x") |
|
|
|
|
|
if avg_degradation < 15: |
|
|
print("\n๐ VERDICT: EXCELLENT! Ripple Field extrapolates with quality!") |
|
|
print(" The ALiBi architecture is working as expected.") |
|
|
elif avg_degradation < 30: |
|
|
print("\nโ
VERDICT: GOOD. Functional extrapolation with moderate degradation.") |
|
|
elif avg_degradation < 50: |
|
|
print("\nโ ๏ธ VERDICT: MARGINAL. Extrapolation works, but with significant loss.") |
|
|
else: |
|
|
print("\nโ VERDICT: FAIL. The model does not extrapolate well beyond training context.") |
|
|
|
|
|
print("=" * 70) |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description='Ripple Field Extrapolation Test') |
|
|
parser.add_argument('--config', type=str, default='medium', |
|
|
choices=['small', 'medium', 'large', 'xlarge']) |
|
|
parser.add_argument('--max-context', type=int, default=4096, |
|
|
help='Max context to test') |
|
|
args = parser.parse_args() |
|
|
|
|
|
print("=" * 70) |
|
|
print("๐ฌ EXTRAPOLATION TEST - RippleGPT ALiBi Validation") |
|
|
print("=" * 70) |
|
|
|
|
|
print("\nโ ๏ธ NOTE: This test validates the central thesis of RippleGPT:") |
|
|
print(" 'Train on N tokens, infer on 2N-4N with quality'") |
|
|
print(" Memory scales with O(Tยฒ) - OOM expected in very long contexts.") |
|
|
|
|
|
|
|
|
try: |
|
|
model, config = load_model(args.config) |
|
|
except FileNotFoundError as e: |
|
|
print(f"\nโ {e}") |
|
|
return 1 |
|
|
|
|
|
|
|
|
try: |
|
|
data = load_data() |
|
|
print(f"\n๐ Data loaded: {len(data)} tokens") |
|
|
except FileNotFoundError as e: |
|
|
print(f"\nโ {e}") |
|
|
return 1 |
|
|
|
|
|
|
|
|
results = run_extrapolation_test(model, config, data, args.max_context) |
|
|
|
|
|
|
|
|
print_summary(results) |
|
|
|
|
|
return 0 |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
exit(main()) |
|
|
|