|
|
""" |
|
|
train_code.py - Trains RippleGPT on Python code for validation. |
|
|
|
|
|
This script uses the prepared dataset to train the model in code completion. |
|
|
The focus is to validate if the architecture can learn code structures. |
|
|
|
|
|
Usage: |
|
|
python validation/train_code.py |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import time |
|
|
import pickle |
|
|
import math |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) |
|
|
|
|
|
from src.model import RippleGPT |
|
|
from src.config import RippleConfig |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DATA_DIR = os.path.join(os.path.dirname(__file__), 'data') |
|
|
OUT_DIR = os.path.join(os.path.dirname(__file__), 'checkpoints') |
|
|
|
|
|
|
|
|
BATCH_SIZE = 32 |
|
|
BLOCK_SIZE = 256 |
|
|
MAX_ITERS = 15000 |
|
|
EVAL_INTERVAL = 500 |
|
|
EVAL_ITERS = 200 |
|
|
LOG_INTERVAL = 100 |
|
|
|
|
|
|
|
|
N_LAYER = 6 |
|
|
N_HEAD = 8 |
|
|
N_EMBD = 384 |
|
|
DROPOUT = 0.1 |
|
|
|
|
|
|
|
|
LEARNING_RATE = 1e-3 |
|
|
WARMUP_ITERS = 200 |
|
|
|
|
|
|
|
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_batch(split: str, data_dir: str = DATA_DIR): |
|
|
"""Loads a data batch.""" |
|
|
if split == 'train': |
|
|
data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') |
|
|
else: |
|
|
data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') |
|
|
|
|
|
ix = torch.randint(len(data) - BLOCK_SIZE, (BATCH_SIZE,)) |
|
|
x = torch.stack([torch.from_numpy((data[i:i+BLOCK_SIZE].astype(np.int64))) for i in ix]) |
|
|
y = torch.stack([torch.from_numpy((data[i+1:i+1+BLOCK_SIZE].astype(np.int64))) for i in ix]) |
|
|
|
|
|
if DEVICE == 'cuda': |
|
|
x, y = x.pin_memory().to(DEVICE, non_blocking=True), y.pin_memory().to(DEVICE, non_blocking=True) |
|
|
else: |
|
|
x, y = x.to(DEVICE), y.to(DEVICE) |
|
|
|
|
|
return x, y |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def estimate_loss(model, ctx): |
|
|
"""Estimates loss on train and validation splits.""" |
|
|
out = {} |
|
|
model.eval() |
|
|
|
|
|
for split in ['train', 'val']: |
|
|
losses = torch.zeros(EVAL_ITERS) |
|
|
for k in range(EVAL_ITERS): |
|
|
X, Y = get_batch(split) |
|
|
with ctx: |
|
|
logits, loss = model(X, Y) |
|
|
losses[k] = loss.item() |
|
|
out[split] = losses.mean() |
|
|
|
|
|
model.train() |
|
|
return out |
|
|
|
|
|
|
|
|
def get_lr(it: int) -> float: |
|
|
"""Learning rate with linear warmup and cosine decay.""" |
|
|
|
|
|
if it < WARMUP_ITERS: |
|
|
return LEARNING_RATE * it / WARMUP_ITERS |
|
|
|
|
|
if it > MAX_ITERS: |
|
|
return LEARNING_RATE * 0.1 |
|
|
|
|
|
decay_ratio = (it - WARMUP_ITERS) / (MAX_ITERS - WARMUP_ITERS) |
|
|
assert 0 <= decay_ratio <= 1 |
|
|
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) |
|
|
return LEARNING_RATE * (0.1 + 0.9 * coeff) |
|
|
|
|
|
|
|
|
def train(): |
|
|
"""Main training loop.""" |
|
|
|
|
|
print("=" * 60) |
|
|
print("๐ RIPPLEGPT TRAINING FOR CODE COMPLETION") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
if not os.path.exists(os.path.join(DATA_DIR, 'train.bin')): |
|
|
print("โ Data not found!") |
|
|
print(" Run first: python validation/code/prepare_code_data.py") |
|
|
return |
|
|
|
|
|
|
|
|
os.makedirs(OUT_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
meta_path = os.path.join(DATA_DIR, 'meta.pkl') |
|
|
with open(meta_path, 'rb') as f: |
|
|
meta = pickle.load(f) |
|
|
vocab_size = meta['vocab_size'] |
|
|
print(f"\n๐ Vocab size: {vocab_size}") |
|
|
|
|
|
|
|
|
torch.manual_seed(1337) |
|
|
|
|
|
|
|
|
print(f"\n๐ง Initializing model...") |
|
|
config = RippleConfig( |
|
|
vocab_size=vocab_size, |
|
|
block_size=BLOCK_SIZE, |
|
|
n_layer=N_LAYER, |
|
|
n_head=N_HEAD, |
|
|
n_embd=N_EMBD, |
|
|
dropout=DROPOUT, |
|
|
use_absolute_pos_emb=False |
|
|
) |
|
|
|
|
|
model = RippleGPT(config) |
|
|
model.to(DEVICE) |
|
|
|
|
|
num_params = model.get_num_params() |
|
|
print(f" Parameters: {num_params / 1e6:.2f}M") |
|
|
print(f" Device: {DEVICE}") |
|
|
print(f" Block size: {BLOCK_SIZE}") |
|
|
print(f" Batch size: {BATCH_SIZE}") |
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE) |
|
|
|
|
|
|
|
|
from contextlib import nullcontext |
|
|
ctx = nullcontext() if DEVICE in ['cpu', 'mps'] else torch.amp.autocast(device_type=DEVICE, dtype=torch.bfloat16) |
|
|
|
|
|
|
|
|
print(f"\n๐ Starting training ({MAX_ITERS} iterations)...") |
|
|
print("-" * 60) |
|
|
|
|
|
X, Y = get_batch('train') |
|
|
t0 = time.time() |
|
|
best_val_loss = float('inf') |
|
|
|
|
|
for iter_num in range(MAX_ITERS): |
|
|
|
|
|
lr = get_lr(iter_num) |
|
|
for param_group in optimizer.param_groups: |
|
|
param_group['lr'] = lr |
|
|
|
|
|
|
|
|
if iter_num % EVAL_INTERVAL == 0 and iter_num > 0: |
|
|
losses = estimate_loss(model, ctx) |
|
|
print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") |
|
|
|
|
|
|
|
|
if losses['val'] < best_val_loss: |
|
|
best_val_loss = losses['val'] |
|
|
checkpoint = { |
|
|
'model': model.state_dict(), |
|
|
'optimizer': optimizer.state_dict(), |
|
|
'config': config, |
|
|
'iter_num': iter_num, |
|
|
'best_val_loss': best_val_loss, |
|
|
} |
|
|
torch.save(checkpoint, os.path.join(OUT_DIR, 'ckpt_best.pt')) |
|
|
print(f" ๐พ Best model saved! (val_loss: {best_val_loss:.4f})") |
|
|
|
|
|
|
|
|
with ctx: |
|
|
logits, loss = model(X, Y) |
|
|
|
|
|
optimizer.zero_grad(set_to_none=True) |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
|
|
|
t1 = time.time() |
|
|
dt = t1 - t0 |
|
|
t0 = t1 |
|
|
|
|
|
if iter_num % LOG_INTERVAL == 0: |
|
|
decay_stats = model.get_decay_stats() |
|
|
print(f"iter {iter_num}: loss {loss.item():.4f}, time {dt*1000:.2f}ms, lr {lr:.6f}") |
|
|
print(f" Ripple Field Stats -> Mean Decay: {decay_stats['mean']:.4f}, Range: [{decay_stats['min']:.4f}, {decay_stats['max']:.4f}]") |
|
|
|
|
|
|
|
|
X, Y = get_batch('train') |
|
|
|
|
|
|
|
|
checkpoint = { |
|
|
'model': model.state_dict(), |
|
|
'optimizer': optimizer.state_dict(), |
|
|
'config': config, |
|
|
'iter_num': MAX_ITERS, |
|
|
'best_val_loss': best_val_loss, |
|
|
} |
|
|
torch.save(checkpoint, os.path.join(OUT_DIR, 'ckpt_final.pt')) |
|
|
|
|
|
print("-" * 60) |
|
|
print(f"โ
Training complete!") |
|
|
print(f" Best val loss: {best_val_loss:.4f}") |
|
|
print(f" Checkpoints saved to: {OUT_DIR}") |
|
|
print(f"\nNext step: python validation/code/validate_code.py") |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
train() |
|
|
|