iclr2025-anonymous commited on
Commit
6a06696
·
verified ·
1 Parent(s): 546bb84

Upload 2 files

Browse files
Files changed (2) hide show
  1. model.py +108 -0
  2. vits.py +177 -0
model.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ from torch import nn, Tensor
4
+ from loguru import logger
5
+ from pathlib import Path
6
+
7
+ from torchvision.transforms import ToTensor
8
+ from torchvision.transforms.v2 import CenterCrop, Compose, Normalize
9
+
10
+
11
+ import vits
12
+
13
+ def _clean_moco_state_dict(state_dict: dict[str, Tensor], linear_keyword: str) -> dict[str, Tensor]:
14
+ """
15
+ Filters and renames keys from a MoCo state_dict.
16
+
17
+ It selects keys from the 'base_encoder', removes the given linear layer keyword,
18
+ and strips the 'module.base_encoder.' prefix.
19
+ """
20
+ for key in list(state_dict.keys()):
21
+ # Check if the key belongs to the base encoder's backbone
22
+ if key.startswith('module.base_encoder') and not key.startswith(f'module.base_encoder.{linear_keyword}'):
23
+ # Create a new key by stripping the prefix
24
+ new_key = key[len("module.base_encoder."):]
25
+ state_dict[new_key] = state_dict[key]
26
+
27
+ # Delete the old key (either renamed or unused)
28
+ del state_dict[key]
29
+
30
+ return state_dict
31
+
32
+ def load_moco_encoder(
33
+ model: nn.Module,
34
+ weight_path: Path,
35
+ linear_keyword: str,
36
+ ) -> nn.Module:
37
+ """
38
+ Loads pre-trained MoCo weights into a given model instance (ResNet, ViT, etc.).
39
+
40
+ This function handles loading the checkpoint, cleaning the state dictionary keys,
41
+ and loading the weights into the model's backbone. It finishes by replacing
42
+ the model's linear head with an Identity layer to turn it into a feature extractor.
43
+
44
+ Args:
45
+ model: An instantiated PyTorch model (e.g., from timm or a custom module).
46
+ weight_path: Path to the .pth or .pt MoCo checkpoint file.
47
+ linear_keyword: The name of the final linear layer to exclude (e.g., 'fc' or 'head').
48
+
49
+ Returns:
50
+ The same model, with pre-trained backbone weights and the head replaced
51
+ by nn.Identity(), ready for feature extraction.
52
+ """
53
+ assert weight_path.exists(), f"Checkpoint not found at '{weight_path}'"
54
+ logger.info(f"=> Loading MoCo checkpoint from '{weight_path}'")
55
+
56
+ # Use weights_only=True for added security if the checkpoint doesn't contain pickled code
57
+ checkpoint = torch.load(weight_path, map_location="cpu", weights_only=True)
58
+
59
+ # Extract the state dictionary containing the model weights
60
+ state_dict = checkpoint["state_dict"]
61
+
62
+ # Clean the state_dict to match the model's architecture
63
+ cleaned_state_dict = _clean_moco_state_dict(state_dict, linear_keyword)
64
+
65
+ # Load the cleaned weights into the model
66
+ msg = model.load_state_dict(cleaned_state_dict, strict=False)
67
+ logger.info(msg)
68
+ logger.info("=> Successfully loaded pre-trained model backbone.")
69
+
70
+ # Replace the model's head to turn it into a feature extractor
71
+ if hasattr(model, linear_keyword):
72
+ setattr(model, linear_keyword, nn.Identity())
73
+ logger.info(f"=> Model's '{linear_keyword}' layer replaced with nn.Identity for feature extraction.")
74
+
75
+ return model
76
+
77
+ def get_vit_feature_extractor(weight_path: Path, model_name: str = "vits8", img_size: int = 40) -> nn.Module:
78
+ """Creates a ViT feature extractor using the unified loader."""
79
+ # 1. Create the model architecture shell
80
+ vit_model = vits.__dict__[model_name](img_size=img_size, num_classes=0)
81
+
82
+ # 2. Use the unified function to load weights and prepare for feature extraction
83
+ feature_extractor = load_moco_encoder(
84
+ model=vit_model,
85
+ weight_path=weight_path,
86
+ linear_keyword='head'
87
+ )
88
+ return feature_extractor
89
+
90
+
91
+ def prepare_transform(
92
+ stats_path,
93
+ size: int = 40,
94
+ ) -> Compose:
95
+ # Get normalisation stats
96
+ with open(stats_path, "r") as f:
97
+ norm_dict = json.load(f)
98
+ mean = norm_dict["mean"]
99
+ std = norm_dict["std"]
100
+
101
+ # Prepare transform
102
+ list_transform = [
103
+ ToTensor(),
104
+ Normalize(mean=mean, std=std),
105
+ CenterCrop(size=size),
106
+ ]
107
+ transform = Compose(list_transform)
108
+ return transform
vits.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import torch
9
+ import torch.nn as nn
10
+ from functools import partial, reduce
11
+ from operator import mul
12
+
13
+ from timm.layers import to_2tuple
14
+ from timm.models.vision_transformer import VisionTransformer, _cfg
15
+ from timm.layers import PatchEmbed
16
+
17
+ __all__ = [
18
+ "vits4",
19
+ "vits8",
20
+ "vitb4",
21
+ "vitb8",
22
+ ]
23
+
24
+
25
+ class VisionTransformerMoCo(VisionTransformer):
26
+ def __init__(self, stop_grad_conv1=False, **kwargs):
27
+ super().__init__(**kwargs)
28
+ # Use fixed 2D sin-cos position embedding
29
+ self.build_2d_sincos_position_embedding()
30
+
31
+ # weight initialization
32
+ for name, m in self.named_modules():
33
+ if isinstance(m, nn.Linear):
34
+ if "qkv" in name:
35
+ # treat the weights of Q, K, V separately
36
+ val = math.sqrt(
37
+ 6.0 / float(m.weight.shape[0] // 3 + m.weight.shape[1])
38
+ )
39
+ nn.init.uniform_(m.weight, -val, val)
40
+ else:
41
+ nn.init.xavier_uniform_(m.weight)
42
+ nn.init.zeros_(m.bias)
43
+ nn.init.normal_(self.cls_token, std=1e-6)
44
+
45
+ if isinstance(self.patch_embed, PatchEmbed):
46
+ # xavier_uniform initialization
47
+ val = math.sqrt(
48
+ 6.0
49
+ / float(
50
+ 3 * reduce(mul, self.patch_embed.patch_size, 1) + self.embed_dim
51
+ )
52
+ )
53
+ nn.init.uniform_(self.patch_embed.proj.weight, -val, val)
54
+ nn.init.zeros_(self.patch_embed.proj.bias)
55
+
56
+ if stop_grad_conv1:
57
+ self.patch_embed.proj.weight.requires_grad = False
58
+ self.patch_embed.proj.bias.requires_grad = False
59
+
60
+ def build_2d_sincos_position_embedding(self, temperature=10000.0):
61
+ h, w = self.patch_embed.grid_size
62
+ grid_w = torch.arange(w, dtype=torch.float32)
63
+ grid_h = torch.arange(h, dtype=torch.float32)
64
+ grid_w, grid_h = torch.meshgrid(grid_w, grid_h)
65
+ assert (
66
+ self.embed_dim % 4 == 0
67
+ ), "Embed dimension must be divisible by 4 for 2D sin-cos position embedding"
68
+ pos_dim = self.embed_dim // 4
69
+ omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
70
+ omega = 1.0 / (temperature**omega)
71
+ out_w = torch.einsum("m,d->md", [grid_w.flatten(), omega])
72
+ out_h = torch.einsum("m,d->md", [grid_h.flatten(), omega])
73
+ pos_emb = torch.cat(
74
+ [torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)],
75
+ dim=1,
76
+ )[None, :, :]
77
+
78
+ pe_token = torch.zeros([1, self.num_prefix_tokens, self.embed_dim], dtype=torch.float32)
79
+ self.pos_embed = nn.Parameter(torch.cat([pe_token, pos_emb], dim=1))
80
+ self.pos_embed.requires_grad = False
81
+
82
+
83
+ class ConvStem(nn.Module):
84
+ """
85
+ ConvStem, from Early Convolutions Help Transformers See Better, Tete et al. https://arxiv.org/abs/2106.14881
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ img_size=224,
91
+ patch_size=16,
92
+ in_chans=3,
93
+ embed_dim=768,
94
+ norm_layer=None,
95
+ flatten=True,
96
+ ):
97
+ super().__init__()
98
+
99
+ assert patch_size == 16, "ConvStem only supports patch size of 16"
100
+ assert embed_dim % 8 == 0, "Embed dimension must be divisible by 8 for ConvStem"
101
+
102
+ img_size = to_2tuple(img_size)
103
+ patch_size = to_2tuple(patch_size)
104
+ self.img_size = img_size
105
+ self.patch_size = patch_size
106
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
107
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
108
+ self.flatten = flatten
109
+
110
+ # build stem, similar to the design in https://arxiv.org/abs/2106.14881
111
+ stem = []
112
+ input_dim, output_dim = 3, embed_dim // 8
113
+ for l in range(4):
114
+ stem.append(
115
+ nn.Conv2d(
116
+ input_dim,
117
+ output_dim,
118
+ kernel_size=3,
119
+ stride=2,
120
+ padding=1,
121
+ bias=False,
122
+ )
123
+ )
124
+ stem.append(nn.BatchNorm2d(output_dim))
125
+ stem.append(nn.ReLU(inplace=True))
126
+ input_dim = output_dim
127
+ output_dim *= 2
128
+ stem.append(nn.Conv2d(input_dim, embed_dim, kernel_size=1))
129
+ self.proj = nn.Sequential(*stem)
130
+
131
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
132
+
133
+ def forward(self, x):
134
+ B, C, H, W = x.shape
135
+ assert (
136
+ H == self.img_size[0] and W == self.img_size[1]
137
+ ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
138
+ x = self.proj(x)
139
+ if self.flatten:
140
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
141
+ x = self.norm(x)
142
+ return x
143
+
144
+
145
+ def vits(patch_size: int, **kwargs):
146
+ model = VisionTransformerMoCo(
147
+ patch_size=patch_size,
148
+ embed_dim=384,
149
+ depth=12,
150
+ num_heads=12,
151
+ mlp_ratio=4,
152
+ qkv_bias=True,
153
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
154
+ **kwargs,
155
+ )
156
+ model.default_cfg = _cfg()
157
+ return model
158
+
159
+ vits4 = partial(vits, patch_size=4)
160
+ vits8 = partial(vits, patch_size=8)
161
+
162
+ def vitb(patch_size: int, **kwargs):
163
+ model = VisionTransformerMoCo(
164
+ patch_size=patch_size,
165
+ embed_dim=768,
166
+ depth=12,
167
+ num_heads=12,
168
+ mlp_ratio=4,
169
+ qkv_bias=True,
170
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
171
+ **kwargs,
172
+ )
173
+ model.default_cfg = _cfg()
174
+ return model
175
+
176
+ vitb4 = partial(vitb, patch_size=4)
177
+ vitb8 = partial(vitb, patch_size=8)