| """geolip_svae_bert_features.py β patchwork-represent BERT features. |
| |
| Tests whether the geolip-svae-transformer's lensed rigid frame can compress and |
| reconstruct real semantic features: each BERT token's hidden vector is treated |
| as a patch (patch_dim = hidden), the token sequence is the N patches, and the |
| model patchwork-represents the sequence as omega tokens on the rigid frame. |
| |
| THE STERILIZE STEP (from the trigram lesson β we control the SVAE's input |
| distribution): BERT features are strongly anisotropic (a dominant common |
| direction + a few high-variance axes). We center out the common component and |
| unit-normalize per token before the frame sees them, so the SVAE gets a clean, |
| isotropic distribution β far more utilizable, as Phil noted. |
| |
| SWAPPABLE FEATURE SOURCE: real transformers BERT (default |
| 'google/bert_uncased_L-2_H-128_A-2' for sandbox, swap bert-base-uncased on |
| Colab) with a simulated-anisotropic fallback so the pipeline always runs. |
| |
| Snap alongside geolip_svae_transformer.py. Deterministic, Colab-proof run(). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import math |
| import time |
| from dataclasses import dataclass, asdict |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from geolip_svae_transformer import ( |
| GeoSVAETransformer, GeoConfig, rigidity_loss, measure_guarantee, |
| dev_critical, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def get_corpus(source: str, n: int) -> List[str]: |
| """Return a list of sentences. 'builtin' = the curated set; 'wikitext' = |
| real Wikipedia sentences (more diverse β more potent test).""" |
| if source == 'builtin': |
| return CORPUS |
| if source == 'wikitext': |
| from datasets import load_dataset |
| ds = load_dataset('Salesforce/wikitext', 'wikitext-2-raw-v1', split='train') |
| sents = [] |
| for row in ds: |
| t = row['text'].strip() |
| if len(t) < 40 or t.startswith('='): |
| continue |
| for s in t.replace('\n', ' ').split('. '): |
| s = s.strip() |
| if 30 < len(s) < 220: |
| sents.append(s if s.endswith('.') else s + '.') |
| if len(sents) >= n: |
| break |
| if len(sents) >= n: |
| break |
| print(f" [corpus] wikitext: {len(sents)} sentences") |
| return sents |
| raise ValueError(f"unknown corpus source: {source}") |
|
|
|
|
| CORPUS = [ |
| "the cat sat quietly on the warm windowsill", |
| "quantum fields fluctuate in the vacuum of empty space", |
| "she poured the coffee and opened her laptop", |
| "the river carved a deep canyon over millions of years", |
| "investors watched the market tumble after the announcement", |
| "a gentle rain fell across the sleeping village", |
| "the algorithm sorts the array in logarithmic time", |
| "he tuned the old guitar before the evening show", |
| "photosynthesis converts sunlight into chemical energy", |
| "the negotiators reached an agreement just before dawn", |
| "stars collapse into dense remnants when their fuel runs out", |
| "the chef seasoned the broth with ginger and lemongrass", |
| "children laughed as the kite climbed into the wind", |
| "the contract was signed in a quiet downtown office", |
| "neurons fire in cascading waves across the cortex", |
| "the train arrived late because of the heavy snow", |
| "a sculptor chipped patiently at the block of marble", |
| "the spacecraft adjusted its orbit around the moon", |
| "they planted tomatoes along the southern fence", |
| "the lecture covered the foundations of thermodynamics", |
| "fog rolled in from the harbor at first light", |
| "the startup pivoted toward enterprise customers", |
| "wolves coordinate their movements while hunting", |
| "the violinist closed her eyes during the solo", |
| "compilers translate source code into machine instructions", |
| "the desert blooms briefly after the spring rains", |
| "the committee debated the proposal for three hours", |
| "electrons occupy discrete energy levels in an atom", |
| "he sketched the bridge from across the river", |
| "the bakery sold out of bread by mid morning", |
| "satellites relay signals across the curved earth", |
| "the novel opens in a crowded train station", |
| ] |
|
|
|
|
| def extract_bert_features(model_name: str, sentences: List[str], max_len: int, |
| device: torch.device): |
| """Returns (features (S,L,H), mask (S,L), hidden). Real BERT; simulated on |
| failure.""" |
| try: |
| from transformers import AutoTokenizer, AutoModel, logging |
| logging.set_verbosity_error() |
| try: |
| tok = AutoTokenizer.from_pretrained(model_name) |
| except Exception: |
| from transformers import BertTokenizer |
| tok = BertTokenizer.from_pretrained(model_name) |
| model = AutoModel.from_pretrained(model_name).to(device).eval() |
| enc = tok(sentences, return_tensors='pt', padding='max_length', |
| truncation=True, max_length=max_len) |
| enc = {k: v.to(device) for k, v in enc.items()} |
| with torch.no_grad(): |
| out = model(**enc) |
| feats = out.last_hidden_state |
| mask = enc['attention_mask'].float() |
| print(f" [features] real BERT '{model_name}' β {tuple(feats.shape)}") |
| return feats, mask, model.config.hidden_size |
| except Exception as e: |
| print(f" [features] real BERT unavailable ({e}); simulated anisotropic") |
| H = 128 |
| g = torch.Generator().manual_seed(0) |
| S, L = len(sentences), max_len |
| common = 3.0 * F.normalize(torch.randn(H, generator=g), dim=0) |
| basis = F.normalize(torch.randn(8, H, generator=g), dim=1) |
| coeffs = torch.randn(S, L, 8, generator=g) |
| feats = common + coeffs @ basis + 0.3 * torch.randn(S, L, H, generator=g) |
| mask = (torch.rand(S, L, generator=g) > 0.2).float() |
| mask[:, 0] = 1.0 |
| return feats.to(device), mask.to(device), H |
|
|
|
|
| def sterilize(feats: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| """Clean the distribution the frame sees: remove the anisotropic common |
| component (mean over real tokens) then unit-normalize per token.""" |
| m = mask.unsqueeze(-1) |
| mean = (feats * m).sum(dim=(0, 1)) / m.sum(dim=(0, 1)).clamp_min(1.0) |
| centered = feats - mean |
| return F.normalize(centered, dim=-1) |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class BertConfig: |
| model_name: str = 'google/bert_uncased_L-2_H-128_A-2' |
| max_len: int = 16 |
| V: int = 32 |
| D_base: int = 4 |
| D_lens: int = 16 |
| hidden: int = 64 |
| n_heads: int = 4 |
| n_layers: int = 2 |
| epochs: int = 40 |
| batch_size: int = 8 |
| lr: float = 2e-3 |
| rigid_weight: float = 0.5 |
| mask_ratio: float = 0.0 |
| corpus_source: str = 'builtin' |
| n_sentences: int = 256 |
| save_checkpoint: bool = False |
| out_dir: str = './geo_svae_bert_results' |
| seed: int = 0 |
|
|
|
|
| def masked_recon(recon, target, mask): |
| m = mask.unsqueeze(-1) |
| mse = ((recon - target) ** 2 * m).sum() / m.sum().clamp_min(1.0) / target.shape[-1] |
| with torch.no_grad(): |
| rn = F.normalize(recon.detach(), dim=-1) |
| tn = F.normalize(target, dim=-1) |
| cos = (rn * tn).sum(-1) |
| cos = (cos * mask).sum() / mask.sum().clamp_min(1.0) |
| return mse, float(cos) |
|
|
|
|
| def run_bert(cfg: BertConfig) -> Dict: |
| out_dir = Path(cfg.out_dir) |
| out_dir.mkdir(parents=True, exist_ok=True) |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| torch.manual_seed(cfg.seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(cfg.seed) |
|
|
| print("=" * 70) |
| print("geolip-svae-transformer β patchwork-represent BERT features") |
| print(f" model={cfg.model_name} max_len={cfg.max_len}") |
| print(f" frame V{cfg.V} D_base{cfg.D_base}βlens D{cfg.D_lens} " |
| f"{cfg.n_layers}LΓ{cfg.n_heads}h | device={device}") |
| print("=" * 70) |
|
|
| corpus = get_corpus(cfg.corpus_source, cfg.n_sentences) |
| feats, mask, hidden = extract_bert_features(cfg.model_name, corpus, |
| cfg.max_len, device) |
| feats = sterilize(feats, mask) |
| print(f" sterilized features: {tuple(feats.shape)} (hidden={hidden}), " |
| f"each token = one patch (patch_dim={hidden})") |
|
|
| geo = GeoConfig(V=cfg.V, D_base=cfg.D_base, D_lens=cfg.D_lens, |
| hidden=cfg.hidden, n_heads=cfg.n_heads, n_layers=cfg.n_layers, |
| patch_dim_override=hidden) |
| model = GeoSVAETransformer(geo).to(device) |
| n_params = sum(p.numel() for p in model.parameters()) |
| print(f" trainable params: {n_params:,}") |
| if cfg.mask_ratio > 0: |
| print(f" TASK: masked reconstruction (mask_ratio={cfg.mask_ratio}) β " |
| f"masked tokens recoverable ONLY via cross-patch attention") |
| else: |
| print(f" TASK: per-token reconstruction (solvable per-patch)") |
| print() |
|
|
| opt = torch.optim.Adam(model.parameters(), lr=cfg.lr) |
| S = feats.shape[0] |
| steps = max(1, S // cfg.batch_size) |
| sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=cfg.epochs * steps) |
|
|
| model.train() |
| best_cos, history = 0.0, [] |
| for epoch in range(cfg.epochs): |
| perm = torch.randperm(S, device=device) |
| ep_cos, ep_mse = [], [] |
| for bi in range(steps): |
| idx = perm[bi * cfg.batch_size:(bi + 1) * cfg.batch_size] |
| x = feats[idx] |
| mk = mask[idx] |
| if cfg.mask_ratio > 0: |
| |
| |
| drop = (torch.rand_like(mk) < cfg.mask_ratio) * mk |
| x_in = x * (1.0 - drop).unsqueeze(-1) |
| score = drop |
| else: |
| x_in, score = x, mk |
| opt.zero_grad() |
| out = model.forward_patches(x_in) |
| recon = out['recon_patches'] |
| mse, _ = masked_recon(recon, x, score) |
| loss = mse + cfg.rigid_weight * rigidity_loss(out['M'], cfg.D_base) |
| loss.backward() |
| opt.step() |
| sched.step() |
| with torch.no_grad(): |
| _, cos = masked_recon(recon, x, score) |
| ep_cos.append(cos) |
| ep_mse.append(float(mse.detach())) |
| with torch.no_grad(): |
| full = model.forward_patches(feats) |
| g = measure_guarantee(full['M_lens'], cfg.D_base) |
| mean_alpha = full['mean_alpha'] |
| mc = sum(ep_cos) / len(ep_cos) |
| best_cos = max(best_cos, mc) |
| history.append({'epoch': epoch, 'recon_cos': mc, |
| 'mse': sum(ep_mse) / len(ep_mse), |
| 'mean_alpha': mean_alpha, 'guarantee': g}) |
| if epoch % 5 == 0 or epoch == cfg.epochs - 1: |
| print(f" epoch {epoch:2d}: recon_cos={mc:.4f} mse={ep_mse[-1]:.5f} | " |
| f"Ξ±={mean_alpha:.4f} | frame dev={g['deviation']:+.4f} " |
| f"in_env={g['in_envelope']} cv_of={g['cv_of']:.3f}") |
|
|
| final_g = history[-1]['guarantee'] |
| final_alpha = history[-1]['mean_alpha'] |
| mechanism = 'A (alpha-engaged cross-patch coordination)' if final_alpha > 0.05 \ |
| else 'B (encoder mode-concentration, alpha near-identity)' |
| verdict = { |
| 'represents_bert': best_cos > 0.7, |
| 'best_recon_cos': best_cos, |
| 'final_recon_cos': history[-1]['recon_cos'], |
| 'guarantee_holds': final_g['in_envelope'], |
| 'final_guarantee': final_g, |
| 'final_mean_alpha': final_alpha, |
| 'mechanism': mechanism, |
| 'hidden': hidden, 'params': n_params, |
| } |
| report = {'config': asdict(cfg), 'history': history, 'verdict': verdict} |
| with open(out_dir / 'geo_svae_bert.json', 'w') as f: |
| json.dump(report, f, indent=2) |
|
|
| print("\n" + "=" * 70) |
| print("BERT-FEATURE VERDICT") |
| print("=" * 70) |
| print(f" {'β' if verdict['represents_bert'] else 'β'} patchwork-represents " |
| f"BERT features: recon cosine {best_cos:.4f} " |
| f"(1.0 = perfect feature reconstruction)") |
| print(f" {'β' if verdict['guarantee_holds'] else 'β'} rigidity guarantee " |
| f"holds while representing real features: dev {final_g['deviation']:+.4f} " |
| f"(crit Β±{dev_critical(cfg.D_base):.3f})") |
| print(f" Β· spectral-alpha Ξ±={final_alpha:.4f} β Mechanism {mechanism}") |
| print(f" β BERT's {hidden}-d per-token features are sterilized, then encoded") |
| print(f" as omega tokens on the rigid lensed frame and reconstructed.") |
| if cfg.save_checkpoint: |
| ckpt = { |
| 'model_state_dict': model.state_dict(), |
| 'geo_config': asdict(geo), |
| 'bert_config': asdict(cfg), |
| 'hidden': hidden, |
| 'verdict': verdict, |
| } |
| ckpt_path = out_dir / 'geolip_svae_transformer.pt' |
| torch.save(ckpt, ckpt_path) |
| print(f" checkpoint: {ckpt_path}") |
|
|
| print(f" report: {out_dir / 'geo_svae_bert.json'}") |
| return report |
|
|
|
|
| |
| |
| |
|
|
| def _is_jupyter_kernel(): |
| try: |
| from IPython import get_ipython |
| ip = get_ipython() |
| return ip is not None and 'IPKernelApp' in ip.config |
| except Exception: |
| return False |
|
|
|
|
| def _filter_jupyter_args(argv): |
| out, skip = [], False |
| for a in argv: |
| if skip: |
| skip = False |
| continue |
| if a == '-f': |
| skip = True |
| continue |
| if a.startswith('-f=') or a.endswith('.json'): |
| continue |
| out.append(a) |
| return out |
|
|
|
|
| def run(**kwargs): |
| """from geolip_svae_bert_features import run |
| run() # small BERT (sandbox) |
| run(model_name='bert-base-uncased', D_lens=64, epochs=60) # Colab |
| """ |
| cfg = BertConfig(**{k: v for k, v in kwargs.items() |
| if k in BertConfig.__dataclass_fields__}) |
| return run_bert(cfg) |
|
|
|
|
| def main(argv=None): |
| import sys |
| if argv is None: |
| argv = sys.argv[1:] |
| if _is_jupyter_kernel(): |
| argv = _filter_jupyter_args(argv) |
| p = argparse.ArgumentParser() |
| p.add_argument('--model-name', default='google/bert_uncased_L-2_H-128_A-2') |
| p.add_argument('--D-lens', type=int, default=16) |
| p.add_argument('--epochs', type=int, default=40) |
| p.add_argument('--out-dir', default='./geo_svae_bert_results') |
| args, _unknown = p.parse_known_args(argv) |
| return run_bert(BertConfig(model_name=args.model_name, D_lens=args.D_lens, |
| epochs=args.epochs, out_dir=args.out_dir)) |
|
|
|
|
| if __name__ == '__main__': |
| main() |