AbstractPhil commited on
Commit
eaeeecc
Β·
verified Β·
1 Parent(s): 744d24d

Update test_cases.py

Browse files
Files changed (1) hide show
  1. test_cases.py +42 -14
test_cases.py CHANGED
@@ -33,6 +33,13 @@ import torch.nn.functional as F
33
  from torch import Tensor
34
  from typing import List, Optional, Tuple
35
 
 
 
 
 
 
 
 
36
 
37
  # ═══════════════════════════════════════════════════════════════════
38
  # Base Flow
@@ -255,10 +262,18 @@ class MagnitudeFlow(BaseFlow):
255
  # Project anchors to geometric space
256
  a_geom = self.anchor_proj(anchors) # [B, k, geom_dim]
257
 
258
- # Gram matrix eigenvalues β†’ spectral profile
259
  G = torch.bmm(a_geom.transpose(-2, -1), a_geom) # [B, geom_dim, geom_dim]
260
- # Use torch.linalg.eigh for now; swap to FL eigh in geolip.linalg
261
- eigenvalues, _ = torch.linalg.eigh(G) # [B, geom_dim]
 
 
 
 
 
 
 
 
262
 
263
  # Magnitude spectrum: how energy distributes across modes
264
  magnitudes = eigenvalues.abs().sqrt() # [B, geom_dim] β€” the Ο‰ spectrum
@@ -306,8 +321,14 @@ class OrbitalFlow(BaseFlow):
306
  a_geom = self.anchor_proj(anchors) # [B, k, geom_dim]
307
  G = torch.bmm(a_geom.transpose(-2, -1), a_geom) # [B, gd, gd]
308
 
309
- # Eigendecomposition β€” the Ο‰ spectrum
310
- eigenvalues, eigenvectors = torch.linalg.eigh(G) # [B, gd], [B, gd, gd]
 
 
 
 
 
 
311
 
312
  # Ο‰ = √|Ξ»|
313
  omega = eigenvalues.abs().sqrt() # [B, gd]
@@ -341,29 +362,36 @@ class OrbitalFlow(BaseFlow):
341
  class AlignmentFlow(BaseFlow):
342
  """SVD alignment flow via soft Procrustes rotation.
343
 
344
- Computes the optimal rotation aligning queries toward the anchor
345
- geometry using SVD of the cross-covariance matrix. The rotation
346
- is applied as a soft geometric bias.
347
  """
348
  def __init__(self, d_model: int, n_anchors: int):
349
  super().__init__(d_model, n_anchors, name='alignment')
350
  self.anchor_proj = nn.Linear(d_model, d_model)
351
  self.query_proj = nn.Linear(d_model, d_model)
352
- self.strength = nn.Parameter(torch.tensor(0.1)) # learnable blend
353
 
354
  def _flow(self, anchors, queries):
355
  B, n, d = queries.shape
356
  a_proj = self.anchor_proj(anchors) # [B, k, d]
357
  q_proj = self.query_proj(queries) # [B, n, d]
358
 
359
- # Cross-covariance: C = Q^T A
360
- C = torch.bmm(q_proj.transpose(-2, -1), a_proj) # [B, d, d]
 
 
 
 
 
361
 
362
  # SVD β†’ optimal rotation (Procrustes)
363
- U, _, Vh = torch.linalg.svd(C)
364
- R = torch.bmm(U, Vh) # [B, d, d] rotation matrix
 
 
365
 
366
- # Apply soft rotation
367
  q_rotated = torch.bmm(queries, R)
368
  return queries + self.strength * (q_rotated - queries)
369
 
 
33
  from torch import Tensor
34
  from typing import List, Optional, Tuple
35
 
36
+ # Use geolip_core.linalg when available (FL eigh, Triton SVD, etc.)
37
+ # Falls back to torch.linalg transparently
38
+ try:
39
+ import geolip_core.linalg as LA
40
+ except ImportError:
41
+ import torch.linalg as LA
42
+
43
 
44
  # ═══════════════════════════════════════════════════════════════════
45
  # Base Flow
 
