AbstractPhil commited on
Commit
b6d45f7
Β·
verified Β·
1 Parent(s): a7e3c11

Update modeling_tri_stream.py

Browse files
Files changed (1) hide show
  1. modeling_tri_stream.py +21 -24
modeling_tri_stream.py CHANGED
@@ -182,14 +182,14 @@ def procrustes_align(source, target, whiten=False):
182
  target: (N, D) β€” target positions (e.g. class centroids)
183
  whiten: if True, normalize variance per dimension before alignment
184
  """
185
- source_c = source - source.mean(0, keepdim=True)
186
- target_c = target - target.mean(0, keepdim=True)
187
 
188
  if whiten:
189
  source_c = source_c / (source_c.std(0, keepdim=True) + 1e-8)
190
  target_c = target_c / (target_c.std(0, keepdim=True) + 1e-8)
191
 
192
- M = source_c.T @ target_c # (D, D)
193
  U, S, Vt = torch.linalg.svd(M)
194
 
195
  # Ensure proper rotation (det = +1)
@@ -658,33 +658,30 @@ class TriStreamViT(nn.Module):
658
  3. Procrustes-align matched anchors to centroids
659
  4. Apply rotation to ALL anchors as small step
660
  """
661
- centroids = simplex_buffer.class_centroids(self.num_classes)
662
- if centroids is None:
663
- return None
 
 
664
 
665
- anchors = self.gal.gal_anchors # (A, D)
666
 
667
- # Match: for each centroid, find nearest anchor
668
- centroid_n = F.normalize(centroids, dim=-1)
669
- anchor_n = F.normalize(anchors, dim=-1)
670
- cos = centroid_n @ anchor_n.T # (C, A)
671
- matched_idx = cos.argmax(dim=1) # (C,) β€” one anchor per class
672
 
673
- # Extract matched anchor positions
674
- matched_anchors = anchors[matched_idx] # (C, D)
675
 
676
- # Procrustes: rotate matched_anchors β†’ centroids
677
- R, score = procrustes_align(
678
- matched_anchors.float(), centroids.float(), whiten=whiten)
679
- R = R.to(anchors.dtype)
680
 
681
- # Apply as small rotation step to ALL anchors
682
- rotated = anchors @ R
683
- self.gal.gal_anchors.copy_(
684
- F.normalize(anchors + lr * (rotated - anchors), dim=-1)
685
- )
686
 
687
- return score
688
 
689
  # ──────────────────────────────────────────────────────────
690
  # LOSS COMPUTATION
 
182
  target: (N, D) β€” target positions (e.g. class centroids)
183
  whiten: if True, normalize variance per dimension before alignment
184
  """
185
+ source_c = source.float() - source.float().mean(0, keepdim=True)
186
+ target_c = target.float() - target.float().mean(0, keepdim=True)
187
 
188
  if whiten:
189
  source_c = source_c / (source_c.std(0, keepdim=True) + 1e-8)
190
  target_c = target_c / (target_c.std(0, keepdim=True) + 1e-8)
191
 
192
+ M = (source_c.T @ target_c).float() # SVD requires float32
193
  U, S, Vt = torch.linalg.svd(M)
194
 
195
  # Ensure proper rotation (det = +1)
 
658
  3. Procrustes-align matched anchors to centroids
659
  4. Apply rotation to ALL anchors as small step
660
  """
661
+ # Force float32 β€” SVD and det don't support bf16
662
+ with torch.amp.autocast("cuda", enabled=False):
663
+ centroids = simplex_buffer.class_centroids(self.num_classes)
664
+ if centroids is None:
665
+ return None
666
 
667
+ anchors = self.gal.gal_anchors.float() # (A, D)
668
 
669
+ centroid_n = F.normalize(centroids.float(), dim=-1)
670
+ anchor_n = F.normalize(anchors, dim=-1)
671
+ cos = centroid_n @ anchor_n.T
672
+ matched_idx = cos.argmax(dim=1)
 
673
 
674
+ matched_anchors = anchors[matched_idx]
 
675
 
676
+ R, score = procrustes_align(
677
+ matched_anchors, centroids.float(), whiten=whiten)
 
 
678
 
679
+ rotated = anchors @ R
680
+ new_anchors = F.normalize(
681
+ anchors + lr * (rotated - anchors), dim=-1)
682
+ self.gal.gal_anchors.copy_(new_anchors.to(self.gal.gal_anchors.dtype))
 
683
 
684
+ return score
685
 
686
  # ──────────────────────────────────────────────────────────
687
  # LOSS COMPUTATION