Antigravity / remote-gpu-client /examples /deep_nanogpt_resumable.py
AdriBat1
Add Deep-NanoGPT experiment (Phase 1 & 2): resumable training, inference, 72-layer models
671ce97
import sys
import traceback
import os
import time
print("πŸ”Ή DEBUG: Script Wrapper Started")
try:
import torch
import torch.nn as nn
from torch.nn import functional as F
import requests
import matplotlib.pyplot as plt
print("πŸ”Ή DEBUG: Imports successful")
# --- Hyperparameters ---
batch_size = 32 # Reduced to fit 2x 72-layer models in 16GB
block_size = 256
learning_rate = 3e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"πŸ”Ή DEBUG: Device selected: {device}")
# HARD MODE CONFIG
n_embd = 128
n_head = 4
n_layer = 72 # Deep!
dropout = 0.1
# Training Loop Config
TOTAL_ITERS = 200 # Target
CHUNK_ITERS = 20 # Steps per execution (safe < 300s)
eval_interval = 20 # Log / Plot every 20 steps
eval_iters = 10 # Fast eval
torch.manual_seed(1337)
# --- Persistence Paths ---
storage_dir = "/home/user/app/storage/deep_experiment_v2"
os.makedirs(storage_dir, exist_ok=True)
ckpt_path_a = os.path.join(storage_dir, 'ckpt_a.pt')
ckpt_path_b = os.path.join(storage_dir, 'ckpt_b.pt')
history_path = os.path.join(storage_dir, 'history.pt')
# --- Data Loading ---
def get_data():
if not os.path.exists('input.txt'):
url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
# print(f"Downloading {url}...")
data = requests.get(url).text
with open('input.txt', 'w') as f:
f.write(data)
else:
with open('input.txt', 'r') as f:
data = f.read()
return data
data = get_data()
chars = sorted(list(set(data)))
vocab_size = len(chars)
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])
data_train = torch.tensor(encode(data), dtype=torch.long)
n = int(0.9 * len(data_train))
train_data = data_train[:n]
val_data = data_train[n:]
def get_batch(split):
data = train_data if split == 'train' else val_data
ix = torch.randint(len(data) - block_size, (batch_size,))
x = torch.stack([data[i:i+block_size] for i in ix])
y = torch.stack([data[i+1:i+block_size+1] for i in ix])
x, y = x.to(device), y.to(device)
return x, y
@torch.no_grad()
def estimate_loss(model):
out = {}
model.eval()
for split in ['train', 'val']:
losses = torch.zeros(eval_iters)
for k in range(eval_iters):
X, Y = get_batch(split)
logits, loss = model(X, Y)
losses[k] = loss.item()
out[split] = losses.mean()
model.train()
return out
# --- Components ---
class Head(nn.Module):
def __init__(self, head_size):
super().__init__()
self.key = nn.Linear(n_embd, head_size, bias=False)
self.query = nn.Linear(n_embd, head_size, bias=False)
self.value = nn.Linear(n_embd, head_size, bias=False)
self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
self.dropout = nn.Dropout(dropout)
def forward(self, x):
B,T,C = x.shape
k = self.key(x)
q = self.query(x)
wei = q @ k.transpose(-2, -1) * C**-0.5
wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
wei = self.dropout(wei)
v = self.value(x)
out = wei @ v
return out
class MultiHeadAttention(nn.Module):
def __init__(self, num_heads, head_size):
super().__init__()
self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
self.proj = nn.Linear(n_embd, n_embd)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
out = torch.cat([h(x) for h in self.heads], dim=-1)
out = self.dropout(self.proj(out))
return out
class FeedForward(nn.Module):
def __init__(self, n_embd):
super().__init__()
self.net = nn.Sequential(
nn.Linear(n_embd, 4 * n_embd),
nn.ReLU(),
nn.Linear(4 * n_embd, n_embd),
nn.Dropout(dropout),
)
def forward(self, x):
return self.net(x)
# --- Standard Block ---
class BlockStandard(nn.Module):
def __init__(self, n_embd, n_head):
super().__init__()
head_size = n_embd // n_head
self.sa = MultiHeadAttention(n_head, head_size)
self.ffwd = FeedForward(n_embd)
self.ln1 = nn.LayerNorm(n_embd)
self.ln2 = nn.LayerNorm(n_embd)
def forward(self, x):
x = x + self.sa(self.ln1(x))
x = x + self.ffwd(self.ln2(x))
return x
# --- mHC Block ---
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self._norm(x.float()).type_as(x) * self.weight
class BlockMHC(nn.Module):
def __init__(self, n_embd, n_head):
super().__init__()
head_size = n_embd // n_head
self.sa = MultiHeadAttention(n_head, head_size)
self.ffwd = FeedForward(n_embd)
self.alpha1 = nn.Parameter(torch.tensor(0.9))
self.beta1 = nn.Parameter(torch.tensor(0.1))
self.ln1 = RMSNorm(n_embd)
self.alpha2 = nn.Parameter(torch.tensor(0.9))
self.beta2 = nn.Parameter(torch.tensor(0.1))
self.ln2 = RMSNorm(n_embd)
def forward(self, x):
mix1 = self.alpha1 * x + self.beta1 * self.sa(x)
x = self.ln1(mix1)
mix2 = self.alpha2 * x + self.beta2 * self.ffwd(x)
x = self.ln2(mix2)
return x
class GPT(nn.Module):
def __init__(self, arch_type='standard'):
super().__init__()
self.arch_type = arch_type
self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
self.position_embedding_table = nn.Embedding(block_size, n_embd)
if arch_type == 'standard':
self.blocks = nn.Sequential(*[BlockStandard(n_embd, n_head) for _ in range(n_layer)])
self.ln_f = nn.LayerNorm(n_embd)
elif arch_type == 'mhc':
self.blocks = nn.Sequential(*[BlockMHC(n_embd, n_head) for _ in range(n_layer)])
self.ln_f = RMSNorm(n_embd)
self.lm_head = nn.Linear(n_embd, vocab_size)
def forward(self, idx, targets=None):
B, T = idx.shape
tok_emb = self.token_embedding_table(idx)
pos_emb = self.position_embedding_table(torch.arange(T, device=device))
x = tok_emb + pos_emb
x = self.blocks(x)
x = self.ln_f(x)
logits = self.lm_head(x)
if targets is None:
loss = None
else:
B, T, C = logits.shape
logits = logits.view(B*T, C)
targets = targets.view(B*T)
loss = F.cross_entropy(logits, targets)
return logits, loss
def generate(self, idx, max_new_tokens):
for _ in range(max_new_tokens):
idx_cond = idx[:, -block_size:]
logits, loss = self(idx_cond)
logits = logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
idx = torch.cat((idx, idx_next), dim=1)
return idx
# --- Resumable Logic ---
history = {'a': [], 'b': [], 'steps': []}
current_iter = 0
if os.path.exists(history_path):
print(f"πŸ”„ Resuming from history: {history_path}")
try:
history = torch.load(history_path)
if history['steps']:
current_iter = history['steps'][-1]
print(f" Last Checkpoint: Step {current_iter}")
except:
print(" ⚠️ Error loading history, starting fresh.")
current_iter = 0
model_a = GPT(arch_type='standard').to(device)
opt_a = torch.optim.AdamW(model_a.parameters(), lr=learning_rate)
model_b = GPT(arch_type='mhc').to(device)
opt_b = torch.optim.AdamW(model_b.parameters(), lr=learning_rate)
# Load Weights if exist
if os.path.exists(ckpt_path_a):
model_a.load_state_dict(torch.load(ckpt_path_a))
opt_a.load_state_dict(torch.load(os.path.join(storage_dir, 'opt_a.pt')))
if os.path.exists(ckpt_path_b):
model_b.load_state_dict(torch.load(ckpt_path_b))
opt_b.load_state_dict(torch.load(os.path.join(storage_dir, 'opt_b.pt')))
# --- Training Chunk ---
print(f"πŸš€ Training Chunk: Steps {current_iter} -> {current_iter + CHUNK_ITERS} (Target: {TOTAL_ITERS})")
start_chunk = time.time()
for i in range(CHUNK_ITERS):
global_step = current_iter + i + 1
if global_step > TOTAL_ITERS:
break
# Eval & Log
if global_step % eval_interval == 0 or global_step == 1:
la = estimate_loss(model_a)
lb = estimate_loss(model_b)
print(f"Step {global_step}: Loss A={la['train']:.4f}, Loss B={lb['train']:.4f}")
history['steps'].append(global_step)
history['a'].append(la['train'].item())
history['b'].append(lb['train'].item())
# Save History
torch.save(history, history_path)
# Update Plots inside the chunk to give realtime feedback
plt.figure(figsize=(10, 6))
plt.plot(history['steps'], history['a'], label='Standard GPT', marker='o')
plt.plot(history['steps'], history['b'], label='DeepMHC', marker='x')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.title(f'Training Stability: Deep-NanoGPT (Layers={n_layer})')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(storage_dir, 'comparison_loss_v2.png'))
plt.close()
# Step A
xb, yb = get_batch('train')
logits, loss = model_a(xb, yb)
opt_a.zero_grad(set_to_none=True)
loss.backward()
# Optional: Clip grad to prevent immediate explosion, but let's see it fail naturally
# torch.nn.utils.clip_grad_norm_(model_a.parameters(), 1.0)
opt_a.step()
# Step B
xb, yb = get_batch('train')
logits, loss = model_b(xb, yb)
opt_b.zero_grad(set_to_none=True)
loss.backward()
opt_b.step()
print(f"🏁 Chunk Done in {time.time()-start_chunk:.2f}s. Global Step: {global_step}")
# --- Save Checkpoints ---
torch.save(model_a.state_dict(), ckpt_path_a)
torch.save(opt_a.state_dict(), os.path.join(storage_dir, 'opt_a.pt'))
torch.save(model_b.state_dict(), ckpt_path_b)
torch.save(opt_b.state_dict(), os.path.join(storage_dir, 'opt_b.pt'))
torch.save(history, history_path)
print("πŸ’Ύ Checkpoints Saved.")
# Return 1 if more work needed, 0 if done
if global_step < TOTAL_ITERS:
print("CONTINUE_TRAINING") # Signal to client
else:
print("TRAINING_COMPLETE")
# Final Generation
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print("Generating final samples...")
out_a = decode(model_a.generate(context, max_new_tokens=200)[0].tolist())
out_b = decode(model_b.generate(context, max_new_tokens=200)[0].tolist())
sp = os.path.join(storage_dir, 'generation_sample_v2.txt')
with open(sp, 'w') as f:
f.write(f"--- MODEL A (Standard, L={n_layer}) ---\n{out_a}\n\n")
f.write(f"--- MODEL B (mHC, L={n_layer}) ---\n{out_b}\n\n")
# Copy to cwd for download
os.system(f"cp {os.path.join(storage_dir, 'comparison_loss_v2.png')} .")
os.system(f"cp {sp} .")
print("βœ… Final Artifacts prepared for download.")
except Exception as e:
print(f"\n❌ FATAL SCRIPT ERROR: {e}")
traceback.print_exc()