"""geolip_svae_transformer.py — prototype v1. The geolip-svae-transformer: a geometric STRUCTURAL COMPANION that imposes the lensed rigid structure on a standard-transformer-shaped token interface, for compression and generation. Not a competitor — a partner that supplies a uniformly identifiable geometric coordinate system the host attends through. GROUNDED IN HOW THE SVAE SCALES THE INPUT (geolip_svae/model.py): image → extract_patches → (B,N,patch_dim=C·ps·ps) → encode: enc_in→blocks→enc_out→reshape(V,D)→sphere-normalize rows = M → SVD-ish split: M = U·S·Vt, where S (D singular values) is the data-specific OMEGA TOKEN and U/Vt/the sphere-normalized M is the UNIFORM GEOMETRIC FRAME (rigid, in-envelope, CV-band, same signature every patch). → SpectralCrossAttention coordinates S across patches. → decode: U·S·Vt → patch → stitch. THE PROTOTYPE'S MOVE: The omega token S is already transformer-token-shaped (a D-vector per position, attended position-to-position). So we: 1. encode patches → sphere M (the geometric frame) [front-end] 2. LENS-FRAME M to D_lens via a rigidity-preserving isometric lift — this is the guarantee: the frame stays rigid + in-envelope at D_lens, where native large-D collapses (exp_003). The lens WIDENS the spectral token for generation capacity while the rigid frame scales with it. 3. read the omega token S from the lensed frame [spectral] 4. cross-patch SPECTRAL TRANSFORMER over S — the relational data selection, "attention through the same avenues we're already using." 5. the attended S modulates the rigid frame (M·S, mirrors U·S) [decode] 6. decode → stitch → reconstruction. Compression = patches → omega tokens. Generation/decode = omega tokens → patches through the rigid frame. SWAPPABLE FRONT-END: use_real_svae=True wraps the installed geolip_svae.PatchSVAE (Colab); default lean vendored encoder matches its structure (sandbox + Colab). Everything else is vendored (protos not installed). Self-contained, deterministic. """ from __future__ import annotations import argparse import json import math import time from dataclasses import dataclass, asdict from pathlib import Path from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F # ════════════════════════════════════════════════════════════════════════ # VENDORED geometry + cv_of (verified against catalog bands D4 .98 / D8 .38 / D16 .21) # ════════════════════════════════════════════════════════════════════════ def canon(v: torch.Tensor) -> torch.Tensor: nz = v.abs() > 1e-6 fi = nz.float().argmax(dim=-1, keepdim=True) fv = torch.gather(v, -1, fi) return v * torch.where(fv < 0, -1.0, 1.0) _UMEAN: Dict[int, float] = {} def uniform_projective_angle(D: int, n: int = 4096, seed: int = 0) -> float: if D in _UMEAN: return _UMEAN[D] g = torch.Generator().manual_seed(seed) pts = torch.randn(n, D, generator=g) pts = canon(pts / pts.norm(dim=1, keepdim=True).clamp_min(1e-12)) cos = (pts @ pts.T).clamp(-1, 1) ang = torch.acos(cos.abs()) iu = torch.triu_indices(n, n, offset=1) _UMEAN[D] = float(ang[iu[0], iu[1]].mean()) return _UMEAN[D] def dev_critical(D: int, coeff: float = 0.02) -> float: return coeff * math.sqrt(D) def intrinsic_deviation(M_rows: torch.Tensor, baseline_D: int) -> float: with torch.no_grad(): a = M_rows / M_rows.norm(dim=1, keepdim=True).clamp_min(1e-12) cos = (a @ a.T).abs().clamp(0, 1 - 1e-7) n = M_rows.shape[0] iu = torch.triu_indices(n, n, offset=1, device=M_rows.device) mean_ang = float(torch.acos(cos[iu[0], iu[1]]).mean()) return mean_ang - uniform_projective_angle(baseline_D) def cayley_menger_sq_vol(coords: torch.Tensor) -> torch.Tensor: D2 = torch.cdist(coords, coords) ** 2 CM = torch.ones(6, 6, device=coords.device, dtype=coords.dtype) CM[0, 0] = 0.0 CM[1:, 1:] = D2 return -torch.linalg.det(CM) / 9216.0 def cv_of(codebook: torch.Tensor, n_samples: int = 1000, seed: int = 0) -> float: V = codebook.shape[0] if V < 5: return 0.0 g = torch.Generator(device='cpu').manual_seed(seed) cb = codebook.detach().cpu().float() vols = [] for _ in range(n_samples): idx = torch.randperm(V, generator=g)[:5] sq = float(cayley_menger_sq_vol(cb[idx])) vols.append(math.sqrt(max(sq, 0.0))) v = torch.tensor(vols) return float(v.std(unbiased=False) / v.mean().clamp_min(1e-12)) def cv_band_for(D: int) -> Tuple[float, float]: if D <= 4: return (0.85, 1.05) if D <= 8: return (0.32, 0.45) if D <= 16: return (0.20, 0.23) return (0.0, 0.20) _UCOS: Dict[int, float] = {} def uniform_mean_abscos(D: int, n: int = 4096, seed: int = 0) -> float: """Mean pairwise |cos| of uniform directions on S^(D-1) — the rigid target.""" if D in _UCOS: return _UCOS[D] g = torch.Generator().manual_seed(seed) pts = F.normalize(torch.randn(n, D, generator=g), dim=1) cos = (pts @ pts.T).abs() iu = torch.triu_indices(n, n, offset=1) _UCOS[D] = float(cos[iu[0], iu[1]].mean()) return _UCOS[D] def rigidity_loss(M: torch.Tensor, D_base: int) -> torch.Tensor: """Differentiable: pull each patch's per-row |cos| spectrum toward the uniform (rigid) baseline so M stays in-envelope. M: (B,N,V,D).""" B, N, V, D = M.shape Mn = F.normalize(M, dim=-1).reshape(B * N, V, D) cos = (Mn @ Mn.transpose(1, 2)).abs() # (BN, V, V) iu = torch.triu_indices(V, V, offset=1, device=M.device) pair = cos[:, iu[0], iu[1]] # (BN, V*(V-1)/2) target = uniform_mean_abscos(D_base) return (pair.mean(dim=1) - target).pow(2).mean() # ════════════════════════════════════════════════════════════════════════ # Patchify / stitch (vendored from geolip_svae/model.py) # ════════════════════════════════════════════════════════════════════════ def extract_patches(images: torch.Tensor, ps: int): B, C, H, W = images.shape gh, gw = H // ps, W // ps p = images.reshape(B, C, gh, ps, gw, ps) p = p.permute(0, 2, 4, 1, 3, 5).contiguous() return p.reshape(B, gh * gw, C * ps * ps), gh, gw def stitch_patches(patches: torch.Tensor, gh: int, gw: int, ps: int, C: int): B = patches.shape[0] p = patches.reshape(B, gh, gw, C, ps, ps) return p.permute(0, 3, 1, 4, 2, 5).reshape(B, C, gh * ps, gw * ps) # ════════════════════════════════════════════════════════════════════════ # Geometric front-end (swappable) — patch → sphere-normalized M (V, D_base) # ════════════════════════════════════════════════════════════════════════ class LeanGeometricEncoder(nn.Module): """Vendored encoder matching PatchSVAE's structure: enc_in → residual blocks → enc_out → reshape(V,D) → sphere row-normalize. Sandbox-runnable.""" def __init__(self, patch_dim: int, V: int, D: int, hidden: int, depth: int = 1): super().__init__() self.V, self.D = V, D self.enc_in = nn.Linear(patch_dim, hidden) self.blocks = nn.ModuleList([ nn.Sequential(nn.Linear(hidden, hidden), nn.GELU(), nn.Linear(hidden, hidden)) for _ in range(depth)]) self.enc_out = nn.Linear(hidden, V * D) nn.init.orthogonal_(self.enc_out.weight) def forward(self, patches: torch.Tensor) -> torch.Tensor: B, N, _ = patches.shape h = F.gelu(self.enc_in(patches.reshape(B * N, -1))) for blk in self.blocks: h = h + blk(h) M = self.enc_out(h).reshape(B * N, self.V, self.D) M = F.normalize(M, dim=-1) # rows on S^(D-1) return canon(M).reshape(B, N, self.V, self.D) class RealSVAEEncoder(nn.Module): """Adapter around the installed geolip_svae.PatchSVAE (Colab). Uses its H2-stable encode_patches to produce the sphere-normalized M. Returns the same (B,N,V,D) interface as the lean encoder.""" def __init__(self, patch_dim: int, V: int, D: int, hidden: int, ps: int, channels: int, depth: int = 1, freeze: bool = False): super().__init__() from geolip_svae.model import PatchSVAE # Colab-only import self.svae = PatchSVAE(V=V, D=D, ps=ps, hidden=hidden, channels=channels, depth=depth, n_cross=1, linear_readout=True, svd_mode='none', match_params=True, row_norm='sphere') self.V, self.D = V, D if freeze: for p in self.svae.parameters(): p.requires_grad_(False) def forward(self, patches: torch.Tensor) -> torch.Tensor: out = self.svae.encode_patches(patches) M = out['M'] if M.dim() == 3: # (B*N, V, D) B, N, _ = patches.shape M = M.reshape(B, N, self.V, self.D) return canon(F.normalize(M, dim=-1)) def make_encoder(patch_dim, V, D, hidden, ps, channels, use_real_svae): if use_real_svae: try: return RealSVAEEncoder(patch_dim, V, D, hidden, ps, channels) except Exception as e: print(f" [front-end] real PatchSVAE unavailable ({e}); " f"falling back to lean encoder") return LeanGeometricEncoder(patch_dim, V, D, hidden) # ════════════════════════════════════════════════════════════════════════ # Lens frame — rigidity-preserving isometric lift D_base → D_lens (the guarantee) # ════════════════════════════════════════════════════════════════════════ class LensFrame(nn.Module): """Lift the sphere-normalized frame to D_lens via an orthonormal embedding. ⟨Ex,Ec⟩ = ⟨x,c⟩ ⇒ pairwise projective angles preserved EXACTLY ⇒ the frame's rigidity (in-envelope at D_base) is carried up to D_lens intact. This is the guarantee the lens-framed features provide.""" def __init__(self, D_base: int, D_lens: int, seed: int = 0): super().__init__() assert D_lens >= D_base self.D_base, self.D_lens = D_base, D_lens g = torch.Generator().manual_seed(seed) Q = torch.linalg.qr(torch.randn(D_lens, D_base, generator=g))[0] self.register_buffer('E', Q) # (D_lens, D_base) def forward(self, M: torch.Tensor) -> torch.Tensor: """M: (B,N,V,D_base) → (B,N,V,D_lens), rows still on the sphere.""" M_lens = M @ self.E.T return canon(F.normalize(M_lens, dim=-1)) # ════════════════════════════════════════════════════════════════════════ # Spectral-alpha attention — the SVAE's actual mechanism (dot-alpha MHA) # ════════════════════════════════════════════════════════════════════════ class SpectralAlphaAttention(nn.Module): """Faithful to the SVAE's SpectralCrossAttention — the ONLY attention the omegas are aligned to behave with: S_out = S · (1 + α_d · tanh(out_proj(SDPA(qkv(norm(S))))_d)) MULTIPLICATIVE (not additive), per-mode α bounded to [0, max_alpha] and initialized near zero (sigmoid(-2)·0.2 ≈ 0.024) so the attention starts as near-identity and ENGAGES GRADUALLY as the omegas form — curation, not forced convergence. max_alpha / alpha_init are the curation knobs.""" def __init__(self, D: int, n_heads: int = 4, max_alpha: float = 0.2, alpha_init: float = -2.0): super().__init__() assert D % n_heads == 0, f"D={D} must be divisible by n_heads={n_heads}" self.n_heads = n_heads self.head_dim = D // n_heads self.max_alpha = max_alpha self.qkv = nn.Linear(D, 3 * D) self.out_proj = nn.Linear(D, D) self.norm = nn.LayerNorm(D) self.alpha_logits = nn.Parameter(torch.full((D,), float(alpha_init))) @property def alpha(self) -> torch.Tensor: return self.max_alpha * torch.sigmoid(self.alpha_logits) # [0, max_alpha] def forward(self, S: torch.Tensor) -> torch.Tensor: B, N, D = S.shape S_n = self.norm(S) qkv = self.qkv(S_n).reshape(B, N, 3, self.n_heads, self.head_dim) qkv = qkv.permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] out = F.scaled_dot_product_attention(q, k, v) out = out.transpose(1, 2).reshape(B, N, D) gate = torch.tanh(self.out_proj(out)) return S * (1.0 + self.alpha.view(1, 1, -1) * gate) # multiplicative class SpectralAlphaStack(nn.Module): """Stack of spectral-alpha layers coordinating omegas across patches.""" def __init__(self, D: int, n_heads: int, n_layers: int, max_alpha: float = 0.2, alpha_init: float = -2.0): super().__init__() self.layers = nn.ModuleList([ SpectralAlphaAttention(D, n_heads, max_alpha, alpha_init) for _ in range(n_layers)]) def forward(self, S: torch.Tensor) -> torch.Tensor: for layer in self.layers: S = layer(S) return S def mean_alpha(self) -> float: """The bloom signal: mean engaged alpha across layers (init ≈ 0.024).""" with torch.no_grad(): return float(torch.stack([l.alpha.mean() for l in self.layers]).mean()) # ════════════════════════════════════════════════════════════════════════ # Decoder — modulated rigid frame → patch (the generation path) # ════════════════════════════════════════════════════════════════════════ class GeoDecoder(nn.Module): """Decode the attended-omega-modulated frame back to a patch. M_dec = M_lens · S_attended (broadcast over V rows) mirrors U·S; a small net reads the modulated frame to patch_dim.""" def __init__(self, V: int, D_lens: int, patch_dim: int, hidden: int, depth: int = 1): super().__init__() self.dec_in = nn.Linear(V * D_lens, hidden) self.blocks = nn.ModuleList([ nn.Sequential(nn.Linear(hidden, hidden), nn.GELU(), nn.Linear(hidden, hidden)) for _ in range(depth)]) self.dec_out = nn.Linear(hidden, patch_dim) def forward(self, M_lens: torch.Tensor, S_att: torch.Tensor) -> torch.Tensor: # M_lens:(B,N,V,D_lens) S_att:(B,N,D_lens) B, N, V, D = M_lens.shape M_dec = M_lens * S_att.unsqueeze(2) # modulate frame by token h = F.gelu(self.dec_in(M_dec.reshape(B * N, V * D))) for blk in self.blocks: h = h + blk(h) return self.dec_out(h).reshape(B, N, -1) # ════════════════════════════════════════════════════════════════════════ # The geolip-svae-transformer # ════════════════════════════════════════════════════════════════════════ @dataclass class GeoConfig: img_size: int = 32 channels: int = 3 ps: int = 4 V: int = 32 D_base: int = 4 D_lens: int = 16 hidden: int = 64 n_heads: int = 4 n_layers: int = 2 max_alpha: float = 0.2 # spectral-alpha ceiling (curation knob) alpha_init: float = -2.0 # near-zero engaged alpha at init (≈0.024) use_real_svae: bool = False patch_dim_override: Optional[int] = None # set = feature mode (e.g. BERT hidden) class GeoSVAETransformer(nn.Module): def __init__(self, cfg: GeoConfig): super().__init__() self.cfg = cfg patch_dim = cfg.patch_dim_override or (cfg.channels * cfg.ps * cfg.ps) self.patch_dim = patch_dim self.feature_mode = cfg.patch_dim_override is not None self.encoder = make_encoder(patch_dim, cfg.V, cfg.D_base, cfg.hidden, cfg.ps, cfg.channels, cfg.use_real_svae and not self.feature_mode) self.lens = LensFrame(cfg.D_base, cfg.D_lens) # the omega EMERGES as the spectral magnitude of the rigid frame — # not a learned squash. No parameters here. self.transformer = SpectralAlphaStack(cfg.D_lens, cfg.n_heads, cfg.n_layers, cfg.max_alpha, cfg.alpha_init) self.decoder = GeoDecoder(cfg.V, cfg.D_lens, patch_dim, cfg.hidden) def omega_token(self, M_lens: torch.Tensor) -> torch.Tensor: """The omega: per-mode spectral magnitude (column norms) of the rigid frame. Forms through curation as the encoder shapes M — never forced. M_lens (B,N,V,D_lens) → S (B,N,D_lens).""" return M_lens.norm(dim=-2) def forward_patches(self, patches: torch.Tensor) -> Dict: """Core path on (B, N, patch_dim) — image patches OR feature tokens.""" M = self.encoder(patches) # (B,N,V,D_base) sphere M_lens = self.lens(M) # (B,N,V,D_lens) rigid S = self.omega_token(M_lens) # (B,N,D_lens) emergent omega S_att = self.transformer(S) # spectral-alpha coordination dec_patches = self.decoder(M_lens, S_att) # (B,N,patch_dim) return {'recon_patches': dec_patches, 'M': M, 'M_lens': M_lens, 'omega': S_att, 'mean_alpha': self.transformer.mean_alpha()} def forward(self, images: torch.Tensor) -> Dict: cfg = self.cfg patches, gh, gw = extract_patches(images, cfg.ps) # (B,N,patch_dim) out = self.forward_patches(patches) recon = stitch_patches(out['recon_patches'], gh, gw, cfg.ps, cfg.channels) out['recon'] = recon return out # ════════════════════════════════════════════════════════════════════════ # Data + rigidity guarantee monitor # ════════════════════════════════════════════════════════════════════════ def make_batch(B, img_size, channels, step, seed, use_real_svae): sd = seed * 100000 + step if use_real_svae: try: from geolip_svae.inference import gen_sixteen_noise x = gen_sixteen_noise(n=B, size=img_size, seed=sd) if x.shape[1] != channels: x = x[:, :channels] if x.shape[1] > channels else \ torch.cat([x, x[:, :1].repeat(1, channels - x.shape[1], 1, 1)], 1) return x.clamp(-4, 4) except Exception: pass g = torch.Generator().manual_seed(sd) # structured noise: low-rank + sparse so there's compressible structure base = torch.randn(B, channels, img_size, img_size, generator=g) smooth = F.avg_pool2d(base, 4, stride=1, padding=2)[..., :img_size, :img_size] return (0.6 * smooth + 0.4 * base).clamp(-4, 4) def measure_guarantee(M_lens: torch.Tensor, D_base: int) -> Dict: """The lens-framed rigidity guarantee at D_lens: take one patch's frame, measure cv_of + intrinsic deviation vs the D_base baseline (preserved by the isometric lift).""" cb = M_lens.detach()[0, 0] # (V, D_lens), one patch cv = cv_of(cb) dev = intrinsic_deviation(cb, D_base) lo, hi = cv_band_for(D_base) # skeleton stays D_base-class return {'cv_of': cv, 'in_d_base_band': lo <= cv <= hi, 'deviation': dev, 'in_envelope': abs(dev) < dev_critical(D_base)} # ════════════════════════════════════════════════════════════════════════ # Prototype training (compression) + guarantee monitoring # ════════════════════════════════════════════════════════════════════════ @dataclass class TrainConfig: epochs: int = 8 steps_per_epoch: int = 150 batch_size: int = 64 lr: float = 2e-3 rigid_weight: float = 0.5 out_dir: str = './geo_svae_transformer_results' seed: int = 0 def run_train(geo: GeoConfig, tr: TrainConfig) -> Dict: out_dir = Path(tr.out_dir) out_dir.mkdir(parents=True, exist_ok=True) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') torch.manual_seed(tr.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(tr.seed) print("=" * 70) print("geolip-svae-transformer prototype — compression train + guarantee") print(f" img{geo.img_size} ps{geo.ps} → {(geo.img_size//geo.ps)**2} patches | " f"V{geo.V} D_base{geo.D_base} → lens D{geo.D_lens} | " f"{geo.n_layers}L×{geo.n_heads}h | device={device}") print(f" front-end: {'REAL PatchSVAE' if geo.use_real_svae else 'lean vendored'}") print("=" * 70) model = GeoSVAETransformer(geo).to(device) n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f" trainable params: {n_params:,}") opt = torch.optim.Adam(model.parameters(), lr=tr.lr) sched = torch.optim.lr_scheduler.CosineAnnealingLR( opt, T_max=tr.epochs * tr.steps_per_epoch) model.train() step, best = 0, float('inf') history = [] for epoch in range(tr.epochs): ep_losses = [] for _ in range(tr.steps_per_epoch): images = make_batch(tr.batch_size, geo.img_size, geo.channels, step, tr.seed, geo.use_real_svae).to(device) opt.zero_grad() out = model(images) recon_loss = F.mse_loss(out['recon'], images) rig_loss = rigidity_loss(out['M'], geo.D_base) loss = recon_loss + tr.rigid_weight * rig_loss loss.backward() opt.step() sched.step() ep_losses.append(float(recon_loss.detach())) best = min(best, float(recon_loss.detach())) step += 1 with torch.no_grad(): probe = make_batch(tr.batch_size, geo.img_size, geo.channels, 99999, tr.seed, geo.use_real_svae).to(device) g = measure_guarantee(model(probe)['M_lens'], geo.D_base) mean_mse = sum(ep_losses[-tr.steps_per_epoch//4:]) / max(1, tr.steps_per_epoch//4) history.append({'epoch': epoch, 'mean_mse': mean_mse, 'best_mse': best, 'mean_alpha': out['mean_alpha'], 'guarantee': g}) print(f" epoch {epoch:2d}: mse={mean_mse:.5f} (best {best:.5f}) | " f"α={out['mean_alpha']:.4f} | lens-frame: cv_of={g['cv_of']:.3f} " f"dev={g['deviation']:+.4f} in_env={g['in_envelope']}") final_g = history[-1]['guarantee'] verdict = { 'compresses': best < 0.05, 'best_mse': best, 'guarantee_holds': final_g['in_envelope'], # the rigidity formula 'cv_in_band': final_g['in_d_base_band'], # secondary signature 'final_guarantee': final_g, 'trainable_params': n_params, } report = {'geo_config': asdict(geo), 'train_config': asdict(tr), 'history': history, 'verdict': verdict} with open(out_dir / 'geo_svae_transformer.json', 'w') as f: json.dump(report, f, indent=2) print("\n" + "=" * 70) print("PROTOTYPE VERDICT") print("=" * 70) print(f" {'✓' if verdict['compresses'] else '✗'} compresses: best MSE " f"{best:.5f} (input → omega tokens → reconstruction)") print(f" {'✓' if verdict['guarantee_holds'] else '✗'} GUARANTEE holds: " f"lens-framed frame in rigidity envelope at D_lens={geo.D_lens} " f"(dev {final_g['deviation']:+.4f}, crit ±{dev_critical(geo.D_base):.3f})") print(f" {'·'} cv_of {final_g['cv_of']:.3f} " f"({'in' if verdict['cv_in_band'] else 'just above'} d{geo.D_base} " f"random-codebook band — secondary signature)") print(f" → omega tokens are transformer-shaped (D_lens-vectors, cross-patch") print(f" attended); the rigid frame is imposed underneath and survives the") print(f" lens in-envelope. Compression trained; decode = generation seed.") print(f" report: {out_dir / 'geo_svae_transformer.json'}") return report # ════════════════════════════════════════════════════════════════════════ # Colab-proof entry points # ════════════════════════════════════════════════════════════════════════ def _is_jupyter_kernel(): try: from IPython import get_ipython ip = get_ipython() return ip is not None and 'IPKernelApp' in ip.config except Exception: return False def _filter_jupyter_args(argv): out, skip = [], False for a in argv: if skip: skip = False continue if a == '-f': skip = True continue if a.startswith('-f=') or a.endswith('.json'): continue out.append(a) return out def run(**kwargs): """Notebook entry: from geolip_svae_transformer import run run() # lean front-end (works anywhere) run(use_real_svae=True) # real PatchSVAE (Colab, geolip_core) run(D_lens=64, n_layers=4, epochs=12) """ geo = GeoConfig(**{k: v for k, v in kwargs.items() if k in GeoConfig.__dataclass_fields__}) tr = TrainConfig(**{k: v for k, v in kwargs.items() if k in TrainConfig.__dataclass_fields__}) return run_train(geo, tr) def main(argv=None): import sys if argv is None: argv = sys.argv[1:] if _is_jupyter_kernel(): argv = _filter_jupyter_args(argv) p = argparse.ArgumentParser() p.add_argument('--D-lens', type=int, default=16) p.add_argument('--n-layers', type=int, default=2) p.add_argument('--epochs', type=int, default=8) p.add_argument('--use-real-svae', action='store_true') p.add_argument('--out-dir', default='./geo_svae_transformer_results') args, _unknown = p.parse_known_args(argv) geo = GeoConfig(D_lens=args.D_lens, n_layers=args.n_layers, use_real_svae=args.use_real_svae) tr = TrainConfig(epochs=args.epochs, out_dir=args.out_dir) return run_train(geo, tr) if __name__ == '__main__': main()