| """geolip_svae_transformer.py β prototype v1. |
| |
| The geolip-svae-transformer: a geometric STRUCTURAL COMPANION that imposes the |
| lensed rigid structure on a standard-transformer-shaped token interface, for |
| compression and generation. Not a competitor β a partner that supplies a |
| uniformly identifiable geometric coordinate system the host attends through. |
| |
| GROUNDED IN HOW THE SVAE SCALES THE INPUT (geolip_svae/model.py): |
| image β extract_patches β (B,N,patch_dim=CΒ·psΒ·ps) |
| β encode: enc_inβblocksβenc_outβreshape(V,D)βsphere-normalize rows = M |
| β SVD-ish split: M = UΒ·SΒ·Vt, where S (D singular values) is the |
| data-specific OMEGA TOKEN and U/Vt/the sphere-normalized M is the |
| UNIFORM GEOMETRIC FRAME (rigid, in-envelope, CV-band, same signature |
| every patch). |
| β SpectralCrossAttention coordinates S across patches. |
| β decode: UΒ·SΒ·Vt β patch β stitch. |
| |
| THE PROTOTYPE'S MOVE: |
| The omega token S is already transformer-token-shaped (a D-vector per |
| position, attended position-to-position). So we: |
| 1. encode patches β sphere M (the geometric frame) [front-end] |
| 2. LENS-FRAME M to D_lens via a rigidity-preserving isometric lift β |
| this is the guarantee: the frame stays rigid + in-envelope at D_lens, |
| where native large-D collapses (exp_003). The lens WIDENS the spectral |
| token for generation capacity while the rigid frame scales with it. |
| 3. read the omega token S from the lensed frame [spectral] |
| 4. cross-patch SPECTRAL TRANSFORMER over S β the relational data |
| selection, "attention through the same avenues we're already using." |
| 5. the attended S modulates the rigid frame (MΒ·S, mirrors UΒ·S) [decode] |
| 6. decode β stitch β reconstruction. |
| Compression = patches β omega tokens. Generation/decode = omega tokens β |
| patches through the rigid frame. |
| |
| SWAPPABLE FRONT-END: use_real_svae=True wraps the installed geolip_svae.PatchSVAE |
| (Colab); default lean vendored encoder matches its structure (sandbox + Colab). |
| Everything else is vendored (protos not installed). Self-contained, deterministic. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import math |
| import time |
| from dataclasses import dataclass, asdict |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| |
| |
| |
|
|
| def canon(v: torch.Tensor) -> torch.Tensor: |
| nz = v.abs() > 1e-6 |
| fi = nz.float().argmax(dim=-1, keepdim=True) |
| fv = torch.gather(v, -1, fi) |
| return v * torch.where(fv < 0, -1.0, 1.0) |
|
|
|
|
| _UMEAN: Dict[int, float] = {} |
|
|
|
|
| def uniform_projective_angle(D: int, n: int = 4096, seed: int = 0) -> float: |
| if D in _UMEAN: |
| return _UMEAN[D] |
| g = torch.Generator().manual_seed(seed) |
| pts = torch.randn(n, D, generator=g) |
| pts = canon(pts / pts.norm(dim=1, keepdim=True).clamp_min(1e-12)) |
| cos = (pts @ pts.T).clamp(-1, 1) |
| ang = torch.acos(cos.abs()) |
| iu = torch.triu_indices(n, n, offset=1) |
| _UMEAN[D] = float(ang[iu[0], iu[1]].mean()) |
| return _UMEAN[D] |
|
|
|
|
| def dev_critical(D: int, coeff: float = 0.02) -> float: |
| return coeff * math.sqrt(D) |
|
|
|
|
| def intrinsic_deviation(M_rows: torch.Tensor, baseline_D: int) -> float: |
| with torch.no_grad(): |
| a = M_rows / M_rows.norm(dim=1, keepdim=True).clamp_min(1e-12) |
| cos = (a @ a.T).abs().clamp(0, 1 - 1e-7) |
| n = M_rows.shape[0] |
| iu = torch.triu_indices(n, n, offset=1, device=M_rows.device) |
| mean_ang = float(torch.acos(cos[iu[0], iu[1]]).mean()) |
| return mean_ang - uniform_projective_angle(baseline_D) |
|
|
|
|
| def cayley_menger_sq_vol(coords: torch.Tensor) -> torch.Tensor: |
| D2 = torch.cdist(coords, coords) ** 2 |
| CM = torch.ones(6, 6, device=coords.device, dtype=coords.dtype) |
| CM[0, 0] = 0.0 |
| CM[1:, 1:] = D2 |
| return -torch.linalg.det(CM) / 9216.0 |
|
|
|
|
| def cv_of(codebook: torch.Tensor, n_samples: int = 1000, seed: int = 0) -> float: |
| V = codebook.shape[0] |
| if V < 5: |
| return 0.0 |
| g = torch.Generator(device='cpu').manual_seed(seed) |
| cb = codebook.detach().cpu().float() |
| vols = [] |
| for _ in range(n_samples): |
| idx = torch.randperm(V, generator=g)[:5] |
| sq = float(cayley_menger_sq_vol(cb[idx])) |
| vols.append(math.sqrt(max(sq, 0.0))) |
| v = torch.tensor(vols) |
| return float(v.std(unbiased=False) / v.mean().clamp_min(1e-12)) |
|
|
|
|
| def cv_band_for(D: int) -> Tuple[float, float]: |
| if D <= 4: |
| return (0.85, 1.05) |
| if D <= 8: |
| return (0.32, 0.45) |
| if D <= 16: |
| return (0.20, 0.23) |
| return (0.0, 0.20) |
|
|
|
|
| _UCOS: Dict[int, float] = {} |
|
|
|
|
| def uniform_mean_abscos(D: int, n: int = 4096, seed: int = 0) -> float: |
| """Mean pairwise |cos| of uniform directions on S^(D-1) β the rigid target.""" |
| if D in _UCOS: |
| return _UCOS[D] |
| g = torch.Generator().manual_seed(seed) |
| pts = F.normalize(torch.randn(n, D, generator=g), dim=1) |
| cos = (pts @ pts.T).abs() |
| iu = torch.triu_indices(n, n, offset=1) |
| _UCOS[D] = float(cos[iu[0], iu[1]].mean()) |
| return _UCOS[D] |
|
|
|
|
| def rigidity_loss(M: torch.Tensor, D_base: int) -> torch.Tensor: |
| """Differentiable: pull each patch's per-row |cos| spectrum toward the |
| uniform (rigid) baseline so M stays in-envelope. M: (B,N,V,D).""" |
| B, N, V, D = M.shape |
| Mn = F.normalize(M, dim=-1).reshape(B * N, V, D) |
| cos = (Mn @ Mn.transpose(1, 2)).abs() |
| iu = torch.triu_indices(V, V, offset=1, device=M.device) |
| pair = cos[:, iu[0], iu[1]] |
| target = uniform_mean_abscos(D_base) |
| return (pair.mean(dim=1) - target).pow(2).mean() |
|
|
|
|
| |
| |
| |
|
|
| def extract_patches(images: torch.Tensor, ps: int): |
| B, C, H, W = images.shape |
| gh, gw = H // ps, W // ps |
| p = images.reshape(B, C, gh, ps, gw, ps) |
| p = p.permute(0, 2, 4, 1, 3, 5).contiguous() |
| return p.reshape(B, gh * gw, C * ps * ps), gh, gw |
|
|
|
|
| def stitch_patches(patches: torch.Tensor, gh: int, gw: int, ps: int, C: int): |
| B = patches.shape[0] |
| p = patches.reshape(B, gh, gw, C, ps, ps) |
| return p.permute(0, 3, 1, 4, 2, 5).reshape(B, C, gh * ps, gw * ps) |
|
|
|
|
| |
| |
| |
|
|
| class LeanGeometricEncoder(nn.Module): |
| """Vendored encoder matching PatchSVAE's structure: enc_in β residual |
| blocks β enc_out β reshape(V,D) β sphere row-normalize. Sandbox-runnable.""" |
|
|
| def __init__(self, patch_dim: int, V: int, D: int, hidden: int, depth: int = 1): |
| super().__init__() |
| self.V, self.D = V, D |
| self.enc_in = nn.Linear(patch_dim, hidden) |
| self.blocks = nn.ModuleList([ |
| nn.Sequential(nn.Linear(hidden, hidden), nn.GELU(), |
| nn.Linear(hidden, hidden)) |
| for _ in range(depth)]) |
| self.enc_out = nn.Linear(hidden, V * D) |
| nn.init.orthogonal_(self.enc_out.weight) |
|
|
| def forward(self, patches: torch.Tensor) -> torch.Tensor: |
| B, N, _ = patches.shape |
| h = F.gelu(self.enc_in(patches.reshape(B * N, -1))) |
| for blk in self.blocks: |
| h = h + blk(h) |
| M = self.enc_out(h).reshape(B * N, self.V, self.D) |
| M = F.normalize(M, dim=-1) |
| return canon(M).reshape(B, N, self.V, self.D) |
|
|
|
|
| class RealSVAEEncoder(nn.Module): |
| """Adapter around the installed geolip_svae.PatchSVAE (Colab). Uses its |
| H2-stable encode_patches to produce the sphere-normalized M. Returns the |
| same (B,N,V,D) interface as the lean encoder.""" |
|
|
| def __init__(self, patch_dim: int, V: int, D: int, hidden: int, ps: int, |
| channels: int, depth: int = 1, freeze: bool = False): |
| super().__init__() |
| from geolip_svae.model import PatchSVAE |
| self.svae = PatchSVAE(V=V, D=D, ps=ps, hidden=hidden, channels=channels, |
| depth=depth, n_cross=1, linear_readout=True, |
| svd_mode='none', match_params=True, row_norm='sphere') |
| self.V, self.D = V, D |
| if freeze: |
| for p in self.svae.parameters(): |
| p.requires_grad_(False) |
|
|
| def forward(self, patches: torch.Tensor) -> torch.Tensor: |
| out = self.svae.encode_patches(patches) |
| M = out['M'] |
| if M.dim() == 3: |
| B, N, _ = patches.shape |
| M = M.reshape(B, N, self.V, self.D) |
| return canon(F.normalize(M, dim=-1)) |
|
|
|
|
| def make_encoder(patch_dim, V, D, hidden, ps, channels, use_real_svae): |
| if use_real_svae: |
| try: |
| return RealSVAEEncoder(patch_dim, V, D, hidden, ps, channels) |
| except Exception as e: |
| print(f" [front-end] real PatchSVAE unavailable ({e}); " |
| f"falling back to lean encoder") |
| return LeanGeometricEncoder(patch_dim, V, D, hidden) |
|
|
|
|
| |
| |
| |
|
|
| class LensFrame(nn.Module): |
| """Lift the sphere-normalized frame to D_lens via an orthonormal embedding. |
| β¨Ex,Ecβ© = β¨x,cβ© β pairwise projective angles preserved EXACTLY β the |
| frame's rigidity (in-envelope at D_base) is carried up to D_lens intact. |
| This is the guarantee the lens-framed features provide.""" |
|
|
| def __init__(self, D_base: int, D_lens: int, seed: int = 0): |
| super().__init__() |
| assert D_lens >= D_base |
| self.D_base, self.D_lens = D_base, D_lens |
| g = torch.Generator().manual_seed(seed) |
| Q = torch.linalg.qr(torch.randn(D_lens, D_base, generator=g))[0] |
| self.register_buffer('E', Q) |
|
|
| def forward(self, M: torch.Tensor) -> torch.Tensor: |
| """M: (B,N,V,D_base) β (B,N,V,D_lens), rows still on the sphere.""" |
| M_lens = M @ self.E.T |
| return canon(F.normalize(M_lens, dim=-1)) |
|
|
|
|
| |
| |
| |
|
|
| class SpectralAlphaAttention(nn.Module): |
| """Faithful to the SVAE's SpectralCrossAttention β the ONLY attention the |
| omegas are aligned to behave with: |
| |
| S_out = S Β· (1 + Ξ±_d Β· tanh(out_proj(SDPA(qkv(norm(S))))_d)) |
| |
| MULTIPLICATIVE (not additive), per-mode Ξ± bounded to [0, max_alpha] and |
| initialized near zero (sigmoid(-2)Β·0.2 β 0.024) so the attention starts as |
| near-identity and ENGAGES GRADUALLY as the omegas form β curation, not |
| forced convergence. max_alpha / alpha_init are the curation knobs.""" |
|
|
| def __init__(self, D: int, n_heads: int = 4, max_alpha: float = 0.2, |
| alpha_init: float = -2.0): |
| super().__init__() |
| assert D % n_heads == 0, f"D={D} must be divisible by n_heads={n_heads}" |
| self.n_heads = n_heads |
| self.head_dim = D // n_heads |
| self.max_alpha = max_alpha |
| self.qkv = nn.Linear(D, 3 * D) |
| self.out_proj = nn.Linear(D, D) |
| self.norm = nn.LayerNorm(D) |
| self.alpha_logits = nn.Parameter(torch.full((D,), float(alpha_init))) |
|
|
| @property |
| def alpha(self) -> torch.Tensor: |
| return self.max_alpha * torch.sigmoid(self.alpha_logits) |
|
|
| def forward(self, S: torch.Tensor) -> torch.Tensor: |
| B, N, D = S.shape |
| S_n = self.norm(S) |
| qkv = self.qkv(S_n).reshape(B, N, 3, self.n_heads, self.head_dim) |
| qkv = qkv.permute(2, 0, 3, 1, 4) |
| q, k, v = qkv[0], qkv[1], qkv[2] |
| out = F.scaled_dot_product_attention(q, k, v) |
| out = out.transpose(1, 2).reshape(B, N, D) |
| gate = torch.tanh(self.out_proj(out)) |
| return S * (1.0 + self.alpha.view(1, 1, -1) * gate) |
|
|
|
|
| class SpectralAlphaStack(nn.Module): |
| """Stack of spectral-alpha layers coordinating omegas across patches.""" |
|
|
| def __init__(self, D: int, n_heads: int, n_layers: int, |
| max_alpha: float = 0.2, alpha_init: float = -2.0): |
| super().__init__() |
| self.layers = nn.ModuleList([ |
| SpectralAlphaAttention(D, n_heads, max_alpha, alpha_init) |
| for _ in range(n_layers)]) |
|
|
| def forward(self, S: torch.Tensor) -> torch.Tensor: |
| for layer in self.layers: |
| S = layer(S) |
| return S |
|
|
| def mean_alpha(self) -> float: |
| """The bloom signal: mean engaged alpha across layers (init β 0.024).""" |
| with torch.no_grad(): |
| return float(torch.stack([l.alpha.mean() for l in self.layers]).mean()) |
|
|
|
|
| |
| |
| |
|
|
| class GeoDecoder(nn.Module): |
| """Decode the attended-omega-modulated frame back to a patch. |
| M_dec = M_lens Β· S_attended (broadcast over V rows) mirrors UΒ·S; a small |
| net reads the modulated frame to patch_dim.""" |
|
|
| def __init__(self, V: int, D_lens: int, patch_dim: int, hidden: int, |
| depth: int = 1): |
| super().__init__() |
| self.dec_in = nn.Linear(V * D_lens, hidden) |
| self.blocks = nn.ModuleList([ |
| nn.Sequential(nn.Linear(hidden, hidden), nn.GELU(), |
| nn.Linear(hidden, hidden)) |
| for _ in range(depth)]) |
| self.dec_out = nn.Linear(hidden, patch_dim) |
|
|
| def forward(self, M_lens: torch.Tensor, S_att: torch.Tensor) -> torch.Tensor: |
| |
| B, N, V, D = M_lens.shape |
| M_dec = M_lens * S_att.unsqueeze(2) |
| h = F.gelu(self.dec_in(M_dec.reshape(B * N, V * D))) |
| for blk in self.blocks: |
| h = h + blk(h) |
| return self.dec_out(h).reshape(B, N, -1) |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class GeoConfig: |
| img_size: int = 32 |
| channels: int = 3 |
| ps: int = 4 |
| V: int = 32 |
| D_base: int = 4 |
| D_lens: int = 16 |
| hidden: int = 64 |
| n_heads: int = 4 |
| n_layers: int = 2 |
| max_alpha: float = 0.2 |
| alpha_init: float = -2.0 |
| use_real_svae: bool = False |
| patch_dim_override: Optional[int] = None |
|
|
|
|
| class GeoSVAETransformer(nn.Module): |
| def __init__(self, cfg: GeoConfig): |
| super().__init__() |
| self.cfg = cfg |
| patch_dim = cfg.patch_dim_override or (cfg.channels * cfg.ps * cfg.ps) |
| self.patch_dim = patch_dim |
| self.feature_mode = cfg.patch_dim_override is not None |
| self.encoder = make_encoder(patch_dim, cfg.V, cfg.D_base, cfg.hidden, |
| cfg.ps, cfg.channels, |
| cfg.use_real_svae and not self.feature_mode) |
| self.lens = LensFrame(cfg.D_base, cfg.D_lens) |
| |
| |
| self.transformer = SpectralAlphaStack(cfg.D_lens, cfg.n_heads, |
| cfg.n_layers, cfg.max_alpha, |
| cfg.alpha_init) |
| self.decoder = GeoDecoder(cfg.V, cfg.D_lens, patch_dim, cfg.hidden) |
|
|
| def omega_token(self, M_lens: torch.Tensor) -> torch.Tensor: |
| """The omega: per-mode spectral magnitude (column norms) of the rigid |
| frame. Forms through curation as the encoder shapes M β never forced. |
| M_lens (B,N,V,D_lens) β S (B,N,D_lens).""" |
| return M_lens.norm(dim=-2) |
|
|
| def forward_patches(self, patches: torch.Tensor) -> Dict: |
| """Core path on (B, N, patch_dim) β image patches OR feature tokens.""" |
| M = self.encoder(patches) |
| M_lens = self.lens(M) |
| S = self.omega_token(M_lens) |
| S_att = self.transformer(S) |
| dec_patches = self.decoder(M_lens, S_att) |
| return {'recon_patches': dec_patches, 'M': M, 'M_lens': M_lens, |
| 'omega': S_att, 'mean_alpha': self.transformer.mean_alpha()} |
|
|
| def forward(self, images: torch.Tensor) -> Dict: |
| cfg = self.cfg |
| patches, gh, gw = extract_patches(images, cfg.ps) |
| out = self.forward_patches(patches) |
| recon = stitch_patches(out['recon_patches'], gh, gw, cfg.ps, cfg.channels) |
| out['recon'] = recon |
| return out |
|
|
|
|
| |
| |
| |
|
|
| def make_batch(B, img_size, channels, step, seed, use_real_svae): |
| sd = seed * 100000 + step |
| if use_real_svae: |
| try: |
| from geolip_svae.inference import gen_sixteen_noise |
| x = gen_sixteen_noise(n=B, size=img_size, seed=sd) |
| if x.shape[1] != channels: |
| x = x[:, :channels] if x.shape[1] > channels else \ |
| torch.cat([x, x[:, :1].repeat(1, channels - x.shape[1], 1, 1)], 1) |
| return x.clamp(-4, 4) |
| except Exception: |
| pass |
| g = torch.Generator().manual_seed(sd) |
| |
| base = torch.randn(B, channels, img_size, img_size, generator=g) |
| smooth = F.avg_pool2d(base, 4, stride=1, padding=2)[..., :img_size, :img_size] |
| return (0.6 * smooth + 0.4 * base).clamp(-4, 4) |
|
|
|
|
| def measure_guarantee(M_lens: torch.Tensor, D_base: int) -> Dict: |
| """The lens-framed rigidity guarantee at D_lens: take one patch's frame, |
| measure cv_of + intrinsic deviation vs the D_base baseline (preserved by |
| the isometric lift).""" |
| cb = M_lens.detach()[0, 0] |
| cv = cv_of(cb) |
| dev = intrinsic_deviation(cb, D_base) |
| lo, hi = cv_band_for(D_base) |
| return {'cv_of': cv, 'in_d_base_band': lo <= cv <= hi, |
| 'deviation': dev, 'in_envelope': abs(dev) < dev_critical(D_base)} |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class TrainConfig: |
| epochs: int = 8 |
| steps_per_epoch: int = 150 |
| batch_size: int = 64 |
| lr: float = 2e-3 |
| rigid_weight: float = 0.5 |
| out_dir: str = './geo_svae_transformer_results' |
| seed: int = 0 |
|
|
|
|
| def run_train(geo: GeoConfig, tr: TrainConfig) -> Dict: |
| out_dir = Path(tr.out_dir) |
| out_dir.mkdir(parents=True, exist_ok=True) |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| torch.manual_seed(tr.seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(tr.seed) |
|
|
| print("=" * 70) |
| print("geolip-svae-transformer prototype β compression train + guarantee") |
| print(f" img{geo.img_size} ps{geo.ps} β {(geo.img_size//geo.ps)**2} patches | " |
| f"V{geo.V} D_base{geo.D_base} β lens D{geo.D_lens} | " |
| f"{geo.n_layers}LΓ{geo.n_heads}h | device={device}") |
| print(f" front-end: {'REAL PatchSVAE' if geo.use_real_svae else 'lean vendored'}") |
| print("=" * 70) |
|
|
| model = GeoSVAETransformer(geo).to(device) |
| n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f" trainable params: {n_params:,}") |
| opt = torch.optim.Adam(model.parameters(), lr=tr.lr) |
| sched = torch.optim.lr_scheduler.CosineAnnealingLR( |
| opt, T_max=tr.epochs * tr.steps_per_epoch) |
|
|
| model.train() |
| step, best = 0, float('inf') |
| history = [] |
| for epoch in range(tr.epochs): |
| ep_losses = [] |
| for _ in range(tr.steps_per_epoch): |
| images = make_batch(tr.batch_size, geo.img_size, geo.channels, |
| step, tr.seed, geo.use_real_svae).to(device) |
| opt.zero_grad() |
| out = model(images) |
| recon_loss = F.mse_loss(out['recon'], images) |
| rig_loss = rigidity_loss(out['M'], geo.D_base) |
| loss = recon_loss + tr.rigid_weight * rig_loss |
| loss.backward() |
| opt.step() |
| sched.step() |
| ep_losses.append(float(recon_loss.detach())) |
| best = min(best, float(recon_loss.detach())) |
| step += 1 |
| with torch.no_grad(): |
| probe = make_batch(tr.batch_size, geo.img_size, geo.channels, |
| 99999, tr.seed, geo.use_real_svae).to(device) |
| g = measure_guarantee(model(probe)['M_lens'], geo.D_base) |
| mean_mse = sum(ep_losses[-tr.steps_per_epoch//4:]) / max(1, tr.steps_per_epoch//4) |
| history.append({'epoch': epoch, 'mean_mse': mean_mse, 'best_mse': best, |
| 'mean_alpha': out['mean_alpha'], 'guarantee': g}) |
| print(f" epoch {epoch:2d}: mse={mean_mse:.5f} (best {best:.5f}) | " |
| f"Ξ±={out['mean_alpha']:.4f} | lens-frame: cv_of={g['cv_of']:.3f} " |
| f"dev={g['deviation']:+.4f} in_env={g['in_envelope']}") |
|
|
| final_g = history[-1]['guarantee'] |
| verdict = { |
| 'compresses': best < 0.05, |
| 'best_mse': best, |
| 'guarantee_holds': final_g['in_envelope'], |
| 'cv_in_band': final_g['in_d_base_band'], |
| 'final_guarantee': final_g, |
| 'trainable_params': n_params, |
| } |
| report = {'geo_config': asdict(geo), 'train_config': asdict(tr), |
| 'history': history, 'verdict': verdict} |
| with open(out_dir / 'geo_svae_transformer.json', 'w') as f: |
| json.dump(report, f, indent=2) |
|
|
| print("\n" + "=" * 70) |
| print("PROTOTYPE VERDICT") |
| print("=" * 70) |
| print(f" {'β' if verdict['compresses'] else 'β'} compresses: best MSE " |
| f"{best:.5f} (input β omega tokens β reconstruction)") |
| print(f" {'β' if verdict['guarantee_holds'] else 'β'} GUARANTEE holds: " |
| f"lens-framed frame in rigidity envelope at D_lens={geo.D_lens} " |
| f"(dev {final_g['deviation']:+.4f}, crit Β±{dev_critical(geo.D_base):.3f})") |
| print(f" {'Β·'} cv_of {final_g['cv_of']:.3f} " |
| f"({'in' if verdict['cv_in_band'] else 'just above'} d{geo.D_base} " |
| f"random-codebook band β secondary signature)") |
| print(f" β omega tokens are transformer-shaped (D_lens-vectors, cross-patch") |
| print(f" attended); the rigid frame is imposed underneath and survives the") |
| print(f" lens in-envelope. Compression trained; decode = generation seed.") |
| print(f" report: {out_dir / 'geo_svae_transformer.json'}") |
| return report |
|
|
|
|
| |
| |
| |
|
|
| def _is_jupyter_kernel(): |
| try: |
| from IPython import get_ipython |
| ip = get_ipython() |
| return ip is not None and 'IPKernelApp' in ip.config |
| except Exception: |
| return False |
|
|
|
|
| def _filter_jupyter_args(argv): |
| out, skip = [], False |
| for a in argv: |
| if skip: |
| skip = False |
| continue |
| if a == '-f': |
| skip = True |
| continue |
| if a.startswith('-f=') or a.endswith('.json'): |
| continue |
| out.append(a) |
| return out |
|
|
|
|
| def run(**kwargs): |
| """Notebook entry: |
| from geolip_svae_transformer import run |
| run() # lean front-end (works anywhere) |
| run(use_real_svae=True) # real PatchSVAE (Colab, geolip_core) |
| run(D_lens=64, n_layers=4, epochs=12) |
| """ |
| geo = GeoConfig(**{k: v for k, v in kwargs.items() |
| if k in GeoConfig.__dataclass_fields__}) |
| tr = TrainConfig(**{k: v for k, v in kwargs.items() |
| if k in TrainConfig.__dataclass_fields__}) |
| return run_train(geo, tr) |
|
|
|
|
| def main(argv=None): |
| import sys |
| if argv is None: |
| argv = sys.argv[1:] |
| if _is_jupyter_kernel(): |
| argv = _filter_jupyter_args(argv) |
| p = argparse.ArgumentParser() |
| p.add_argument('--D-lens', type=int, default=16) |
| p.add_argument('--n-layers', type=int, default=2) |
| p.add_argument('--epochs', type=int, default=8) |
| p.add_argument('--use-real-svae', action='store_true') |
| p.add_argument('--out-dir', default='./geo_svae_transformer_results') |
| args, _unknown = p.parse_known_args(argv) |
| geo = GeoConfig(D_lens=args.D_lens, n_layers=args.n_layers, |
| use_real_svae=args.use_real_svae) |
| tr = TrainConfig(epochs=args.epochs, out_dir=args.out_dir) |
| return run_train(geo, tr) |
|
|
|
|
| if __name__ == '__main__': |
| main() |