Gabriele commited on
Commit
53edac2
·
1 Parent(s): 37b751d

Using safetensors for weights loading

Browse files
Files changed (2) hide show
  1. megaloc_model.py +156 -354
  2. model.safetensors +2 -2
megaloc_model.py CHANGED
@@ -18,227 +18,143 @@ import torchvision.transforms.functional as tfm
18
  from huggingface_hub import PyTorchModelHubMixin
19
 
20
 
21
- # ==============================================================================
22
- # Optimal Transport Feature Aggregation
23
- # ==============================================================================
24
- # The following implements an optimal transport-based feature aggregation module
25
- # that converts local patch features into a compact global descriptor.
26
- # ==============================================================================
27
-
28
-
29
- def sinkhorn_log_iterations(
30
- source_log_weights: torch.Tensor,
31
- target_log_weights: torch.Tensor,
32
- cost_matrix: torch.Tensor,
33
- num_iterations: int = 20,
34
- regularization: float = 1.0,
35
- ) -> torch.Tensor:
36
- """Compute optimal transport plan using Sinkhorn iterations in log space.
37
-
38
- This implements the Sinkhorn-Knopp algorithm for computing the entropy-regularized
39
- optimal transport plan between two distributions. The log-space formulation
40
- provides numerical stability.
41
-
42
  Args:
43
- source_log_weights: Log of source distribution weights [batch, m+1]
44
- target_log_weights: Log of target distribution weights [batch, n]
45
- cost_matrix: Cost/score matrix [batch, m+1, n]
46
- num_iterations: Number of Sinkhorn iterations
47
- regularization: Entropy regularization strength
48
-
49
- Returns:
50
- Log of the transport plan matrix [batch, m+1, n]
 
 
51
  """
52
- # Apply regularization scaling
53
- scaled_costs = cost_matrix / regularization
54
 
55
- # Initialize dual variables
56
- dual_source = torch.zeros_like(source_log_weights)
57
- dual_target = torch.zeros_like(target_log_weights)
58
 
59
- # Sinkhorn iterations: alternating row and column normalization
60
- for _ in range(num_iterations):
61
- # Row normalization (update source dual)
62
- dual_source = source_log_weights - torch.logsumexp(scaled_costs + dual_target.unsqueeze(1), dim=2).squeeze()
63
- # Column normalization (update target dual)
64
- dual_target = target_log_weights - torch.logsumexp(scaled_costs + dual_source.unsqueeze(2), dim=1).squeeze()
65
 
66
- # Compute final transport plan
67
- transport_plan = scaled_costs + dual_source.unsqueeze(2) + dual_target.unsqueeze(1)
68
- return transport_plan
69
 
70
 
71
- def compute_soft_assignments(
72
- affinity_scores: torch.Tensor,
73
- slack_logit: float = 1.0,
74
- num_iterations: int = 3,
75
- regularization: float = 1.0,
76
- ) -> torch.Tensor:
77
- """Compute soft cluster assignments using optimal transport with slack.
 
 
78
 
79
- Augments the affinity matrix with a slack row to handle unassigned features,
80
- then applies Sinkhorn normalization to get valid transport probabilities.
 
 
 
 
 
81
 
82
- Args:
83
- affinity_scores: Raw affinity scores [batch, num_clusters, num_patches]
84
- slack_logit: Initial logit value for the slack row
85
- num_iterations: Number of Sinkhorn iterations
86
- regularization: Entropy regularization strength
87
 
88
- Returns:
89
- Log-probabilities of assignments [batch, num_clusters+1, num_patches]
90
- """
91
- batch_size, num_clusters, num_patches = affinity_scores.size()
92
-
93
- # Augment score matrix with slack row for handling outliers/unmatched
94
- augmented_scores = torch.empty(
95
- batch_size,
96
- num_clusters + 1,
97
- num_patches,
98
- dtype=affinity_scores.dtype,
99
- device=affinity_scores.device,
100
- )
101
- augmented_scores[:, :num_clusters, :num_patches] = affinity_scores
102
- augmented_scores[:, num_clusters, :] = slack_logit
103
-
104
- # Prepare log-weights for source (clusters + slack) and target (patches)
105
- log_normalization = -torch.tensor(math.log(num_patches + num_clusters), device=affinity_scores.device)
106
-
107
- # Source weights: uniform over clusters, extra mass on slack
108
- source_log = log_normalization.expand(num_clusters + 1).contiguous()
109
- source_log = source_log.clone()
110
- source_log[-1] = source_log[-1] + math.log(num_patches - num_clusters)
111
-
112
- # Target weights: uniform over patches
113
- target_log = log_normalization.expand(num_patches).contiguous()
114
-
115
- # Expand to batch dimension
116
- source_log = source_log.expand(batch_size, -1)
117
- target_log = target_log.expand(batch_size, -1)
118
-
119
- # Solve optimal transport
120
- log_transport = sinkhorn_log_iterations(
121
- source_log,
122
- target_log,
123
- augmented_scores,
124
- num_iterations=num_iterations,
125
- regularization=regularization,
126
- )
127
-
128
- return log_transport - log_normalization
129
-
130
-
131
- class FeatureAggregationHead(nn.Module):
132
  """Optimal transport-based aggregation of local features into global descriptor.
133
 
134
- This module learns to aggregate local patch features into a compact global
135
- representation using differentiable optimal transport. It produces:
136
- 1. A global scene token from the CLS token
137
- 2. Cluster-aggregated local descriptors weighted by transport probabilities
138
-
139
- The final descriptor is the L2-normalized concatenation of both components.
140
 
141
  Args:
142
- input_channels: Number of input feature channels (from backbone)
143
- num_clusters: Number of learned cluster centers
144
- cluster_channels: Dimensionality of each cluster descriptor
145
- global_token_dim: Dimensionality of the global scene token
146
- hidden_dim: Hidden dimension for MLPs
147
- dropout_rate: Dropout probability (0 to disable)
148
  """
