Gabriele commited on
Commit
37b751d
·
1 Parent(s): c6d8ab9

Using safetensors for weights loading

Browse files
Files changed (4) hide show
  1. README.md +41 -6
  2. config.json +4 -4
  3. megaloc_model.py +603 -192
  4. model.safetensors +2 -2
README.md CHANGED
@@ -1,17 +1,52 @@
1
  ---
 
 
 
2
  tags:
3
- - model_hub_mixin
 
4
  - pytorch_model_hub_mixin
5
  - arxiv:2502.17237
6
- license: mit
7
  ---
8
 
9
  # MegaLoc
10
- MegaLoc is an image retrieval model for any localization task, which achieves SOTA on most VPR datasets, including indoor and outdoor ones.
11
- You can find details in our paper [MegaLoc: One Retrieval to Place Them All](https://arxiv.org/abs/2502.17237)
12
 
13
- ### Qualitataive examples
14
- Here are some examples of top-1 retrieved images from the SF-XL test set, which has 2.8M images as database.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  ![teaser](https://github.com/user-attachments/assets/a90b8d4c-ab53-4151-aacc-93493d583713)
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ pipeline_tag: image-feature-extraction
3
+ library_name: pytorch
4
+ license: mit
5
  tags:
6
+ - visual-place-recognition
7
+ - image-retrieval
8
  - pytorch_model_hub_mixin
9
  - arxiv:2502.17237
 
10
  ---
11
 
12
  # MegaLoc
 
 
13
 
14
+ MegaLoc is an image retrieval model for visual place recognition (VPR) that achieves state-of-the-art on most VPR datasets, including indoor and outdoor environments.
15
+
16
+ **Paper:** [MegaLoc: One Retrieval to Place Them All](https://arxiv.org/abs/2502.17237) (CVPR 2025 Workshop)
17
+
18
+ **GitHub:** [gmberton/MegaLoc](https://github.com/gmberton/MegaLoc)
19
+
20
+ ## Usage
21
+
22
+ ```python
23
+ import torch
24
+ model = torch.hub.load("gmberton/MegaLoc", "get_trained_model")
25
+ model.eval()
26
+
27
+ # Extract descriptor from an image
28
+ image = torch.randn(1, 3, 322, 322) # [B, 3, H, W] - any size works
29
+ with torch.no_grad():
30
+ descriptor = model(image) # [B, 8448] L2-normalized descriptor
31
+ ```
32
+
33
+ For benchmarking on VPR datasets, see [VPR-methods-evaluation](https://github.com/gmberton/VPR-methods-evaluation).
34
+
35
+ ## Qualitative Examples
36
+
37
+ Top-1 retrieved images from the SF-XL test set (2.8M database images):
38
 
39
  ![teaser](https://github.com/user-attachments/assets/a90b8d4c-ab53-4151-aacc-93493d583713)
40
 
41
+ ## Citation
42
+
43
+ ```bibtex
44
+ @InProceedings{Berton_2025_CVPR,
45
+ author = {Berton, Gabriele and Masone, Carlo},
46
+ title = {MegaLoc: One Retrieval to Place Them All},
47
+ booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
48
+ month = {June},
49
+ year = {2025},
50
+ pages = {2861-2867}
51
+ }
52
+ ```
config.json CHANGED
@@ -1,7 +1,7 @@
1
  {
2
- "cluster_dim": 256,
3
  "feat_dim": 8448,
4
- "mlp_dim": 512,
5
  "num_clusters": 64,
6
- "token_dim": 256
7
- }
 
 
 
1
  {
 
2
  "feat_dim": 8448,
 
3
  "num_clusters": 64,
4
+ "cluster_dim": 256,
5
+ "token_dim": 256,
6
+ "mlp_dim": 512
7
+ }
megaloc_model.py CHANGED
@@ -1,256 +1,667 @@
1
- """Code for the MegaLoc model.
2
- Much of the code in this file is from SALAD https://github.com/serizba/salad
 
 
 
 
 
 
3
  """
4
 
5
  import math
 
6
 
7
  import torch
8
  import torch.nn as nn
9
  import torch.nn.functional as F
10
- import torchvision.transforms as tfm
11
  from huggingface_hub import PyTorchModelHubMixin
12
 
13
 
14
- class MegaLocModel(nn.Module, PyTorchModelHubMixin):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def __init__(
16
  self,
17
- feat_dim=8448,
18
- num_clusters=64,
19
- cluster_dim=256,
20
- token_dim=256,
21
- mlp_dim=512,
22
- ):
 
23
  super().__init__()
24
- self.backbone = DINOv2()
25
- self.salad_out_dim = num_clusters * cluster_dim + token_dim
26
- self.aggregator = Aggregator(
27
- feat_dim=feat_dim,
28
- agg_config={
29
- "num_channels": self.backbone.num_channels,
30
- "num_clusters": num_clusters,
31
- "cluster_dim": cluster_dim,
32
- "token_dim": token_dim,
33
- "mlp_dim": mlp_dim,
34
- },
35
- salad_out_dim=self.salad_out_dim,
 
 
 
36
  )
37
- self.feat_dim = feat_dim
38
- self.l2norm = L2Norm()
39
-
40
- def forward(self, images):
41
- b, c, h, w = images.shape
42
- if h % 14 != 0 or w % 14 != 0:
43
- # DINO needs height and width as multiple of 14, therefore resize them
44
- # to the nearest multiple of 14
45
- h = round(h / 14) * 14
46
- w = round(w / 14) * 14
47
- images = tfm.functional.resize(images, [h, w], antialias=True)
48
- features = self.aggregator(self.backbone(images))
49
- features = self.l2norm(features)
50
- return features
51
-
52
-
53
- class L2Norm(nn.Module):
54
- def __init__(self, dim=1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  super().__init__()
56
- self.dim = dim
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- def forward(self, x):
59
- return F.normalize(x, p=2.0, dim=self.dim)
60
 
 
 
61
 
62
- class Aggregator(nn.Module):
63
- def __init__(self, feat_dim, agg_config, salad_out_dim):
64
  super().__init__()
65
- self.agg = SALAD(**agg_config)
66
- self.linear = nn.Linear(salad_out_dim, feat_dim)
67
 
68
- def forward(self, x):
69
- x = self.agg(x)
70
- return self.linear(x)
71
 
72
 
73
- class DINOv2(nn.Module):
74
- def __init__(self, num_trainable_blocks=4, norm_layer=True, return_token=True):
 
 
 
 
 
 
 
 
 
75
  super().__init__()
76
- self.model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14")
77
- self.num_channels = 768
78
- self.num_trainable_blocks = num_trainable_blocks
79
- self.norm_layer = norm_layer
80
- self.return_token = return_token
81
 
82
- def forward(self, x):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  """
84
- The forward method for the DINOv2 class
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
- Parameters:
87
- x (torch.Tensor): The input tensor [B, 3, H, W]. H and W should be divisible by 14.
 
 
 
 
 
 
 
 
88
 
89
  Returns:
90
- f (torch.Tensor): The feature map [B, C, H // 14, W // 14].
91
- t (torch.Tensor): The token [B, C]. This is only returned if return_token is True.
 
92
  """
 
 
 
 
 
 
 
 
93
 
94
- B, C, H, W = x.shape
 
95
 
96
- x = self.model.prepare_tokens_with_masks(x)
 
 
97
 
98
- # First blocks are frozen
99
- with torch.no_grad():
100
- for blk in self.model.blocks[: -self.num_trainable_blocks]:
101
- x = blk(x)
102
- x = x.detach()
103
 
104
- # Last blocks are trained
105
- for blk in self.model.blocks[-self.num_trainable_blocks :]:
106
- x = blk(x)
107
 
108
- if self.norm_layer:
109
- x = self.model.norm(x)
110
 
111
- t = x[:, 0]
112
- f = x[:, 1:]
 
113
 
114
- # Reshape to (B, C, H, W)
115
- f = f.reshape((B, H // 14, W // 14, self.num_channels)).permute(0, 3, 1, 2)
116
 
117
- if self.return_token:
118
- return f, t
119
- return f
120
 
 
 
121
 
122
- # Code adapted from OpenGlue, MIT license
123
- # https://github.com/ucuapps/OpenGlue/blob/main/models/superglue/optimal_transport.py
124
- def log_otp_solver(log_a, log_b, M, num_iters: int = 20, reg: float = 1.0) -> torch.Tensor:
125
- r"""Sinkhorn matrix scaling algorithm for Differentiable Optimal Transport problem.
126
- This function solves the optimization problem and returns the OT matrix for the given parameters.
127
  Args:
128
- log_a : torch.Tensor
129
- Source weights
130
- log_b : torch.Tensor
131
- Target weights
132
- M : torch.Tensor
133
- metric cost matrix
134
- num_iters : int, default=100
135
- The number of iterations.
136
- reg : float, default=1.0
137
- regularization value
138
  """
139
- M = M / reg # regularization
140
 
141
- u, v = torch.zeros_like(log_a), torch.zeros_like(log_b)
 
 
 
142
 
143
- for _ in range(num_iters):
144
- u = log_a - torch.logsumexp(M + v.unsqueeze(1), dim=2).squeeze()
145
- v = log_b - torch.logsumexp(M + u.unsqueeze(2), dim=1).squeeze()
146
 
147
- return M + u.unsqueeze(2) + v.unsqueeze(1)
148
 
 
 
 
149
 
150
- # Code adapted from OpenGlue, MIT license
151
- # https://github.com/ucuapps/OpenGlue/blob/main/models/superglue/superglue.py
152
- def get_matching_probs(S, dustbin_score=1.0, num_iters=3, reg=1.0):
153
- """sinkhorn"""
154
- batch_size, m, n = S.size()
155
- # augment scores matrix
156
- S_aug = torch.empty(batch_size, m + 1, n, dtype=S.dtype, device=S.device)
157
- S_aug[:, :m, :n] = S
158
- S_aug[:, m, :] = dustbin_score
159
 
160
- # prepare normalized source and target log-weights
161
- norm = -torch.tensor(math.log(n + m), device=S.device)
162
- log_a, log_b = norm.expand(m + 1).contiguous(), norm.expand(n).contiguous()
163
- log_a[-1] = log_a[-1] + math.log(n - m)
164
- log_a, log_b = log_a.expand(batch_size, -1), log_b.expand(batch_size, -1)
165
- log_P = log_otp_solver(log_a, log_b, S_aug, num_iters=num_iters, reg=reg)
166
- return log_P - norm
167
 
 
 
 
168
 
169
- class SALAD(nn.Module):
170
- """
171
- This class represents the Sinkhorn Algorithm for Locally Aggregated Descriptors (SALAD) model.
172
-
173
- Attributes:
174
- num_channels (int): The number of channels of the inputs (d).
175
- num_clusters (int): The number of clusters in the model (m).
176
- cluster_dim (int): The number of channels of the clusters (l).
177
- token_dim (int): The dimension of the global scene token (g).
178
- dropout (float): The dropout rate.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  """
180
 
181
  def __init__(
182
  self,
183
- num_channels=1536,
184
- num_clusters=64,
185
- cluster_dim=128,
186
- token_dim=256,
187
- mlp_dim=512,
188
- dropout=0.3,
189
- ) -> None:
190
  super().__init__()
191
 
192
- self.num_channels = num_channels
193
- self.num_clusters = num_clusters
194
- self.cluster_dim = cluster_dim
195
- self.token_dim = token_dim
196
- self.mlp_dim = mlp_dim
197
-
198
- if dropout > 0:
199
- dropout = nn.Dropout(dropout)
200
- else:
201
- dropout = nn.Identity()
202
-
203
- # MLP for global scene token g
204
- self.token_features = nn.Sequential(
205
- nn.Linear(self.num_channels, self.mlp_dim), nn.ReLU(), nn.Linear(self.mlp_dim, self.token_dim)
206
- )
207
- # MLP for local features f_i
208
- self.cluster_features = nn.Sequential(
209
- nn.Conv2d(self.num_channels, self.mlp_dim, 1),
210
- dropout,
211
- nn.ReLU(),
212
- nn.Conv2d(self.mlp_dim, self.cluster_dim, 1),
213
- )
214
- # MLP for score matrix S
215
- self.score = nn.Sequential(
216
- nn.Conv2d(self.num_channels, self.mlp_dim, 1),
217
- dropout,
218
- nn.ReLU(),
219
- nn.Conv2d(self.mlp_dim, self.num_clusters, 1),
220
  )
221
- # Dustbin parameter z
222
- self.dust_bin = nn.Parameter(torch.tensor(1.0))
223
 
224
- def forward(self, x):
225
- """
226
- x (tuple): A tuple containing two elements, f and t.
227
- (torch.Tensor): The feature tensors (t_i) [B, C, H // 14, W // 14].
228
- (torch.Tensor): The token tensor (t_{n+1}) [B, C].
 
 
 
229
 
230
  Returns:
231
- f (torch.Tensor): The global descriptor [B, m*l + g]
232
  """
233
- x, t = x # Extract features and token
234
-
235
- f = self.cluster_features(x).flatten(2)
236
- p = self.score(x).flatten(2)
237
- t = self.token_features(t)
238
 
239
- # Sinkhorn algorithm
240
- p = get_matching_probs(p, self.dust_bin, 3)
241
- p = torch.exp(p)
242
- # Normalize to maintain mass
243
- p = p[:, :-1, :]
244
 
245
- p = p.unsqueeze(1).repeat(1, self.cluster_dim, 1, 1)
246
- f = f.unsqueeze(2).repeat(1, 1, self.num_clusters, 1)
247
 
248
- f = torch.cat(
249
- [
250
- nn.functional.normalize(t, p=2, dim=-1),
251
- nn.functional.normalize((f * p).sum(dim=-1), p=2, dim=1).flatten(1),
252
- ],
253
- dim=-1,
254
- )
255
 
256
- return nn.functional.normalize(f, p=2, dim=-1)
 
 
1
+ """MegaLoc: One Retrieval to Place Them All
2
+
3
+ This module implements the MegaLoc model for visual place recognition.
4
+ The model combines a Vision Transformer backbone with an optimal transport-based
5
+ feature aggregation module.
6
+
7
+ Paper: https://arxiv.org/abs/2502.17237
8
+ License: MIT
9
  """
10
 
11
  import math
12
+ from typing import Tuple
13
 
14
  import torch
15
  import torch.nn as nn
16
  import torch.nn.functional as F
17
+ import torchvision.transforms.functional as tfm
18
  from huggingface_hub import PyTorchModelHubMixin
19
 
20
 
21
+ # ==============================================================================
22
+ # Optimal Transport Feature Aggregation
23
+ # ==============================================================================
24
+ # The following implements an optimal transport-based feature aggregation module
25
+ # that converts local patch features into a compact global descriptor.
26
+ # ==============================================================================
27
+
28
+
29
+ def sinkhorn_log_iterations(
30
+ source_log_weights: torch.Tensor,
31
+ target_log_weights: torch.Tensor,
32
+ cost_matrix: torch.Tensor,
33
+ num_iterations: int = 20,
34
+ regularization: float = 1.0,
35
+ ) -> torch.Tensor:
36
+ """Compute optimal transport plan using Sinkhorn iterations in log space.
37
+
38
+ This implements the Sinkhorn-Knopp algorithm for computing the entropy-regularized
39
+ optimal transport plan between two distributions. The log-space formulation
40
+ provides numerical stability.
41
+
42
+ Args:
43
+ source_log_weights: Log of source distribution weights [batch, m+1]
44
+ target_log_weights: Log of target distribution weights [batch, n]
45
+ cost_matrix: Cost/score matrix [batch, m+1, n]
46
+ num_iterations: Number of Sinkhorn iterations
47
+ regularization: Entropy regularization strength
48
+
49
+ Returns:
50
+ Log of the transport plan matrix [batch, m+1, n]
51
+ """
52
+ # Apply regularization scaling
53
+ scaled_costs = cost_matrix / regularization
54
+
55
+ # Initialize dual variables
56
+ dual_source = torch.zeros_like(source_log_weights)
57
+ dual_target = torch.zeros_like(target_log_weights)
58
+
59
+ # Sinkhorn iterations: alternating row and column normalization
60
+ for _ in range(num_iterations):
61
+ # Row normalization (update source dual)
62
+ dual_source = source_log_weights - torch.logsumexp(scaled_costs + dual_target.unsqueeze(1), dim=2).squeeze()
63
+ # Column normalization (update target dual)
64
+ dual_target = target_log_weights - torch.logsumexp(scaled_costs + dual_source.unsqueeze(2), dim=1).squeeze()
65
+
66
+ # Compute final transport plan
67
+ transport_plan = scaled_costs + dual_source.unsqueeze(2) + dual_target.unsqueeze(1)
68
+ return transport_plan
69
+
70
+
71
+ def compute_soft_assignments(
72
+ affinity_scores: torch.Tensor,
73
+ slack_logit: float = 1.0,
74
+ num_iterations: int = 3,
75
+ regularization: float = 1.0,
76
+ ) -> torch.Tensor:
77
+ """Compute soft cluster assignments using optimal transport with slack.
78
+
79
+ Augments the affinity matrix with a slack row to handle unassigned features,
80
+ then applies Sinkhorn normalization to get valid transport probabilities.
81
+
82
+ Args:
83
+ affinity_scores: Raw affinity scores [batch, num_clusters, num_patches]
84
+ slack_logit: Initial logit value for the slack row
85
+ num_iterations: Number of Sinkhorn iterations
86
+ regularization: Entropy regularization strength
87
+
88
+ Returns:
89
+ Log-probabilities of assignments [batch, num_clusters+1, num_patches]
90
+ """
91
+ batch_size, num_clusters, num_patches = affinity_scores.size()
92
+
93
+ # Augment score matrix with slack row for handling outliers/unmatched
94
+ augmented_scores = torch.empty(
95
+ batch_size,
96
+ num_clusters + 1,
97
+ num_patches,
98
+ dtype=affinity_scores.dtype,
99
+ device=affinity_scores.device,
100
+ )
101
+ augmented_scores[:, :num_clusters, :num_patches] = affinity_scores
102
+ augmented_scores[:, num_clusters, :] = slack_logit
103
+
104
+ # Prepare log-weights for source (clusters + slack) and target (patches)
105
+ log_normalization = -torch.tensor(math.log(num_patches + num_clusters), device=affinity_scores.device)
106
+
107
+ # Source weights: uniform over clusters, extra mass on slack
108
+ source_log = log_normalization.expand(num_clusters + 1).contiguous()
109
+ source_log = source_log.clone()
110
+ source_log[-1] = source_log[-1] + math.log(num_patches - num_clusters)
111
+
112
+ # Target weights: uniform over patches
113
+ target_log = log_normalization.expand(num_patches).contiguous()
114
+
115
+ # Expand to batch dimension
116
+ source_log = source_log.expand(batch_size, -1)
117
+ target_log = target_log.expand(batch_size, -1)
118
+
119
+ # Solve optimal transport
120
+ log_transport = sinkhorn_log_iterations(
121
+ source_log,
122
+ target_log,
123
+ augmented_scores,
124
+ num_iterations=num_iterations,
125
+ regularization=regularization,
126
+ )
127
+
128
+ return log_transport - log_normalization
129
+
130
+
131
+ class FeatureAggregationHead(nn.Module):
132
+ """Optimal transport-based aggregation of local features into global descriptor.
133
+
134
+ This module learns to aggregate local patch features into a compact global
135
+ representation using differentiable optimal transport. It produces:
136
+ 1. A global scene token from the CLS token
137
+ 2. Cluster-aggregated local descriptors weighted by transport probabilities
138
+
139
+ The final descriptor is the L2-normalized concatenation of both components.
140
+
141
+ Args:
142
+ input_channels: Number of input feature channels (from backbone)
143
+ num_clusters: Number of learned cluster centers
144
+ cluster_channels: Dimensionality of each cluster descriptor
145
+ global_token_dim: Dimensionality of the global scene token
146
+ hidden_dim: Hidden dimension for MLPs
147
+ dropout_rate: Dropout probability (0 to disable)
148
+ """
149
+
150
  def __init__(
151
  self,
152
+ input_channels: int = 1536,
153
+ num_clusters: int = 64,
154
+ cluster_channels: int = 128,
155
+ global_token_dim: int = 256,
156
+ hidden_dim: int = 512,
157
+ dropout_rate: float = 0.3,
158
+ ) -> None:
159
  super().__init__()
160
+
161
+ self.input_channels = input_channels
162
+ self.num_clusters = num_clusters
163
+ self.cluster_channels = cluster_channels
164
+ self.global_token_dim = global_token_dim
165
+ self.hidden_dim = hidden_dim
166
+
167
+ # Dropout layer (or identity if disabled)
168
+ regularization = nn.Dropout(dropout_rate) if dropout_rate > 0 else nn.Identity()
169
+
170
+ # MLP to project CLS token to global scene descriptor
171
+ self.global_token_mlp = nn.Sequential(
172
+ nn.Linear(self.input_channels, self.hidden_dim),
173
+ nn.ReLU(),
174
+ nn.Linear(self.hidden_dim, self.global_token_dim),
175
  )
176
+
177
+ # Convolutional MLP to project patch features to cluster descriptors
178
+ self.descriptor_projection = nn.Sequential(
179
+ nn.Conv2d(self.input_channels, self.hidden_dim, 1),
180
+ regularization,
181
+ nn.ReLU(),
182
+ nn.Conv2d(self.hidden_dim, self.cluster_channels, 1),
183
+ )
184
+
185
+ # Convolutional MLP to compute cluster assignment logits
186
+ self.assignment_head = nn.Sequential(
187
+ nn.Conv2d(self.input_channels, self.hidden_dim, 1),
188
+ regularization,
189
+ nn.ReLU(),
190
+ nn.Conv2d(self.hidden_dim, self.num_clusters, 1),
191
+ )
192
+
193
+ # Learnable slack variable for optimal transport
194
+ self.slack_variable = nn.Parameter(torch.tensor(1.0))
195
+
196
+ def forward(self, inputs):
197
+ """Aggregate local and global features into compact descriptor.
198
+
199
+ Args:
200
+ inputs: Tuple of (patch_features, cls_token)
201
+ - patch_features: [B, C, H, W] spatial feature map
202
+ - cls_token: [B, C] global CLS token
203
+
204
+ Returns:
205
+ Global descriptor [B, num_clusters * cluster_channels + global_token_dim]
206
+ """
207
+ patch_features, cls_token = inputs
208
+
209
+ # Project patch features to cluster descriptors: [B, cluster_channels, H*W]
210
+ local_descriptors = self.descriptor_projection(patch_features).flatten(2)
211
+
212
+ # Compute assignment logits: [B, num_clusters, H*W]
213
+ assignment_logits = self.assignment_head(patch_features).flatten(2)
214
+
215
+ # Project CLS token to global descriptor: [B, global_token_dim]
216
+ global_descriptor = self.global_token_mlp(cls_token)
217
+
218
+ # Compute soft assignments via optimal transport
219
+ log_assignments = compute_soft_assignments(assignment_logits, self.slack_variable, num_iterations=3)
220
+ assignments = torch.exp(log_assignments)
221
+
222
+ # Remove slack row (keep only cluster assignments)
223
+ assignments = assignments[:, :-1, :]
224
+
225
+ # Aggregate local descriptors weighted by assignments
226
+ # assignments: [B, num_clusters, num_patches]
227
+ # local_descriptors: [B, cluster_channels, num_patches]
228
+ # We want: [B, cluster_channels, num_clusters]
229
+ assignments = assignments.unsqueeze(1).repeat(1, self.cluster_channels, 1, 1)
230
+ local_descriptors = local_descriptors.unsqueeze(2).repeat(1, 1, self.num_clusters, 1)
231
+
232
+ # Weighted sum over patches for each cluster
233
+ aggregated_clusters = (local_descriptors * assignments).sum(dim=-1)
234
+
235
+ # Normalize and concatenate
236
+ normalized_global = F.normalize(global_descriptor, p=2, dim=-1)
237
+ normalized_local = F.normalize(aggregated_clusters, p=2, dim=1).flatten(1)
238
+
239
+ combined = torch.cat([normalized_global, normalized_local], dim=-1)
240
+
241
+ return F.normalize(combined, p=2, dim=-1)
242
+
243
+
244
+ # ==============================================================================
245
+ # Vision Transformer Components
246
+ # ==============================================================================
247
+
248
+
249
+ class PatchEmbedding(nn.Module):
250
+ """Convert image patches to embeddings using a convolutional layer."""
251
+
252
+ def __init__(
253
+ self,
254
+ image_size: int = 518,
255
+ patch_size: int = 14,
256
+ in_channels: int = 3,
257
+ embed_dim: int = 768,
258
+ ):
259
  super().__init__()
260
+ self.image_size = image_size
261
+ self.patch_size = patch_size
262
+ self.num_patches = (image_size // patch_size) ** 2
263
+ self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
264
+
265
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
266
+ # x: [B, C, H, W] -> [B, embed_dim, H/patch_size, W/patch_size]
267
+ x = self.proj(x)
268
+ # Flatten spatial dimensions: [B, embed_dim, num_patches]
269
+ x = x.flatten(2)
270
+ # Transpose to [B, num_patches, embed_dim]
271
+ x = x.transpose(1, 2)
272
+ return x
273
 
 
 
274
 
275
+ class LayerScale(nn.Module):
276
+ """Learnable per-channel scaling as used in CaiT and DINOv2."""
277
 
278
+ def __init__(self, dim: int, init_value: float = 1e-5):
 
279
  super().__init__()
280
+ self.gamma = nn.Parameter(init_value * torch.ones(dim))
 
281
 
282
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
283
+ return x * self.gamma
 
284
 
285
 
286
+ class MultiHeadAttention(nn.Module):
287
+ """Multi-head self-attention module."""
288
+
289
+ def __init__(
290
+ self,
291
+ dim: int,
292
+ num_heads: int = 12,
293
+ qkv_bias: bool = True,
294
+ attn_drop: float = 0.0,
295
+ proj_drop: float = 0.0,
296
+ ):
297
  super().__init__()
298
+ self.num_heads = num_heads
299
+ self.head_dim = dim // num_heads
300
+ self.scale = self.head_dim**-0.5
 
 
301
 
302
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
303
+ self.attn_drop = nn.Dropout(attn_drop)
304
+ self.proj = nn.Linear(dim, dim)
305
+ self.proj_drop = nn.Dropout(proj_drop)
306
+
307
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
308
+ B, N, C = x.shape
309
+
310
+ # Compute Q, K, V
311
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
312
+ qkv = qkv.permute(2, 0, 3, 1, 4) # [3, B, num_heads, N, head_dim]
313
+ q, k, v = qkv[0], qkv[1], qkv[2]
314
+
315
+ # Scaled dot-product attention
316
+ attn = (q @ k.transpose(-2, -1)) * self.scale
317
+ attn = attn.softmax(dim=-1)
318
+ attn = self.attn_drop(attn)
319
+
320
+ # Apply attention to values
321
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
322
+ x = self.proj(x)
323
+ x = self.proj_drop(x)
324
+
325
+ return x
326
+
327
+
328
+ class MLP(nn.Module):
329
+ """MLP module with GELU activation."""
330
+
331
+ def __init__(
332
+ self,
333
+ in_features: int,
334
+ hidden_features: int = None,
335
+ out_features: int = None,
336
+ drop: float = 0.0,
337
+ ):
338
+ super().__init__()
339
+ out_features = out_features or in_features
340
+ hidden_features = hidden_features or in_features
341
+
342
+ self.fc1 = nn.Linear(in_features, hidden_features)
343
+ self.act = nn.GELU()
344
+ self.fc2 = nn.Linear(hidden_features, out_features)
345
+ self.drop = nn.Dropout(drop)
346
+
347
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
348
+ x = self.fc1(x)
349
+ x = self.act(x)
350
+ x = self.drop(x)
351
+ x = self.fc2(x)
352
+ x = self.drop(x)
353
+ return x
354
+
355
+
356
+ class TransformerBlock(nn.Module):
357
+ """Vision Transformer block with LayerScale."""
358
+
359
+ def __init__(
360
+ self,
361
+ dim: int,
362
+ num_heads: int,
363
+ mlp_ratio: float = 4.0,
364
+ qkv_bias: bool = True,
365
+ drop: float = 0.0,
366
+ attn_drop: float = 0.0,
367
+ init_values: float = 1e-5,
368
+ ):
369
+ super().__init__()
370
+ self.norm1 = nn.LayerNorm(dim, eps=1e-6)
371
+ self.attn = MultiHeadAttention(
372
+ dim,
373
+ num_heads=num_heads,
374
+ qkv_bias=qkv_bias,
375
+ attn_drop=attn_drop,
376
+ proj_drop=drop,
377
+ )
378
+ self.ls1 = LayerScale(dim, init_value=init_values)
379
+
380
+ self.norm2 = nn.LayerNorm(dim, eps=1e-6)
381
+ self.mlp = MLP(
382
+ in_features=dim,
383
+ hidden_features=int(dim * mlp_ratio),
384
+ drop=drop,
385
+ )
386
+ self.ls2 = LayerScale(dim, init_value=init_values)
387
+
388
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
389
+ x = x + self.ls1(self.attn(self.norm1(x)))
390
+ x = x + self.ls2(self.mlp(self.norm2(x)))
391
+ return x
392
+
393
+
394
+ class VisionTransformerBackbone(nn.Module):
395
+ """DINOv2 Vision Transformer backbone for feature extraction.
396
+
397
+ This implements a ViT-B/14 architecture compatible with DINOv2 weights.
398
+ The positional encoding interpolation matches the Facebook implementation
399
+ for exact output compatibility.
400
+ """
401
+
402
+ def __init__(
403
+ self,
404
+ image_size: int = 518,
405
+ patch_size: int = 14,
406
+ in_channels: int = 3,
407
+ embed_dim: int = 768,
408
+ depth: int = 12,
409
+ num_heads: int = 12,
410
+ mlp_ratio: float = 4.0,
411
+ qkv_bias: bool = True,
412
+ ):
413
+ super().__init__()
414
+ self.patch_size = patch_size
415
+ self.embed_dim = embed_dim
416
+ self.num_channels = embed_dim # For compatibility
417
+
418
+ # Patch embedding
419
+ self.patch_embed = PatchEmbedding(
420
+ image_size=image_size,
421
+ patch_size=patch_size,
422
+ in_channels=in_channels,
423
+ embed_dim=embed_dim,
424
+ )
425
+
426
+ # Positional encoding interpolation parameters (matching Facebook's DINO)
427
+ self.interpolate_offset = 0.1
428
+ self.interpolate_antialias = False
429
+
430
+ # Class token
431
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
432
+
433
+ # Positional embedding (for 37x37 = 1369 patches + 1 CLS token = 1370)
434
+ num_patches = (image_size // patch_size) ** 2
435
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
436
+
437
+ # Transformer blocks
438
+ self.blocks = nn.ModuleList(
439
+ [
440
+ TransformerBlock(
441
+ dim=embed_dim,
442
+ num_heads=num_heads,
443
+ mlp_ratio=mlp_ratio,
444
+ qkv_bias=qkv_bias,
445
+ )
446
+ for _ in range(depth)
447
+ ]
448
+ )
449
+
450
+ # Final layer norm
451
+ self.norm = nn.LayerNorm(embed_dim, eps=1e-6)
452
+
453
+ def interpolate_pos_encoding(self, x: torch.Tensor, w: int, h: int) -> torch.Tensor:
454
+ """Interpolate positional encoding for different input sizes.
455
+
456
+ This matches the Facebook DINOv2 implementation exactly, including
457
+ the interpolation offset kludge for backward compatibility.
458
  """
459
+ previous_dtype = x.dtype
460
+ npatch = x.shape[1] - 1 # Exclude CLS token
461
+ N = self.pos_embed.shape[1] - 1 # Number of patches in pos_embed
462
+
463
+ # If input matches training resolution, return as-is
464
+ if npatch == N and w == h:
465
+ return self.pos_embed
466
+
467
+ pos_embed = self.pos_embed.float()
468
+ class_pos_embed = pos_embed[:, 0]
469
+ patch_pos_embed = pos_embed[:, 1:]
470
+
471
+ dim = x.shape[-1]
472
+ w0 = w // self.patch_size
473
+ h0 = h // self.patch_size
474
+ M = int(math.sqrt(N)) # Original number of patches per dimension
475
+
476
+ # Use scale_factor with offset for backward compatibility
477
+ # This is the "kludge" from Facebook's DINO implementation
478
+ sx = float(w0 + self.interpolate_offset) / M
479
+ sy = float(h0 + self.interpolate_offset) / M
480
+
481
+ patch_pos_embed = F.interpolate(
482
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
483
+ scale_factor=(sx, sy),
484
+ mode="bicubic",
485
+ antialias=self.interpolate_antialias,
486
+ )
487
+
488
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
489
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
490
+
491
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
492
+
493
+ def prepare_tokens(self, x: torch.Tensor) -> torch.Tensor:
494
+ """Prepare input tokens with positional encoding."""
495
+ B, C, W, H = x.shape
496
+
497
+ # Patch embedding
498
+ x = self.patch_embed(x)
499
+
500
+ # Add CLS token
501
+ cls_tokens = self.cls_token.expand(B, -1, -1)
502
+ x = torch.cat((cls_tokens, x), dim=1)
503
 
504
+ # Add positional encoding
505
+ x = x + self.interpolate_pos_encoding(x, W, H)
506
+
507
+ return x
508
+
509
+ def forward(self, images: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
510
+ """Extract features from images.
511
+
512
+ Args:
513
+ images: Input images [B, 3, H, W] where H, W are multiples of 14
514
 
515
  Returns:
516
+ Tuple of:
517
+ - patch_features: [B, 768, H//14, W//14] spatial feature map
518
+ - cls_token: [B, 768] global CLS token
519
  """
520
+ batch_size, _, height, width = images.shape
521
+
522
+ # Prepare tokens with positional encoding
523
+ x = self.prepare_tokens(images)
524
+
525
+ # Apply transformer blocks
526
+ for block in self.blocks:
527
+ x = block(x)
528
 
529
+ # Apply final layer norm
530
+ x = self.norm(x)
531
 
532
+ # Extract CLS token and patch tokens
533
+ cls_token = x[:, 0]
534
+ patch_tokens = x[:, 1:]
535
 
536
+ # Reshape patch tokens to spatial format
537
+ h_patches = height // self.patch_size
538
+ w_patches = width // self.patch_size
539
+ patch_features = patch_tokens.reshape(batch_size, h_patches, w_patches, self.embed_dim).permute(0, 3, 1, 2)
 
540
 
541
+ return patch_features, cls_token
 
 
542
 
 
 
543
 
544
+ # ==============================================================================
545
+ # Feature Dimension Reduction
546
+ # ==============================================================================
547
 
 
 
548
 
549
+ class DescriptorAggregator(nn.Module):
550
+ """Wrapper combining feature aggregation with linear projection.
 
551
 
552
+ Applies the optimal transport aggregation followed by a linear layer
553
+ to reduce dimensionality to the desired output size.
554
 
 
 
 
 
 
555
  Args:
556
+ output_dim: Final descriptor dimensionality
557
+ aggregator_config: Configuration for FeatureAggregationHead
558
+ aggregator_output_dim: Output dimension of the aggregation head
 
 
 
 
 
 
 
559
  """
 
560
 
561
+ def __init__(self, output_dim: int, aggregator_config: dict, aggregator_output_dim: int):
562
+ super().__init__()
563
+ self.aggregation = FeatureAggregationHead(**aggregator_config)
564
+ self.projection = nn.Linear(aggregator_output_dim, output_dim)
565
 
566
+ def forward(self, x):
567
+ aggregated = self.aggregation(x)
568
+ return self.projection(aggregated)
569
 
 
570
 
571
+ # ==============================================================================
572
+ # L2 Normalization Layer
573
+ # ==============================================================================
574
 
 
 
 
 
 
 
 
 
 
575
 
576
+ class L2Normalize(nn.Module):
577
+ """L2 normalization layer."""
 
 
 
 
 
578
 
579
+ def __init__(self, dim: int = -1):
580
+ super().__init__()
581
+ self.dim = dim
582
 
583
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
584
+ return F.normalize(x, p=2, dim=self.dim)
585
+
586
+
587
+ # ==============================================================================
588
+ # Main Model
589
+ # ==============================================================================
590
+
591
+
592
+ class MegaLoc(nn.Module, PyTorchModelHubMixin):
593
+ """MegaLoc: Unified visual place recognition model.
594
+
595
+ Combines a DINOv2 Vision Transformer backbone with optimal transport-based
596
+ feature aggregation to produce compact, discriminative image descriptors
597
+ for place recognition and image retrieval tasks.
598
+
599
+ Args:
600
+ feat_dim: Output descriptor dimensionality (default: 8448)
601
+ num_clusters: Number of cluster centers for aggregation (default: 64)
602
+ cluster_dim: Dimensionality of cluster descriptors (default: 256)
603
+ token_dim: Dimensionality of global scene token (default: 256)
604
+ mlp_dim: Hidden dimension for MLPs (default: 512)
605
+
606
+ Example:
607
+ >>> model = MegaLoc.from_pretrained("gberton/MegaLoc")
608
+ >>> model.eval()
609
+ >>> image = torch.randn(1, 3, 322, 322) # Will auto-resize to 322x322
610
+ >>> descriptor = model(image) # [1, 8448]
611
  """
612
 
613
  def __init__(
614
  self,
615
+ feat_dim: int = 8448,
616
+ num_clusters: int = 64,
617
+ cluster_dim: int = 256,
618
+ token_dim: int = 256,
619
+ mlp_dim: int = 512,
620
+ ):
 
621
  super().__init__()
622
 
623
+ self.backbone = VisionTransformerBackbone()
624
+
625
+ # Aggregator output: num_clusters * cluster_dim + token_dim
626
+ self.aggregator_output_dim = num_clusters * cluster_dim + token_dim
627
+
628
+ self.aggregator = DescriptorAggregator(
629
+ output_dim=feat_dim,
630
+ aggregator_config={
631
+ "input_channels": self.backbone.num_channels,
632
+ "num_clusters": num_clusters,
633
+ "cluster_channels": cluster_dim,
634
+ "global_token_dim": token_dim,
635
+ "hidden_dim": mlp_dim,
636
+ },
637
+ aggregator_output_dim=self.aggregator_output_dim,
 
 
 
 
 
 
 
 
 
 
 
 
 
638
  )
 
 
639
 
640
+ self.feat_dim = feat_dim
641
+ self.normalize = L2Normalize()
642
+
643
+ def forward(self, images: torch.Tensor) -> torch.Tensor:
644
+ """Extract global descriptor from images.
645
+
646
+ Args:
647
+ images: Input images [B, 3, H, W]
648
 
649
  Returns:
650
+ L2-normalized descriptors [B, feat_dim]
651
  """
652
+ batch_size, channels, height, width = images.shape
 
 
 
 
653
 
654
+ # Ensure dimensions are multiples of 14 (ViT patch size)
655
+ if height % 14 != 0 or width % 14 != 0:
656
+ height = round(height / 14) * 14
657
+ width = round(width / 14) * 14
658
+ images = tfm.resize(images, [height, width], antialias=True)
659
 
660
+ # Extract backbone features
661
+ features = self.backbone(images)
662
 
663
+ # Aggregate into global descriptor
664
+ descriptor = self.aggregator(features)
 
 
 
 
 
665
 
666
+ # Final L2 normalization
667
+ return self.normalize(descriptor)
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7f84357772ac8c92eedb86267afe013fde9ab68bb9dbe478866d08fe04c38684
3
- size 914581652
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d8716ac9959a86e00494f605a4be46aebed15694ab4ad77c27b91ada9ab51e4
3
+ size 914577620