diff --git "a/spectral/notebooks/experiment_8_eigen_svd3_kernel.ipynb" "b/spectral/notebooks/experiment_8_eigen_svd3_kernel.ipynb" new file mode 100644--- /dev/null +++ "b/spectral/notebooks/experiment_8_eigen_svd3_kernel.ipynb" @@ -0,0 +1,2585 @@ +{ + "nbformat": 4, + "nbformat_minor": 5, + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + }, + "accelerator": "GPU", + "colab": { + "provenance": [], + "gpuType": "T4" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Experiment 8: Eigen / SVD3 Kernel Variations\n", + "\n", + "All experiments use the fused Triton SVD3 kernel (3\u00d73 Jacobi eigensolver in registers).\n", + "Each patch is an (M, 3) matrix. The kernel returns full U, S, Vh.\n", + "We test 10 different feature extraction strategies from the same decomposition.\n", + "\n", + "**Key insight from NB1-7**: spatial structure is the entire game (68.64% vs 47% plateau).\n", + "These experiments preserve per-patch features where possible.\n", + "\n", + "**Experiments:**\n", + "- 8.1 \u2014 Grassmannian V-features (right singular vectors as subspace coordinates)\n", + "- 8.2 \u2014 Full SVD signature (S + V + U projection norms)\n", + "- 8.3 \u2014 Per-patch Procrustes alignment (color-space SO(3) per patch)\n", + "- 8.4 \u2014 Singular value spatial map (S values preserve patch position)\n", + "- 8.5 \u2014 Subspace angles between adjacent patches\n", + "- 8.6 \u2014 SVD reconstruction residuals at each truncation\n", + "- 8.7 \u2014 Grassmannian + Procrustes joint features\n", + "- 8.8 \u2014 Channel mixing matrix V with spatial statistics\n", + "- 8.9 \u2014 Per-patch spectral contrast normalization\n", + "- 8.10 \u2014 Full kitchen sink: all SVD features combined" + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# @title Install Dependencies\n", + "!pip install -q kymatio torch torchvision tensorboard matplotlib scikit-learn huggingface_hub triton\n", + "%load_ext tensorboard\n", + "import torch\n", + "print(f\"PyTorch {torch.__version__}, CUDA available: {torch.cuda.is_available()}\")\n", + "if torch.cuda.is_available():\n", + " print(f\"GPU: {torch.cuda.get_device_name()}\")\n", + "" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# @title GeoLIP Architecture \u2014 Full Pipeline (geolip_core + losses + encoder)\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "# This cell embeds the COMPLETE GeoLIP architecture verbatim.\n", + "# DO NOT MODIFY \u2014 this is the canonical reference implementation.\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "\n", + "import os, sys, time, math, json, datetime, warnings\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import DataLoader\n", + "from torch.utils.tensorboard import SummaryWriter\n", + "import torchvision\n", + "import torchvision.transforms as T\n", + "import matplotlib\n", + "matplotlib.use(\"Agg\")\n", + "import matplotlib.pyplot as plt\n", + "from collections import defaultdict\n", + "\n", + "warnings.filterwarnings(\"ignore\", category=UserWarning)\n", + "\n", + "# \u2500\u2500 Performance: TF32 + cudnn benchmark + MAGMA for batched linalg \u2500\u2500\n", + "torch.backends.cuda.matmul.allow_tf32 = True\n", + "torch.backends.cudnn.allow_tf32 = True\n", + "torch.backends.cudnn.benchmark = True\n", + "# MAGMA is much faster than cusolver for batched small-matrix SVD/QR/eig\n", + "try:\n", + " torch.backends.cuda.preferred_linalg_library('magma')\n", + " _linalg_lib = 'magma'\n", + "except Exception:\n", + " _linalg_lib = 'default'\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"[DEVICE] {device}\" + (f\" \u2014 {torch.cuda.get_device_name()}\" if device.type == \"cuda\" else \"\"))\n", + "if device.type == \"cuda\":\n", + " print(f\"[PERF] TF32={torch.backends.cuda.matmul.allow_tf32}, cudnn.benchmark={torch.backends.cudnn.benchmark}, linalg={_linalg_lib}\")\n", + "\n", + "# \u2500\u2500 Fused Triton SVD kernel for batched M\u00d73 matrices \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", + "# cuSOLVER dispatch overhead dominates for tiny (64,3) patches.\n", + "# This kernel fuses G=A^T A, 3\u00d73 Jacobi eigensolver (in scalar registers),\n", + "# and U recovery into a single kernel launch. ~10,000x faster than cuSOLVER.\n", + "_HAS_TRITON_SVD3 = False\n", + "try:\n", + " import triton\n", + " import triton.language as tl\n", + "\n", + " @triton.jit\n", + " def _svd3_kernel(\n", + " A_ptr, U_ptr, S_ptr, Vh_ptr,\n", + " M: tl.constexpr, BLOCK_M: tl.constexpr,\n", + " JACOBI_ITERS: tl.constexpr, EPS: tl.constexpr,\n", + " ):\n", + " bid = tl.program_id(0)\n", + " # Stage 1: G = A^T A (6 accumulators, symmetric)\n", + " g00 = tl.zeros([], dtype=tl.float32)\n", + " g01 = tl.zeros([], dtype=tl.float32)\n", + " g02 = tl.zeros([], dtype=tl.float32)\n", + " g11 = tl.zeros([], dtype=tl.float32)\n", + " g12 = tl.zeros([], dtype=tl.float32)\n", + " g22 = tl.zeros([], dtype=tl.float32)\n", + " base = bid * M * 3\n", + " for block_start in range(0, M, BLOCK_M):\n", + " offs = tl.arange(0, BLOCK_M)\n", + " row_idx = block_start + offs\n", + " mask = row_idx < M\n", + " ptr0 = base + row_idx * 3 + 0\n", + " ptr1 = base + row_idx * 3 + 1\n", + " ptr2 = base + row_idx * 3 + 2\n", + " a0 = tl.load(A_ptr + ptr0, mask=mask, other=0.0).to(tl.float32)\n", + " a1 = tl.load(A_ptr + ptr1, mask=mask, other=0.0).to(tl.float32)\n", + " a2 = tl.load(A_ptr + ptr2, mask=mask, other=0.0).to(tl.float32)\n", + " g00 += tl.sum(a0 * a0); g01 += tl.sum(a0 * a1); g02 += tl.sum(a0 * a2)\n", + " g11 += tl.sum(a1 * a1); g12 += tl.sum(a1 * a2); g22 += tl.sum(a2 * a2)\n", + " # Stage 2: 3\u00d73 Jacobi eigensolver (all in scalar registers)\n", + " v00 = 1.0; v01 = 0.0; v02 = 0.0\n", + " v10 = 0.0; v11 = 1.0; v12 = 0.0\n", + " v20 = 0.0; v21 = 0.0; v22 = 1.0\n", + " for _sweep in range(JACOBI_ITERS):\n", + " # pair (0,1)\n", + " off_diag = g01; diag_diff = g11 - g00; abs_off = tl.abs(off_diag)\n", + " tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n", + " t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n", + " c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n", + " ng00 = c*c*g00 - 2.0*s*c*g01 + s*s*g11; ng11 = s*s*g00 + 2.0*s*c*g01 + c*c*g11\n", + " ng02 = c*g02 - s*g12; ng12 = s*g02 + c*g12\n", + " g00 = ng00; g11 = ng11; g01 = 0.0; g02 = ng02; g12 = ng12\n", + " nv00 = c*v00-s*v01; nv01 = s*v00+c*v01; nv10 = c*v10-s*v11; nv11 = s*v10+c*v11\n", + " nv20 = c*v20-s*v21; nv21 = s*v20+c*v21\n", + " v00=nv00; v01=nv01; v10=nv10; v11=nv11; v20=nv20; v21=nv21\n", + " # pair (0,2)\n", + " off_diag = g02; diag_diff = g22 - g00; abs_off = tl.abs(off_diag)\n", + " tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n", + " t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n", + " c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n", + " ng00 = c*c*g00 - 2.0*s*c*g02 + s*s*g22; ng22 = s*s*g00 + 2.0*s*c*g02 + c*c*g22\n", + " ng01 = c*g01 - s*g12; ng12b = s*g01 + c*g12\n", + " g00 = ng00; g22 = ng22; g02 = 0.0; g01 = ng01; g12 = ng12b\n", + " nv00 = c*v00-s*v02; nv02 = s*v00+c*v02; nv10 = c*v10-s*v12; nv12 = s*v10+c*v12\n", + " nv20 = c*v20-s*v22; nv22 = s*v20+c*v22\n", + " v00=nv00; v02=nv02; v10=nv10; v12=nv12; v20=nv20; v22=nv22\n", + " # pair (1,2)\n", + " off_diag = g12; diag_diff = g22 - g11; abs_off = tl.abs(off_diag)\n", + " tau = tl.where(abs_off > EPS, diag_diff / (2.0 * off_diag), 0.0)\n", + " t = tl.where(abs_off > EPS, tl.where(tau >= 0, 1.0, -1.0) / (tl.abs(tau) + tl.sqrt(1.0 + tau * tau)), 0.0)\n", + " c = 1.0 / tl.sqrt(1.0 + t * t); s = t * c\n", + " ng11 = c*c*g11 - 2.0*s*c*g12 + s*s*g22; ng22 = s*s*g11 + 2.0*s*c*g12 + c*c*g22\n", + " ng01 = c*g01 - s*g02; ng02b = s*g01 + c*g02\n", + " g11 = ng11; g22 = ng22; g12 = 0.0; g01 = ng01; g02 = ng02b\n", + " nv01 = c*v01-s*v02; nv02 = s*v01+c*v02; nv11 = c*v11-s*v12; nv12 = s*v11+c*v12\n", + " nv21 = c*v21-s*v22; nv22 = s*v21+c*v22\n", + " v01=nv01; v02=nv02; v11=nv11; v12=nv12; v21=nv21; v22=nv22\n", + " # Sort eigenvalues descending + permute V columns\n", + " s0 = tl.sqrt(tl.maximum(g00, EPS)); s1 = tl.sqrt(tl.maximum(g11, EPS)); s2 = tl.sqrt(tl.maximum(g22, EPS))\n", + " do_swap = s0 < s1\n", + " s0, s1 = tl.where(do_swap, s1, s0), tl.where(do_swap, s0, s1)\n", + " tv=v00; v00=tl.where(do_swap,v01,v00); v01=tl.where(do_swap,tv,v01)\n", + " tv=v10; v10=tl.where(do_swap,v11,v10); v11=tl.where(do_swap,tv,v11)\n", + " tv=v20; v20=tl.where(do_swap,v21,v20); v21=tl.where(do_swap,tv,v21)\n", + " do_swap = s0 < s2\n", + " s0, s2 = tl.where(do_swap, s2, s0), tl.where(do_swap, s0, s2)\n", + " tv=v00; v00=tl.where(do_swap,v02,v00); v02=tl.where(do_swap,tv,v02)\n", + " tv=v10; v10=tl.where(do_swap,v12,v10); v12=tl.where(do_swap,tv,v12)\n", + " tv=v20; v20=tl.where(do_swap,v22,v20); v22=tl.where(do_swap,tv,v22)\n", + " do_swap = s1 < s2\n", + " s1, s2 = tl.where(do_swap, s2, s1), tl.where(do_swap, s1, s2)\n", + " tv=v01; v01=tl.where(do_swap,v02,v01); v02=tl.where(do_swap,tv,v02)\n", + " tv=v11; v11=tl.where(do_swap,v12,v11); v12=tl.where(do_swap,tv,v12)\n", + " tv=v21; v21=tl.where(do_swap,v22,v21); v22=tl.where(do_swap,tv,v22)\n", + " # Write S\n", + " s_base = bid * 3\n", + " tl.store(S_ptr + s_base + 0, s0); tl.store(S_ptr + s_base + 1, s1); tl.store(S_ptr + s_base + 2, s2)\n", + " # Write Vh = V^T\n", + " vh_base = bid * 9\n", + " tl.store(Vh_ptr+vh_base+0,v00); tl.store(Vh_ptr+vh_base+1,v10); tl.store(Vh_ptr+vh_base+2,v20)\n", + " tl.store(Vh_ptr+vh_base+3,v01); tl.store(Vh_ptr+vh_base+4,v11); tl.store(Vh_ptr+vh_base+5,v21)\n", + " tl.store(Vh_ptr+vh_base+6,v02); tl.store(Vh_ptr+vh_base+7,v12); tl.store(Vh_ptr+vh_base+8,v22)\n", + " # Stage 3: U = A @ V @ diag(1/S)\n", + " inv_s0 = 1.0/(s0+EPS); inv_s1 = 1.0/(s1+EPS); inv_s2 = 1.0/(s2+EPS)\n", + " for block_start in range(0, M, BLOCK_M):\n", + " offs = tl.arange(0, BLOCK_M)\n", + " row_idx = block_start + offs\n", + " mask = row_idx < M\n", + " a0 = tl.load(A_ptr+base+row_idx*3+0, mask=mask, other=0.0).to(tl.float32)\n", + " a1 = tl.load(A_ptr+base+row_idx*3+1, mask=mask, other=0.0).to(tl.float32)\n", + " a2 = tl.load(A_ptr+base+row_idx*3+2, mask=mask, other=0.0).to(tl.float32)\n", + " u0 = (a0*v00 + a1*v10 + a2*v20) * inv_s0\n", + " u1 = (a0*v01 + a1*v11 + a2*v21) * inv_s1\n", + " u2 = (a0*v02 + a1*v12 + a2*v22) * inv_s2\n", + " u_base = bid * M * 3\n", + " tl.store(U_ptr+u_base+row_idx*3+0, u0, mask=mask)\n", + " tl.store(U_ptr+u_base+row_idx*3+1, u1, mask=mask)\n", + " tl.store(U_ptr+u_base+row_idx*3+2, u2, mask=mask)\n", + "\n", + " def batched_svd3(A, block_m=128, jacobi_iters=6):\n", + " \"\"\"Fused Triton SVD for (B, M, 3) tensors. Returns U, S, Vh.\"\"\"\n", + " assert A.ndim == 3 and A.shape[2] == 3\n", + " B, M, _ = A.shape\n", + " A_f32 = A.contiguous().float()\n", + " U = torch.empty((B, M, 3), dtype=torch.float32, device=A.device)\n", + " S = torch.empty((B, 3), dtype=torch.float32, device=A.device)\n", + " Vh = torch.empty((B, 3, 3), dtype=torch.float32, device=A.device)\n", + " _svd3_kernel[(B,)](A_f32, U, S, Vh, M=M, BLOCK_M=block_m, JACOBI_ITERS=jacobi_iters, EPS=1e-12)\n", + " return U, S, Vh\n", + "\n", + " _HAS_TRITON_SVD3 = True\n", + " print(\"[PERF] Triton SVD3 kernel loaded \u2014 fused 3\u00d73 Jacobi eigensolver\")\n", + "except ImportError:\n", + " print(\"[PERF] Triton not available \u2014 falling back to torch.linalg.svd\")\n", + "\n", + " def batched_svd3(A, block_m=128, jacobi_iters=6):\n", + " \"\"\"Fallback: torch.linalg.svd for (B, M, 3) tensors.\"\"\"\n", + " return torch.linalg.svd(A.float(), full_matrices=False)\n", + "\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "# GEOLIP CORE \u2014 Geometric Building Blocks\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "# \u2500\u2500 ACTIVATIONS \u2500\u2500\n", + "class SquaredReLU(nn.Module):\n", + " def forward(self, x): return F.relu(x) ** 2\n", + "class StarReLU(nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.scale = nn.Parameter(torch.ones(1) * 0.8944)\n", + " self.bias = nn.Parameter(torch.zeros(1) - 0.4472)\n", + " def forward(self, x): return F.relu(x) ** 2 * self.scale + self.bias\n", + "ACTIVATIONS = {\n", + " 'squared_relu': SquaredReLU, 'star_relu': StarReLU,\n", + " 'gelu': lambda: nn.GELU(), 'relu': lambda: nn.ReLU(), 'sigmoid': lambda: nn.Sigmoid(),\n", + "}\n", + "def make_activation(name='squared_relu'):\n", + " if name not in ACTIVATIONS:\n", + " raise ValueError(f\"Unknown activation '{name}'. Choose from: {list(ACTIVATIONS.keys())}\")\n", + " return ACTIVATIONS[name]()\n", + "# \u2500\u2500 ANCHOR INITIALIZATION \u2500\u2500\n", + "def init_anchors_xavier(n, d):\n", + " w = torch.empty(n, d); nn.init.xavier_normal_(w); return F.normalize(w, dim=-1)\n", + "def init_anchors_orthogonal(n, d):\n", + " if n <= d:\n", + " Q, _ = torch.linalg.qr(torch.randn(d, n)); return Q.T.contiguous()\n", + " else:\n", + " Q, _ = torch.linalg.qr(torch.randn(d, d))\n", + " return torch.cat([Q.T, F.normalize(torch.randn(n - d, d), dim=-1)], dim=0)\n", + "def init_anchors_repulsion(n, d, iters=200, lr=0.05):\n", + " vecs = F.normalize(init_anchors_orthogonal(n, d), dim=-1)\n", + " for _ in range(iters):\n", + " sim = vecs @ vecs.T; sim.fill_diagonal_(-2.0)\n", + " vecs = F.normalize(vecs - lr * vecs[sim.argmax(dim=1)], dim=-1)\n", + " return vecs\n", + "INIT_METHODS = {'xavier': init_anchors_xavier, 'orthogonal': init_anchors_orthogonal, 'repulsion': init_anchors_repulsion}\n", + "# \u2500\u2500 CONSTELLATION \u2500\u2500\n", + "class Constellation(nn.Module):\n", + " \"\"\"Anchors on S^(d-1). Triangulates input embeddings.\"\"\"\n", + " def __init__(self, n_anchors, dim, anchor_drop=0.0, anchor_init='repulsion'):\n", + " super().__init__()\n", + " self.anchors = nn.Parameter(INIT_METHODS[anchor_init](n_anchors, dim))\n", + " self.anchor_drop = anchor_drop\n", + " self.n_anchors = n_anchors\n", + " self.dim = dim\n", + " def triangulate(self, emb, training=False):\n", + " anchors = F.normalize(self.anchors, dim=-1)\n", + " if training and self.anchor_drop > 0:\n", + " mask = torch.rand(anchors.shape[0], device=anchors.device) > self.anchor_drop\n", + " if mask.sum() < 2: mask[:2] = True\n", + " anchors = anchors[mask]\n", + " cos = emb @ anchors.T; tri = 1.0 - cos\n", + " _, nl = cos.max(dim=-1)\n", + " return tri, mask.nonzero(as_tuple=True)[0][nl]\n", + " cos = emb @ anchors.T; tri = 1.0 - cos; _, nearest = cos.max(dim=-1)\n", + " return tri, nearest\n", + " def forward(self, emb, training=False): return self.triangulate(emb, training)\n", + "# \u2500\u2500 PATCHWORK \u2500\u2500\n", + "class Patchwork(nn.Module):\n", + " \"\"\"Round-robin compartments reading diverse anchor subsets.\"\"\"\n", + " def __init__(self, n_anchors, n_comp=8, d_comp=64, activation='squared_relu'):\n", + " super().__init__()\n", + " self.n_comp, self.d_comp = n_comp, d_comp\n", + " self.output_dim = n_comp * d_comp\n", + " self.register_buffer('asgn', torch.arange(n_anchors) % n_comp)\n", + " apc = n_anchors // n_comp\n", + " self.comps = nn.ModuleList([\n", + " nn.Sequential(nn.Linear(apc, d_comp*2), make_activation(activation),\n", + " nn.Linear(d_comp*2, d_comp), nn.LayerNorm(d_comp))\n", + " for _ in range(n_comp)])\n", + " def forward(self, tri):\n", + " return torch.cat([self.comps[k](tri[:, self.asgn == k]) for k in range(self.n_comp)], dim=-1)\n", + "# \u2500\u2500 RELAY LAYER \u2500\u2500\n", + "class RelayLayer(nn.Module):\n", + " \"\"\"Single constellation relay. Vectorized, gated, no attention.\n", + " Patches -> S^(patch_dim-1) -> triangulate at 3 SLERP phases -> patchwork -> gated residual.\"\"\"\n", + " def __init__(self, input_dim, patch_dim=16, n_anchors=16, n_phases=3, pw_hidden=32, gate_init=-3.0):\n", + " super().__init__()\n", + " assert input_dim % patch_dim == 0\n", + " self.input_dim, self.patch_dim = input_dim, patch_dim\n", + " self.n_patches = input_dim // patch_dim\n", + " self.n_anchors, self.n_phases = n_anchors, n_phases\n", + " P, A, d = self.n_patches, n_anchors, patch_dim\n", + " home = torch.empty(P, A, d); nn.init.xavier_normal_(home.view(P*A, d))\n", + " home = F.normalize(home.view(P, A, d), dim=-1)\n", + " self.register_buffer('home', home)\n", + " self.anchors = nn.Parameter(home.clone())\n", + " tri_dim = n_phases * A\n", + " self.pw_w1 = nn.Parameter(torch.empty(P, tri_dim, pw_hidden))\n", + " self.pw_b1 = nn.Parameter(torch.zeros(1, P, pw_hidden))\n", + " self.pw_w2 = nn.Parameter(torch.empty(P, pw_hidden, d))\n", + " self.pw_b2 = nn.Parameter(torch.zeros(1, P, d))\n", + " for p in range(P):\n", + " nn.init.xavier_normal_(self.pw_w1.data[p])\n", + " nn.init.xavier_normal_(self.pw_w2.data[p])\n", + " self.pw_norm = nn.LayerNorm(d)\n", + " self.gates = nn.Parameter(torch.full((P,), gate_init))\n", + " self.norm = nn.LayerNorm(input_dim)\n", + " def drift(self):\n", + " h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1)\n", + " return torch.acos((h * c).sum(dim=-1).clamp(-1+1e-7, 1-1e-7))\n", + " def at_phase(self, t):\n", + " h, c = F.normalize(self.home, dim=-1), F.normalize(self.anchors, dim=-1)\n", + " omega = self.drift().unsqueeze(-1); so = omega.sin().clamp(min=1e-7)\n", + " return torch.sin((1-t)*omega)/so * h + torch.sin(t*omega)/so * c\n", + " def forward(self, x):\n", + " B, D = x.shape; P, A, d = self.n_patches, self.n_anchors, self.patch_dim\n", + " patches = self.norm(x).reshape(B, P, d)\n", + " patches_n = F.normalize(patches, dim=-1)\n", + " tris = []\n", + " for t in [0.0, 1/3, 2/3]:\n", + " at = F.normalize(self.at_phase(t), dim=-1)\n", + " tris.append(1.0 - torch.einsum('bpd,pad->bpa', patches_n, at))\n", + " tri = torch.cat(tris, dim=-1)\n", + " h = F.gelu(torch.einsum('bpt,pth->bph', tri, self.pw_w1) + self.pw_b1)\n", + " pw = self.pw_norm(torch.einsum('bph,phd->bpd', h, self.pw_w2) + self.pw_b2)\n", + " gate = self.gates.sigmoid().unsqueeze(0).unsqueeze(-1)\n", + " return x + (gate * pw + (1-gate) * patches).reshape(B, D)\n", + "# \u2500\u2500 CONSTELLATION RELAY (sequence-aware) \u2500\u2500\n", + "class ConstellationRelay(nn.Module):\n", + " \"\"\"Per-token geometric processing. O(S). Handles (B,D) and (B,S,D).\"\"\"\n", + " def __init__(self, dim, n_anchors=16, n_comp=8, d_comp=64,\n", + " gate_init=-3.0, anchor_init='repulsion', activation='squared_relu'):\n", + " super().__init__()\n", + " self.dim = dim; self.norm = nn.LayerNorm(dim)\n", + " self.constellation = Constellation(n_anchors, dim, anchor_init=anchor_init)\n", + " self.patchwork = Patchwork(n_anchors, n_comp, d_comp, activation)\n", + " self.proj = nn.Linear(self.patchwork.output_dim, dim)\n", + " self.gate = nn.Parameter(torch.full((dim,), gate_init))\n", + " def forward(self, x):\n", + " squeeze = x.dim() == 2\n", + " if squeeze: x = x.unsqueeze(1)\n", + " B, S, D = x.shape; residual = x\n", + " h_flat = F.normalize(self.norm(x).reshape(B*S, D), dim=-1)\n", + " tri, _ = self.constellation.triangulate(h_flat)\n", + " update = self.proj(self.patchwork(tri)).reshape(B, S, D)\n", + " out = residual + torch.sigmoid(self.gate) * update\n", + " return out.squeeze(1) if squeeze else out\n", + "# \u2500\u2500 MAGNITUDE FLOW \u2500\u2500\n", + "class MagnitudeFlow(nn.Module):\n", + " \"\"\"Relay-stack per-compartment magnitude. No attention.\"\"\"\n", + " def __init__(self, dim, n_anchors, hidden_dim=64, n_heads=4,\n", + " n_layers=2, mag_min=0.1, mag_max=5.0, n_comp=8):\n", + " super().__init__()\n", + " self.dim, self.n_anchors = dim, n_anchors\n", + " self.mag_min, self.mag_max, self.n_comp, self.n_layers = mag_min, mag_max, n_comp, n_layers\n", + " patch_dim = 16; relay_dim = n_comp * patch_dim\n", + " self.patch_dim, self.relay_dim = patch_dim, relay_dim\n", + " self.emb_proj = nn.Linear(dim, relay_dim // 2)\n", + " self.tri_proj = nn.Linear(n_anchors, relay_dim // 4)\n", + " self.ctx_proj = nn.Linear(relay_dim // 2 + relay_dim // 4 + 1, relay_dim)\n", + " self.relays = nn.ModuleList([\n", + " RelayLayer(relay_dim, patch_dim, 16, 3, hidden_dim, -3.0) for _ in range(n_layers)])\n", + " self.mag_heads = nn.ModuleList([\n", + " nn.Sequential(nn.Linear(patch_dim, patch_dim//2), nn.GELU(), nn.Linear(patch_dim//2, 1))\n", + " for _ in range(n_comp)])\n", + " self.register_buffer('stats_bias_cached', torch.zeros(n_comp), persistent=False)\n", + " def update_stats(self, push_diag, anchor_push):\n", + " with torch.no_grad():\n", + " device = self.stats_bias_cached.device\n", + " if anchor_push.strategy == 'momentum' and anchor_push.accumulator is not None:\n", + " mn = anchor_push.accumulator.norm(dim=-1)\n", + " apc = self.n_anchors // self.n_comp\n", + " self.stats_bias_cached = torch.stack([\n", + " mn[k*apc : (k+1)*apc if k < self.n_comp-1 else self.n_anchors].mean()\n", + " for k in range(self.n_comp)])\n", + " else: self.stats_bias_cached.zero_()\n", + " def forward(self, emb, triangulation, raw_magnitude):\n", + " B, A = emb.shape[0], self.n_anchors\n", + " x = self.ctx_proj(torch.cat([self.emb_proj(emb), self.tri_proj(triangulation), raw_magnitude], -1))\n", + " for relay in self.relays: x = relay(x)\n", + " patches = x.reshape(B, self.n_comp, self.patch_dim)\n", + " mc = torch.cat([self.mag_heads[k](patches[:, k]) for k in range(self.n_comp)], -1)\n", + " mc = self.mag_min + (self.mag_max - self.mag_min) * torch.sigmoid(mc + self.stats_bias_cached)\n", + " apc = A // self.n_comp\n", + " mag = torch.cat([mc[:, k:k+1].expand(-1, apc if k < self.n_comp-1 else A - k*apc)\n", + " for k in range(self.n_comp)], -1)\n", + " return mag, mc\n", + " def get_relay_diagnostics(self):\n", + " return [{'layer': i, 'drift_mean': r.drift().mean().item(),\n", + " 'gate_mean': r.gates.sigmoid().mean().item()} for i, r in enumerate(self.relays)]\n", + "# \u2500\u2500 ANCHOR PUSH \u2500\u2500\n", + "def _project_tangent(vec, point):\n", + " return vec - (vec * point).sum(dim=-1, keepdim=True) * point\n", + "def _compute_centroids_and_assign(anchors_n, emb_n, label_buffer, device):\n", + " n_a = anchors_n.shape[0]; classes = label_buffer.unique(); n_cls = classes.shape[0]\n", + " centroids = torch.cat([F.normalize(emb_n[label_buffer==c].mean(0, keepdim=True), dim=-1)\n", + " for c in classes if (label_buffer==c).sum() > 0], dim=0)\n", + " if centroids.shape[0] == 0: return None, None, None, None\n", + " cos = anchors_n @ centroids.T; apc = n_a // n_cls\n", + " assigned = torch.full((n_a,), -1, dtype=torch.long, device=device)\n", + " cc = torch.zeros(n_cls, dtype=torch.long, device=device)\n", + " for idx in cos.flatten().sort(descending=True).indices:\n", + " a, c = (idx // n_cls).item(), (idx % n_cls).item()\n", + " if assigned[a] >= 0 or cc[c] >= apc + 1: continue\n", + " assigned[a] = c; cc[c] += 1\n", + " if (assigned >= 0).all(): break\n", + " u = (assigned < 0).nonzero(as_tuple=True)[0]\n", + " if len(u) > 0: assigned[u] = (anchors_n[u] @ centroids.T).argmax(1)\n", + " nearest = (emb_n @ anchors_n.T).argmax(1)\n", + " util = torch.bincount(nearest, minlength=n_a).float()\n", + " return centroids, assigned, util / util.sum().clamp(min=1), classes\n", + "def _perturb_target(target, apc, rank):\n", + " if apc > 1 and rank > 0:\n", + " noise = torch.randn_like(target) * 0.05\n", + " return F.normalize(target + noise - (noise * target).sum() * target, dim=-1)\n", + " return target\n", + "class AnchorPush:\n", + " \"\"\"Configurable anchor push. Strategies: raw, gru, momentum.\"\"\"\n", + " def __init__(self, strategy, n_anchors, dim, **kw):\n", + " self.strategy, self.n_anchors, self.dim, self.push_count = strategy, n_anchors, dim, 0\n", + " if strategy == 'raw': self.lr = kw.get('lr', 0.1)\n", + " elif strategy == 'momentum':\n", + " self.decay, self.alpha, self.beta = kw.get('decay', 0.9), kw.get('alpha', 0.1), kw.get('beta', 0.05)\n", + " self.util_floor, self.accumulator = kw.get('util_floor', 0.001), None\n", + " elif strategy == 'gru':\n", + " self.ema_decay = kw.get('ema_decay', 0.9); self.z_scale = kw.get('z_scale', 3.0)\n", + " self.r_scale = kw.get('r_scale', 5.0)\n", + " self.prev_pos = self.util_ema = self.drift_ema = None\n", + " @torch.no_grad()\n", + " def push(self, core, emb_buf, lbl_buf):\n", + " anchors = core.constellation.anchors.data; n_a = anchors.shape[0]; device = anchors.device\n", + " emb_n = F.normalize(emb_buf, dim=-1); anchors_n = F.normalize(anchors, dim=-1)\n", + " centroids, assigned, util, classes = _compute_centroids_and_assign(anchors_n, emb_n, lbl_buf, device)\n", + " if centroids is None: return {'moved': 0}\n", + " if hasattr(core, 'anchor_classes'):\n", + " for a in range(n_a): core.anchor_classes[a] = classes[assigned[a]]\n", + " if hasattr(core, 'class_centroids'):\n", + " for i, c in enumerate(classes): core.class_centroids[c] = centroids[i]\n", + " apc = n_a // centroids.shape[0]\n", + " targets = torch.stack([_perturb_target(centroids[assigned[a].item()], apc,\n", + " (assigned[:a]==assigned[a]).sum().item()) for a in range(n_a)])\n", + " if self.strategy == 'raw':\n", + " for a in range(n_a): anchors[a] = F.normalize(anchors_n[a] + self.lr*(targets[a]-anchors_n[a]), dim=-1)\n", + " d = torch.acos((anchors_n * F.normalize(anchors, dim=-1)).sum(-1).clamp(-1+1e-6, 1-1e-6))\n", + " diag = {'drift_mean': d.mean().item(), 'drift_max': d.max().item()}\n", + " elif self.strategy == 'momentum':\n", + " if self.accumulator is None: self.accumulator = torch.zeros(n_a, self.dim, device=device)\n", + " res = _project_tangent(targets - anchors_n, anchors_n)\n", + " self.accumulator = self.decay * _project_tangent(self.accumulator, anchors_n) + res\n", + " corr = self.alpha * res + self.beta * self.accumulator\n", + " dead = util < self.util_floor\n", + " if dead.any(): corr[dead] = res[dead] * 0.5\n", + " new = F.normalize(anchors_n + corr, dim=-1)\n", + " d = torch.acos((anchors_n * new).sum(-1).clamp(-1+1e-6, 1-1e-6))\n", + " anchors.copy_(new)\n", + " diag = {'drift_mean': d.mean().item(), 'drift_max': d.max().item(),\n", + " 'momentum_mean': self.accumulator.norm(dim=-1).mean().item(), 'dead_count': dead.sum().item()}\n", + " else:\n", + " diag = {}\n", + " diag.update({'moved': n_a, 'n_active': (util > 0).sum().item(),\n", + " 'util_min': util.min().item(), 'util_max': util.max().item()})\n", + " self.push_count += 1; return diag\n", + "# \u2500\u2500 FLOW ATTENTION (historical) \u2500\u2500\n", + "class FlowAttention(nn.Module):\n", + " \"\"\"3-step Euler flow in tangent space. Superseded by relay.\"\"\"\n", + " def __init__(self, dim, n_anchors, flow_dim=64, n_steps=3, time_dim=32, gate_init=-3.0):\n", + " super().__init__()\n", + " self.dim, self.flow_dim, self.n_anchors, self.n_steps, self.time_dim = dim, flow_dim, n_anchors, n_steps, time_dim\n", + " self.to_flow = nn.Sequential(nn.Linear(n_anchors+dim, flow_dim), nn.LayerNorm(flow_dim))\n", + " self.time_mlp = nn.Sequential(nn.Linear(time_dim, flow_dim), nn.GELU())\n", + " self.stats_proj = nn.Linear(3, flow_dim, bias=False)\n", + " self.velocity = nn.Sequential(nn.Linear(flow_dim, flow_dim*2), nn.GELU(), nn.Linear(flow_dim*2, flow_dim))\n", + " self.to_correction = nn.Linear(flow_dim, dim, bias=False)\n", + " self.gate = nn.Parameter(torch.full((dim,), gate_init))\n", + " self.register_buffer('stats_bias_cached', torch.zeros(flow_dim), persistent=False)\n", + " def update_stats(self, push_diag, anchor_push):\n", + " with torch.no_grad():\n", + " dev = self.stats_proj.weight.device\n", + " mn = anchor_push.accumulator.norm(dim=-1) if (anchor_push.strategy=='momentum' and anchor_push.accumulator is not None) else torch.zeros(self.n_anchors, device=dev)\n", + " dr = torch.tensor(push_diag.get('drift_mean',0.0), device=dev).expand(self.n_anchors)\n", + " ut = torch.tensor(push_diag.get('util_max',0.0), device=dev).expand(self.n_anchors)\n", + " self.stats_bias_cached = self.stats_proj(torch.stack([mn, ut, dr], -1)).mean(0)\n", + " def forward(self, emb, constellation):\n", + " B, D, dev = *emb.shape, emb.device\n", + " tri = emb @ F.normalize(constellation.anchors, dim=-1).T\n", + " z = self.to_flow(torch.cat([tri, emb], -1)); dt = 1.0/self.n_steps\n", + " half = self.time_dim // 2\n", + " freqs = torch.exp(-math.log(10000.0) * torch.arange(half, device=dev) / half)\n", + " for s in range(self.n_steps):\n", + " args = (s*dt)*freqs; t_emb = torch.cat([args.sin(), args.cos()])\n", + " z = z + dt * (self.velocity(z + self.time_mlp(t_emb)) + self.stats_bias_cached)\n", + " c = self.to_correction(z); c = c - (c*emb).sum(-1,keepdim=True)*emb\n", + " return F.normalize(emb + torch.sigmoid(self.gate)*c, dim=-1)\n", + "# \u2500\u2500 GEOMETRIC AUTOGRAD \u2500\u2500\n", + "class GeometricAutograd(torch.autograd.Function):\n", + " \"\"\"Manifold-aware gradient correction on S^(D-1). Forward: identity.\"\"\"\n", + " @staticmethod\n", + " def forward(ctx, emb, anchors, tang_strength, sep_strength):\n", + " ctx.save_for_backward(emb, anchors); ctx.tang, ctx.sep = tang_strength, sep_strength\n", + " return emb\n", + " @staticmethod\n", + " def backward(ctx, grad):\n", + " emb, anchors = ctx.saved_tensors\n", + " dot = (grad * emb).sum(-1, keepdim=True)\n", + " corrected = grad - ctx.tang * dot * emb\n", + " if ctx.sep > 0:\n", + " an = F.normalize(anchors.detach(), dim=-1)\n", + " nearest = an[(emb @ an.T).argmax(-1)]\n", + " toward = (corrected * nearest).sum(-1, keepdim=True)\n", + " corrected = corrected - ctx.sep * F.relu(toward) * nearest\n", + " return corrected, None, None, None\n", + "# \u2500\u2500 UTILITIES \u2500\u2500\n", + "def param_count(module, name=\"\"):\n", + " t = sum(p.numel() for p in module.parameters())\n", + " tr = sum(p.numel() for p in module.parameters() if p.requires_grad)\n", + " if name: print(f\" {name}: {t:,} ({tr:,} trainable)\")\n", + " return t, tr\n", + "def model_summary(model):\n", + " total = sum(p.numel() for p in model.parameters())\n", + " print(f\" Total: {total:,}\")\n", + " for n, m in model.named_children():\n", + " c = sum(p.numel() for p in m.parameters())\n", + " if c > 0: print(f\" {n}: {c:,}\")\n", + " return total\n", + "\n", + "\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "# GEOLIP LOSSES \u2014 All loss functions and metrics\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "\n", + "\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "# UTILITY \u2014 deferred .item() for logging only\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "\n", + "def ld_to_scalars(ld):\n", + " \"\"\"Convert a loss dict from tensors to Python floats for logging.\n", + " Call this ONCE per epoch at logging time, NOT per batch.\"\"\"\n", + " out = {}\n", + " for k, v in ld.items():\n", + " if isinstance(v, torch.Tensor):\n", + " out[k] = v.item() if v.dim() == 0 else v.detach().cpu().tolist()\n", + " else:\n", + " out[k] = v\n", + " return out\n", + "\n", + "\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "# CV \u2014 Coefficient of Variation of Pentachoron Volumes\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "\n", + "def _batch_pentachoron_volumes(emb, n_samples=200, n_points=5):\n", + " \"\"\"Batched pentachoron volumes \u2014 zero Python loops.\"\"\"\n", + " N, D = emb.shape\n", + " device, dtype = emb.device, emb.dtype\n", + " pool = min(N, 512)\n", + " indices = torch.rand(n_samples, pool, device=device).argsort(dim=1)[:, :n_points]\n", + " pts = emb[:pool][indices]\n", + " gram = torch.bmm(pts, pts.transpose(1, 2))\n", + " norms = torch.diagonal(gram, dim1=1, dim2=2)\n", + " d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram)\n", + " M = n_points + 1\n", + " cm = torch.zeros(n_samples, M, M, device=device, dtype=dtype)\n", + " cm[:, 0, 1:] = 1.0; cm[:, 1:, 0] = 1.0; cm[:, 1:, 1:] = d2\n", + " k = n_points - 1\n", + " pf = ((-1.0) ** (k + 1)) / ((2.0 ** k) * (math.factorial(k) ** 2))\n", + " dets = pf * torch.linalg.det(cm.float())\n", + " valid = dets > 1e-20\n", + " return dets[valid].to(dtype).sqrt()\n", + "\n", + "\n", + "def cv_loss(emb, target=0.22, n_samples=32, n_points=5, batched=True):\n", + " \"\"\"Differentiable CV loss. Returns (CV - target)\u00b2 as TENSOR.\n", + " Default n_samples=32 for training speed.\"\"\"\n", + " if emb.shape[0] < n_points:\n", + " return torch.tensor(0.0, device=emb.device, requires_grad=True)\n", + " vols = _batch_pentachoron_volumes(emb, n_samples, n_points)\n", + " if vols.shape[0] < 5:\n", + " return torch.tensor(0.0, device=emb.device, requires_grad=True)\n", + " cv = vols.std() / (vols.mean() + 1e-8)\n", + " return (cv - target).pow(2)\n", + "\n", + "\n", + "@torch.no_grad()\n", + "def cv_metric(emb, n_samples=200, n_points=5):\n", + " \"\"\"Non-differentiable CV for monitoring. Returns float.\"\"\"\n", + " vols = _batch_pentachoron_volumes(emb, n_samples, n_points)\n", + " if vols.shape[0] < 10:\n", + " return 0.0\n", + " return (vols.std() / (vols.mean() + 1e-8)).item()\n", + "\n", + "\n", + "def cv_multi_scale(emb, scales=(3, 4, 5, 6, 7, 8), n_samples=100):\n", + " \"\"\"CV at multiple simplex sizes. Returns dict.\"\"\"\n", + " results = {}\n", + " with torch.no_grad():\n", + " for n_pts in scales:\n", + " vols = _batch_pentachoron_volumes(emb, n_samples, n_pts)\n", + " results[n_pts] = round((vols.std() / (vols.mean() + 1e-8)).item(), 4) if vols.shape[0] >= 10 else None\n", + " return results\n", + "\n", + "\n", + "def cayley_menger_vol2(points):\n", + " \"\"\"Squared simplex volume. points: (B, N, D) \u2192 (B,).\"\"\"\n", + " B, N, D = points.shape\n", + " gram = torch.bmm(points, points.transpose(1, 2))\n", + " norms = torch.diagonal(gram, dim1=1, dim2=2)\n", + " d2 = F.relu(norms.unsqueeze(2) + norms.unsqueeze(1) - 2 * gram)\n", + " cm = torch.zeros(B, N + 1, N + 1, device=points.device, dtype=points.dtype)\n", + " cm[:, 0, 1:] = 1; cm[:, 1:, 0] = 1; cm[:, 1:, 1:] = d2\n", + " k = N - 1\n", + " sign = (-1.0) ** (k + 1)\n", + " fact = math.factorial(k)\n", + " return sign * torch.linalg.det(cm.float()).to(points.dtype) / ((2 ** k) * (fact ** 2))\n", + "\n", + "\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "# NCE \u2014 InfoNCE contrastive loss (NO .item() \u2014 returns tensors)\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "\n", + "def nce_loss(z1, z2, temperature=0.07, normalize=True):\n", + " \"\"\"Symmetric InfoNCE. Returns (loss_tensor, acc_tensor). No CUDA sync.\"\"\"\n", + " if normalize:\n", + " z1 = F.normalize(z1, dim=-1); z2 = F.normalize(z2, dim=-1)\n", + " B = z1.shape[0]\n", + " labels = torch.arange(B, device=z1.device)\n", + " sim = z1 @ z2.T / temperature\n", + " loss = F.cross_entropy(sim, labels)\n", + " acc = (sim.argmax(1) == labels).float().mean() # TENSOR, no .item()\n", + " return loss, acc\n", + "\n", + "\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "# CLASSIFICATION (NO .item())\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "\n", + "def ce_loss(logits, targets):\n", + " \"\"\"Cross-entropy + accuracy. Returns (loss_tensor, acc_tensor).\"\"\"\n", + " loss = F.cross_entropy(logits, targets)\n", + " acc = (logits.argmax(-1) == targets).float().mean() # TENSOR\n", + " return loss, acc\n", + "\n", + "\n", + "def ce_loss_paired(logits1, logits2, targets):\n", + " \"\"\"Averaged CE over two views. Returns (loss_tensor, acc_tensor).\"\"\"\n", + " l1 = F.cross_entropy(logits1, targets)\n", + " l2 = F.cross_entropy(logits2, targets)\n", + " acc = (logits1.argmax(-1) == targets).float().mean() # TENSOR\n", + " return (l1 + l2) / 2, acc\n", + "\n", + "\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "# BRIDGE (NO .item())\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "\n", + "def bridge_loss(bridge_logits, assign_targets, detach_targets=True):\n", + " \"\"\"Soft CE: patchwork predicts constellation's soft assignment. Returns (loss, acc) tensors.\"\"\"\n", + " if detach_targets: assign_targets = assign_targets.detach()\n", + " loss = -(assign_targets * F.log_softmax(bridge_logits, dim=-1)).sum(-1).mean()\n", + " acc = (bridge_logits.argmax(-1) == assign_targets.argmax(-1)).float().mean() # TENSOR\n", + " return loss, acc\n", + "\n", + "\n", + "def bridge_loss_paired(bridge1, bridge2, assign1, assign2, detach_targets=True):\n", + " \"\"\"Bridge loss averaged over two views. Returns (loss, acc) tensors.\"\"\"\n", + " l1, acc = bridge_loss(bridge1, assign1, detach_targets)\n", + " l2, _ = bridge_loss(bridge2, assign2, detach_targets)\n", + " return (l1 + l2) / 2, acc\n", + "\n", + "\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "# ASSIGNMENT (NO .item())\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "\n", + "def assign_bce_loss(soft_assign, cos_to_anchors):\n", + " \"\"\"Assignment crispness. Returns (loss_tensor, entropy_tensor).\"\"\"\n", + " nearest = cos_to_anchors.argmax(dim=-1)\n", + " hard = torch.zeros_like(soft_assign)\n", + " hard.scatter_(1, nearest.unsqueeze(1), 1.0)\n", + " with torch.amp.autocast(\"cuda\", enabled=False):\n", + " loss = F.binary_cross_entropy(\n", + " soft_assign.float().clamp(1e-7, 1 - 1e-7), hard.float(), reduction='mean')\n", + " entropy = -(soft_assign * soft_assign.clamp(min=1e-8).log()).sum(-1).mean() # TENSOR\n", + " return loss, entropy\n", + "\n", + "\n", + "def assign_nce_loss(assign1, assign2, temperature=0.1):\n", + " \"\"\"Assignment consistency NCE. Returns (loss, acc) tensors.\"\"\"\n", + " B = assign1.shape[0]\n", + " labels = torch.arange(B, device=assign1.device)\n", + " sim = assign1 @ assign2.T / temperature\n", + " loss = F.cross_entropy(sim, labels)\n", + " acc = (sim.argmax(1) == labels).float().mean() # TENSOR\n", + " return loss, acc\n", + "\n", + "\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "# ATTRACTION (NO .item())\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "\n", + "def attraction_loss(cos_to_anchors):\n", + " \"\"\"Pull embeddings toward nearest anchor. Returns (loss, nearest_cos) tensors.\"\"\"\n", + " nearest_cos = cos_to_anchors.max(dim=1).values\n", + " loss = (1.0 - nearest_cos).mean()\n", + " return loss, nearest_cos.mean() # TENSOR\n", + "\n", + "\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "# SPREAD\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "\n", + "def spread_loss(anchors, target_cos=0.0):\n", + " \"\"\"Repulsion loss keeping anchors spread on S^(d-1).\"\"\"\n", + " a = F.normalize(anchors, dim=-1)\n", + " sim = a @ a.T\n", + " mask = ~torch.eye(a.shape[0], dtype=torch.bool, device=a.device)\n", + " return F.relu(sim[mask] - target_cos).mean()\n", + "\n", + "\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "# kNN \u2014 validation only (this one CAN .item() \u2014 it's not in training)\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "\n", + "@torch.no_grad()\n", + "def knn_accuracy(embeddings, targets, k=1):\n", + " \"\"\"k-NN classification accuracy. Returns float (called outside training loop).\"\"\"\n", + " sim = embeddings @ embeddings.T\n", + " sim.fill_diagonal_(-1)\n", + " if k == 1:\n", + " nn_idx = sim.argmax(dim=1)\n", + " return (targets[nn_idx] == targets).float().mean().item()\n", + " else:\n", + " _, topk_idx = sim.topk(k, dim=1)\n", + " nn_labels = targets[topk_idx]\n", + " pred = nn_labels.mode(dim=1).values\n", + " return (pred == targets).float().mean().item()\n", + "\n", + "\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "# THREE-DOMAIN COMPOUND LOSS (NO .item() in hot path)\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "\n", + "def three_domain_loss(output, targets, constellation, cv_target=0.22,\n", + " infonce_temp=0.07, assign_temp=0.1,\n", + " w_ce=1.0, w_nce_emb=0.5,\n", + " w_nce_pw=1.0, w_bridge=1.0,\n", + " w_assign=0.5, w_assign_nce=0.25,\n", + " w_nce_tri=0.5, w_attract=0.25,\n", + " w_cv=0.01, w_spread=0.01,\n", + " cv_batched=True):\n", + " \"\"\"Full three-domain cooperative loss. Returns (total_loss, loss_dict).\n", + " loss_dict values are TENSORS. Call ld_to_scalars() for logging.\"\"\"\n", + " ld = {}\n", + " emb1, emb2 = output['embedding'], output['embedding_aug']\n", + "\n", + " # \u2500\u2500 EXTERNAL \u2500\u2500\n", + " l_ce, acc = ce_loss_paired(output['logits'], output['logits_aug'], targets)\n", + " ld['ce'], ld['acc'] = l_ce, acc\n", + " l_nce_emb, nce_emb_acc = nce_loss(emb1, emb2, infonce_temp, normalize=False)\n", + " ld['nce_emb'], ld['nce_emb_acc'] = l_nce_emb, nce_emb_acc\n", + "\n", + " # \u2500\u2500 GEOMETRIC \u2500\u2500\n", + " l_nce_pw, nce_pw_acc = nce_loss(output['patchwork1'], output['patchwork1_aug'], assign_temp, normalize=True)\n", + " ld['nce_pw'], ld['nce_pw_acc'] = l_nce_pw, nce_pw_acc\n", + " l_bridge, bridge_acc = bridge_loss_paired(output['bridge1'], output['bridge2'], output['assign1'], output['assign2'])\n", + " ld['bridge'], ld['bridge_acc'] = l_bridge, bridge_acc\n", + "\n", + " # \u2500\u2500 INTERNAL \u2500\u2500\n", + " l_assign, assign_ent = assign_bce_loss(output['assign1'], output['cos1'])\n", + " ld['assign'], ld['assign_entropy'] = l_assign, assign_ent\n", + " l_assign_nce, assign_nce_acc = assign_nce_loss(output['assign1'], output['assign2'], assign_temp)\n", + " ld['assign_nce'], ld['assign_nce_acc'] = l_assign_nce, assign_nce_acc\n", + " l_nce_tri, nce_tri_acc = nce_loss(output['tri1'], output['tri2'], 0.1, normalize=True)\n", + " ld['nce_tri'], ld['nce_tri_acc'] = l_nce_tri, nce_tri_acc\n", + " l_attract, nearest_cos = attraction_loss(output['cos1'])\n", + " ld['attract'], ld['nearest_cos'] = l_attract, nearest_cos\n", + " l_cv = cv_loss(emb1, target=cv_target, n_samples=32, batched=cv_batched)\n", + " ld['cv'] = l_cv\n", + " l_spread = spread_loss(constellation.anchors)\n", + " ld['spread'] = l_spread\n", + "\n", + " # \u2500\u2500 TOTAL (all tensor ops, zero sync) \u2500\u2500\n", + " loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n", + " loss_geometric = w_nce_pw * l_nce_pw + w_bridge * l_bridge\n", + " loss_internal = (w_assign * l_assign + w_assign_nce * l_assign_nce\n", + " + w_nce_tri * l_nce_tri + w_attract * l_attract\n", + " + w_cv * l_cv + w_spread * l_spread)\n", + " loss = loss_external + loss_geometric + loss_internal\n", + "\n", + " ld['loss_external'] = loss_external\n", + " ld['loss_geometric'] = loss_geometric\n", + " ld['loss_internal'] = loss_internal\n", + " ld['total'] = loss\n", + " return loss, ld\n", + "\n", + "\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "# OBSERVER LOSS (NO .item() in hot path)\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "\n", + "def observer_loss(output, anchors, targets=None,\n", + " infonce_temp=0.07, assign_temp=0.1, cv_target=0.22,\n", + " w_nce_emb=0.5,\n", + " w_nce_pw=1.0, w_bridge=1.0,\n", + " w_assign=0.5, w_assign_nce=0.25,\n", + " w_nce_tri=0.5, w_attract=0.25,\n", + " w_cv=0.01, w_spread=0.01,\n", + " cv_batched=True):\n", + " \"\"\"Observer self-organization loss. No CE. No task loss.\n", + " Returns (loss, loss_dict) where loss_dict values are TENSORS.\"\"\"\n", + " ld = {}\n", + " emb1, emb2 = output['embedding'], output['embedding_aug']\n", + "\n", + " l_nce_emb, nce_emb_acc = nce_loss(emb1, emb2, infonce_temp, normalize=False)\n", + " ld['nce_emb'], ld['nce_emb_acc'] = l_nce_emb, nce_emb_acc\n", + "\n", + " l_nce_pw, nce_pw_acc = nce_loss(\n", + " output['patchwork1'], output['patchwork1_aug'], assign_temp, normalize=True)\n", + " ld['nce_pw'], ld['nce_pw_acc'] = l_nce_pw, nce_pw_acc\n", + " l_bridge, bridge_acc = bridge_loss_paired(\n", + " output['bridge1'], output['bridge2'], output['assign1'], output['assign2'])\n", + " ld['bridge'], ld['bridge_acc'] = l_bridge, bridge_acc\n", + "\n", + " l_assign, assign_ent = assign_bce_loss(output['assign1'], output['cos1'])\n", + " ld['assign'], ld['assign_entropy'] = l_assign, assign_ent\n", + " l_assign_nce, assign_nce_acc = assign_nce_loss(\n", + " output['assign1'], output['assign2'], assign_temp)\n", + " ld['assign_nce'], ld['assign_nce_acc'] = l_assign_nce, assign_nce_acc\n", + " l_nce_tri, nce_tri_acc = nce_loss(output['tri1'], output['tri2'], 0.1, normalize=True)\n", + " ld['nce_tri'], ld['nce_tri_acc'] = l_nce_tri, nce_tri_acc\n", + " l_attract, nearest_cos = attraction_loss(output['cos1'])\n", + " ld['attract'], ld['nearest_cos'] = l_attract, nearest_cos\n", + " l_cv = cv_loss(emb1, target=cv_target, n_samples=32, batched=cv_batched)\n", + " ld['cv'] = l_cv\n", + " l_spread = spread_loss(anchors)\n", + " ld['spread'] = l_spread\n", + "\n", + " if targets is not None:\n", + " ld['knn_acc'] = knn_accuracy(emb1, targets)\n", + "\n", + " loss_geometric = w_nce_pw * l_nce_pw + w_bridge * l_bridge\n", + " loss_internal = (w_assign * l_assign + w_assign_nce * l_assign_nce\n", + " + w_nce_tri * l_nce_tri + w_attract * l_attract\n", + " + w_cv * l_cv + w_spread * l_spread)\n", + " loss = w_nce_emb * l_nce_emb + loss_geometric + loss_internal\n", + "\n", + " ld['loss_geometric'] = loss_geometric\n", + " ld['loss_internal'] = loss_internal\n", + " ld['total'] = loss\n", + " return loss, ld\n", + "\n", + "\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "# GEOLIP ENCODER \u2014 InternalConstellationCore + GeoLIPImageEncoder\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "\"\"\"\n", + "GeoLIP Image Encoder - CONV variation\n", + "=====================================\n", + "Complete trainable model: conv encoder \u2192 S^(d-1) \u2192 magnitude \u2192 constellation \u2192 classify.\n", + "Classes:\n", + " ConvEncoder: 8-layer conv \u2192 D-dim projection\n", + " InternalConstellationCore: Three-domain head (external + geometric + internal)\n", + " GeoLIPImageEncoder: Full pipeline: encoder + MagnitudeFlow + core\n", + "Usage:\n", + " from geolip_encoder import GeoLIPImageEncoder\n", + " model = GeoLIPImageEncoder(num_classes=100, output_dim=384, n_anchors=2048)\n", + " out = model.forward_paired(v1, v2)\n", + " loss, ld = model.compute_loss(out, targets)\n", + "Author: AbstractPhil + Claude Opus 4.6\n", + "License: Apache 2.0\n", + "\"\"\"\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "# CONV ENCODER \u2014 8-layer, proven on CIFAR-100\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "class ConvEncoder(nn.Module):\n", + " \"\"\"8-layer conv \u2192 D-dim projection on S^(d-1).\n", + " Architecture: 4 blocks of (conv-BN-GELU, conv-BN-GELU, MaxPool)\n", + " Channels: 64 \u2192 128 \u2192 256 \u2192 384\n", + " Output: (B, output_dim) after linear + LayerNorm\n", + " Note: L2 normalization is NOT applied here \u2014 the caller decides\n", + " when to normalize (preserving raw magnitude for MagnitudeFlow).\n", + " \"\"\"\n", + " def __init__(self, output_dim=256):\n", + " super().__init__()\n", + " self.output_dim = output_dim\n", + " self.features = nn.Sequential(\n", + " nn.Conv2d(3, 64, 3, padding=1), nn.BatchNorm2d(64), nn.GELU(),\n", + " nn.Conv2d(64, 64, 3, padding=1), nn.BatchNorm2d(64), nn.GELU(),\n", + " nn.MaxPool2d(2),\n", + " nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.GELU(),\n", + " nn.Conv2d(128, 128, 3, padding=1), nn.BatchNorm2d(128), nn.GELU(),\n", + " nn.MaxPool2d(2),\n", + " nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.GELU(),\n", + " nn.Conv2d(256, 256, 3, padding=1), nn.BatchNorm2d(256), nn.GELU(),\n", + " nn.MaxPool2d(2),\n", + " nn.Conv2d(256, 384, 3, padding=1), nn.BatchNorm2d(384), nn.GELU(),\n", + " nn.Conv2d(384, 384, 3, padding=1), nn.BatchNorm2d(384), nn.GELU(),\n", + " nn.MaxPool2d(2),\n", + " nn.AdaptiveAvgPool2d(1),\n", + " nn.Flatten(),\n", + " )\n", + " self.proj = nn.Sequential(\n", + " nn.Linear(384, output_dim),\n", + " nn.LayerNorm(output_dim),\n", + " )\n", + " def forward(self, x):\n", + " \"\"\"Returns: (B, output_dim) unnormalized features.\"\"\"\n", + " return self.proj(self.features(x))\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "# INTERNAL CONSTELLATION CORE \u2014 three-domain head\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "class InternalConstellationCore(nn.Module):\n", + " \"\"\"Constellation with independent internal + external objectives.\n", + " The constellation discovers its own structure. The task head reads it.\n", + " Three domains:\n", + " EXTERNAL: CE + embedding NCE \u2192 task_head, patchwork, encoder\n", + " GEOMETRIC: patchwork NCE + bridge \u2192 patchwork, encoder, anchors\n", + " INTERNAL: assign + tri NCE + attract + CV + spread \u2192 anchors, encoder\n", + " Args:\n", + " num_classes: classification targets\n", + " dim: embedding dimension\n", + " n_anchors: anchors on S^(dim-1)\n", + " n_comp: patchwork compartments\n", + " d_comp: hidden dim per compartment\n", + " anchor_drop: training anchor dropout\n", + " activation: activation function name\n", + " cv_target: target CV for geometric loss\n", + " infonce_temp: embedding NCE temperature\n", + " assign_temp: assignment temperature\n", + " assign_sharpness: BCE target sharpness\n", + " \"\"\"\n", + " def __init__(\n", + " self,\n", + " num_classes=100,\n", + " dim=256,\n", + " n_anchors=128,\n", + " n_comp=8,\n", + " d_comp=64,\n", + " anchor_drop=0.15,\n", + " activation='squared_relu',\n", + " cv_target=0.22,\n", + " infonce_temp=0.07,\n", + " assign_temp=0.1,\n", + " assign_sharpness=5.0,\n", + " ):\n", + " super().__init__()\n", + " self.num_classes = num_classes\n", + " self.dim = dim\n", + " self.n_anchors = n_anchors\n", + " self.cv_target = cv_target\n", + " self.infonce_temp = infonce_temp\n", + " self.assign_temp = assign_temp\n", + " self.assign_sharpness = assign_sharpness\n", + " self.config = {k: v for k, v in locals().items()\n", + " if k != 'self' and not k.startswith('_')}\n", + " # Constellation \u2014 owns its own geometry\n", + " self.constellation = Constellation(n_anchors, dim, anchor_drop)\n", + " # Patchwork \u2014 interprets distance patterns\n", + " self.patchwork = Patchwork(n_anchors, n_comp, d_comp, activation)\n", + " pw_dim = self.patchwork.output_dim\n", + " # Bridge: patchwork predicts constellation's assignment\n", + " self.bridge = nn.Sequential(nn.Linear(pw_dim, n_anchors))\n", + " # Task head: reads assignment + patchwork + embedding\n", + " total_feat = n_anchors + pw_dim + dim\n", + " self.task_head = nn.Sequential(\n", + " nn.Linear(total_feat, pw_dim),\n", + " make_activation(activation),\n", + " nn.LayerNorm(pw_dim),\n", + " nn.Dropout(0.1),\n", + " nn.Linear(pw_dim, num_classes),\n", + " )\n", + " # Buffers\n", + " self.register_buffer('anchor_classes', torch.zeros(n_anchors, dtype=torch.long))\n", + " self.register_buffer('class_centroids', torch.zeros(num_classes, dim))\n", + " def _triangulate(self, emb):\n", + " \"\"\"emb \u2192 (cos, tri, nearest, soft_assign).\"\"\"\n", + " anchors_n = F.normalize(self.constellation.anchors, dim=-1)\n", + " cos = emb @ anchors_n.T\n", + " tri = 1.0 - cos\n", + " _, nearest = cos.max(dim=-1)\n", + " soft_assign = F.softmax(cos / self.assign_temp, dim=-1)\n", + " return cos, tri, nearest, soft_assign\n", + " def forward_paired(self, emb1, emb2, mag1=None, mag2=None):\n", + " \"\"\"Paired forward for training. Returns dict with all intermediates.\"\"\"\n", + " cos1, tri1, nearest1, assign1 = self._triangulate(emb1)\n", + " cos2, tri2, nearest2, assign2 = self._triangulate(emb2)\n", + " # Magnitude weighting\n", + " tri1_w = tri1 * mag1 if mag1 is not None else tri1\n", + " tri2_w = tri2 * mag2 if mag2 is not None else tri2\n", + " # Patchwork\n", + " pw1 = self.patchwork(tri1_w)\n", + " pw2 = self.patchwork(tri2_w)\n", + " # Bridge\n", + " bridge1 = self.bridge(pw1)\n", + " bridge2 = self.bridge(pw2)\n", + " # Task head\n", + " feat1 = torch.cat([assign1, pw1, emb1], dim=-1)\n", + " feat2 = torch.cat([assign2, pw2, emb2], dim=-1)\n", + " logits1 = self.task_head(feat1)\n", + " logits2 = self.task_head(feat2)\n", + " return {\n", + " 'embedding': emb1, 'embedding_aug': emb2,\n", + " 'mag1': mag1, 'mag2': mag2,\n", + " 'cos1': cos1, 'cos2': cos2,\n", + " 'tri1': tri1, 'tri2': tri2,\n", + " 'nearest': nearest1,\n", + " 'assign1': assign1, 'assign2': assign2,\n", + " 'patchwork1': pw1, 'patchwork1_aug': pw2,\n", + " 'bridge1': bridge1, 'bridge2': bridge2,\n", + " 'logits': logits1, 'logits_aug': logits2,\n", + " }\n", + " def forward(self, emb, mag=None):\n", + " \"\"\"Single view for eval.\"\"\"\n", + " out = self.forward_paired(emb, emb, mag, mag)\n", + " return {\n", + " 'logits': out['logits'],\n", + " 'embedding': emb,\n", + " 'magnitude': mag,\n", + " 'triangulation': out['tri1'],\n", + " 'cos_to_anchors': out['cos1'],\n", + " 'nearest': out['nearest'],\n", + " 'assignment': out['assign1'],\n", + " 'patchwork': out['patchwork1'],\n", + " }\n", + " def compute_loss(self, output, targets,\n", + " w_ce=1.0, w_nce_emb=0.5,\n", + " w_nce_pw=1.0, w_bridge=1.0,\n", + " w_assign=0.5, w_assign_nce=0.25,\n", + " w_nce_tri=0.5, w_attract=0.25,\n", + " w_cv=0.01, w_spread=0.01,\n", + " cv_batched=True, compute_knn=False):\n", + " \"\"\"Three-domain cooperative loss \u2014 fully batched, zero Python loops.\n", + " Args:\n", + " compute_knn: if False (default), skip kNN during training for speed.\n", + " Set True during validation or every N steps.\n", + " Returns:\n", + " total_loss, loss_dict\n", + " \"\"\"\n", + " ld = {}\n", + " emb1, emb2 = output['embedding'], output['embedding_aug']\n", + " # \u2500\u2500 EXTERNAL (batched matmul) \u2500\u2500\n", + " l_ce, acc = ce_loss_paired(output['logits'], output['logits_aug'], targets)\n", + " ld['ce'], ld['acc'] = l_ce, acc\n", + " l_nce_emb, nce_emb_acc = nce_loss(emb1, emb2, self.infonce_temp, normalize=False)\n", + " ld['nce_emb'], ld['nce_emb_acc'] = l_nce_emb, nce_emb_acc\n", + " # \u2500\u2500 GEOMETRIC (batched matmul) \u2500\u2500\n", + " l_nce_pw, nce_pw_acc = nce_loss(\n", + " output['patchwork1'], output['patchwork1_aug'], self.assign_temp, normalize=True)\n", + " ld['nce_pw'], ld['nce_pw_acc'] = l_nce_pw, nce_pw_acc\n", + " l_bridge, bridge_acc = bridge_loss_paired(\n", + " output['bridge1'], output['bridge2'],\n", + " output['assign1'], output['assign2'])\n", + " ld['bridge'], ld['bridge_acc'] = l_bridge, bridge_acc\n", + " # \u2500\u2500 INTERNAL (batched \u2014 no Python loops) \u2500\u2500\n", + " l_assign, assign_ent = assign_bce_loss(output['assign1'], output['cos1'])\n", + " ld['assign'], ld['assign_entropy'] = l_assign, assign_ent\n", + " l_assign_nce, assign_nce_acc = assign_nce_loss(\n", + " output['assign1'], output['assign2'], self.assign_temp)\n", + " ld['assign_nce'], ld['assign_nce_acc'] = l_assign_nce, assign_nce_acc\n", + " l_nce_tri, nce_tri_acc = nce_loss(\n", + " output['tri1'], output['tri2'], 0.1, normalize=True)\n", + " ld['nce_tri'], ld['nce_tri_acc'] = l_nce_tri, nce_tri_acc\n", + " l_attract, nearest_cos = attraction_loss(output['cos1'])\n", + " ld['attract'], ld['nearest_cos'] = l_attract, nearest_cos\n", + " # CV: batched Cayley-Menger, n_samples=32 for training speed\n", + " l_cv = cv_loss(emb1, target=self.cv_target, n_samples=32, batched=cv_batched)\n", + " ld['cv'] = l_cv\n", + " l_spread = spread_loss(self.constellation.anchors)\n", + " ld['spread'] = l_spread\n", + " # \u2500\u2500 kNN (SKIP during training \u2014 B\u00d7B matmul is expensive every batch) \u2500\u2500\n", + " if compute_knn:\n", + " ld['knn_acc'] = knn_accuracy(emb1, targets)\n", + " # \u2500\u2500 TOTAL \u2500\u2500\n", + " loss_external = w_ce * l_ce + w_nce_emb * l_nce_emb\n", + " loss_geometric = w_nce_pw * l_nce_pw + w_bridge * l_bridge\n", + " loss_internal = (w_assign * l_assign + w_assign_nce * l_assign_nce\n", + " + w_nce_tri * l_nce_tri + w_attract * l_attract\n", + " + w_cv * l_cv + w_spread * l_spread)\n", + " loss = loss_external + loss_geometric + loss_internal\n", + " # Store tensors directly \u2014 NO .item() in hot path (causes CUDA sync stalls)\n", + " # Use ld_to_scalars() at logging time only\n", + " ld['loss_external'] = loss_external\n", + " ld['loss_geometric'] = loss_geometric\n", + " ld['loss_internal'] = loss_internal\n", + " ld['t_ce'] = l_ce\n", + " ld['t_nce_emb'] = l_nce_emb\n", + " ld['t_nce_pw'] = l_nce_pw\n", + " ld['t_bridge'] = l_bridge\n", + " ld['t_assign'] = l_assign\n", + " ld['t_assign_nce'] = l_assign_nce\n", + " ld['t_nce_tri'] = l_nce_tri\n", + " ld['t_attract'] = l_attract\n", + " ld['total'] = loss\n", + " return loss, ld\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "# GEOLIP IMAGE ENCODER \u2014 full pipeline\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "class GeoLIPImageEncoder(nn.Module):\n", + " \"\"\"Complete GeoLIP model: ConvEncoder \u2192 S^(d-1) \u2192 MagnitudeFlow \u2192 Core.\n", + " Args:\n", + " num_classes: classification targets\n", + " output_dim: embedding dimension on S^(d-1)\n", + " n_anchors: constellation anchors\n", + " n_comp: patchwork compartments\n", + " d_comp: per-compartment hidden dim\n", + " anchor_drop: training anchor dropout\n", + " activation: activation function name\n", + " cv_target: CV loss target\n", + " infonce_temp: embedding NCE temperature\n", + " assign_temp: assignment temperature\n", + " assign_sharpness: BCE sharpness\n", + " mag_hidden: magnitude relay patchwork hidden dim\n", + " mag_heads: unused (API compat)\n", + " mag_layers: relay layers in MagnitudeFlow\n", + " mag_min: minimum magnitude\n", + " mag_max: maximum magnitude\n", + " \"\"\"\n", + " def __init__(\n", + " self,\n", + " num_classes=100,\n", + " output_dim=384,\n", + " n_anchors=512,\n", + " n_comp=8,\n", + " d_comp=64,\n", + " anchor_drop=0.15,\n", + " activation='squared_relu',\n", + " cv_target=0.22,\n", + " infonce_temp=0.07,\n", + " assign_temp=0.1,\n", + " assign_sharpness=5.0,\n", + " mag_hidden=64,\n", + " mag_heads=4,\n", + " mag_layers=2,\n", + " mag_min=0.1,\n", + " mag_max=5.0,\n", + " ):\n", + " super().__init__()\n", + " self.output_dim = output_dim\n", + " self.config = {k: v for k, v in locals().items()\n", + " if k != 'self' and not k.startswith('_')}\n", + " self.encoder = ConvEncoder(output_dim)\n", + " self.mag_flow = MagnitudeFlow(\n", + " dim=output_dim, n_anchors=n_anchors,\n", + " hidden_dim=mag_hidden, n_heads=mag_heads, n_layers=mag_layers,\n", + " mag_min=mag_min, mag_max=mag_max, n_comp=n_comp,\n", + " )\n", + " self.core = InternalConstellationCore(\n", + " num_classes=num_classes, dim=output_dim,\n", + " n_anchors=n_anchors, n_comp=n_comp, d_comp=d_comp,\n", + " anchor_drop=anchor_drop, activation=activation,\n", + " cv_target=cv_target, infonce_temp=infonce_temp,\n", + " assign_temp=assign_temp, assign_sharpness=assign_sharpness,\n", + " )\n", + " self._init_encoder_weights()\n", + " def _init_encoder_weights(self):\n", + " for m in self.encoder.modules():\n", + " if isinstance(m, nn.Linear):\n", + " nn.init.trunc_normal_(m.weight, std=0.02)\n", + " if m.bias is not None: nn.init.zeros_(m.bias)\n", + " elif isinstance(m, nn.Conv2d):\n", + " nn.init.kaiming_normal_(m.weight, mode='fan_out')\n", + " if m.bias is not None: nn.init.zeros_(m.bias)\n", + " elif isinstance(m, (nn.BatchNorm2d, nn.LayerNorm)):\n", + " nn.init.ones_(m.weight); nn.init.zeros_(m.bias)\n", + " def _encode(self, x):\n", + " \"\"\"Pixels \u2192 S^(d-1) + per-anchor magnitude.\"\"\"\n", + " feat = self.encoder(x)\n", + " raw_mag = feat.norm(dim=-1, keepdim=True)\n", + " emb = F.normalize(feat, dim=-1)\n", + " anchors_n = F.normalize(self.core.constellation.anchors, dim=-1)\n", + " tri = emb @ anchors_n.T\n", + " mag, mag_comp = self.mag_flow(emb, tri, raw_mag)\n", + " return emb, mag, mag_comp\n", + " def forward_paired(self, v1, v2):\n", + " \"\"\"Training: two views \u2192 full pipeline.\"\"\"\n", + " emb1, mag1, mc1 = self._encode(v1)\n", + " emb2, mag2, mc2 = self._encode(v2)\n", + " out = self.core.forward_paired(emb1, emb2, mag1, mag2)\n", + " out['mag_comp1'] = mc1\n", + " out['mag_comp2'] = mc2\n", + " return out\n", + " def forward(self, x):\n", + " \"\"\"Eval: single view \u2192 classify.\"\"\"\n", + " emb, mag, mag_comp = self._encode(x)\n", + " out = self.core(emb, mag)\n", + " out['mag_comp'] = mag_comp\n", + " return out\n", + " def compute_loss(self, output, targets, **kwargs):\n", + " \"\"\"Delegate to core's three-domain loss.\"\"\"\n", + " return self.core.compute_loss(output, targets, **kwargs)\n", + " def get_anchor_param_ids(self):\n", + " \"\"\"Return set of param ids that should have weight_decay=0.\n", + " Includes constellation anchors + all relay layer anchors.\n", + " \"\"\"\n", + " ids = set(id(p) for p in self.core.constellation.parameters())\n", + " for relay in self.mag_flow.relays:\n", + " ids.add(id(relay.anchors))\n", + " return ids\n", + " def make_optimizer(self, lr=3e-4, weight_decay=0.05):\n", + " \"\"\"Build AdamW with proper anchor exclusion from weight decay.\"\"\"\n", + " anchor_ids = self.get_anchor_param_ids()\n", + " decay = [p for p in self.parameters() if id(p) not in anchor_ids]\n", + " nodecay = [p for p in self.parameters() if id(p) in anchor_ids]\n", + " return torch.optim.AdamW([\n", + " {'params': decay, 'weight_decay': weight_decay},\n", + " {'params': nodecay, 'weight_decay': 0.0},\n", + " ], lr=lr)\n", + " def summary(self):\n", + " \"\"\"Print parameter breakdown.\"\"\"\n", + " print(\"GeoLIPImageEncoder Summary\")\n", + " print(\"=\" * 50)\n", + " param_count(self.encoder, \"encoder\")\n", + " param_count(self.mag_flow, \"mag_flow\")\n", + " param_count(self.core.constellation, \"constellation\")\n", + " param_count(self.core.patchwork, \"patchwork\")\n", + " param_count(self.core.bridge, \"bridge\")\n", + " param_count(self.core.task_head, \"task_head\")\n", + " print(\"-\" * 50)\n", + " total = model_summary(self)\n", + " print(f\"\\n Config: {self.config}\")\n", + " return total\n", + "\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "# SPECTRAL GEOLIP ENCODER \u2014 Replaces ConvEncoder with deterministic front-end\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "\n", + "class SpectralGeoLIPEncoder(nn.Module):\n", + " \"\"\"Deterministic front-end \u2192 proj \u2192 S^(d-1) \u2192 MagnitudeFlow \u2192 Core.\n", + "\n", + " This replaces GeoLIPImageEncoder's ConvEncoder with any spectral/geometric\n", + " front-end. The front-end is NOT learned (torch.no_grad). All learning\n", + " happens in: projection, MagnitudeFlow relays, constellation anchors,\n", + " patchwork compartments, bridge, and task head.\n", + "\n", + " Args:\n", + " front_end: nn.Module \u2014 deterministic transform (Scattering2D, Gabor, etc.)\n", + " front_end_dim: int \u2014 flattened output dimension of front_end\n", + " num_classes: classification targets (10 for CIFAR-10)\n", + " output_dim: embedding dimension on S^(d-1) (default 256)\n", + " n_anchors: constellation anchors (default 128)\n", + " n_comp: patchwork compartments (default 8)\n", + " d_comp: per-compartment hidden dim (default 64)\n", + " anchor_drop: training anchor dropout (default 0.15)\n", + " activation: activation function name (default 'squared_relu')\n", + " cv_target: CV loss target (default 0.22)\n", + " infonce_temp: embedding NCE temperature (default 0.07)\n", + " assign_temp: assignment temperature (default 0.1)\n", + " assign_sharpness: BCE sharpness (default 5.0)\n", + " mag_hidden: MagnitudeFlow relay hidden dim (default 64)\n", + " mag_layers: relay layers in MagnitudeFlow (default 2)\n", + " mag_min: minimum magnitude (default 0.1)\n", + " mag_max: maximum magnitude (default 5.0)\n", + " \"\"\"\n", + " def __init__(\n", + " self,\n", + " front_end,\n", + " front_end_dim,\n", + " num_classes=10,\n", + " output_dim=256,\n", + " n_anchors=128,\n", + " n_comp=8,\n", + " d_comp=64,\n", + " anchor_drop=0.15,\n", + " activation='squared_relu',\n", + " cv_target=0.22,\n", + " infonce_temp=0.07,\n", + " assign_temp=0.1,\n", + " assign_sharpness=5.0,\n", + " mag_hidden=64,\n", + " mag_layers=2,\n", + " mag_min=0.1,\n", + " mag_max=5.0,\n", + " ):\n", + " super().__init__()\n", + " self.output_dim = output_dim\n", + " self.front_end = front_end\n", + " # Freeze front-end\n", + " for p in self.front_end.parameters():\n", + " p.requires_grad = False\n", + "\n", + " # Projection: front_end_dim \u2192 output_dim\n", + " self.proj = nn.Sequential(\n", + " nn.Linear(front_end_dim, output_dim),\n", + " nn.LayerNorm(output_dim),\n", + " )\n", + "\n", + " # MagnitudeFlow\n", + " self.mag_flow = MagnitudeFlow(\n", + " dim=output_dim, n_anchors=n_anchors,\n", + " hidden_dim=mag_hidden, n_heads=4, n_layers=mag_layers,\n", + " mag_min=mag_min, mag_max=mag_max, n_comp=n_comp,\n", + " )\n", + "\n", + " # InternalConstellationCore \u2014 three-domain head\n", + " self.core = InternalConstellationCore(\n", + " num_classes=num_classes, dim=output_dim,\n", + " n_anchors=n_anchors, n_comp=n_comp, d_comp=d_comp,\n", + " anchor_drop=anchor_drop, activation=activation,\n", + " cv_target=cv_target, infonce_temp=infonce_temp,\n", + " assign_temp=assign_temp, assign_sharpness=assign_sharpness,\n", + " )\n", + "\n", + " self.config = {k: v for k, v in locals().items()\n", + " if k not in ('self', 'front_end') and not k.startswith('_')}\n", + "\n", + " def _encode(self, x):\n", + " \"\"\"Front-end \u2192 proj \u2192 S^(d-1) + per-anchor magnitude.\"\"\"\n", + " with torch.no_grad():\n", + " feat = self.front_end(x)\n", + " if feat.dim() > 2:\n", + " feat = feat.reshape(feat.shape[0], -1)\n", + " feat = self.proj(feat)\n", + " raw_mag = feat.norm(dim=-1, keepdim=True)\n", + " emb = F.normalize(feat, dim=-1)\n", + " anchors_n = F.normalize(self.core.constellation.anchors, dim=-1)\n", + " tri = emb @ anchors_n.T\n", + " mag, mag_comp = self.mag_flow(emb, tri, raw_mag)\n", + " return emb, mag, mag_comp\n", + "\n", + " def forward_paired(self, v1, v2):\n", + " \"\"\"Training: two views \u2192 full pipeline with all intermediates.\"\"\"\n", + " emb1, mag1, mc1 = self._encode(v1)\n", + " emb2, mag2, mc2 = self._encode(v2)\n", + " out = self.core.forward_paired(emb1, emb2, mag1, mag2)\n", + " out['mag_comp1'] = mc1\n", + " out['mag_comp2'] = mc2\n", + " return out\n", + "\n", + " def forward(self, x):\n", + " \"\"\"Eval: single view \u2192 classify.\"\"\"\n", + " emb, mag, mag_comp = self._encode(x)\n", + " out = self.core(emb, mag)\n", + " out['mag_comp'] = mag_comp\n", + " return out\n", + "\n", + " def compute_loss(self, output, targets, **kwargs):\n", + " \"\"\"Delegate to core's three-domain loss.\"\"\"\n", + " return self.core.compute_loss(output, targets, **kwargs)\n", + "\n", + " def get_anchor_param_ids(self):\n", + " ids = set(id(p) for p in self.core.constellation.parameters())\n", + " for relay in self.mag_flow.relays:\n", + " ids.add(id(relay.anchors))\n", + " return ids\n", + "\n", + " def make_optimizer(self, lr=3e-4, weight_decay=0.05):\n", + " \"\"\"AdamW with anchor exclusion from weight decay.\"\"\"\n", + " anchor_ids = self.get_anchor_param_ids()\n", + " decay = [p for p in self.parameters() if p.requires_grad and id(p) not in anchor_ids]\n", + " nodecay = [p for p in self.parameters() if p.requires_grad and id(p) in anchor_ids]\n", + " return torch.optim.AdamW([\n", + " {'params': decay, 'weight_decay': weight_decay},\n", + " {'params': nodecay, 'weight_decay': 0.0},\n", + " ], lr=lr)\n", + "\n", + " def summary(self):\n", + " print(\"SpectralGeoLIPEncoder Summary\")\n", + " print(\"=\" * 50)\n", + " param_count(self.proj, \"projection\")\n", + " param_count(self.mag_flow, \"mag_flow\")\n", + " param_count(self.core.constellation, \"constellation\")\n", + " param_count(self.core.patchwork, \"patchwork\")\n", + " param_count(self.core.bridge, \"bridge\")\n", + " param_count(self.core.task_head, \"task_head\")\n", + " print(\"-\" * 50)\n", + " total = sum(p.numel() for p in self.parameters())\n", + " trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)\n", + " print(f\" Total: {total:,} ({trainable:,} trainable)\")\n", + " return total\n", + "\n", + "\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "# DATA LOADERS\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "\n", + "CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)\n", + "CIFAR10_STD = (0.2470, 0.2435, 0.2616)\n", + "CIFAR10_CLASSES = (\"plane\",\"car\",\"bird\",\"cat\",\"deer\",\"dog\",\"frog\",\"horse\",\"ship\",\"truck\")\n", + "\n", + "class PairedTransform:\n", + " \"\"\"Apply two independent augmentations to produce paired views.\"\"\"\n", + " def __init__(self, base_transform):\n", + " self.t = base_transform\n", + " def __call__(self, img):\n", + " return self.t(img), self.t(img)\n", + "\n", + "def get_paired_loaders(batch_size=512, num_workers=4):\n", + " \"\"\"Return paired train_loader (two augmented views) and val_loader (single view).\"\"\"\n", + " tf_aug = T.Compose([\n", + " T.RandomCrop(32, padding=4),\n", + " T.RandomHorizontalFlip(),\n", + " T.ToTensor(),\n", + " ])\n", + " tf_val = T.Compose([T.ToTensor()])\n", + "\n", + " train_ds = torchvision.datasets.CIFAR10(\n", + " root=\"./data\", train=True, download=True,\n", + " transform=PairedTransform(tf_aug))\n", + " val_ds = torchvision.datasets.CIFAR10(\n", + " root=\"./data\", train=False, download=True,\n", + " transform=tf_val)\n", + "\n", + " def paired_collate(batch):\n", + " v1s, v2s, labels = [], [], []\n", + " for (v1, v2), label in batch:\n", + " v1s.append(v1); v2s.append(v2); labels.append(label)\n", + " return torch.stack(v1s), torch.stack(v2s), torch.tensor(labels)\n", + "\n", + " train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,\n", + " num_workers=num_workers, pin_memory=True,\n", + " drop_last=True, collate_fn=paired_collate,\n", + " persistent_workers=(num_workers > 0))\n", + " val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,\n", + " num_workers=num_workers, pin_memory=True,\n", + " persistent_workers=(num_workers > 0))\n", + " print(f\"[DATA] CIFAR-10 paired: {len(train_ds)} train, {len(val_ds)} val, bs={batch_size}\")\n", + " return train_loader, val_loader\n", + "\n", + "\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "# TENSORBOARD HELPERS\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "\n", + "def make_writer(experiment_id, base_dir=\"runs\"):\n", + " timestamp = datetime.datetime.now().strftime(\"%Y%m%d_%H%M%S\")\n", + " log_dir = os.path.join(base_dir, f\"{experiment_id}_{timestamp}\")\n", + " writer = SummaryWriter(log_dir=log_dir)\n", + " print(f\"[TB] Logging to {log_dir}\")\n", + " return writer\n", + "\n", + "def log_loss_dict(writer, ld, step, prefix=\"train\"):\n", + " \"\"\"Log every key in the loss dict to TensorBoard.\n", + " Accepts both tensors and floats \u2014 converts tensors with ld_to_scalars first.\"\"\"\n", + " scalars = ld_to_scalars(ld) if any(isinstance(v, torch.Tensor) for v in ld.values()) else ld\n", + " for k, v in scalars.items():\n", + " if isinstance(v, (int, float)):\n", + " writer.add_scalar(f\"{prefix}/{k}\", v, step)\n", + "\n", + "def log_anchor_stats(writer, constellation, step):\n", + " \"\"\"Log anchor norms, inter-anchor similarity, and utilization.\"\"\"\n", + " with torch.no_grad():\n", + " a = F.normalize(constellation.anchors, dim=-1)\n", + " sim = a @ a.T\n", + " mask = ~torch.eye(a.shape[0], dtype=torch.bool, device=a.device)\n", + " writer.add_histogram(\"anchors/inter_cos\", sim[mask].cpu(), step)\n", + " writer.add_scalar(\"anchors/mean_inter_cos\", sim[mask].mean().item(), step)\n", + " writer.add_histogram(\"anchors/norms\", constellation.anchors.norm(dim=-1).cpu(), step)\n", + "\n", + "def log_embedding_scatter(writer, embeddings, labels, step, tag=\"val_scatter\"):\n", + " \"\"\"Log PCA scatter of embeddings.\"\"\"\n", + " try:\n", + " from sklearn.decomposition import PCA\n", + " emb_np = embeddings.detach().cpu().numpy()\n", + " lab_np = labels if isinstance(labels, np.ndarray) else np.array(labels)\n", + " proj = PCA(n_components=2).fit_transform(emb_np)\n", + " fig, ax = plt.subplots(1, 1, figsize=(8, 8))\n", + " for c in range(10):\n", + " m = lab_np == c\n", + " ax.scatter(proj[m, 0], proj[m, 1], s=3, alpha=0.5, label=CIFAR10_CLASSES[c])\n", + " ax.legend(markerscale=5, fontsize=8)\n", + " ax.set_title(f\"{tag} \u2014 Step {step}\")\n", + " fig.tight_layout()\n", + " writer.add_figure(f\"{tag}/pca\", fig, step)\n", + " plt.close(fig)\n", + " except ImportError:\n", + " pass\n", + "\n", + "\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "# TRAINING HARNESS \u2014 paired views, three-domain loss, AnchorPush\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "\n", + "def train_spectral_geolip(\n", + " experiment_id,\n", + " model,\n", + " train_loader,\n", + " val_loader,\n", + " device,\n", + " epochs=60,\n", + " lr=3e-4,\n", + " weight_decay=0.05,\n", + " push_every=5,\n", + " push_strategy='momentum',\n", + " log_every=1,\n", + " patience=25,\n", + " loss_weights=None,\n", + " use_amp=True,\n", + " use_compile=True,\n", + "):\n", + " \"\"\"Full training loop \u2014 sync-free hot path, AMP BF16, torch.compile.\n", + "\n", + " PERFORMANCE NOTES:\n", + " - Zero .item() calls in training loop (all loss dict values are tensors)\n", + " - AMP BF16 autocast for Blackwell/Ampere throughput\n", + " - torch.compile for kernel fusion (if available)\n", + " - ld_to_scalars() called ONLY at epoch-level logging, not per batch\n", + " - non_blocking transfers for overlapped H2D copies\n", + " \"\"\"\n", + " model = model.to(device)\n", + "\n", + " # \u2500\u2500 AMP setup \u2500\u2500\n", + " amp_dtype = torch.bfloat16 if (device.type == 'cuda' and torch.cuda.is_bf16_supported()) else torch.float16\n", + " scaler = torch.amp.GradScaler('cuda', enabled=(use_amp and device.type == 'cuda' and amp_dtype == torch.float16))\n", + " use_autocast = use_amp and device.type == 'cuda'\n", + " if use_autocast:\n", + " print(f\"[PERF] AMP enabled \u2014 dtype={amp_dtype}\")\n", + "\n", + " # \u2500\u2500 torch.compile \u2500\u2500\n", + " # Use mode='default' (dynamo tracing) not 'reduce-overhead' (CUDA graphs)\n", + " # because CUDA graphs don't support dynamic ops like torch.linalg.svd/qr/eigvalsh\n", + " if use_compile and hasattr(torch, 'compile'):\n", + " try:\n", + " model = torch.compile(model, mode='default')\n", + " print(\"[PERF] torch.compile enabled (default \u2014 dynamo tracing)\")\n", + " except Exception as e:\n", + " print(f\"[PERF] torch.compile skipped: {e}\")\n", + "\n", + " writer = make_writer(experiment_id)\n", + " optimizer = model.make_optimizer(lr=lr, weight_decay=weight_decay)\n", + " scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)\n", + " anchor_push = AnchorPush(push_strategy, model.core.n_anchors, model.output_dim)\n", + "\n", + " # Loss weight defaults\n", + " lw = dict(w_ce=1.0, w_nce_emb=0.5, w_nce_pw=1.0, w_bridge=1.0,\n", + " w_assign=0.5, w_assign_nce=0.25, w_nce_tri=0.5, w_attract=0.25,\n", + " w_cv=0.01, w_spread=0.01)\n", + " if loss_weights:\n", + " lw.update(loss_weights)\n", + "\n", + " best_val_acc = 0.0\n", + " best_epoch = 0\n", + " no_improve = 0\n", + " n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", + "\n", + " writer.add_text(\"config/params\", f\"Trainable: {n_params:,}\")\n", + " writer.add_text(\"config/experiment\", experiment_id)\n", + " writer.add_text(\"config/loss_weights\", json.dumps(lw))\n", + " writer.add_text(\"config/amp\", str(amp_dtype) if use_autocast else \"disabled\")\n", + " model.summary()\n", + " print(f\"\\n{'='*70}\")\n", + " print(f\"[EXP] {experiment_id} | {n_params:,} trainable params | {epochs} epochs\")\n", + " print(f\"{'='*70}\")\n", + "\n", + " emb_buffer = []\n", + " lbl_buffer = []\n", + "\n", + " for epoch in range(1, epochs + 1):\n", + " model.train()\n", + " t0 = time.time()\n", + " # Accumulate on GPU as tensors \u2014 NO .item() in loop\n", + " epoch_loss_sum = torch.tensor(0.0, device=device)\n", + " epoch_acc_sum = torch.tensor(0.0, device=device)\n", + " n_batches = 0\n", + "\n", + " for v1, v2, labels in train_loader:\n", + " v1 = v1.to(device, non_blocking=True)\n", + " v2 = v2.to(device, non_blocking=True)\n", + " labels = labels.to(device, non_blocking=True)\n", + " optimizer.zero_grad(set_to_none=True)\n", + "\n", + " # \u2500\u2500 Forward + loss under autocast \u2500\u2500\n", + " with torch.amp.autocast('cuda', dtype=amp_dtype, enabled=use_autocast):\n", + " output = model.forward_paired(v1, v2)\n", + " loss, ld = model.compute_loss(output, labels, compute_knn=False, **lw)\n", + "\n", + " # \u2500\u2500 Backward (scaler handles FP16; BF16 doesn't need scaling) \u2500\u2500\n", + " if amp_dtype == torch.float16 and use_autocast:\n", + " scaler.scale(loss).backward()\n", + " scaler.unscale_(optimizer)\n", + " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n", + " scaler.step(optimizer)\n", + " scaler.update()\n", + " else:\n", + " loss.backward()\n", + " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n", + " optimizer.step()\n", + "\n", + " # Accumulate on device \u2014 tensors, no sync\n", + " epoch_loss_sum += loss.detach()\n", + " acc_val = ld.get('acc', None)\n", + " if acc_val is not None:\n", + " epoch_acc_sum += acc_val.detach() if isinstance(acc_val, torch.Tensor) else acc_val\n", + " n_batches += 1\n", + "\n", + " # Buffer for AnchorPush\n", + " if epoch % push_every == 0 or epoch == 1:\n", + " emb_buffer.append(output['embedding'].detach())\n", + " lbl_buffer.append(labels.detach())\n", + "\n", + " scheduler.step()\n", + " # Single sync point per epoch \u2014 pull two scalars\n", + " avg_loss = (epoch_loss_sum / n_batches).item()\n", + " avg_acc = (epoch_acc_sum / n_batches).item() * 100\n", + "\n", + " # AnchorPush\n", + " if (epoch % push_every == 0 or epoch == 1) and len(emb_buffer) > 0:\n", + " all_emb = torch.cat(emb_buffer[-20:], dim=0)\n", + " all_lbl = torch.cat(lbl_buffer[-20:], dim=0)\n", + " push_diag = anchor_push.push(model.core, all_emb, all_lbl)\n", + " model.mag_flow.update_stats(push_diag, anchor_push)\n", + " if epoch % log_every == 0:\n", + " for k, v in push_diag.items():\n", + " if isinstance(v, (int, float)):\n", + " writer.add_scalar(f\"push/{k}\", v, epoch)\n", + " emb_buffer.clear()\n", + " lbl_buffer.clear()\n", + "\n", + " # Validation\n", + " model.eval()\n", + " val_correct = 0\n", + " val_total = 0\n", + " val_loss_sum = 0.0\n", + " val_embs = []\n", + " val_labels_list = []\n", + " do_knn = (epoch % (log_every * 5) == 0 or epoch == epochs)\n", + "\n", + " with torch.no_grad():\n", + " for images, labels in val_loader:\n", + " images = images.to(device, non_blocking=True)\n", + " labels = labels.to(device, non_blocking=True)\n", + " with torch.amp.autocast('cuda', dtype=amp_dtype, enabled=use_autocast):\n", + " out = model(images)\n", + " logits = out['logits']\n", + " val_loss_sum += F.cross_entropy(logits, labels).item()\n", + " pred = logits.argmax(-1)\n", + " val_correct += (pred == labels).sum().item()\n", + " val_total += labels.size(0)\n", + " if do_knn:\n", + " val_embs.append(out['embedding'].cpu())\n", + " val_labels_list.extend(labels.cpu().numpy())\n", + "\n", + " val_acc = 100.0 * val_correct / val_total\n", + " val_loss = val_loss_sum / len(val_loader)\n", + "\n", + " # kNN accuracy on validation embeddings (only at log intervals)\n", + " knn_acc = 0.0\n", + " if do_knn and len(val_embs) > 0:\n", + " all_val_emb = torch.cat(val_embs, dim=0).to(device)\n", + " all_val_lbl = torch.tensor(val_labels_list, device=device)\n", + " knn_acc = knn_accuracy(all_val_emb[:5000], all_val_lbl[:5000]) * 100\n", + "\n", + " # Log\n", + " if epoch % log_every == 0:\n", + " writer.add_scalar(\"train/loss\", avg_loss, epoch)\n", + " writer.add_scalar(\"train/acc\", avg_acc, epoch)\n", + " writer.add_scalar(\"val/loss\", val_loss, epoch)\n", + " writer.add_scalar(\"val/acc\", val_acc, epoch)\n", + " if do_knn:\n", + " writer.add_scalar(\"val/knn_acc\", knn_acc, epoch)\n", + " writer.add_scalar(\"lr\", optimizer.param_groups[0][\"lr\"], epoch)\n", + " writer.add_scalar(\"train_val_gap\", avg_acc - val_acc, epoch)\n", + " if ld is not None:\n", + " log_loss_dict(writer, ld, epoch, \"loss_components\")\n", + " log_anchor_stats(writer, model.core.constellation, epoch)\n", + " for rd in model.mag_flow.get_relay_diagnostics():\n", + " writer.add_scalar(f\"relay/drift_L{rd['layer']}\", rd['drift_mean'], epoch)\n", + " writer.add_scalar(f\"relay/gate_L{rd['layer']}\", rd['gate_mean'], epoch)\n", + "\n", + " # Embedding visualization\n", + " if do_knn and len(val_embs) > 0:\n", + " log_embedding_scatter(writer, torch.cat(val_embs, 0)[:2000],\n", + " val_labels_list[:2000], epoch)\n", + "\n", + " # Best tracking\n", + " if val_acc > best_val_acc:\n", + " best_val_acc = val_acc\n", + " best_epoch = epoch\n", + " no_improve = 0\n", + " else:\n", + " no_improve += 1\n", + "\n", + " # Print\n", + " if epoch % 5 == 0 or epoch == 1 or epoch == epochs:\n", + " elapsed = time.time() - t0\n", + " knn_str = f\" | kNN {knn_acc:5.1f}%\" if do_knn else \"\"\n", + " print(f\" E{epoch:3d} | Train {avg_acc:5.1f}% | Val {val_acc:5.1f}%{knn_str}\"\n", + " f\" | Loss {avg_loss:.4f} | Best {best_val_acc:.1f}% @E{best_epoch}\"\n", + " f\" | {elapsed:.1f}s\")\n", + "\n", + " # Early stopping\n", + " if no_improve >= patience:\n", + " print(f\" [EARLY STOP] No improvement for {patience} epochs.\")\n", + " break\n", + "\n", + " writer.close()\n", + " print(f\"\\n[RESULT] {experiment_id}: Best Val = {best_val_acc:.2f}% @E{best_epoch} \"\n", + " f\"| Params: {n_params:,}\")\n", + " return {\"experiment\": experiment_id, \"best_val_acc\": best_val_acc,\n", + " \"best_epoch\": best_epoch, \"params\": n_params}\n", + "\n", + "\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "# SCOREBOARD\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "\n", + "class Scoreboard:\n", + " def __init__(self):\n", + " self.results = []\n", + " def add(self, r):\n", + " self.results.append(r)\n", + " self._show()\n", + " def _show(self):\n", + " print(f\"\\n{'='*80}\")\n", + " print(f\"{'SCOREBOARD':^80}\")\n", + " print(f\"{'='*80}\")\n", + " print(f\"{'Experiment':<45} {'Val%':>7} {'Params':>10} {'Epoch':>6}\")\n", + " print(f\"{'-'*45} {'-'*7} {'-'*10} {'-'*6}\")\n", + " for r in sorted(self.results, key=lambda x: x['best_val_acc'], reverse=True):\n", + " print(f\"{r['experiment']:<45} {r['best_val_acc']:>6.2f}% {r['params']:>10,} {r['best_epoch']:>6}\")\n", + " print(f\"{'='*80}\")\n", + " def save(self, path=\"scoreboard.json\"):\n", + " with open(path, \"w\") as f:\n", + " json.dump(self.results, f, indent=2)\n", + "\n", + "scoreboard = Scoreboard()\n", + "\n", + "\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "# HUGGINGFACE UPLOAD \u2014 push TB runs + scoreboard after experiments\n", + "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n", + "\n", + "HF_REPO = \"AbstractPhil/geolip-hypersphere-experiments\"\n", + "HF_SUBDIR = \"spectral/notebooks\"\n", + "\n", + "def upload_runs_to_hf(notebook_name, scoreboard_path=None, runs_dir=\"runs\"):\n", + " \"\"\"Upload TensorBoard runs and scoreboard to HuggingFace repo.\n", + "\n", + " Uploads to: {HF_REPO}/spectral/notebooks/{notebook_name}/\n", + " Requires HF_TOKEN in Colab secrets or environment.\n", + " \"\"\"\n", + " try:\n", + " from huggingface_hub import HfApi, login\n", + " except ImportError:\n", + " print(\"[HF] huggingface_hub not installed \u2014 skipping upload\")\n", + " return None\n", + "\n", + " # Get token\n", + " token = os.environ.get('HF_TOKEN') or os.environ.get('HUGGINGFACE_TOKEN')\n", + " if token is None:\n", + " try:\n", + " from google.colab import userdata\n", + " token = userdata.get('HF_TOKEN')\n", + " except Exception:\n", + " pass\n", + " if token is None:\n", + " print(\"[HF] No HF_TOKEN found. Set it in Colab secrets or environment.\")\n", + " print(\" Colab: Settings \u2192 Secrets \u2192 Add HF_TOKEN\")\n", + " return None\n", + "\n", + " api = HfApi(token=token)\n", + " target = f\"{HF_SUBDIR}/{notebook_name}\"\n", + " print(f\"[HF] Uploading to {HF_REPO}/{target} ...\")\n", + "\n", + " # Upload runs directory (TensorBoard logs)\n", + " if os.path.exists(runs_dir) and os.listdir(runs_dir):\n", + " api.upload_folder(\n", + " folder_path=runs_dir,\n", + " repo_id=HF_REPO,\n", + " path_in_repo=f\"{target}/runs\",\n", + " token=token,\n", + " commit_message=f\"Upload TB runs: {notebook_name}\",\n", + " )\n", + " n_runs = len(os.listdir(runs_dir))\n", + " print(f\"[HF] Uploaded {n_runs} TB run(s) to {target}/runs/\")\n", + " else:\n", + " print(f\"[HF] No runs directory found at {runs_dir}\")\n", + "\n", + " # Upload scoreboard JSON if provided\n", + " if scoreboard_path and os.path.exists(scoreboard_path):\n", + " api.upload_file(\n", + " path_or_fileobj=scoreboard_path,\n", + " repo_id=HF_REPO,\n", + " path_in_repo=f\"{target}/scoreboard.json\",\n", + " token=token,\n", + " commit_message=f\"Upload scoreboard: {notebook_name}\",\n", + " )\n", + " print(f\"[HF] Uploaded scoreboard to {target}/scoreboard.json\")\n", + "\n", + " url = f\"https://huggingface.co/{HF_REPO}/tree/main/{target}\"\n", + " print(f\"[HF] Done \u2192 {url}\")\n", + " return url\n", + "\n", + "print(\"[SHARED] GeoLIP architecture loaded \u2014 full three-domain pipeline ready.\")\n", + "" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# @title Launch TensorBoard\n", + "%tensorboard --logdir runs\n", + "" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 8.1 \u2014 Grassmannian V-features (Baseline)\n", + "\n", + "Right singular vectors V encode subspace orientation in channel space.\n", + "This is the Grassmannian coordinate \u2014 which RGB mixing direction captures most variance.\n", + "Reproduction of NB2 experiment as baseline for this notebook." + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# @title Experiment 8.1 \u2014 Grassmannian V-features (Baseline)\n", + "class GrassVFrontEnd(nn.Module):\n", + " \"\"\"Baseline Grassmannian: S + log-ratios + V (right singular vectors).\"\"\"\n", + " def __init__(self, patch_size=8, k=3, input_size=32):\n", + " super().__init__()\n", + " self.patch_size = patch_size\n", + " self.k = k\n", + " self.C = 3\n", + " n_h = input_size // patch_size\n", + " self.n_patches = n_h * n_h\n", + " self.features_per_patch = k + k + k * self.C # S + logS + V = 3+3+9 = 15\n", + " self.output_dim = self.n_patches * self.features_per_patch\n", + " backend = \"Triton SVD3\" if _HAS_TRITON_SVD3 else \"torch.linalg.svd\"\n", + " print(f\"[8.1 GRASS-V] {self.n_patches} patches, dim={self.output_dim} ({backend})\")\n", + "\n", + " @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n", + " def forward(self, x):\n", + " B, C, H, W = x.shape\n", + " ps = self.patch_size\n", + " patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n", + " n_p = patches.shape[2] * patches.shape[3]\n", + " X = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1)\n", + " U, S, Vh = batched_svd3(X)\n", + " k = self.k\n", + " sv_ratios = torch.log(S[:, :k] / (S[:, k-1:k] + 1e-8) + 1e-8)\n", + " V_feat = Vh[:, :k, :].reshape(-1, k * C)\n", + " feats = torch.cat([S[:, :k], sv_ratios, V_feat], dim=-1)\n", + " return feats.reshape(B, -1)\n", + "\n", + "front = GrassVFrontEnd(patch_size=8, k=3).to(device)\n", + "model_8_1 = SpectralGeoLIPEncoder(\n", + " front_end=front, front_end_dim=front.output_dim,\n", + " num_classes=10, output_dim=256, n_anchors=128, n_comp=8, d_comp=64,\n", + ").to(device)\n", + "model_8_1.summary()\n", + "\n", + "train_loader, val_loader = get_paired_loaders(batch_size=512)\n", + "result = train_spectral_geolip(\"8.1_grass_v_baseline\", model_8_1,\n", + " train_loader, val_loader, device, epochs=60)\n", + "scoreboard.add(result)\n", + "" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 8.2 \u2014 Full SVD Signature\n", + "\n", + "Everything the SVD gives us: S, V, plus U projection norms.\n", + "U columns are orthonormal (so U\u1d40U \u2248 I is trivial), but U\u1d40X gives\n", + "how much of the original data each principal component captures \u2014 these vary per patch." + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# @title Experiment 8.2 \u2014 Full SVD Signature\n", + "class FullSVDFrontEnd(nn.Module):\n", + " \"\"\"Full SVD: S + log-ratios + V + U projection norms onto original data.\"\"\"\n", + " def __init__(self, patch_size=8, k=3, input_size=32):\n", + " super().__init__()\n", + " self.patch_size = patch_size\n", + " self.k = k\n", + " self.C = 3\n", + " n_h = input_size // patch_size\n", + " self.n_patches = n_h * n_h\n", + " # S(3) + logS(3) + V(9) + proj_norms(3) = 18\n", + " self.features_per_patch = k + k + k * self.C + k\n", + " self.output_dim = self.n_patches * self.features_per_patch\n", + " print(f\"[8.2 FULL-SVD] {self.n_patches} patches, dim={self.output_dim}\")\n", + "\n", + " @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n", + " def forward(self, x):\n", + " B, C, H, W = x.shape\n", + " ps = self.patch_size\n", + " patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n", + " n_p = patches.shape[2] * patches.shape[3]\n", + " X = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1)\n", + " U, S, Vh = batched_svd3(X)\n", + " k = self.k\n", + " sv_ratios = torch.log(S[:, :k] / (S[:, k-1:k] + 1e-8) + 1e-8)\n", + " V_feat = Vh[:, :k, :].reshape(-1, k * C)\n", + " # U^T X: how each principal component projects onto all channels\n", + " proj = torch.bmm(U[:, :, :k].transpose(1, 2), X) # (N, k, C)\n", + " proj_norms = proj.norm(dim=-1) # (N, k) \u2014 these are just S again for exact SVD\n", + " # But ratio of proj_norms to S reveals numerical conditioning\n", + " feats = torch.cat([S[:, :k], sv_ratios, V_feat, proj_norms], dim=-1)\n", + " return feats.reshape(B, -1)\n", + "\n", + "front = FullSVDFrontEnd(patch_size=8, k=3).to(device)\n", + "model_8_2 = SpectralGeoLIPEncoder(\n", + " front_end=front, front_end_dim=front.output_dim,\n", + " num_classes=10, output_dim=256, n_anchors=128, n_comp=8, d_comp=64,\n", + ").to(device)\n", + "model_8_2.summary()\n", + "\n", + "train_loader, val_loader = get_paired_loaders(batch_size=512)\n", + "result = train_spectral_geolip(\"8.2_full_svd\", model_8_2,\n", + " train_loader, val_loader, device, epochs=60)\n", + "scoreboard.add(result)\n", + "" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 8.3 \u2014 Per-Patch Procrustes Alignment\n", + "\n", + "For each patch \u00d7 template pair, compute 3\u00d73 color cross-covariance \u2192 SVD \u2192 Procrustes rotation.\n", + "Features: per-patch singular values (spatial alignment map) + per-template rotation statistics.\n", + "262K 3\u00d73 SVDs per batch via single Triton kernel call (~0.02ms)." + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# @title Experiment 8.3 \u2014 Per-Patch Procrustes\n", + "class PatchProcrustessFrontEnd(nn.Module):\n", + " \"\"\"Per-patch Procrustes: for each (patch, template_patch) pair, compute\n", + " 3\u00d73 color-space cross-covariance SVD. Features: per-patch singular values\n", + " (spatial alignment map) + per-template rotation statistics.\"\"\"\n", + " def __init__(self, n_templates=8, patch_size=4, input_size=32):\n", + " super().__init__()\n", + " self.patch_size = patch_size\n", + " self.n_templates = n_templates\n", + " ps2 = patch_size * patch_size\n", + " n_patches = (input_size // patch_size) ** 2\n", + " self.n_patches = n_patches\n", + " self.ps2 = ps2\n", + " # Templates: (T, n_patches, ps2, 3)\n", + " templates = torch.randn(n_templates, n_patches, ps2, 3)\n", + " templates = F.normalize(templates.reshape(n_templates, -1, 3), dim=-1)\n", + " templates = templates.reshape(n_templates, n_patches, ps2, 3)\n", + " self.register_buffer('templates', templates)\n", + " # Per-patch S map: T * n_patches * 3\n", + " # Per-template stats: T * 5 (mean_align, std_align, mean_trace, std_trace, proper_frac)\n", + " self.output_dim = n_templates * n_patches * 3 + n_templates * 5\n", + " print(f\"[8.3 PROCRUSTES] {n_templates}T \u00d7 {n_patches}P, dim={self.output_dim}\")\n", + "\n", + " @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n", + " def forward(self, x):\n", + " B, C, H, W = x.shape\n", + " T, n_p, ps2 = self.n_templates, self.n_patches, self.ps2\n", + " ps = self.patch_size\n", + " # Reshape to (B, n_patches, ps2, 3)\n", + " patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n", + " patches = patches.permute(0, 2, 3, 4, 5, 1).contiguous().reshape(B, n_p, ps2, C)\n", + " patches_n = F.normalize(patches, dim=-1)\n", + " # Expand: (B*T*n_p, ps2, 3)\n", + " p_exp = patches_n.unsqueeze(1).expand(B, T, n_p, ps2, C).reshape(B*T*n_p, ps2, C)\n", + " t_exp = self.templates.unsqueeze(0).expand(B, T, n_p, ps2, C).reshape(B*T*n_p, ps2, C)\n", + " # Cross-covariance: (B*T*n_p, 3, 3)\n", + " M = torch.bmm(p_exp.transpose(1, 2), t_exp)\n", + " # SVD via Triton kernel\n", + " U, S, Vh = batched_svd3(M)\n", + " R_opt = torch.bmm(U, Vh)\n", + " # Per-patch S: spatial alignment map\n", + " S_map = S.reshape(B, T * n_p * 3)\n", + " # Per-template aggregates\n", + " align_q = S.reshape(B, T, n_p, 3).sum(dim=-1) # (B, T, n_p)\n", + " rot_tr = R_opt.diagonal(dim1=-2, dim2=-1).sum(-1).reshape(B, T, n_p)\n", + " mean_align = align_q.mean(dim=-1) # (B, T)\n", + " std_align = align_q.std(dim=-1)\n", + " mean_trace = rot_tr.mean(dim=-1)\n", + " std_trace = rot_tr.std(dim=-1)\n", + " proper_frac = (rot_tr > 0).float().mean(dim=-1) # fraction proper rotations\n", + " agg = torch.cat([mean_align, std_align, mean_trace, std_trace, proper_frac], dim=-1)\n", + " return torch.cat([S_map, agg], dim=-1)\n", + "\n", + "front = PatchProcrustessFrontEnd(n_templates=8, patch_size=4).to(device)\n", + "model_8_3 = SpectralGeoLIPEncoder(\n", + " front_end=front, front_end_dim=front.output_dim,\n", + " num_classes=10, output_dim=256, n_anchors=128, n_comp=8, d_comp=64,\n", + ").to(device)\n", + "model_8_3.summary()\n", + "\n", + "train_loader, val_loader = get_paired_loaders(batch_size=512)\n", + "result = train_spectral_geolip(\"8.3_patch_procrustes\", model_8_3,\n", + " train_loader, val_loader, device, epochs=60)\n", + "scoreboard.add(result)\n", + "" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 8.4 \u2014 Singular Value Spatial Map\n", + "\n", + "Just the singular values, but at high spatial resolution.\n", + "4\u00d74 patches (64 per image) give a 64-position map of spectral profiles.\n", + "Tests whether S alone carries enough info when spatial layout is preserved." + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# @title Experiment 8.4 \u2014 Singular Value Spatial Map (Small Patches)\n", + "class SVSpatialFrontEnd(nn.Module):\n", + " \"\"\"Singular values at high spatial resolution (4\u00d74 patches = 64 patches).\"\"\"\n", + " def __init__(self, patch_size=4, input_size=32):\n", + " super().__init__()\n", + " self.patch_size = patch_size\n", + " n_h = input_size // patch_size\n", + " self.n_patches = n_h * n_h # 64\n", + " # S(3) + log_ratios(3) per patch\n", + " self.features_per_patch = 6\n", + " self.output_dim = self.n_patches * self.features_per_patch\n", + " print(f\"[8.4 SV-SPATIAL] {self.n_patches} patches ({patch_size}\u00d7{patch_size}), dim={self.output_dim}\")\n", + "\n", + " @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n", + " def forward(self, x):\n", + " B, C, H, W = x.shape\n", + " ps = self.patch_size\n", + " patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n", + " n_p = patches.shape[2] * patches.shape[3]\n", + " X = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1)\n", + " U, S, Vh = batched_svd3(X)\n", + " sv_ratios = torch.log(S / (S[:, 2:3] + 1e-8) + 1e-8)\n", + " feats = torch.cat([S, sv_ratios], dim=-1)\n", + " return feats.reshape(B, -1)\n", + "\n", + "front = SVSpatialFrontEnd(patch_size=4).to(device)\n", + "model_8_4 = SpectralGeoLIPEncoder(\n", + " front_end=front, front_end_dim=front.output_dim,\n", + " num_classes=10, output_dim=256, n_anchors=128, n_comp=8, d_comp=64,\n", + ").to(device)\n", + "model_8_4.summary()\n", + "\n", + "train_loader, val_loader = get_paired_loaders(batch_size=512)\n", + "result = train_spectral_geolip(\"8.4_sv_spatial_4x4\", model_8_4,\n", + " train_loader, val_loader, device, epochs=60)\n", + "scoreboard.add(result)\n", + "" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 8.5 \u2014 Subspace Angles Between Adjacent Patches\n", + "\n", + "The principal angle between two subspaces = arccos(\u03c3\u2081(V\u2081\u1d40V\u2082)).\n", + "For adjacent patches, this captures edge/texture boundaries \u2014 where the\n", + "color mixing direction changes sharply, there's an edge." + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# @title Experiment 8.5 \u2014 Subspace Angles (Patch Boundaries)\n", + "class SubspaceAngleFrontEnd(nn.Module):\n", + " \"\"\"Subspace angles between horizontally and vertically adjacent patches.\n", + " Principal angle = arccos(\u03c3\u2081(V\u2081\u1d40 V\u2082)). Captures edges and texture boundaries.\"\"\"\n", + " def __init__(self, patch_size=8, input_size=32):\n", + " super().__init__()\n", + " self.patch_size = patch_size\n", + " self.n_h = input_size // patch_size # patches per row/col\n", + " n_p = self.n_h * self.n_h\n", + " self.n_patches = n_p\n", + " # Per patch: S(3) + V(9) = 12\n", + " # Horizontal edges: n_h * (n_h - 1) = 4*3 = 12 edges, 3 angles each = 36\n", + " # Vertical edges: same = 36\n", + " n_h_edges = self.n_h * (self.n_h - 1)\n", + " n_v_edges = (self.n_h - 1) * self.n_h\n", + " self.n_h_edges = n_h_edges\n", + " self.n_v_edges = n_v_edges\n", + " self.output_dim = n_p * 12 + (n_h_edges + n_v_edges) * 3\n", + " print(f\"[8.5 SUBSPACE-ANGLE] {n_p} patches, {n_h_edges+n_v_edges} edges, dim={self.output_dim}\")\n", + "\n", + " @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n", + " def forward(self, x):\n", + " B, C, H, W = x.shape\n", + " ps = self.patch_size\n", + " n_h = self.n_h\n", + " patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n", + " n_p = patches.shape[2] * patches.shape[3]\n", + " X = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1)\n", + " U, S, Vh = batched_svd3(X)\n", + " # Per-patch features\n", + " V_feat = Vh.reshape(B * n_p, 9)\n", + " patch_feats = torch.cat([S, V_feat], dim=-1).reshape(B, n_p, 12)\n", + " # Reshape Vh to grid: (B, n_h, n_h, 3, 3)\n", + " Vh_grid = Vh.reshape(B, n_h, n_h, 3, 3)\n", + " # Horizontal subspace angles: V_i^T V_{i+1} for adjacent columns\n", + " Vh_left = Vh_grid[:, :, :-1].reshape(-1, 3, 3) # (B*n_h*(n_h-1), 3, 3)\n", + " Vh_right = Vh_grid[:, :, 1:].reshape(-1, 3, 3)\n", + " cross_h = torch.bmm(Vh_left, Vh_right.transpose(1, 2)) # (N, 3, 3)\n", + " # Singular values of V1^T V2 = cosines of principal angles\n", + " _, angles_h, _ = batched_svd3(cross_h) # (N, 3)\n", + " angles_h = torch.acos(angles_h.clamp(-1+1e-6, 1-1e-6)) # actual angles\n", + " # Vertical\n", + " Vh_top = Vh_grid[:, :-1, :].reshape(-1, 3, 3)\n", + " Vh_bot = Vh_grid[:, 1:, :].reshape(-1, 3, 3)\n", + " cross_v = torch.bmm(Vh_top, Vh_bot.transpose(1, 2))\n", + " _, angles_v, _ = batched_svd3(cross_v)\n", + " angles_v = torch.acos(angles_v.clamp(-1+1e-6, 1-1e-6))\n", + " return torch.cat([\n", + " patch_feats.reshape(B, -1),\n", + " angles_h.reshape(B, -1),\n", + " angles_v.reshape(B, -1),\n", + " ], dim=-1)\n", + "\n", + "front = SubspaceAngleFrontEnd(patch_size=8).to(device)\n", + "model_8_5 = SpectralGeoLIPEncoder(\n", + " front_end=front, front_end_dim=front.output_dim,\n", + " num_classes=10, output_dim=256, n_anchors=128, n_comp=8, d_comp=64,\n", + ").to(device)\n", + "model_8_5.summary()\n", + "\n", + "train_loader, val_loader = get_paired_loaders(batch_size=512)\n", + "result = train_spectral_geolip(\"8.5_subspace_angles\", model_8_5,\n", + " train_loader, val_loader, device, epochs=60)\n", + "scoreboard.add(result)\n", + "" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 8.6 \u2014 SVD Reconstruction Residuals\n", + "\n", + "Truncate SVD at k=1, k=2, k=3. The residual norms at each level\n", + "tell you how much information each additional principal component captures.\n", + "A patch dominated by luminance has low rank-1 residual; a chromatic patch has high." + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# @title Experiment 8.6 \u2014 Reconstruction Residuals\n", + "class ReconResidualFrontEnd(nn.Module):\n", + " \"\"\"SVD reconstruction residuals at each truncation level.\n", + " ||X - U_k S_k Vh_k||_F for k=1,2,3 gives a 3-level energy pyramid.\"\"\"\n", + " def __init__(self, patch_size=8, input_size=32):\n", + " super().__init__()\n", + " self.patch_size = patch_size\n", + " n_h = input_size // patch_size\n", + " self.n_patches = n_h * n_h\n", + " # Per patch: S(3) + residual norms at k=0,1,2 (3) + energy fractions (3) = 9\n", + " self.features_per_patch = 9\n", + " self.output_dim = self.n_patches * self.features_per_patch\n", + " print(f\"[8.6 RECON-RESID] {self.n_patches} patches, dim={self.output_dim}\")\n", + "\n", + " @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n", + " def forward(self, x):\n", + " B, C, H, W = x.shape\n", + " ps = self.patch_size\n", + " patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n", + " n_p = patches.shape[2] * patches.shape[3]\n", + " X = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1)\n", + " U, S, Vh = batched_svd3(X)\n", + " total_energy = (S * S).sum(dim=-1, keepdim=True) + 1e-10 # ||X||_F^2\n", + " # Cumulative energy captured: sum(S[:k]^2) / sum(S^2)\n", + " S2 = S * S\n", + " cum_energy = torch.cumsum(S2, dim=-1)\n", + " energy_frac = cum_energy / total_energy # (N, 3) \u2014 fraction captured at k=1,2,3\n", + " # Residual norms: sqrt(total - cumulative)\n", + " residual = torch.sqrt((total_energy - cum_energy).clamp(min=0))\n", + " feats = torch.cat([S, residual, energy_frac], dim=-1) # (N, 9)\n", + " return feats.reshape(B, -1)\n", + "\n", + "front = ReconResidualFrontEnd(patch_size=8).to(device)\n", + "model_8_6 = SpectralGeoLIPEncoder(\n", + " front_end=front, front_end_dim=front.output_dim,\n", + " num_classes=10, output_dim=256, n_anchors=128, n_comp=8, d_comp=64,\n", + ").to(device)\n", + "model_8_6.summary()\n", + "\n", + "train_loader, val_loader = get_paired_loaders(batch_size=512)\n", + "result = train_spectral_geolip(\"8.6_recon_residuals\", model_8_6,\n", + " train_loader, val_loader, device, epochs=60)\n", + "scoreboard.add(result)\n", + "" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 8.7 \u2014 Grassmannian + Procrustes Joint\n", + "\n", + "Combine per-patch subspace features (from 8.1) with per-patch\n", + "Procrustes alignment features (from 8.3). Tests complementarity." + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# @title Experiment 8.7 \u2014 Grassmannian + Procrustes Joint\n", + "class GrassProcrustesFrontEnd(nn.Module):\n", + " \"\"\"Joint: per-patch Grassmannian SVD + per-patch Procrustes alignment.\"\"\"\n", + " def __init__(self, n_templates=4, patch_size=8, input_size=32):\n", + " super().__init__()\n", + " self.patch_size = patch_size\n", + " self.n_templates = n_templates\n", + " self.C = 3\n", + " self.k = 3\n", + " n_h = input_size // patch_size\n", + " self.n_patches = n_h * n_h\n", + " ps2 = patch_size * patch_size\n", + " self.ps2 = ps2\n", + " # Procrustes templates: (T, n_patches, ps2, 3)\n", + " templates = torch.randn(n_templates, self.n_patches, ps2, 3)\n", + " templates = F.normalize(templates.reshape(n_templates, -1, 3), dim=-1)\n", + " templates = templates.reshape(n_templates, self.n_patches, ps2, 3)\n", + " self.register_buffer('templates', templates)\n", + " # Grassmannian: n_patches * (S:3 + V:9) = n_patches * 12\n", + " grass_dim = self.n_patches * 12\n", + " # Procrustes: T * n_patches * 3 (S map) + T * 3 (mean align, mean trace, proper frac)\n", + " proc_dim = n_templates * self.n_patches * 3 + n_templates * 3\n", + " self.output_dim = grass_dim + proc_dim\n", + " print(f\"[8.7 GRASS+PROC] grass={grass_dim} + proc={proc_dim} = {self.output_dim}\")\n", + "\n", + " @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n", + " def forward(self, x):\n", + " B, C, H, W = x.shape\n", + " T, n_p, ps2 = self.n_templates, self.n_patches, self.ps2\n", + " ps = self.patch_size\n", + " # Grassmannian path (8\u00d78 patches \u2192 (B*16, 64, 3))\n", + " patches_8 = x.unfold(2, ps, ps).unfold(3, ps, ps)\n", + " X = patches_8.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1)\n", + " U_g, S_g, Vh_g = batched_svd3(X)\n", + " grass_feats = torch.cat([S_g, Vh_g.reshape(-1, 9)], dim=-1).reshape(B, -1)\n", + " # Procrustes path (same patches reshaped to pixel-color format)\n", + " patches_px = patches_8.permute(0, 2, 3, 4, 5, 1).contiguous().reshape(B, n_p, ps2, C)\n", + " patches_n = F.normalize(patches_px, dim=-1)\n", + " p_exp = patches_n.unsqueeze(1).expand(B, T, n_p, ps2, C).reshape(B*T*n_p, ps2, C)\n", + " t_exp = self.templates.unsqueeze(0).expand(B, T, n_p, ps2, C).reshape(B*T*n_p, ps2, C)\n", + " M = torch.bmm(p_exp.transpose(1, 2), t_exp)\n", + " U_p, S_p, Vh_p = batched_svd3(M)\n", + " R_opt = torch.bmm(U_p, Vh_p)\n", + " S_map = S_p.reshape(B, T * n_p * 3)\n", + " align_q = S_p.reshape(B, T, n_p, 3).sum(-1).mean(-1)\n", + " rot_tr = R_opt.diagonal(dim1=-2, dim2=-1).sum(-1).reshape(B, T, n_p).mean(-1)\n", + " proper = (R_opt.diagonal(dim1=-2, dim2=-1).sum(-1).reshape(B, T, n_p) > 0).float().mean(-1)\n", + " proc_agg = torch.cat([align_q, rot_tr, proper], dim=-1)\n", + " return torch.cat([grass_feats, S_map, proc_agg], dim=-1)\n", + "\n", + "front = GrassProcrustesFrontEnd(n_templates=4, patch_size=8).to(device)\n", + "model_8_7 = SpectralGeoLIPEncoder(\n", + " front_end=front, front_end_dim=front.output_dim,\n", + " num_classes=10, output_dim=256, n_anchors=128, n_comp=8, d_comp=64,\n", + ").to(device)\n", + "model_8_7.summary()\n", + "\n", + "train_loader, val_loader = get_paired_loaders(batch_size=512)\n", + "result = train_spectral_geolip(\"8.7_grass_procrustes\", model_8_7,\n", + " train_loader, val_loader, device, epochs=60)\n", + "scoreboard.add(result)\n", + "" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 8.8 \u2014 Channel Mixing V with Spatial Statistics\n", + "\n", + "V tells us which RGB direction each principal component points.\n", + "Aggregate V over patches with mean, std, and spatial gradients.\n", + "Tests whether the spatial variation of V (not just V itself) is informative." + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# @title Experiment 8.8 \u2014 V Spatial Statistics\n", + "class VSpatialStatsFrontEnd(nn.Module):\n", + " \"\"\"V (right singular vectors) with spatial statistics: mean, std, gradient.\"\"\"\n", + " def __init__(self, patch_size=8, input_size=32):\n", + " super().__init__()\n", + " self.patch_size = patch_size\n", + " n_h = input_size // patch_size\n", + " self.n_h = n_h\n", + " self.n_patches = n_h * n_h\n", + " # Per patch: S(3) + V(9) = 12 \u2192 total 192 (spatial map)\n", + " # Global stats: V mean(9) + V std(9) = 18\n", + " # Spatial gradient: mean |V_h_grad|(9) + mean |V_v_grad|(9) = 18\n", + " self.output_dim = self.n_patches * 12 + 18 + 18\n", + " print(f\"[8.8 V-SPATIAL] {self.n_patches} patches, dim={self.output_dim}\")\n", + "\n", + " @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n", + " def forward(self, x):\n", + " B, C, H, W = x.shape\n", + " ps = self.patch_size\n", + " n_h = self.n_h\n", + " patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n", + " n_p = patches.shape[2] * patches.shape[3]\n", + " X = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1)\n", + " U, S, Vh = batched_svd3(X)\n", + " # Per-patch features\n", + " V_flat = Vh.reshape(B * n_p, 9)\n", + " patch_feats = torch.cat([S, V_flat], dim=-1).reshape(B, n_p, 12)\n", + " # V on grid: (B, n_h, n_h, 9)\n", + " V_grid = V_flat.reshape(B, n_h, n_h, 9)\n", + " V_mean = V_grid.reshape(B, -1, 9).mean(dim=1) # (B, 9)\n", + " V_std = V_grid.reshape(B, -1, 9).std(dim=1) # (B, 9)\n", + " # Spatial gradients of V\n", + " V_h_grad = (V_grid[:, :, 1:, :] - V_grid[:, :, :-1, :]).abs().mean(dim=(1, 2)) # (B, 9)\n", + " V_v_grad = (V_grid[:, 1:, :, :] - V_grid[:, :-1, :, :]).abs().mean(dim=(1, 2)) # (B, 9)\n", + " return torch.cat([\n", + " patch_feats.reshape(B, -1),\n", + " V_mean, V_std, V_h_grad, V_v_grad,\n", + " ], dim=-1)\n", + "\n", + "front = VSpatialStatsFrontEnd(patch_size=8).to(device)\n", + "model_8_8 = SpectralGeoLIPEncoder(\n", + " front_end=front, front_end_dim=front.output_dim,\n", + " num_classes=10, output_dim=256, n_anchors=128, n_comp=8, d_comp=64,\n", + ").to(device)\n", + "model_8_8.summary()\n", + "\n", + "train_loader, val_loader = get_paired_loaders(batch_size=512)\n", + "result = train_spectral_geolip(\"8.8_v_spatial_stats\", model_8_8,\n", + " train_loader, val_loader, device, epochs=60)\n", + "scoreboard.add(result)\n", + "" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 8.9 \u2014 Spectral Contrast Normalization\n", + "\n", + "Normalize singular values relative to local neighborhood.\n", + "\u03c3_norm = \u03c3_i / mean(\u03c3_i over neighbors). Captures spectral contrast\n", + "\u2014 patches that are spectrally different from their neighbors stand out." + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# @title Experiment 8.9 \u2014 Spectral Contrast Normalization\n", + "class SpectralContrastFrontEnd(nn.Module):\n", + " \"\"\"SVD with spatially-normalized singular values.\n", + " \u03c3_contrast = \u03c3 / mean(\u03c3_neighbors) highlights spectral edges.\"\"\"\n", + " def __init__(self, patch_size=8, input_size=32):\n", + " super().__init__()\n", + " self.patch_size = patch_size\n", + " n_h = input_size // patch_size\n", + " self.n_h = n_h\n", + " self.n_patches = n_h * n_h\n", + " # Per patch: raw S(3) + contrast S(3) + V(9) = 15\n", + " self.features_per_patch = 15\n", + " self.output_dim = self.n_patches * self.features_per_patch\n", + " print(f\"[8.9 SPEC-CONTRAST] {self.n_patches} patches, dim={self.output_dim}\")\n", + "\n", + " @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n", + " def forward(self, x):\n", + " B, C, H, W = x.shape\n", + " ps = self.patch_size\n", + " n_h = self.n_h\n", + " patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n", + " n_p = patches.shape[2] * patches.shape[3]\n", + " X = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1)\n", + " U, S, Vh = batched_svd3(X)\n", + " V_feat = Vh.reshape(B * n_p, 9)\n", + " # S on spatial grid: (B, n_h, n_h, 3)\n", + " S_grid = S.reshape(B, n_h, n_h, 3)\n", + " # Compute local mean S via avg_pool (3\u00d73 neighborhood)\n", + " S_perm = S_grid.permute(0, 3, 1, 2) # (B, 3, n_h, n_h)\n", + " S_local_mean = F.avg_pool2d(S_perm, kernel_size=3, stride=1, padding=1) # (B, 3, n_h, n_h)\n", + " S_contrast = S_perm / (S_local_mean + 1e-8)\n", + " S_contrast = S_contrast.permute(0, 2, 3, 1).reshape(B * n_p, 3)\n", + " feats = torch.cat([S, S_contrast, V_feat], dim=-1)\n", + " return feats.reshape(B, -1)\n", + "\n", + "front = SpectralContrastFrontEnd(patch_size=8).to(device)\n", + "model_8_9 = SpectralGeoLIPEncoder(\n", + " front_end=front, front_end_dim=front.output_dim,\n", + " num_classes=10, output_dim=256, n_anchors=128, n_comp=8, d_comp=64,\n", + ").to(device)\n", + "model_8_9.summary()\n", + "\n", + "train_loader, val_loader = get_paired_loaders(batch_size=512)\n", + "result = train_spectral_geolip(\"8.9_spectral_contrast\", model_8_9,\n", + " train_loader, val_loader, device, epochs=60)\n", + "scoreboard.add(result)\n", + "" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 8.10 \u2014 Kitchen Sink: All SVD Features\n", + "\n", + "Everything combined: S, V, log-ratios, energy fractions, subspace angles,\n", + "spectral contrast, spatial gradients. Maximum information extraction\n", + "from the 3\u00d73 eigendecomposition." + ] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# @title Experiment 8.10 \u2014 Kitchen Sink\n", + "class KitchenSinkSVDFrontEnd(nn.Module):\n", + " \"\"\"All SVD features combined: S, V, log-ratios, energy fractions,\n", + " spectral contrast, V spatial gradients. Maximum feature extraction.\"\"\"\n", + " def __init__(self, patch_size=8, input_size=32):\n", + " super().__init__()\n", + " self.patch_size = patch_size\n", + " n_h = input_size // patch_size\n", + " self.n_h = n_h\n", + " self.n_patches = n_h * n_h\n", + " # Per patch: S(3) + logS(3) + V(9) + energy_frac(3) + contrast_S(3) = 21\n", + " per_patch = 21\n", + " self.per_patch = per_patch\n", + " # Global: V_mean(9) + V_std(9) + V_h_grad(9) + V_v_grad(9) = 36\n", + " # + h_angles: n_h*(n_h-1)*3 = 36, v_angles: same = 36\n", + " n_h_edges = n_h * (n_h - 1)\n", + " n_v_edges = (n_h - 1) * n_h\n", + " self.n_h_edges = n_h_edges\n", + " self.n_v_edges = n_v_edges\n", + " self.output_dim = self.n_patches * per_patch + 36 + (n_h_edges + n_v_edges) * 3\n", + " print(f\"[8.10 KITCHEN-SINK] {self.n_patches} patches, dim={self.output_dim}\")\n", + "\n", + " @torch.amp.custom_fwd(device_type='cuda', cast_inputs=torch.float32)\n", + " def forward(self, x):\n", + " B, C, H, W = x.shape\n", + " ps = self.patch_size\n", + " n_h = self.n_h\n", + " patches = x.unfold(2, ps, ps).unfold(3, ps, ps)\n", + " n_p = patches.shape[2] * patches.shape[3]\n", + " X = patches.permute(0, 2, 3, 1, 4, 5).reshape(B * n_p, C, ps * ps).permute(0, 2, 1)\n", + " U, S, Vh = batched_svd3(X)\n", + " k = 3\n", + " # Per-patch features\n", + " sv_ratios = torch.log(S / (S[:, 2:3] + 1e-8) + 1e-8)\n", + " S2 = S * S\n", + " total_e = S2.sum(dim=-1, keepdim=True) + 1e-10\n", + " energy_frac = torch.cumsum(S2, dim=-1) / total_e\n", + " V_flat = Vh.reshape(B * n_p, 9)\n", + " # Spectral contrast\n", + " S_grid = S.reshape(B, n_h, n_h, 3).permute(0, 3, 1, 2)\n", + " S_local = F.avg_pool2d(S_grid, 3, 1, 1)\n", + " S_contrast = (S_grid / (S_local + 1e-8)).permute(0, 2, 3, 1).reshape(B * n_p, 3)\n", + " patch_feats = torch.cat([S, sv_ratios, V_flat, energy_frac, S_contrast], dim=-1)\n", + " patch_feats = patch_feats.reshape(B, n_p, self.per_patch)\n", + " # V spatial stats\n", + " V_grid = V_flat.reshape(B, n_h, n_h, 9)\n", + " V_mean = V_grid.reshape(B, -1, 9).mean(1)\n", + " V_std = V_grid.reshape(B, -1, 9).std(1)\n", + " V_h_grad = (V_grid[:, :, 1:] - V_grid[:, :, :-1]).abs().mean(dim=(1, 2))\n", + " V_v_grad = (V_grid[:, 1:, :] - V_grid[:, :-1, :]).abs().mean(dim=(1, 2))\n", + " # Subspace angles\n", + " Vh_grid = Vh.reshape(B, n_h, n_h, 3, 3)\n", + " Vh_l = Vh_grid[:, :, :-1].reshape(-1, 3, 3)\n", + " Vh_r = Vh_grid[:, :, 1:].reshape(-1, 3, 3)\n", + " _, ah, _ = batched_svd3(torch.bmm(Vh_l, Vh_r.transpose(1, 2)))\n", + " ah = torch.acos(ah.clamp(-1+1e-6, 1-1e-6))\n", + " Vh_t = Vh_grid[:, :-1, :].reshape(-1, 3, 3)\n", + " Vh_b = Vh_grid[:, 1:, :].reshape(-1, 3, 3)\n", + " _, av, _ = batched_svd3(torch.bmm(Vh_t, Vh_b.transpose(1, 2)))\n", + " av = torch.acos(av.clamp(-1+1e-6, 1-1e-6))\n", + " return torch.cat([\n", + " patch_feats.reshape(B, -1),\n", + " V_mean, V_std, V_h_grad, V_v_grad,\n", + " ah.reshape(B, -1), av.reshape(B, -1),\n", + " ], dim=-1)\n", + "\n", + "front = KitchenSinkSVDFrontEnd(patch_size=8).to(device)\n", + "model_8_10 = SpectralGeoLIPEncoder(\n", + " front_end=front, front_end_dim=front.output_dim,\n", + " num_classes=10, output_dim=256, n_anchors=128, n_comp=8, d_comp=64,\n", + ").to(device)\n", + "model_8_10.summary()\n", + "\n", + "train_loader, val_loader = get_paired_loaders(batch_size=512)\n", + "result = train_spectral_geolip(\"8.10_kitchen_sink\", model_8_10,\n", + " train_loader, val_loader, device, epochs=60)\n", + "scoreboard.add(result)\n", + "" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# @title FINAL SCOREBOARD \u2014 All Eigen/SVD3 Experiments\n", + "scoreboard._show()\n", + "scoreboard.save(\"scoreboard_exp8.json\")\n", + "print(\"\\n[DONE] All eigen/SVD3 kernel experiments complete!\")\n", + "" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "# @title Upload Runs to HuggingFace\n", + "upload_runs_to_hf(\"experiment_8_eigen_svd3_kernel\", \"scoreboard_exp8.json\")\n", + "" + ], + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file