geolip-svae-transformer / bert_trainer.py
AbstractPhil's picture
Create bert_trainer.py
11ffd43 verified
"""geolip_svae_bert_features.py β€” patchwork-represent BERT features.
Tests whether the geolip-svae-transformer's lensed rigid frame can compress and
reconstruct real semantic features: each BERT token's hidden vector is treated
as a patch (patch_dim = hidden), the token sequence is the N patches, and the
model patchwork-represents the sequence as omega tokens on the rigid frame.
THE STERILIZE STEP (from the trigram lesson β€” we control the SVAE's input
distribution): BERT features are strongly anisotropic (a dominant common
direction + a few high-variance axes). We center out the common component and
unit-normalize per token before the frame sees them, so the SVAE gets a clean,
isotropic distribution β€” far more utilizable, as Phil noted.
SWAPPABLE FEATURE SOURCE: real transformers BERT (default
'google/bert_uncased_L-2_H-128_A-2' for sandbox, swap bert-base-uncased on
Colab) with a simulated-anisotropic fallback so the pipeline always runs.
Snap alongside geolip_svae_transformer.py. Deterministic, Colab-proof run().
"""
from __future__ import annotations
import argparse
import json
import math
import time
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from geolip_svae_transformer import (
GeoSVAETransformer, GeoConfig, rigidity_loss, measure_guarantee,
dev_critical,
)
# ════════════════════════════════════════════════════════════════════════
# Corpus + BERT feature extraction (swappable real / simulated)
# ════════════════════════════════════════════════════════════════════════
def get_corpus(source: str, n: int) -> List[str]:
"""Return a list of sentences. 'builtin' = the curated set; 'wikitext' =
real Wikipedia sentences (more diverse β†’ more potent test)."""
if source == 'builtin':
return CORPUS
if source == 'wikitext':
from datasets import load_dataset
ds = load_dataset('Salesforce/wikitext', 'wikitext-2-raw-v1', split='train')
sents = []
for row in ds:
t = row['text'].strip()
if len(t) < 40 or t.startswith('='): # skip headers/short
continue
for s in t.replace('\n', ' ').split('. '):
s = s.strip()
if 30 < len(s) < 220:
sents.append(s if s.endswith('.') else s + '.')
if len(sents) >= n:
break
if len(sents) >= n:
break
print(f" [corpus] wikitext: {len(sents)} sentences")
return sents
raise ValueError(f"unknown corpus source: {source}")
CORPUS = [
"the cat sat quietly on the warm windowsill",
"quantum fields fluctuate in the vacuum of empty space",
"she poured the coffee and opened her laptop",
"the river carved a deep canyon over millions of years",
"investors watched the market tumble after the announcement",
"a gentle rain fell across the sleeping village",
"the algorithm sorts the array in logarithmic time",
"he tuned the old guitar before the evening show",
"photosynthesis converts sunlight into chemical energy",
"the negotiators reached an agreement just before dawn",
"stars collapse into dense remnants when their fuel runs out",
"the chef seasoned the broth with ginger and lemongrass",
"children laughed as the kite climbed into the wind",
"the contract was signed in a quiet downtown office",
"neurons fire in cascading waves across the cortex",
"the train arrived late because of the heavy snow",
"a sculptor chipped patiently at the block of marble",
"the spacecraft adjusted its orbit around the moon",
"they planted tomatoes along the southern fence",
"the lecture covered the foundations of thermodynamics",
"fog rolled in from the harbor at first light",
"the startup pivoted toward enterprise customers",
"wolves coordinate their movements while hunting",
"the violinist closed her eyes during the solo",
"compilers translate source code into machine instructions",
"the desert blooms briefly after the spring rains",
"the committee debated the proposal for three hours",
"electrons occupy discrete energy levels in an atom",
"he sketched the bridge from across the river",
"the bakery sold out of bread by mid morning",
"satellites relay signals across the curved earth",
"the novel opens in a crowded train station",
]
def extract_bert_features(model_name: str, sentences: List[str], max_len: int,
device: torch.device):
"""Returns (features (S,L,H), mask (S,L), hidden). Real BERT; simulated on
failure."""
try:
from transformers import AutoTokenizer, AutoModel, logging
logging.set_verbosity_error()
try:
tok = AutoTokenizer.from_pretrained(model_name)
except Exception:
from transformers import BertTokenizer
tok = BertTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device).eval()
enc = tok(sentences, return_tensors='pt', padding='max_length',
truncation=True, max_length=max_len)
enc = {k: v.to(device) for k, v in enc.items()}
with torch.no_grad():
out = model(**enc)
feats = out.last_hidden_state # (S,L,H)
mask = enc['attention_mask'].float() # (S,L)
print(f" [features] real BERT '{model_name}' β†’ {tuple(feats.shape)}")
return feats, mask, model.config.hidden_size
except Exception as e:
print(f" [features] real BERT unavailable ({e}); simulated anisotropic")
H = 128
g = torch.Generator().manual_seed(0)
S, L = len(sentences), max_len
common = 3.0 * F.normalize(torch.randn(H, generator=g), dim=0) # anisotropy
basis = F.normalize(torch.randn(8, H, generator=g), dim=1)
coeffs = torch.randn(S, L, 8, generator=g)
feats = common + coeffs @ basis + 0.3 * torch.randn(S, L, H, generator=g)
mask = (torch.rand(S, L, generator=g) > 0.2).float()
mask[:, 0] = 1.0
return feats.to(device), mask.to(device), H
def sterilize(feats: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""Clean the distribution the frame sees: remove the anisotropic common
component (mean over real tokens) then unit-normalize per token."""
m = mask.unsqueeze(-1)
mean = (feats * m).sum(dim=(0, 1)) / m.sum(dim=(0, 1)).clamp_min(1.0)
centered = feats - mean
return F.normalize(centered, dim=-1)
# ════════════════════════════════════════════════════════════════════════
# Train the lensed transformer to patchwork-represent the features
# ════════════════════════════════════════════════════════════════════════
@dataclass
class BertConfig:
model_name: str = 'google/bert_uncased_L-2_H-128_A-2' # bert-base-uncased on Colab
max_len: int = 16
V: int = 32
D_base: int = 4
D_lens: int = 16
hidden: int = 64
n_heads: int = 4
n_layers: int = 2
epochs: int = 40
batch_size: int = 8
lr: float = 2e-3
rigid_weight: float = 0.5
mask_ratio: float = 0.0 # >0 = masked reconstruction (forces cross-patch Ξ±)
corpus_source: str = 'builtin' # 'builtin' | 'wikitext'
n_sentences: int = 256 # for wikitext
save_checkpoint: bool = False
out_dir: str = './geo_svae_bert_results'
seed: int = 0
def masked_recon(recon, target, mask):
m = mask.unsqueeze(-1)
mse = ((recon - target) ** 2 * m).sum() / m.sum().clamp_min(1.0) / target.shape[-1]
with torch.no_grad():
rn = F.normalize(recon.detach(), dim=-1)
tn = F.normalize(target, dim=-1)
cos = (rn * tn).sum(-1) # (S,L)
cos = (cos * mask).sum() / mask.sum().clamp_min(1.0)
return mse, float(cos)
def run_bert(cfg: BertConfig) -> Dict:
out_dir = Path(cfg.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(cfg.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(cfg.seed)
print("=" * 70)
print("geolip-svae-transformer β€” patchwork-represent BERT features")
print(f" model={cfg.model_name} max_len={cfg.max_len}")
print(f" frame V{cfg.V} D_base{cfg.D_base}β†’lens D{cfg.D_lens} "
f"{cfg.n_layers}LΓ—{cfg.n_heads}h | device={device}")
print("=" * 70)
corpus = get_corpus(cfg.corpus_source, cfg.n_sentences)
feats, mask, hidden = extract_bert_features(cfg.model_name, corpus,
cfg.max_len, device)
feats = sterilize(feats, mask) # the input cleaning
print(f" sterilized features: {tuple(feats.shape)} (hidden={hidden}), "
f"each token = one patch (patch_dim={hidden})")
geo = GeoConfig(V=cfg.V, D_base=cfg.D_base, D_lens=cfg.D_lens,
hidden=cfg.hidden, n_heads=cfg.n_heads, n_layers=cfg.n_layers,
patch_dim_override=hidden)
model = GeoSVAETransformer(geo).to(device)
n_params = sum(p.numel() for p in model.parameters())
print(f" trainable params: {n_params:,}")
if cfg.mask_ratio > 0:
print(f" TASK: masked reconstruction (mask_ratio={cfg.mask_ratio}) β€” "
f"masked tokens recoverable ONLY via cross-patch attention")
else:
print(f" TASK: per-token reconstruction (solvable per-patch)")
print()
opt = torch.optim.Adam(model.parameters(), lr=cfg.lr)
S = feats.shape[0]
steps = max(1, S // cfg.batch_size)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=cfg.epochs * steps)
model.train()
best_cos, history = 0.0, []
for epoch in range(cfg.epochs):
perm = torch.randperm(S, device=device)
ep_cos, ep_mse = [], []
for bi in range(steps):
idx = perm[bi * cfg.batch_size:(bi + 1) * cfg.batch_size]
x = feats[idx] # (B,L,hidden)
mk = mask[idx]
if cfg.mask_ratio > 0:
# mask a fraction of REAL tokens; score only those (recoverable
# only from neighbors β†’ demands cross-patch attention)
drop = (torch.rand_like(mk) < cfg.mask_ratio) * mk
x_in = x * (1.0 - drop).unsqueeze(-1) # zero the masked tokens
score = drop # score masked positions
else:
x_in, score = x, mk
opt.zero_grad()
out = model.forward_patches(x_in)
recon = out['recon_patches']
mse, _ = masked_recon(recon, x, score)
loss = mse + cfg.rigid_weight * rigidity_loss(out['M'], cfg.D_base)
loss.backward()
opt.step()
sched.step()
with torch.no_grad():
_, cos = masked_recon(recon, x, score)
ep_cos.append(cos)
ep_mse.append(float(mse.detach()))
with torch.no_grad():
full = model.forward_patches(feats)
g = measure_guarantee(full['M_lens'], cfg.D_base)
mean_alpha = full['mean_alpha']
mc = sum(ep_cos) / len(ep_cos)
best_cos = max(best_cos, mc)
history.append({'epoch': epoch, 'recon_cos': mc,
'mse': sum(ep_mse) / len(ep_mse),
'mean_alpha': mean_alpha, 'guarantee': g})
if epoch % 5 == 0 or epoch == cfg.epochs - 1:
print(f" epoch {epoch:2d}: recon_cos={mc:.4f} mse={ep_mse[-1]:.5f} | "
f"Ξ±={mean_alpha:.4f} | frame dev={g['deviation']:+.4f} "
f"in_env={g['in_envelope']} cv_of={g['cv_of']:.3f}")
final_g = history[-1]['guarantee']
final_alpha = history[-1]['mean_alpha']
mechanism = 'A (alpha-engaged cross-patch coordination)' if final_alpha > 0.05 \
else 'B (encoder mode-concentration, alpha near-identity)'
verdict = {
'represents_bert': best_cos > 0.7,
'best_recon_cos': best_cos,
'final_recon_cos': history[-1]['recon_cos'],
'guarantee_holds': final_g['in_envelope'],
'final_guarantee': final_g,
'final_mean_alpha': final_alpha,
'mechanism': mechanism,
'hidden': hidden, 'params': n_params,
}
report = {'config': asdict(cfg), 'history': history, 'verdict': verdict}
with open(out_dir / 'geo_svae_bert.json', 'w') as f:
json.dump(report, f, indent=2)
print("\n" + "=" * 70)
print("BERT-FEATURE VERDICT")
print("=" * 70)
print(f" {'βœ“' if verdict['represents_bert'] else 'βœ—'} patchwork-represents "
f"BERT features: recon cosine {best_cos:.4f} "
f"(1.0 = perfect feature reconstruction)")
print(f" {'βœ“' if verdict['guarantee_holds'] else 'βœ—'} rigidity guarantee "
f"holds while representing real features: dev {final_g['deviation']:+.4f} "
f"(crit Β±{dev_critical(cfg.D_base):.3f})")
print(f" Β· spectral-alpha Ξ±={final_alpha:.4f} β†’ Mechanism {mechanism}")
print(f" β†’ BERT's {hidden}-d per-token features are sterilized, then encoded")
print(f" as omega tokens on the rigid lensed frame and reconstructed.")
if cfg.save_checkpoint:
ckpt = {
'model_state_dict': model.state_dict(),
'geo_config': asdict(geo),
'bert_config': asdict(cfg),
'hidden': hidden,
'verdict': verdict,
}
ckpt_path = out_dir / 'geolip_svae_transformer.pt'
torch.save(ckpt, ckpt_path)
print(f" checkpoint: {ckpt_path}")
print(f" report: {out_dir / 'geo_svae_bert.json'}")
return report
# ════════════════════════════════════════════════════════════════════════
# Colab-proof
# ════════════════════════════════════════════════════════════════════════
def _is_jupyter_kernel():
try:
from IPython import get_ipython
ip = get_ipython()
return ip is not None and 'IPKernelApp' in ip.config
except Exception:
return False
def _filter_jupyter_args(argv):
out, skip = [], False
for a in argv:
if skip:
skip = False
continue
if a == '-f':
skip = True
continue
if a.startswith('-f=') or a.endswith('.json'):
continue
out.append(a)
return out
def run(**kwargs):
"""from geolip_svae_bert_features import run
run() # small BERT (sandbox)
run(model_name='bert-base-uncased', D_lens=64, epochs=60) # Colab
"""
cfg = BertConfig(**{k: v for k, v in kwargs.items()
if k in BertConfig.__dataclass_fields__})
return run_bert(cfg)
def main(argv=None):
import sys
if argv is None:
argv = sys.argv[1:]
if _is_jupyter_kernel():
argv = _filter_jupyter_args(argv)
p = argparse.ArgumentParser()
p.add_argument('--model-name', default='google/bert_uncased_L-2_H-128_A-2')
p.add_argument('--D-lens', type=int, default=16)
p.add_argument('--epochs', type=int, default=40)
p.add_argument('--out-dir', default='./geo_svae_bert_results')
args, _unknown = p.parse_known_args(argv)
return run_bert(BertConfig(model_name=args.model_name, D_lens=args.D_lens,
epochs=args.epochs, out_dir=args.out_dir))
if __name__ == '__main__':
main()