File size: 3,784 Bytes
b47957e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
import torch.nn.functional as F
from src.model import RippleGPT
from src.config import RippleConfig
import os
import pickle

device = 'mps' if torch.backends.mps.is_available() else 'cpu'

def load_model(ckpt_path):
    checkpoint = torch.load(ckpt_path, map_location=device, weights_only=False)
    config = RippleConfig(**checkpoint['model_args'])
    # FORCE the model to accept longer context than training (256)
    # We set block_size to 1024 to see if it breaks or works
    config.block_size = 1024 
    model = RippleGPT(config)
    
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k,v in list(state_dict.items()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)
    model.to(device)
    model.eval()
    return model

def measure_perplexity(model, data_tensor, context_len):
    """
    Measures how surprised the model is. Lower is better.
    We test on a context length LARGER than training.
    """
    max_batches = 10
    total_loss = 0
    with torch.no_grad():
        for i in range(max_batches):
            # Grab a chunk of size 'context_len'
            # If model was trained on 256, and we test 1024, this validates Ripple
            if i * context_len + context_len + 1 > len(data_tensor): break
            
            x = data_tensor[i*context_len : i*context_len + context_len].unsqueeze(0).to(device)
            y = data_tensor[i*context_len+1 : i*context_len + context_len+1].unsqueeze(0).to(device)
            
            _, loss = model(x, y)
            total_loss += loss.item()
            
    avg_loss = total_loss / max_batches
    perplexity = torch.exp(torch.tensor(avg_loss))
    return avg_loss, perplexity.item()

# Load Data
print("Loading data...")
dataset_dir = 'data'
val_data_path = os.path.join(dataset_dir, 'val.bin')
meta_path = os.path.join(dataset_dir, 'meta.pkl')

if os.path.exists(val_data_path) and os.path.exists(meta_path):
    print(f"Loading official validation data from {val_data_path}...")
    import numpy as np
    val_data_np = np.fromfile(val_data_path, dtype=np.uint16)
    val_data = torch.from_numpy(val_data_np.astype(np.int64))
else:
    print("Official validation data not found. Downloading tinyshakespeare for demo...")
    # Load validation text (simulate loading validation data)
    import requests
    text = requests.get("https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt").text
    
    # Simple encoding if meta not found
    chars = sorted(list(set(text)))
    stoi = { ch:i for i,ch in enumerate(chars) }
    encode = lambda s: [stoi[c] for c in s]
    val_data = torch.tensor(encode(text[int(0.9*len(text)):]), dtype=torch.long)

# Load Model
print("Loading RippleGPT...")
ckpt_path = 'out/ckpt_best.pt' if os.path.exists('out/ckpt_best.pt') else 'out/ckpt.pt'
print(f"Loading checkpoint from {ckpt_path}")
model = load_model(ckpt_path)

# TEST 1: Standard Context (256)
loss_256, ppl_256 = measure_perplexity(model, val_data, 256)
print(f"Context 256 (Trained size): Loss {loss_256:.4f}, Perplexity {ppl_256:.2f}")

# TEST 2: Extrapolation (512) - The Scientific Proof
try:
    loss_512, ppl_512 = measure_perplexity(model, val_data, 512)
    print(f"Context 512 (2x Training):  Loss {loss_512:.4f}, Perplexity {ppl_512:.2f}")
    print("✅ EXTRAPOLATION SUCCESSFUL: Model handled 2x context length!")
except Exception as e:
    print(f"❌ EXTRAPOLATION FAILED: {e}")

# TEST 3: Extreme Extrapolation (1024)
try:
    loss_1024, ppl_1024 = measure_perplexity(model, val_data, 1024)
    print(f"Context 1024 (4x Training): Loss {loss_1024:.4f}, Perplexity {ppl_1024:.2f}")
except Exception as e:
    pass