149
 
150
  def __init__(
151
  self,
152
- input_channels: int = 1536,
153
- num_clusters: int = 64,
154
- cluster_channels: int = 128,
155
- global_token_dim: int = 256,
156
- hidden_dim: int = 512,
157
- dropout_rate: float = 0.3,
158
  ) -> None:
159
  super().__init__()
160
 
161
- self.input_channels = input_channels
162
  self.num_clusters = num_clusters
163
- self.cluster_channels = cluster_channels
164
- self.global_token_dim = global_token_dim
165
- self.hidden_dim = hidden_dim
166
-
167
- # Dropout layer (or identity if disabled)
168
- regularization = nn.Dropout(dropout_rate) if dropout_rate > 0 else nn.Identity()
169
-
170
- # MLP to project CLS token to global scene descriptor
171
- self.global_token_mlp = nn.Sequential(
172
- nn.Linear(self.input_channels, self.hidden_dim),
173
- nn.ReLU(),
174
- nn.Linear(self.hidden_dim, self.global_token_dim),
175
  )
176
-
177
- # Convolutional MLP to project patch features to cluster descriptors
178
- self.descriptor_projection = nn.Sequential(
179
- nn.Conv2d(self.input_channels, self.hidden_dim, 1),
180
- regularization,
181
  nn.ReLU(),
182
- nn.Conv2d(self.hidden_dim, self.cluster_channels, 1),
183
  )
184
-
185
- # Convolutional MLP to compute cluster assignment logits
186
- self.assignment_head = nn.Sequential(
187
- nn.Conv2d(self.input_channels, self.hidden_dim, 1),
188
- regularization,
189
  nn.ReLU(),
190
- nn.Conv2d(self.hidden_dim, self.num_clusters, 1),
191
  )
 
 
192
 
193
- # Learnable slack variable for optimal transport
194
- self.slack_variable = nn.Parameter(torch.tensor(1.0))
195
-
196
- def forward(self, inputs):
197
- """Aggregate local and global features into compact descriptor.
198
-
199
  Args:
200
- inputs: Tuple of (patch_features, cls_token)
201
- - patch_features: [B, C, H, W] spatial feature map
202
- - cls_token: [B, C] global CLS token
203
 
204
  Returns:
205
- Global descriptor [B, num_clusters * cluster_channels + global_token_dim]
206
  """
207
- patch_features, cls_token = inputs
208
-
209
- # Project patch features to cluster descriptors: [B, cluster_channels, H*W]
210
- local_descriptors = self.descriptor_projection(patch_features).flatten(2)
211
-
212
- # Compute assignment logits: [B, num_clusters, H*W]
213
- assignment_logits = self.assignment_head(patch_features).flatten(2)
214
-
215
- # Project CLS token to global descriptor: [B, global_token_dim]
216
- global_descriptor = self.global_token_mlp(cls_token)
217
 
218
- # Compute soft assignments via optimal transport
219
- log_assignments = compute_soft_assignments(assignment_logits, self.slack_variable, num_iterations=3)
220
- assignments = torch.exp(log_assignments)
221
 
222
- # Remove slack row (keep only cluster assignments)
223
- assignments = assignments[:, :-1, :]
 
224
 
