Continual-Mega commited on
Commit
537f272
·
verified ·
1 Parent(s): 1b8e995

Upload CLIP/openai.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. CLIP/openai.py +136 -0
CLIP/openai.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ OpenAI pretrained model functions
2
+
3
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
+ """
5
+
6
+ import os
7
+ import warnings
8
+ from typing import List, Optional, Union
9
+
10
+ import torch
11
+
12
+ from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
13
+
14
+ __all__ = ["list_openai_models", "load_openai_model"]
15
+
16
+
17
+ def load_openai_model(
18
+ name: str,
19
+ precision: Optional[str] = None,
20
+ device: Optional[Union[str, torch.device]] = None,
21
+ jit: bool = True,
22
+ cache_dir: Optional[str] = None,
23
+ ):
24
+ """Load a CLIP model
25
+
26
+ Parameters
27
+ ----------
28
+ name : str
29
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
30
+ precision: str
31
+ Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
32
+ device : Union[str, torch.device]
33
+ The device to put the loaded model
34
+ jit : bool
35
+ Whether to load the optimized JIT model (default) or more hackable non-JIT model.
36
+ cache_dir : Optional[str]
37
+ The directory to cache the downloaded model weights
38
+
39
+ Returns
40
+ -------
41
+ model : torch.nn.Module
42
+ The CLIP model
43
+ preprocess : Callable[[PIL.Image], torch.Tensor]
44
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
45
+ """
46
+ if device is None:
47
+ device = "cuda" if torch.cuda.is_available() else "cpu"
48
+ if precision is None:
49
+ precision = 'fp32' if device == 'cpu' else 'fp16'
50
+
51
+ if os.path.isfile(name):
52
+ model_path = name
53
+ else:
54
+ raise RuntimeError(f"Model {name} not found; available models")
55
+
56
+ try:
57
+ # loading JIT archive
58
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
59
+ state_dict = None
60
+ except RuntimeError:
61
+ # loading saved state dict
62
+ if jit:
63
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
64
+ jit = False
65
+ state_dict = torch.load(model_path, map_location="cpu")
66
+
67
+ if not jit:
68
+ # Build a non-jit model from the OpenAI jitted model state dict
69
+ cast_dtype = get_cast_dtype(precision)
70
+ try:
71
+ model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
72
+ except KeyError:
73
+ sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
74
+ model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
75
+
76
+ # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
77
+ model = model.to(device)
78
+ if precision.startswith('amp') or precision == 'fp32':
79
+ model.float()
80
+ elif precision == 'bf16':
81
+ convert_weights_to_lp(model, dtype=torch.bfloat16)
82
+
83
+ return model
84
+
85
+ # patch the device names
86
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
87
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
88
+
89
+ def patch_device(module):
90
+ try:
91
+ graphs = [module.graph] if hasattr(module, "graph") else []
92
+ except RuntimeError:
93
+ graphs = []
94
+
95
+ if hasattr(module, "forward1"):
96
+ graphs.append(module.forward1.graph)
97
+
98
+ for graph in graphs:
99
+ for node in graph.findAllNodes("prim::Constant"):
100
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
101
+ node.copyAttributes(device_node)
102
+
103
+ model.apply(patch_device)
104
+ patch_device(model.encode_image)
105
+ patch_device(model.encode_text)
106
+
107
+ # patch dtype to float32 (typically for CPU)
108
+ if precision == 'fp32':
109
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
110
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
111
+ float_node = float_input.node()
112
+
113
+ def patch_float(module):
114
+ try:
115
+ graphs = [module.graph] if hasattr(module, "graph") else []
116
+ except RuntimeError:
117
+ graphs = []
118
+
119
+ if hasattr(module, "forward1"):
120
+ graphs.append(module.forward1.graph)
121
+
122
+ for graph in graphs:
123
+ for node in graph.findAllNodes("aten::to"):
124
+ inputs = list(node.inputs())
125
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
126
+ if inputs[i].node()["value"] == 5:
127
+ inputs[i].node().copyAttributes(float_node)
128
+
129
+ model.apply(patch_float)
130
+ patch_float(model.encode_image)
131
+ patch_float(model.encode_text)
132
+ model.float()
133
+
134
+ # ensure image_size attr available at consistent location for both jit and non-jit
135
+ model.visual.image_size = model.input_resolution.item()
136
+ return model