AbstractPhil's picture
Update trainer.py
9e5a420 verified
#@title Geometric Autoregressive LM - Full Training with HF Upload + TensorBoard generated valid shakespere
"""
Prototype LM for geometric simplex structures.
Requires the geometricvocab's SimplexFactory for valid simplex representations, or the simplex behavior will not learn.
try:
!pip uninstall -qy geometricvocab
except:
pass
!pip install -q git+https://github.com/AbstractEyes/lattice_vocabulary.git
License: MIT
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
import math
from itertools import combinations
import time
import os
import json
from tqdm.auto import tqdm
from pathlib import Path
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")
from geovocab2.shapes.factory.simplex_factory import SimplexFactory
from huggingface_hub import HfApi, create_repo, upload_folder
import tiktoken
# ============================================================================
# CONFIG
# ============================================================================
HF_REPO = "AbstractPhil/ksimplex-llm-prototype"
RUN_NAME = f"run_{int(time.time())}"
CHECKPOINT_DIR = Path(f"./checkpoints/{RUN_NAME}")
TENSORBOARD_DIR = Path(f"./runs/{RUN_NAME}")
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
TENSORBOARD_DIR.mkdir(parents=True, exist_ok=True)
# ============================================================================
# CAYLEY-MENGER VALIDATOR
# ============================================================================
class CMValidator(nn.Module):
def __init__(self, k):
super().__init__()
self._k = k
self._nv = k + 1
pairs = list(combinations(range(self._nv), 2))
self._npairs = len(pairs)
self.register_buffer('_pi', torch.tensor([p[0] for p in pairs], dtype=torch.long))
self.register_buffer('_pj', torch.tensor([p[1] for p in pairs], dtype=torch.long))
sign = (-1.0) ** (k + 1)
fact = math.factorial(k)
self._prefactor = sign / ((2.0 ** k) * (fact ** 2))
def forward(self, verts):
gram = torch.einsum('...ve,...we->...vw', verts, verts)
norms = torch.diagonal(gram, dim1=-2, dim2=-1)
d2_mat = norms.unsqueeze(-1) + norms.unsqueeze(-2) - 2 * gram
d2_mat = F.relu(d2_mat)
d2_pairs = d2_mat[..., self._pi, self._pj]
shape = d2_mat.shape[:-2]
V = d2_mat.shape[-1]
cm = torch.zeros(*shape, V+1, V+1, device=d2_mat.device, dtype=d2_mat.dtype)
cm[..., 0, 1:] = 1.0
cm[..., 1:, 0] = 1.0
cm[..., 1:, 1:] = d2_mat
vol2 = self._prefactor * torch.linalg.det(cm)
return d2_pairs, vol2
# ============================================================================
# K-SIMPLEX CHANNEL ENCODER
# ============================================================================
class KSimplexChannel(nn.Module):
BASE_DEFORM = 0.05
def __init__(self, k, in_dim, edim, feat_dim):
super().__init__()
self._k = k
self._nv = k + 1
self._edim = edim
self._feat_dim = feat_dim
self._cm = CMValidator(k)
self._geo_dim = self._cm._npairs + 1
factory = SimplexFactory(k=k, embed_dim=edim, method="regular", scale=1.0)
self.register_buffer('_template', factory.build_torch(dtype=torch.float32))
self._to_coords = nn.Linear(in_dim, self._nv * edim)
self._to_feats = nn.Linear(in_dim, self._nv * feat_dim)
self._geo_gate = nn.Sequential(
nn.Linear(self._geo_dim, feat_dim),
nn.Sigmoid(),
)
self._out_dim = feat_dim + self._geo_dim
@property
def out_dim(self):
return self._out_dim
def forward(self, x):
coords = self._to_coords(x).unflatten(-1, (self._nv, self._edim))
verts = self._template + self.BASE_DEFORM * coords
vert_feats = self._to_feats(x).unflatten(-1, (self._nv, self._feat_dim))
d2, vol2 = self._cm(verts)
geo = torch.cat([d2, vol2.unsqueeze(-1)], dim=-1)
gate = self._geo_gate(geo)
validity = torch.sigmoid(vol2 * 1e6).unsqueeze(-1)
feat_agg = vert_feats.mean(dim=-2) * gate * validity
out = torch.cat([feat_agg, geo], dim=-1)
return out, vol2, d2.mean(dim=-1)
# ============================================================================
# TOKEN TO K-SIMPLEX CHANNELS
# ============================================================================
class TokenToKChannels(nn.Module):
def __init__(self, embed_dim, depth, edim, feat_dim, hidden=256):
super().__init__()
self._depth = depth
self._proj = nn.Sequential(
nn.Linear(embed_dim, hidden),
nn.LayerNorm(hidden),
nn.GELU(),
nn.Linear(hidden, hidden),
nn.LayerNorm(hidden),
nn.GELU(),
)
self._k_encoders = nn.ModuleList([
KSimplexChannel(k=k+1, in_dim=hidden, edim=edim, feat_dim=feat_dim)
for k in range(depth)
])
self._k_out_dims = [enc.out_dim for enc in self._k_encoders]
self._max_out_dim = max(self._k_out_dims)
def forward(self, x):
h = self._proj(x)
out_list, vol2_list, d2_list = [], [], []
for enc in self._k_encoders:
out, vol2, d2_mean = enc(h)
pad_size = self._max_out_dim - out.shape[-1]
if pad_size > 0:
out = F.pad(out, (0, pad_size))
out_list.append(out)
vol2_list.append(vol2)
d2_list.append(d2_mean)
k_channels = torch.stack(out_list, dim=-2)
vol2 = torch.stack(vol2_list, dim=-1)
d2_mean = torch.stack(d2_list, dim=-1)
return k_channels, vol2, d2_mean
# ============================================================================
# K-CHANNEL CROSS-ATTENTION
# ============================================================================
class KChannelCrossAttention(nn.Module):
def __init__(self, depth, feat_dim, num_heads=4, dropout=0.1):
super().__init__()
self._depth = depth
self._feat_dim = feat_dim
self._num_heads = num_heads
self._head_dim = feat_dim // num_heads
self._norm_q = nn.LayerNorm(feat_dim)
self._norm_kv = nn.LayerNorm(feat_dim)
self._to_q = nn.Linear(feat_dim, feat_dim)
self._to_k = nn.Linear(feat_dim, feat_dim)
self._to_v = nn.Linear(feat_dim, feat_dim)
self._out = nn.Linear(feat_dim, feat_dim)
self._drop = nn.Dropout(dropout)
self._scale = self._head_dim ** -0.5
def forward(self, x):
B, T, K, F = x.shape
x_flat = x.view(B * T, K, F)
q = self._to_q(self._norm_q(x_flat))
k = self._to_k(self._norm_kv(x_flat))
v = self._to_v(self._norm_kv(x_flat))
q = q.view(-1, K, self._num_heads, self._head_dim).transpose(1, 2)
k = k.view(-1, K, self._num_heads, self._head_dim).transpose(1, 2)
v = v.view(-1, K, self._num_heads, self._head_dim).transpose(1, 2)
attn = (q @ k.transpose(-2, -1)) * self._scale
attn = attn.softmax(dim=-1)
attn = self._drop(attn)
out = (attn @ v).transpose(1, 2).reshape(B * T, K, F)
out = self._out(out)
out = self._drop(out)
return x + out.view(B, T, K, F)
# ============================================================================
# CAUSAL SEQUENCE ATTENTION
# ============================================================================
class CausalSequenceAttention(nn.Module):
def __init__(self, depth, feat_dim, num_heads=4, dropout=0.1, max_seq_len=2048):
super().__init__()
self._num_heads = num_heads
total_dim = depth * feat_dim
self._head_dim = total_dim // num_heads
self._norm = nn.LayerNorm(total_dim)
self._to_qkv = nn.Linear(total_dim, 3 * total_dim)
self._out = nn.Linear(total_dim, total_dim)
self._drop = nn.Dropout(dropout)
self._scale = self._head_dim ** -0.5
self.register_buffer(
'_causal_mask',
torch.tril(torch.ones(max_seq_len, max_seq_len)).bool()
)
def forward(self, x):
B, T, K, F = x.shape
x_flat = x.view(B, T, K * F)
x_norm = self._norm(x_flat)
qkv = self._to_qkv(x_norm).chunk(3, dim=-1)
q, k, v = [t.view(B, T, self._num_heads, self._head_dim).transpose(1, 2) for t in qkv]
attn = (q @ k.transpose(-2, -1)) * self._scale
mask = self._causal_mask[:T, :T]
attn = attn.masked_fill(~mask, float('-inf'))
attn = attn.softmax(dim=-1)
attn = self._drop(attn)
out = (attn @ v).transpose(1, 2).reshape(B, T, K * F)
out = self._out(out)
out = self._drop(out)
return x + out.view(B, T, K, F)
# ============================================================================
# TRANSFORMER BLOCK
# ============================================================================
class GeoBlock(nn.Module):
def __init__(self, depth, feat_dim, num_heads, mlp_ratio=4.0, dropout=0.1, max_seq_len=2048):
super().__init__()
self._k_attn = KChannelCrossAttention(depth, feat_dim, num_heads, dropout)
self._seq_attn = CausalSequenceAttention(depth, feat_dim, num_heads, dropout, max_seq_len)
total_dim = depth * feat_dim
self._norm = nn.LayerNorm(total_dim)
self._mlp = nn.Sequential(
nn.Linear(total_dim, int(total_dim * mlp_ratio)),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(int(total_dim * mlp_ratio), total_dim),
nn.Dropout(dropout),
)
def forward(self, x):
B, T, K, F = x.shape
x = self._k_attn(x)
x = self._seq_attn(x)
x_flat = x.view(B, T, K * F)
x_flat = x_flat + self._mlp(self._norm(x_flat))
x = x_flat.view(B, T, K, F)
return x
# ============================================================================
# GEOMETRIC LM
# ============================================================================
class GeometricLM(nn.Module):
def __init__(
self,
vocab_size,
max_seq_len=512,
embed_dim=256,
depth=4,
edim=16,
feat_dim=64,
hidden=256,
num_heads=8,
num_blocks=8,
dropout=0.1,
):
super().__init__()
self._vocab_size = vocab_size
self._max_seq_len = max_seq_len
self._depth = depth
self._feat_dim = feat_dim
self._tok_embed = nn.Embedding(vocab_size, embed_dim)
self._pos_embed = nn.Embedding(max_seq_len, embed_dim)
self._tok_to_k = TokenToKChannels(embed_dim, depth, edim, feat_dim, hidden)
self._max_out_dim = self._tok_to_k._max_out_dim
self._proj = nn.Linear(self._max_out_dim, feat_dim)
self._blocks = nn.ModuleList([
GeoBlock(depth, feat_dim, num_heads, dropout=dropout, max_seq_len=max_seq_len)
for _ in range(num_blocks)
])
total_dim = depth * feat_dim
self._norm = nn.LayerNorm(total_dim)
self._lm_head = nn.Linear(total_dim, vocab_size, bias=False)
self._config = {
'vocab_size': vocab_size,
'max_seq_len': max_seq_len,
'embed_dim': embed_dim,
'depth': depth,
'edim': edim,
'feat_dim': feat_dim,
'hidden': hidden,
'num_heads': num_heads,
'num_blocks': num_blocks,
'dropout': dropout,
'total_dim': total_dim,
}
def forward(self, tokens):
B, T = tokens.shape
pos = torch.arange(T, device=tokens.device)
x = self._tok_embed(tokens) + self._pos_embed(pos)
k_channels, vol2, d2_mean = self._tok_to_k(x)
k_channels = self._proj(k_channels)
for blk in self._blocks:
k_channels = blk(k_channels)
out = k_channels.flatten(-2)
logits = self._lm_head(self._norm(out))
return logits, {'vol2': vol2, 'd2_mean': d2_mean}
@torch.no_grad()
def generate(self, prompt_tokens, max_new_tokens=100, temperature=1.0, top_k=50):
self.eval()
tokens = prompt_tokens.clone()
for _ in range(max_new_tokens):
ctx = tokens[:, -self._max_seq_len:]
logits, _ = self(ctx)
logits = logits[:, -1, :] / temperature
if top_k > 0:
v, _ = torch.topk(logits, top_k)
logits[logits < v[:, [-1]]] = float('-inf')
probs = F.softmax(logits, dim=-1)
next_tok = torch.multinomial(probs, num_samples=1)
tokens = torch.cat([tokens, next_tok], dim=1)
return tokens
# ============================================================================
# DATASET
# ============================================================================
class TokenizedDataset(Dataset):
def __init__(self, tokens, seq_len, stride=None):
self._tokens = tokens
self._seq_len = seq_len
self._stride = stride if stride else seq_len // 2 # 50% overlap max
def __len__(self):
return max(0, (len(self._tokens) - self._seq_len - 1) // self._stride)
def __getitem__(self, idx):
start = idx * self._stride
chunk = self._tokens[start:start + self._seq_len + 1]
x = torch.tensor(chunk[:-1], dtype=torch.long)
y = torch.tensor(chunk[1:], dtype=torch.long)
return x, y
# ============================================================================
# LOSS & METRICS
# ============================================================================
def lm_loss(logits, targets, info, ce_weight=1.0, validity_weight=0.1):
B, T, V = logits.shape
ce = F.cross_entropy(logits.view(B * T, V), targets.view(B * T))
validity = F.relu(-info['vol2']).mean()
total = ce_weight * ce + validity_weight * validity
return total, ce, validity
@torch.no_grad()
def compute_metrics(info, depth):
vol2 = info['vol2']
d2_mean = info['d2_mean']
m = {'valid_rate': (vol2 > 0).float().mean().item()}
for k in range(depth):
m[f'k{k+1}_valid'] = (vol2[..., k] > 0).float().mean().item()
m[f'k{k+1}_vol2'] = vol2[..., k].mean().item()
m[f'k{k+1}_d2'] = d2_mean[..., k].mean().item()
return m
# ============================================================================
# SANITY CHECK
# ============================================================================
@torch.no_grad()
def sanity_check(model, enc, device):
"""Verify no information leak."""
print("\n" + "=" * 60)
print("SANITY CHECK")
print("=" * 60)
model.eval()
# Test 1: Random input should give high CE
random_tokens = torch.randint(0, 1000, (4, 256), device=device)
logits, _ = model(random_tokens)
random_targets = torch.randint(0, enc.n_vocab, (4, 256), device=device)
ce = F.cross_entropy(logits.view(-1, enc.n_vocab), random_targets.view(-1))
expected_ce = math.log(enc.n_vocab)
print(f"Test 1 - Random input:")
print(f" CE: {ce.item():.2f} (expected ~{expected_ce:.2f})")
print(f" PPL: {math.exp(min(ce.item(), 20)):.0f} (expected ~{enc.n_vocab})")
test1_pass = ce.item() > 8.0 # Should be close to ln(50257) ≈ 10.8
print(f" Status: {'✓ PASS' if test1_pass else '✗ FAIL'}")
# Test 2: Causal mask - early positions shouldn't depend on late tokens
tokens1 = torch.zeros(1, 256, dtype=torch.long, device=device)
tokens2 = torch.zeros(1, 256, dtype=torch.long, device=device)
tokens2[0, 128:] = 999 # Change later tokens
logits1, _ = model(tokens1)
logits2, _ = model(tokens2)
diff_early = (logits1[0, :128] - logits2[0, :128]).abs().max().item()
diff_late = (logits1[0, 128:] - logits2[0, 128:]).abs().max().item()
print(f"\nTest 2 - Causal mask:")
print(f" Early positions diff: {diff_early:.6f} (should be ~0)")
print(f" Late positions diff: {diff_late:.6f} (should be >0)")
test2_pass = diff_early < 1e-5 and diff_late > 1e-3
print(f" Status: {'✓ PASS' if test2_pass else '✗ FAIL'}")
# Test 3: Dataset sanity - x and y should be offset by 1
print(f"\nTest 3 - Dataset offset:")
test_tokens = list(range(100))
ds = TokenizedDataset(test_tokens, seq_len=10)
x, y = ds[0]
offset_correct = all(x[i] + 1 == y[i] for i in range(len(x)))
print(f" x: {x[:5].tolist()}...")
print(f" y: {y[:5].tolist()}...")
print(f" Offset correct: {'✓ PASS' if offset_correct else '✗ FAIL'}")
print("=" * 60)
all_pass = test1_pass and test2_pass and offset_correct
if not all_pass:
print("⚠️ WARNING: Some sanity checks failed!")
else:
print("✓ All sanity checks passed!")
print("=" * 60 + "\n")
model.train()
return all_pass
# ============================================================================
# GENERATION SAMPLING
# ============================================================================
PROMPTS = [
"ROMEO: ",
"JULIET: ",
"To be or not to be",
"The king ",
"Once upon a time",
"First Citizen:\n",
"What light through yonder",
"Friends, Romans, countrymen",
"Now is the winter of",
"All the world's a stage",
]
@torch.no_grad()
def generate_samples(model, enc, device, epoch, writer=None):
"""Generate samples from all prompts."""
model.eval()
samples = []
print(f"\n{'='*60}")
print(f"GENERATION SAMPLES - Epoch {epoch}")
print(f"{'='*60}")
for i, prompt in enumerate(PROMPTS):
prompt_tokens = torch.tensor([enc.encode(prompt)], device=device)
out_tokens = model.generate(
prompt_tokens,
max_new_tokens=100,
temperature=0.8,
top_k=50
)
generated = enc.decode(out_tokens[0].tolist())
samples.append({'prompt': prompt, 'generated': generated})
print(f"\n--- Prompt {i+1}: '{prompt.strip()}' ---")
print(generated[:300])
if len(generated) > 300:
print("...")
print(f"{'='*60}\n")
# Log to tensorboard
if writer:
sample_text = "\n\n".join([
f"**Prompt:** {s['prompt']}\n**Generated:**\n{s['generated'][:500]}"
for s in samples
])
writer.add_text("samples/generated", sample_text, epoch)
model.train()
return samples
# ============================================================================
# CHECKPOINTING & HF UPLOAD
# ============================================================================
def save_checkpoint(model, optimizer, scheduler, epoch, config, metrics, checkpoint_dir):
"""Save checkpoint locally."""
checkpoint = {
'epoch': epoch,
'model_state_dict': model._orig_mod.state_dict() if hasattr(model, '_orig_mod') else model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scheduler_state_dict': scheduler.state_dict(),
'config': config,
'metrics': metrics,
}
path = checkpoint_dir / f"checkpoint_epoch_{epoch:03d}.pt"
torch.save(checkpoint, path)
# Also save latest
torch.save(checkpoint, checkpoint_dir / "checkpoint_latest.pt")
# Save config as JSON
with open(checkpoint_dir / "config.json", 'w') as f:
json.dump(config, f, indent=2)
print(f"Saved checkpoint: {path}")
return path
def upload_to_hf(checkpoint_dir, repo_id, epoch):
"""Upload checkpoint directory to HuggingFace."""
try:
api = HfApi()
# Create repo if doesn't exist
try:
create_repo(repo_id, exist_ok=True, repo_type="model")
except Exception as e:
print(f"Repo creation note: {e}")
# Upload folder
api.upload_folder(
folder_path=str(checkpoint_dir),
repo_id=repo_id,
commit_message=f"Epoch {epoch} checkpoint",
)
print(f"Uploaded to HuggingFace: {repo_id}")
return True
except Exception as e:
print(f"HuggingFace upload failed: {e}")
return False
# ============================================================================
# TRAIN
# ============================================================================
def train():
import urllib.request
# TensorBoard
writer = SummaryWriter(log_dir=str(TENSORBOARD_DIR))
print(f"TensorBoard logs: {TENSORBOARD_DIR}")
print(f"Checkpoints: {CHECKPOINT_DIR}")
print(f"HuggingFace repo: {HF_REPO}")
# Data
data_path = './data/shakespeare.txt'
if not os.path.exists(data_path):
os.makedirs('./data', exist_ok=True)
url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
print("Downloading Shakespeare...")
urllib.request.urlretrieve(url, data_path)
with open(data_path, 'r') as f:
text = f.read()
print(f"Text length: {len(text):,} chars")
# Tokenizer
print("Loading tokenizer...")
enc = tiktoken.get_encoding("gpt2")
print("Tokenizing...")
tokens = enc.encode(text)
print(f"Token count: {len(tokens):,}")
print(f"Vocab size: {enc.n_vocab:,}")
print(f"Compression ratio: {len(text) / len(tokens):.2f}x")
# Split
seq_len = 256
split_idx = int(len(tokens) * 0.9)
train_tokens = tokens[:split_idx]
val_tokens = tokens[split_idx:]
train_ds = TokenizedDataset(train_tokens, seq_len)
val_ds = TokenizedDataset(val_tokens, seq_len)
batch_size = 12
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, persistent_workers=True)
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True)
print(f"Train sequences: {len(train_ds):,} ({len(train_dl)} batches)")
print(f"Val sequences: {len(val_ds):,} ({len(val_dl)} batches)")
# Model config
model_config = {
'vocab_size': enc.n_vocab,
'max_seq_len': seq_len,
'embed_dim': 384,
'depth': 4,
'edim': 16,
'feat_dim': 96,
'hidden': 384,
'num_heads': 8,
'num_blocks': 8,
'dropout': 0.1,
}
# Training config
train_config = {
'batch_size': batch_size,
'seq_len': seq_len,
'lr': 3e-4,
'weight_decay': 0.1,
'num_epochs': 14,
'grad_clip': 1.0,
'ce_weight': 1.0,
'validity_weight': 0.1,
}
full_config = {
'model': model_config,
'training': train_config,
'data': {
'train_tokens': len(train_tokens),
'val_tokens': len(val_tokens),
'vocab_size': enc.n_vocab,
},
'run_name': RUN_NAME,
}
# Save config
with open(CHECKPOINT_DIR / "config.json", 'w') as f:
json.dump(full_config, f, indent=2)
# Model
print("\nBuilding model...")
model = GeometricLM(**model_config).to(device)
print(f"\nConfig:")
for k, v in model._config.items():
print(f" {k}: {v}")
params = sum(p.numel() for p in model.parameters())
print(f" params: {params:,}")
full_config['model']['params'] = params
# Sanity check BEFORE compile
sanity_check(model, enc, device)
print("\nCompiling...")
#model = torch.compile(model, mode="reduce-overhead")
# Optimizer
opt = torch.optim.AdamW(
model.parameters(),
lr=train_config['lr'],
weight_decay=train_config['weight_decay']
)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=train_config['num_epochs'])
# Log model graph
# writer.add_graph(model, torch.zeros(1, seq_len, dtype=torch.long, device=device))
best_val = float('inf')
best_ppl = float('inf')
global_step = 0
print("\nTraining...")
print("=" * 120)
epoch_pbar = tqdm(range(train_config['num_epochs']), desc="Epochs", position=0)
for ep in epoch_pbar:
epoch_start = time.time()
# ==================== TRAIN ====================
model.train()
ce_sum, val_sum, n = 0, 0, 0
train_pbar = tqdm(train_dl, desc=f"Train {ep+1}", leave=False, position=1)
for batch_idx, (x, y) in enumerate(train_pbar):
x, y = x.to(device), y.to(device)
opt.zero_grad()
logits, info = model(x)
loss, ce, val = lm_loss(
logits, y, info,
ce_weight=train_config['ce_weight'],
validity_weight=train_config['validity_weight']
)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), train_config['grad_clip'])
opt.step()
ce_sum += ce.item() * x.size(0)
val_sum += val.item() * x.size(0)
n += x.size(0)
# TensorBoard - batch level
if global_step % 100 == 0:
writer.add_scalar("train/ce_batch", ce.item(), global_step)
writer.add_scalar("train/ppl_batch", math.exp(min(ce.item(), 10)), global_step)
writer.add_scalar("train/validity_batch", val.item(), global_step)
writer.add_scalar("train/lr", sched.get_last_lr()[0], global_step)
global_step += 1
train_pbar.set_postfix({
'CE': f'{ce.item():.3f}',
'PPL': f'{math.exp(min(ce.item(), 10)):.1f}'
})
tr_ce = ce_sum / n
tr_ppl = math.exp(min(tr_ce, 10))
tr_val = val_sum / n
# ==================== VAL ====================
model.eval()
ce_sum, n = 0, 0
metrics_agg = []
val_pbar = tqdm(val_dl, desc=f"Val {ep+1}", leave=False, position=1)
with torch.no_grad():
for x, y in val_pbar:
x, y = x.to(device), y.to(device)
logits, info = model(x)
_, ce, _ = lm_loss(logits, y, info)
ce_sum += ce.item() * x.size(0)
n += x.size(0)
metrics_agg.append(compute_metrics(info, model._config['depth']))
val_pbar.set_postfix({
'CE': f'{ce.item():.3f}',
'PPL': f'{math.exp(min(ce.item(), 10)):.1f}'
})
va_ce = ce_sum / n
va_ppl = math.exp(min(va_ce, 10))
sched.step()
if va_ce < best_val:
best_val = va_ce
best_ppl = va_ppl
# Aggregate metrics
m = {k: sum(d[k] for d in metrics_agg) / len(metrics_agg) for k in metrics_agg[0]}
epoch_time = time.time() - epoch_start
# ==================== TENSORBOARD - EPOCH ====================
writer.add_scalar("epoch/train_ce", tr_ce, ep)
writer.add_scalar("epoch/train_ppl", tr_ppl, ep)
writer.add_scalar("epoch/val_ce", va_ce, ep)
writer.add_scalar("epoch/val_ppl", va_ppl, ep)
writer.add_scalar("epoch/best_ppl", best_ppl, ep)
writer.add_scalar("epoch/validity_loss", tr_val, ep)
writer.add_scalar("epoch/time", epoch_time, ep)
for k in range(model._config['depth']):
writer.add_scalar(f"geometry/k{k+1}_valid", m[f'k{k+1}_valid'], ep)
writer.add_scalar(f"geometry/k{k+1}_vol2", m[f'k{k+1}_vol2'], ep)
writer.add_scalar(f"geometry/k{k+1}_d2", m[f'k{k+1}_d2'], ep)
writer.add_scalar("geometry/valid_rate", m['valid_rate'], ep)
# ==================== LOGGING ====================
epoch_pbar.set_postfix({
'TrPPL': f'{tr_ppl:.1f}',
'VaPPL': f'{va_ppl:.1f}',
'Best': f'{best_ppl:.1f}',
'Valid': f"{m['valid_rate']:.0%}"
})
tqdm.write(
f"\nEp {ep+1:3d} | TrCE {tr_ce:.4f} | VaCE {va_ce:.4f} | "
f"TrPPL {tr_ppl:7.2f} | VaPPL {va_ppl:7.2f} | BestPPL {best_ppl:.2f} | "
f"Time {epoch_time:.1f}s"
)
tqdm.write(
f" | k1 {m['k1_valid']:5.1%} vol²={m['k1_vol2']:.2e} | "
f"k2 {m['k2_valid']:5.1%} vol²={m['k2_vol2']:.2e} | "
f"k3 {m['k3_valid']:5.1%} vol²={m['k3_vol2']:.2e} | "
f"k4 {m['k4_valid']:5.1%} vol²={m['k4_vol2']:.2e}"
)
# ==================== GENERATE SAMPLES ====================
if ep % 25 == 0 or ep == train_config['num_epochs'] - 1:
samples = generate_samples(model, enc, device, ep + 1, writer)
# Save samples to file
with open(CHECKPOINT_DIR / f"samples_epoch_{ep+1:03d}.json", 'w') as f:
json.dump(samples, f, indent=2)
# ==================== CHECKPOINT ====================
metrics = {
'epoch': ep + 1,
'train_ce': tr_ce,
'train_ppl': tr_ppl,
'val_ce': va_ce,
'val_ppl': va_ppl,
'best_ppl': best_ppl,
'geometry': m,
}
if ep % 2 == 0 or ep == train_config['num_epochs'] - 1:
save_checkpoint(model, opt, sched, ep + 1, full_config, metrics, CHECKPOINT_DIR)
# ==================== HF UPLOAD ====================
if train_config['num_epochs'] - 1 == ep:
upload_to_hf(CHECKPOINT_DIR, HF_REPO, ep + 1)
# ==================== FINAL ====================
writer.close()
print("\n" + "=" * 120)
print(f"Training complete!")
print(f"Best val CE: {best_val:.4f}, PPL: {best_ppl:.2f}")
print(f"Checkpoints: {CHECKPOINT_DIR}")
print(f"TensorBoard: {TENSORBOARD_DIR}")
print(f"HuggingFace: https://huggingface.co/{HF_REPO}")
print("=" * 120)
return model, enc
if __name__ == "__main__":
model, tokenizer = train()