AbstractPhil commited on
Commit
013d216
·
verified ·
1 Parent(s): ed7f648

Update test_cases.py

Browse files
Files changed (1) hide show
  1. test_cases.py +25 -37
test_cases.py CHANGED
@@ -265,15 +265,8 @@ class MagnitudeFlow(BaseFlow):
265
  # Gram matrix
266
  G = torch.bmm(a_geom.transpose(-2, -1), a_geom) # [B, geom_dim, geom_dim]
267
 
268
- # Eigenvectors under no_grad (FL eigh has in-place deflation)
269
- # Eigenvalues recomputed via Rayleigh quotient for differentiability
270
- with torch.no_grad():
271
- _, V = LA.eigh(G) # [B, gd, gd]
272
- V = V.detach()
273
-
274
- # Rayleigh quotient: λᵢ = vᵢᵀ G vᵢ — differentiable through G
275
- GV = torch.bmm(G, V) # [B, gd, gd]
276
- eigenvalues = (V * GV).sum(dim=-2) # [B, gd]
277
 
278
  # Magnitude spectrum: how energy distributes across modes
279
  magnitudes = eigenvalues.abs().sqrt() # [B, geom_dim] — the ω spectrum
@@ -321,14 +314,8 @@ class OrbitalFlow(BaseFlow):
321
  a_geom = self.anchor_proj(anchors) # [B, k, geom_dim]
322
  G = torch.bmm(a_geom.transpose(-2, -1), a_geom) # [B, gd, gd]
323
 
324
- # Eigenvectors under no_grad (FL eigh has in-place deflation)
325
- with torch.no_grad():
326
- _, eigenvectors = LA.eigh(G) # [B, gd, gd]
327
- eigenvectors = eigenvectors.detach()
328
-
329
- # Rayleigh quotient: differentiable eigenvalues through G
330
- GV = torch.bmm(G, eigenvectors) # [B, gd, gd]
331
- eigenvalues = (eigenvectors * GV).sum(dim=-2) # [B, gd]
332
 
333
  # ω = √|λ|
334
  omega = eigenvalues.abs().sqrt() # [B, gd]
@@ -360,40 +347,41 @@ class OrbitalFlow(BaseFlow):
360
  # ═══════════════════════════════════════════════════════════════════
361
 
362
  class AlignmentFlow(BaseFlow):
363
- """SVD alignment flow via soft Procrustes rotation.
364
 
365
- Computes attention-weighted anchor targets per query, then finds the
366
- optimal rotation aligning queries toward those targets via SVD of
367
- the cross-covariance matrix.
368
  """
369
  def __init__(self, d_model: int, n_anchors: int):
370
  super().__init__(d_model, n_anchors, name='alignment')
371
- self.anchor_proj = nn.Linear(d_model, d_model)
372
- self.query_proj = nn.Linear(d_model, d_model)
 
 
373
  self.strength = nn.Parameter(torch.tensor(0.1))
374
 
375
  def _flow(self, anchors, queries):
376
  B, n, d = queries.shape
377
- a_proj = self.anchor_proj(anchors) # [B, k, d]
378
- q_proj = self.query_proj(queries) # [B, n, d]
 
379
 
380
- # Attention-weighted anchors → per-query targets [B, n, d]
381
- sim = torch.bmm(q_proj, a_proj.transpose(-2, -1)) / math.sqrt(d)
382
  weights = F.softmax(sim, dim=-1) # [B, n, k]
383
- targets = torch.bmm(weights, a_proj) # [B, n, d]
384
 
385
- # Cross-covariance: C = Q^T @ T, both [B, n, d] → C is [B, d, d]
386
  C = torch.bmm(q_proj.transpose(-2, -1), targets)
387
 
388
- # SVD → optimal rotation (Procrustes)
389
- U, _, Vh = torch.linalg.svd(C) # full d×d SVD — not through geolip.linalg.svd
390
- # Note: geolip.linalg.svd is thin SVD for M≥N rectangular matrices.
391
- # Cross-covariance C is square [B, d, d], use torch directly.
392
- R = torch.bmm(U, Vh) # [B, d, d]
393
 