262
  # Project anchors to geometric space
263
  a_geom = self.anchor_proj(anchors) # [B, k, geom_dim]
264
 
265
+ # Gram matrix
266
  G = torch.bmm(a_geom.transpose(-2, -1), a_geom) # [B, geom_dim, geom_dim]
267
+
268
+ # Eigenvectors under no_grad (FL eigh has in-place deflation)
269
+ # Eigenvalues recomputed via Rayleigh quotient for differentiability
270
+ with torch.no_grad():
271
+ _, V = LA.eigh(G) # [B, gd, gd]
272
+ V = V.detach()
273
+
274
+ # Rayleigh quotient: Ξ»α΅’ = vα΅’α΅€ G vα΅’ β€” differentiable through G
275
+ GV = torch.bmm(G, V) # [B, gd, gd]
276
+ eigenvalues = (V * GV).sum(dim=-2) # [B, gd]
277
 
278
  # Magnitude spectrum: how energy distributes across modes
279
  magnitudes = eigenvalues.abs().sqrt() # [B, geom_dim] β€” the Ο‰ spectrum
 
321
  a_geom = self.anchor_proj(anchors) # [B, k, geom_dim]
322
  G = torch.bmm(a_geom.transpose(-2, -1), a_geom) # [B, gd, gd]
323
 
324
+ # Eigenvectors under no_grad (FL eigh has in-place deflation)
325
+ with torch.no_grad():
326
+ _, eigenvectors = LA.eigh(G) # [B, gd, gd]
327
+ eigenvectors = eigenvectors.detach()
328
+
329
+ # Rayleigh quotient: differentiable eigenvalues through G
330
+ GV = torch.bmm(G, eigenvectors) # [B, gd, gd]
331
+ eigenvalues = (eigenvectors * GV).sum(dim=-2) # [B, gd]
332
 
333
  # Ο‰ = √|Ξ»|
334
  omega = eigenvalues.abs().sqrt() # [B, gd]
 
362
  class AlignmentFlow(BaseFlow):
363
  """SVD alignment flow via soft Procrustes rotation.
364
 
365
+ Computes attention-weighted anchor targets per query, then finds the
366
+ optimal rotation aligning queries toward those targets via SVD of
367
+ the cross-covariance matrix.
368
  """
369
  def __init__(self, d_model: int, n_anchors: int):
370
  super().__init__(d_model, n_anchors, name='alignment')
371
  self.anchor_proj = nn.Linear(d_model, d_model)
372
  self.query_proj = nn.Linear(d_model, d_model)
373
+ self.strength = nn.Parameter(torch.tensor(0.1))
374
 
375
  def _flow(self, anchors, queries):
376
  B, n, d = queries.shape
377
  a_proj = self.anchor_proj(anchors) # [B, k, d]
378
  q_proj = self.query_proj(queries) # [B, n, d]
379
 
380
+ # Attention-weighted anchors β†’ per-query targets [B, n, d]
381
+ sim = torch.bmm(q_proj, a_proj.transpose(-2, -1)) / math.sqrt(d)
382
+ weights = F.softmax(sim, dim=-1) # [B, n, k]
383
+ targets = torch.bmm(weights, a_proj) # [B, n, d]
384
+
385
+ # Cross-covariance: C = Q^T @ T, both [B, n, d] β†’ C is [B, d, d]
386
+ C = torch.bmm(q_proj.transpose(-2, -1), targets)
387
 
388
  # SVD β†’ optimal rotation (Procrustes)
389
+ U, _, Vh = torch.linalg.svd(C) # full dΓ—d SVD β€” not through geolip.linalg.svd
390
+ # Note: geolip.linalg.svd is thin SVD for Mβ‰₯N rectangular matrices.
391
+ # Cross-covariance C is square [B, d, d], use torch directly.
392
+ R = torch.bmm(U, Vh) # [B, d, d]
393
 
394
+ # Soft rotation
395
  q_rotated = torch.bmm(queries, R)
396
  return queries + self.strength * (q_rotated - queries)
397