Continual-Mega commited on
Commit
a07ade3
·
verified ·
1 Parent(s): f817c63

Upload CLIP/model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. CLIP/model.py +538 -0
CLIP/model.py ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CLIP Model
2
+
3
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+ from dataclasses import dataclass
6
+ import logging
7
+ import math
8
+ from typing import Optional, Tuple, Union
9
+ from itertools import repeat
10
+ import collections.abc
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torch import nn
15
+ from torch.utils.checkpoint import checkpoint
16
+ from .modified_resnet import ModifiedResNet
17
+ from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer
18
+ from collections import OrderedDict
19
+
20
+
21
+ @dataclass
22
+ class CLIPVisionCfg:
23
+ layers: Union[Tuple[int, int, int, int], int] = 12
24
+ width: int = 768
25
+ head_width: int = 64
26
+ mlp_ratio: float = 4.0
27
+ patch_size: int = 16
28
+ image_size: Union[Tuple[int, int], int] = 224
29
+ ls_init_value: Optional[float] = None # layer scale initial value
30
+ patch_dropout: float = 0.2 # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results
31
+ input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design
32
+ global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580)
33
+ attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer
34
+ n_queries: int = 256 # n_queries for attentional pooler
35
+ attn_pooler_heads: int = 8 # n heads for attentional_pooling
36
+ timm_model_name: str = None # a valid model name overrides layers, width, patch_size
37
+ timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model
38
+ timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
39
+ timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '')
40
+ timm_proj_bias: bool = False # enable bias final projection
41
+ timm_drop: float = 0. # head dropout
42
+ timm_drop_path: Optional[float] = None # backbone stochastic depth
43
+ output_tokens: bool = True
44
+
45
+
46
+ @dataclass
47
+ class CLIPTextCfg:
48
+ context_length: int = 77
49
+ vocab_size: int = 49408
50
+ width: int = 512
51
+ heads: int = 8
52
+ layers: int = 12
53
+ ls_init_value: Optional[float] = None # layer scale initial value
54
+ hf_model_name: str = None
55
+ hf_tokenizer_name: str = None
56
+ hf_model_pretrained: bool = True
57
+ proj: str = 'mlp'
58
+ pooler_type: str = 'mean_pooler'
59
+ embed_cls: bool = False
60
+ pad_id: int = 0
61
+ output_tokens: bool = False
62
+
63
+
64
+ def get_cast_dtype(precision: str):
65
+ cast_dtype = None
66
+ if precision == 'bf16':
67
+ cast_dtype = torch.bfloat16
68
+ elif precision == 'fp16':
69
+ cast_dtype = torch.float16
70
+ return cast_dtype
71
+
72
+
73
+ def _build_vision_tower(
74
+ embed_dim: int,
75
+ vision_cfg: CLIPVisionCfg,
76
+ quick_gelu: bool = False,
77
+ cast_dtype: Optional[torch.dtype] = None
78
+ ):
79
+ if isinstance(vision_cfg, dict):
80
+ vision_cfg = CLIPVisionCfg(**vision_cfg)
81
+
82
+ # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
83
+ # memory efficient in recent PyTorch releases (>= 1.10).
84
+ # NOTE: timm models always use native GELU regardless of quick_gelu flag.
85
+ act_layer = QuickGELU if quick_gelu else nn.GELU
86
+ if isinstance(vision_cfg.layers, (tuple, list)):
87
+ vision_heads = vision_cfg.width * 32 // vision_cfg.head_width
88
+ visual = ModifiedResNet(
89
+ layers=vision_cfg.layers,
90
+ output_dim=embed_dim,
91
+ heads=vision_heads,
92
+ image_size=vision_cfg.image_size,
93
+ width=vision_cfg.width,
94
+ )
95
+ else:
96
+ vision_heads = vision_cfg.width // vision_cfg.head_width
97
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
98
+ visual = VisionTransformer(
99
+ image_size=vision_cfg.image_size,
100
+ patch_size=vision_cfg.patch_size,
101
+ width=vision_cfg.width,
102
+ layers=vision_cfg.layers,
103
+ heads=vision_heads,
104
+ mlp_ratio=vision_cfg.mlp_ratio,
105
+ ls_init_value=vision_cfg.ls_init_value,
106
+ patch_dropout=vision_cfg.patch_dropout,
107
+ input_patchnorm=vision_cfg.input_patchnorm,
108
+ global_average_pool=vision_cfg.global_average_pool,
109
+ attentional_pool=vision_cfg.attentional_pool,
110
+ n_queries=vision_cfg.n_queries,
111
+ attn_pooler_heads=vision_cfg.attn_pooler_heads,
112
+ output_tokens=vision_cfg.output_tokens,
113
+ output_dim=embed_dim,
114
+ act_layer=act_layer,
115
+ norm_layer=norm_layer,
116
+ )
117
+
118
+ return visual
119
+
120
+
121
+ def _build_text_tower(
122
+ embed_dim: int,
123
+ text_cfg: CLIPTextCfg,
124
+ quick_gelu: bool = False,
125
+ cast_dtype: Optional[torch.dtype] = None,
126
+ ):
127
+ if isinstance(text_cfg, dict):
128
+ text_cfg = CLIPTextCfg(**text_cfg)
129
+
130
+ act_layer = QuickGELU if quick_gelu else nn.GELU
131
+ norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
132
+
133
+ text = TextTransformer(
134
+ context_length=text_cfg.context_length,
135
+ vocab_size=text_cfg.vocab_size,
136
+ width=text_cfg.width,
137
+ heads=text_cfg.heads,
138
+ layers=text_cfg.layers,
139
+ ls_init_value=text_cfg.ls_init_value,
140
+ output_dim=embed_dim,
141
+ embed_cls=text_cfg.embed_cls,
142
+ output_tokens=text_cfg.output_tokens,
143
+ pad_id=text_cfg.pad_id,
144
+ act_layer=act_layer,
145
+ norm_layer=norm_layer,
146
+ )
147
+ return text
148
+
149
+
150
+ class ResidualAttentionBlock_learnable_token(nn.Module):
151
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, design_details=None,
152
+ text_layer=False, i = 0):
153
+ super().__init__()
154
+
155
+ self.attn = nn.MultiheadAttention(d_model, n_head)
156
+ self.ln_1 = LayerNorm(d_model)
157
+ self.mlp = nn.Sequential(OrderedDict([
158
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
159
+ ("gelu", QuickGELU()),
160
+ ("c_proj", nn.Linear(d_model * 4, d_model))
161
+ ]))
162
+ self.ln_2 = LayerNorm(d_model)
163
+ self.attn_mask = attn_mask
164
+
165
+ self.i = i
166
+ self.compound_prompt_nctx = design_details['learnabel_text_embedding_length']
167
+ self.text_layer = text_layer
168
+ if i == 0:
169
+ self.first_layer = True
170
+ else:
171
+ self.first_layer = False
172
+
173
+ def attention(self, x: torch.Tensor):
174
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
175
+ if isinstance(self.attn, Attention):
176
+ x = x.transpose(0, 1)
177
+ x, x_ori = self.attn(x)
178
+ return [x.transpose(0, 1), x_ori.transpose(0, 1)]
179
+ else:
180
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
181
+
182
+ def forward(self, inputs):
183
+
184
+ # dual paths for blocks deeper than "d"
185
+ if isinstance(self.attn, Attention):
186
+ x = inputs[0]
187
+ if isinstance(x, list):
188
+ x, x_ori = x
189
+ x_res = self.attention(self.ln_1(x_ori))
190
+ x_res, x_ori_res = x_res
191
+ x_ori += x_ori_res
192
+ x_ori = x_ori + self.mlp(self.ln_2(x_ori))
193
+ x += x_res # skip ffn for the new path
194
+ return [x, x_ori]
195
+
196
+ # start of dual path
197
+ else:
198
+ x_res = self.attention(self.ln_1(x))
199
+ if isinstance(x_res, list):
200
+ x_res, x_ori_res = x_res
201
+ x_ori = x + x_ori_res
202
+ x_ori = x_ori + self.mlp(self.ln_2(x_ori))
203
+ x += x_res
204
+ return [x, x_ori]
205
+
206
+ # singl path before "d"
207
+ else:
208
+ x = inputs[0]
209
+ compound_prompts_deeper = inputs[1]
210
+ counter = inputs[2]
211
+ if not self.first_layer:
212
+ # First check if the ith layer needs compound prompts or not
213
+ if not (counter > len(compound_prompts_deeper) - 1):
214
+ # Appending the learnable tokens in different way
215
+ # x -> [77, NCLS, DIM]
216
+ # First remove the learnable tokens from previous layer
217
+ prefix = x[:1, :, :]
218
+ suffix = x[1 + self.compound_prompt_nctx:, :, :]
219
+ textual_context = compound_prompts_deeper[counter]
220
+ textual_context = textual_context.expand(x.shape[1], -1, -1).permute(1, 0, 2).half()
221
+ # Add the learnable tokens of this layer with the input, replaced by previous
222
+ # layer learnable tokens
223
+ x = torch.cat([prefix, textual_context, suffix], dim=0)
224
+ # Once done, update the counter, so that the next time, it does not use same learnable tokens
225
+ counter += 1
226
+ x = x + self.attention(self.ln_1(x))
227
+ x = x + self.mlp(self.ln_2(x))
228
+ return [x, compound_prompts_deeper, counter]
229
+
230
+
231
+
232
+ class CLIP(nn.Module):
233
+ output_dict: torch.jit.Final[bool]
234
+
235
+ def __init__(
236
+ self,
237
+ embed_dim: int,
238
+ vision_cfg: CLIPVisionCfg,
239
+ text_cfg: CLIPTextCfg,
240
+ quick_gelu: bool = False,
241
+ cast_dtype: Optional[torch.dtype] = None,
242
+ output_dict: bool = False,
243
+ design_details = None
244
+ ):
245
+ super().__init__()
246
+ self.output_dict = output_dict
247
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
248
+
249
+ text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
250
+ self.transformer = text.transformer
251
+ self.vocab_size = text.vocab_size
252
+ self.token_embedding = text.token_embedding
253
+ self.positional_embedding = text.positional_embedding
254
+ self.ln_final = text.ln_final
255
+ self.text_projection = text.text_projection
256
+ self.register_buffer('attn_mask', text.attn_mask, persistent=False)
257
+
258
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
259
+
260
+ def build_attention_mask(self):
261
+ # lazily create causal attention mask, with full attention between the vision tokens
262
+ # pytorch uses additive attention mask; fill with -inf
263
+ mask = torch.empty(77, 77)
264
+ mask.fill_(float("-inf"))
265
+ mask.triu_(1) # zero out the lower diagonal
266
+ return mask
267
+
268
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
269
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
270
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
271
+
272
+ @torch.jit.ignore
273
+ def set_grad_checkpointing(self, enable=True):
274
+ self.visual.set_grad_checkpointing(enable)
275
+ self.transformer.grad_checkpointing = enable
276
+
277
+ def encode_image(self, image, out_layers, normalize: bool = False):
278
+ # print(image.shape)
279
+ features = self.visual(image, out_layers)
280
+ return F.normalize(features, dim=-1) if normalize else features
281
+
282
+ def encode_text(self, text, normalize: bool = False):
283
+ cast_dtype = self.transformer.get_cast_dtype()
284
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
285
+
286
+ x = x + self.positional_embedding.to(cast_dtype)
287
+ x = x.permute(1, 0, 2) # NLD -> LND
288
+ x, attn, tokens = self.transformer(x, attn_mask=self.attn_mask)
289
+ x = x.permute(1, 0, 2) # LND -> NLD
290
+ x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
291
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
292
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
293
+ return F.normalize(x, dim=-1) if normalize else x
294
+
295
+ def encode_text_learn(self, prompts, tokenized_prompts, deep_compound_prompts_text = None, normalize: bool = False):
296
+ cast_dtype = self.transformer.get_cast_dtype()
297
+
298
+ # x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
299
+
300
+ # x = x + self.positional_embedding.to(cast_dtype)
301
+
302
+ x = prompts + self.positional_embedding.to(cast_dtype)
303
+ x = x.permute(1, 0, 2) # NLD -> LND
304
+ # print("test", x.shape, len(deep_compound_prompts_text))
305
+ if deep_compound_prompts_text is None:
306
+ x = self.transformer(x)
307
+ else:
308
+ x = self.transformer([x, deep_compound_prompts_text, 0])
309
+ x = x.permute(1, 0, 2) # LND -> NLD
310
+ x = self.ln_final(x).type(torch.float32) # [batch_size, n_ctx, transformer.width]
311
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
312
+ x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
313
+ return x
314
+
315
+ def forward(self, image, text):
316
+ image_features = self.encode_image(image, normalize=True)
317
+ text_features = self.encode_text(text, normalize=True)
318
+ if self.output_dict:
319
+ return {
320
+ "image_features": image_features,
321
+ "text_features": text_features,
322
+ "logit_scale": self.logit_scale.exp()
323
+ }
324
+ return image_features, text_features, self.logit_scale.exp()
325
+
326
+
327
+ class CustomTextCLIP(nn.Module):
328
+ output_dict: torch.jit.Final[bool]
329
+
330
+ def __init__(
331
+ self,
332
+ embed_dim: int,
333
+ vision_cfg: CLIPVisionCfg,
334
+ text_cfg: CLIPTextCfg,
335
+ quick_gelu: bool = False,
336
+ cast_dtype: Optional[torch.dtype] = None,
337
+ output_dict: bool = False,
338
+ ):
339
+ super().__init__()
340
+ self.output_dict = output_dict
341
+ self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype)
342
+ self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
343
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
344
+
345
+ def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False):
346
+ # lock image tower as per LiT - https://arxiv.org/abs/2111.07991
347
+ self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats)
348
+
349
+ def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
350
+ self.text.lock(unlocked_layers, freeze_layer_norm)
351
+
352
+ @torch.jit.ignore
353
+ def set_grad_checkpointing(self, enable=True):
354
+ self.visual.set_grad_checkpointing(enable)
355
+ self.text.set_grad_checkpointing(enable)
356
+
357
+ def encode_image(self, image, normalize: bool = False):
358
+ features = self.visual(image)
359
+ return F.normalize(features, dim=-1) if normalize else features
360
+
361
+ def encode_text(self, text, normalize: bool = False):
362
+ features = self.text(text)
363
+ return F.normalize(features, dim=-1) if normalize else features
364
+
365
+ def forward(self, image, text):
366
+ image_features = self.encode_image(image, normalize=True)
367
+ text_features = self.encode_text(text, normalize=True)
368
+ if self.output_dict:
369
+ return {
370
+ "image_features": image_features,
371
+ "text_features": text_features,
372
+ "logit_scale": self.logit_scale.exp()
373
+ }
374
+ return image_features, text_features, self.logit_scale.exp()
375
+
376
+
377
+ def convert_weights_to_lp(model: nn.Module, dtype=torch.float16):
378
+ """Convert applicable model parameters to low-precision (bf16 or fp16)"""
379
+
380
+ def _convert_weights(l):
381
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
382
+ l.weight.data = l.weight.data.to(dtype)
383
+ if l.bias is not None:
384
+ l.bias.data = l.bias.data.to(dtype)
385
+
386
+ if isinstance(l, (nn.MultiheadAttention, Attention)):
387
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
388
+ tensor = getattr(l, attr)
389
+ if tensor is not None:
390
+ tensor.data = tensor.data.to(dtype)
391
+
392
+ for name in ["text_projection", "proj"]:
393
+ if hasattr(l, name):
394
+ attr = getattr(l, name)
395
+ if attr is not None:
396
+ attr.data = attr.data.to(dtype)
397
+
398
+ model.apply(_convert_weights)
399
+
400
+
401
+ convert_weights_to_fp16 = convert_weights_to_lp # backwards compat
402
+
403
+
404
+ # used to maintain checkpoint compatibility
405
+ def convert_to_custom_text_state_dict(state_dict: dict):
406
+ if 'text_projection' in state_dict:
407
+ # old format state_dict, move text tower -> .text
408
+ new_state_dict = {}
409
+ for k, v in state_dict.items():
410
+ if any(k.startswith(p) for p in (
411
+ 'text_projection',
412
+ 'positional_embedding',
413
+ 'token_embedding',
414
+ 'transformer',
415
+ 'ln_final',
416
+ )):
417
+ k = 'text.' + k
418
+ new_state_dict[k] = v
419
+ return new_state_dict
420
+ return state_dict
421
+
422
+
423
+ def build_model_from_openai_state_dict(
424
+ state_dict: dict,
425
+ quick_gelu=True,
426
+ cast_dtype=torch.float16,
427
+ ):
428
+ vit = "visual.proj" in state_dict
429
+
430
+ if vit:
431
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
432
+ vision_layers = len(
433
+ [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
434
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
435
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
436
+ image_size = vision_patch_size * grid_size
437
+ else:
438
+ counts: list = [
439
+ len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
440
+ vision_layers = tuple(counts)
441
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
442
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
443
+ vision_patch_size = None
444
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
445
+ image_size = output_width * 32
446
+
447
+ embed_dim = state_dict["text_projection"].shape[1]
448
+ context_length = state_dict["positional_embedding"].shape[0]
449
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
450
+ transformer_width = state_dict["ln_final.weight"].shape[0]
451
+ transformer_heads = transformer_width // 64
452
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
453
+
454
+ vision_cfg = CLIPVisionCfg(
455
+ layers=vision_layers,
456
+ width=vision_width,
457
+ patch_size=vision_patch_size,
458
+ image_size=image_size,
459
+ )
460
+ text_cfg = CLIPTextCfg(
461
+ context_length=context_length,
462
+ vocab_size=vocab_size,
463
+ width=transformer_width,
464
+ heads=transformer_heads,
465
+ layers=transformer_layers,
466
+ )
467
+ model = CLIP(
468
+ embed_dim,
469
+ vision_cfg=vision_cfg,
470
+ text_cfg=text_cfg,
471
+ quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU
472
+ cast_dtype=cast_dtype,
473
+ )
474
+
475
+ for key in ["input_resolution", "context_length", "vocab_size"]:
476
+ state_dict.pop(key, None)
477
+
478
+ convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16
479
+ model.load_state_dict(state_dict)
480
+ return model.eval()
481
+
482
+
483
+ def trace_model(model, batch_size=256, device=torch.device('cpu')):
484
+ model.eval()
485
+ image_size = model.visual.image_size
486
+ example_images = torch.ones((batch_size, 3, image_size, image_size), device=device)
487
+ example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device)
488
+ model = torch.jit.trace_module(
489
+ model,
490
+ inputs=dict(
491
+ forward=(example_images, example_text),
492
+ encode_text=(example_text,),
493
+ encode_image=(example_images,)
494
+ ))
495
+ model.visual.image_size = image_size
496
+ return model
497
+
498
+ # From PyTorch internals
499
+ def _ntuple(n):
500
+ def parse(x):
501
+ if isinstance(x, collections.abc.Iterable):
502
+ return x
503
+ return tuple(repeat(x, n))
504
+ return parse
505
+ to_2tuple = _ntuple(2)
506
+
507
+ def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True):
508
+ # Rescale the grid of position embeddings when loading from state_dict
509
+ old_pos_embed = state_dict.get('visual.positional_embedding', None)
510
+ if old_pos_embed is None or not hasattr(model.visual, 'grid_size'):
511
+ return
512
+ grid_size = to_2tuple(model.visual.grid_size)
513
+ extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more)
514
+ new_seq_len = grid_size[0] * grid_size[1] + extra_tokens
515
+ if new_seq_len == old_pos_embed.shape[0]:
516
+ return
517
+
518
+ if extra_tokens:
519
+ pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:]
520
+ else:
521
+ pos_emb_tok, pos_emb_img = None, old_pos_embed
522
+ old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img))))
523
+
524
+ logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size)
525
+ pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2)
526
+ pos_emb_img = F.interpolate(
527
+ pos_emb_img,
528
+ size=grid_size,
529
+ mode=interpolation,
530
+ antialias=antialias,
531
+ align_corners=False,
532
+ )
533
+ pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0]
534
+ if pos_emb_tok is not None:
535
+ new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0)
536
+ else:
537
+ new_pos_embed = pos_emb_img
538
+ state_dict['visual.positional_embedding'] = new_pos_embed