Pawel Piwowarski commited on
Commit
8a6022e
·
0 Parent(s):

init commit

Browse files
Files changed (8) hide show
  1. .gitattributes +2 -0
  2. .gitignore +1 -0
  3. README.md +3 -0
  4. load_model.py +36 -0
  5. model.py +161 -0
  6. models/mixvpr.py +94 -0
  7. models/salad.py +141 -0
  8. weights/best_model_95.6.torch +3 -0
.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .torch filter=lfs diff=lfs merge=lfs -text
2
+ *.torch filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
load_model.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import DINOv2FeatureExtractor
2
+ import torch
3
+
4
+
5
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
6
+
7
+ MODEL_CHECKPOINT_PATH = './weights/best_model_95.6.torch'
8
+
9
+
10
+ model = DINOv2FeatureExtractor(
11
+ model_type="vit_base_patch14_reg4_dinov2.lvd142m",
12
+ num_of_layers_to_unfreeze=0,
13
+ desc_dim=768,
14
+ aggregator_type="SALAD",
15
+ )
16
+ print('loading model ... ')
17
+ model_state_dict = torch.load(MODEL_CHECKPOINT_PATH, map_location=DEVICE)
18
+ model.load_state_dict(model_state_dict)
19
+ model = model.to(DEVICE)
20
+ model.eval()
21
+ print('loaded ....')
22
+
23
+
24
+
25
+ # Move to device
26
+ model.to(DEVICE)
27
+
28
+ # Print some info about model weights
29
+ num_params = sum(p.numel() for p in model.parameters())
30
+ num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
31
+ print(f"Model total parameters: {num_params:,}")
32
+ print(f"Model trainable parameters: {num_trainable:,}")
33
+
34
+ print(model.aggregator_type)
35
+
36
+
model.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import timm
5
+ import logging
6
+ from types import SimpleNamespace as Namespace
7
+
8
+ # Assuming these are in your project structure
9
+ from models.salad import SALAD
10
+ from models.mixvpr import MixVPR
11
+
12
+
13
+
14
+
15
+ class DINOv2FeatureExtractor(nn.Module):
16
+ def __init__(
17
+ self,
18
+ image_size=518, # Default for DINOv2 models
19
+ model_type="vit_base_patch14_reg4_dinov2.lvd142m",
20
+ num_of_layers_to_unfreeze=1,
21
+ desc_dim=768, # vit-base has 768-dim embeddings
22
+ aggregator_type="No",
23
+ ):
24
+ super().__init__()
25
+
26
+ # Initialize backbone with registers
27
+ self.backbone = timm.create_model(
28
+ model_type, pretrained=True, num_classes=0, img_size=image_size
29
+ )
30
+
31
+ # Store configuration parameters
32
+ self.model_type = model_type
33
+ self.num_channels = self.backbone.embed_dim
34
+ self.desc_dim = desc_dim
35
+ self.image_size = image_size
36
+ self.num_of_layers_to_unfreeze = num_of_layers_to_unfreeze
37
+ self.aggregator_type = aggregator_type
38
+ self.aggregator = None
39
+
40
+ if aggregator_type == "SALAD":
41
+ if "vit_small" in model_type:
42
+ self.aggregator = SALAD(
43
+ num_channels=self.num_channels,
44
+ num_clusters=24,
45
+ cluster_dim=64,
46
+ token_dim=512,
47
+ dropout=0.3,
48
+ )
49
+ # Output: 512 + (24 * 64) = 2,048 dims
50
+ self.desc_dim = 512 + (24 * 64)
51
+ elif "vit_base" in model_type:
52
+ self.aggregator = SALAD(
53
+ num_channels=self.num_channels,
54
+ num_clusters=32,
55
+ cluster_dim=64,
56
+ token_dim=1024,
57
+ dropout=0.3,
58
+ )
59
+ # Output: 1024 + (32 * 64) = 3,072 dims
60
+ self.desc_dim = 1024 + (32 * 64)
61
+ elif "vit_large" in model_type:
62
+ self.aggregator = SALAD(
63
+ num_channels=self.num_channels,
64
+ num_clusters=48,
65
+ cluster_dim=64,
66
+ token_dim=1024,
67
+ dropout=0.3,
68
+ )
69
+ # Output: 1024 + (48 * 64) = 4,096 dims
70
+ self.desc_dim = 1024 + (48 * 64)
71
+ elif aggregator_type == "MixVPR":
72
+ patch_dim = image_size // 14
73
+ if "vit_small" in model_type:
74
+ out_dim = 2048
75
+ elif "vit_base" in model_type:
76
+ out_dim = 3072
77
+ elif "vit_large" in model_type:
78
+ out_dim = 4096
79
+ else:
80
+ # Default or error
81
+ out_dim = 4096
82
+
83
+ self.aggregator = MixVPR(
84
+ in_channels=self.num_channels,
85
+ in_h=patch_dim,
86
+ in_w=patch_dim,
87
+ out_channels=out_dim,
88
+ )
89
+ self.desc_dim = out_dim
90
+
91
+
92
+ # This should be called regardless of the aggregator type.
93
+ self._freeze_parameters()
94
+
95
+ def _freeze_parameters(self):
96
+ """
97
+ Freeze all parameters except the last N transformer blocks and norm layer.
98
+ """
99
+ # First freeze everything
100
+ for param in self.backbone.parameters():
101
+ param.requires_grad = False
102
+
103
+ # Unfreeze the last N blocks
104
+ if self.num_of_layers_to_unfreeze > 0:
105
+ for block in self.backbone.blocks[
106
+ -self.num_of_layers_to_unfreeze :
107
+ ]:
108
+ for param in block.parameters():
109
+ param.requires_grad = True
110
+
111
+ # Unfreeze norm layer
112
+ for param in self.backbone.norm.parameters():
113
+ param.requires_grad = True
114
+
115
+ # Count trainable parameters for backbone
116
+ def count_trainable_params(model):
117
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
118
+
119
+ logging.info(
120
+ f"Number of trainable parameters backbone: {count_trainable_params(self.backbone):,}"
121
+ )
122
+
123
+ # Count aggregator parameters if it exists
124
+ if self.aggregator is not None:
125
+ aggregator_params = count_trainable_params(self.aggregator)
126
+ logging.info(
127
+ f"Number of trainable parameters aggregator: {aggregator_params:,}"
128
+ )
129
+ logging.info(
130
+ f"Total trainable parameters: {count_trainable_params(self.backbone) + aggregator_params:,}"
131
+ )
132
+
133
+ def forward(self, x):
134
+ B, _, H, W = x.shape
135
+ x = self.backbone.forward_features(x)
136
+
137
+ # Consistent handling for register vs. non-register models
138
+ if self.aggregator_type in ["SALAD", "MixVPR"]:
139
+ # DINOv2 with registers has 4 register tokens + 1 CLS token
140
+ # Standard ViT has 1 CLS token
141
+ start_index = 5 if "reg" in self.model_type else 1
142
+ patch_tokens = x[:, start_index:]
143
+
144
+ # Reshape to (B, C, H, W) for aggregators
145
+ patch_tokens_map = patch_tokens.reshape(
146
+ (B, H // 14, W // 14, self.num_channels)
147
+ ).permute(0, 3, 1, 2)
148
+
149
+ if self.aggregator_type == "SALAD":
150
+ cls_token = x[:, 0]
151
+ return self.aggregator((patch_tokens_map, cls_token))
152
+ elif self.aggregator_type == "MixVPR":
153
+ return self.aggregator(patch_tokens_map)
154
+
155
+ # Default behavior: extract features from CLS pooling
156
+ features = self.backbone.forward_head(x, pre_logits=True)
157
+
158
+ # L2 normalization
159
+ return F.normalize(features, p=2, dim=-1)
160
+
161
+
models/mixvpr.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+
5
+ import numpy as np
6
+
7
+
8
+ class FeatureMixerLayer(nn.Module):
9
+ def __init__(self, in_dim, mlp_ratio=1):
10
+ super().__init__()
11
+ self.mix = nn.Sequential(
12
+ nn.LayerNorm(in_dim),
13
+ nn.Linear(in_dim, int(in_dim * mlp_ratio)),
14
+ nn.ReLU(),
15
+ nn.Linear(int(in_dim * mlp_ratio), in_dim),
16
+ )
17
+
18
+ for m in self.modules():
19
+ if isinstance(m, (nn.Linear)):
20
+ nn.init.trunc_normal_(m.weight, std=0.02)
21
+ if m.bias is not None:
22
+ nn.init.zeros_(m.bias)
23
+
24
+ def forward(self, x):
25
+ return x + self.mix(x)
26
+
27
+
28
+ class MixVPR(nn.Module):
29
+ def __init__(self,
30
+ in_channels=1024,
31
+ in_h=20,
32
+ in_w=20,
33
+ out_channels=512,
34
+ mix_depth=1,
35
+ mlp_ratio=1,
36
+ out_rows=4,
37
+ ) -> None:
38
+ super().__init__()
39
+
40
+ self.in_h = in_h # height of input feature maps
41
+ self.in_w = in_w # width of input feature maps
42
+ self.in_channels = in_channels # depth of input feature maps
43
+
44
+ self.out_channels = out_channels # depth wise projection dimension
45
+ self.out_rows = out_rows # row wise projection dimesion
46
+
47
+ self.mix_depth = mix_depth # L the number of stacked FeatureMixers
48
+ self.mlp_ratio = mlp_ratio # ratio of the mid projection layer in the mixer block
49
+
50
+ hw = in_h*in_w
51
+ self.mix = nn.Sequential(*[
52
+ FeatureMixerLayer(in_dim=hw, mlp_ratio=mlp_ratio)
53
+ for _ in range(self.mix_depth)
54
+ ])
55
+ self.channel_proj = nn.Linear(in_channels, out_channels)
56
+ self.row_proj = nn.Linear(hw, out_rows)
57
+
58
+ def forward(self, x):
59
+ x = x.flatten(2)
60
+ x = self.mix(x)
61
+ x = x.permute(0, 2, 1)
62
+ x = self.channel_proj(x)
63
+ x = x.permute(0, 2, 1)
64
+ x = self.row_proj(x)
65
+ x = F.normalize(x.flatten(1), p=2, dim=1)
66
+ return x
67
+
68
+
69
+ # -------------------------------------------------------------------------------
70
+
71
+ def print_nb_params(m):
72
+ model_parameters = filter(lambda p: p.requires_grad, m.parameters())
73
+ params = sum([np.prod(p.size()) for p in model_parameters])
74
+ print(f'Trainable parameters: {params/1e6:.3}M')
75
+
76
+
77
+ def main():
78
+ x = torch.randn(1, 1024, 20, 20)
79
+ agg = MixVPR(
80
+ in_channels=1024,
81
+ in_h=20,
82
+ in_w=20,
83
+ out_channels=1024,
84
+ mix_depth=4,
85
+ mlp_ratio=1,
86
+ out_rows=4)
87
+
88
+ print_nb_params(agg)
89
+ output = agg(x)
90
+ print(output.shape)
91
+
92
+
93
+ if __name__ == '__main__':
94
+ main()
models/salad.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ # Code adapted from OpenGlue, MIT license
6
+ # https://github.com/ucuapps/OpenGlue/blob/main/models/superglue/optimal_transport.py
7
+ def log_otp_solver(log_a, log_b, M, num_iters: int = 20, reg: float = 1.0) -> torch.Tensor:
8
+ r"""Sinkhorn matrix scaling algorithm for Differentiable Optimal Transport problem.
9
+ This function solves the optimization problem and returns the OT matrix for the given parameters.
10
+ Args:
11
+ log_a : torch.Tensor
12
+ Source weights
13
+ log_b : torch.Tensor
14
+ Target weights
15
+ M : torch.Tensor
16
+ metric cost matrix
17
+ num_iters : int, default=100
18
+ The number of iterations.
19
+ reg : float, default=1.0
20
+ regularization value
21
+ """
22
+ M = M / reg # regularization
23
+
24
+ u, v = torch.zeros_like(log_a), torch.zeros_like(log_b)
25
+
26
+ for _ in range(num_iters):
27
+ u = log_a - torch.logsumexp(M + v.unsqueeze(1), dim=2).squeeze()
28
+ v = log_b - torch.logsumexp(M + u.unsqueeze(2), dim=1).squeeze()
29
+
30
+ return M + u.unsqueeze(2) + v.unsqueeze(1)
31
+
32
+ # Code adapted from OpenGlue, MIT license
33
+ # https://github.com/ucuapps/OpenGlue/blob/main/models/superglue/superglue.py
34
+ def get_matching_probs(S, dustbin_score = 1.0, num_iters=3, reg=1.0):
35
+ """sinkhorn"""
36
+ batch_size, m, n = S.size()
37
+ # augment scores matrix
38
+ S_aug = torch.empty(batch_size, m + 1, n, dtype=S.dtype, device=S.device)
39
+ S_aug[:, :m, :n] = S
40
+ S_aug[:, m, :] = dustbin_score
41
+
42
+ # prepare normalized source and target log-weights
43
+ norm = -torch.tensor(math.log(n + m), device=S.device)
44
+ log_a, log_b = norm.expand(m + 1).contiguous(), norm.expand(n).contiguous()
45
+ log_a[-1] = log_a[-1] + math.log(n-m)
46
+ log_a, log_b = log_a.expand(batch_size, -1), log_b.expand(batch_size, -1)
47
+ log_P = log_otp_solver(
48
+ log_a,
49
+ log_b,
50
+ S_aug,
51
+ num_iters=num_iters,
52
+ reg=reg
53
+ )
54
+ return log_P - norm
55
+
56
+
57
+ class SALAD(nn.Module):
58
+ """
59
+ This class represents the Sinkhorn Algorithm for Locally Aggregated Descriptors (SALAD) model.
60
+
61
+ Attributes:
62
+ num_channels (int): The number of channels of the inputs (d).
63
+ num_clusters (int): The number of clusters in the model (m).
64
+ cluster_dim (int): The number of channels of the clusters (l).
65
+ token_dim (int): The dimension of the global scene token (g).
66
+ dropout (float): The dropout rate.
67
+ """
68
+ def __init__(self,
69
+ num_channels=1536,
70
+ num_clusters=64,
71
+ cluster_dim=128,
72
+ token_dim=256,
73
+ dropout=0.3,
74
+ ) -> None:
75
+ super().__init__()
76
+
77
+ self.num_channels = num_channels
78
+ self.num_clusters= num_clusters
79
+ self.cluster_dim = cluster_dim
80
+ self.token_dim = token_dim
81
+
82
+ if dropout > 0:
83
+ dropout = nn.Dropout(dropout)
84
+ else:
85
+ dropout = nn.Identity()
86
+
87
+ # MLP for global scene token g
88
+ self.token_features = nn.Sequential(
89
+ nn.Linear(self.num_channels, 512),
90
+ nn.ReLU(),
91
+ nn.Linear(512, self.token_dim)
92
+ )
93
+ # MLP for local features f_i
94
+ self.cluster_features = nn.Sequential(
95
+ nn.Conv2d(self.num_channels, 512, 1),
96
+ dropout,
97
+ nn.ReLU(),
98
+ nn.Conv2d(512, self.cluster_dim, 1)
99
+ )
100
+ # MLP for score matrix S
101
+ self.score = nn.Sequential(
102
+ nn.Conv2d(self.num_channels, 512, 1),
103
+ dropout,
104
+ nn.ReLU(),
105
+ nn.Conv2d(512, self.num_clusters, 1),
106
+ )
107
+ # Dustbin parameter z
108
+ self.dust_bin = nn.Parameter(torch.tensor(1.))
109
+
110
+
111
+ def forward(self, x):
112
+ """
113
+ x (tuple): A tuple containing two elements, f and t.
114
+ (torch.Tensor): The feature tensors (t_i) [B, C, H // 14, W // 14].
115
+ (torch.Tensor): The token tensor (t_{n+1}) [B, C].
116
+
117
+ Returns:
118
+ f (torch.Tensor): The global descriptor [B, m*l + g]
119
+ """
120
+ x, t = x # Extract features and token
121
+
122
+ f = self.cluster_features(x).flatten(2)
123
+ p = self.score(x).flatten(2)
124
+ t = self.token_features(t)
125
+
126
+ # Sinkhorn algorithm
127
+ p = get_matching_probs(p, self.dust_bin, 3)
128
+ p = torch.exp(p)
129
+ # discard the dustbin
130
+ p = p[:, :-1, :]
131
+
132
+
133
+ p = p.unsqueeze(1).repeat(1, self.cluster_dim, 1, 1)
134
+ f = f.unsqueeze(2).repeat(1, 1, self.num_clusters, 1)
135
+
136
+ f = torch.cat([
137
+ nn.functional.normalize(t, p=2, dim=-1),
138
+ nn.functional.normalize((f * p).sum(dim=-1), p=2, dim=1).flatten(1)
139
+ ], dim=-1)
140
+
141
+ return nn.functional.normalize(f, p=2, dim=-1)
weights/best_model_95.6.torch ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6cea6330719dee2b63e70438a2addc7f85242737a5079d6b88af10f7794669b
3
+ size 353426618