Continual-Mega commited on
Commit
f817c63
·
verified ·
1 Parent(s): 370c0d0

Upload CLIP/clip.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. CLIP/clip.py +240 -0
CLIP/clip.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ import pathlib
5
+ import re
6
+ from copy import deepcopy
7
+ from pathlib import Path
8
+ from typing import Any, Dict, Optional, Tuple, Union
9
+ import torch
10
+ from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict, resize_pos_embed, get_cast_dtype
11
+ from .openai import load_openai_model
12
+
13
+
14
+ _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
15
+ _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
16
+ _MODEL_CKPT_PATHS = {'ViT-L-14-336': Path(__file__).parent / "ckpt/ViT-L-14-336px.pt"}
17
+
18
+
19
+ def _natural_key(string_):
20
+ return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
21
+
22
+
23
+ def _rescan_model_configs():
24
+ global _MODEL_CONFIGS
25
+
26
+ config_ext = ('.json',)
27
+ config_files = []
28
+ for config_path in _MODEL_CONFIG_PATHS:
29
+ if config_path.is_file() and config_path.suffix in config_ext:
30
+ config_files.append(config_path)
31
+ elif config_path.is_dir():
32
+ for ext in config_ext:
33
+ config_files.extend(config_path.glob(f'*{ext}'))
34
+
35
+ for cf in config_files:
36
+ with open(cf, 'r') as f:
37
+ model_cfg = json.load(f)
38
+ if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')):
39
+ _MODEL_CONFIGS[cf.stem] = model_cfg
40
+
41
+ _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
42
+
43
+
44
+ _rescan_model_configs() # initial populate of model config registry
45
+
46
+
47
+ def list_models():
48
+ """ enumerate available model architectures based on config files """
49
+ return list(_MODEL_CONFIGS.keys())
50
+
51
+
52
+
53
+ def get_model_config(model_name):
54
+ # print(_MODEL_CONFIGS)
55
+ if model_name in _MODEL_CONFIGS:
56
+ # print('herehere')
57
+ return deepcopy(_MODEL_CONFIGS[model_name])
58
+ else:
59
+ return None
60
+
61
+
62
+ def load_state_dict(checkpoint_path: str, map_location='cpu'):
63
+ checkpoint = torch.load(checkpoint_path, map_location=map_location)
64
+ if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
65
+ state_dict = checkpoint['state_dict']
66
+ else:
67
+ state_dict = checkpoint
68
+ if next(iter(state_dict.items()))[0].startswith('module'):
69
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
70
+ return state_dict
71
+
72
+
73
+
74
+ def load_checkpoint(model, checkpoint_path, strict=True):
75
+ state_dict = load_state_dict(checkpoint_path)
76
+ # detect old format and make compatible with new format
77
+ if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'):
78
+ state_dict = convert_to_custom_text_state_dict(state_dict)
79
+ resize_pos_embed(state_dict, model)
80
+ incompatible_keys = model.load_state_dict(state_dict, strict=strict)
81
+ return incompatible_keys
82
+
83
+
84
+ def create_model(
85
+ model_name: str,
86
+ img_size: int,
87
+ pretrained: Optional[str] = None,
88
+ precision: str = 'fp32',
89
+ device: Union[str, torch.device] = 'cpu',
90
+ jit: bool = False,
91
+ force_quick_gelu: bool = False,
92
+ force_custom_text: bool = False,
93
+ force_patch_dropout: Optional[float] = None,
94
+ force_image_size: Optional[Union[int, Tuple[int, int]]] = None,
95
+ output_dict: Optional[bool] = None,
96
+ require_pretrained: bool = False,
97
+ adapter = False,
98
+ ):
99
+
100
+ model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names
101
+ checkpoint_path = None
102
+ model_cfg = None
103
+
104
+ if isinstance(device, str):
105
+ device = torch.device(device)
106
+
107
+ if pretrained and pretrained.lower() == 'openai':
108
+ logging.info(f'Loading pretrained {model_name} from OpenAI.')
109
+ model_cfg = model_cfg or get_model_config(model_name)
110
+ # print(model_cfg['vision_cfg'])
111
+ if model_cfg['vision_cfg']['image_size'] != img_size:
112
+ model_cfg['vision_cfg']['image_size'] = img_size
113
+ cast_dtype = get_cast_dtype(precision)
114
+
115
+ model_pre = load_openai_model(
116
+ name = _MODEL_CKPT_PATHS[model_name],
117
+ precision=precision,
118
+ device=device,
119
+ jit=jit,
120
+ )
121
+ state_dict = model_pre.state_dict()
122
+
123
+ # to always output dict even if it is clip
124
+ if output_dict and hasattr(model_pre, "output_dict"):
125
+ model_pre.output_dict = True
126
+
127
+ model = CLIP(**model_cfg, cast_dtype=cast_dtype)
128
+ ### for resnet
129
+ if not hasattr(model.visual, 'grid_size'):
130
+ model.visual.grid_size = int(np.sqrt(model.visual.attnpool.positional_embedding.shape[0] - 1))
131
+ resize_pos_embed(state_dict, model)
132
+ incompatible_keys = model.load_state_dict(state_dict, strict=True)
133
+ model.to(device=device)
134
+ if precision in ("fp16", "bf16"):
135
+ convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)
136
+
137
+ # set image / mean metadata from pretrained_cfg if available, or use default
138
+ model.visual.image_mean = (0.48145466, 0.4578275, 0.40821073)
139
+ model.visual.image_std = (0.26862954, 0.26130258, 0.27577711)
140
+
141
+ # to always output dict even if it is clip
142
+ if output_dict and hasattr(model, "output_dict"):
143
+ model.output_dict = True
144
+
145
+ if jit:
146
+ model = torch.jit.script(model)
147
+ else:
148
+ cast_dtype = get_cast_dtype(precision)
149
+
150
+ model_pre = load_openai_model(
151
+ name = _MODEL_CKPT_PATHS[model_name],
152
+ precision=precision,
153
+ device=device,
154
+ jit=jit,
155
+ )
156
+ state_dict = model_pre.state_dict()
157
+
158
+ # to always output dict even if it is clip
159
+ if output_dict and hasattr(model_pre, "output_dict"):
160
+ model_pre.output_dict = True
161
+
162
+ model = CLIP(**model_cfg, cast_dtype=cast_dtype)
163
+ ### for resnet
164
+ if not hasattr(model.visual, 'grid_size'):
165
+ model.visual.grid_size = int(np.sqrt(model.visual.attnpool.positional_embedding.shape[0] - 1))
166
+ incompatible_keys = model.load_state_dict(state_dict, strict=True)
167
+ model.to(device=device)
168
+ if precision in ("fp16", "bf16"):
169
+ convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)
170
+
171
+ # set image / mean metadata from pretrained_cfg if available, or use default
172
+ model.visual.image_mean = (0.48145466, 0.4578275, 0.40821073)
173
+ model.visual.image_std = (0.26862954, 0.26130258, 0.27577711)
174
+
175
+ # to always output dict even if it is clip
176
+ if output_dict and hasattr(model, "output_dict"):
177
+ model.output_dict = True
178
+
179
+ if jit:
180
+ model = torch.jit.script(model)
181
+ else:
182
+ # print('here')
183
+ model_cfg = model_cfg or get_model_config(model_name)
184
+ if model_cfg is not None:
185
+ print(f'Loaded {model_name} model config.')
186
+ else:
187
+ raise RuntimeError(f'Model config for {model_name} not found.')
188
+
189
+ if force_quick_gelu:
190
+ # override for use of QuickGELU on non-OpenAI transformer models
191
+ model_cfg["quick_gelu"] = True
192
+
193
+ if force_patch_dropout is not None:
194
+ # override the default patch dropout value
195
+ model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout
196
+
197
+ if force_image_size is not None:
198
+ # override model config's image size
199
+ model_cfg["vision_cfg"]["image_size"] = force_image_size
200
+
201
+
202
+ cast_dtype = get_cast_dtype(precision)
203
+ custom_text = model_cfg.pop('custom_text', False) or force_custom_text
204
+
205
+ if custom_text:
206
+ model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype)
207
+ else:
208
+ model = CLIP(**model_cfg, cast_dtype=cast_dtype)
209
+
210
+ pretrained_loaded = False
211
+ if pretrained:
212
+ checkpoint_path = _MODEL_CKPT_PATHS[model_name]
213
+ if checkpoint_path:
214
+ print(f'Loading pretrained {model_name} weights ({pretrained}).')
215
+ load_checkpoint(model, checkpoint_path)
216
+ else:
217
+ raise RuntimeError(f'Pretrained weights ({pretrained}) not found for model {model_name}.')
218
+ pretrained_loaded = True
219
+
220
+ if require_pretrained and not pretrained_loaded:
221
+ # callers of create_model_from_pretrained always expect pretrained weights
222
+ raise RuntimeError(
223
+ f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.')
224
+
225
+ model.to(device=device)
226
+ if precision in ("fp16", "bf16"):
227
+ convert_weights_to_lp(model, dtype=torch.bfloat16 if precision == 'bf16' else torch.float16)
228
+
229
+ # set image / mean metadata from pretrained_cfg if available, or use default
230
+ model.visual.image_mean = (0.48145466, 0.4578275, 0.40821073)
231
+ model.visual.image_std = (0.26862954, 0.26130258, 0.27577711)
232
+
233
+ # to always output dict even if it is clip
234
+ if output_dict and hasattr(model, "output_dict"):
235
+ model.output_dict = True
236
+
237
+ if jit:
238
+ model = torch.jit.script(model)
239
+
240
+ return model