Update test_cases.py
Browse files- test_cases.py +25 -37
test_cases.py
CHANGED
|
@@ -265,15 +265,8 @@ class MagnitudeFlow(BaseFlow):
|
|
| 265 |
# Gram matrix
|
| 266 |
G = torch.bmm(a_geom.transpose(-2, -1), a_geom) # [B, geom_dim, geom_dim]
|
| 267 |
|
| 268 |
-
#
|
| 269 |
-
|
| 270 |
-
with torch.no_grad():
|
| 271 |
-
_, V = LA.eigh(G) # [B, gd, gd]
|
| 272 |
-
V = V.detach()
|
| 273 |
-
|
| 274 |
-
# Rayleigh quotient: λᵢ = vᵢᵀ G vᵢ — differentiable through G
|
| 275 |
-
GV = torch.bmm(G, V) # [B, gd, gd]
|
| 276 |
-
eigenvalues = (V * GV).sum(dim=-2) # [B, gd]
|
| 277 |
|
| 278 |
# Magnitude spectrum: how energy distributes across modes
|
| 279 |
magnitudes = eigenvalues.abs().sqrt() # [B, geom_dim] — the ω spectrum
|
|
@@ -321,14 +314,8 @@ class OrbitalFlow(BaseFlow):
|
|
| 321 |
a_geom = self.anchor_proj(anchors) # [B, k, geom_dim]
|
| 322 |
G = torch.bmm(a_geom.transpose(-2, -1), a_geom) # [B, gd, gd]
|
| 323 |
|
| 324 |
-
#
|
| 325 |
-
|
| 326 |
-
_, eigenvectors = LA.eigh(G) # [B, gd, gd]
|
| 327 |
-
eigenvectors = eigenvectors.detach()
|
| 328 |
-
|
| 329 |
-
# Rayleigh quotient: differentiable eigenvalues through G
|
| 330 |
-
GV = torch.bmm(G, eigenvectors) # [B, gd, gd]
|
| 331 |
-
eigenvalues = (eigenvectors * GV).sum(dim=-2) # [B, gd]
|
| 332 |
|
| 333 |
# ω = √|λ|
|
| 334 |
omega = eigenvalues.abs().sqrt() # [B, gd]
|
|
@@ -360,40 +347,41 @@ class OrbitalFlow(BaseFlow):
|
|
| 360 |
# ═══════════════════════════════════════════════════════════════════
|
| 361 |
|
| 362 |
class AlignmentFlow(BaseFlow):
|
| 363 |
-
"""SVD alignment flow via soft Procrustes rotation.
|
| 364 |
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
the cross-covariance matrix.
|
| 368 |
"""
|
| 369 |
def __init__(self, d_model: int, n_anchors: int):
|
| 370 |
super().__init__(d_model, n_anchors, name='alignment')
|
| 371 |
-
self.
|
| 372 |
-
self.
|
|
|
|
|
|
|
| 373 |
self.strength = nn.Parameter(torch.tensor(0.1))
|
| 374 |
|
| 375 |
def _flow(self, anchors, queries):
|
| 376 |
B, n, d = queries.shape
|
| 377 |
-
|
| 378 |
-
|
|
|
|
| 379 |
|
| 380 |
-
# Attention-weighted anchors → per-query targets [B, n,
|
| 381 |
-
sim = torch.bmm(q_proj, a_proj.transpose(-2, -1)) / math.sqrt(
|
| 382 |
weights = F.softmax(sim, dim=-1) # [B, n, k]
|
| 383 |
-
targets = torch.bmm(weights, a_proj) # [B, n,
|
| 384 |
|
| 385 |
-
# Cross-covariance
|
| 386 |
C = torch.bmm(q_proj.transpose(-2, -1), targets)
|
| 387 |
|
| 388 |
-
# SVD → optimal rotation (
|
| 389 |
-
U, _, Vh =
|
| 390 |
-
|
| 391 |
-
# Cross-covariance C is square [B, d, d], use torch directly.
|
| 392 |
-
R = torch.bmm(U, Vh) # [B, d, d]
|
| 393 |
|
| 394 |
-
#
|
| 395 |
-
q_rotated = torch.bmm(
|
| 396 |
-
|
|
|
|
| 397 |
|
| 398 |
|
| 399 |
# ═════════════════════════════════════════��═════════════════════════
|
|
|
|
| 265 |
# Gram matrix
|
| 266 |
G = torch.bmm(a_geom.transpose(-2, -1), a_geom) # [B, geom_dim, geom_dim]
|
| 267 |
|
| 268 |
+
# Eigendecomposition — differentiable through torch.linalg.eigh
|
| 269 |
+
eigenvalues, _ = LA.eigh(G, method='torch') # [B, geom_dim]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
|
| 271 |
# Magnitude spectrum: how energy distributes across modes
|
| 272 |
magnitudes = eigenvalues.abs().sqrt() # [B, geom_dim] — the ω spectrum
|
|
|
|
| 314 |
a_geom = self.anchor_proj(anchors) # [B, k, geom_dim]
|
| 315 |
G = torch.bmm(a_geom.transpose(-2, -1), a_geom) # [B, gd, gd]
|
| 316 |
|
| 317 |
+
# Eigendecomposition — the ω spectrum (differentiable via torch.linalg.eigh)
|
| 318 |
+
eigenvalues, eigenvectors = LA.eigh(G, method='torch') # [B, gd], [B, gd, gd]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 319 |
|
| 320 |
# ω = √|λ|
|
| 321 |
omega = eigenvalues.abs().sqrt() # [B, gd]
|
|
|
|
| 347 |
# ═══════════════════════════════════════════════════════════════════
|
| 348 |
|
| 349 |
class AlignmentFlow(BaseFlow):
|
| 350 |
+
"""SVD alignment flow via soft Procrustes rotation in projected space.
|
| 351 |
|
| 352 |
+
Projects to geom_dim, computes optimal rotation via SVD of the
|
| 353 |
+
cross-covariance in the small space, applies rotation, projects back.
|
|
|
|
| 354 |
"""
|
| 355 |
def __init__(self, d_model: int, n_anchors: int):
|
| 356 |
super().__init__(d_model, n_anchors, name='alignment')
|
| 357 |
+
self.geom_dim = min(n_anchors, 12) # FL eigh sweet spot
|
| 358 |
+
self.anchor_proj = nn.Linear(d_model, self.geom_dim)
|
| 359 |
+
self.query_proj = nn.Linear(d_model, self.geom_dim)
|
| 360 |
+
self.geom_to_query = nn.Linear(self.geom_dim, d_model)
|
| 361 |
self.strength = nn.Parameter(torch.tensor(0.1))
|
| 362 |
|
| 363 |
def _flow(self, anchors, queries):
|
| 364 |
B, n, d = queries.shape
|
| 365 |
+
# Project to small geometric space
|
| 366 |
+
a_proj = self.anchor_proj(anchors) # [B, k, geom_dim]
|
| 367 |
+
q_proj = self.query_proj(queries) # [B, n, geom_dim]
|
| 368 |
|
| 369 |
+
# Attention-weighted anchors → per-query targets [B, n, geom_dim]
|
| 370 |
+
sim = torch.bmm(q_proj, a_proj.transpose(-2, -1)) / math.sqrt(self.geom_dim)
|
| 371 |
weights = F.softmax(sim, dim=-1) # [B, n, k]
|
| 372 |
+
targets = torch.bmm(weights, a_proj) # [B, n, geom_dim]
|
| 373 |
|
| 374 |
+
# Cross-covariance in small space: [B, geom_dim, geom_dim]
|
| 375 |
C = torch.bmm(q_proj.transpose(-2, -1), targets)
|
| 376 |
|
| 377 |
+
# SVD → optimal rotation via gram_eigh (differentiable, no in-place ops)
|
| 378 |
+
U, _, Vh = LA.svd(C, method='gram_eigh')
|
| 379 |
+
R = torch.bmm(U, Vh) # [B, geom_dim, geom_dim]
|
|
|
|
|
|
|
| 380 |
|
| 381 |
+
# Rotate queries in small space, project back to d_model
|
| 382 |
+
q_rotated = torch.bmm(q_proj, R) # [B, n, geom_dim]
|
| 383 |
+
delta = self.geom_to_query(q_rotated - q_proj) # [B, n, d]
|
| 384 |
+
return queries + self.strength * delta
|
| 385 |
|
| 386 |
|
| 387 |
# ═════════════════════════════════════════��═════════════════════════
|