394
- # Soft rotation
395
- q_rotated = torch.bmm(queries, R)
396
- return queries + self.strength * (q_rotated - queries)
 
397
 
398
 
399
  # ═════════════════════════════════════════��═════════════════════════
 
265
  # Gram matrix
266
  G = torch.bmm(a_geom.transpose(-2, -1), a_geom) # [B, geom_dim, geom_dim]
267
 
268
+ # Eigendecomposition differentiable through torch.linalg.eigh
269
+ eigenvalues, _ = LA.eigh(G, method='torch') # [B, geom_dim]
 
 
 
 
 
 
 
270
 
271
  # Magnitude spectrum: how energy distributes across modes
272
  magnitudes = eigenvalues.abs().sqrt() # [B, geom_dim] — the ω spectrum
 
314
  a_geom = self.anchor_proj(anchors) # [B, k, geom_dim]
315
  G = torch.bmm(a_geom.transpose(-2, -1), a_geom) # [B, gd, gd]
316
 
317
+ # Eigendecomposition the ω spectrum (differentiable via torch.linalg.eigh)
318
+ eigenvalues, eigenvectors = LA.eigh(G, method='torch') # [B, gd], [B, gd, gd]
 
 
 
 
 
 
319
 
320
  # ω = √|λ|
321
  omega = eigenvalues.abs().sqrt() # [B, gd]
 
347
  # ═══════════════════════════════════════════════════════════════════
348
 
349
  class AlignmentFlow(BaseFlow):
350
+ """SVD alignment flow via soft Procrustes rotation in projected space.
351
 
352
+ Projects to geom_dim, computes optimal rotation via SVD of the
353
+ cross-covariance in the small space, applies rotation, projects back.
 
354
  """
355
  def __init__(self, d_model: int, n_anchors: int):
356
  super().__init__(d_model, n_anchors, name='alignment')
357
+ self.geom_dim = min(n_anchors, 12) # FL eigh sweet spot
358
+ self.anchor_proj = nn.Linear(d_model, self.geom_dim)
359
+ self.query_proj = nn.Linear(d_model, self.geom_dim)
360
+ self.geom_to_query = nn.Linear(self.geom_dim, d_model)
361
  self.strength = nn.Parameter(torch.tensor(0.1))
362
 
363
  def _flow(self, anchors, queries):
364
  B, n, d = queries.shape
365
+ # Project to small geometric space
366
+ a_proj = self.anchor_proj(anchors) # [B, k, geom_dim]
367
+ q_proj = self.query_proj(queries) # [B, n, geom_dim]
368
 
369
+ # Attention-weighted anchors → per-query targets [B, n, geom_dim]
370
+ sim = torch.bmm(q_proj, a_proj.transpose(-2, -1)) / math.sqrt(self.geom_dim)
371
  weights = F.softmax(sim, dim=-1) # [B, n, k]
372
+ targets = torch.bmm(weights, a_proj) # [B, n, geom_dim]
373
 
374
+ # Cross-covariance in small space: [B, geom_dim, geom_dim]
375
  C = torch.bmm(q_proj.transpose(-2, -1), targets)
376
 
377
+ # SVD → optimal rotation via gram_eigh (differentiable, no in-place ops)
378
+ U, _, Vh = LA.svd(C, method='gram_eigh')
379
+ R = torch.bmm(U, Vh) # [B, geom_dim, geom_dim]
 
 
380
 
381
+ # Rotate queries in small space, project back to d_model
382
+ q_rotated = torch.bmm(q_proj, R) # [B, n, geom_dim]
383
+ delta = self.geom_to_query(q_rotated - q_proj) # [B, n, d]
384
+ return queries + self.strength * delta
385
 
386
 
387
  # ═════════════════════════════════════════��═════════════════════════