| |
| """ |
| Geometric Lookup Flow Matching (GLFM) |
| ======================================== |
| A flow matching variant where velocity prediction is driven by |
| geometric address lookup on S^15. |
| |
| Core insight (empirical): |
| The constellation bottleneck doesn't reconstruct encoder features. |
| It produces cos_sim β 0 to its input. Instead, the triangulation |
| profile acts as a continuous ADDRESS on the unit hypersphere, |
| and the generator produces velocity fields from that address. |
| |
| This is: v(x_t, t, c) = Generator(Address(x_t), t, c) |
| where Address(x) = triangulate(project_to_sphere(encode(x))) |
| |
| GLFM formalizes this into three stages: |
| |
| Stage 1 β GEOMETRIC ADDRESSING |
| Encoder maps x_t to multiple resolution embeddings on S^15. |
| Each resolution captures different spatial frequency information. |
| Triangulation against fixed anchors produces a structured address. |
| |
| Stage 2 β ADDRESS CONDITIONING |
| The geometric address is concatenated with: |
| - Timestep embedding (sinusoidal) |
| - Class/text conditioning |
| - Noise level features |
| The conditioning modulates WHAT to generate at this address. |
| |
| Stage 3 β VELOCITY GENERATION |
| A deep MLP generates the velocity field from the conditioned address. |
| This is NOT reconstruction β it's generation from a lookup. |
| The generator never sees the raw encoder features. |
| |
| Key properties: |
| - Address space is geometrically structured (Voronoi cells on S^15) |
| - Anchors self-organize: <0.29 rad = frame holders, >0.29 = task encoders |
| - Precision-invariant (works at fp8) |
| - 21Γ compression with zero velocity quality loss |
| - Multi-scale addressing captures both coarse and fine structure |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| import os |
| import time |
| from tqdm import tqdm |
| from torchvision import datasets, transforms |
| from torchvision.utils import save_image, make_grid |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
|
|
| |
| |
| |
|
|
| class GeometricAddressEncoder(nn.Module): |
| """ |
| Maps spatial features to geometric addresses on S^15. |
| |
| Multi-scale: produces addresses at 2 resolutions. |
| - Coarse: global pool β single 256d embedding β 1 address |
| - Fine: per-spatial-position β 256d embeddings β HW addresses |
| |
| Each address is triangulated against the constellation. |
| The combined triangulation profiles form the full geometric address. |
| """ |
| def __init__( |
| self, |
| spatial_channels, |
| spatial_size, |
| embed_dim=256, |
| patch_dim=16, |
| n_anchors=16, |
| n_phases=3, |
| ): |
| super().__init__() |
| self.spatial_channels = spatial_channels |
| self.spatial_size = spatial_size |
| self.embed_dim = embed_dim |
| self.patch_dim = patch_dim |
| self.n_patches = embed_dim // patch_dim |
| self.n_anchors = n_anchors |
| self.n_phases = n_phases |
|
|
| P, A, d = self.n_patches, n_anchors, patch_dim |
|
|
| |
| self.coarse_proj = nn.Sequential( |
| nn.Linear(spatial_channels, embed_dim), |
| nn.LayerNorm(embed_dim), |
| ) |
|
|
| |
| self.fine_proj = nn.Sequential( |
| nn.Linear(spatial_channels, embed_dim), |
| nn.LayerNorm(embed_dim), |
| ) |
|
|
| |
| home = torch.empty(P, A, d) |
| nn.init.xavier_normal_(home.view(P * A, d)) |
| home = F.normalize(home.view(P, A, d), dim=-1) |
| self.register_buffer('home', home) |
| self.anchors = nn.Parameter(home.clone()) |
|
|
| |
| self.tri_dim = P * A * n_phases |
|
|
| |
| self.address_dim = self.tri_dim * 2 |
|
|
| def drift(self): |
| h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1) |
| return torch.acos((h * c).sum(-1).clamp(-1 + 1e-7, 1 - 1e-7)) |
|
|
| def at_phase(self, t): |
| h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1) |
| omega = self.drift().unsqueeze(-1) |
| so = omega.sin().clamp(min=1e-7) |
| return torch.sin((1-t)*omega)/so * h + torch.sin(t*omega)/so * c |
|
|
| def triangulate(self, patches_n): |
| """patches_n: (..., P, d) β (..., P*A*n_phases)""" |
| shape = patches_n.shape[:-2] |
| P, A, d = self.n_patches, self.n_anchors, self.patch_dim |
| flat = patches_n.reshape(-1, P, d) |
| phases = torch.linspace(0, 1, self.n_phases, device=flat.device).tolist() |
| tris = [] |
| for t in phases: |
| at = F.normalize(self.at_phase(t), dim=-1) |
| tris.append(1.0 - torch.einsum('bpd,pad->bpa', flat, at)) |
| tri = torch.cat(tris, dim=-1).reshape(flat.shape[0], -1) |
| return tri.reshape(*shape, -1) |
|
|
| def forward(self, feature_map): |
| """ |
| feature_map: (B, C, H, W) from encoder |
| Returns: (B, address_dim) geometric address |
| """ |
| B, C, H, W = feature_map.shape |
|
|
| |
| coarse = feature_map.mean(dim=(-2, -1)) |
| coarse_emb = self.coarse_proj(coarse) |
| coarse_patches = F.normalize( |
| coarse_emb.reshape(B, self.n_patches, self.patch_dim), dim=-1) |
| coarse_addr = self.triangulate(coarse_patches) |
|
|
| |
| fine = feature_map.permute(0, 2, 3, 1).reshape(B * H * W, C) |
| fine_emb = self.fine_proj(fine) |
| fine_patches = F.normalize( |
| fine_emb.reshape(B * H * W, self.n_patches, self.patch_dim), dim=-1) |
| fine_addr = self.triangulate(fine_patches) |
| |
| fine_addr = fine_addr.reshape(B, H * W, -1) |
| fine_mean = fine_addr.mean(dim=1) |
| fine_max = fine_addr.max(dim=1).values |
| |
| fine_combined = (fine_mean + fine_max) / 2 |
|
|
| |
| return torch.cat([coarse_addr, fine_combined], dim=-1) |
|
|
|
|
| |
| |
| |
|
|
| class AddressConditioner(nn.Module): |
| """ |
| Combines geometric address with timestep and class conditioning. |
| Produces a conditioned address vector ready for the generator. |
| """ |
| def __init__(self, address_dim, cond_dim=256, output_dim=1024): |
| super().__init__() |
| self.time_emb = nn.Sequential( |
| SinusoidalPosEmb(cond_dim), |
| nn.Linear(cond_dim, cond_dim), nn.GELU(), |
| nn.Linear(cond_dim, cond_dim)) |
|
|
| |
| self.noise_emb = nn.Embedding(64, cond_dim) |
|
|
| self.fuse = nn.Sequential( |
| nn.Linear(address_dim + cond_dim * 3, output_dim), |
| nn.GELU(), |
| nn.LayerNorm(output_dim), |
| ) |
|
|
| def forward(self, address, t, class_emb): |
| """ |
| address: (B, address_dim) from geometric encoder |
| t: (B,) timestep |
| class_emb: (B, cond_dim) class embedding |
| Returns: (B, output_dim) conditioned address |
| """ |
| t_emb = self.time_emb(t) |
| |
| t_discrete = (t * 63).long().clamp(0, 63) |
| n_emb = self.noise_emb(t_discrete) |
|
|
| combined = torch.cat([address, t_emb, class_emb, n_emb], dim=-1) |
| return self.fuse(combined) |
|
|
|
|
| |
| |
| |
|
|
| class VelocityGenerator(nn.Module): |
| """ |
| Generates spatial velocity features from a conditioned address. |
| NOT reconstruction β generation from geometric lookup. |
| """ |
| def __init__(self, cond_address_dim, spatial_dim, hidden=1024, depth=4): |
| super().__init__() |
| self.spatial_dim = spatial_dim |
|
|
| |
| self.blocks = nn.ModuleList() |
| self.blocks.append(nn.Sequential( |
| nn.Linear(cond_address_dim, hidden), |
| nn.GELU(), nn.LayerNorm(hidden))) |
|
|
| for _ in range(depth): |
| self.blocks.append(ResBlock(hidden)) |
|
|
| self.head = nn.Sequential( |
| nn.Linear(hidden, hidden), nn.GELU(), |
| nn.Linear(hidden, spatial_dim)) |
|
|
| def forward(self, cond_address): |
| """ |
| cond_address: (B, cond_address_dim) |
| Returns: (B, spatial_dim) generated velocity features |
| """ |
| h = self.blocks[0](cond_address) |
| for block in self.blocks[1:]: |
| h = block(h) |
| return self.head(h) |
|
|
|
|
| class ResBlock(nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(dim, dim), nn.GELU(), nn.LayerNorm(dim), |
| nn.Linear(dim, dim), nn.GELU(), nn.LayerNorm(dim)) |
|
|
| def forward(self, x): |
| return x + self.net(x) |
|
|
|
|
| |
| |
| |
|
|
| class SinusoidalPosEmb(nn.Module): |
| def __init__(self, dim): |
| super().__init__() |
| self.dim = dim |
|
|
| def forward(self, t): |
| half = self.dim // 2 |
| emb = math.log(10000) / (half - 1) |
| emb = torch.exp(torch.arange(half, device=t.device, dtype=t.dtype) * -emb) |
| emb = t.unsqueeze(-1) * emb.unsqueeze(0) |
| return torch.cat([emb.sin(), emb.cos()], dim=-1) |
|
|
|
|
| class AdaGroupNorm(nn.Module): |
| def __init__(self, ch, cond_dim, groups=8): |
| super().__init__() |
| self.gn = nn.GroupNorm(min(groups, ch), ch, affine=False) |
| self.proj = nn.Linear(cond_dim, ch * 2) |
| nn.init.zeros_(self.proj.weight); nn.init.zeros_(self.proj.bias) |
|
|
| def forward(self, x, cond): |
| x = self.gn(x) |
| s, sh = self.proj(cond).unsqueeze(-1).unsqueeze(-1).chunk(2, dim=1) |
| return x * (1 + s) + sh |
|
|
|
|
| class ConvBlock(nn.Module): |
| def __init__(self, ch, cond_dim): |
| super().__init__() |
| self.dw = nn.Conv2d(ch, ch, 7, padding=3, groups=ch) |
| self.norm = AdaGroupNorm(ch, cond_dim) |
| self.pw1 = nn.Conv2d(ch, ch * 4, 1) |
| self.pw2 = nn.Conv2d(ch * 4, ch, 1) |
| self.act = nn.GELU() |
|
|
| def forward(self, x, cond): |
| r = x |
| x = self.act(self.pw1(self.norm(self.dw(x), cond))) |
| return r + self.pw2(x) |
|
|
|
|
| class Downsample(nn.Module): |
| def __init__(self, ch): |
| super().__init__() |
| self.conv = nn.Conv2d(ch, ch, 3, stride=2, padding=1) |
| def forward(self, x): return self.conv(x) |
|
|
|
|
| class Upsample(nn.Module): |
| def __init__(self, ch): |
| super().__init__() |
| self.conv = nn.Conv2d(ch, ch, 3, padding=1) |
| def forward(self, x): |
| return self.conv(F.interpolate(x, scale_factor=2, mode='nearest')) |
|
|
|
|
| |
| |
| |
|
|
| class GLFMUNet(nn.Module): |
| """ |
| Geometric Lookup Flow Matching UNet. |
| |
| Encoder β GeometricAddress β Conditioner β VelocityGenerator β Decoder |
| |
| The middle of the UNet is the three-stage GLFM pipeline. |
| No attention. No reconstruction. Pure geometric lookup. |
| """ |
| def __init__( |
| self, |
| in_ch=3, |
| base_ch=64, |
| ch_mults=(1, 2, 4), |
| n_classes=10, |
| cond_dim=256, |
| embed_dim=256, |
| n_anchors=16, |
| n_phases=3, |
| gen_hidden=1024, |
| gen_depth=4, |
| ): |
| super().__init__() |
| self.ch_mults = ch_mults |
|
|
| |
| self.class_emb = nn.Embedding(n_classes, cond_dim) |
|
|
| |
| self.enc_time = nn.Sequential( |
| SinusoidalPosEmb(cond_dim), |
| nn.Linear(cond_dim, cond_dim), nn.GELU(), |
| nn.Linear(cond_dim, cond_dim)) |
|
|
| self.in_conv = nn.Conv2d(in_ch, base_ch, 3, padding=1) |
|
|
| |
| self.enc = nn.ModuleList() |
| self.enc_down = nn.ModuleList() |
| ch = base_ch |
| enc_channels = [base_ch] |
|
|
| for i, m in enumerate(ch_mults): |
| ch_out = base_ch * m |
| self.enc.append(nn.ModuleList([ |
| ConvBlock(ch, cond_dim) if ch == ch_out |
| else nn.Sequential(nn.Conv2d(ch, ch_out, 1), ConvBlock(ch_out, cond_dim)), |
| ConvBlock(ch_out, cond_dim), |
| ])) |
| ch = ch_out |
| enc_channels.append(ch) |
| if i < len(ch_mults) - 1: |
| self.enc_down.append(Downsample(ch)) |
|
|
| |
| mid_ch = ch |
| H_mid = 32 // (2 ** (len(ch_mults) - 1)) |
| spatial_dim = mid_ch * H_mid * H_mid |
| self.mid_spatial = (mid_ch, H_mid, H_mid) |
|
|
| |
| self.geo_encoder = GeometricAddressEncoder( |
| spatial_channels=mid_ch, |
| spatial_size=H_mid, |
| embed_dim=embed_dim, |
| patch_dim=16, |
| n_anchors=n_anchors, |
| n_phases=n_phases, |
| ) |
|
|
| |
| self.conditioner = AddressConditioner( |
| address_dim=self.geo_encoder.address_dim, |
| cond_dim=cond_dim, |
| output_dim=gen_hidden, |
| ) |
|
|
| |
| self.generator = VelocityGenerator( |
| cond_address_dim=gen_hidden, |
| spatial_dim=spatial_dim, |
| hidden=gen_hidden, |
| depth=gen_depth, |
| ) |
|
|
| |
| self.dec_up = nn.ModuleList() |
| self.dec_skip = nn.ModuleList() |
| self.dec = nn.ModuleList() |
|
|
| |
| self.dec_time = nn.Sequential( |
| SinusoidalPosEmb(cond_dim), |
| nn.Linear(cond_dim, cond_dim), nn.GELU(), |
| nn.Linear(cond_dim, cond_dim)) |
|
|
| for i in range(len(ch_mults) - 1, -1, -1): |
| ch_out = base_ch * ch_mults[i] |
| skip_ch = enc_channels.pop() |
| self.dec_skip.append(nn.Conv2d(ch + skip_ch, ch_out, 1)) |
| self.dec.append(nn.ModuleList([ |
| ConvBlock(ch_out, cond_dim), |
| ConvBlock(ch_out, cond_dim), |
| ])) |
| ch = ch_out |
| if i > 0: |
| self.dec_up.append(Upsample(ch)) |
|
|
| self.out_norm = nn.GroupNorm(8, ch) |
| self.out_conv = nn.Conv2d(ch, in_ch, 3, padding=1) |
| nn.init.zeros_(self.out_conv.weight) |
| nn.init.zeros_(self.out_conv.bias) |
|
|
| def forward(self, x, t, class_labels): |
| |
| enc_cond = self.enc_time(t) + self.class_emb(class_labels) |
| dec_cond = self.dec_time(t) + self.class_emb(class_labels) |
| cls_emb = self.class_emb(class_labels) |
|
|
| h = self.in_conv(x) |
| skips = [h] |
|
|
| |
| for i in range(len(self.ch_mults)): |
| for block in self.enc[i]: |
| if isinstance(block, ConvBlock): h = block(h, enc_cond) |
| elif isinstance(block, nn.Sequential): |
| h = block[0](h); h = block[1](h, enc_cond) |
| skips.append(h) |
| if i < len(self.enc_down): |
| h = self.enc_down[i](h) |
|
|
| |
| B = h.shape[0] |
| address = self.geo_encoder(h) |
| cond_addr = self.conditioner(address, t, cls_emb) |
| h = self.generator(cond_addr) |
| h = h.reshape(B, *self.mid_spatial) |
|
|
| |
| for i in range(len(self.ch_mults)): |
| skip = skips.pop() |
| if i > 0: |
| h = self.dec_up[i - 1](h) |
| h = torch.cat([h, skip], dim=1) |
| h = self.dec_skip[i](h) |
| for block in self.dec[i]: |
| h = block(h, dec_cond) |
|
|
| return self.out_conv(F.silu(self.out_norm(h))) |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def sample(model, n=64, steps=50, cls=None, n_cls=10): |
| model.eval() |
| x = torch.randn(n, 3, 32, 32, device=DEVICE) |
| labels = (torch.full((n,), cls, dtype=torch.long, device=DEVICE) |
| if cls is not None else torch.randint(0, n_cls, (n,), device=DEVICE)) |
| dt = 1.0 / steps |
| for s in range(steps): |
| t = torch.full((n,), 1.0 - s * dt, device=DEVICE) |
| with torch.amp.autocast("cuda", dtype=torch.bfloat16): |
| v = model(x, t, labels) |
| x = x - v.float() * dt |
| return x.clamp(-1, 1), labels |
|
|
|
|
| |
| |
| |
|
|
| BATCH = 128 |
| EPOCHS = 80 |
| LR = 3e-4 |
| SAMPLE_EVERY = 5 |
|
|
| print("=" * 70) |
| print("GEOMETRIC LOOKUP FLOW MATCHING (GLFM)") |
| print(f" Three-stage: Address β Condition β Generate") |
| print(f" Multi-scale: coarse (global) + fine (per-position)") |
| print(f" Device: {DEVICE}") |
| print("=" * 70) |
|
|
| transform = transforms.Compose([ |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| transforms.Normalize((0.5,)*3, (0.5,)*3), |
| ]) |
| train_ds = datasets.CIFAR10('./data', train=True, download=True, transform=transform) |
| train_loader = torch.utils.data.DataLoader( |
| train_ds, batch_size=BATCH, shuffle=True, |
| num_workers=4, pin_memory=True, drop_last=True) |
|
|
| model = GLFMUNet( |
| in_ch=3, base_ch=64, ch_mults=(1, 2, 4), |
| n_classes=10, cond_dim=256, embed_dim=256, |
| n_anchors=16, n_phases=3, |
| gen_hidden=1024, gen_depth=4, |
| ).to(DEVICE) |
|
|
| n_params = sum(p.numel() for p in model.parameters()) |
| n_geo = sum(p.numel() for p in model.geo_encoder.parameters()) |
| n_cond = sum(p.numel() for p in model.conditioner.parameters()) |
| n_gen = sum(p.numel() for p in model.generator.parameters()) |
| n_anchor = sum(p.numel() for n, p in model.named_parameters() if 'anchor' in n) |
|
|
| print(f" Total: {n_params:,}") |
| print(f" Geo Encoder: {n_geo:,} (Stage 1 β address)") |
| print(f" Conditioner: {n_cond:,} (Stage 2 β fuse)") |
| print(f" Generator: {n_gen:,} (Stage 3 β velocity)") |
| print(f" Anchors: {n_anchor:,}") |
| print(f" Address dim: {model.geo_encoder.address_dim} " |
| f"(coarse {model.geo_encoder.tri_dim} + fine {model.geo_encoder.tri_dim})") |
| print(f" Compression: {model.generator.spatial_dim} β " |
| f"{model.geo_encoder.address_dim} " |
| f"({model.generator.spatial_dim / model.geo_encoder.address_dim:.1f}Γ)") |
|
|
| |
| with torch.no_grad(): |
| d = torch.randn(2, 3, 32, 32, device=DEVICE) |
| o = model(d, torch.rand(2, device=DEVICE), torch.randint(0, 10, (2,), device=DEVICE)) |
| print(f" Shape: {d.shape} β {o.shape} β") |
| print(f" Train: {len(train_ds):,}") |
|
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( |
| optimizer, T_max=EPOCHS * len(train_loader), eta_min=1e-6) |
| scaler = torch.amp.GradScaler("cuda") |
|
|
| os.makedirs("samples_glfm", exist_ok=True) |
| os.makedirs("checkpoints", exist_ok=True) |
|
|
| print(f"\n{'='*70}") |
| print(f"TRAINING β {EPOCHS} epochs") |
| print(f"{'='*70}") |
|
|
| best_loss = float('inf') |
| bn = model.geo_encoder |
|
|
| for epoch in range(EPOCHS): |
| model.train() |
| t0 = time.time() |
| total_loss = 0 |
| n = 0 |
|
|
| pbar = tqdm(train_loader, desc=f"E{epoch+1:3d}/{EPOCHS}", unit="b") |
| for images, labels in pbar: |
| images = images.to(DEVICE, non_blocking=True) |
| labels = labels.to(DEVICE, non_blocking=True) |
| B = images.shape[0] |
|
|
| t = torch.rand(B, device=DEVICE) |
| eps = torch.randn_like(images) |
| t_b = t.view(B, 1, 1, 1) |
| x_t = (1 - t_b) * images + t_b * eps |
| v_target = eps - images |
|
|
| with torch.amp.autocast("cuda", dtype=torch.bfloat16): |
| v_pred = model(x_t, t, labels) |
| loss = F.mse_loss(v_pred, v_target) |
|
|
| optimizer.zero_grad(set_to_none=True) |
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| scaler.step(optimizer) |
| scaler.update() |
| scheduler.step() |
|
|
| total_loss += loss.item() |
| n += 1 |
| if n % 20 == 0: |
| pbar.set_postfix(loss=f"{total_loss/n:.4f}", lr=f"{scheduler.get_last_lr()[0]:.1e}") |
|
|
| elapsed = time.time() - t0 |
| avg_loss = total_loss / n |
|
|
| mk = "" |
| if avg_loss < best_loss: |
| best_loss = avg_loss |
| torch.save({ |
| 'state_dict': model.state_dict(), |
| 'epoch': epoch + 1, 'loss': avg_loss, |
| }, 'checkpoints/glfm_best.pt') |
| mk = " β
" |
|
|
| print(f" E{epoch+1:3d}: loss={avg_loss:.4f} lr={scheduler.get_last_lr()[0]:.1e} " |
| f"({elapsed:.0f}s){mk}") |
|
|
| |
| if (epoch + 1) % 10 == 0: |
| with torch.no_grad(): |
| drift = bn.drift().detach() |
| near = (drift - 0.29154).abs().lt(0.05).float().mean().item() |
| crossed = (drift > 0.29154).float().mean().item() |
| print(f" β
drift: mean={drift.mean():.4f} max={drift.max():.4f} " |
| f"near_0.29={near:.1%} crossed={crossed:.1%}") |
|
|
| |
| if (epoch + 1) % SAMPLE_EVERY == 0 or epoch == 0: |
| imgs, _ = sample(model, 64, 50) |
| save_image(make_grid((imgs + 1) / 2, nrow=8), f'samples_glfm/epoch_{epoch+1:03d}.png') |
| print(f" β samples_glfm/epoch_{epoch+1:03d}.png") |
|
|
| if (epoch + 1) % 20 == 0: |
| names = ['plane','auto','bird','cat','deer','dog','frog','horse','ship','truck'] |
| for c in range(10): |
| cs, _ = sample(model, 8, 50, cls=c) |
| save_image(make_grid((cs+1)/2, nrow=8), |
| f'samples_glfm/epoch_{epoch+1:03d}_{names[c]}.png') |
| print(f" β per-class samples") |
|
|
| print(f"\n{'='*70}") |
| print(f"GEOMETRIC LOOKUP FLOW MATCHING β COMPLETE") |
| print(f" Best loss: {best_loss:.4f}") |
| print(f" Total: {n_params:,}") |
| with torch.no_grad(): |
| drift = bn.drift().detach() |
| near = (drift - 0.29154).abs().lt(0.05).float().mean().item() |
| crossed = (drift > 0.29154).float().mean().item() |
| print(f" Final drift: mean={drift.mean():.4f} max={drift.max():.4f}") |
| print(f" Near 0.29: {near:.1%} Crossed: {crossed:.1%}") |
| print(f"{'='*70}") |