Vidit2003 commited on
Commit
4842e7d
·
verified ·
1 Parent(s): 791a96b

Update vision_transformer.py

Browse files
Files changed (1) hide show
  1. vision_transformer.py +164 -120
vision_transformer.py CHANGED
@@ -1,43 +1,41 @@
1
- # Copyright (c) Facebook, Inc. and its affiliates.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- """
15
- Mostly copy-paste from timm library.
16
- https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
17
- """
18
  import math
19
  from functools import partial
20
-
21
  import torch
22
  import torch.nn as nn
23
-
24
- from utils import trunc_normal_
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
 
27
  def drop_path(x, drop_prob: float = 0., training: bool = False):
28
  if drop_prob == 0. or not training:
29
  return x
30
  keep_prob = 1 - drop_prob
31
- shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
32
  random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
33
- random_tensor.floor_() # binarize
34
  output = x.div(keep_prob) * random_tensor
35
  return output
36
 
37
 
38
  class DropPath(nn.Module):
39
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
40
- """
41
  def __init__(self, drop_prob=None):
42
  super(DropPath, self).__init__()
43
  self.drop_prob = drop_prob
@@ -71,7 +69,6 @@ class Attention(nn.Module):
71
  self.num_heads = num_heads
72
  head_dim = dim // num_heads
73
  self.scale = qk_scale or head_dim ** -0.5
74
-
75
  self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
76
  self.attn_drop = nn.Dropout(attn_drop)
77
  self.proj = nn.Linear(dim, dim)
@@ -81,11 +78,9 @@ class Attention(nn.Module):
81
  B, N, C = x.shape
82
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
83
  q, k, v = qkv[0], qkv[1], qkv[2]
84
-
85
  attn = (q @ k.transpose(-2, -1)) * self.scale
86
  attn = attn.softmax(dim=-1)
87
  attn = self.attn_drop(attn)
88
-
89
  x = (attn @ v).transpose(1, 2).reshape(B, N, C)
90
  x = self.proj(x)
91
  x = self.proj_drop(x)
@@ -114,15 +109,13 @@ class Block(nn.Module):
114
 
115
 
116
  class PatchEmbed(nn.Module):
117
- """ Image to Patch Embedding
118
- """
119
  def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
120
  super().__init__()
