AbstractPhil commited on
Commit
1d40566
·
verified ·
1 Parent(s): 04248a4

Removed iterations and optimized further.

Browse files
Files changed (1) hide show
  1. vit_zana_v3.py +26 -19
vit_zana_v3.py CHANGED
@@ -203,6 +203,15 @@ class BaselineViT(nn.Module):
203
  # Temperature for similarity-based classification
204
  self.temperature = nn.Parameter(torch.ones(1) * np.log(1/0.07))
205
 
 
 
 
 
 
 
 
 
 
206
  # Initialize weights
207
  self.init_weights()
208
 
@@ -220,29 +229,27 @@ class BaselineViT(nn.Module):
220
  nn.init.ones_(m.weight)
221
  nn.init.zeros_(m.bias)
222
 
 
223
  def get_class_centroids(self) -> torch.Tensor:
224
- """Get all class centroids for similarity computation."""
225
- centroids = torch.stack([
226
- penta.get_centroid() for penta in self.class_pentachora
227
- ])
228
- return F.normalize(centroids, dim=-1)
229
-
230
  def compute_pentachora_similarities(self, features: torch.Tensor) -> torch.Tensor:
231
  """
232
- Compute similarities between features and all class pentachora.
233
-
234
- Args:
235
- features: [batch, dim] features to compare
236
-
237
- Returns:
238
- similarities: [batch, num_classes]
239
  """
240
- similarities = []
241
- for penta in self.class_pentachora:
242
- sim = penta.compute_similarity(features, mode=self.similarity_mode)
243
- similarities.append(sim)
244
-
245
- return torch.stack(similarities, dim=-1)
 
 
 
 
 
 
 
246
 
247
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
248
  """Extract features from images."""
 
203
  # Temperature for similarity-based classification
204
  self.temperature = nn.Parameter(torch.ones(1) * np.log(1/0.07))
205
 
206
+ self.register_buffer(
207
+ 'all_centroids',
208
+ torch.stack([penta.centroid for penta in self.class_pentachora])
209
+ )
210
+ self.register_buffer(
211
+ 'all_centroids_norm',
212
+ F.normalize(self.all_centroids, dim=-1)
213
+ )
214
+
215
  # Initialize weights
216
  self.init_weights()
217
 
 
229
  nn.init.ones_(m.weight)
230
  nn.init.zeros_(m.bias)
231
 
232
+ # Then get_class_centroids becomes:
233
  def get_class_centroids(self) -> torch.Tensor:
234
+ return self.all_centroids_norm
235
+
 
 
 
 
236
  def compute_pentachora_similarities(self, features: torch.Tensor) -> torch.Tensor:
237
  """
238
+ Compute similarities between features and all class pentachora (vectorized).
 
 
 
 
 
 
239
  """
240
+ if self.similarity_mode == 'rose':
241
+ # Stack all vertices into single tensor for batch Rose scoring
242
+ all_vertices = torch.stack([penta.vertices for penta in self.class_pentachora]) # [100, 5, 128]
243
+ # Expand features for batch computation
244
+ features_exp = features.unsqueeze(1).expand(-1, self.num_classes, -1) # [B, 100, 128]
245
+ # Compute Rose scores in parallel
246
+ return PentachoronStabilizer.rose_score_magnitude(features_exp.reshape(-1, 128), all_vertices.repeat(features.shape[0], 1, 1)).reshape(features.shape[0], -1)
247
+ else:
248
+ # Stack all centroids
249
+ centroids = torch.stack([penta.centroid_norm for penta in self.class_pentachora]) # [100, 128]
250
+ features_norm = F.normalize(features, dim=-1) # [B, 128]
251
+ return torch.matmul(features_norm, centroids.T) # [B, 100]
252
+
253
 
254
  def forward_features(self, x: torch.Tensor) -> torch.Tensor:
255
  """Extract features from images."""