|
|
|
|
|
""" |
|
|
Prototype LM for geometric simplex structures. |
|
|
|
|
|
Requires the geometricvocab's SimplexFactory for valid simplex representations, or the simplex behavior will not learn. |
|
|
|
|
|
try: |
|
|
!pip uninstall -qy geometricvocab |
|
|
except: |
|
|
pass |
|
|
|
|
|
!pip install -q git+https://github.com/AbstractEyes/lattice_vocabulary.git |
|
|
|
|
|
License: MIT |
|
|
""" |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import DataLoader, Dataset |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
import math |
|
|
from itertools import combinations |
|
|
import time |
|
|
import os |
|
|
import json |
|
|
from tqdm.auto import tqdm |
|
|
from pathlib import Path |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
print(f"Device: {device}") |
|
|
|
|
|
from geovocab2.shapes.factory.simplex_factory import SimplexFactory |
|
|
from huggingface_hub import HfApi, create_repo, upload_folder |
|
|
import tiktoken |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
HF_REPO = "AbstractPhil/ksimplex-llm-prototype" |
|
|
RUN_NAME = f"run_{int(time.time())}" |
|
|
CHECKPOINT_DIR = Path(f"./checkpoints/{RUN_NAME}") |
|
|
TENSORBOARD_DIR = Path(f"./runs/{RUN_NAME}") |
|
|
|
|
|
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True) |
|
|
TENSORBOARD_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CMValidator(nn.Module): |
|
|
def __init__(self, k): |
|
|
super().__init__() |
|
|
self._k = k |
|
|
self._nv = k + 1 |
|
|
|
|
|
pairs = list(combinations(range(self._nv), 2)) |
|
|
self._npairs = len(pairs) |
|
|
self.register_buffer('_pi', torch.tensor([p[0] for p in pairs], dtype=torch.long)) |
|
|
self.register_buffer('_pj', torch.tensor([p[1] for p in pairs], dtype=torch.long)) |
|
|
|
|
|
sign = (-1.0) ** (k + 1) |
|
|
fact = math.factorial(k) |
|
|
self._prefactor = sign / ((2.0 ** k) * (fact ** 2)) |
|
|
|
|
|
def forward(self, verts): |
|
|
gram = torch.einsum('...ve,...we->...vw', verts, verts) |
|
|
norms = torch.diagonal(gram, dim1=-2, dim2=-1) |
|
|
d2_mat = norms.unsqueeze(-1) + norms.unsqueeze(-2) - 2 * gram |
|
|
d2_mat = F.relu(d2_mat) |
|
|
|
|
|
d2_pairs = d2_mat[..., self._pi, self._pj] |
|
|
|
|
|
shape = d2_mat.shape[:-2] |
|
|
V = d2_mat.shape[-1] |
|
|
cm = torch.zeros(*shape, V+1, V+1, device=d2_mat.device, dtype=d2_mat.dtype) |
|
|
cm[..., 0, 1:] = 1.0 |
|
|
cm[..., 1:, 0] = 1.0 |
|
|
cm[..., 1:, 1:] = d2_mat |
|
|
|
|
|
vol2 = self._prefactor * torch.linalg.det(cm) |
|
|
|
|
|
return d2_pairs, vol2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class KSimplexChannel(nn.Module): |
|
|
BASE_DEFORM = 0.05 |
|
|
|
|
|
def __init__(self, k, in_dim, edim, feat_dim): |
|
|
super().__init__() |
|
|
self._k = k |
|
|
self._nv = k + 1 |
|
|
self._edim = edim |
|
|
self._feat_dim = feat_dim |
|
|
|
|
|
self._cm = CMValidator(k) |
|
|
self._geo_dim = self._cm._npairs + 1 |
|
|
|
|
|
factory = SimplexFactory(k=k, embed_dim=edim, method="regular", scale=1.0) |
|
|
self.register_buffer('_template', factory.build_torch(dtype=torch.float32)) |
|
|
|
|
|
self._to_coords = nn.Linear(in_dim, self._nv * edim) |
|
|
self._to_feats = nn.Linear(in_dim, self._nv * feat_dim) |
|
|
|
|
|
self._geo_gate = nn.Sequential( |
|
|
nn.Linear(self._geo_dim, feat_dim), |
|
|
nn.Sigmoid(), |
|
|
) |
|
|
|
|
|
self._out_dim = feat_dim + self._geo_dim |
|
|
|
|
|
@property |
|
|
def out_dim(self): |
|
|
return self._out_dim |
|
|
|
|
|
def forward(self, x): |
|
|
coords = self._to_coords(x).unflatten(-1, (self._nv, self._edim)) |
|
|
verts = self._template + self.BASE_DEFORM * coords |
|
|
|
|
|
vert_feats = self._to_feats(x).unflatten(-1, (self._nv, self._feat_dim)) |
|
|
|
|
|
d2, vol2 = self._cm(verts) |
|
|
geo = torch.cat([d2, vol2.unsqueeze(-1)], dim=-1) |
|
|
|
|
|
gate = self._geo_gate(geo) |
|
|
validity = torch.sigmoid(vol2 * 1e6).unsqueeze(-1) |
|
|
|
|
|
feat_agg = vert_feats.mean(dim=-2) * gate * validity |
|
|
|
|
|
out = torch.cat([feat_agg, geo], dim=-1) |
|
|
|
|
|
return out, vol2, d2.mean(dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TokenToKChannels(nn.Module): |
|
|
def __init__(self, embed_dim, depth, edim, feat_dim, hidden=256): |
|
|
super().__init__() |
|
|
self._depth = depth |
|
|
|
|
|
self._proj = nn.Sequential( |
|
|
nn.Linear(embed_dim, hidden), |
|
|
nn.LayerNorm(hidden), |
|
|
nn.GELU(), |
|
|
nn.Linear(hidden, hidden), |
|
|
nn.LayerNorm(hidden), |
|
|
nn.GELU(), |
|
|
) |
|
|
|
|
|
self._k_encoders = nn.ModuleList([ |
|
|
KSimplexChannel(k=k+1, in_dim=hidden, edim=edim, feat_dim=feat_dim) |
|
|
for k in range(depth) |
|
|
]) |
|
|
|
|
|
self._k_out_dims = [enc.out_dim for enc in self._k_encoders] |
|
|
self._max_out_dim = max(self._k_out_dims) |
|
|
|
|
|
def forward(self, x): |
|
|
h = self._proj(x) |
|
|
|
|
|
out_list, vol2_list, d2_list = [], [], [] |
|
|
|
|
|
for enc in self._k_encoders: |
|
|
out, vol2, d2_mean = enc(h) |
|
|
|
|
|
pad_size = self._max_out_dim - out.shape[-1] |
|
|
if pad_size > 0: |
|
|
out = F.pad(out, (0, pad_size)) |
|
|
|
|
|
out_list.append(out) |
|
|
vol2_list.append(vol2) |
|
|
d2_list.append(d2_mean) |
|
|
|
|
|
k_channels = torch.stack(out_list, dim=-2) |
|
|
vol2 = torch.stack(vol2_list, dim=-1) |
|
|
d2_mean = torch.stack(d2_list, dim=-1) |
|
|
|
|
|
return k_channels, vol2, d2_mean |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class KChannelCrossAttention(nn.Module): |
|
|
def __init__(self, depth, feat_dim, num_heads=4, dropout=0.1): |
|
|
super().__init__() |
|
|
self._depth = depth |
|
|
self._feat_dim = feat_dim |
|
|
self._num_heads = num_heads |
|
|
self._head_dim = feat_dim // num_heads |
|
|
|
|
|
self._norm_q = nn.LayerNorm(feat_dim) |
|
|
self._norm_kv = nn.LayerNorm(feat_dim) |
|
|
|
|
|
self._to_q = nn.Linear(feat_dim, feat_dim) |
|
|
self._to_k = nn.Linear(feat_dim, feat_dim) |
|
|
self._to_v = nn.Linear(feat_dim, feat_dim) |
|
|
self._out = nn.Linear(feat_dim, feat_dim) |
|
|
self._drop = nn.Dropout(dropout) |
|
|
|
|
|
self._scale = self._head_dim ** -0.5 |
|
|
|
|
|
def forward(self, x): |
|
|
B, T, K, F = x.shape |
|
|
|
|
|
x_flat = x.view(B * T, K, F) |
|
|
|
|
|
q = self._to_q(self._norm_q(x_flat)) |
|
|
k = self._to_k(self._norm_kv(x_flat)) |
|
|
v = self._to_v(self._norm_kv(x_flat)) |
|
|
|
|
|
q = q.view(-1, K, self._num_heads, self._head_dim).transpose(1, 2) |
|
|
k = k.view(-1, K, self._num_heads, self._head_dim).transpose(1, 2) |
|
|
v = v.view(-1, K, self._num_heads, self._head_dim).transpose(1, 2) |
|
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self._scale |
|
|
attn = attn.softmax(dim=-1) |
|
|
attn = self._drop(attn) |
|
|
|
|
|
out = (attn @ v).transpose(1, 2).reshape(B * T, K, F) |
|
|
out = self._out(out) |
|
|
out = self._drop(out) |
|
|
|
|
|
return x + out.view(B, T, K, F) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CausalSequenceAttention(nn.Module): |
|
|
def __init__(self, depth, feat_dim, num_heads=4, dropout=0.1, max_seq_len=2048): |
|
|
super().__init__() |
|
|
self._num_heads = num_heads |
|
|
|
|
|
total_dim = depth * feat_dim |
|
|
self._head_dim = total_dim // num_heads |
|
|
|
|
|
self._norm = nn.LayerNorm(total_dim) |
|
|
self._to_qkv = nn.Linear(total_dim, 3 * total_dim) |
|
|
self._out = nn.Linear(total_dim, total_dim) |
|
|
self._drop = nn.Dropout(dropout) |
|
|
|
|
|
self._scale = self._head_dim ** -0.5 |
|
|
|
|
|
self.register_buffer( |
|
|
'_causal_mask', |
|
|
torch.tril(torch.ones(max_seq_len, max_seq_len)).bool() |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
B, T, K, F = x.shape |
|
|
|
|
|
x_flat = x.view(B, T, K * F) |
|
|
x_norm = self._norm(x_flat) |
|
|
|
|
|
qkv = self._to_qkv(x_norm).chunk(3, dim=-1) |
|
|
q, k, v = [t.view(B, T, self._num_heads, self._head_dim).transpose(1, 2) for t in qkv] |
|
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self._scale |
|
|
|
|
|
mask = self._causal_mask[:T, :T] |
|
|
attn = attn.masked_fill(~mask, float('-inf')) |
|
|
attn = attn.softmax(dim=-1) |
|
|
attn = self._drop(attn) |
|
|
|
|
|
out = (attn @ v).transpose(1, 2).reshape(B, T, K * F) |
|
|
out = self._out(out) |
|
|
out = self._drop(out) |
|
|
|
|
|
return x + out.view(B, T, K, F) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GeoBlock(nn.Module): |
|
|
def __init__(self, depth, feat_dim, num_heads, mlp_ratio=4.0, dropout=0.1, max_seq_len=2048): |
|
|
super().__init__() |
|
|
|
|
|
self._k_attn = KChannelCrossAttention(depth, feat_dim, num_heads, dropout) |
|
|
self._seq_attn = CausalSequenceAttention(depth, feat_dim, num_heads, dropout, max_seq_len) |
|
|
|
|
|
total_dim = depth * feat_dim |
|
|
self._norm = nn.LayerNorm(total_dim) |
|
|
self._mlp = nn.Sequential( |
|
|
nn.Linear(total_dim, int(total_dim * mlp_ratio)), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(int(total_dim * mlp_ratio), total_dim), |
|
|
nn.Dropout(dropout), |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
B, T, K, F = x.shape |
|
|
|
|
|
x = self._k_attn(x) |
|
|
x = self._seq_attn(x) |
|
|
|
|
|
x_flat = x.view(B, T, K * F) |
|
|
x_flat = x_flat + self._mlp(self._norm(x_flat)) |
|
|
x = x_flat.view(B, T, K, F) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GeometricLM(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
vocab_size, |
|
|
max_seq_len=512, |
|
|
embed_dim=256, |
|
|
depth=4, |
|
|
edim=16, |
|
|
feat_dim=64, |
|
|
hidden=256, |
|
|
num_heads=8, |
|
|
num_blocks=8, |
|
|
dropout=0.1, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self._vocab_size = vocab_size |
|
|
self._max_seq_len = max_seq_len |
|
|
self._depth = depth |
|
|
self._feat_dim = feat_dim |
|
|
|
|
|
self._tok_embed = nn.Embedding(vocab_size, embed_dim) |
|
|
self._pos_embed = nn.Embedding(max_seq_len, embed_dim) |
|
|
|
|
|
self._tok_to_k = TokenToKChannels(embed_dim, depth, edim, feat_dim, hidden) |
|
|
self._max_out_dim = self._tok_to_k._max_out_dim |
|
|
|
|
|
self._proj = nn.Linear(self._max_out_dim, feat_dim) |
|
|
|
|
|
self._blocks = nn.ModuleList([ |
|
|
GeoBlock(depth, feat_dim, num_heads, dropout=dropout, max_seq_len=max_seq_len) |
|
|
for _ in range(num_blocks) |
|
|
]) |
|
|
|
|
|
total_dim = depth * feat_dim |
|
|
self._norm = nn.LayerNorm(total_dim) |
|
|
self._lm_head = nn.Linear(total_dim, vocab_size, bias=False) |
|
|
|
|
|
self._config = { |
|
|
'vocab_size': vocab_size, |
|
|
'max_seq_len': max_seq_len, |
|
|
'embed_dim': embed_dim, |
|
|
'depth': depth, |
|
|
'edim': edim, |
|
|
'feat_dim': feat_dim, |
|
|
'hidden': hidden, |
|
|
'num_heads': num_heads, |
|
|
'num_blocks': num_blocks, |
|
|
'dropout': dropout, |
|
|
'total_dim': total_dim, |
|
|
} |
|
|
|
|
|
def forward(self, tokens): |
|
|
B, T = tokens.shape |
|
|
|
|
|
pos = torch.arange(T, device=tokens.device) |
|
|
x = self._tok_embed(tokens) + self._pos_embed(pos) |
|
|
|
|
|
k_channels, vol2, d2_mean = self._tok_to_k(x) |
|
|
k_channels = self._proj(k_channels) |
|
|
|
|
|
for blk in self._blocks: |
|
|
k_channels = blk(k_channels) |
|
|
|
|
|
out = k_channels.flatten(-2) |
|
|
logits = self._lm_head(self._norm(out)) |
|
|
|
|
|
return logits, {'vol2': vol2, 'd2_mean': d2_mean} |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate(self, prompt_tokens, max_new_tokens=100, temperature=1.0, top_k=50): |
|
|
self.eval() |
|
|
tokens = prompt_tokens.clone() |
|
|
|
|
|
for _ in range(max_new_tokens): |
|
|
ctx = tokens[:, -self._max_seq_len:] |
|
|
logits, _ = self(ctx) |
|
|
logits = logits[:, -1, :] / temperature |
|
|
|
|
|
if top_k > 0: |
|
|
v, _ = torch.topk(logits, top_k) |
|
|
logits[logits < v[:, [-1]]] = float('-inf') |
|
|
|
|
|
probs = F.softmax(logits, dim=-1) |
|
|
next_tok = torch.multinomial(probs, num_samples=1) |
|
|
tokens = torch.cat([tokens, next_tok], dim=1) |
|
|
|
|
|
return tokens |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TokenizedDataset(Dataset): |
|
|
def __init__(self, tokens, seq_len, stride=None): |
|
|
self._tokens = tokens |
|
|
self._seq_len = seq_len |
|
|
self._stride = stride if stride else seq_len // 2 |
|
|
|
|
|
def __len__(self): |
|
|
return max(0, (len(self._tokens) - self._seq_len - 1) // self._stride) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
start = idx * self._stride |
|
|
chunk = self._tokens[start:start + self._seq_len + 1] |
|
|
x = torch.tensor(chunk[:-1], dtype=torch.long) |
|
|
y = torch.tensor(chunk[1:], dtype=torch.long) |
|
|
return x, y |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def lm_loss(logits, targets, info, ce_weight=1.0, validity_weight=0.1): |
|
|
B, T, V = logits.shape |
|
|
ce = F.cross_entropy(logits.view(B * T, V), targets.view(B * T)) |
|
|
validity = F.relu(-info['vol2']).mean() |
|
|
total = ce_weight * ce + validity_weight * validity |
|
|
return total, ce, validity |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def compute_metrics(info, depth): |
|
|
vol2 = info['vol2'] |
|
|
d2_mean = info['d2_mean'] |
|
|
|
|
|
m = {'valid_rate': (vol2 > 0).float().mean().item()} |
|
|
for k in range(depth): |
|
|
m[f'k{k+1}_valid'] = (vol2[..., k] > 0).float().mean().item() |
|
|
m[f'k{k+1}_vol2'] = vol2[..., k].mean().item() |
|
|
m[f'k{k+1}_d2'] = d2_mean[..., k].mean().item() |
|
|
return m |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def sanity_check(model, enc, device): |
|
|
"""Verify no information leak.""" |
|
|
print("\n" + "=" * 60) |
|
|
print("SANITY CHECK") |
|
|
print("=" * 60) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
random_tokens = torch.randint(0, 1000, (4, 256), device=device) |
|
|
logits, _ = model(random_tokens) |
|
|
random_targets = torch.randint(0, enc.n_vocab, (4, 256), device=device) |
|
|
ce = F.cross_entropy(logits.view(-1, enc.n_vocab), random_targets.view(-1)) |
|
|
|
|
|
expected_ce = math.log(enc.n_vocab) |
|
|
print(f"Test 1 - Random input:") |
|
|
print(f" CE: {ce.item():.2f} (expected ~{expected_ce:.2f})") |
|
|
print(f" PPL: {math.exp(min(ce.item(), 20)):.0f} (expected ~{enc.n_vocab})") |
|
|
|
|
|
test1_pass = ce.item() > 8.0 |
|
|
print(f" Status: {'✓ PASS' if test1_pass else '✗ FAIL'}") |
|
|
|
|
|
|
|
|
tokens1 = torch.zeros(1, 256, dtype=torch.long, device=device) |
|
|
tokens2 = torch.zeros(1, 256, dtype=torch.long, device=device) |
|
|
tokens2[0, 128:] = 999 |
|
|
|
|
|
logits1, _ = model(tokens1) |
|
|
logits2, _ = model(tokens2) |
|
|
|
|
|
diff_early = (logits1[0, :128] - logits2[0, :128]).abs().max().item() |
|
|
diff_late = (logits1[0, 128:] - logits2[0, 128:]).abs().max().item() |
|
|
|
|
|
print(f"\nTest 2 - Causal mask:") |
|
|
print(f" Early positions diff: {diff_early:.6f} (should be ~0)") |
|
|
print(f" Late positions diff: {diff_late:.6f} (should be >0)") |
|
|
|
|
|
test2_pass = diff_early < 1e-5 and diff_late > 1e-3 |
|
|
print(f" Status: {'✓ PASS' if test2_pass else '✗ FAIL'}") |
|
|
|
|
|
|
|
|
print(f"\nTest 3 - Dataset offset:") |
|
|
test_tokens = list(range(100)) |
|
|
ds = TokenizedDataset(test_tokens, seq_len=10) |
|
|
x, y = ds[0] |
|
|
offset_correct = all(x[i] + 1 == y[i] for i in range(len(x))) |
|
|
print(f" x: {x[:5].tolist()}...") |
|
|
print(f" y: {y[:5].tolist()}...") |
|
|
print(f" Offset correct: {'✓ PASS' if offset_correct else '✗ FAIL'}") |
|
|
|
|
|
print("=" * 60) |
|
|
|
|
|
all_pass = test1_pass and test2_pass and offset_correct |
|
|
if not all_pass: |
|
|
print("⚠️ WARNING: Some sanity checks failed!") |
|
|
else: |
|
|
print("✓ All sanity checks passed!") |
|
|
|
|
|
print("=" * 60 + "\n") |
|
|
|
|
|
model.train() |
|
|
return all_pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PROMPTS = [ |
|
|
"ROMEO: ", |
|
|
"JULIET: ", |
|
|
"To be or not to be", |
|
|
"The king ", |
|
|
"Once upon a time", |
|
|
"First Citizen:\n", |
|
|
"What light through yonder", |
|
|
"Friends, Romans, countrymen", |
|
|
"Now is the winter of", |
|
|
"All the world's a stage", |
|
|
] |
|
|
|
|
|
@torch.no_grad() |
|
|
def generate_samples(model, enc, device, epoch, writer=None): |
|
|
"""Generate samples from all prompts.""" |
|
|
model.eval() |
|
|
|
|
|
samples = [] |
|
|
print(f"\n{'='*60}") |
|
|
print(f"GENERATION SAMPLES - Epoch {epoch}") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
for i, prompt in enumerate(PROMPTS): |
|
|
prompt_tokens = torch.tensor([enc.encode(prompt)], device=device) |
|
|
|
|
|
out_tokens = model.generate( |
|
|
prompt_tokens, |
|
|
max_new_tokens=100, |
|
|
temperature=0.8, |
|
|
top_k=50 |
|
|
) |
|
|
|
|
|
generated = enc.decode(out_tokens[0].tolist()) |
|
|
samples.append({'prompt': prompt, 'generated': generated}) |
|
|
|
|
|
print(f"\n--- Prompt {i+1}: '{prompt.strip()}' ---") |
|
|
print(generated[:300]) |
|
|
if len(generated) > 300: |
|
|
print("...") |
|
|
|
|
|
print(f"{'='*60}\n") |
|
|
|
|
|
|
|
|
if writer: |
|
|
sample_text = "\n\n".join([ |
|
|
f"**Prompt:** {s['prompt']}\n**Generated:**\n{s['generated'][:500]}" |
|
|
for s in samples |
|
|
]) |
|
|
writer.add_text("samples/generated", sample_text, epoch) |
|
|
|
|
|
model.train() |
|
|
return samples |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_checkpoint(model, optimizer, scheduler, epoch, config, metrics, checkpoint_dir): |
|
|
"""Save checkpoint locally.""" |
|
|
checkpoint = { |
|
|
'epoch': epoch, |
|
|
'model_state_dict': model._orig_mod.state_dict() if hasattr(model, '_orig_mod') else model.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'scheduler_state_dict': scheduler.state_dict(), |
|
|
'config': config, |
|
|
'metrics': metrics, |
|
|
} |
|
|
|
|
|
path = checkpoint_dir / f"checkpoint_epoch_{epoch:03d}.pt" |
|
|
torch.save(checkpoint, path) |
|
|
|
|
|
|
|
|
torch.save(checkpoint, checkpoint_dir / "checkpoint_latest.pt") |
|
|
|
|
|
|
|
|
with open(checkpoint_dir / "config.json", 'w') as f: |
|
|
json.dump(config, f, indent=2) |
|
|
|
|
|
print(f"Saved checkpoint: {path}") |
|
|
return path |
|
|
|
|
|
|
|
|
def upload_to_hf(checkpoint_dir, repo_id, epoch): |
|
|
"""Upload checkpoint directory to HuggingFace.""" |
|
|
try: |
|
|
api = HfApi() |
|
|
|
|
|
|
|
|
try: |
|
|
create_repo(repo_id, exist_ok=True, repo_type="model") |
|
|
except Exception as e: |
|
|
print(f"Repo creation note: {e}") |
|
|
|
|
|
|
|
|
api.upload_folder( |
|
|
folder_path=str(checkpoint_dir), |
|
|
repo_id=repo_id, |
|
|
commit_message=f"Epoch {epoch} checkpoint", |
|
|
) |
|
|
|
|
|
print(f"Uploaded to HuggingFace: {repo_id}") |
|
|
return True |
|
|
except Exception as e: |
|
|
print(f"HuggingFace upload failed: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(): |
|
|
import urllib.request |
|
|
|
|
|
|
|
|
writer = SummaryWriter(log_dir=str(TENSORBOARD_DIR)) |
|
|
print(f"TensorBoard logs: {TENSORBOARD_DIR}") |
|
|
print(f"Checkpoints: {CHECKPOINT_DIR}") |
|
|
print(f"HuggingFace repo: {HF_REPO}") |
|
|
|
|
|
|
|
|
data_path = './data/shakespeare.txt' |
|
|
if not os.path.exists(data_path): |
|
|
os.makedirs('./data', exist_ok=True) |
|
|
url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt' |
|
|
print("Downloading Shakespeare...") |
|
|
urllib.request.urlretrieve(url, data_path) |
|
|
|
|
|
with open(data_path, 'r') as f: |
|
|
text = f.read() |
|
|
|
|
|
print(f"Text length: {len(text):,} chars") |
|
|
|
|
|
|
|
|
print("Loading tokenizer...") |
|
|
enc = tiktoken.get_encoding("gpt2") |
|
|
|
|
|
print("Tokenizing...") |
|
|
tokens = enc.encode(text) |
|
|
print(f"Token count: {len(tokens):,}") |
|
|
print(f"Vocab size: {enc.n_vocab:,}") |
|
|
print(f"Compression ratio: {len(text) / len(tokens):.2f}x") |
|
|
|
|
|
|
|
|
seq_len = 256 |
|
|
split_idx = int(len(tokens) * 0.9) |
|
|
train_tokens = tokens[:split_idx] |
|
|
val_tokens = tokens[split_idx:] |
|
|
|
|
|
train_ds = TokenizedDataset(train_tokens, seq_len) |
|
|
val_ds = TokenizedDataset(val_tokens, seq_len) |
|
|
|
|
|
batch_size = 12 |
|
|
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True, persistent_workers=True) |
|
|
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True) |
|
|
|
|
|
print(f"Train sequences: {len(train_ds):,} ({len(train_dl)} batches)") |
|
|
print(f"Val sequences: {len(val_ds):,} ({len(val_dl)} batches)") |
|
|
|
|
|
|
|
|
model_config = { |
|
|
'vocab_size': enc.n_vocab, |
|
|
'max_seq_len': seq_len, |
|
|
'embed_dim': 384, |
|
|
'depth': 4, |
|
|
'edim': 16, |
|
|
'feat_dim': 96, |
|
|
'hidden': 384, |
|
|
'num_heads': 8, |
|
|
'num_blocks': 8, |
|
|
'dropout': 0.1, |
|
|
} |
|
|
|
|
|
|
|
|
train_config = { |
|
|
'batch_size': batch_size, |
|
|
'seq_len': seq_len, |
|
|
'lr': 3e-4, |
|
|
'weight_decay': 0.1, |
|
|
'num_epochs': 14, |
|
|
'grad_clip': 1.0, |
|
|
'ce_weight': 1.0, |
|
|
'validity_weight': 0.1, |
|
|
} |
|
|
|
|
|
full_config = { |
|
|
'model': model_config, |
|
|
'training': train_config, |
|
|
'data': { |
|
|
'train_tokens': len(train_tokens), |
|
|
'val_tokens': len(val_tokens), |
|
|
'vocab_size': enc.n_vocab, |
|
|
}, |
|
|
'run_name': RUN_NAME, |
|
|
} |
|
|
|
|
|
|
|
|
with open(CHECKPOINT_DIR / "config.json", 'w') as f: |
|
|
json.dump(full_config, f, indent=2) |
|
|
|
|
|
|
|
|
print("\nBuilding model...") |
|
|
model = GeometricLM(**model_config).to(device) |
|
|
|
|
|
print(f"\nConfig:") |
|
|
for k, v in model._config.items(): |
|
|
print(f" {k}: {v}") |
|
|
|
|
|
params = sum(p.numel() for p in model.parameters()) |
|
|
print(f" params: {params:,}") |
|
|
full_config['model']['params'] = params |
|
|
|
|
|
|
|
|
sanity_check(model, enc, device) |
|
|
|
|
|
print("\nCompiling...") |
|
|
|
|
|
|
|
|
|
|
|
opt = torch.optim.AdamW( |
|
|
model.parameters(), |
|
|
lr=train_config['lr'], |
|
|
weight_decay=train_config['weight_decay'] |
|
|
) |
|
|
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=train_config['num_epochs']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
best_val = float('inf') |
|
|
best_ppl = float('inf') |
|
|
global_step = 0 |
|
|
|
|
|
print("\nTraining...") |
|
|
print("=" * 120) |
|
|
|
|
|
epoch_pbar = tqdm(range(train_config['num_epochs']), desc="Epochs", position=0) |
|
|
|
|
|
for ep in epoch_pbar: |
|
|
epoch_start = time.time() |
|
|
|
|
|
|
|
|
model.train() |
|
|
ce_sum, val_sum, n = 0, 0, 0 |
|
|
|
|
|
train_pbar = tqdm(train_dl, desc=f"Train {ep+1}", leave=False, position=1) |
|
|
for batch_idx, (x, y) in enumerate(train_pbar): |
|
|
x, y = x.to(device), y.to(device) |
|
|
|
|
|
opt.zero_grad() |
|
|
logits, info = model(x) |
|
|
loss, ce, val = lm_loss( |
|
|
logits, y, info, |
|
|
ce_weight=train_config['ce_weight'], |
|
|
validity_weight=train_config['validity_weight'] |
|
|
) |
|
|
loss.backward() |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), train_config['grad_clip']) |
|
|
opt.step() |
|
|
|
|
|
ce_sum += ce.item() * x.size(0) |
|
|
val_sum += val.item() * x.size(0) |
|
|
n += x.size(0) |
|
|
|
|
|
|
|
|
if global_step % 100 == 0: |
|
|
writer.add_scalar("train/ce_batch", ce.item(), global_step) |
|
|
writer.add_scalar("train/ppl_batch", math.exp(min(ce.item(), 10)), global_step) |
|
|
writer.add_scalar("train/validity_batch", val.item(), global_step) |
|
|
writer.add_scalar("train/lr", sched.get_last_lr()[0], global_step) |
|
|
|
|
|
global_step += 1 |
|
|
|
|
|
train_pbar.set_postfix({ |
|
|
'CE': f'{ce.item():.3f}', |
|
|
'PPL': f'{math.exp(min(ce.item(), 10)):.1f}' |
|
|
}) |
|
|
|
|
|
tr_ce = ce_sum / n |
|
|
tr_ppl = math.exp(min(tr_ce, 10)) |
|
|
tr_val = val_sum / n |
|
|
|
|
|
|
|
|
model.eval() |
|
|
ce_sum, n = 0, 0 |
|
|
metrics_agg = [] |
|
|
|
|
|
val_pbar = tqdm(val_dl, desc=f"Val {ep+1}", leave=False, position=1) |
|
|
with torch.no_grad(): |
|
|
for x, y in val_pbar: |
|
|
x, y = x.to(device), y.to(device) |
|
|
logits, info = model(x) |
|
|
_, ce, _ = lm_loss(logits, y, info) |
|
|
ce_sum += ce.item() * x.size(0) |
|
|
n += x.size(0) |
|
|
metrics_agg.append(compute_metrics(info, model._config['depth'])) |
|
|
|
|
|
val_pbar.set_postfix({ |
|
|
'CE': f'{ce.item():.3f}', |
|
|
'PPL': f'{math.exp(min(ce.item(), 10)):.1f}' |
|
|
}) |
|
|
|
|
|
va_ce = ce_sum / n |
|
|
va_ppl = math.exp(min(va_ce, 10)) |
|
|
|
|
|
sched.step() |
|
|
|
|
|
if va_ce < best_val: |
|
|
best_val = va_ce |
|
|
best_ppl = va_ppl |
|
|
|
|
|
|
|
|
m = {k: sum(d[k] for d in metrics_agg) / len(metrics_agg) for k in metrics_agg[0]} |
|
|
|
|
|
epoch_time = time.time() - epoch_start |
|
|
|
|
|
|
|
|
writer.add_scalar("epoch/train_ce", tr_ce, ep) |
|
|
writer.add_scalar("epoch/train_ppl", tr_ppl, ep) |
|
|
writer.add_scalar("epoch/val_ce", va_ce, ep) |
|
|
writer.add_scalar("epoch/val_ppl", va_ppl, ep) |
|
|
writer.add_scalar("epoch/best_ppl", best_ppl, ep) |
|
|
writer.add_scalar("epoch/validity_loss", tr_val, ep) |
|
|
writer.add_scalar("epoch/time", epoch_time, ep) |
|
|
|
|
|
for k in range(model._config['depth']): |
|
|
writer.add_scalar(f"geometry/k{k+1}_valid", m[f'k{k+1}_valid'], ep) |
|
|
writer.add_scalar(f"geometry/k{k+1}_vol2", m[f'k{k+1}_vol2'], ep) |
|
|
writer.add_scalar(f"geometry/k{k+1}_d2", m[f'k{k+1}_d2'], ep) |
|
|
|
|
|
writer.add_scalar("geometry/valid_rate", m['valid_rate'], ep) |
|
|
|
|
|
|
|
|
epoch_pbar.set_postfix({ |
|
|
'TrPPL': f'{tr_ppl:.1f}', |
|
|
'VaPPL': f'{va_ppl:.1f}', |
|
|
'Best': f'{best_ppl:.1f}', |
|
|
'Valid': f"{m['valid_rate']:.0%}" |
|
|
}) |
|
|
|
|
|
tqdm.write( |
|
|
f"\nEp {ep+1:3d} | TrCE {tr_ce:.4f} | VaCE {va_ce:.4f} | " |
|
|
f"TrPPL {tr_ppl:7.2f} | VaPPL {va_ppl:7.2f} | BestPPL {best_ppl:.2f} | " |
|
|
f"Time {epoch_time:.1f}s" |
|
|
) |
|
|
tqdm.write( |
|
|
f" | k1 {m['k1_valid']:5.1%} vol²={m['k1_vol2']:.2e} | " |
|
|
f"k2 {m['k2_valid']:5.1%} vol²={m['k2_vol2']:.2e} | " |
|
|
f"k3 {m['k3_valid']:5.1%} vol²={m['k3_vol2']:.2e} | " |
|
|
f"k4 {m['k4_valid']:5.1%} vol²={m['k4_vol2']:.2e}" |
|
|
) |
|
|
|
|
|
|
|
|
if ep % 25 == 0 or ep == train_config['num_epochs'] - 1: |
|
|
samples = generate_samples(model, enc, device, ep + 1, writer) |
|
|
|
|
|
|
|
|
with open(CHECKPOINT_DIR / f"samples_epoch_{ep+1:03d}.json", 'w') as f: |
|
|
json.dump(samples, f, indent=2) |
|
|
|
|
|
|
|
|
metrics = { |
|
|
'epoch': ep + 1, |
|
|
'train_ce': tr_ce, |
|
|
'train_ppl': tr_ppl, |
|
|
'val_ce': va_ce, |
|
|
'val_ppl': va_ppl, |
|
|
'best_ppl': best_ppl, |
|
|
'geometry': m, |
|
|
} |
|
|
|
|
|
if ep % 2 == 0 or ep == train_config['num_epochs'] - 1: |
|
|
save_checkpoint(model, opt, sched, ep + 1, full_config, metrics, CHECKPOINT_DIR) |
|
|
|
|
|
|
|
|
|
|
|
if train_config['num_epochs'] - 1 == ep: |
|
|
upload_to_hf(CHECKPOINT_DIR, HF_REPO, ep + 1) |
|
|
|
|
|
|
|
|
writer.close() |
|
|
|
|
|
print("\n" + "=" * 120) |
|
|
print(f"Training complete!") |
|
|
print(f"Best val CE: {best_val:.4f}, PPL: {best_ppl:.2f}") |
|
|
print(f"Checkpoints: {CHECKPOINT_DIR}") |
|
|
print(f"TensorBoard: {TENSORBOARD_DIR}") |
|
|
print(f"HuggingFace: https://huggingface.co/{HF_REPO}") |
|
|
print("=" * 120) |
|
|
|
|
|
return model, enc |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
model, tokenizer = train() |