121
  num_patches = (img_size // patch_size) * (img_size // patch_size)
122
  self.img_size = img_size
123
  self.patch_size = patch_size
124
  self.num_patches = num_patches
125
-
126
  self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
127
 
128
  def forward(self, x):
@@ -131,37 +124,109 @@ class PatchEmbed(nn.Module):
131
  return x
132
 
133
 
134
- class VisionTransformer(nn.Module):
135
- """ Vision Transformer """
136
- def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
137
- num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
138
- drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
139
- super().__init__()
140
- self.num_features = self.embed_dim = embed_dim
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
- self.patch_embed = PatchEmbed(
143
- img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
144
- num_patches = self.patch_embed.num_patches
145
 
146
- self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
147
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
148
- self.pos_drop = nn.Dropout(p=drop_rate)
149
 
150
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  self.blocks = nn.ModuleList([
152
  Block(
153
- dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
154
- drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
155
- for i in range(depth)])
156
- self.norm = norm_layer(embed_dim)
157
-
 
 
 
 
 
 
 
 
 
158
  # Classifier head
159
- self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
160
-
 
161
  trunc_normal_(self.pos_embed, std=.02)
162
  trunc_normal_(self.cls_token, std=.02)
163
  self.apply(self._init_weights)
164
-
165
  def _init_weights(self, m):
166
  if isinstance(m, nn.Linear):
167
  trunc_normal_(m.weight, std=.02)
@@ -170,7 +235,7 @@ class VisionTransformer(nn.Module):
170
  elif isinstance(m, nn.LayerNorm):
171
  nn.init.constant_(m.bias, 0)
172
  nn.init.constant_(m.weight, 1.0)
173
-
174
  def interpolate_pos_encoding(self, x, w, h):
175
  npatch = x.shape[1] - 1
176
  N = self.pos_embed.shape[1] - 1
@@ -181,8 +246,6 @@ class VisionTransformer(nn.Module):
181
  dim = x.shape[-1]
182
  w0 = w // self.patch_embed.patch_size
183
  h0 = h // self.patch_embed.patch_size
184
- # we add a small number to avoid floating point error in the interpolation
185
- # see discussion at https://github.com/facebookresearch/dino/issues/8
186
  w0, h0 = w0 + 0.1, h0 + 0.1
187
  patch_pos_embed = nn.functional.interpolate(
188
  patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
@@ -192,39 +255,67 @@ class VisionTransformer(nn.Module):
192
  assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
193
  patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
194
  return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
195
-
196
  def prepare_tokens(self, x):
197
  B, nc, w, h = x.shape
198
- x = self.patch_embed(x) # patch linear embedding
199
-
200
- # add the [CLS] token to the embed patch tokens
201
  cls_tokens = self.cls_token.expand(B, -1, -1)
202
  x = torch.cat((cls_tokens, x), dim=1)
203
-
204
- # add positional encoding to each token
205
  x = x + self.interpolate_pos_encoding(x, w, h)
206
-
207
  return self.pos_drop(x)
208
-
209
- def forward(self, x):
210
- x = self.prepare_tokens(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  for blk in self.blocks:
212
  x = blk(x)
 
213
  x = self.norm(x)
214
- return x[:, 0]
215
-
 
 
 
 
 
 
 
 
 
 
 
216
  def get_last_selfattention(self, x):
 
217
  x = self.prepare_tokens(x)
218
  for i, blk in enumerate(self.blocks):
219
  if i < len(self.blocks) - 1:
220
  x = blk(x)
221
  else:
222
- # return attention of the last block
223
  return blk(x, return_attention=True)
224
-
225
  def get_intermediate_layers(self, x, n=1):
 
226
  x = self.prepare_tokens(x)
227
- # we return the output tokens from the `n` last blocks
228
  output = []
229
  for i, blk in enumerate(self.blocks):
230
  x = blk(x)
@@ -233,53 +324,6 @@ class VisionTransformer(nn.Module):
233
  return output
234
 
235
 
236
-
237
- def vit_small(patch_size=16, **kwargs):
238
- model = VisionTransformer(
239
- patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
240
- qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
241
- return model
242
-
243
-
244
- def vit_base(patch_size=16, **kwargs):
245
- model = VisionTransformer(
246
- patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
247
- qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
248
- return model
249
-
250
-
251
- class DINOHead(nn.Module):
252
- def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
253
- super().__init__()
254
- nlayers = max(nlayers, 1)
255
- if nlayers == 1:
256
- self.mlp = nn.Linear(in_dim, bottleneck_dim)
257
- else:
258
- layers = [nn.Linear(in_dim, hidden_dim)]
259
- if use_bn:
260
- layers.append(nn.BatchNorm1d(hidden_dim))
261
- layers.append(nn.GELU())
262
- for _ in range(nlayers - 2):
263
- layers.append(nn.Linear(hidden_dim, hidden_dim))
264
- if use_bn:
265
- layers.append(nn.BatchNorm1d(hidden_dim))
266
- layers.append(nn.GELU())
267
- layers.append(nn.Linear(hidden_dim, bottleneck_dim))
268
- self.mlp = nn.Sequential(*layers)
269
- self.apply(self._init_weights)
270
- self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
271
- self.last_layer.weight_g.data.fill_(1)
272
- if norm_last_layer:
273
- self.last_layer.weight_g.requires_grad = False
274
-
275
- def _init_weights(self, m):
276
- if isinstance(m, nn.Linear):
277
- trunc_normal_(m.weight, std=.02)
278
- if isinstance(m, nn.Linear) and m.bias is not None:
279
- nn.init.constant_(m.bias, 0)
280
-
281
- def forward(self, x):
282
- x = self.mlp(x)
283
- x = nn.functional.normalize(x, dim=-1, p=2)
284
- x = self.last_layer(x)
285
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import math
2
  from functools import partial
 
3
  import torch
4
  import torch.nn as nn
5
+ from transformers import PretrainedConfig, PreTrainedModel
6
+ from transformers.modeling_outputs import BaseModelOutput
7
+ from typing import Optional, Tuple, Union
8
+
9
+
10
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
11
+ """Truncated normal initialization (from timm library)"""
12
+ def norm_cdf(x):
13
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
14
+
15
+ with torch.no_grad():
16
+ l = norm_cdf((a - mean) / std)
17
+ u = norm_cdf((b - mean) / std)
18
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
19
+ tensor.erfinv_()
20
+ tensor.mul_(std * math.sqrt(2.))
21
+ tensor.add_(mean)
22
+ tensor.clamp_(min=a, max=b)
23
+ return tensor
24
 
25
 
26
  def drop_path(x, drop_prob: float = 0., training: bool = False):
27
  if drop_prob == 0. or not training:
28
  return x
29
  keep_prob = 1 - drop_prob
30
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
31
  random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
32
+ random_tensor.floor_()
33
  output = x.div(keep_prob) * random_tensor
34
  return output
35
 
36
 
37
  class DropPath(nn.Module):
38
+ """Drop paths (Stochastic Depth) per sample"""
 
39
  def __init__(self, drop_prob=None):
40
  super(DropPath, self).__init__()
41
  self.drop_prob = drop_prob
 
69
  self.num_heads = num_heads
70
  head_dim = dim // num_heads
71
  self.scale = qk_scale or head_dim ** -0.5
 
72
  self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
73
  self.attn_drop = nn.Dropout(attn_drop)
74
  self.proj = nn.Linear(dim, dim)
 
78
  B, N, C = x.shape
79
  qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
80
  q, k, v = qkv[0], qkv[1], qkv[2]
 
81
  attn = (q @ k.transpose(-2, -1)) * self.scale
82
  attn = attn.softmax(dim=-1)
83
  attn = self.attn_drop(attn)
 
84
  x = (attn @ v).transpose(1, 2).reshape(B, N, C)
85
  x = self.proj(x)
86
  x = self.proj_drop(x)
 
109
 
110
 
111
  class PatchEmbed(nn.Module):
112
+ """ Image to Patch Embedding """
 
113
  def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
114
  super().__init__()
115
  num_patches = (img_size // patch_size) * (img_size // patch_size)
116
  self.img_size = img_size
117
  self.patch_size = patch_size
118
  self.num_patches = num_patches
 
119
  self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
120
 
121
  def forward(self, x):
 
124
  return x
125
 
126
 
127
+ # ============================================================================
128
+ # HUGGING FACE CONFIGURATION CLASS (REQUIRED)
129
+ # ============================================================================
130
+
131
+ class VisionTransformerConfig(PretrainedConfig):
132
+ """Configuration for Vision Transformer model"""
133
+
134
+ model_type = "vit"
135
+
136
+ def __init__(
137
+ self,
138
+ img_size=224,
139
+ patch_size=16,
140
+ in_chans=3,
141
+ num_classes=0,
142
+ embed_dim=768,
143
+ depth=12,
144
+ num_heads=12,
145
+ mlp_ratio=4.0,
146
+ qkv_bias=True,
147
+ qk_scale=None,
148
+ drop_rate=0.0,
149
+ attn_drop_rate=0.0,
150
+ drop_path_rate=0.0,
151
+ **kwargs
152
+ ):
153
+ super().__init__(**kwargs)
154
+ self.img_size = img_size
155
+ self.patch_size = patch_size
156
+ self.in_chans = in_chans
157
+ self.num_classes = num_classes
158
+ self.embed_dim = embed_dim
159
+ self.depth = depth
160
+ self.num_heads = num_heads
161
+ self.mlp_ratio = mlp_ratio
162
+ self.qkv_bias = qkv_bias
163
+ self.qk_scale = qk_scale
164
+ self.drop_rate = drop_rate
165
+ self.attn_drop_rate = attn_drop_rate
166
+ self.drop_path_rate = drop_path_rate
167
 
 
 
 
168
 
169
+ # ============================================================================
170
+ # HUGGING FACE COMPATIBLE WRAPPER (REQUIRED)
171
+ # ============================================================================
172
 
173
+ class VisionTransformer(PreTrainedModel):
174
+ """
175
+ Vision Transformer - Hugging Face compatible wrapper
176
+
177
+ This wraps the original VisionTransformer to make it compatible with
178
+ Hugging Face's AutoModel.from_pretrained()
179
+ """
180
+
181
+ config_class = VisionTransformerConfig
182
+ base_model_prefix = "vit"
183
+ main_input_name = "pixel_values"
184
+
185
+ def __init__(self, config):
186
+ super().__init__(config)
187
+ self.config = config
188
+
189
+ # Initialize the core Vision Transformer components
190
+ self.num_features = self.embed_dim = config.embed_dim
191
+
192
+ self.patch_embed = PatchEmbed(
193
+ img_size=config.img_size,
194
+ patch_size=config.patch_size,
195
+ in_chans=config.in_chans,
196
+ embed_dim=config.embed_dim
197
+ )
198
+ num_patches = self.patch_embed.num_patches
199
+
200
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.embed_dim))
201
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, config.embed_dim))
202
+ self.pos_drop = nn.Dropout(p=config.drop_rate)
203
+
204
+ # Stochastic depth decay rule
205
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.depth)]
206
  self.blocks = nn.ModuleList([
207
  Block(
208
+ dim=config.embed_dim,
209
+ num_heads=config.num_heads,
210
+ mlp_ratio=config.mlp_ratio,
211
+ qkv_bias=config.qkv_bias,
212
+ qk_scale=config.qk_scale,
213
+ drop=config.drop_rate,
214
+ attn_drop=config.attn_drop_rate,
215
+ drop_path=dpr[i],
216
+ norm_layer=nn.LayerNorm
217
+ )
218
+ for i in range(config.depth)
219
+ ])
220
+ self.norm = nn.LayerNorm(config.embed_dim)
221
+
222
  # Classifier head
223
+ self.head = nn.Linear(config.embed_dim, config.num_classes) if config.num_classes > 0 else nn.Identity()
224
+
225
+ # Initialize weights
226
  trunc_normal_(self.pos_embed, std=.02)
227
  trunc_normal_(self.cls_token, std=.02)
228
  self.apply(self._init_weights)
229
+
230
  def _init_weights(self, m):
231
  if isinstance(m, nn.Linear):
232
  trunc_normal_(m.weight, std=.02)
 
235
  elif isinstance(m, nn.LayerNorm):
236
  nn.init.constant_(m.bias, 0)
237
  nn.init.constant_(m.weight, 1.0)
238
+
239
  def interpolate_pos_encoding(self, x, w, h):
240
  npatch = x.shape[1] - 1
241
  N = self.pos_embed.shape[1] - 1
 
246
  dim = x.shape[-1]
247
  w0 = w // self.patch_embed.patch_size
248
  h0 = h // self.patch_embed.patch_size
 
 
249
  w0, h0 = w0 + 0.1, h0 + 0.1
250
  patch_pos_embed = nn.functional.interpolate(
251
  patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
 
255
  assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
256
  patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
257
  return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
258
+
259
  def prepare_tokens(self, x):
260
  B, nc, w, h = x.shape
261
+ x = self.patch_embed(x)
 
 
262
  cls_tokens = self.cls_token.expand(B, -1, -1)
263
  x = torch.cat((cls_tokens, x), dim=1)
 
 
264
  x = x + self.interpolate_pos_encoding(x, w, h)
 
265
  return self.pos_drop(x)
266
+
267
+ def forward(
268
+ self,
269
+ pixel_values: Optional[torch.FloatTensor] = None,
270
+ output_attentions: Optional[bool] = None,
271
+ output_hidden_states: Optional[bool] = None,
272
+ return_dict: Optional[bool] = None,
273
+ ) -> Union[Tuple, BaseModelOutput]:
274
+ """
275
+ Forward pass compatible with Hugging Face
276
+
277
+ Args:
278
+ pixel_values: Input images (batch_size, channels, height, width)
279
+ output_attentions: Whether to return attention weights
280
+ output_hidden_states: Whether to return all hidden states
281
+ return_dict: Whether to return BaseModelOutput
282
+
283
+ Returns:
284
+ BaseModelOutput or tuple
285
+ """
286
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
287
+
288
+ x = self.prepare_tokens(pixel_values)
289
+
290
  for blk in self.blocks:
291
  x = blk(x)
292
+
293
  x = self.norm(x)
294
+
295
+ # Return CLS token output
296
+ pooled_output = x[:, 0]
297
+
298
+ if not return_dict:
299
+ return (x, pooled_output)
300
+
301
+ return BaseModelOutput(
302
+ last_hidden_state=x,
303
+ hidden_states=None,
304
+ attentions=None,
305
+ )
306
+
307
  def get_last_selfattention(self, x):
308
+ """Get attention from last block"""
309
  x = self.prepare_tokens(x)
310
  for i, blk in enumerate(self.blocks):
311
  if i < len(self.blocks) - 1:
312
  x = blk(x)
313
  else:
 
314
  return blk(x, return_attention=True)
315
+
316
  def get_intermediate_layers(self, x, n=1):
317
+ """Get outputs from last n blocks"""
318
  x = self.prepare_tokens(x)
 
319
  output = []
320
  for i, blk in enumerate(self.blocks):
321
  x = blk(x)
 
324
  return output
325
 
326
 
327
+ # Register for auto classes
328
+ VisionTransformerConfig.register_for_auto_class()
329
+ VisionTransformer.register_for_auto_class("AutoModel")