RippleGPT-Nano / test_extrapolation.py
Tavernari's picture
Upload folder using huggingface_hub
b47957e verified
import torch
import torch.nn.functional as F
from src.model import RippleGPT
from src.config import RippleConfig
import os
import pickle
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
def load_model(ckpt_path):
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
config = RippleConfig(**checkpoint['model_args'])
# FORCE the model to accept longer context than training (256)
# We set block_size to 1024 to see if it breaks or works
config.block_size = 1024
model = RippleGPT(config)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k,v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
model.to(device)
model.eval()
return model
def measure_perplexity(model, data_tensor, context_len):
"""
Measures how surprised the model is. Lower is better.
We test on a context length LARGER than training.
"""
max_batches = 10
total_loss = 0
with torch.no_grad():
for i in range(max_batches):
# Grab a chunk of size 'context_len'
# If model was trained on 256, and we test 1024, this validates Ripple
if i * context_len + context_len + 1 > len(data_tensor): break
x = data_tensor[i*context_len : i*context_len + context_len].unsqueeze(0).to(device)
y = data_tensor[i*context_len+1 : i*context_len + context_len+1].unsqueeze(0).to(device)
_, loss = model(x, y)
total_loss += loss.item()
avg_loss = total_loss / max_batches
perplexity = torch.exp(torch.tensor(avg_loss))
return avg_loss, perplexity.item()
# Load Data
print("Loading data...")
dataset_dir = 'data'
val_data_path = os.path.join(dataset_dir, 'val.bin')
meta_path = os.path.join(dataset_dir, 'meta.pkl')
if os.path.exists(val_data_path) and os.path.exists(meta_path):
print(f"Loading official validation data from {val_data_path}...")
import numpy as np
val_data_np = np.fromfile(val_data_path, dtype=np.uint16)
val_data = torch.from_numpy(val_data_np.astype(np.int64))
else:
print("Official validation data not found. Downloading tinyshakespeare for demo...")
# Load validation text (simulate loading validation data)
import requests
text = requests.get("https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt").text
# Simple encoding if meta not found
chars = sorted(list(set(text)))
stoi = { ch:i for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s]
val_data = torch.tensor(encode(text[int(0.9*len(text)):]), dtype=torch.long)
# Load Model
print("Loading RippleGPT...")
ckpt_path = 'out/ckpt_best.pt' if os.path.exists('out/ckpt_best.pt') else 'out/ckpt.pt'
print(f"Loading checkpoint from {ckpt_path}")
model = load_model(ckpt_path)
# TEST 1: Standard Context (256)
loss_256, ppl_256 = measure_perplexity(model, val_data, 256)
print(f"Context 256 (Trained size): Loss {loss_256:.4f}, Perplexity {ppl_256:.2f}")
# TEST 2: Extrapolation (512) - The Scientific Proof
try:
loss_512, ppl_512 = measure_perplexity(model, val_data, 512)
print(f"Context 512 (2x Training): Loss {loss_512:.4f}, Perplexity {ppl_512:.2f}")
print("✅ EXTRAPOLATION SUCCESSFUL: Model handled 2x context length!")
except Exception as e:
print(f"❌ EXTRAPOLATION FAILED: {e}")
# TEST 3: Extreme Extrapolation (1024)
try:
loss_1024, ppl_1024 = measure_perplexity(model, val_data, 1024)
print(f"Context 1024 (4x Training): Loss {loss_1024:.4f}, Perplexity {ppl_1024:.2f}")
except Exception as e:
pass