|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from src.model import RippleGPT |
|
|
from src.config import RippleConfig |
|
|
import os |
|
|
import pickle |
|
|
|
|
|
device = 'mps' if torch.backends.mps.is_available() else 'cpu' |
|
|
|
|
|
def load_model(ckpt_path): |
|
|
checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False) |
|
|
config = RippleConfig(**checkpoint['model_args']) |
|
|
|
|
|
|
|
|
config.block_size = 1024 |
|
|
model = RippleGPT(config) |
|
|
|
|
|
state_dict = checkpoint['model'] |
|
|
unwanted_prefix = '_orig_mod.' |
|
|
for k,v in list(state_dict.items()): |
|
|
if k.startswith(unwanted_prefix): |
|
|
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) |
|
|
model.load_state_dict(state_dict) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
def measure_perplexity(model, data_tensor, context_len): |
|
|
""" |
|
|
Measures how surprised the model is. Lower is better. |
|
|
We test on a context length LARGER than training. |
|
|
""" |
|
|
max_batches = 10 |
|
|
total_loss = 0 |
|
|
with torch.no_grad(): |
|
|
for i in range(max_batches): |
|
|
|
|
|
|
|
|
if i * context_len + context_len + 1 > len(data_tensor): break |
|
|
|
|
|
x = data_tensor[i*context_len : i*context_len + context_len].unsqueeze(0).to(device) |
|
|
y = data_tensor[i*context_len+1 : i*context_len + context_len+1].unsqueeze(0).to(device) |
|
|
|
|
|
_, loss = model(x, y) |
|
|
total_loss += loss.item() |
|
|
|
|
|
avg_loss = total_loss / max_batches |
|
|
perplexity = torch.exp(torch.tensor(avg_loss)) |
|
|
return avg_loss, perplexity.item() |
|
|
|
|
|
|
|
|
print("Loading data...") |
|
|
dataset_dir = 'data' |
|
|
val_data_path = os.path.join(dataset_dir, 'val.bin') |
|
|
meta_path = os.path.join(dataset_dir, 'meta.pkl') |
|
|
|
|
|
if os.path.exists(val_data_path) and os.path.exists(meta_path): |
|
|
print(f"Loading official validation data from {val_data_path}...") |
|
|
import numpy as np |
|
|
val_data_np = np.fromfile(val_data_path, dtype=np.uint16) |
|
|
val_data = torch.from_numpy(val_data_np.astype(np.int64)) |
|
|
else: |
|
|
print("Official validation data not found. Downloading tinyshakespeare for demo...") |
|
|
|
|
|
import requests |
|
|
text = requests.get("https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt").text |
|
|
|
|
|
|
|
|
chars = sorted(list(set(text))) |
|
|
stoi = { ch:i for i,ch in enumerate(chars) } |
|
|
encode = lambda s: [stoi[c] for c in s] |
|
|
val_data = torch.tensor(encode(text[int(0.9*len(text)):]), dtype=torch.long) |
|
|
|
|
|
|
|
|
print("Loading RippleGPT...") |
|
|
ckpt_path = 'out/ckpt_best.pt' if os.path.exists('out/ckpt_best.pt') else 'out/ckpt.pt' |
|
|
print(f"Loading checkpoint from {ckpt_path}") |
|
|
model = load_model(ckpt_path) |
|
|
|
|
|
|
|
|
loss_256, ppl_256 = measure_perplexity(model, val_data, 256) |
|
|
print(f"Context 256 (Trained size): Loss {loss_256:.4f}, Perplexity {ppl_256:.2f}") |
|
|
|
|
|
|
|
|
try: |
|
|
loss_512, ppl_512 = measure_perplexity(model, val_data, 512) |
|
|
print(f"Context 512 (2x Training): Loss {loss_512:.4f}, Perplexity {ppl_512:.2f}") |
|
|
print("✅ EXTRAPOLATION SUCCESSFUL: Model handled 2x context length!") |
|
|
except Exception as e: |
|
|
print(f"❌ EXTRAPOLATION FAILED: {e}") |
|
|
|
|
|
|
|
|
try: |
|
|
loss_1024, ppl_1024 = measure_perplexity(model, val_data, 1024) |
|
|
print(f"Context 1024 (4x Training): Loss {loss_1024:.4f}, Perplexity {ppl_1024:.2f}") |
|
|
except Exception as e: |
|
|
pass |
|
|
|