Update modeling_tri_stream.py
Browse files- modeling_tri_stream.py +21 -24
modeling_tri_stream.py
CHANGED
|
@@ -182,14 +182,14 @@ def procrustes_align(source, target, whiten=False):
|
|
| 182 |
target: (N, D) β target positions (e.g. class centroids)
|
| 183 |
whiten: if True, normalize variance per dimension before alignment
|
| 184 |
"""
|
| 185 |
-
source_c = source - source.mean(0, keepdim=True)
|
| 186 |
-
target_c = target - target.mean(0, keepdim=True)
|
| 187 |
|
| 188 |
if whiten:
|
| 189 |
source_c = source_c / (source_c.std(0, keepdim=True) + 1e-8)
|
| 190 |
target_c = target_c / (target_c.std(0, keepdim=True) + 1e-8)
|
| 191 |
|
| 192 |
-
M = source_c.T @ target_c #
|
| 193 |
U, S, Vt = torch.linalg.svd(M)
|
| 194 |
|
| 195 |
# Ensure proper rotation (det = +1)
|
|
@@ -658,33 +658,30 @@ class TriStreamViT(nn.Module):
|
|
| 658 |
3. Procrustes-align matched anchors to centroids
|
| 659 |
4. Apply rotation to ALL anchors as small step
|
| 660 |
"""
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
|
|
|
|
|
|
| 664 |
|
| 665 |
-
|
| 666 |
|
| 667 |
-
|
| 668 |
-
|
| 669 |
-
|
| 670 |
-
|
| 671 |
-
matched_idx = cos.argmax(dim=1) # (C,) β one anchor per class
|
| 672 |
|
| 673 |
-
|
| 674 |
-
matched_anchors = anchors[matched_idx] # (C, D)
|
| 675 |
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
matched_anchors.float(), centroids.float(), whiten=whiten)
|
| 679 |
-
R = R.to(anchors.dtype)
|
| 680 |
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
)
|
| 686 |
|
| 687 |
-
|
| 688 |
|
| 689 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 690 |
# LOSS COMPUTATION
|
|
|
|
| 182 |
target: (N, D) β target positions (e.g. class centroids)
|
| 183 |
whiten: if True, normalize variance per dimension before alignment
|
| 184 |
"""
|
| 185 |
+
source_c = source.float() - source.float().mean(0, keepdim=True)
|
| 186 |
+
target_c = target.float() - target.float().mean(0, keepdim=True)
|
| 187 |
|
| 188 |
if whiten:
|
| 189 |
source_c = source_c / (source_c.std(0, keepdim=True) + 1e-8)
|
| 190 |
target_c = target_c / (target_c.std(0, keepdim=True) + 1e-8)
|
| 191 |
|
| 192 |
+
M = (source_c.T @ target_c).float() # SVD requires float32
|
| 193 |
U, S, Vt = torch.linalg.svd(M)
|
| 194 |
|
| 195 |
# Ensure proper rotation (det = +1)
|
|
|
|
| 658 |
3. Procrustes-align matched anchors to centroids
|
| 659 |
4. Apply rotation to ALL anchors as small step
|
| 660 |
"""
|
| 661 |
+
# Force float32 β SVD and det don't support bf16
|
| 662 |
+
with torch.amp.autocast("cuda", enabled=False):
|
| 663 |
+
centroids = simplex_buffer.class_centroids(self.num_classes)
|
| 664 |
+
if centroids is None:
|
| 665 |
+
return None
|
| 666 |
|
| 667 |
+
anchors = self.gal.gal_anchors.float() # (A, D)
|
| 668 |
|
| 669 |
+
centroid_n = F.normalize(centroids.float(), dim=-1)
|
| 670 |
+
anchor_n = F.normalize(anchors, dim=-1)
|
| 671 |
+
cos = centroid_n @ anchor_n.T
|
| 672 |
+
matched_idx = cos.argmax(dim=1)
|
|
|
|
| 673 |
|
| 674 |
+
matched_anchors = anchors[matched_idx]
|
|
|
|
| 675 |
|
| 676 |
+
R, score = procrustes_align(
|
| 677 |
+
matched_anchors, centroids.float(), whiten=whiten)
|
|
|
|
|
|
|
| 678 |
|
| 679 |
+
rotated = anchors @ R
|
| 680 |
+
new_anchors = F.normalize(
|
| 681 |
+
anchors + lr * (rotated - anchors), dim=-1)
|
| 682 |
+
self.gal.gal_anchors.copy_(new_anchors.to(self.gal.gal_anchors.dtype))
|
|
|
|
| 683 |
|
| 684 |
+
return score
|
| 685 |
|
| 686 |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 687 |
# LOSS COMPUTATION
|