Removed iterations and optimized further.
Browse files- vit_zana_v3.py +26 -19
vit_zana_v3.py
CHANGED
|
@@ -203,6 +203,15 @@ class BaselineViT(nn.Module):
|
|
| 203 |
# Temperature for similarity-based classification
|
| 204 |
self.temperature = nn.Parameter(torch.ones(1) * np.log(1/0.07))
|
| 205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
# Initialize weights
|
| 207 |
self.init_weights()
|
| 208 |
|
|
@@ -220,29 +229,27 @@ class BaselineViT(nn.Module):
|
|
| 220 |
nn.init.ones_(m.weight)
|
| 221 |
nn.init.zeros_(m.bias)
|
| 222 |
|
|
|
|
| 223 |
def get_class_centroids(self) -> torch.Tensor:
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
penta.get_centroid() for penta in self.class_pentachora
|
| 227 |
-
])
|
| 228 |
-
return F.normalize(centroids, dim=-1)
|
| 229 |
-
|
| 230 |
def compute_pentachora_similarities(self, features: torch.Tensor) -> torch.Tensor:
|
| 231 |
"""
|
| 232 |
-
Compute similarities between features and all class pentachora.
|
| 233 |
-
|
| 234 |
-
Args:
|
| 235 |
-
features: [batch, dim] features to compare
|
| 236 |
-
|
| 237 |
-
Returns:
|
| 238 |
-
similarities: [batch, num_classes]
|
| 239 |
"""
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 246 |
|
| 247 |
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
| 248 |
"""Extract features from images."""
|
|
|
|
| 203 |
# Temperature for similarity-based classification
|
| 204 |
self.temperature = nn.Parameter(torch.ones(1) * np.log(1/0.07))
|
| 205 |
|
| 206 |
+
self.register_buffer(
|
| 207 |
+
'all_centroids',
|
| 208 |
+
torch.stack([penta.centroid for penta in self.class_pentachora])
|
| 209 |
+
)
|
| 210 |
+
self.register_buffer(
|
| 211 |
+
'all_centroids_norm',
|
| 212 |
+
F.normalize(self.all_centroids, dim=-1)
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
# Initialize weights
|
| 216 |
self.init_weights()
|
| 217 |
|
|
|
|
| 229 |
nn.init.ones_(m.weight)
|
| 230 |
nn.init.zeros_(m.bias)
|
| 231 |
|
| 232 |
+
# Then get_class_centroids becomes:
|
| 233 |
def get_class_centroids(self) -> torch.Tensor:
|
| 234 |
+
return self.all_centroids_norm
|
| 235 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
def compute_pentachora_similarities(self, features: torch.Tensor) -> torch.Tensor:
|
| 237 |
"""
|
| 238 |
+
Compute similarities between features and all class pentachora (vectorized).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
"""
|
| 240 |
+
if self.similarity_mode == 'rose':
|
| 241 |
+
# Stack all vertices into single tensor for batch Rose scoring
|
| 242 |
+
all_vertices = torch.stack([penta.vertices for penta in self.class_pentachora]) # [100, 5, 128]
|
| 243 |
+
# Expand features for batch computation
|
| 244 |
+
features_exp = features.unsqueeze(1).expand(-1, self.num_classes, -1) # [B, 100, 128]
|
| 245 |
+
# Compute Rose scores in parallel
|
| 246 |
+
return PentachoronStabilizer.rose_score_magnitude(features_exp.reshape(-1, 128), all_vertices.repeat(features.shape[0], 1, 1)).reshape(features.shape[0], -1)
|
| 247 |
+
else:
|
| 248 |
+
# Stack all centroids
|
| 249 |
+
centroids = torch.stack([penta.centroid_norm for penta in self.class_pentachora]) # [100, 128]
|
| 250 |
+
features_norm = F.normalize(features, dim=-1) # [B, 128]
|
| 251 |
+
return torch.matmul(features_norm, centroids.T) # [B, 100]
|
| 252 |
+
|
| 253 |
|
| 254 |
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
| 255 |
"""Extract features from images."""
|