Update test_cases.py
Browse files- test_cases.py +42 -14
test_cases.py
CHANGED
|
@@ -33,6 +33,13 @@ import torch.nn.functional as F
|
|
| 33 |
from torch import Tensor
|
| 34 |
from typing import List, Optional, Tuple
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 38 |
# Base Flow
|
|
@@ -255,10 +262,18 @@ class MagnitudeFlow(BaseFlow):
|
|
| 255 |
# Project anchors to geometric space
|
| 256 |
a_geom = self.anchor_proj(anchors) # [B, k, geom_dim]
|
| 257 |
|
| 258 |
-
# Gram matrix
|
| 259 |
G = torch.bmm(a_geom.transpose(-2, -1), a_geom) # [B, geom_dim, geom_dim]
|
| 260 |
-
|
| 261 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
|
| 263 |
# Magnitude spectrum: how energy distributes across modes
|
| 264 |
magnitudes = eigenvalues.abs().sqrt() # [B, geom_dim] β the Ο spectrum
|
|
@@ -306,8 +321,14 @@ class OrbitalFlow(BaseFlow):
|
|
| 306 |
a_geom = self.anchor_proj(anchors) # [B, k, geom_dim]
|
| 307 |
G = torch.bmm(a_geom.transpose(-2, -1), a_geom) # [B, gd, gd]
|
| 308 |
|
| 309 |
-
#
|
| 310 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
|
| 312 |
# Ο = β|Ξ»|
|
| 313 |
omega = eigenvalues.abs().sqrt() # [B, gd]
|
|
@@ -341,29 +362,36 @@ class OrbitalFlow(BaseFlow):
|
|
| 341 |
class AlignmentFlow(BaseFlow):
|
| 342 |
"""SVD alignment flow via soft Procrustes rotation.
|
| 343 |
|
| 344 |
-
Computes
|
| 345 |
-
|
| 346 |
-
|
| 347 |
"""
|
| 348 |
def __init__(self, d_model: int, n_anchors: int):
|
| 349 |
super().__init__(d_model, n_anchors, name='alignment')
|
| 350 |
self.anchor_proj = nn.Linear(d_model, d_model)
|
| 351 |
self.query_proj = nn.Linear(d_model, d_model)
|
| 352 |
-
self.strength = nn.Parameter(torch.tensor(0.1))
|
| 353 |
|
| 354 |
def _flow(self, anchors, queries):
|
| 355 |
B, n, d = queries.shape
|
| 356 |
a_proj = self.anchor_proj(anchors) # [B, k, d]
|
| 357 |
q_proj = self.query_proj(queries) # [B, n, d]
|
| 358 |
|
| 359 |
-
#
|
| 360 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
|
| 362 |
# SVD β optimal rotation (Procrustes)
|
| 363 |
-
U, _, Vh = torch.linalg.svd(C)
|
| 364 |
-
|
|
|
|
|
|
|
| 365 |
|
| 366 |
-
#
|
| 367 |
q_rotated = torch.bmm(queries, R)
|
| 368 |
return queries + self.strength * (q_rotated - queries)
|
| 369 |
|
|
|
|
| 33 |
from torch import Tensor
|
| 34 |
from typing import List, Optional, Tuple
|
| 35 |
|
| 36 |
+
# Use geolip_core.linalg when available (FL eigh, Triton SVD, etc.)
|
| 37 |
+
# Falls back to torch.linalg transparently
|
| 38 |
+
try:
|
| 39 |
+
import geolip_core.linalg as LA
|
| 40 |
+
except ImportError:
|
| 41 |
+
import torch.linalg as LA
|
| 42 |
+
|
| 43 |
|
| 44 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 45 |
# Base Flow
|
|
|
|
| 262 |
# Project anchors to geometric space
|
| 263 |
a_geom = self.anchor_proj(anchors) # [B, k, geom_dim]
|
| 264 |
|
| 265 |
+
# Gram matrix
|
| 266 |
G = torch.bmm(a_geom.transpose(-2, -1), a_geom) # [B, geom_dim, geom_dim]
|
| 267 |
+
|
| 268 |
+
# Eigenvectors under no_grad (FL eigh has in-place deflation)
|
| 269 |
+
# Eigenvalues recomputed via Rayleigh quotient for differentiability
|
| 270 |
+
with torch.no_grad():
|
| 271 |
+
_, V = LA.eigh(G) # [B, gd, gd]
|
| 272 |
+
V = V.detach()
|
| 273 |
+
|
| 274 |
+
# Rayleigh quotient: Ξ»α΅’ = vα΅’α΅ G vα΅’ β differentiable through G
|
| 275 |
+
GV = torch.bmm(G, V) # [B, gd, gd]
|
| 276 |
+
eigenvalues = (V * GV).sum(dim=-2) # [B, gd]
|
| 277 |
|
| 278 |
# Magnitude spectrum: how energy distributes across modes
|
| 279 |
magnitudes = eigenvalues.abs().sqrt() # [B, geom_dim] β the Ο spectrum
|
|
|
|
| 321 |
a_geom = self.anchor_proj(anchors) # [B, k, geom_dim]
|
| 322 |
G = torch.bmm(a_geom.transpose(-2, -1), a_geom) # [B, gd, gd]
|
| 323 |
|
| 324 |
+
# Eigenvectors under no_grad (FL eigh has in-place deflation)
|
| 325 |
+
with torch.no_grad():
|
| 326 |
+
_, eigenvectors = LA.eigh(G) # [B, gd, gd]
|
| 327 |
+
eigenvectors = eigenvectors.detach()
|
| 328 |
+
|
| 329 |
+
# Rayleigh quotient: differentiable eigenvalues through G
|
| 330 |
+
GV = torch.bmm(G, eigenvectors) # [B, gd, gd]
|
| 331 |
+
eigenvalues = (eigenvectors * GV).sum(dim=-2) # [B, gd]
|
| 332 |
|
| 333 |
# Ο = β|Ξ»|
|
| 334 |
omega = eigenvalues.abs().sqrt() # [B, gd]
|
|
|
|
| 362 |
class AlignmentFlow(BaseFlow):
|
| 363 |
"""SVD alignment flow via soft Procrustes rotation.
|
| 364 |
|
| 365 |
+
Computes attention-weighted anchor targets per query, then finds the
|
| 366 |
+
optimal rotation aligning queries toward those targets via SVD of
|
| 367 |
+
the cross-covariance matrix.
|
| 368 |
"""
|
| 369 |
def __init__(self, d_model: int, n_anchors: int):
|
| 370 |
super().__init__(d_model, n_anchors, name='alignment')
|
| 371 |
self.anchor_proj = nn.Linear(d_model, d_model)
|
| 372 |
self.query_proj = nn.Linear(d_model, d_model)
|
| 373 |
+
self.strength = nn.Parameter(torch.tensor(0.1))
|
| 374 |
|
| 375 |
def _flow(self, anchors, queries):
|
| 376 |
B, n, d = queries.shape
|
| 377 |
a_proj = self.anchor_proj(anchors) # [B, k, d]
|
| 378 |
q_proj = self.query_proj(queries) # [B, n, d]
|
| 379 |
|
| 380 |
+
# Attention-weighted anchors β per-query targets [B, n, d]
|
| 381 |
+
sim = torch.bmm(q_proj, a_proj.transpose(-2, -1)) / math.sqrt(d)
|
| 382 |
+
weights = F.softmax(sim, dim=-1) # [B, n, k]
|
| 383 |
+
targets = torch.bmm(weights, a_proj) # [B, n, d]
|
| 384 |
+
|
| 385 |
+
# Cross-covariance: C = Q^T @ T, both [B, n, d] β C is [B, d, d]
|
| 386 |
+
C = torch.bmm(q_proj.transpose(-2, -1), targets)
|
| 387 |
|
| 388 |
# SVD β optimal rotation (Procrustes)
|
| 389 |
+
U, _, Vh = torch.linalg.svd(C) # full dΓd SVD β not through geolip.linalg.svd
|
| 390 |
+
# Note: geolip.linalg.svd is thin SVD for Mβ₯N rectangular matrices.
|
| 391 |
+
# Cross-covariance C is square [B, d, d], use torch directly.
|
| 392 |
+
R = torch.bmm(U, Vh) # [B, d, d]
|
| 393 |
|
| 394 |
+
# Soft rotation
|
| 395 |
q_rotated = torch.bmm(queries, R)
|
| 396 |
return queries + self.strength * (q_rotated - queries)
|
| 397 |
|