225
- # Aggregate local descriptors weighted by assignments
226
- # assignments: [B, num_clusters, num_patches]
227
- # local_descriptors: [B, cluster_channels, num_patches]
228
- # We want: [B, cluster_channels, num_clusters]
229
- assignments = assignments.unsqueeze(1).repeat(1, self.cluster_channels, 1, 1)
230
- local_descriptors = local_descriptors.unsqueeze(2).repeat(1, 1, self.num_clusters, 1)
231
 
232
- # Weighted sum over patches for each cluster
233
- aggregated_clusters = (local_descriptors * assignments).sum(dim=-1)
234
-
235
- # Normalize and concatenate
236
- normalized_global = F.normalize(global_descriptor, p=2, dim=-1)
237
- normalized_local = F.normalize(aggregated_clusters, p=2, dim=1).flatten(1)
238
-
239
- combined = torch.cat([normalized_global, normalized_local], dim=-1)
240
 
241
- return F.normalize(combined, p=2, dim=-1)
242
 
243
 
244
  # ==============================================================================
@@ -249,13 +165,7 @@ class FeatureAggregationHead(nn.Module):
249
  class PatchEmbedding(nn.Module):
250
  """Convert image patches to embeddings using a convolutional layer."""
251
 
252
- def __init__(
253
- self,
254
- image_size: int = 518,
255
- patch_size: int = 14,
256
- in_channels: int = 3,
257
- embed_dim: int = 768,
258
- ):
259
  super().__init__()
260
  self.image_size = image_size
261
  self.patch_size = patch_size
@@ -263,11 +173,8 @@ class PatchEmbedding(nn.Module):
263
  self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
264
 
265
  def forward(self, x: torch.Tensor) -> torch.Tensor:
266
- # x: [B, C, H, W] -> [B, embed_dim, H/patch_size, W/patch_size]
267
  x = self.proj(x)
268
- # Flatten spatial dimensions: [B, embed_dim, num_patches]
269
  x = x.flatten(2)
270
- # Transpose to [B, num_patches, embed_dim]
271
  x = x.transpose(1, 2)
272
  return x
273
 
@@ -287,12 +194,7 @@ class MultiHeadAttention(nn.Module):
287
  """Multi-head self-attention module."""
288
 
289
  def __init__(
290
- self,
291
- dim: int,
292
- num_heads: int = 12,
293
- qkv_bias: bool = True,
294
- attn_drop: float = 0.0,
295
- proj_drop: float = 0.0,
296
  ):
297
  super().__init__()
298
  self.num_heads = num_heads
@@ -307,17 +209,14 @@ class MultiHeadAttention(nn.Module):
307
  def forward(self, x: torch.Tensor) -> torch.Tensor:
308
  B, N, C = x.shape
309
 
310
- # Compute Q, K, V
311
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
312
- qkv = qkv.permute(2, 0, 3, 1, 4) # [3, B, num_heads, N, head_dim]
313
  q, k, v = qkv[0], qkv[1], qkv[2]
314
 
315
- # Scaled dot-product attention
316
  attn = (q @ k.transpose(-2, -1)) * self.scale
317
  attn = attn.softmax(dim=-1)
318
  attn = self.attn_drop(attn)
319
 
320
- # Apply attention to values
321
  x = (attn @ v).transpose(1, 2).reshape(B, N, C)
322
  x = self.proj(x)
323
  x = self.proj_drop(x)
@@ -328,13 +227,7 @@ class MultiHeadAttention(nn.Module):
328
  class MLP(nn.Module):
329
  """MLP module with GELU activation."""
330
 
331
- def __init__(
332
- self,
333
- in_features: int,
334
- hidden_features: int = None,
335
- out_features: int = None,
336
- drop: float = 0.0,
337
- ):
338
  super().__init__()
339
  out_features = out_features or in_features
340
  hidden_features = hidden_features or in_features
@@ -368,21 +261,11 @@ class TransformerBlock(nn.Module):
368
  ):
369
  super().__init__()
370
  self.norm1 = nn.LayerNorm(dim, eps=1e-6)
371
- self.attn = MultiHeadAttention(
372
- dim,
373
- num_heads=num_heads,
374
- qkv_bias=qkv_bias,
375
- attn_drop=attn_drop,
376
- proj_drop=drop,
377
- )
378
  self.ls1 = LayerScale(dim, init_value=init_values)
379
 
380
  self.norm2 = nn.LayerNorm(dim, eps=1e-6)
381
- self.mlp = MLP(
382
- in_features=dim,
383
- hidden_features=int(dim * mlp_ratio),
384
- drop=drop,
385
- )
386
  self.ls2 = LayerScale(dim, init_value=init_values)
387
 
388
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -391,12 +274,10 @@ class TransformerBlock(nn.Module):
391
  return x
392
 
393
 
