Continual-Mega commited on
Commit
1b8e995
·
verified ·
1 Parent(s): a07ade3

Upload CLIP/modified_resnet.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. CLIP/modified_resnet.py +218 -0
CLIP/modified_resnet.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+
8
+ def freeze_batch_norm_2d(module, module_match={}, name=''):
9
+ """
10
+ Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
11
+ itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
12
+ returned. Otherwise, the module is walked recursively and submodules are converted in place.
13
+
14
+ Args:
15
+ module (torch.nn.Module): Any PyTorch module.
16
+ module_match (dict): Dictionary of full module names to freeze (all if empty)
17
+ name (str): Full module name (prefix)
18
+
19
+ Returns:
20
+ torch.nn.Module: Resulting module
21
+
22
+ Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
23
+ """
24
+ res = module
25
+ is_match = True
26
+ if module_match:
27
+ is_match = name in module_match
28
+ if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
29
+ res = FrozenBatchNorm2d(module.num_features)
30
+ res.num_features = module.num_features
31
+ res.affine = module.affine
32
+ if module.affine:
33
+ res.weight.data = module.weight.data.clone().detach()
34
+ res.bias.data = module.bias.data.clone().detach()
35
+ res.running_mean.data = module.running_mean.data
36
+ res.running_var.data = module.running_var.data
37
+ res.eps = module.eps
38
+ else:
39
+ for child_name, child in module.named_children():
40
+ full_child_name = '.'.join([name, child_name]) if name else child_name
41
+ new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
42
+ if new_child is not child:
43
+ res.add_module(child_name, new_child)
44
+ return res
45
+
46
+
47
+ class Bottleneck(nn.Module):
48
+ expansion = 4
49
+
50
+ def __init__(self, inplanes, planes, stride=1):
51
+ super().__init__()
52
+
53
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
54
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
55
+ self.bn1 = nn.BatchNorm2d(planes)
56
+ self.act1 = nn.ReLU(inplace=True)
57
+
58
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
59
+ self.bn2 = nn.BatchNorm2d(planes)
60
+ self.act2 = nn.ReLU(inplace=True)
61
+
62
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
63
+
64
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
65
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
66
+ self.act3 = nn.ReLU(inplace=True)
67
+
68
+ self.downsample = None
69
+ self.stride = stride
70
+
71
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
72
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
73
+ self.downsample = nn.Sequential(OrderedDict([
74
+ ("-1", nn.AvgPool2d(stride)),
75
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
76
+ ("1", nn.BatchNorm2d(planes * self.expansion))
77
+ ]))
78
+
79
+ def forward(self, x: torch.Tensor):
80
+ identity = x
81
+
82
+ out = self.act1(self.bn1(self.conv1(x)))
83
+ out = self.act2(self.bn2(self.conv2(out)))
84
+ out = self.avgpool(out)
85
+ out = self.bn3(self.conv3(out))
86
+
87
+ if self.downsample is not None:
88
+ identity = self.downsample(x)
89
+
90
+ out += identity
91
+ out = self.act3(out)
92
+ return out
93
+
94
+
95
+ class AttentionPool2d(nn.Module):
96
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
97
+ super().__init__()
98
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
99
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
100
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
101
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
102
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
103
+ self.num_heads = num_heads
104
+
105
+ def forward(self, x):
106
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
107
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
108
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
109
+ x, _ = F.multi_head_attention_forward(
110
+ query=x, key=x, value=x,
111
+ embed_dim_to_check=x.shape[-1],
112
+ num_heads=self.num_heads,
113
+ q_proj_weight=self.q_proj.weight,
114
+ k_proj_weight=self.k_proj.weight,
115
+ v_proj_weight=self.v_proj.weight,
116
+ in_proj_weight=None,
117
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
118
+ bias_k=None,
119
+ bias_v=None,
120
+ add_zero_attn=False,
121
+ dropout_p=0.,
122
+ out_proj_weight=self.c_proj.weight,
123
+ out_proj_bias=self.c_proj.bias,
124
+ use_separate_proj_weight=True,
125
+ training=self.training,
126
+ need_weights=False
127
+ )
128
+
129
+ return x[0]
130
+
131
+
132
+ class ModifiedResNet(nn.Module):
133
+ """
134
+ A ResNet class that is similar to torchvision's but contains the following changes:
135
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
136
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
137
+ - The final pooling layer is a QKV attention instead of an average pool
138
+ """
139
+
140
+ def __init__(self, layers, output_dim, heads, image_size=224, width=64):
141
+ super().__init__()
142
+ self.output_dim = output_dim
143
+ self.image_size = image_size
144
+
145
+ # the 3-layer stem
146
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
147
+ self.bn1 = nn.BatchNorm2d(width // 2)
148
+ self.act1 = nn.ReLU(inplace=True)
149
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
150
+ self.bn2 = nn.BatchNorm2d(width // 2)
151
+ self.act2 = nn.ReLU(inplace=True)
152
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
153
+ self.bn3 = nn.BatchNorm2d(width)
154
+ self.act3 = nn.ReLU(inplace=True)
155
+ self.avgpool = nn.AvgPool2d(2)
156
+
157
+ # residual layers
158
+ self._inplanes = width # this is a *mutable* variable used during construction
159
+ self.layer1 = self._make_layer(width, layers[0])
160
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
161
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
162
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
163
+
164
+ embed_dim = width * 32 # the ResNet feature dimension
165
+ self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
166
+
167
+ self.init_parameters()
168
+
169
+ def _make_layer(self, planes, blocks, stride=1):
170
+ layers = [Bottleneck(self._inplanes, planes, stride)]
171
+
172
+ self._inplanes = planes * Bottleneck.expansion
173
+ for _ in range(1, blocks):
174
+ layers.append(Bottleneck(self._inplanes, planes))
175
+
176
+ return nn.Sequential(*layers)
177
+
178
+ def init_parameters(self):
179
+ if self.attnpool is not None:
180
+ std = self.attnpool.c_proj.in_features ** -0.5
181
+ nn.init.normal_(self.attnpool.q_proj.weight, std=std)
182
+ nn.init.normal_(self.attnpool.k_proj.weight, std=std)
183
+ nn.init.normal_(self.attnpool.v_proj.weight, std=std)
184
+ nn.init.normal_(self.attnpool.c_proj.weight, std=std)
185
+
186
+ for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
187
+ for name, param in resnet_block.named_parameters():
188
+ if name.endswith("bn3.weight"):
189
+ nn.init.zeros_(param)
190
+
191
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
192
+ assert unlocked_groups == 0, 'partial locking not currently supported for this model'
193
+ for param in self.parameters():
194
+ param.requires_grad = False
195
+ if freeze_bn_stats:
196
+ freeze_batch_norm_2d(self)
197
+
198
+ @torch.jit.ignore
199
+ def set_grad_checkpointing(self, enable=True):
200
+ # FIXME support for non-transformer
201
+ pass
202
+
203
+ def stem(self, x):
204
+ x = self.act1(self.bn1(self.conv1(x)))
205
+ x = self.act2(self.bn2(self.conv2(x)))
206
+ x = self.act3(self.bn3(self.conv3(x)))
207
+ x = self.avgpool(x)
208
+ return x
209
+
210
+ def forward(self, x):
211
+ x = self.stem(x)
212
+ x = self.layer1(x)
213
+ x = self.layer2(x)
214
+ x = self.layer3(x)
215
+ x = self.layer4(x)
216
+ x = self.attnpool(x)
217
+
218
+ return x