gberton commited on
Commit
9b84e11
·
verified ·
1 Parent(s): 2065c6c

Upload megaloc_model.py

Browse files
Files changed (1) hide show
  1. megaloc_model.py +255 -0
megaloc_model.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Code for the MegaLoc model.
2
+ Much of the code in this file is from SALAD https://github.com/serizba/salad
3
+ """
4
+
5
+ import math
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ import torchvision.transforms as tfm
11
+
12
+
13
+ class MegaLocModel(nn.Module):
14
+ def __init__(
15
+ self,
16
+ feat_dim=8448,
17
+ num_clusters=64,
18
+ cluster_dim=256,
19
+ token_dim=256,
20
+ mlp_dim=512,
21
+ ):
22
+ super().__init__()
23
+ self.backbone = DINOv2()
24
+ self.salad_out_dim = num_clusters * cluster_dim + token_dim
25
+ self.aggregator = Aggregator(
26
+ feat_dim=feat_dim,
27
+ agg_config={
28
+ "num_channels": self.backbone.num_channels,
29
+ "num_clusters": num_clusters,
30
+ "cluster_dim": cluster_dim,
31
+ "token_dim": token_dim,
32
+ "mlp_dim": mlp_dim,
33
+ },
34
+ salad_out_dim=self.salad_out_dim,
35
+ )
36
+ self.feat_dim = feat_dim
37
+ self.l2norm = L2Norm()
38
+
39
+ def forward(self, images):
40
+ b, c, h, w = images.shape
41
+ if h % 14 != 0 or w % 14 != 0:
42
+ # DINO needs height and width as multiple of 14, therefore resize them
43
+ # to the nearest multiple of 14
44
+ h = round(h / 14) * 14
45
+ w = round(w / 14) * 14
46
+ images = tfm.functional.resize(images, [h, w], antialias=True)
47
+ features = self.aggregator(self.backbone(images))
48
+ features = self.l2norm(features)
49
+ return features
50
+
51
+
52
+ class L2Norm(nn.Module):
53
+ def __init__(self, dim=1):
54
+ super().__init__()
55
+ self.dim = dim
56
+
57
+ def forward(self, x):
58
+ return F.normalize(x, p=2.0, dim=self.dim)
59
+
60
+
61
+ class Aggregator(nn.Module):
62
+ def __init__(self, feat_dim, agg_config, salad_out_dim):
63
+ super().__init__()
64
+ self.agg = SALAD(**agg_config)
65
+ self.linear = nn.Linear(salad_out_dim, feat_dim)
66
+
67
+ def forward(self, x):
68
+ x = self.agg(x)
69
+ return self.linear(x)
70
+
71
+
72
+ class DINOv2(nn.Module):
73
+ def __init__(self, num_trainable_blocks=4, norm_layer=True, return_token=True):
74
+ super().__init__()
75
+ self.model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14")
76
+ self.num_channels = 768
77
+ self.num_trainable_blocks = num_trainable_blocks
78
+ self.norm_layer = norm_layer
79
+ self.return_token = return_token
80
+
81
+ def forward(self, x):
82
+ """
83
+ The forward method for the DINOv2 class
84
+
85
+ Parameters:
86
+ x (torch.Tensor): The input tensor [B, 3, H, W]. H and W should be divisible by 14.
87
+
88
+ Returns:
89
+ f (torch.Tensor): The feature map [B, C, H // 14, W // 14].
90
+ t (torch.Tensor): The token [B, C]. This is only returned if return_token is True.
91
+ """
92
+
93
+ B, C, H, W = x.shape
94
+
95
+ x = self.model.prepare_tokens_with_masks(x)
96
+
97
+ # First blocks are frozen
98
+ with torch.no_grad():
99
+ for blk in self.model.blocks[: -self.num_trainable_blocks]:
100
+ x = blk(x)
101
+ x = x.detach()
102
+
103
+ # Last blocks are trained
104
+ for blk in self.model.blocks[-self.num_trainable_blocks :]:
105
+ x = blk(x)
106
+
107
+ if self.norm_layer:
108
+ x = self.model.norm(x)
109
+
110
+ t = x[:, 0]
111
+ f = x[:, 1:]
112
+
113
+ # Reshape to (B, C, H, W)
114
+ f = f.reshape((B, H // 14, W // 14, self.num_channels)).permute(0, 3, 1, 2)
115
+
116
+ if self.return_token:
117
+ return f, t
118
+ return f
119
+
120
+
121
+ # Code adapted from OpenGlue, MIT license
122
+ # https://github.com/ucuapps/OpenGlue/blob/main/models/superglue/optimal_transport.py
123
+ def log_otp_solver(log_a, log_b, M, num_iters: int = 20, reg: float = 1.0) -> torch.Tensor:
124
+ r"""Sinkhorn matrix scaling algorithm for Differentiable Optimal Transport problem.
125
+ This function solves the optimization problem and returns the OT matrix for the given parameters.
126
+ Args:
127
+ log_a : torch.Tensor
128
+ Source weights
129
+ log_b : torch.Tensor
130
+ Target weights
131
+ M : torch.Tensor
132
+ metric cost matrix
133
+ num_iters : int, default=100
134
+ The number of iterations.
135
+ reg : float, default=1.0
136
+ regularization value
137
+ """
138
+ M = M / reg # regularization
139
+
140
+ u, v = torch.zeros_like(log_a), torch.zeros_like(log_b)
141
+
142
+ for _ in range(num_iters):
143
+ u = log_a - torch.logsumexp(M + v.unsqueeze(1), dim=2).squeeze()
144
+ v = log_b - torch.logsumexp(M + u.unsqueeze(2), dim=1).squeeze()
145
+
146
+ return M + u.unsqueeze(2) + v.unsqueeze(1)
147
+
148
+
149
+ # Code adapted from OpenGlue, MIT license
150
+ # https://github.com/ucuapps/OpenGlue/blob/main/models/superglue/superglue.py
151
+ def get_matching_probs(S, dustbin_score=1.0, num_iters=3, reg=1.0):
152
+ """sinkhorn"""
153
+ batch_size, m, n = S.size()
154
+ # augment scores matrix
155
+ S_aug = torch.empty(batch_size, m + 1, n, dtype=S.dtype, device=S.device)
156
+ S_aug[:, :m, :n] = S
157
+ S_aug[:, m, :] = dustbin_score
158
+
159
+ # prepare normalized source and target log-weights
160
+ norm = -torch.tensor(math.log(n + m), device=S.device)
161
+ log_a, log_b = norm.expand(m + 1).contiguous(), norm.expand(n).contiguous()
162
+ log_a[-1] = log_a[-1] + math.log(n - m)
163
+ log_a, log_b = log_a.expand(batch_size, -1), log_b.expand(batch_size, -1)
164
+ log_P = log_otp_solver(log_a, log_b, S_aug, num_iters=num_iters, reg=reg)
165
+ return log_P - norm
166
+
167
+
168
+ class SALAD(nn.Module):
169
+ """
170
+ This class represents the Sinkhorn Algorithm for Locally Aggregated Descriptors (SALAD) model.
171
+
172
+ Attributes:
173
+ num_channels (int): The number of channels of the inputs (d).
174
+ num_clusters (int): The number of clusters in the model (m).
175
+ cluster_dim (int): The number of channels of the clusters (l).
176
+ token_dim (int): The dimension of the global scene token (g).
177
+ dropout (float): The dropout rate.
178
+ """
179
+
180
+ def __init__(
181
+ self,
182
+ num_channels=1536,
183
+ num_clusters=64,
184
+ cluster_dim=128,
185
+ token_dim=256,
186
+ mlp_dim=512,
187
+ dropout=0.3,
188
+ ) -> None:
189
+ super().__init__()
190
+
191
+ self.num_channels = num_channels
192
+ self.num_clusters = num_clusters
193
+ self.cluster_dim = cluster_dim
194
+ self.token_dim = token_dim
195
+ self.mlp_dim = mlp_dim
196
+
197
+ if dropout > 0:
198
+ dropout = nn.Dropout(dropout)
199
+ else:
200
+ dropout = nn.Identity()
201
+
202
+ # MLP for global scene token g
203
+ self.token_features = nn.Sequential(
204
+ nn.Linear(self.num_channels, self.mlp_dim), nn.ReLU(), nn.Linear(self.mlp_dim, self.token_dim)
205
+ )
206
+ # MLP for local features f_i
207
+ self.cluster_features = nn.Sequential(
208
+ nn.Conv2d(self.num_channels, self.mlp_dim, 1),
209
+ dropout,
210
+ nn.ReLU(),
211
+ nn.Conv2d(self.mlp_dim, self.cluster_dim, 1),
212
+ )
213
+ # MLP for score matrix S
214
+ self.score = nn.Sequential(
215
+ nn.Conv2d(self.num_channels, self.mlp_dim, 1),
216
+ dropout,
217
+ nn.ReLU(),
218
+ nn.Conv2d(self.mlp_dim, self.num_clusters, 1),
219
+ )
220
+ # Dustbin parameter z
221
+ self.dust_bin = nn.Parameter(torch.tensor(1.0))
222
+
223
+ def forward(self, x):
224
+ """
225
+ x (tuple): A tuple containing two elements, f and t.
226
+ (torch.Tensor): The feature tensors (t_i) [B, C, H // 14, W // 14].
227
+ (torch.Tensor): The token tensor (t_{n+1}) [B, C].
228
+
229
+ Returns:
230
+ f (torch.Tensor): The global descriptor [B, m*l + g]
231
+ """
232
+ x, t = x # Extract features and token
233
+
234
+ f = self.cluster_features(x).flatten(2)
235
+ p = self.score(x).flatten(2)
236
+ t = self.token_features(t)
237
+
238
+ # Sinkhorn algorithm
239
+ p = get_matching_probs(p, self.dust_bin, 3)
240
+ p = torch.exp(p)
241
+ # Normalize to maintain mass
242
+ p = p[:, :-1, :]
243
+
244
+ p = p.unsqueeze(1).repeat(1, self.cluster_dim, 1, 1)
245
+ f = f.unsqueeze(2).repeat(1, 1, self.num_clusters, 1)
246
+
247
+ f = torch.cat(
248
+ [
249
+ nn.functional.normalize(t, p=2, dim=-1),
250
+ nn.functional.normalize((f * p).sum(dim=-1), p=2, dim=1).flatten(1),
251
+ ],
252
+ dim=-1,
253
+ )
254
+
255
+ return nn.functional.normalize(f, p=2, dim=-1)