huu-ontocord commited on
Commit
8ee292f
·
verified ·
1 Parent(s): c64d641

Delete eva_vit.py

Browse files
Files changed (1) hide show
  1. eva_vit.py +0 -493
eva_vit.py DELETED
@@ -1,493 +0,0 @@
1
- # Based on EVA, BEIT, timm and DeiT code bases
2
- # https://github.com/baaivision/EVA
3
- # https://github.com/rwightman/pytorch-image-models/tree/master/timm
4
- # https://github.com/microsoft/unilm/tree/master/beit
5
- # https://github.com/facebookresearch/deit/
6
- # https://github.com/facebookresearch/dino
7
- # --------------------------------------------------------'
8
- import math
9
- from functools import partial
10
- from torch.nn import LayerNorm
11
-
12
- import torch
13
- import torch.nn as nn
14
- import torch.nn.functional as F
15
- import torch.utils.checkpoint as checkpoint
16
- from timm.models.layers import drop_path, to_2tuple, trunc_normal_
17
-
18
-
19
-
20
-
21
- def _cfg(url='', **kwargs):
22
- return {
23
- 'url': url,
24
- 'num_classes': 1000,
25
- 'input_size': (3, 224, 224),
26
- 'pool_size': None,
27
- 'crop_pct': .9,
28
- 'interpolation': 'bicubic',
29
- 'mean': (0.5, 0.5, 0.5),
30
- 'std': (0.5, 0.5, 0.5),
31
- **kwargs
32
- }
33
-
34
-
35
- class DropPath(nn.Module):
36
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
37
- """
38
- def __init__(self, drop_prob=None):
39
- super(DropPath, self).__init__()
40
- self.drop_prob = drop_prob
41
-
42
- def forward(self, x):
43
- return drop_path(x, self.drop_prob, self.training)
44
-
45
- def extra_repr(self) -> str:
46
- return 'p={}'.format(self.drop_prob)
47
-
48
-
49
- class Mlp(nn.Module):
50
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
51
- super().__init__()
52
- out_features = out_features or in_features
53
- hidden_features = hidden_features or in_features
54
- self.fc1 = nn.Linear(in_features, hidden_features)
55
- self.act = act_layer()
56
- self.fc2 = nn.Linear(hidden_features, out_features)
57
- self.drop = nn.Dropout(drop)
58
-
59
- def forward(self, x):
60
- x = self.fc1(x)
61
- x = self.act(x)
62
- # x = self.drop(x)
63
- # commit this for the orignal BERT implement
64
- x = self.fc2(x)
65
- x = self.drop(x)
66
- return x
67
-
68
-
69
- class Attention(nn.Module):
70
- def __init__(self,
71
- dim,
72
- num_heads=8,
73
- qkv_bias=False,
74
- qk_scale=None,
75
- attn_drop=0.,
76
- proj_drop=0.,
77
- window_size=None,
78
- attn_head_dim=None):
79
- super().__init__()
80
- self.num_heads = num_heads
81
- head_dim = dim // num_heads
82
- if attn_head_dim is not None:
83
- head_dim = attn_head_dim
84
- all_head_dim = head_dim * self.num_heads
85
- self.scale = qk_scale or head_dim**-0.5
86
-
87
- self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
88
- if qkv_bias:
89
- self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
90
- self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
91
- else:
92
- self.q_bias = None
93
- self.v_bias = None
94
-
95
- if window_size:
96
- self.window_size = window_size
97
- self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
98
- self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance,
99
- num_heads)) # 2*Wh-1 * 2*Ww-1, nH
100
- # cls to token & token 2 cls & cls to cls
101
-
102
- # get pair-wise relative position index for each token inside the window
103
- coords_h = torch.arange(window_size[0])
104
- coords_w = torch.arange(window_size[1])
105
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
106
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
107
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
108
- relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
109
- relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
110
- relative_coords[:, :, 1] += window_size[1] - 1
111
- relative_coords[:, :, 0] *= 2 * window_size[1] - 1
112
- relative_position_index = \
113
- torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
114
- relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
115
- relative_position_index[0, 0:] = self.num_relative_distance - 3
116
- relative_position_index[0:, 0] = self.num_relative_distance - 2
117
- relative_position_index[0, 0] = self.num_relative_distance - 1
118
-
119
- self.register_buffer("relative_position_index", relative_position_index)
120
- else:
121
- self.window_size = None
122
- self.relative_position_bias_table = None
123
- self.relative_position_index = None
124
-
125
- self.attn_drop = nn.Dropout(attn_drop)
126
- self.proj = nn.Linear(all_head_dim, dim)
127
- self.proj_drop = nn.Dropout(proj_drop)
128
-
129
- def forward(self, x, rel_pos_bias=None):
130
- B, N, C = x.shape
131
- qkv_bias = None
132
- if self.q_bias is not None:
133
- qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
134
- # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
135
- qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
136
- qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
137
- q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
138
-
139
- q = q * self.scale
140
- attn = (q @ k.transpose(-2, -1))
141
-
142
- if self.relative_position_bias_table is not None:
143
- relative_position_bias = \
144
- self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
145
- self.window_size[0] * self.window_size[1] + 1,
146
- self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
147
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
148
- attn = attn + relative_position_bias.unsqueeze(0)
149
-
150
- if rel_pos_bias is not None:
151
- attn = attn + rel_pos_bias
152
-
153
- attn = attn.softmax(dim=-1)
154
- attn = self.attn_drop(attn)
155
-
156
- x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
157
- x = self.proj(x)
158
- x = self.proj_drop(x)
159
- return x
160
-
161
-
162
- class Block(nn.Module):
163
- def __init__(self,
164
- dim,
165
- num_heads,
166
- mlp_ratio=4.,
167
- qkv_bias=False,
168
- qk_scale=None,
169
- drop=0.,
170
- attn_drop=0.,
171
- drop_path=0.,
172
- init_values=None,
173
- act_layer=nn.GELU,
174
- norm_layer=nn.LayerNorm,
175
- window_size=None,
176
- attn_head_dim=None):
177
- super().__init__()
178
- self.norm1 = norm_layer(dim)
179
- self.attn = Attention(dim,
180
- num_heads=num_heads,
181
- qkv_bias=qkv_bias,
182
- qk_scale=qk_scale,
183
- attn_drop=attn_drop,
184
- proj_drop=drop,
185
- window_size=window_size,
186
- attn_head_dim=attn_head_dim)
187
- # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
188
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
189
- self.norm2 = norm_layer(dim)
190
- mlp_hidden_dim = int(dim * mlp_ratio)
191
- self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
192
-
193
- if init_values is not None and init_values > 0:
194
- self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
195
- self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
196
- else:
197
- self.gamma_1, self.gamma_2 = None, None
198
-
199
- def forward(self, x, rel_pos_bias=None):
200
- if self.gamma_1 is None:
201
- x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
202
- x = x + self.drop_path(self.mlp(self.norm2(x)))
203
- else:
204
- x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
205
- x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
206
- return x
207
-
208
-
209
- class PatchEmbed(nn.Module):
210
- """ Image to Patch Embedding
211
- """
212
- def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
213
- super().__init__()
214
- img_size = to_2tuple(img_size)
215
- patch_size = to_2tuple(patch_size)
216
- num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
217
- self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
218
- self.img_size = img_size
219
- self.patch_size = patch_size
220
- self.num_patches = num_patches
221
-
222
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
223
-
224
- def forward(self, x, **kwargs):
225
- B, C, H, W = x.shape
226
- # FIXME look at relaxing size constraints
227
- assert H == self.img_size[0] and W == self.img_size[1], \
228
- f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
229
- x = self.proj(x).flatten(2).transpose(1, 2)
230
- return x
231
-
232
-
233
- class RelativePositionBias(nn.Module):
234
- def __init__(self, window_size, num_heads):
235
- super().__init__()
236
- self.window_size = window_size
237
- self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
238
- self.relative_position_bias_table = nn.Parameter(torch.zeros(self.num_relative_distance,
239
- num_heads)) # 2*Wh-1 * 2*Ww-1, nH
240
- # cls to token & token 2 cls & cls to cls
241
-
242
- # get pair-wise relative position index for each token inside the window
243
- coords_h = torch.arange(window_size[0])
244
- coords_w = torch.arange(window_size[1])
245
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
246
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
247
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
248
- relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
249
- relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
250
- relative_coords[:, :, 1] += window_size[1] - 1
251
- relative_coords[:, :, 0] *= 2 * window_size[1] - 1
252
- relative_position_index = \
253
- torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
254
- relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
255
- relative_position_index[0, 0:] = self.num_relative_distance - 3
256
- relative_position_index[0:, 0] = self.num_relative_distance - 2
257
- relative_position_index[0, 0] = self.num_relative_distance - 1
258
-
259
- self.register_buffer("relative_position_index", relative_position_index)
260
-
261
- # trunc_normal_(self.relative_position_bias_table, std=.02)
262
-
263
- def forward(self):
264
- relative_position_bias = \
265
- self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
266
- self.window_size[0] * self.window_size[1] + 1,
267
- self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
268
- return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
269
-
270
-
271
- class VisionTransformerEvaClip(nn.Module):
272
- """ Vision Transformer with support for patch or hybrid CNN input stage
273
- """
274
- def __init__(self,
275
- img_size=224,
276
- patch_size=16,
277
- in_chans=3,
278
- num_classes=1000,
279
- embed_dim=768,
280
- depth=12,
281
- num_heads=12,
282
- mlp_ratio=4.,
283
- qkv_bias=False,
284
- qk_scale=None,
285
- drop_rate=0.,
286
- attn_drop_rate=0.,
287
- drop_path_rate=0.,
288
- norm_layer=nn.LayerNorm,
289
- init_values=None,
290
- use_abs_pos_emb=True,
291
- use_rel_pos_bias=False,
292
- use_shared_rel_pos_bias=False,
293
- use_mean_pooling=True,
294
- init_scale=0.001,
295
- use_checkpoint=False):
296
- super().__init__()
297
- self.image_size = img_size
298
- self.num_classes = num_classes
299
- self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
300
-
301
- self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
302
- num_patches = self.patch_embed.num_patches
303
-
304
- self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
305
- if use_abs_pos_emb:
306
- self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
307
- else:
308
- self.pos_embed = None
309
- self.pos_drop = nn.Dropout(p=drop_rate)
310
-
311
- if use_shared_rel_pos_bias:
312
- self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
313
- else:
314
- self.rel_pos_bias = None
315
- self.use_checkpoint = use_checkpoint
316
-
317
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
318
- self.use_rel_pos_bias = use_rel_pos_bias
319
- self.blocks = nn.ModuleList([
320
- Block(dim=embed_dim,
321
- num_heads=num_heads,
322
- mlp_ratio=mlp_ratio,
323
- qkv_bias=qkv_bias,
324
- qk_scale=qk_scale,
325
- drop=drop_rate,
326
- attn_drop=attn_drop_rate,
327
- drop_path=dpr[i],
328
- norm_layer=norm_layer,
329
- init_values=init_values,
330
- window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None) for i in range(depth)
331
- ])
332
- # self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
333
- # self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
334
- # self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
335
-
336
- if self.pos_embed is not None:
337
- trunc_normal_(self.pos_embed, std=.02)
338
- trunc_normal_(self.cls_token, std=.02)
339
- # trunc_normal_(self.mask_token, std=.02)
340
- # if isinstance(self.head, nn.Linear):
341
- # trunc_normal_(self.head.weight, std=.02)
342
- self.apply(self._init_weights)
343
- self.fix_init_weight()
344
- self.ln_vision = LayerNorm(self.num_features)
345
-
346
- def fix_init_weight(self):
347
- def rescale(param, layer_id):
348
- param.div_(math.sqrt(2.0 * layer_id))
349
-
350
- for layer_id, layer in enumerate(self.blocks):
351
- rescale(layer.attn.proj.weight.data, layer_id + 1)
352
- rescale(layer.mlp.fc2.weight.data, layer_id + 1)
353
-
354
- def _init_weights(self, m):
355
- if isinstance(m, nn.Linear):
356
- trunc_normal_(m.weight, std=.02)
357
- if isinstance(m, nn.Linear) and m.bias is not None:
358
- nn.init.constant_(m.bias, 0)
359
- elif isinstance(m, nn.LayerNorm):
360
- nn.init.constant_(m.bias, 0)
361
- nn.init.constant_(m.weight, 1.0)
362
-
363
- _initialize_weights = _init_weights
364
-
365
- def get_classifier(self):
366
- return self.head
367
-
368
- def reset_classifier(self, num_classes, global_pool=''):
369
- self.num_classes = num_classes
370
- self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
371
-
372
- def forward_features(self, x):
373
- x = self.patch_embed(x)
374
- batch_size, seq_len, _ = x.size()
375
-
376
- cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
377
- x = torch.cat((cls_tokens, x), dim=1)
378
- if self.pos_embed is not None:
379
- x = x + self.pos_embed
380
- x = self.pos_drop(x)
381
-
382
- rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
383
- for blk in self.blocks:
384
- if self.use_checkpoint:
385
- x = checkpoint.checkpoint(blk, x, rel_pos_bias)
386
- else:
387
- x = blk(x, rel_pos_bias)
388
- return x
389
-
390
- def forward(self, x):
391
- x = self.forward_features(x)
392
- # x = self.head(x)
393
- return x
394
-
395
- def get_intermediate_layers(self, x):
396
- x = self.patch_embed(x)
397
- batch_size, seq_len, _ = x.size()
398
-
399
- cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
400
- x = torch.cat((cls_tokens, x), dim=1)
401
- if self.pos_embed is not None:
402
- x = x + self.pos_embed
403
- x = self.pos_drop(x)
404
-
405
- features = []
406
- rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
407
- for blk in self.blocks:
408
- x = blk(x, rel_pos_bias)
409
- features.append(x)
410
-
411
- return features
412
-
413
- def get_num_layer(self, var_name=""):
414
- if var_name in ("cls_token", "mask_token", "pos_embed"):
415
- return 0
416
- elif var_name.startswith("patch_embed"):
417
- return 0
418
- elif var_name.startswith("rel_pos_bias"):
419
- return len(self.blocks) - 1
420
- elif var_name.startswith("blocks"):
421
- layer_id = int(var_name.split('.')[1])
422
- return layer_id + 1
423
- else:
424
- return len(self.blocks)
425
-
426
-
427
- def interpolate_pos_embed(model, checkpoint_model):
428
- if 'pos_embed' in checkpoint_model:
429
- pos_embed_checkpoint = checkpoint_model['pos_embed'].float()
430
- embedding_size = pos_embed_checkpoint.shape[-1]
431
- num_patches = model.patch_embed.num_patches
432
- num_extra_tokens = model.pos_embed.shape[-2] - num_patches
433
- # height (== width) for the checkpoint position embedding
434
- orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens)**0.5)
435
- # height (== width) for the new position embedding
436
- new_size = int(num_patches**0.5)
437
- # class_token and dist_token are kept unchanged
438
- if orig_size != new_size:
439
- print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
440
- extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
441
- # only the position tokens are interpolated
442
- pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
443
- pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
444
- pos_tokens = torch.nn.functional.interpolate(pos_tokens,
445
- size=(new_size, new_size),
446
- mode='bicubic',
447
- align_corners=False)
448
- pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
449
- new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
450
- checkpoint_model['pos_embed'] = new_pos_embed
451
-
452
-
453
- def convert_weights_to_fp16(model: nn.Module):
454
- """Convert applicable model parameters to fp16"""
455
- def _convert_weights_to_fp16(l):
456
- if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
457
- l.weight.data = l.weight.data.half()
458
- if l.bias is not None:
459
- l.bias.data = l.bias.data.half()
460
-
461
- model.apply(_convert_weights_to_fp16)
462
-
463
-
464
- def create_eva_vit_g(img_size=224, drop_path_rate=0.4, use_checkpoint=False, precision="fp16", cache_dir="./",):
465
- model = VisionTransformerEvaClip(
466
- img_size=img_size,
467
- patch_size=14,
468
- use_mean_pooling=False,
469
- embed_dim=1408,
470
- depth=39,
471
- num_heads=1408 // 88,
472
- mlp_ratio=4.3637,
473
- qkv_bias=True,
474
- drop_path_rate=drop_path_rate,
475
- norm_layer=partial(nn.LayerNorm, eps=1e-6),
476
- use_checkpoint=use_checkpoint,
477
- )
478
- cache_path = cache_dir
479
- state_dict = torch.load(cache_path+"/eva_vit_g.pth", map_location="cpu")
480
- interpolate_pos_embed(model, state_dict)
481
-
482
- incompatible_keys = model.load_state_dict(state_dict, strict=False)
483
- print(incompatible_keys)
484
-
485
- if precision == "fp16":
486
- # model.to("cuda")
487
- convert_weights_to_fp16(model)
488
- return model
489
-
490
-
491
- if __name__ == "__main__":
492
- model = create_eva_vit_g()
493
- print (model.num_features)