"""geolip_svae_bert_features.py — patchwork-represent BERT features. Tests whether the geolip-svae-transformer's lensed rigid frame can compress and reconstruct real semantic features: each BERT token's hidden vector is treated as a patch (patch_dim = hidden), the token sequence is the N patches, and the model patchwork-represents the sequence as omega tokens on the rigid frame. THE STERILIZE STEP (from the trigram lesson — we control the SVAE's input distribution): BERT features are strongly anisotropic (a dominant common direction + a few high-variance axes). We center out the common component and unit-normalize per token before the frame sees them, so the SVAE gets a clean, isotropic distribution — far more utilizable, as Phil noted. SWAPPABLE FEATURE SOURCE: real transformers BERT (default 'google/bert_uncased_L-2_H-128_A-2' for sandbox, swap bert-base-uncased on Colab) with a simulated-anisotropic fallback so the pipeline always runs. Snap alongside geolip_svae_transformer.py. Deterministic, Colab-proof run(). """ from __future__ import annotations import argparse import json import math import time from dataclasses import dataclass, asdict from pathlib import Path from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from geolip_svae_transformer import ( GeoSVAETransformer, GeoConfig, rigidity_loss, measure_guarantee, dev_critical, ) # ════════════════════════════════════════════════════════════════════════ # Corpus + BERT feature extraction (swappable real / simulated) # ════════════════════════════════════════════════════════════════════════ def get_corpus(source: str, n: int) -> List[str]: """Return a list of sentences. 'builtin' = the curated set; 'wikitext' = real Wikipedia sentences (more diverse → more potent test).""" if source == 'builtin': return CORPUS if source == 'wikitext': from datasets import load_dataset ds = load_dataset('Salesforce/wikitext', 'wikitext-2-raw-v1', split='train') sents = [] for row in ds: t = row['text'].strip() if len(t) < 40 or t.startswith('='): # skip headers/short continue for s in t.replace('\n', ' ').split('. '): s = s.strip() if 30 < len(s) < 220: sents.append(s if s.endswith('.') else s + '.') if len(sents) >= n: break if len(sents) >= n: break print(f" [corpus] wikitext: {len(sents)} sentences") return sents raise ValueError(f"unknown corpus source: {source}") CORPUS = [ "the cat sat quietly on the warm windowsill", "quantum fields fluctuate in the vacuum of empty space", "she poured the coffee and opened her laptop", "the river carved a deep canyon over millions of years", "investors watched the market tumble after the announcement", "a gentle rain fell across the sleeping village", "the algorithm sorts the array in logarithmic time", "he tuned the old guitar before the evening show", "photosynthesis converts sunlight into chemical energy", "the negotiators reached an agreement just before dawn", "stars collapse into dense remnants when their fuel runs out", "the chef seasoned the broth with ginger and lemongrass", "children laughed as the kite climbed into the wind", "the contract was signed in a quiet downtown office", "neurons fire in cascading waves across the cortex", "the train arrived late because of the heavy snow", "a sculptor chipped patiently at the block of marble", "the spacecraft adjusted its orbit around the moon", "they planted tomatoes along the southern fence", "the lecture covered the foundations of thermodynamics", "fog rolled in from the harbor at first light", "the startup pivoted toward enterprise customers", "wolves coordinate their movements while hunting", "the violinist closed her eyes during the solo", "compilers translate source code into machine instructions", "the desert blooms briefly after the spring rains", "the committee debated the proposal for three hours", "electrons occupy discrete energy levels in an atom", "he sketched the bridge from across the river", "the bakery sold out of bread by mid morning", "satellites relay signals across the curved earth", "the novel opens in a crowded train station", ] def extract_bert_features(model_name: str, sentences: List[str], max_len: int, device: torch.device): """Returns (features (S,L,H), mask (S,L), hidden). Real BERT; simulated on failure.""" try: from transformers import AutoTokenizer, AutoModel, logging logging.set_verbosity_error() try: tok = AutoTokenizer.from_pretrained(model_name) except Exception: from transformers import BertTokenizer tok = BertTokenizer.from_pretrained(model_name) model = AutoModel.from_pretrained(model_name).to(device).eval() enc = tok(sentences, return_tensors='pt', padding='max_length', truncation=True, max_length=max_len) enc = {k: v.to(device) for k, v in enc.items()} with torch.no_grad(): out = model(**enc) feats = out.last_hidden_state # (S,L,H) mask = enc['attention_mask'].float() # (S,L) print(f" [features] real BERT '{model_name}' → {tuple(feats.shape)}") return feats, mask, model.config.hidden_size except Exception as e: print(f" [features] real BERT unavailable ({e}); simulated anisotropic") H = 128 g = torch.Generator().manual_seed(0) S, L = len(sentences), max_len common = 3.0 * F.normalize(torch.randn(H, generator=g), dim=0) # anisotropy basis = F.normalize(torch.randn(8, H, generator=g), dim=1) coeffs = torch.randn(S, L, 8, generator=g) feats = common + coeffs @ basis + 0.3 * torch.randn(S, L, H, generator=g) mask = (torch.rand(S, L, generator=g) > 0.2).float() mask[:, 0] = 1.0 return feats.to(device), mask.to(device), H def sterilize(feats: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: """Clean the distribution the frame sees: remove the anisotropic common component (mean over real tokens) then unit-normalize per token.""" m = mask.unsqueeze(-1) mean = (feats * m).sum(dim=(0, 1)) / m.sum(dim=(0, 1)).clamp_min(1.0) centered = feats - mean return F.normalize(centered, dim=-1) # ════════════════════════════════════════════════════════════════════════ # Train the lensed transformer to patchwork-represent the features # ════════════════════════════════════════════════════════════════════════ @dataclass class BertConfig: model_name: str = 'google/bert_uncased_L-2_H-128_A-2' # bert-base-uncased on Colab max_len: int = 16 V: int = 32 D_base: int = 4 D_lens: int = 16 hidden: int = 64 n_heads: int = 4 n_layers: int = 2 epochs: int = 40 batch_size: int = 8 lr: float = 2e-3 rigid_weight: float = 0.5 mask_ratio: float = 0.0 # >0 = masked reconstruction (forces cross-patch α) corpus_source: str = 'builtin' # 'builtin' | 'wikitext' n_sentences: int = 256 # for wikitext save_checkpoint: bool = False out_dir: str = './geo_svae_bert_results' seed: int = 0 def masked_recon(recon, target, mask): m = mask.unsqueeze(-1) mse = ((recon - target) ** 2 * m).sum() / m.sum().clamp_min(1.0) / target.shape[-1] with torch.no_grad(): rn = F.normalize(recon.detach(), dim=-1) tn = F.normalize(target, dim=-1) cos = (rn * tn).sum(-1) # (S,L) cos = (cos * mask).sum() / mask.sum().clamp_min(1.0) return mse, float(cos) def run_bert(cfg: BertConfig) -> Dict: out_dir = Path(cfg.out_dir) out_dir.mkdir(parents=True, exist_ok=True) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') torch.manual_seed(cfg.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(cfg.seed) print("=" * 70) print("geolip-svae-transformer — patchwork-represent BERT features") print(f" model={cfg.model_name} max_len={cfg.max_len}") print(f" frame V{cfg.V} D_base{cfg.D_base}→lens D{cfg.D_lens} " f"{cfg.n_layers}L×{cfg.n_heads}h | device={device}") print("=" * 70) corpus = get_corpus(cfg.corpus_source, cfg.n_sentences) feats, mask, hidden = extract_bert_features(cfg.model_name, corpus, cfg.max_len, device) feats = sterilize(feats, mask) # the input cleaning print(f" sterilized features: {tuple(feats.shape)} (hidden={hidden}), " f"each token = one patch (patch_dim={hidden})") geo = GeoConfig(V=cfg.V, D_base=cfg.D_base, D_lens=cfg.D_lens, hidden=cfg.hidden, n_heads=cfg.n_heads, n_layers=cfg.n_layers, patch_dim_override=hidden) model = GeoSVAETransformer(geo).to(device) n_params = sum(p.numel() for p in model.parameters()) print(f" trainable params: {n_params:,}") if cfg.mask_ratio > 0: print(f" TASK: masked reconstruction (mask_ratio={cfg.mask_ratio}) — " f"masked tokens recoverable ONLY via cross-patch attention") else: print(f" TASK: per-token reconstruction (solvable per-patch)") print() opt = torch.optim.Adam(model.parameters(), lr=cfg.lr) S = feats.shape[0] steps = max(1, S // cfg.batch_size) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=cfg.epochs * steps) model.train() best_cos, history = 0.0, [] for epoch in range(cfg.epochs): perm = torch.randperm(S, device=device) ep_cos, ep_mse = [], [] for bi in range(steps): idx = perm[bi * cfg.batch_size:(bi + 1) * cfg.batch_size] x = feats[idx] # (B,L,hidden) mk = mask[idx] if cfg.mask_ratio > 0: # mask a fraction of REAL tokens; score only those (recoverable # only from neighbors → demands cross-patch attention) drop = (torch.rand_like(mk) < cfg.mask_ratio) * mk x_in = x * (1.0 - drop).unsqueeze(-1) # zero the masked tokens score = drop # score masked positions else: x_in, score = x, mk opt.zero_grad() out = model.forward_patches(x_in) recon = out['recon_patches'] mse, _ = masked_recon(recon, x, score) loss = mse + cfg.rigid_weight * rigidity_loss(out['M'], cfg.D_base) loss.backward() opt.step() sched.step() with torch.no_grad(): _, cos = masked_recon(recon, x, score) ep_cos.append(cos) ep_mse.append(float(mse.detach())) with torch.no_grad(): full = model.forward_patches(feats) g = measure_guarantee(full['M_lens'], cfg.D_base) mean_alpha = full['mean_alpha'] mc = sum(ep_cos) / len(ep_cos) best_cos = max(best_cos, mc) history.append({'epoch': epoch, 'recon_cos': mc, 'mse': sum(ep_mse) / len(ep_mse), 'mean_alpha': mean_alpha, 'guarantee': g}) if epoch % 5 == 0 or epoch == cfg.epochs - 1: print(f" epoch {epoch:2d}: recon_cos={mc:.4f} mse={ep_mse[-1]:.5f} | " f"α={mean_alpha:.4f} | frame dev={g['deviation']:+.4f} " f"in_env={g['in_envelope']} cv_of={g['cv_of']:.3f}") final_g = history[-1]['guarantee'] final_alpha = history[-1]['mean_alpha'] mechanism = 'A (alpha-engaged cross-patch coordination)' if final_alpha > 0.05 \ else 'B (encoder mode-concentration, alpha near-identity)' verdict = { 'represents_bert': best_cos > 0.7, 'best_recon_cos': best_cos, 'final_recon_cos': history[-1]['recon_cos'], 'guarantee_holds': final_g['in_envelope'], 'final_guarantee': final_g, 'final_mean_alpha': final_alpha, 'mechanism': mechanism, 'hidden': hidden, 'params': n_params, } report = {'config': asdict(cfg), 'history': history, 'verdict': verdict} with open(out_dir / 'geo_svae_bert.json', 'w') as f: json.dump(report, f, indent=2) print("\n" + "=" * 70) print("BERT-FEATURE VERDICT") print("=" * 70) print(f" {'✓' if verdict['represents_bert'] else '✗'} patchwork-represents " f"BERT features: recon cosine {best_cos:.4f} " f"(1.0 = perfect feature reconstruction)") print(f" {'✓' if verdict['guarantee_holds'] else '✗'} rigidity guarantee " f"holds while representing real features: dev {final_g['deviation']:+.4f} " f"(crit ±{dev_critical(cfg.D_base):.3f})") print(f" · spectral-alpha α={final_alpha:.4f} → Mechanism {mechanism}") print(f" → BERT's {hidden}-d per-token features are sterilized, then encoded") print(f" as omega tokens on the rigid lensed frame and reconstructed.") if cfg.save_checkpoint: ckpt = { 'model_state_dict': model.state_dict(), 'geo_config': asdict(geo), 'bert_config': asdict(cfg), 'hidden': hidden, 'verdict': verdict, } ckpt_path = out_dir / 'geolip_svae_transformer.pt' torch.save(ckpt, ckpt_path) print(f" checkpoint: {ckpt_path}") print(f" report: {out_dir / 'geo_svae_bert.json'}") return report # ════════════════════════════════════════════════════════════════════════ # Colab-proof # ════════════════════════════════════════════════════════════════════════ def _is_jupyter_kernel(): try: from IPython import get_ipython ip = get_ipython() return ip is not None and 'IPKernelApp' in ip.config except Exception: return False def _filter_jupyter_args(argv): out, skip = [], False for a in argv: if skip: skip = False continue if a == '-f': skip = True continue if a.startswith('-f=') or a.endswith('.json'): continue out.append(a) return out def run(**kwargs): """from geolip_svae_bert_features import run run() # small BERT (sandbox) run(model_name='bert-base-uncased', D_lens=64, epochs=60) # Colab """ cfg = BertConfig(**{k: v for k, v in kwargs.items() if k in BertConfig.__dataclass_fields__}) return run_bert(cfg) def main(argv=None): import sys if argv is None: argv = sys.argv[1:] if _is_jupyter_kernel(): argv = _filter_jupyter_args(argv) p = argparse.ArgumentParser() p.add_argument('--model-name', default='google/bert_uncased_L-2_H-128_A-2') p.add_argument('--D-lens', type=int, default=16) p.add_argument('--epochs', type=int, default=40) p.add_argument('--out-dir', default='./geo_svae_bert_results') args, _unknown = p.parse_known_args(argv) return run_bert(BertConfig(model_name=args.model_name, D_lens=args.D_lens, epochs=args.epochs, out_dir=args.out_dir)) if __name__ == '__main__': main()