394
- class VisionTransformerBackbone(nn.Module):
395
  """DINOv2 Vision Transformer backbone for feature extraction.
396
 
397
  This implements a ViT-B/14 architecture compatible with DINOv2 weights.
398
- The positional encoding interpolation matches the Facebook implementation
399
- for exact output compatibility.
400
  """
401
 
402
  def __init__(
@@ -413,54 +294,34 @@ class VisionTransformerBackbone(nn.Module):
413
  super().__init__()
414
  self.patch_size = patch_size
415
  self.embed_dim = embed_dim
416
- self.num_channels = embed_dim # For compatibility
417
 
418
- # Patch embedding
419
  self.patch_embed = PatchEmbedding(
420
- image_size=image_size,
421
- patch_size=patch_size,
422
- in_channels=in_channels,
423
- embed_dim=embed_dim,
424
  )
425
 
426
- # Positional encoding interpolation parameters (matching Facebook's DINO)
427
  self.interpolate_offset = 0.1
428
  self.interpolate_antialias = False
429
 
430
- # Class token
431
  self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
432
-
433
- # Positional embedding (for 37x37 = 1369 patches + 1 CLS token = 1370)
434
  num_patches = (image_size // patch_size) ** 2
435
  self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
436
 
437
- # Transformer blocks
438
  self.blocks = nn.ModuleList(
439
  [
440
- TransformerBlock(
441
- dim=embed_dim,
442
- num_heads=num_heads,
443
- mlp_ratio=mlp_ratio,
444
- qkv_bias=qkv_bias,
445
- )
446
  for _ in range(depth)
447
  ]
448
  )
449
 
450
- # Final layer norm
451
  self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
452
 
453
  def interpolate_pos_encoding(self, x: torch.Tensor, w: int, h: int) -> torch.Tensor:
454
- """Interpolate positional encoding for different input sizes.
455
-
456
- This matches the Facebook DINOv2 implementation exactly, including
457
- the interpolation offset kludge for backward compatibility.
458
- """
459
  previous_dtype = x.dtype
460
- npatch = x.shape[1] - 1 # Exclude CLS token
461
- N = self.pos_embed.shape[1] - 1 # Number of patches in pos_embed
462
 
463
- # If input matches training resolution, return as-is
464
  if npatch == N and w == h:
465
  return self.pos_embed
466
 
@@ -471,10 +332,8 @@ class VisionTransformerBackbone(nn.Module):
471
  dim = x.shape[-1]
472
  w0 = w // self.patch_size
473
  h0 = h // self.patch_size
474
- M = int(math.sqrt(N)) # Original number of patches per dimension
475
 
476
- # Use scale_factor with offset for backward compatibility
477
- # This is the "kludge" from Facebook's DINO implementation
478
  sx = float(w0 + self.interpolate_offset) / M
479
  sy = float(h0 + self.interpolate_offset) / M
480
 
@@ -490,22 +349,6 @@ class VisionTransformerBackbone(nn.Module):
490
 
491
  return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
492
 
493
- def prepare_tokens(self, x: torch.Tensor) -> torch.Tensor:
494
- """Prepare input tokens with positional encoding."""
495
- B, C, W, H = x.shape
496
-
497
- # Patch embedding
498
- x = self.patch_embed(x)
499
-
500
- # Add CLS token
501
- cls_tokens = self.cls_token.expand(B, -1, -1)
502
- x = torch.cat((cls_tokens, x), dim=1)
503
-
504
- # Add positional encoding
505
- x = x + self.interpolate_pos_encoding(x, W, H)
506
-
507
- return x
508
-
509
  def forward(self, images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
510
  """Extract features from images.
511
 
@@ -513,80 +356,52 @@ class VisionTransformerBackbone(nn.Module):
513
  images: Input images [B, 3, H, W] where H, W are multiples of 14
514
 
515
  Returns:
516
- Tuple of:
517
- - patch_features: [B, 768, H//14, W//14] spatial feature map
518
- - cls_token: [B, 768] global CLS token
519
  """
520
- batch_size, _, height, width = images.shape
521
 
522
- # Prepare tokens with positional encoding
523
- x = self.prepare_tokens(images)
 
 
524
 
525
- # Apply transformer blocks
526
  for block in self.blocks:
527
  x = block(x)
528
 
529
- # Apply final layer norm
530
  x = self.norm(x)
531
 
532
- # Extract CLS token and patch tokens
533
  cls_token = x[:, 0]
534
  patch_tokens = x[:, 1:]
535
-
536
- # Reshape patch tokens to spatial format
537
- h_patches = height // self.patch_size
538
- w_patches = width // self.patch_size
539
- patch_features = patch_tokens.reshape(batch_size, h_patches, w_patches, self.embed_dim).permute(0, 3, 1, 2)
540
 
541
  return patch_features, cls_token
542
 
543
 
544
  # ==============================================================================
545
- # Feature Dimension Reduction
546
  # ==============================================================================
547
 
548
 
549
- class DescriptorAggregator(nn.Module):
550
- """Wrapper combining feature aggregation with linear projection.
551
-
552
- Applies the optimal transport aggregation followed by a linear layer
553
- to reduce dimensionality to the desired output size.
554
-
555
- Args:
556
- output_dim: Final descriptor dimensionality
557
- aggregator_config: Configuration for FeatureAggregationHead
558
- aggregator_output_dim: Output dimension of the aggregation head
559
- """
560
-
561
- def __init__(self, output_dim: int, aggregator_config: dict, aggregator_output_dim: int):
562
  super().__init__()
563
- self.aggregation = FeatureAggregationHead(**aggregator_config)
564
- self.projection = nn.Linear(aggregator_output_dim, output_dim)
565
 
566
  def forward(self, x):
567
- aggregated = self.aggregation(x)
568
- return self.projection(aggregated)
569
-
570
-
571
- # ==============================================================================
572
- # L2 Normalization Layer
573
- # ==============================================================================
574
 
575
 
576
- class L2Normalize(nn.Module):
577
- """L2 normalization layer."""
578
-
579
- def __init__(self, dim: int = -1):
580
  super().__init__()
581
- self.dim = dim
582
-
583
- def forward(self, x: torch.Tensor) -> torch.Tensor:
584
- return F.normalize(x, p=2, dim=self.dim)
585
-
586
 
587
- # ==============================================================================
588
- # Main Model
589
- # ==============================================================================
590
 
591
 
592
  class MegaLoc(nn.Module, PyTorchModelHubMixin):
@@ -604,10 +419,9 @@ class MegaLoc(nn.Module, PyTorchModelHubMixin):
604
  mlp_dim: Hidden dimension for MLPs (default: 512)
605
 
606
  Example:
607
- >>> model = MegaLoc.from_pretrained("gberton/MegaLoc")
608
  >>> model.eval()
609
- >>> image = torch.randn(1, 3, 322, 322) # Will auto-resize to 322x322
610
- >>> descriptor = model(image) # [1, 8448]
611
  """
612
 
613
  def __init__(
@@ -620,25 +434,21 @@ class MegaLoc(nn.Module, PyTorchModelHubMixin):
620
  ):
621
  super().__init__()
622
 
623
- self.backbone = VisionTransformerBackbone()
624
-
625
- # Aggregator output: num_clusters * cluster_dim + token_dim
626
- self.aggregator_output_dim = num_clusters * cluster_dim + token_dim
627
-
628
- self.aggregator = DescriptorAggregator(
629
- output_dim=feat_dim,
630
- aggregator_config={
631
- "input_channels": self.backbone.num_channels,
632
  "num_clusters": num_clusters,
633
- "cluster_channels": cluster_dim,
634
- "global_token_dim": token_dim,
635
- "hidden_dim": mlp_dim,
636
  },
637
- aggregator_output_dim=self.aggregator_output_dim,
638
  )
639
-
640
  self.feat_dim = feat_dim
641
- self.normalize = L2Normalize()
642
 
643
  def forward(self, images: torch.Tensor) -> torch.Tensor:
644
  """Extract global descriptor from images.
@@ -649,19 +459,11 @@ class MegaLoc(nn.Module, PyTorchModelHubMixin):
649
  Returns:
650
  L2-normalized descriptors [B, feat_dim]
651
  """
652
- batch_size, channels, height, width = images.shape
653
-
654
- # Ensure dimensions are multiples of 14 (ViT patch size)
655
- if height % 14 != 0 or width % 14 != 0:
656
- height = round(height / 14) * 14
657
- width = round(width / 14) * 14
658
- images = tfm.resize(images, [height, width], antialias=True)
659
-
660
- # Extract backbone features
661
- features = self.backbone(images)
662
-
663
- # Aggregate into global descriptor
664
- descriptor = self.aggregator(features)
665
-
666
- # Final L2 normalization
667
- return self.normalize(descriptor)
 
18
  from huggingface_hub import PyTorchModelHubMixin
19
 
20
 
21
+ # Code adapted from OpenGlue, MIT license
22
+ # https://github.com/ucuapps/OpenGlue/blob/main/models/superglue/optimal_transport.py
23
+ def log_otp_solver(log_a, log_b, M, num_iters: int = 20, reg: float = 1.0) -> torch.Tensor:
24
+ r"""Sinkhorn matrix scaling algorithm for Differentiable Optimal Transport problem.
25
+ This function solves the optimization problem and returns the OT matrix for the given parameters.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  Args:
27
+ log_a : torch.Tensor
28
+ Source weights
29
+ log_b : torch.Tensor
30
+ Target weights
31
+ M : torch.Tensor
32
+ metric cost matrix
33
+ num_iters : int, default=100
34
+ The number of iterations.
35
+ reg : float, default=1.0
36
+ regularization value
37
  """
38
+ M = M / reg # regularization
 
39
 
40
+ u, v = torch.zeros_like(log_a), torch.zeros_like(log_b)
 
 
41
 
42
+ for _ in range(num_iters):
43
+ u = log_a - torch.logsumexp(M + v.unsqueeze(1), dim=2).squeeze()
44
+ v = log_b - torch.logsumexp(M + u.unsqueeze(2), dim=1).squeeze()
 
 
 
45
 
46
+ return M + u.unsqueeze(2) + v.unsqueeze(1)
 
 
47
 
48
 
49
+ # Code adapted from OpenGlue, MIT license
50
+ # https://github.com/ucuapps/OpenGlue/blob/main/models/superglue/superglue.py
51
+ def get_matching_probs(S, dustbin_score=1.0, num_iters=3, reg=1.0):
52
+ """sinkhorn"""
53
+ batch_size, m, n = S.size()
54
+ # augment scores matrix
55
+ S_aug = torch.empty(batch_size, m + 1, n, dtype=S.dtype, device=S.device)
56
+ S_aug[:, :m, :n] = S
57
+ S_aug[:, m, :] = dustbin_score
58
 
59
+ # prepare normalized source and target log-weights
60
+ norm = -torch.tensor(math.log(n + m), device=S.device)
61
+ log_a, log_b = norm.expand(m + 1).contiguous(), norm.expand(n).contiguous()
62
+ log_a[-1] = log_a[-1] + math.log(n - m)
63
+ log_a, log_b = log_a.expand(batch_size, -1), log_b.expand(batch_size, -1)
64
+ log_P = log_otp_solver(log_a, log_b, S_aug, num_iters=num_iters, reg=reg)
65
+ return log_P - norm
66
 
 
 
 
 
 
67
 
68
+ class FeatureAggregator(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  """Optimal transport-based aggregation of local features into global descriptor.
70
 
71
+ This module aggregates local patch features into a compact global representation
72
+ using differentiable optimal transport.
 
 
 
 
73
 
74
  Args:
75
+ num_channels: Number of input feature channels (from backbone)
76
+ num_clusters: Number of cluster centers
77
+ cluster_dim: Dimensionality of cluster descriptors
78
+ token_dim: Dimensionality of global scene token
79
+ mlp_dim: Hidden dimension for MLPs
80
+ dropout: Dropout probability (0 to disable)
81
  """
82
 
83
  def __init__(
84
  self,
85
+ num_channels=1536,
86
+ num_clusters=64,
87
+ cluster_dim=128,
88
+ token_dim=256,
89
+ mlp_dim=512,
90
+ dropout=0.3,
91
  ) -> None:
92
  super().__init__()
93
 
94
+ self.num_channels = num_channels
95
  self.num_clusters = num_clusters
96
+ self.cluster_dim = cluster_dim
97
+ self.token_dim = token_dim
98
+ self.mlp_dim = mlp_dim
99
+
100
+ if dropout > 0:
101
+ dropout = nn.Dropout(dropout)
102
+ else:
103
+ dropout = nn.Identity()
104
+
105
+ # MLP for global scene token
106
+ self.token_features = nn.Sequential(
107
+ nn.Linear(self.num_channels, self.mlp_dim), nn.ReLU(), nn.Linear(self.mlp_dim, self.token_dim)
108
  )
109
+ # MLP for local features
110
+ self.cluster_features = nn.Sequential(
111
+ nn.Conv2d(self.num_channels, self.mlp_dim, 1),
112
+ dropout,
 
113
  nn.ReLU(),
114
+ nn.Conv2d(self.mlp_dim, self.cluster_dim, 1),
115
  )
116
+ # MLP for score matrix
117
+ self.score = nn.Sequential(
118
+ nn.Conv2d(self.num_channels, self.mlp_dim, 1),
119
+ dropout,
 
120
  nn.ReLU(),
121
+ nn.Conv2d(self.mlp_dim, self.num_clusters, 1),
122
  )
123
+ # Dustbin parameter
124
+ self.dust_bin = nn.Parameter(torch.tensor(1.0))
125
 
126
+ def forward(self, x):
127
+ """
 
 
 
 
128
  Args:
129
+ x: Tuple of (features, token)
130
+ features: [B, C, H, W] spatial feature map
131
+ token: [B, C] global CLS token
132
 
133
  Returns:
134
+ Global descriptor [B, num_clusters * cluster_dim + token_dim]
135
  """
136
+ x, t = x
 
 
 
 
 
 
 
 
 
137
 
138
+ f = self.cluster_features(x).flatten(2)
139
+ p = self.score(x).flatten(2)
140
+ t = self.token_features(t)
141
 
142
+ p = get_matching_probs(p, self.dust_bin, 3)
143
+ p = torch.exp(p)
144
+ p = p[:, :-1, :]
145
 
146
+ p = p.unsqueeze(1).repeat(1, self.cluster_dim, 1, 1)
147
+ f = f.unsqueeze(2).repeat(1, 1, self.num_clusters, 1)
 
 
 
 
148
 
149
+ f = torch.cat(
150
+ [
151
+ F.normalize(t, p=2, dim=-1),
152
+ F.normalize((f * p).sum(dim=-1), p=2, dim=1).flatten(1),
153
+ ],
154
+ dim=-1,
155
+ )
 
156
 
157
+ return F.normalize(f, p=2, dim=-1)
158
 
159
 
160
  # ==============================================================================
 
165
  class PatchEmbedding(nn.Module):
166
  """Convert image patches to embeddings using a convolutional layer."""
167
 
168
+ def __init__(self, image_size: int = 518, patch_size: int = 14, in_channels: int = 3, embed_dim: int = 768):
 
 
 
 
 
 
169
  super().__init__()
170
  self.image_size = image_size
171
  self.patch_size = patch_size
 
173
  self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
174
 
175
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
176
  x = self.proj(x)
 
177
  x = x.flatten(2)
 
178
  x = x.transpose(1, 2)
179
  return x
180
 
 
194
  """Multi-head self-attention module."""
195
 
196
  def __init__(
197
+ self, dim: int, num_heads: int = 12, qkv_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0
 
 
 
 
 
198
  ):
199
  super().__init__()
200
  self.num_heads = num_heads
 
209
  def forward(self, x: torch.Tensor) -> torch.Tensor:
210
  B, N, C = x.shape
211
 
 
212
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
213
+ qkv = qkv.permute(2, 0, 3, 1, 4)
214
  q, k, v = qkv[0], qkv[1], qkv[2]
215
 
 
216
  attn = (q @ k.transpose(-2, -1)) * self.scale
217
  attn = attn.softmax(dim=-1)
218
  attn = self.attn_drop(attn)
219
 
 
220
  x = (attn @ v).transpose(1, 2).reshape(B, N, C)
221
  x = self.proj(x)
222
  x = self.proj_drop(x)
 
227
  class MLP(nn.Module):
228
  """MLP module with GELU activation."""
229
 
230
+ def __init__(self, in_features: int, hidden_features: int = None, out_features: int = None, drop: float = 0.0):
 
 
 
 
 
 
231
  super().__init__()
232
  out_features = out_features or in_features
233
  hidden_features = hidden_features or in_features
 
261
  ):
262
  super().__init__()
263
  self.norm1 = nn.LayerNorm(dim, eps=1e-6)
264
+ self.attn = MultiHeadAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
 
 
 
 
 
 
265
  self.ls1 = LayerScale(dim, init_value=init_values)
266
 
267
  self.norm2 = nn.LayerNorm(dim, eps=1e-6)
268
+ self.mlp = MLP(in_features=dim, hidden_features=int(dim * mlp_ratio), drop=drop)
 
 
 
 
269
  self.ls2 = LayerScale(dim, init_value=init_values)
270
 
271
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
274
  return x
275
 
276
 
277
+ class DINOv2(nn.Module):
278
  """DINOv2 Vision Transformer backbone for feature extraction.
279
 
280
  This implements a ViT-B/14 architecture compatible with DINOv2 weights.
 
 
281
  """
282
 
283
  def __init__(
 
294
  super().__init__()
295
  self.patch_size = patch_size
296
  self.embed_dim = embed_dim
297
+ self.num_channels = embed_dim
298
 
 
299
  self.patch_embed = PatchEmbedding(
300
+ image_size=image_size, patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim
 
 
 
301
  )
302
 
 
303
  self.interpolate_offset = 0.1
304
  self.interpolate_antialias = False
305
 
 
306
  self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
 
 
307
  num_patches = (image_size // patch_size) ** 2
308
  self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
309
 
 
310
  self.blocks = nn.ModuleList(
311
  [
312
+ TransformerBlock(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias)
 
 
 
 
 
313
  for _ in range(depth)
314
  ]
315
  )
316
 
 
317
  self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
318
 
319
  def interpolate_pos_encoding(self, x: torch.Tensor, w: int, h: int) -> torch.Tensor:
320
+ """Interpolate positional encoding for different input sizes."""
 
 
 
 
321
  previous_dtype = x.dtype
322
+ npatch = x.shape[1] - 1
323
+ N = self.pos_embed.shape[1] - 1
324
 
 
325
  if npatch == N and w == h:
326
  return self.pos_embed
327
 
 
332
  dim = x.shape[-1]
333
  w0 = w // self.patch_size
334
  h0 = h // self.patch_size
335
+ M = int(math.sqrt(N))
336
 
 
 
337
  sx = float(w0 + self.interpolate_offset) / M
338
  sy = float(h0 + self.interpolate_offset) / M
339
 
 
349
 
350
  return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
351
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  def forward(self, images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
353
  """Extract features from images.
354
 
 
356
  images: Input images [B, 3, H, W] where H, W are multiples of 14
357
 
358
  Returns:
359
+ Tuple of (patch_features [B, 768, H//14, W//14], cls_token [B, 768])
 
 
360
  """
361
+ B, _, H, W = images.shape
362
 
363
+ x = self.patch_embed(images)
364
+ cls_tokens = self.cls_token.expand(B, -1, -1)
365
+ x = torch.cat((cls_tokens, x), dim=1)
366
+ x = x + self.interpolate_pos_encoding(x, W, H)
367
 
 
368
  for block in self.blocks:
369
  x = block(x)
370
 
 
371
  x = self.norm(x)
372
 
 
373
  cls_token = x[:, 0]
374
  patch_tokens = x[:, 1:]
375
+ patch_features = patch_tokens.reshape(B, H // self.patch_size, W // self.patch_size, self.embed_dim).permute(
376
+ 0, 3, 1, 2
377
+ )
 
 
378
 
379
  return patch_features, cls_token
380
 
381
 
382
  # ==============================================================================
383
+ # Main Model
384
  # ==============================================================================
385
 
386
 
387
+ class L2Norm(nn.Module):
388
+ def __init__(self, dim=1):
 
 
 
 
 
 
 
 
 
 
 
389
  super().__init__()
390
+ self.dim = dim
 
391
 
392
  def forward(self, x):
393
+ return F.normalize(x, p=2.0, dim=self.dim)
 
 
 
 
 
 
394
 
395
 
396
+ class Aggregator(nn.Module):
397
+ def __init__(self, feat_dim, agg_config, salad_out_dim):
 
 
398
  super().__init__()
399
+ self.agg = FeatureAggregator(**agg_config)
400
+ self.linear = nn.Linear(salad_out_dim, feat_dim)
 
 
 
401
 
402
+ def forward(self, x):
403
+ x = self.agg(x)
404
+ return self.linear(x)
405
 
406
 
407
  class MegaLoc(nn.Module, PyTorchModelHubMixin):
 
419
  mlp_dim: Hidden dimension for MLPs (default: 512)
420
 
421
  Example:
422
+ >>> model = torch.hub.load("gmberton/MegaLoc", "get_trained_model")
423
  >>> model.eval()
424
+ >>> descriptor = model(image) # [B, 8448]
 
425
  """
426
 
427
  def __init__(
 
434
  ):
435
  super().__init__()
436
 
437
+ self.backbone = DINOv2()
438
+ self.salad_out_dim = num_clusters * cluster_dim + token_dim
439
+ self.aggregator = Aggregator(
440
+ feat_dim=feat_dim,
441
+ agg_config={
442
+ "num_channels": self.backbone.num_channels,
 
 
 
443
  "num_clusters": num_clusters,
444
+ "cluster_dim": cluster_dim,
445
+ "token_dim": token_dim,
446
+ "mlp_dim": mlp_dim,
447
  },
448
+ salad_out_dim=self.salad_out_dim,
449
  )
 
450
  self.feat_dim = feat_dim
451
+ self.l2norm = L2Norm()
452
 
453
  def forward(self, images: torch.Tensor) -> torch.Tensor:
454
  """Extract global descriptor from images.
 
459
  Returns:
460
  L2-normalized descriptors [B, feat_dim]
461
  """
462
+ b, c, h, w = images.shape
463
+ if h % 14 != 0 or w % 14 != 0:
464
+ h = round(h / 14) * 14
465
+ w = round(w / 14) * 14
466
+ images = tfm.resize(images, [h, w], antialias=True)
467
+ features = self.aggregator(self.backbone(images))
468
+ features = self.l2norm(features)
469
+ return features
 
 
 
 
 
 
 
 
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9d8716ac9959a86e00494f605a4be46aebed15694ab4ad77c27b91ada9ab51e4
3
- size 914577620
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4f9f2bcb60018f91eb6a8e061ed054fd55654e10c2569cf13841ea986ffb4f8
3
+ size 914577436