Add files using upload-large-folder tool
Browse files- .gitattributes +6 -0
- __pycache__/adv_clip_loss.cpython-39.pyc +0 -0
- datasets/cifar-10-batches-py/data_batch_1 +3 -0
- datasets/cifar-10-batches-py/data_batch_2 +3 -0
- datasets/cifar-10-batches-py/data_batch_3 +3 -0
- datasets/cifar-10-batches-py/data_batch_4 +3 -0
- datasets/cifar-10-batches-py/data_batch_5 +3 -0
- datasets/cifar-10-batches-py/test_batch +3 -0
- datasets/cifar-10-python.tar.gz +3 -0
- modified_clip/.ipynb_checkpoints/clip-checkpoint.py +245 -0
- modified_clip/bpe_simple_vocab_16e6.txt.gz +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
datasets/cifar-10-batches-py/data_batch_2 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
datasets/cifar-10-batches-py/data_batch_1 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
datasets/cifar-10-batches-py/data_batch_5 filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
datasets/cifar-10-batches-py/data_batch_4 filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
datasets/cifar-10-batches-py/data_batch_3 filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
datasets/cifar-10-batches-py/test_batch filter=lfs diff=lfs merge=lfs -text
|
__pycache__/adv_clip_loss.cpython-39.pyc
ADDED
|
Binary file (5.05 kB). View file
|
|
|
datasets/cifar-10-batches-py/data_batch_1
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:54636561a3ce25bd3e19253c6b0d8538147b0ae398331ac4a2d86c6d987368cd
|
| 3 |
+
size 31035704
|
datasets/cifar-10-batches-py/data_batch_2
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:766b2cef9fbc745cf056b3152224f7cf77163b330ea9a15f9392beb8b89bc5a8
|
| 3 |
+
size 31035320
|
datasets/cifar-10-batches-py/data_batch_3
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0f00d98ebfb30b3ec0ad19f9756dc2630b89003e10525f5e148445e82aa6a1f9
|
| 3 |
+
size 31035999
|
datasets/cifar-10-batches-py/data_batch_4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3f7bb240661948b8f4d53e36ec720d8306f5668bd0071dcb4e6c947f78e9682b
|
| 3 |
+
size 31035696
|
datasets/cifar-10-batches-py/data_batch_5
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d91802434d8376bbaeeadf58a737e3a1b12ac839077e931237e0dcd43adcb154
|
| 3 |
+
size 31035623
|
datasets/cifar-10-batches-py/test_batch
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f53d8d457504f7cff4ea9e021afcf0e0ad8e24a91f3fc42091b8adef61157831
|
| 3 |
+
size 31035526
|
datasets/cifar-10-python.tar.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce
|
| 3 |
+
size 170498071
|
modified_clip/.ipynb_checkpoints/clip-checkpoint.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
import os
|
| 3 |
+
import urllib
|
| 4 |
+
import warnings
|
| 5 |
+
from typing import Any, Union, List
|
| 6 |
+
from pkg_resources import packaging
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
from .model import build_model
|
| 14 |
+
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
from torchvision.transforms import InterpolationMode
|
| 18 |
+
BICUBIC = InterpolationMode.BICUBIC
|
| 19 |
+
except ImportError:
|
| 20 |
+
BICUBIC = Image.BICUBIC
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
|
| 24 |
+
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
__all__ = ["available_models", "load", "tokenize"]
|
| 28 |
+
_tokenizer = _Tokenizer()
|
| 29 |
+
|
| 30 |
+
_MODELS = {
|
| 31 |
+
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
| 32 |
+
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
| 33 |
+
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
| 34 |
+
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
|
| 35 |
+
"RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
|
| 36 |
+
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
| 37 |
+
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
| 38 |
+
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
|
| 39 |
+
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _download(url: str, root: str):
|
| 44 |
+
os.makedirs(root, exist_ok=True)
|
| 45 |
+
filename = os.path.basename(url)
|
| 46 |
+
|
| 47 |
+
expected_sha256 = url.split("/")[-2]
|
| 48 |
+
download_target = os.path.join(root, filename)
|
| 49 |
+
|
| 50 |
+
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
| 51 |
+
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
| 52 |
+
|
| 53 |
+
if os.path.isfile(download_target):
|
| 54 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
| 55 |
+
return download_target
|
| 56 |
+
else:
|
| 57 |
+
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
| 58 |
+
|
| 59 |
+
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
| 60 |
+
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
|
| 61 |
+
while True:
|
| 62 |
+
buffer = source.read(8192)
|
| 63 |
+
if not buffer:
|
| 64 |
+
break
|
| 65 |
+
|
| 66 |
+
output.write(buffer)
|
| 67 |
+
loop.update(len(buffer))
|
| 68 |
+
|
| 69 |
+
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
| 70 |
+
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
|
| 71 |
+
|
| 72 |
+
return download_target
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _convert_image_to_rgb(image):
|
| 76 |
+
return image.convert("RGB")
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _transform(n_px):
|
| 80 |
+
return Compose([
|
| 81 |
+
Resize(n_px, interpolation=BICUBIC),
|
| 82 |
+
CenterCrop(n_px),
|
| 83 |
+
_convert_image_to_rgb,
|
| 84 |
+
ToTensor(),
|
| 85 |
+
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
| 86 |
+
])
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def available_models() -> List[str]:
|
| 90 |
+
"""Returns the names of available CLIP models"""
|
| 91 |
+
return list(_MODELS.keys())
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None, prompt_len: int = 0):
|
| 95 |
+
"""Load a CLIP model
|
| 96 |
+
|
| 97 |
+
Parameters
|
| 98 |
+
----------
|
| 99 |
+
name : str
|
| 100 |
+
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
| 101 |
+
|
| 102 |
+
device : Union[str, torch.device]
|
| 103 |
+
The device to put the loaded model
|
| 104 |
+
|
| 105 |
+
jit : bool
|
| 106 |
+
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
| 107 |
+
|
| 108 |
+
download_root: str
|
| 109 |
+
path to download the model files; by default, it uses "~/.cache/clip"
|
| 110 |
+
|
| 111 |
+
Returns
|
| 112 |
+
-------
|
| 113 |
+
model : torch.nn.Module
|
| 114 |
+
The CLIP model
|
| 115 |
+
|
| 116 |
+
preprocess : Callable[[PIL.Image], torch.Tensor]
|
| 117 |
+
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
| 118 |
+
"""
|
| 119 |
+
if name in _MODELS:
|
| 120 |
+
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
|
| 121 |
+
elif os.path.isfile(name):
|
| 122 |
+
model_path = name
|
| 123 |
+
else:
|
| 124 |
+
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
| 125 |
+
|
| 126 |
+
with open(model_path, 'rb') as opened_file:
|
| 127 |
+
try:
|
| 128 |
+
# loading JIT archive
|
| 129 |
+
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
|
| 130 |
+
state_dict = None
|
| 131 |
+
except RuntimeError:
|
| 132 |
+
# loading saved state dict
|
| 133 |
+
if jit:
|
| 134 |
+
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
| 135 |
+
jit = False
|
| 136 |
+
state_dict = torch.load(opened_file, map_location="cpu")
|
| 137 |
+
|
| 138 |
+
if not jit:
|
| 139 |
+
model = build_model(state_dict or model.state_dict(), prompt_len).to(device)
|
| 140 |
+
if str(device) == "cpu":
|
| 141 |
+
model.float()
|
| 142 |
+
return model, _transform(model.visual.input_resolution)
|
| 143 |
+
|
| 144 |
+
# patch the device names
|
| 145 |
+
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
| 146 |
+
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
| 147 |
+
|
| 148 |
+
def _node_get(node: torch._C.Node, key: str):
|
| 149 |
+
"""Gets attributes of a node which is polymorphic over return type.
|
| 150 |
+
|
| 151 |
+
From https://github.com/pytorch/pytorch/pull/82628
|
| 152 |
+
"""
|
| 153 |
+
sel = node.kindOf(key)
|
| 154 |
+
return getattr(node, sel)(key)
|
| 155 |
+
|
| 156 |
+
def patch_device(module):
|
| 157 |
+
try:
|
| 158 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
| 159 |
+
except RuntimeError:
|
| 160 |
+
graphs = []
|
| 161 |
+
|
| 162 |
+
if hasattr(module, "forward1"):
|
| 163 |
+
graphs.append(module.forward1.graph)
|
| 164 |
+
|
| 165 |
+
for graph in graphs:
|
| 166 |
+
for node in graph.findAllNodes("prim::Constant"):
|
| 167 |
+
if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"):
|
| 168 |
+
node.copyAttributes(device_node)
|
| 169 |
+
|
| 170 |
+
model.apply(patch_device)
|
| 171 |
+
patch_device(model.encode_image)
|
| 172 |
+
patch_device(model.encode_text)
|
| 173 |
+
|
| 174 |
+
# patch dtype to float32 on CPU
|
| 175 |
+
if str(device) == "cpu":
|
| 176 |
+
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
| 177 |
+
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
| 178 |
+
float_node = float_input.node()
|
| 179 |
+
|
| 180 |
+
def patch_float(module):
|
| 181 |
+
try:
|
| 182 |
+
graphs = [module.graph] if hasattr(module, "graph") else []
|
| 183 |
+
except RuntimeError:
|
| 184 |
+
graphs = []
|
| 185 |
+
|
| 186 |
+
if hasattr(module, "forward1"):
|
| 187 |
+
graphs.append(module.forward1.graph)
|
| 188 |
+
|
| 189 |
+
for graph in graphs:
|
| 190 |
+
for node in graph.findAllNodes("aten::to"):
|
| 191 |
+
inputs = list(node.inputs())
|
| 192 |
+
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
| 193 |
+
if _node_get(inputs[i].node(), "value") == 5:
|
| 194 |
+
inputs[i].node().copyAttributes(float_node)
|
| 195 |
+
|
| 196 |
+
model.apply(patch_float)
|
| 197 |
+
patch_float(model.encode_image)
|
| 198 |
+
patch_float(model.encode_text)
|
| 199 |
+
|
| 200 |
+
model.float()
|
| 201 |
+
|
| 202 |
+
return model, _transform(model.input_resolution.item())
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
|
| 206 |
+
"""
|
| 207 |
+
Returns the tokenized representation of given input string(s)
|
| 208 |
+
|
| 209 |
+
Parameters
|
| 210 |
+
----------
|
| 211 |
+
texts : Union[str, List[str]]
|
| 212 |
+
An input string or a list of input strings to tokenize
|
| 213 |
+
|
| 214 |
+
context_length : int
|
| 215 |
+
The context length to use; all CLIP models use 77 as the context length
|
| 216 |
+
|
| 217 |
+
truncate: bool
|
| 218 |
+
Whether to truncate the text in case its encoding is longer than the context length
|
| 219 |
+
|
| 220 |
+
Returns
|
| 221 |
+
-------
|
| 222 |
+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
|
| 223 |
+
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
|
| 224 |
+
"""
|
| 225 |
+
if isinstance(texts, str):
|
| 226 |
+
texts = [texts]
|
| 227 |
+
|
| 228 |
+
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
| 229 |
+
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
| 230 |
+
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
| 231 |
+
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
|
| 232 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
| 233 |
+
else:
|
| 234 |
+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
|
| 235 |
+
|
| 236 |
+
for i, tokens in enumerate(all_tokens):
|
| 237 |
+
if len(tokens) > context_length:
|
| 238 |
+
if truncate:
|
| 239 |
+
tokens = tokens[:context_length]
|
| 240 |
+
tokens[-1] = eot_token
|
| 241 |
+
else:
|
| 242 |
+
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
| 243 |
+
result[i, :len(tokens)] = torch.tensor(tokens)
|
| 244 |
+
|
| 245 |
+
return result
|
modified_clip/bpe_simple_vocab_16e6.txt.gz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
| 3 |
+
size 1356917
|