File size: 7,386 Bytes
148b631
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
"""
train_code.py - Trains RippleGPT on Python code for validation.

This script uses the prepared dataset to train the model in code completion.
The focus is to validate if the architecture can learn code structures.

Usage:
    python validation/train_code.py
"""

import os
import sys
import time
import pickle
import math
import numpy as np
import torch

# Add root directory to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__))))

from src.model import RippleGPT
from src.config import RippleConfig

# -----------------------------------------------------------------------------
# Configuration
# -----------------------------------------------------------------------------

# Directories
DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')
OUT_DIR = os.path.join(os.path.dirname(__file__), 'checkpoints')

# Training Hyperparameters
BATCH_SIZE = 32
BLOCK_SIZE = 256
MAX_ITERS = 15000  # Optimized to prevent saturation
EVAL_INTERVAL = 500
EVAL_ITERS = 200
LOG_INTERVAL = 100

# Model Hyperparameters (The Sweet Spot)
N_LAYER = 6  
N_HEAD = 8   
N_EMBD = 384 
DROPOUT = 0.1 

# Optimization
LEARNING_RATE = 1e-3 # Restores aggressive LR to learn fast
WARMUP_ITERS = 200

# Device
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'

# -----------------------------------------------------------------------------
# Helper Functions
# -----------------------------------------------------------------------------

def get_batch(split: str, data_dir: str = DATA_DIR):
    """Loads a data batch."""
    if split == 'train':
        data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
    else:
        data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
    
    ix = torch.randint(len(data) - BLOCK_SIZE, (BATCH_SIZE,))
    x = torch.stack([torch.from_numpy((data[i:i+BLOCK_SIZE].astype(np.int64))) for i in ix])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+BLOCK_SIZE].astype(np.int64))) for i in ix])
    
    if DEVICE == 'cuda':
        x, y = x.pin_memory().to(DEVICE, non_blocking=True), y.pin_memory().to(DEVICE, non_blocking=True)
    else:
        x, y = x.to(DEVICE), y.to(DEVICE)
    
    return x, y


@torch.no_grad()
def estimate_loss(model, ctx):
    """Estimates loss on train and validation splits."""
    out = {}
    model.eval()
    
    for split in ['train', 'val']:
        losses = torch.zeros(EVAL_ITERS)
        for k in range(EVAL_ITERS):
            X, Y = get_batch(split)
            with ctx:
                logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    
    model.train()
    return out


def get_lr(it: int) -> float:
    """Learning rate with linear warmup and cosine decay."""
    # 1) Linear Warmup
    if it < WARMUP_ITERS:
        return LEARNING_RATE * it / WARMUP_ITERS
    # 2) If past the end, maintain minimum
    if it > MAX_ITERS:
        return LEARNING_RATE * 0.1
    # 3) Cosine Decay
    decay_ratio = (it - WARMUP_ITERS) / (MAX_ITERS - WARMUP_ITERS)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return LEARNING_RATE * (0.1 + 0.9 * coeff) # Decays to 10% of original


def train():
    """Main training loop."""
    
    print("=" * 60)
    print("๐Ÿš€ RIPPLEGPT TRAINING FOR CODE COMPLETION")
    print("=" * 60)
    
    # Check if data exists
    if not os.path.exists(os.path.join(DATA_DIR, 'train.bin')):
        print("โŒ Data not found!")
        print("   Run first: python validation/code/prepare_code_data.py")
        return
    
    # Create checkpoints directory
    os.makedirs(OUT_DIR, exist_ok=True)
    
    # Load vocabulary
    meta_path = os.path.join(DATA_DIR, 'meta.pkl')
    with open(meta_path, 'rb') as f:
        meta = pickle.load(f)
    vocab_size = meta['vocab_size']
    print(f"\n๐Ÿ“š Vocab size: {vocab_size}")
    
    # Seed for reproducibility
    torch.manual_seed(1337)
    
    # Initialize model
    print(f"\n๐Ÿ”ง Initializing model...")
    config = RippleConfig(
        vocab_size=vocab_size,
        block_size=BLOCK_SIZE,
        n_layer=N_LAYER,
        n_head=N_HEAD,
        n_embd=N_EMBD,
        dropout=DROPOUT,
        use_absolute_pos_emb=False  # Use Ripple Field!
    )
    
    model = RippleGPT(config)
    model.to(DEVICE)
    
    num_params = model.get_num_params()
    print(f"   Parameters: {num_params / 1e6:.2f}M")
    print(f"   Device: {DEVICE}")
    print(f"   Block size: {BLOCK_SIZE}")
    print(f"   Batch size: {BATCH_SIZE}")
    
    # Optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    
    # Autocast context
    from contextlib import nullcontext
    ctx = nullcontext() if DEVICE in ['cpu', 'mps'] else torch.amp.autocast(device_type=DEVICE, dtype=torch.bfloat16)
    
    # Training loop
    print(f"\n๐Ÿ“ˆ Starting training ({MAX_ITERS} iterations)...")
    print("-" * 60)
    
    X, Y = get_batch('train')
    t0 = time.time()
    best_val_loss = float('inf')
    
    for iter_num in range(MAX_ITERS):
        # Learning rate scheduling
        lr = get_lr(iter_num)
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
        
        # Periodic evaluation
        if iter_num % EVAL_INTERVAL == 0 and iter_num > 0:
            losses = estimate_loss(model, ctx)
            print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
            
            # Save best model
            if losses['val'] < best_val_loss:
                best_val_loss = losses['val']
                checkpoint = {
                    'model': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'config': config,
                    'iter_num': iter_num,
                    'best_val_loss': best_val_loss,
                }
                torch.save(checkpoint, os.path.join(OUT_DIR, 'ckpt_best.pt'))
                print(f"   ๐Ÿ’พ Best model saved! (val_loss: {best_val_loss:.4f})")
        
        # Forward/backward
        with ctx:
            logits, loss = model(X, Y)
        
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        
        # Logging
        t1 = time.time()
        dt = t1 - t0
        t0 = t1
        
        if iter_num % LOG_INTERVAL == 0:
            decay_stats = model.get_decay_stats()
            print(f"iter {iter_num}: loss {loss.item():.4f}, time {dt*1000:.2f}ms, lr {lr:.6f}")
            print(f"   Ripple Field Stats -> Mean Decay: {decay_stats['mean']:.4f}, Range: [{decay_stats['min']:.4f}, {decay_stats['max']:.4f}]")
        
        # Next batch
        X, Y = get_batch('train')
    
    # Save final checkpoint
    checkpoint = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'config': config,
        'iter_num': MAX_ITERS,
        'best_val_loss': best_val_loss,
    }
    torch.save(checkpoint, os.path.join(OUT_DIR, 'ckpt_final.pt'))
    
    print("-" * 60)
    print(f"โœ… Training complete!")
    print(f"   Best val loss: {best_val_loss:.4f}")
    print(f"   Checkpoints saved to: {OUT_DIR}")
    print(f"\nNext step: python validation/code/validate_code.py")


if __name__ == '__main__':
    train()