geolip-svae-transformer / transformer_prototype.py
AbstractPhil's picture
Create transformer_prototype.py
93ea1d2 verified
"""geolip_svae_transformer.py β€” prototype v1.
The geolip-svae-transformer: a geometric STRUCTURAL COMPANION that imposes the
lensed rigid structure on a standard-transformer-shaped token interface, for
compression and generation. Not a competitor β€” a partner that supplies a
uniformly identifiable geometric coordinate system the host attends through.
GROUNDED IN HOW THE SVAE SCALES THE INPUT (geolip_svae/model.py):
image β†’ extract_patches β†’ (B,N,patch_dim=CΒ·psΒ·ps)
→ encode: enc_in→blocks→enc_out→reshape(V,D)→sphere-normalize rows = M
β†’ SVD-ish split: M = UΒ·SΒ·Vt, where S (D singular values) is the
data-specific OMEGA TOKEN and U/Vt/the sphere-normalized M is the
UNIFORM GEOMETRIC FRAME (rigid, in-envelope, CV-band, same signature
every patch).
β†’ SpectralCrossAttention coordinates S across patches.
β†’ decode: UΒ·SΒ·Vt β†’ patch β†’ stitch.
THE PROTOTYPE'S MOVE:
The omega token S is already transformer-token-shaped (a D-vector per
position, attended position-to-position). So we:
1. encode patches β†’ sphere M (the geometric frame) [front-end]
2. LENS-FRAME M to D_lens via a rigidity-preserving isometric lift β€”
this is the guarantee: the frame stays rigid + in-envelope at D_lens,
where native large-D collapses (exp_003). The lens WIDENS the spectral
token for generation capacity while the rigid frame scales with it.
3. read the omega token S from the lensed frame [spectral]
4. cross-patch SPECTRAL TRANSFORMER over S β€” the relational data
selection, "attention through the same avenues we're already using."
5. the attended S modulates the rigid frame (MΒ·S, mirrors UΒ·S) [decode]
6. decode β†’ stitch β†’ reconstruction.
Compression = patches β†’ omega tokens. Generation/decode = omega tokens β†’
patches through the rigid frame.
SWAPPABLE FRONT-END: use_real_svae=True wraps the installed geolip_svae.PatchSVAE
(Colab); default lean vendored encoder matches its structure (sandbox + Colab).
Everything else is vendored (protos not installed). Self-contained, deterministic.
"""
from __future__ import annotations
import argparse
import json
import math
import time
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
# ════════════════════════════════════════════════════════════════════════
# VENDORED geometry + cv_of (verified against catalog bands D4 .98 / D8 .38 / D16 .21)
# ════════════════════════════════════════════════════════════════════════
def canon(v: torch.Tensor) -> torch.Tensor:
nz = v.abs() > 1e-6
fi = nz.float().argmax(dim=-1, keepdim=True)
fv = torch.gather(v, -1, fi)
return v * torch.where(fv < 0, -1.0, 1.0)
_UMEAN: Dict[int, float] = {}
def uniform_projective_angle(D: int, n: int = 4096, seed: int = 0) -> float:
if D in _UMEAN:
return _UMEAN[D]
g = torch.Generator().manual_seed(seed)
pts = torch.randn(n, D, generator=g)
pts = canon(pts / pts.norm(dim=1, keepdim=True).clamp_min(1e-12))
cos = (pts @ pts.T).clamp(-1, 1)
ang = torch.acos(cos.abs())
iu = torch.triu_indices(n, n, offset=1)
_UMEAN[D] = float(ang[iu[0], iu[1]].mean())
return _UMEAN[D]
def dev_critical(D: int, coeff: float = 0.02) -> float:
return coeff * math.sqrt(D)
def intrinsic_deviation(M_rows: torch.Tensor, baseline_D: int) -> float:
with torch.no_grad():
a = M_rows / M_rows.norm(dim=1, keepdim=True).clamp_min(1e-12)
cos = (a @ a.T).abs().clamp(0, 1 - 1e-7)
n = M_rows.shape[0]
iu = torch.triu_indices(n, n, offset=1, device=M_rows.device)
mean_ang = float(torch.acos(cos[iu[0], iu[1]]).mean())
return mean_ang - uniform_projective_angle(baseline_D)
def cayley_menger_sq_vol(coords: torch.Tensor) -> torch.Tensor:
D2 = torch.cdist(coords, coords) ** 2
CM = torch.ones(6, 6, device=coords.device, dtype=coords.dtype)
CM[0, 0] = 0.0
CM[1:, 1:] = D2
return -torch.linalg.det(CM) / 9216.0
def cv_of(codebook: torch.Tensor, n_samples: int = 1000, seed: int = 0) -> float:
V = codebook.shape[0]
if V < 5:
return 0.0
g = torch.Generator(device='cpu').manual_seed(seed)
cb = codebook.detach().cpu().float()
vols = []
for _ in range(n_samples):
idx = torch.randperm(V, generator=g)[:5]
sq = float(cayley_menger_sq_vol(cb[idx]))
vols.append(math.sqrt(max(sq, 0.0)))
v = torch.tensor(vols)
return float(v.std(unbiased=False) / v.mean().clamp_min(1e-12))
def cv_band_for(D: int) -> Tuple[float, float]:
if D <= 4:
return (0.85, 1.05)
if D <= 8:
return (0.32, 0.45)
if D <= 16:
return (0.20, 0.23)
return (0.0, 0.20)
_UCOS: Dict[int, float] = {}
def uniform_mean_abscos(D: int, n: int = 4096, seed: int = 0) -> float:
"""Mean pairwise |cos| of uniform directions on S^(D-1) β€” the rigid target."""
if D in _UCOS:
return _UCOS[D]
g = torch.Generator().manual_seed(seed)
pts = F.normalize(torch.randn(n, D, generator=g), dim=1)
cos = (pts @ pts.T).abs()
iu = torch.triu_indices(n, n, offset=1)
_UCOS[D] = float(cos[iu[0], iu[1]].mean())
return _UCOS[D]
def rigidity_loss(M: torch.Tensor, D_base: int) -> torch.Tensor:
"""Differentiable: pull each patch's per-row |cos| spectrum toward the
uniform (rigid) baseline so M stays in-envelope. M: (B,N,V,D)."""
B, N, V, D = M.shape
Mn = F.normalize(M, dim=-1).reshape(B * N, V, D)
cos = (Mn @ Mn.transpose(1, 2)).abs() # (BN, V, V)
iu = torch.triu_indices(V, V, offset=1, device=M.device)
pair = cos[:, iu[0], iu[1]] # (BN, V*(V-1)/2)
target = uniform_mean_abscos(D_base)
return (pair.mean(dim=1) - target).pow(2).mean()
# ════════════════════════════════════════════════════════════════════════
# Patchify / stitch (vendored from geolip_svae/model.py)
# ════════════════════════════════════════════════════════════════════════
def extract_patches(images: torch.Tensor, ps: int):
B, C, H, W = images.shape
gh, gw = H // ps, W // ps
p = images.reshape(B, C, gh, ps, gw, ps)
p = p.permute(0, 2, 4, 1, 3, 5).contiguous()
return p.reshape(B, gh * gw, C * ps * ps), gh, gw
def stitch_patches(patches: torch.Tensor, gh: int, gw: int, ps: int, C: int):
B = patches.shape[0]
p = patches.reshape(B, gh, gw, C, ps, ps)
return p.permute(0, 3, 1, 4, 2, 5).reshape(B, C, gh * ps, gw * ps)
# ════════════════════════════════════════════════════════════════════════
# Geometric front-end (swappable) β€” patch β†’ sphere-normalized M (V, D_base)
# ════════════════════════════════════════════════════════════════════════
class LeanGeometricEncoder(nn.Module):
"""Vendored encoder matching PatchSVAE's structure: enc_in β†’ residual
blocks β†’ enc_out β†’ reshape(V,D) β†’ sphere row-normalize. Sandbox-runnable."""
def __init__(self, patch_dim: int, V: int, D: int, hidden: int, depth: int = 1):
super().__init__()
self.V, self.D = V, D
self.enc_in = nn.Linear(patch_dim, hidden)
self.blocks = nn.ModuleList([
nn.Sequential(nn.Linear(hidden, hidden), nn.GELU(),
nn.Linear(hidden, hidden))
for _ in range(depth)])
self.enc_out = nn.Linear(hidden, V * D)
nn.init.orthogonal_(self.enc_out.weight)
def forward(self, patches: torch.Tensor) -> torch.Tensor:
B, N, _ = patches.shape
h = F.gelu(self.enc_in(patches.reshape(B * N, -1)))
for blk in self.blocks:
h = h + blk(h)
M = self.enc_out(h).reshape(B * N, self.V, self.D)
M = F.normalize(M, dim=-1) # rows on S^(D-1)
return canon(M).reshape(B, N, self.V, self.D)
class RealSVAEEncoder(nn.Module):
"""Adapter around the installed geolip_svae.PatchSVAE (Colab). Uses its
H2-stable encode_patches to produce the sphere-normalized M. Returns the
same (B,N,V,D) interface as the lean encoder."""
def __init__(self, patch_dim: int, V: int, D: int, hidden: int, ps: int,
channels: int, depth: int = 1, freeze: bool = False):
super().__init__()
from geolip_svae.model import PatchSVAE # Colab-only import
self.svae = PatchSVAE(V=V, D=D, ps=ps, hidden=hidden, channels=channels,
depth=depth, n_cross=1, linear_readout=True,
svd_mode='none', match_params=True, row_norm='sphere')
self.V, self.D = V, D
if freeze:
for p in self.svae.parameters():
p.requires_grad_(False)
def forward(self, patches: torch.Tensor) -> torch.Tensor:
out = self.svae.encode_patches(patches)
M = out['M']
if M.dim() == 3: # (B*N, V, D)
B, N, _ = patches.shape
M = M.reshape(B, N, self.V, self.D)
return canon(F.normalize(M, dim=-1))
def make_encoder(patch_dim, V, D, hidden, ps, channels, use_real_svae):
if use_real_svae:
try:
return RealSVAEEncoder(patch_dim, V, D, hidden, ps, channels)
except Exception as e:
print(f" [front-end] real PatchSVAE unavailable ({e}); "
f"falling back to lean encoder")
return LeanGeometricEncoder(patch_dim, V, D, hidden)
# ════════════════════════════════════════════════════════════════════════
# Lens frame β€” rigidity-preserving isometric lift D_base β†’ D_lens (the guarantee)
# ════════════════════════════════════════════════════════════════════════
class LensFrame(nn.Module):
"""Lift the sphere-normalized frame to D_lens via an orthonormal embedding.
⟨Ex,Ec⟩ = ⟨x,c⟩ β‡’ pairwise projective angles preserved EXACTLY β‡’ the
frame's rigidity (in-envelope at D_base) is carried up to D_lens intact.
This is the guarantee the lens-framed features provide."""
def __init__(self, D_base: int, D_lens: int, seed: int = 0):
super().__init__()
assert D_lens >= D_base
self.D_base, self.D_lens = D_base, D_lens
g = torch.Generator().manual_seed(seed)
Q = torch.linalg.qr(torch.randn(D_lens, D_base, generator=g))[0]
self.register_buffer('E', Q) # (D_lens, D_base)
def forward(self, M: torch.Tensor) -> torch.Tensor:
"""M: (B,N,V,D_base) β†’ (B,N,V,D_lens), rows still on the sphere."""
M_lens = M @ self.E.T
return canon(F.normalize(M_lens, dim=-1))
# ════════════════════════════════════════════════════════════════════════
# Spectral-alpha attention β€” the SVAE's actual mechanism (dot-alpha MHA)
# ════════════════════════════════════════════════════════════════════════
class SpectralAlphaAttention(nn.Module):
"""Faithful to the SVAE's SpectralCrossAttention β€” the ONLY attention the
omegas are aligned to behave with:
S_out = S Β· (1 + Ξ±_d Β· tanh(out_proj(SDPA(qkv(norm(S))))_d))
MULTIPLICATIVE (not additive), per-mode Ξ± bounded to [0, max_alpha] and
initialized near zero (sigmoid(-2)Β·0.2 β‰ˆ 0.024) so the attention starts as
near-identity and ENGAGES GRADUALLY as the omegas form β€” curation, not
forced convergence. max_alpha / alpha_init are the curation knobs."""
def __init__(self, D: int, n_heads: int = 4, max_alpha: float = 0.2,
alpha_init: float = -2.0):
super().__init__()
assert D % n_heads == 0, f"D={D} must be divisible by n_heads={n_heads}"
self.n_heads = n_heads
self.head_dim = D // n_heads
self.max_alpha = max_alpha
self.qkv = nn.Linear(D, 3 * D)
self.out_proj = nn.Linear(D, D)
self.norm = nn.LayerNorm(D)
self.alpha_logits = nn.Parameter(torch.full((D,), float(alpha_init)))
@property
def alpha(self) -> torch.Tensor:
return self.max_alpha * torch.sigmoid(self.alpha_logits) # [0, max_alpha]
def forward(self, S: torch.Tensor) -> torch.Tensor:
B, N, D = S.shape
S_n = self.norm(S)
qkv = self.qkv(S_n).reshape(B, N, 3, self.n_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
out = F.scaled_dot_product_attention(q, k, v)
out = out.transpose(1, 2).reshape(B, N, D)
gate = torch.tanh(self.out_proj(out))
return S * (1.0 + self.alpha.view(1, 1, -1) * gate) # multiplicative
class SpectralAlphaStack(nn.Module):
"""Stack of spectral-alpha layers coordinating omegas across patches."""
def __init__(self, D: int, n_heads: int, n_layers: int,
max_alpha: float = 0.2, alpha_init: float = -2.0):
super().__init__()
self.layers = nn.ModuleList([
SpectralAlphaAttention(D, n_heads, max_alpha, alpha_init)
for _ in range(n_layers)])
def forward(self, S: torch.Tensor) -> torch.Tensor:
for layer in self.layers:
S = layer(S)
return S
def mean_alpha(self) -> float:
"""The bloom signal: mean engaged alpha across layers (init β‰ˆ 0.024)."""
with torch.no_grad():
return float(torch.stack([l.alpha.mean() for l in self.layers]).mean())
# ════════════════════════════════════════════════════════════════════════
# Decoder β€” modulated rigid frame β†’ patch (the generation path)
# ════════════════════════════════════════════════════════════════════════
class GeoDecoder(nn.Module):
"""Decode the attended-omega-modulated frame back to a patch.
M_dec = M_lens Β· S_attended (broadcast over V rows) mirrors UΒ·S; a small
net reads the modulated frame to patch_dim."""
def __init__(self, V: int, D_lens: int, patch_dim: int, hidden: int,
depth: int = 1):
super().__init__()
self.dec_in = nn.Linear(V * D_lens, hidden)
self.blocks = nn.ModuleList([
nn.Sequential(nn.Linear(hidden, hidden), nn.GELU(),
nn.Linear(hidden, hidden))
for _ in range(depth)])
self.dec_out = nn.Linear(hidden, patch_dim)
def forward(self, M_lens: torch.Tensor, S_att: torch.Tensor) -> torch.Tensor:
# M_lens:(B,N,V,D_lens) S_att:(B,N,D_lens)
B, N, V, D = M_lens.shape
M_dec = M_lens * S_att.unsqueeze(2) # modulate frame by token
h = F.gelu(self.dec_in(M_dec.reshape(B * N, V * D)))
for blk in self.blocks:
h = h + blk(h)
return self.dec_out(h).reshape(B, N, -1)
# ════════════════════════════════════════════════════════════════════════
# The geolip-svae-transformer
# ════════════════════════════════════════════════════════════════════════
@dataclass
class GeoConfig:
img_size: int = 32
channels: int = 3
ps: int = 4
V: int = 32
D_base: int = 4
D_lens: int = 16
hidden: int = 64
n_heads: int = 4
n_layers: int = 2
max_alpha: float = 0.2 # spectral-alpha ceiling (curation knob)
alpha_init: float = -2.0 # near-zero engaged alpha at init (β‰ˆ0.024)
use_real_svae: bool = False
patch_dim_override: Optional[int] = None # set = feature mode (e.g. BERT hidden)
class GeoSVAETransformer(nn.Module):
def __init__(self, cfg: GeoConfig):
super().__init__()
self.cfg = cfg
patch_dim = cfg.patch_dim_override or (cfg.channels * cfg.ps * cfg.ps)
self.patch_dim = patch_dim
self.feature_mode = cfg.patch_dim_override is not None
self.encoder = make_encoder(patch_dim, cfg.V, cfg.D_base, cfg.hidden,
cfg.ps, cfg.channels,
cfg.use_real_svae and not self.feature_mode)
self.lens = LensFrame(cfg.D_base, cfg.D_lens)
# the omega EMERGES as the spectral magnitude of the rigid frame β€”
# not a learned squash. No parameters here.
self.transformer = SpectralAlphaStack(cfg.D_lens, cfg.n_heads,
cfg.n_layers, cfg.max_alpha,
cfg.alpha_init)
self.decoder = GeoDecoder(cfg.V, cfg.D_lens, patch_dim, cfg.hidden)
def omega_token(self, M_lens: torch.Tensor) -> torch.Tensor:
"""The omega: per-mode spectral magnitude (column norms) of the rigid
frame. Forms through curation as the encoder shapes M β€” never forced.
M_lens (B,N,V,D_lens) β†’ S (B,N,D_lens)."""
return M_lens.norm(dim=-2)
def forward_patches(self, patches: torch.Tensor) -> Dict:
"""Core path on (B, N, patch_dim) β€” image patches OR feature tokens."""
M = self.encoder(patches) # (B,N,V,D_base) sphere
M_lens = self.lens(M) # (B,N,V,D_lens) rigid
S = self.omega_token(M_lens) # (B,N,D_lens) emergent omega
S_att = self.transformer(S) # spectral-alpha coordination
dec_patches = self.decoder(M_lens, S_att) # (B,N,patch_dim)
return {'recon_patches': dec_patches, 'M': M, 'M_lens': M_lens,
'omega': S_att, 'mean_alpha': self.transformer.mean_alpha()}
def forward(self, images: torch.Tensor) -> Dict:
cfg = self.cfg
patches, gh, gw = extract_patches(images, cfg.ps) # (B,N,patch_dim)
out = self.forward_patches(patches)
recon = stitch_patches(out['recon_patches'], gh, gw, cfg.ps, cfg.channels)
out['recon'] = recon
return out
# ════════════════════════════════════════════════════════════════════════
# Data + rigidity guarantee monitor
# ════════════════════════════════════════════════════════════════════════
def make_batch(B, img_size, channels, step, seed, use_real_svae):
sd = seed * 100000 + step
if use_real_svae:
try:
from geolip_svae.inference import gen_sixteen_noise
x = gen_sixteen_noise(n=B, size=img_size, seed=sd)
if x.shape[1] != channels:
x = x[:, :channels] if x.shape[1] > channels else \
torch.cat([x, x[:, :1].repeat(1, channels - x.shape[1], 1, 1)], 1)
return x.clamp(-4, 4)
except Exception:
pass
g = torch.Generator().manual_seed(sd)
# structured noise: low-rank + sparse so there's compressible structure
base = torch.randn(B, channels, img_size, img_size, generator=g)
smooth = F.avg_pool2d(base, 4, stride=1, padding=2)[..., :img_size, :img_size]
return (0.6 * smooth + 0.4 * base).clamp(-4, 4)
def measure_guarantee(M_lens: torch.Tensor, D_base: int) -> Dict:
"""The lens-framed rigidity guarantee at D_lens: take one patch's frame,
measure cv_of + intrinsic deviation vs the D_base baseline (preserved by
the isometric lift)."""
cb = M_lens.detach()[0, 0] # (V, D_lens), one patch
cv = cv_of(cb)
dev = intrinsic_deviation(cb, D_base)
lo, hi = cv_band_for(D_base) # skeleton stays D_base-class
return {'cv_of': cv, 'in_d_base_band': lo <= cv <= hi,
'deviation': dev, 'in_envelope': abs(dev) < dev_critical(D_base)}
# ════════════════════════════════════════════════════════════════════════
# Prototype training (compression) + guarantee monitoring
# ════════════════════════════════════════════════════════════════════════
@dataclass
class TrainConfig:
epochs: int = 8
steps_per_epoch: int = 150
batch_size: int = 64
lr: float = 2e-3
rigid_weight: float = 0.5
out_dir: str = './geo_svae_transformer_results'
seed: int = 0
def run_train(geo: GeoConfig, tr: TrainConfig) -> Dict:
out_dir = Path(tr.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(tr.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(tr.seed)
print("=" * 70)
print("geolip-svae-transformer prototype β€” compression train + guarantee")
print(f" img{geo.img_size} ps{geo.ps} β†’ {(geo.img_size//geo.ps)**2} patches | "
f"V{geo.V} D_base{geo.D_base} β†’ lens D{geo.D_lens} | "
f"{geo.n_layers}LΓ—{geo.n_heads}h | device={device}")
print(f" front-end: {'REAL PatchSVAE' if geo.use_real_svae else 'lean vendored'}")
print("=" * 70)
model = GeoSVAETransformer(geo).to(device)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f" trainable params: {n_params:,}")
opt = torch.optim.Adam(model.parameters(), lr=tr.lr)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(
opt, T_max=tr.epochs * tr.steps_per_epoch)
model.train()
step, best = 0, float('inf')
history = []
for epoch in range(tr.epochs):
ep_losses = []
for _ in range(tr.steps_per_epoch):
images = make_batch(tr.batch_size, geo.img_size, geo.channels,
step, tr.seed, geo.use_real_svae).to(device)
opt.zero_grad()
out = model(images)
recon_loss = F.mse_loss(out['recon'], images)
rig_loss = rigidity_loss(out['M'], geo.D_base)
loss = recon_loss + tr.rigid_weight * rig_loss
loss.backward()
opt.step()
sched.step()
ep_losses.append(float(recon_loss.detach()))
best = min(best, float(recon_loss.detach()))
step += 1
with torch.no_grad():
probe = make_batch(tr.batch_size, geo.img_size, geo.channels,
99999, tr.seed, geo.use_real_svae).to(device)
g = measure_guarantee(model(probe)['M_lens'], geo.D_base)
mean_mse = sum(ep_losses[-tr.steps_per_epoch//4:]) / max(1, tr.steps_per_epoch//4)
history.append({'epoch': epoch, 'mean_mse': mean_mse, 'best_mse': best,
'mean_alpha': out['mean_alpha'], 'guarantee': g})
print(f" epoch {epoch:2d}: mse={mean_mse:.5f} (best {best:.5f}) | "
f"Ξ±={out['mean_alpha']:.4f} | lens-frame: cv_of={g['cv_of']:.3f} "
f"dev={g['deviation']:+.4f} in_env={g['in_envelope']}")
final_g = history[-1]['guarantee']
verdict = {
'compresses': best < 0.05,
'best_mse': best,
'guarantee_holds': final_g['in_envelope'], # the rigidity formula
'cv_in_band': final_g['in_d_base_band'], # secondary signature
'final_guarantee': final_g,
'trainable_params': n_params,
}
report = {'geo_config': asdict(geo), 'train_config': asdict(tr),
'history': history, 'verdict': verdict}
with open(out_dir / 'geo_svae_transformer.json', 'w') as f:
json.dump(report, f, indent=2)
print("\n" + "=" * 70)
print("PROTOTYPE VERDICT")
print("=" * 70)
print(f" {'βœ“' if verdict['compresses'] else 'βœ—'} compresses: best MSE "
f"{best:.5f} (input β†’ omega tokens β†’ reconstruction)")
print(f" {'βœ“' if verdict['guarantee_holds'] else 'βœ—'} GUARANTEE holds: "
f"lens-framed frame in rigidity envelope at D_lens={geo.D_lens} "
f"(dev {final_g['deviation']:+.4f}, crit Β±{dev_critical(geo.D_base):.3f})")
print(f" {'Β·'} cv_of {final_g['cv_of']:.3f} "
f"({'in' if verdict['cv_in_band'] else 'just above'} d{geo.D_base} "
f"random-codebook band β€” secondary signature)")
print(f" β†’ omega tokens are transformer-shaped (D_lens-vectors, cross-patch")
print(f" attended); the rigid frame is imposed underneath and survives the")
print(f" lens in-envelope. Compression trained; decode = generation seed.")
print(f" report: {out_dir / 'geo_svae_transformer.json'}")
return report
# ════════════════════════════════════════════════════════════════════════
# Colab-proof entry points
# ════════════════════════════════════════════════════════════════════════
def _is_jupyter_kernel():
try:
from IPython import get_ipython
ip = get_ipython()
return ip is not None and 'IPKernelApp' in ip.config
except Exception:
return False
def _filter_jupyter_args(argv):
out, skip = [], False
for a in argv:
if skip:
skip = False
continue
if a == '-f':
skip = True
continue
if a.startswith('-f=') or a.endswith('.json'):
continue
out.append(a)
return out
def run(**kwargs):
"""Notebook entry:
from geolip_svae_transformer import run
run() # lean front-end (works anywhere)
run(use_real_svae=True) # real PatchSVAE (Colab, geolip_core)
run(D_lens=64, n_layers=4, epochs=12)
"""
geo = GeoConfig(**{k: v for k, v in kwargs.items()
if k in GeoConfig.__dataclass_fields__})
tr = TrainConfig(**{k: v for k, v in kwargs.items()
if k in TrainConfig.__dataclass_fields__})
return run_train(geo, tr)
def main(argv=None):
import sys
if argv is None:
argv = sys.argv[1:]
if _is_jupyter_kernel():
argv = _filter_jupyter_args(argv)
p = argparse.ArgumentParser()
p.add_argument('--D-lens', type=int, default=16)
p.add_argument('--n-layers', type=int, default=2)
p.add_argument('--epochs', type=int, default=8)
p.add_argument('--use-real-svae', action='store_true')
p.add_argument('--out-dir', default='./geo_svae_transformer_results')
args, _unknown = p.parse_known_args(argv)
geo = GeoConfig(D_lens=args.D_lens, n_layers=args.n_layers,
use_real_svae=args.use_real_svae)
tr = TrainConfig(epochs=args.epochs, out_dir=args.out_dir)
return run_train(geo, tr)
if __name__ == '__main__':
main()