davjoython commited on
Commit
30a7879
·
verified ·
1 Parent(s): 8462475

Upload 68 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +68 -0
  2. augmentations_clip.py +255 -0
  3. clip/.ipynb_checkpoints/clip-checkpoint.py +225 -0
  4. clip/.ipynb_checkpoints/model-checkpoint.py +432 -0
  5. clip/__init__.py +1 -0
  6. clip/__pycache__/__init__.cpython-311.pyc +0 -0
  7. clip/__pycache__/__init__.cpython-36.pyc +0 -0
  8. clip/__pycache__/__init__.cpython-38.pyc +0 -0
  9. clip/__pycache__/clip.cpython-311.pyc +0 -0
  10. clip/__pycache__/clip.cpython-36.pyc +0 -0
  11. clip/__pycache__/clip.cpython-38.pyc +0 -0
  12. clip/__pycache__/model.cpython-311.pyc +0 -0
  13. clip/__pycache__/model.cpython-36.pyc +0 -0
  14. clip/__pycache__/model.cpython-38.pyc +0 -0
  15. clip/__pycache__/simple_tokenizer.cpython-311.pyc +0 -0
  16. clip/__pycache__/simple_tokenizer.cpython-36.pyc +0 -0
  17. clip/__pycache__/simple_tokenizer.cpython-38.pyc +0 -0
  18. clip/__pycache__/utils.cpython-38.pyc +0 -0
  19. clip/bpe_simple_vocab_16e6.txt.gz +3 -0
  20. clip/clip.py +228 -0
  21. clip/model.py +432 -0
  22. clip/simple_tokenizer.py +132 -0
  23. config.py +48 -0
  24. detector.py +86 -0
  25. freia_funcs.py +473 -0
  26. loralib/__init__.py +2 -0
  27. loralib/__pycache__/__init__.cpython-38.pyc +0 -0
  28. loralib/__pycache__/layers.cpython-38.pyc +0 -0
  29. loralib/__pycache__/utils.cpython-38.pyc +0 -0
  30. loralib/easymultiheadattention.py +124 -0
  31. loralib/layers.py +598 -0
  32. loralib/utils.py +236 -0
  33. model.py +95 -0
  34. models/__init__.py +43 -0
  35. models/__pycache__/__init__.cpython-38.pyc +0 -0
  36. models/__pycache__/clip_models.cpython-38.pyc +0 -0
  37. models/__pycache__/imagenet_models.cpython-38.pyc +0 -0
  38. models/__pycache__/resnet.cpython-38.pyc +0 -0
  39. models/__pycache__/vision_transformer.cpython-38.pyc +0 -0
  40. models/__pycache__/vision_transformer_misc.cpython-38.pyc +0 -0
  41. models/__pycache__/vision_transformer_utils.cpython-38.pyc +0 -0
  42. models/clip/__init__.py +1 -0
  43. models/clip/__pycache__/__init__.cpython-310.pyc +0 -0
  44. models/clip/__pycache__/__init__.cpython-38.pyc +0 -0
  45. models/clip/__pycache__/__init__.cpython-39.pyc +0 -0
  46. models/clip/__pycache__/clip.cpython-310.pyc +0 -0
  47. models/clip/__pycache__/clip.cpython-38.pyc +0 -0
  48. models/clip/__pycache__/clip.cpython-39.pyc +0 -0
  49. models/clip/__pycache__/model.cpython-310.pyc +0 -0
  50. models/clip/__pycache__/model.cpython-38.pyc +0 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ from detector import FakeImageDetector
4
+
5
+ print("正在初始化检测器,请稍候...")
6
+ try:
7
+ detector = FakeImageDetector()
8
+
9
+ print("检测器初始化完成,Web 服务准备就绪。")
10
+ models_loaded = True
11
+ except Exception as e:
12
+ print(f"模型加载失败: {e}")
13
+ models_loaded = False
14
+ detector = None
15
+
16
+ def predict_image(input_image_numpy, threshold):
17
+ """
18
+ 接收 Gradio 的输入 (numpy array),调用检测器,并返回结果。
19
+ """
20
+ if not models_loaded or detector is None:
21
+ return "错误:模型未能成功加载,请检查后台日志。", None
22
+
23
+ pil_image = Image.fromarray(input_image_numpy)
24
+
25
+ result_text, score = detector.detect(pil_image, threshold)
26
+
27
+ label_color = "red" if score > threshold else "green"
28
+
29
+ return result_text, gr.Label(value=f"{score:.10f}", label=label_color)
30
+
31
+
32
+ with gr.Blocks(title="伪造图像检测器", theme=gr.themes.Soft()) as demo:
33
+ gr.Markdown(
34
+ """
35
+ # 伪造图像检测器 (Fake Image Detector)
36
+ 上传一张图片,模型将判断其为 **真实的 (Real)** 还是 **AI 生成的伪造图像 (Fake)**。
37
+ """
38
+ )
39
+
40
+ with gr.Row():
41
+ with gr.Column(scale=1):
42
+ # 输入组件
43
+ image_input = gr.Image(type="numpy", label="上传图片", height=300)
44
+ # threshold_slider = gr.Slider(
45
+ # minimum=0.495, maximum=0.55, value=0.499892068, step=0.0001,
46
+ # label="检测门限 (Threshold)",
47
+ # info="得分低于此门限的图片被认为是伪造的"
48
+ # )
49
+ submit_btn = gr.Button("开始检测", variant="primary")
50
+
51
+ with gr.Column(scale=1):
52
+ # 输出组件
53
+ result_output_text = gr.Textbox(label="检测结论", lines=2)
54
+ # 这里我们用一个临时的 Label 来显示带颜色的分数
55
+ result_output_score = gr.Label(label="模型原始得分")
56
+
57
+ submit_btn.click(
58
+ fn=predict_image,
59
+ inputs=[image_input, 0.49999],
60
+ outputs=[result_output_text, result_output_score]
61
+ )
62
+
63
+ if not models_loaded:
64
+ print("\n由于模型加载失败,Gradio Web服务无法启动。")
65
+ else:
66
+ print("正在启动 Gradio 服务...")
67
+
68
+ demo.launch(server_name="0.0.0.0")
augmentations_clip.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # import logging
7
+
8
+ from torchvision import transforms
9
+
10
+ import torch
11
+ import cv2
12
+ from PIL import Image
13
+ import numpy as np
14
+
15
+ from my_transforms import (
16
+ GaussianBlur,
17
+ make_normalize_transform,
18
+ make_normalize_transform_clip,
19
+ )
20
+
21
+ def add_gaussian_noise(tensor, mean=0.0, std=0.1):
22
+ noise = torch.randn(tensor.size()).cuda() * std + mean
23
+ return tensor + noise
24
+
25
+
26
+
27
+
28
+ class DataAugmentationCLIP(object):
29
+ def __init__(
30
+ self,
31
+ global_crops_scale,
32
+ local_crops_scale,
33
+ local_crops_number,
34
+ global_crops_size=224,
35
+ local_crops_size=96,
36
+ ):
37
+
38
+ self.source_trans = transforms.Compose([
39
+ # transforms.RandomCrop(224),
40
+ # transforms.CenterCrop(224),
41
+ transforms.ToTensor(),
42
+ make_normalize_transform_clip(),
43
+ ])
44
+
45
+ # self.crop = transforms.Compose([
46
+ # transforms.CenterCrop(224),
47
+
48
+ # ])
49
+
50
+ self.crop = transforms.Compose([
51
+ transforms.Resize(224), # 将短边缩放到 224,长边会按比例缩放
52
+ transforms.RandomCrop(224), # 然后裁剪到 224x224
53
+ ])
54
+
55
+ self.centercrop = transforms.Compose([
56
+ transforms.CenterCrop(224),
57
+
58
+ ])
59
+
60
+ self.randomcrop = transforms.Compose([
61
+ transforms.RandomCrop(224),
62
+
63
+ ])
64
+
65
+ self.local_crops_number = local_crops_number
66
+
67
+ def __call__(self, image):
68
+ output = {}
69
+ output["source"] = []
70
+
71
+ if np.array(image).shape[0]<224 or np.array(image).shape[1]<224:
72
+ crops_all = [
73
+ self.centercrop(image) for _ in range(self.local_crops_number)
74
+ ]
75
+ else:
76
+ crops_all = [
77
+ self.centercrop(image) for _ in range(self.local_crops_number)
78
+ ]
79
+
80
+ for crops_image in crops_all:
81
+ output["source"].append(self.source_trans(crops_image)) #单独使用好一些
82
+
83
+
84
+ output["offsets"] = ()
85
+
86
+ return output
87
+
88
+
89
+ class DataAugmentationDINO(object):
90
+ def __init__(
91
+ self,
92
+ global_crops_scale,
93
+ local_crops_scale,
94
+ local_crops_number,
95
+ global_crops_size=224,
96
+ local_crops_size=96,
97
+ ):
98
+
99
+ self.source_trans = transforms.Compose([
100
+ # transforms.RandomCrop(224),
101
+ # transforms.CenterCrop(224),
102
+ transforms.ToTensor(),
103
+ make_normalize_transform(),
104
+ ])
105
+
106
+ # self.crop = transforms.Compose([
107
+ # transforms.CenterCrop(224),
108
+
109
+ # ])
110
+
111
+ self.crop = transforms.Compose([
112
+ transforms.Resize(224), # 将短边缩放到 224,长边会按比例缩放
113
+ transforms.CenterCrop(224), # 然后裁剪到 224x224
114
+ ])
115
+
116
+ self.centercrop = transforms.Compose([
117
+ transforms.CenterCrop(224),
118
+
119
+ ])
120
+
121
+ self.local_crops_number = local_crops_number
122
+
123
+ def __call__(self, image):
124
+ output = {}
125
+ output["source"] = []
126
+
127
+ if np.array(image).shape[0]<224 or np.array(image).shape[1]<224:
128
+ crops_all = [
129
+ self.centercrop(image) for _ in range(self.local_crops_number)
130
+ ]
131
+ else:
132
+ crops_all = [
133
+ self.centercrop(image) for _ in range(self.local_crops_number)
134
+ ]
135
+
136
+ for crops_image in crops_all:
137
+ output["source"].append(self.source_trans(crops_image)) #单独使用好一些
138
+
139
+
140
+ output["offsets"] = ()
141
+
142
+ return output
143
+
144
+
145
+ class DataAugmentationResNet_test(object):
146
+ def __init__(
147
+ self,
148
+ global_crops_scale,
149
+ local_crops_scale,
150
+ local_crops_number,
151
+ global_crops_size=224,
152
+ local_crops_size=96,
153
+ ):
154
+
155
+ self.source_trans = transforms.Compose([
156
+ # transforms.RandomCrop(224),
157
+ # transforms.CenterCrop(224),
158
+ transforms.ToTensor(),
159
+ make_normalize_transform(),
160
+ ])
161
+
162
+ # self.crop = transforms.Compose([
163
+ # transforms.CenterCrop(224),
164
+
165
+ # ])
166
+
167
+ self.crop = transforms.Compose([
168
+ transforms.Resize(224), # 将短边缩放到 224,长边会按比例缩放
169
+ transforms.CenterCrop(224), # 然后裁剪到 224x224
170
+ ])
171
+
172
+ self.centercrop = transforms.Compose([
173
+ transforms.CenterCrop(224),
174
+
175
+ ])
176
+
177
+ self.local_crops_number = local_crops_number
178
+
179
+ def __call__(self, image):
180
+ output = {}
181
+ output["source"] = []
182
+
183
+ if np.array(image).shape[0]<224 or np.array(image).shape[1]<224:
184
+ crops_all = [
185
+ self.centercrop(image) for _ in range(self.local_crops_number)
186
+ ]
187
+ else:
188
+ crops_all = [
189
+ self.centercrop(image) for _ in range(self.local_crops_number)
190
+ ]
191
+
192
+ for crops_image in crops_all:
193
+ output["source"].append(self.source_trans(crops_image)) #单独使用好一些
194
+
195
+
196
+ output["offsets"] = ()
197
+
198
+ return output
199
+
200
+
201
+
202
+ class DataAugmentationCLIP_gen(object):
203
+ def __init__(
204
+ self,
205
+ global_crops_scale,
206
+ local_crops_scale,
207
+ local_crops_number,
208
+ global_crops_size=224,
209
+ local_crops_size=96,
210
+ ):
211
+
212
+ self.source_trans = transforms.Compose([
213
+ # transforms.RandomCrop(224),
214
+ # transforms.CenterCrop(224),
215
+ transforms.ToTensor(),
216
+ make_normalize_transform_clip(),
217
+ ])
218
+
219
+ # self.crop = transforms.Compose([
220
+ # transforms.RandomCrop(224),
221
+
222
+ # ])
223
+
224
+ self.crop = transforms.Compose([
225
+ transforms.Resize(224), # 将短边缩放到 224,长边会按比例缩放
226
+ transforms.CenterCrop(224), # 然后裁剪到 224x224
227
+ ])
228
+
229
+ self.centercrop = transforms.Compose([
230
+ transforms.CenterCrop(224),
231
+
232
+ ])
233
+
234
+ self.local_crops_number = local_crops_number
235
+
236
+ def __call__(self, image):
237
+ output = {}
238
+ output["source"] = []
239
+
240
+ # if np.array(image).shape[0]<224 or np.array(image).shape[1]<224:
241
+ # crops_all = [
242
+ # self.crop(self.centercrop(image)) for _ in range(self.local_crops_number)
243
+ # ]
244
+ # else:
245
+ crops_all = [
246
+ self.crop(image) for _ in range(self.local_crops_number)
247
+ ]
248
+
249
+ for crops_image in crops_all:
250
+ output["source"].append(self.source_trans(crops_image)) #单独使用好一些
251
+
252
+
253
+ output["offsets"] = ()
254
+
255
+ return output
clip/.ipynb_checkpoints/clip-checkpoint.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from typing import Union, List
6
+
7
+ import torch
8
+ from PIL import Image
9
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
10
+ from tqdm import tqdm
11
+
12
+ from .model import build_model
13
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
14
+
15
+ try:
16
+ from torchvision.transforms import InterpolationMode
17
+ BICUBIC = InterpolationMode.BICUBIC
18
+ except ImportError:
19
+ BICUBIC = Image.BICUBIC
20
+
21
+
22
+ if torch.__version__.split(".") < ["1", "7", "1"]:
23
+ warnings.warn("PyTorch version 1.7.1 or higher is recommended")
24
+
25
+
26
+ __all__ = ["available_models", "load", "tokenize"]
27
+ _tokenizer = _Tokenizer()
28
+
29
+ _MODELS = {
30
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
31
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
32
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
33
+ "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
34
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
35
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
36
+ }
37
+
38
+
39
+ def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
40
+ os.makedirs(root, exist_ok=True)
41
+ filename = os.path.basename(url)
42
+
43
+ expected_sha256 = url.split("/")[-2]
44
+ download_target = os.path.join(root, filename)
45
+
46
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
47
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
48
+
49
+ if os.path.isfile(download_target):
50
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
51
+ return download_target
52
+ else:
53
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
54
+
55
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
56
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
57
+ while True:
58
+ buffer = source.read(8192)
59
+ if not buffer:
60
+ break
61
+
62
+ output.write(buffer)
63
+ loop.update(len(buffer))
64
+
65
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
66
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
67
+
68
+ return download_target
69
+
70
+
71
+ def _transform(n_px):
72
+ return Compose([
73
+ Resize(n_px, interpolation=BICUBIC),
74
+ CenterCrop(n_px),
75
+ lambda image: image.convert("RGB"),
76
+ ToTensor(),
77
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
78
+ ])
79
+
80
+
81
+ def available_models() -> List[str]:
82
+ """Returns the names of available CLIP models"""
83
+ return list(_MODELS.keys())
84
+
85
+
86
+ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False):
87
+ """Load a CLIP model
88
+
89
+ Parameters
90
+ ----------
91
+ name : str
92
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
93
+
94
+ device : Union[str, torch.device]
95
+ The device to put the loaded model
96
+
97
+ jit : bool
98
+ Whether to load the optimized JIT model or more hackable non-JIT model (default).
99
+
100
+ Returns
101
+ -------
102
+ model : torch.nn.Module
103
+ The CLIP model
104
+
105
+ preprocess : Callable[[PIL.Image], torch.Tensor]
106
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
107
+ """
108
+ '''
109
+ if name in _MODELS:
110
+ model_path = _download(_MODELS[name])
111
+ elif os.path.isfile(name):
112
+ model_path = name
113
+ else:
114
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
115
+ '''
116
+ model_path = '/model/4DaiRui/pretrained_ood/ViT-B-16.pt'
117
+
118
+
119
+ try:
120
+ # loading JIT archive
121
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
122
+ state_dict = None
123
+ except RuntimeError:
124
+ # loading saved state dict
125
+ if jit:
126
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
127
+ jit = False
128
+ state_dict = torch.load(model_path, map_location="cpu")
129
+
130
+ if not jit:
131
+ model = build_model(state_dict or model.state_dict()).to(device)
132
+ if str(device) == "cpu":
133
+ model.float()
134
+ return model, _transform(model.visual.input_resolution)
135
+
136
+ # patch the device names
137
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
138
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
139
+
140
+ def patch_device(module):
141
+ try:
142
+ graphs = [module.graph] if hasattr(module, "graph") else []
143
+ except RuntimeError:
144
+ graphs = []
145
+
146
+ if hasattr(module, "forward1"):
147
+ graphs.append(module.forward1.graph)
148
+
149
+ for graph in graphs:
150
+ for node in graph.findAllNodes("prim::Constant"):
151
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
152
+ node.copyAttributes(device_node)
153
+
154
+ model.apply(patch_device)
155
+ patch_device(model.encode_image)
156
+ patch_device(model.encode_text)
157
+
158
+ # patch dtype to float32 on CPU
159
+ if str(device) == "cpu":
160
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
161
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
162
+ float_node = float_input.node()
163
+
164
+ def patch_float(module):
165
+ try:
166
+ graphs = [module.graph] if hasattr(module, "graph") else []
167
+ except RuntimeError:
168
+ graphs = []
169
+
170
+ if hasattr(module, "forward1"):
171
+ graphs.append(module.forward1.graph)
172
+
173
+ for graph in graphs:
174
+ for node in graph.findAllNodes("aten::to"):
175
+ inputs = list(node.inputs())
176
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
177
+ if inputs[i].node()["value"] == 5:
178
+ inputs[i].node().copyAttributes(float_node)
179
+
180
+ model.apply(patch_float)
181
+ patch_float(model.encode_image)
182
+ patch_float(model.encode_text)
183
+
184
+ model.float()
185
+
186
+ return model, _transform(model.input_resolution.item())
187
+
188
+
189
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
190
+ """
191
+ Returns the tokenized representation of given input string(s)
192
+
193
+ Parameters
194
+ ----------
195
+ texts : Union[str, List[str]]
196
+ An input string or a list of input strings to tokenize
197
+
198
+ context_length : int
199
+ The context length to use; all CLIP models use 77 as the context length
200
+
201
+ truncate: bool
202
+ Whether to truncate the text in case its encoding is longer than the context length
203
+
204
+ Returns
205
+ -------
206
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
207
+ """
208
+ if isinstance(texts, str):
209
+ texts = [texts]
210
+
211
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
212
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
213
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
214
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
215
+
216
+ for i, tokens in enumerate(all_tokens):
217
+ if len(tokens) > context_length:
218
+ if truncate:
219
+ tokens = tokens[:context_length]
220
+ tokens[-1] = eot_token
221
+ else:
222
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
223
+ result[i, :len(tokens)] = torch.tensor(tokens)
224
+
225
+ return result
clip/.ipynb_checkpoints/model-checkpoint.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+
10
+ class Bottleneck(nn.Module):
11
+ expansion = 4
12
+
13
+ def __init__(self, inplanes, planes, stride=1):
14
+ super().__init__()
15
+
16
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18
+ self.bn1 = nn.BatchNorm2d(planes)
19
+
20
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
21
+ self.bn2 = nn.BatchNorm2d(planes)
22
+
23
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
24
+
25
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
26
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
27
+
28
+ self.relu = nn.ReLU(inplace=True)
29
+ self.downsample = None
30
+ self.stride = stride
31
+
32
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
33
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
34
+ self.downsample = nn.Sequential(OrderedDict([
35
+ ("-1", nn.AvgPool2d(stride)),
36
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
37
+ ("1", nn.BatchNorm2d(planes * self.expansion))
38
+ ]))
39
+
40
+ def forward(self, x: torch.Tensor):
41
+ identity = x
42
+
43
+ out = self.relu(self.bn1(self.conv1(x)))
44
+ out = self.relu(self.bn2(self.conv2(out)))
45
+ out = self.avgpool(out)
46
+ out = self.bn3(self.conv3(out))
47
+
48
+ if self.downsample is not None:
49
+ identity = self.downsample(x)
50
+
51
+ out += identity
52
+ out = self.relu(out)
53
+ return out
54
+
55
+
56
+ class AttentionPool2d(nn.Module):
57
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
58
+ super().__init__()
59
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
60
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
61
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
62
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
63
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
64
+ self.num_heads = num_heads
65
+
66
+ def forward(self, x):
67
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
68
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
69
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
70
+ x, _ = F.multi_head_attention_forward(
71
+ query=x, key=x, value=x,
72
+ embed_dim_to_check=x.shape[-1],
73
+ num_heads=self.num_heads,
74
+ q_proj_weight=self.q_proj.weight,
75
+ k_proj_weight=self.k_proj.weight,
76
+ v_proj_weight=self.v_proj.weight,
77
+ in_proj_weight=None,
78
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
79
+ bias_k=None,
80
+ bias_v=None,
81
+ add_zero_attn=False,
82
+ dropout_p=0,
83
+ out_proj_weight=self.c_proj.weight,
84
+ out_proj_bias=self.c_proj.bias,
85
+ use_separate_proj_weight=True,
86
+ training=self.training,
87
+ need_weights=False
88
+ )
89
+
90
+ return x[0]
91
+
92
+
93
+ class ModifiedResNet(nn.Module):
94
+ """
95
+ A ResNet class that is similar to torchvision's but contains the following changes:
96
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
97
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
98
+ - The final pooling layer is a QKV attention instead of an average pool
99
+ """
100
+
101
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
102
+ super().__init__()
103
+ self.output_dim = output_dim
104
+ self.input_resolution = input_resolution
105
+
106
+ # the 3-layer stem
107
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
108
+ self.bn1 = nn.BatchNorm2d(width // 2)
109
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
110
+ self.bn2 = nn.BatchNorm2d(width // 2)
111
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
112
+ self.bn3 = nn.BatchNorm2d(width)
113
+ self.avgpool = nn.AvgPool2d(2)
114
+ self.relu = nn.ReLU(inplace=True)
115
+
116
+ # residual layers
117
+ self._inplanes = width # this is a *mutable* variable used during construction
118
+ self.layer1 = self._make_layer(width, layers[0])
119
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
120
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
121
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
122
+
123
+ embed_dim = width * 32 # the ResNet feature dimension
124
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
125
+
126
+ def _make_layer(self, planes, blocks, stride=1):
127
+ layers = [Bottleneck(self._inplanes, planes, stride)]
128
+
129
+ self._inplanes = planes * Bottleneck.expansion
130
+ for _ in range(1, blocks):
131
+ layers.append(Bottleneck(self._inplanes, planes))
132
+
133
+ return nn.Sequential(*layers)
134
+
135
+ def forward(self, x):
136
+ def stem(x):
137
+ for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
138
+ x = self.relu(bn(conv(x)))
139
+ x = self.avgpool(x)
140
+ return x
141
+
142
+ x = x.type(self.conv1.weight.dtype)
143
+ x = stem(x)
144
+ x = self.layer1(x)
145
+ x = self.layer2(x)
146
+ x = self.layer3(x)
147
+ x = self.layer4(x)
148
+ x = self.attnpool(x)
149
+
150
+ return x
151
+
152
+
153
+ class LayerNorm(nn.LayerNorm):
154
+ """Subclass torch's LayerNorm to handle fp16."""
155
+
156
+ def forward(self, x: torch.Tensor):
157
+ orig_type = x.dtype
158
+ ret = super().forward(x.type(torch.float32))
159
+ return ret.type(orig_type)
160
+
161
+
162
+ class QuickGELU(nn.Module):
163
+ def forward(self, x: torch.Tensor):
164
+ return x * torch.sigmoid(1.702 * x)
165
+
166
+
167
+ class ResidualAttentionBlock(nn.Module):
168
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
169
+ super().__init__()
170
+
171
+ self.attn = nn.MultiheadAttention(d_model, n_head)
172
+ self.ln_1 = LayerNorm(d_model)
173
+ self.mlp = nn.Sequential(OrderedDict([
174
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
175
+ ("gelu", QuickGELU()),
176
+ ("c_proj", nn.Linear(d_model * 4, d_model))
177
+ ]))
178
+ self.ln_2 = LayerNorm(d_model)
179
+ self.attn_mask = attn_mask
180
+
181
+ def attention(self, x: torch.Tensor):
182
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
183
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
184
+
185
+ def forward(self, x: torch.Tensor):
186
+ x = x + self.attention(self.ln_1(x))
187
+ x = x + self.mlp(self.ln_2(x))
188
+ return x
189
+
190
+
191
+ class Transformer(nn.Module):
192
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
193
+ super().__init__()
194
+ self.width = width
195
+ self.layers = layers
196
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
197
+
198
+ def forward(self, x: torch.Tensor):
199
+ return self.resblocks(x)
200
+
201
+
202
+ class VisionTransformer(nn.Module):
203
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
204
+ super().__init__()
205
+ self.input_resolution = input_resolution
206
+ self.output_dim = output_dim
207
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
208
+
209
+ scale = width ** -0.5
210
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
211
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
212
+ self.ln_pre = LayerNorm(width)
213
+
214
+ self.transformer = Transformer(width, layers, heads)
215
+
216
+ self.ln_post = LayerNorm(width)
217
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
218
+
219
+ def forward(self, x: torch.Tensor):
220
+ x = self.conv1(x) # shape = [*, width, grid, grid]
221
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
222
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
223
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
224
+ x = x + self.positional_embedding.to(x.dtype)
225
+ x = self.ln_pre(x)
226
+
227
+ x = x.permute(1, 0, 2) # NLD -> LND
228
+ x = self.transformer(x)
229
+ x = x.permute(1, 0, 2) # LND -> NLD
230
+
231
+ x = self.ln_post(x[:, 0, :])
232
+
233
+ if self.proj is not None:
234
+ x = x @ self.proj
235
+
236
+ return x
237
+
238
+
239
+ class CLIP(nn.Module):
240
+ def __init__(self,
241
+ embed_dim: int,
242
+ # vision
243
+ image_resolution: int,
244
+ vision_layers: Union[Tuple[int, int, int, int], int],
245
+ vision_width: int,
246
+ vision_patch_size: int,
247
+ # text
248
+ context_length: int,
249
+ vocab_size: int,
250
+ transformer_width: int,
251
+ transformer_heads: int,
252
+ transformer_layers: int
253
+ ):
254
+ super().__init__()
255
+
256
+ self.context_length = context_length
257
+
258
+ if isinstance(vision_layers, (tuple, list)):
259
+ vision_heads = vision_width * 32 // 64
260
+ self.visual = ModifiedResNet(
261
+ layers=vision_layers,
262
+ output_dim=embed_dim,
263
+ heads=vision_heads,
264
+ input_resolution=image_resolution,
265
+ width=vision_width
266
+ )
267
+ else:
268
+ vision_heads = vision_width // 64
269
+ self.visual = VisionTransformer(
270
+ input_resolution=image_resolution,
271
+ patch_size=vision_patch_size,
272
+ width=vision_width,
273
+ layers=vision_layers,
274
+ heads=vision_heads,
275
+ output_dim=embed_dim
276
+ )
277
+
278
+ self.transformer = Transformer(
279
+ width=transformer_width,
280
+ layers=transformer_layers,
281
+ heads=transformer_heads,
282
+ attn_mask=self.build_attention_mask()
283
+ )
284
+
285
+ self.vocab_size = vocab_size
286
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
287
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
288
+ self.ln_final = LayerNorm(transformer_width)
289
+
290
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
291
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
292
+
293
+ self.initialize_parameters()
294
+
295
+ def initialize_parameters(self):
296
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
297
+ nn.init.normal_(self.positional_embedding, std=0.01)
298
+
299
+ if isinstance(self.visual, ModifiedResNet):
300
+ if self.visual.attnpool is not None:
301
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
302
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
303
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
304
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
305
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
306
+
307
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
308
+ for name, param in resnet_block.named_parameters():
309
+ if name.endswith("bn3.weight"):
310
+ nn.init.zeros_(param)
311
+
312
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
313
+ attn_std = self.transformer.width ** -0.5
314
+ fc_std = (2 * self.transformer.width) ** -0.5
315
+ for block in self.transformer.resblocks:
316
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
317
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
318
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
319
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
320
+
321
+ if self.text_projection is not None:
322
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
323
+
324
+ def build_attention_mask(self):
325
+ # lazily create causal attention mask, with full attention between the vision tokens
326
+ # pytorch uses additive attention mask; fill with -inf
327
+ mask = torch.empty(self.context_length, self.context_length)
328
+ mask.fill_(float("-inf"))
329
+ mask.triu_(1) # zero out the lower diagonal
330
+ return mask
331
+
332
+ @property
333
+ def dtype(self):
334
+ return self.visual.conv1.weight.dtype
335
+
336
+ def encode_image(self, image):
337
+ return self.visual(image.type(self.dtype))
338
+
339
+ def encode_text(self, text):
340
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
341
+
342
+ x = x + self.positional_embedding.type(self.dtype)
343
+ x = x.permute(1, 0, 2) # NLD -> LND
344
+ x = self.transformer(x)
345
+ x = x.permute(1, 0, 2) # LND -> NLD
346
+ x = self.ln_final(x).type(self.dtype)
347
+
348
+ # x.shape = [batch_size, n_ctx, transformer.width]
349
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
350
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
351
+
352
+ return x
353
+
354
+ def forward(self, image, text):
355
+ image_features = self.encode_image(image)
356
+ text_features = self.encode_text(text)
357
+
358
+ # normalized features
359
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
360
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
361
+
362
+ # cosine similarity as logits
363
+ logit_scale = self.logit_scale.exp()
364
+ logits_per_image = logit_scale * image_features @ text_features.t()
365
+ logits_per_text = logit_scale * text_features @ image_features.t()
366
+
367
+ # shape = [global_batch_size, global_batch_size]
368
+ return logits_per_image, logits_per_text
369
+
370
+
371
+ def convert_weights(model: nn.Module):
372
+ """Convert applicable model parameters to fp16"""
373
+
374
+ def _convert_weights_to_fp16(l):
375
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
376
+ l.weight.data = l.weight.data.half()
377
+ if l.bias is not None:
378
+ l.bias.data = l.bias.data.half()
379
+
380
+ if isinstance(l, nn.MultiheadAttention):
381
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
382
+ tensor = getattr(l, attr)
383
+ if tensor is not None:
384
+ tensor.data = tensor.data.half()
385
+
386
+ for name in ["text_projection", "proj"]:
387
+ if hasattr(l, name):
388
+ attr = getattr(l, name)
389
+ if attr is not None:
390
+ attr.data = attr.data.half()
391
+
392
+ model.apply(_convert_weights_to_fp16)
393
+
394
+
395
+ def build_model(state_dict: dict):
396
+ vit = "visual.proj" in state_dict
397
+
398
+ if vit:
399
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
400
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
401
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
402
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
403
+ image_resolution = vision_patch_size * grid_size
404
+ else:
405
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
406
+ vision_layers = tuple(counts)
407
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
408
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
409
+ vision_patch_size = None
410
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
411
+ image_resolution = output_width * 32
412
+
413
+ embed_dim = state_dict["text_projection"].shape[1]
414
+ context_length = state_dict["positional_embedding"].shape[0]
415
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
416
+ transformer_width = state_dict["ln_final.weight"].shape[0]
417
+ transformer_heads = transformer_width // 64
418
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
419
+
420
+ model = CLIP(
421
+ embed_dim,
422
+ image_resolution, vision_layers, vision_width, vision_patch_size,
423
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
424
+ )
425
+
426
+ for key in ["input_resolution", "context_length", "vocab_size"]:
427
+ if key in state_dict:
428
+ del state_dict[key]
429
+
430
+ convert_weights(model)
431
+ model.load_state_dict(state_dict)
432
+ return model.eval()
clip/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .clip import *
clip/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (172 Bytes). View file
 
clip/__pycache__/__init__.cpython-36.pyc ADDED
Binary file (126 Bytes). View file
 
clip/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (140 Bytes). View file
 
clip/__pycache__/clip.cpython-311.pyc ADDED
Binary file (14.9 kB). View file
 
clip/__pycache__/clip.cpython-36.pyc ADDED
Binary file (7.67 kB). View file
 
clip/__pycache__/clip.cpython-38.pyc ADDED
Binary file (8.03 kB). View file
 
clip/__pycache__/model.cpython-311.pyc ADDED
Binary file (31.3 kB). View file
 
clip/__pycache__/model.cpython-36.pyc ADDED
Binary file (15 kB). View file
 
clip/__pycache__/model.cpython-38.pyc ADDED
Binary file (14.9 kB). View file
 
clip/__pycache__/simple_tokenizer.cpython-311.pyc ADDED
Binary file (11 kB). View file
 
clip/__pycache__/simple_tokenizer.cpython-36.pyc ADDED
Binary file (5.76 kB). View file
 
clip/__pycache__/simple_tokenizer.cpython-38.pyc ADDED
Binary file (5.77 kB). View file
 
clip/__pycache__/utils.cpython-38.pyc ADDED
Binary file (6.82 kB). View file
 
clip/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
clip/clip.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import os
3
+ import urllib
4
+ import warnings
5
+ from typing import Union, List
6
+
7
+ import torch
8
+ from PIL import Image
9
+ from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
10
+ from tqdm import tqdm
11
+
12
+ from .model import build_model
13
+ from .simple_tokenizer import SimpleTokenizer as _Tokenizer
14
+
15
+ try:
16
+ from torchvision.transforms import InterpolationMode
17
+ BICUBIC = InterpolationMode.BICUBIC
18
+ except ImportError:
19
+ BICUBIC = Image.BICUBIC
20
+
21
+
22
+ if torch.__version__.split(".") < ["1", "7", "1"]:
23
+ warnings.warn("PyTorch version 1.7.1 or higher is recommended")
24
+
25
+
26
+ __all__ = ["available_models", "load", "tokenize"]
27
+ _tokenizer = _Tokenizer()
28
+
29
+ _MODELS = {
30
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
31
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
32
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
33
+ "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
34
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
35
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
36
+ "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
37
+ "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
38
+ }
39
+
40
+
41
+ def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
42
+ os.makedirs(root, exist_ok=True)
43
+ filename = os.path.basename(url)
44
+
45
+ expected_sha256 = url.split("/")[-2]
46
+ download_target = os.path.join(root, filename)
47
+
48
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
49
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
50
+
51
+ if os.path.isfile(download_target):
52
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
53
+ return download_target
54
+ else:
55
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
56
+
57
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
58
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
59
+ while True:
60
+ buffer = source.read(8192)
61
+ if not buffer:
62
+ break
63
+
64
+ output.write(buffer)
65
+ loop.update(len(buffer))
66
+
67
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
68
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
69
+
70
+ return download_target
71
+
72
+
73
+ def _transform(n_px):
74
+ return Compose([
75
+ Resize(n_px, interpolation=BICUBIC),
76
+ CenterCrop(n_px),
77
+ lambda image: image.convert("RGB"),
78
+ ToTensor(),
79
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
80
+ ])
81
+
82
+
83
+ def available_models() -> List[str]:
84
+ """Returns the names of available CLIP models"""
85
+ return list(_MODELS.keys())
86
+
87
+
88
+ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False):
89
+ """Load a CLIP model
90
+
91
+ Parameters
92
+ ----------
93
+ name : str
94
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
95
+
96
+ device : Union[str, torch.device]
97
+ The device to put the loaded model
98
+
99
+ jit : bool
100
+ Whether to load the optimized JIT model or more hackable non-JIT model (default).
101
+
102
+ Returns
103
+ -------
104
+ model : torch.nn.Module
105
+ The CLIP model
106
+
107
+ preprocess : Callable[[PIL.Image], torch.Tensor]
108
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
109
+ """
110
+
111
+ if name in _MODELS:
112
+ model_path = _download(_MODELS[name])
113
+ elif os.path.isfile(name):
114
+ model_path = name
115
+ else:
116
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
117
+
118
+ #model_path = 'E:/code/lsn/clip/RN50.pt'
119
+ # model_path = 'E:/code/lsn/clip/ViT-B-16.pt'
120
+
121
+
122
+ try:
123
+ # loading JIT archive
124
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
125
+ state_dict = None
126
+ except RuntimeError:
127
+ # loading saved state dict
128
+ if jit:
129
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
130
+ jit = False
131
+ state_dict = torch.load(model_path, map_location="cpu")
132
+
133
+ if not jit:
134
+ model = build_model(state_dict or model.state_dict()).to(device)
135
+ if str(device) == "cpu":
136
+ model.float()
137
+ return model, _transform(model.visual.input_resolution)
138
+
139
+ # patch the device names
140
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
141
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
142
+
143
+ def patch_device(module):
144
+ try:
145
+ graphs = [module.graph] if hasattr(module, "graph") else []
146
+ except RuntimeError:
147
+ graphs = []
148
+
149
+ if hasattr(module, "forward1"):
150
+ graphs.append(module.forward1.graph)
151
+
152
+ for graph in graphs:
153
+ for node in graph.findAllNodes("prim::Constant"):
154
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
155
+ node.copyAttributes(device_node)
156
+
157
+ model.apply(patch_device)
158
+ patch_device(model.encode_image)
159
+ patch_device(model.encode_text)
160
+
161
+ # patch dtype to float32 on CPU
162
+ if str(device) == "cpu":
163
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
164
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
165
+ float_node = float_input.node()
166
+
167
+ def patch_float(module):
168
+ try:
169
+ graphs = [module.graph] if hasattr(module, "graph") else []
170
+ except RuntimeError:
171
+ graphs = []
172
+
173
+ if hasattr(module, "forward1"):
174
+ graphs.append(module.forward1.graph)
175
+
176
+ for graph in graphs:
177
+ for node in graph.findAllNodes("aten::to"):
178
+ inputs = list(node.inputs())
179
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
180
+ if inputs[i].node()["value"] == 5:
181
+ inputs[i].node().copyAttributes(float_node)
182
+
183
+ model.apply(patch_float)
184
+ patch_float(model.encode_image)
185
+ patch_float(model.encode_text)
186
+
187
+ model.float()
188
+
189
+ return model, _transform(model.input_resolution.item())
190
+
191
+
192
+ def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
193
+ """
194
+ Returns the tokenized representation of given input string(s)
195
+
196
+ Parameters
197
+ ----------
198
+ texts : Union[str, List[str]]
199
+ An input string or a list of input strings to tokenize
200
+
201
+ context_length : int
202
+ The context length to use; all CLIP models use 77 as the context length
203
+
204
+ truncate: bool
205
+ Whether to truncate the text in case its encoding is longer than the context length
206
+
207
+ Returns
208
+ -------
209
+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
210
+ """
211
+ if isinstance(texts, str):
212
+ texts = [texts]
213
+
214
+ sot_token = _tokenizer.encoder["<|startoftext|>"]
215
+ eot_token = _tokenizer.encoder["<|endoftext|>"]
216
+ all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
217
+ result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
218
+
219
+ for i, tokens in enumerate(all_tokens):
220
+ if len(tokens) > context_length:
221
+ if truncate:
222
+ tokens = tokens[:context_length]
223
+ tokens[-1] = eot_token
224
+ else:
225
+ raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
226
+ result[i, :len(tokens)] = torch.tensor(tokens)
227
+
228
+ return result
clip/model.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from typing import Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+
9
+
10
+ class Bottleneck(nn.Module):
11
+ expansion = 4
12
+
13
+ def __init__(self, inplanes, planes, stride=1):
14
+ super().__init__()
15
+
16
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18
+ self.bn1 = nn.BatchNorm2d(planes)
19
+
20
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
21
+ self.bn2 = nn.BatchNorm2d(planes)
22
+
23
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
24
+
25
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
26
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
27
+
28
+ self.relu = nn.ReLU(inplace=True)
29
+ self.downsample = None
30
+ self.stride = stride
31
+
32
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
33
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
34
+ self.downsample = nn.Sequential(OrderedDict([
35
+ ("-1", nn.AvgPool2d(stride)),
36
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
37
+ ("1", nn.BatchNorm2d(planes * self.expansion))
38
+ ]))
39
+
40
+ def forward(self, x: torch.Tensor):
41
+ identity = x
42
+
43
+ out = self.relu(self.bn1(self.conv1(x)))
44
+ out = self.relu(self.bn2(self.conv2(out)))
45
+ out = self.avgpool(out)
46
+ out = self.bn3(self.conv3(out))
47
+
48
+ if self.downsample is not None:
49
+ identity = self.downsample(x)
50
+
51
+ out += identity
52
+ out = self.relu(out)
53
+ return out
54
+
55
+
56
+ class AttentionPool2d(nn.Module):
57
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
58
+ super().__init__()
59
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
60
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
61
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
62
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
63
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
64
+ self.num_heads = num_heads
65
+
66
+ def forward(self, x):
67
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
68
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
69
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
70
+ x, _ = F.multi_head_attention_forward(
71
+ query=x, key=x, value=x,
72
+ embed_dim_to_check=x.shape[-1],
73
+ num_heads=self.num_heads,
74
+ q_proj_weight=self.q_proj.weight,
75
+ k_proj_weight=self.k_proj.weight,
76
+ v_proj_weight=self.v_proj.weight,
77
+ in_proj_weight=None,
78
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
79
+ bias_k=None,
80
+ bias_v=None,
81
+ add_zero_attn=False,
82
+ dropout_p=0,
83
+ out_proj_weight=self.c_proj.weight,
84
+ out_proj_bias=self.c_proj.bias,
85
+ use_separate_proj_weight=True,
86
+ training=self.training,
87
+ need_weights=False
88
+ )
89
+
90
+ return x[0]
91
+
92
+
93
+ class ModifiedResNet(nn.Module):
94
+ """
95
+ A ResNet class that is similar to torchvision's but contains the following changes:
96
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
97
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
98
+ - The final pooling layer is a QKV attention instead of an average pool
99
+ """
100
+
101
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
102
+ super().__init__()
103
+ self.output_dim = output_dim
104
+ self.input_resolution = input_resolution
105
+
106
+ # the 3-layer stem
107
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
108
+ self.bn1 = nn.BatchNorm2d(width // 2)
109
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
110
+ self.bn2 = nn.BatchNorm2d(width // 2)
111
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
112
+ self.bn3 = nn.BatchNorm2d(width)
113
+ self.avgpool = nn.AvgPool2d(2)
114
+ self.relu = nn.ReLU(inplace=True)
115
+
116
+ # residual layers
117
+ self._inplanes = width # this is a *mutable* variable used during construction
118
+ self.layer1 = self._make_layer(width, layers[0])
119
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
120
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
121
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
122
+
123
+ embed_dim = width * 32 # the ResNet feature dimension
124
+ self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
125
+
126
+ def _make_layer(self, planes, blocks, stride=1):
127
+ layers = [Bottleneck(self._inplanes, planes, stride)]
128
+
129
+ self._inplanes = planes * Bottleneck.expansion
130
+ for _ in range(1, blocks):
131
+ layers.append(Bottleneck(self._inplanes, planes))
132
+
133
+ return nn.Sequential(*layers)
134
+
135
+ def forward(self, x):
136
+ def stem(x):
137
+ for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
138
+ x = self.relu(bn(conv(x)))
139
+ x = self.avgpool(x)
140
+ return x
141
+
142
+ x = x.type(self.conv1.weight.dtype)
143
+ x = stem(x)
144
+ x = self.layer1(x)
145
+ x = self.layer2(x)
146
+ x = self.layer3(x)
147
+ x = self.layer4(x)
148
+ x = self.attnpool(x)
149
+
150
+ return x
151
+
152
+
153
+ class LayerNorm(nn.LayerNorm):
154
+ """Subclass torch's LayerNorm to handle fp16."""
155
+
156
+ def forward(self, x: torch.Tensor):
157
+ orig_type = x.dtype
158
+ ret = super().forward(x.type(torch.float32))
159
+ return ret.type(orig_type)
160
+
161
+
162
+ class QuickGELU(nn.Module):
163
+ def forward(self, x: torch.Tensor):
164
+ return x * torch.sigmoid(1.702 * x)
165
+
166
+
167
+ class ResidualAttentionBlock(nn.Module):
168
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
169
+ super().__init__()
170
+
171
+ self.attn = nn.MultiheadAttention(d_model, n_head)
172
+ self.ln_1 = LayerNorm(d_model)
173
+ self.mlp = nn.Sequential(OrderedDict([
174
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
175
+ ("gelu", QuickGELU()),
176
+ ("c_proj", nn.Linear(d_model * 4, d_model))
177
+ ]))
178
+ self.ln_2 = LayerNorm(d_model)
179
+ self.attn_mask = attn_mask
180
+
181
+ def attention(self, x: torch.Tensor):
182
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
183
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
184
+
185
+ def forward(self, x: torch.Tensor):
186
+ x = x + self.attention(self.ln_1(x))
187
+ x = x + self.mlp(self.ln_2(x))
188
+ return x
189
+
190
+
191
+ class Transformer(nn.Module):
192
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
193
+ super().__init__()
194
+ self.width = width
195
+ self.layers = layers
196
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
197
+
198
+ def forward(self, x: torch.Tensor):
199
+ return self.resblocks(x)
200
+
201
+
202
+ class VisionTransformer(nn.Module):
203
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
204
+ super().__init__()
205
+ self.input_resolution = input_resolution
206
+ self.output_dim = output_dim
207
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
208
+
209
+ scale = width ** -0.5
210
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
211
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
212
+ self.ln_pre = LayerNorm(width)
213
+
214
+ self.transformer = Transformer(width, layers, heads)
215
+
216
+ self.ln_post = LayerNorm(width)
217
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
218
+
219
+ def forward(self, x: torch.Tensor):
220
+ x = self.conv1(x) # shape = [*, width, grid, grid]
221
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
222
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
223
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
224
+ x = x + self.positional_embedding.to(x.dtype)
225
+ x = self.ln_pre(x)
226
+
227
+ x = x.permute(1, 0, 2) # NLD -> LND
228
+ x = self.transformer(x)
229
+ x = x.permute(1, 0, 2) # LND -> NLD
230
+
231
+ x = self.ln_post(x[:, 0, :])
232
+
233
+ # if self.proj is not None:
234
+ # x = x @ self.proj
235
+
236
+ return x
237
+
238
+
239
+ class CLIP(nn.Module):
240
+ def __init__(self,
241
+ embed_dim: int,
242
+ # vision
243
+ image_resolution: int,
244
+ vision_layers: Union[Tuple[int, int, int, int], int],
245
+ vision_width: int,
246
+ vision_patch_size: int,
247
+ # text
248
+ context_length: int,
249
+ vocab_size: int,
250
+ transformer_width: int,
251
+ transformer_heads: int,
252
+ transformer_layers: int
253
+ ):
254
+ super().__init__()
255
+
256
+ self.context_length = context_length
257
+
258
+ if isinstance(vision_layers, (tuple, list)):
259
+ vision_heads = vision_width * 32 // 64
260
+ self.visual = ModifiedResNet(
261
+ layers=vision_layers,
262
+ output_dim=embed_dim,
263
+ heads=vision_heads,
264
+ input_resolution=image_resolution,
265
+ width=vision_width
266
+ )
267
+ else:
268
+ vision_heads = vision_width // 64
269
+ self.visual = VisionTransformer(
270
+ input_resolution=image_resolution,
271
+ patch_size=vision_patch_size,
272
+ width=vision_width,
273
+ layers=vision_layers,
274
+ heads=vision_heads,
275
+ output_dim=embed_dim
276
+ )
277
+
278
+ self.transformer = Transformer(
279
+ width=transformer_width,
280
+ layers=transformer_layers,
281
+ heads=transformer_heads,
282
+ attn_mask=self.build_attention_mask()
283
+ )
284
+
285
+ self.vocab_size = vocab_size
286
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
287
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
288
+ self.ln_final = LayerNorm(transformer_width)
289
+
290
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
291
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
292
+
293
+ self.initialize_parameters()
294
+
295
+ def initialize_parameters(self):
296
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
297
+ nn.init.normal_(self.positional_embedding, std=0.01)
298
+
299
+ if isinstance(self.visual, ModifiedResNet):
300
+ if self.visual.attnpool is not None:
301
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
302
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
303
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
304
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
305
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
306
+
307
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
308
+ for name, param in resnet_block.named_parameters():
309
+ if name.endswith("bn3.weight"):
310
+ nn.init.zeros_(param)
311
+
312
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
313
+ attn_std = self.transformer.width ** -0.5
314
+ fc_std = (2 * self.transformer.width) ** -0.5
315
+ for block in self.transformer.resblocks:
316
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
317
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
318
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
319
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
320
+
321
+ if self.text_projection is not None:
322
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
323
+
324
+ def build_attention_mask(self):
325
+ # lazily create causal attention mask, with full attention between the vision tokens
326
+ # pytorch uses additive attention mask; fill with -inf
327
+ mask = torch.empty(self.context_length, self.context_length)
328
+ mask.fill_(float("-inf"))
329
+ mask.triu_(1) # zero out the lower diagonal
330
+ return mask
331
+
332
+ @property
333
+ def dtype(self):
334
+ return self.visual.conv1.weight.dtype
335
+
336
+ def encode_image(self, image):
337
+ return self.visual(image.type(self.dtype))
338
+
339
+ def encode_text(self, text):
340
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
341
+
342
+ x = x + self.positional_embedding.type(self.dtype)
343
+ x = x.permute(1, 0, 2) # NLD -> LND
344
+ x = self.transformer(x)
345
+ x = x.permute(1, 0, 2) # LND -> NLD
346
+ x = self.ln_final(x).type(self.dtype)
347
+
348
+ # x.shape = [batch_size, n_ctx, transformer.width]
349
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
350
+ x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
351
+
352
+ return x
353
+
354
+ def forward(self, image, text):
355
+ image_features = self.encode_image(image)
356
+ text_features = self.encode_text(text)
357
+
358
+ # normalized features
359
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
360
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
361
+
362
+ # cosine similarity as logits
363
+ logit_scale = self.logit_scale.exp()
364
+ logits_per_image = logit_scale * image_features @ text_features.t()
365
+ logits_per_text = logit_scale * text_features @ image_features.t()
366
+
367
+ # shape = [global_batch_size, global_batch_size]
368
+ return logits_per_image, logits_per_text
369
+
370
+
371
+ def convert_weights(model: nn.Module):
372
+ """Convert applicable model parameters to fp16"""
373
+
374
+ def _convert_weights_to_fp16(l):
375
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
376
+ l.weight.data = l.weight.data.half()
377
+ if l.bias is not None:
378
+ l.bias.data = l.bias.data.half()
379
+
380
+ if isinstance(l, nn.MultiheadAttention):
381
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
382
+ tensor = getattr(l, attr)
383
+ if tensor is not None:
384
+ tensor.data = tensor.data.half()
385
+
386
+ for name in ["text_projection", "proj"]:
387
+ if hasattr(l, name):
388
+ attr = getattr(l, name)
389
+ if attr is not None:
390
+ attr.data = attr.data.half()
391
+
392
+ model.apply(_convert_weights_to_fp16)
393
+
394
+
395
+ def build_model(state_dict: dict):
396
+ vit = "visual.proj" in state_dict
397
+
398
+ if vit:
399
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
400
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
401
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
402
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
403
+ image_resolution = vision_patch_size * grid_size
404
+ else:
405
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
406
+ vision_layers = tuple(counts)
407
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
408
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
409
+ vision_patch_size = None
410
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
411
+ image_resolution = output_width * 32
412
+
413
+ embed_dim = state_dict["text_projection"].shape[1]
414
+ context_length = state_dict["positional_embedding"].shape[0]
415
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
416
+ transformer_width = state_dict["ln_final.weight"].shape[0]
417
+ transformer_heads = transformer_width // 64
418
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
419
+
420
+ model = CLIP(
421
+ embed_dim,
422
+ image_resolution, vision_layers, vision_width, vision_patch_size,
423
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
424
+ )
425
+
426
+ for key in ["input_resolution", "context_length", "vocab_size"]:
427
+ if key in state_dict:
428
+ del state_dict[key]
429
+
430
+ convert_weights(model)
431
+ model.load_state_dict(state_dict)
432
+ return model.eval()
clip/simple_tokenizer.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import html
3
+ import os
4
+ from functools import lru_cache
5
+
6
+ import ftfy
7
+ import regex as re
8
+
9
+
10
+ @lru_cache()
11
+ def default_bpe():
12
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13
+
14
+
15
+ @lru_cache()
16
+ def bytes_to_unicode():
17
+ """
18
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
19
+ The reversible bpe codes work on unicode strings.
20
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
23
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
25
+ """
26
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27
+ cs = bs[:]
28
+ n = 0
29
+ for b in range(2**8):
30
+ if b not in bs:
31
+ bs.append(b)
32
+ cs.append(2**8+n)
33
+ n += 1
34
+ cs = [chr(n) for n in cs]
35
+ return dict(zip(bs, cs))
36
+
37
+
38
+ def get_pairs(word):
39
+ """Return set of symbol pairs in a word.
40
+ Word is represented as tuple of symbols (symbols being variable-length strings).
41
+ """
42
+ pairs = set()
43
+ prev_char = word[0]
44
+ for char in word[1:]:
45
+ pairs.add((prev_char, char))
46
+ prev_char = char
47
+ return pairs
48
+
49
+
50
+ def basic_clean(text):
51
+ text = ftfy.fix_text(text)
52
+ text = html.unescape(html.unescape(text))
53
+ return text.strip()
54
+
55
+
56
+ def whitespace_clean(text):
57
+ text = re.sub(r'\s+', ' ', text)
58
+ text = text.strip()
59
+ return text
60
+
61
+
62
+ class SimpleTokenizer(object):
63
+ def __init__(self, bpe_path: str = default_bpe()):
64
+ self.byte_encoder = bytes_to_unicode()
65
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67
+ merges = merges[1:49152-256-2+1]
68
+ merges = [tuple(merge.split()) for merge in merges]
69
+ vocab = list(bytes_to_unicode().values())
70
+ vocab = vocab + [v+'</w>' for v in vocab]
71
+ for merge in merges:
72
+ vocab.append(''.join(merge))
73
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74
+ self.encoder = dict(zip(vocab, range(len(vocab))))
75
+ self.decoder = {v: k for k, v in self.encoder.items()}
76
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
77
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79
+
80
+ def bpe(self, token):
81
+ if token in self.cache:
82
+ return self.cache[token]
83
+ word = tuple(token[:-1]) + ( token[-1] + '</w>',)
84
+ pairs = get_pairs(word)
85
+
86
+ if not pairs:
87
+ return token+'</w>'
88
+
89
+ while True:
90
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91
+ if bigram not in self.bpe_ranks:
92
+ break
93
+ first, second = bigram
94
+ new_word = []
95
+ i = 0
96
+ while i < len(word):
97
+ try:
98
+ j = word.index(first, i)
99
+ new_word.extend(word[i:j])
100
+ i = j
101
+ except:
102
+ new_word.extend(word[i:])
103
+ break
104
+
105
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
106
+ new_word.append(first+second)
107
+ i += 2
108
+ else:
109
+ new_word.append(word[i])
110
+ i += 1
111
+ new_word = tuple(new_word)
112
+ word = new_word
113
+ if len(word) == 1:
114
+ break
115
+ else:
116
+ pairs = get_pairs(word)
117
+ word = ' '.join(word)
118
+ self.cache[token] = word
119
+ return word
120
+
121
+ def encode(self, text):
122
+ bpe_tokens = []
123
+ text = whitespace_clean(basic_clean(text)).lower()
124
+ for token in re.findall(self.pat, text):
125
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127
+ return bpe_tokens
128
+
129
+ def decode(self, tokens):
130
+ text = ''.join([self.decoder[token] for token in tokens])
131
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
132
+ return text
config.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''This file configures the training procedure because handling arguments in every single function is so exhaustive for
2
+ research purposes. Don't try this code if you are a software engineer.'''
3
+
4
+ # device settings
5
+ device = 'cuda' # or 'cpu'
6
+ import torch
7
+ torch.cuda.set_device(0)
8
+
9
+ # data settings
10
+ dataset_path = "dummy_dataset"
11
+ class_name = "dummy_class"
12
+ modelname = "dummy_test"
13
+
14
+ img_size = (448, 448)
15
+ img_dims = [3] + list(img_size)
16
+
17
+ # transformation settings
18
+ transf_rotations = True
19
+ transf_brightness = 0.0
20
+ transf_contrast = 0.0
21
+ transf_saturation = 0.0
22
+ norm_mean, norm_std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
23
+
24
+ # network hyperparameters
25
+ n_scales = 3 # number of scales at which features are extracted, img_size is the highest - others are //2, //4,...
26
+ clamp_alpha = 3 # see paper equation 2 for explanation
27
+ n_coupling_blocks = 2
28
+ fc_internal = 4096 # number of neurons in hidden layers of s-t-networks
29
+ dropout = 0# dropout in s-t-networks
30
+ lr_init = 2e-4
31
+ n_feat = 256 * n_scales # do not change except you change the feature extractor
32
+
33
+ # dataloader parameters
34
+ n_transforms = 4 # number of transformations per sample in training
35
+ n_transforms_test = 64 # number of transformations per sample in testing
36
+ batch_size = 24 # actual batch size is this value multiplied by n_transforms(_test)
37
+ batch_size_test = batch_size * n_transforms // n_transforms_test
38
+
39
+ # total epochs = meta_epochs * sub_epochs
40
+ # evaluation after <sub_epochs> epochs
41
+ meta_epochs = 24
42
+ sub_epochs = 8
43
+
44
+ # output settings
45
+ verbose = True
46
+ grad_map_viz = False
47
+ hide_tqdm_bar = True
48
+ save_model = True
detector.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import clip
4
+ from PIL import Image
5
+ from torch.cuda.amp import autocast as autocast
6
+ from huggingface_hub import hf_hub_download
7
+ import spaces
8
+
9
+ from model import flow_model
10
+ from augmentations_clip import DataAugmentationCLIP as DataAugmentationCLIP_test
11
+
12
+ MODEL_REPO_ID = "davjoython/flow_fake"
13
+ FLOW_MODEL_FILENAME = "flow_fake_detector_centercrop_v4.pth"
14
+ CLIP_MODEL_FILENAME = "my_clip_ViT-L-14.pt"
15
+ class FakeImageDetector:
16
+
17
+ def __init__(self):
18
+
19
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ print(f"检测器初始化在 CPU 上,运行时将使用 {self.device}")
21
+
22
+ print(f"正在从 {MODEL_REPO_ID} 下载 CLIP 模型...")
23
+ clip_model_path = hf_hub_download(
24
+ repo_id=MODEL_REPO_ID,
25
+ filename=CLIP_MODEL_FILENAME
26
+ )
27
+ print("CLIP 模型已下载。")
28
+ self.clip_model, _ = clip.load(clip_model_path, device="cpu")
29
+ self.clip_model.eval()
30
+ print("CLIP 模型已加载到 CPU。")
31
+
32
+ print(f"正在从 {MODEL_REPO_ID} 下载 Flow 模型...")
33
+ flow_model_path = hf_hub_download(
34
+ repo_id=MODEL_REPO_ID,
35
+ filename=FLOW_MODEL_FILENAME
36
+ )
37
+ print("Flow 模型已下载。")
38
+ self.flow = flow_model()
39
+ self.flow.load_state_dict(torch.load(flow_model_path, map_location="cpu"))
40
+ self.flow = self.flow.to("cpu")
41
+ self.flow.eval()
42
+ print("Flow 模型已加载到 CPU。")
43
+
44
+ print("模型加载完成。")
45
+
46
+ self.transform = DataAugmentationCLIP_test(
47
+ (0.9, 1.0), (0.05, 0.4), 1,
48
+ global_crops_size=224, local_crops_size=96,
49
+ )
50
+
51
+ @spaces.GPU(duration=10)
52
+ def detect(self, image_pil, threshold=0.5):
53
+
54
+ if not isinstance(image_pil, Image.Image):
55
+ raise TypeError("输入必须是 PIL Image 对象")
56
+
57
+ img_rgb = image_pil.convert("RGB")
58
+
59
+ current_device = "cuda" if torch.cuda.is_available() else "cpu"
60
+
61
+ flow_model_gpu = self.flow.to(current_device)
62
+ clip_model_gpu = self.clip_model.to(current_device)
63
+
64
+ transformed_img_dict = self.transform(img_rgb)
65
+ img_tensor = transformed_img_dict["source"][0].unsqueeze(0).to(current_device)
66
+
67
+ with torch.no_grad():
68
+ if current_device == "cuda":
69
+ with autocast():
70
+ embedding = clip_model_gpu.visual(img_tensor.half())
71
+ z = flow_model_gpu(embedding)
72
+ score = 1 - torch.sigmoid(torch.mean(z.float()**2 / 10000, dim=1)).item()
73
+ else:
74
+ embedding = clip_model_gpu.visual(img_tensor)
75
+ z = flow_model_gpu(embedding.float())
76
+ score = 1 - torch.sigmoid(torch.mean(z.float()**2 / 10000, dim=1)).item()
77
+
78
+ if current_device == "cuda":
79
+ torch.cuda.empty_cache()
80
+
81
+ if score > threshold:
82
+ result_text = f"结论: 伪造的 (Fake)\n分数: {score:.10f}"
83
+ else:
84
+ result_text = f"结论: 真实的 (Real)\n分数: {score:.10f}"
85
+
86
+ return result_text, score
freia_funcs.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''This Code is based on the FrEIA Framework, source: https://github.com/VLL-HD/FrEIA
2
+ It is a assembly of the necessary modules/functions from FrEIA that are needed for our purposes.'''
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.autograd import Variable
7
+ from math import exp
8
+ import numpy as np
9
+ VERBOSE = False
10
+
11
+
12
+ class dummy_data:
13
+ def __init__(self, *dims):
14
+ self.dims = dims
15
+
16
+ @property
17
+ def shape(self):
18
+ return self.dims
19
+
20
+ class F_fully_connected(nn.Module):
21
+ '''Fully connected tranformation, not reversible, but used below.'''
22
+
23
+ def __init__(self, size_in, size, internal_size=None, dropout=0.0):
24
+ super(F_fully_connected, self).__init__()
25
+ if not internal_size:
26
+ internal_size = 2*size
27
+
28
+ self.d1 = nn.Dropout(p=dropout)
29
+ self.d2 = nn.Dropout(p=dropout)
30
+ self.d2b = nn.Dropout(p=dropout)
31
+
32
+ self.fc1 = nn.Linear(size_in, internal_size)
33
+ self.fc2 = nn.Linear(internal_size, internal_size)
34
+ self.fc2b = nn.Linear(internal_size, internal_size)
35
+ self.fc3 = nn.Linear(internal_size, size)
36
+
37
+ self.nl1 = nn.ReLU()
38
+ self.nl2 = nn.ReLU()
39
+ self.nl2b = nn.ReLU()
40
+
41
+ self.bn = nn.BatchNorm1d(size_in)
42
+
43
+
44
+ def forward(self, x):
45
+ out = self.nl1(self.d1(self.fc1(x)))
46
+ out = self.nl2(self.d2(self.fc2(out)))
47
+ out = self.nl2b(self.d2b(self.fc2b(out)))
48
+ out = self.fc3(out)
49
+ return out
50
+
51
+ class permute_layer(nn.Module):
52
+ '''permutes input vector in a random but fixed way'''
53
+
54
+ def __init__(self, dims_in, seed):
55
+ super(permute_layer, self).__init__()
56
+ self.in_channels = dims_in[0][0]
57
+
58
+ np.random.seed(seed)
59
+ self.perm = np.random.permutation(self.in_channels)
60
+ np.random.seed()
61
+
62
+ self.perm_inv = np.zeros_like(self.perm)
63
+ for i, p in enumerate(self.perm):
64
+ self.perm_inv[p] = i
65
+
66
+ self.perm = torch.LongTensor(self.perm)
67
+ self.perm_inv = torch.LongTensor(self.perm_inv)
68
+
69
+ def forward(self, x, rev=False):
70
+ if not rev:
71
+ return [x[0][:, self.perm]]
72
+ else:
73
+ return [x[0][:, self.perm_inv]]
74
+
75
+ def jacobian(self, x, rev=False):
76
+ # TODO: use batch size, set as nn.Parameter so cuda() works
77
+ return 0.
78
+
79
+ def output_dims(self, input_dims):
80
+ assert len(input_dims) == 1, "Can only use 1 input"
81
+ return input_dims
82
+
83
+
84
+
85
+ class glow_coupling_layer(nn.Module):
86
+ def __init__(self, dims_in, F_class=F_fully_connected, F_args={},
87
+ clamp=5.):
88
+ super(glow_coupling_layer, self).__init__()
89
+ channels = dims_in[0][0]
90
+ self.ndims = len(dims_in[0])
91
+
92
+ self.split_len1 = channels // 2
93
+ self.split_len2 = channels - channels // 2
94
+
95
+ self.clamp = clamp
96
+ self.max_s = exp(clamp)
97
+ self.min_s = exp(-clamp)
98
+
99
+ self.s1 = F_class(self.split_len1, self.split_len2*2, **F_args)
100
+ self.s2 = F_class(self.split_len2, self.split_len1*2, **F_args)
101
+
102
+ def e(self, s):
103
+ return torch.exp(self.log_e(s))
104
+
105
+ def log_e(self, s):
106
+ return self.clamp * 0.636 * torch.atan(s / self.clamp)
107
+
108
+ def forward(self, x, rev=False):
109
+ x1, x2 = (x[0].narrow(1, 0, self.split_len1),
110
+ x[0].narrow(1, self.split_len1, self.split_len2))
111
+
112
+ if not rev:
113
+ r2 = self.s2(x2)
114
+ s2, t2 = r2[:, :self.split_len1], r2[:, self.split_len1:]
115
+ #print(s2.shape, x1.shape, t2.shape)
116
+ y1 = self.e(s2) * x1 + t2
117
+
118
+ r1 = self.s1(y1)
119
+ s1, t1 = r1[:, :self.split_len2], r1[:, self.split_len2:]
120
+ y2 = self.e(s1) * x2 + t1
121
+
122
+ else: # names of x and y are swapped!
123
+ r1 = self.s1(x1)
124
+ s1, t1 = r1[:, :self.split_len2], r1[:, self.split_len2:]
125
+ y2 = (x2 - t1) / self.e(s1)
126
+
127
+ r2 = self.s2(y2)
128
+ s2, t2 = r2[:, :self.split_len1], r2[:, self.split_len1:]
129
+ y1 = (x1 - t2) / self.e(s2)
130
+ y = torch.cat((y1, y2), 1)
131
+ y = torch.clamp(y, -1e6, 1e6)
132
+ return [y]
133
+
134
+ def jacobian(self, x, rev=False):
135
+ x1, x2 = (x[0].narrow(1, 0, self.split_len1),
136
+ x[0].narrow(1, self.split_len1, self.split_len2))
137
+
138
+ if not rev:
139
+ r2 = self.s2(x2)
140
+ s2, t2 = r2[:, :self.split_len1], r2[:, self.split_len1:]
141
+ y1 = self.e(s2) * x1 + t2
142
+
143
+ r1 = self.s1(y1)
144
+ s1, t1 = r1[:, :self.split_len2], r1[:, self.split_len2:]
145
+
146
+ else: # names of x and y are swapped!
147
+ r1 = self.s1(x1)
148
+ s1, t1 = r1[:, :self.split_len2], r1[:, self.split_len2:]
149
+ y2 = (x2 - t1) / self.e(s1)
150
+
151
+ r2 = self.s2(y2)
152
+ s2, t2 = r2[:, :self.split_len1], r2[:, self.split_len1:]
153
+
154
+ jac = (torch.sum(self.log_e(s1), dim=1)
155
+ + torch.sum(self.log_e(s2), dim=1))
156
+ for i in range(self.ndims-1):
157
+ jac = torch.sum(jac, dim=1)
158
+
159
+ return jac
160
+
161
+ def output_dims(self, input_dims):
162
+ assert len(input_dims) == 1, "Can only use 1 input"
163
+ return input_dims
164
+
165
+ class Node:
166
+ '''The Node class represents one transformation in the graph, with an
167
+ arbitrary number of in- and outputs.'''
168
+ def __init__(self, inputs, module_type, module_args, name=None):
169
+ self.inputs = inputs
170
+ self.outputs = []
171
+ self.module_type = module_type
172
+ self.module_args = module_args
173
+
174
+ self.input_dims, self.module = None, None
175
+ self.computed = None
176
+ self.computed_rev = None
177
+ self.id = None
178
+
179
+ if name:
180
+ self.name = name
181
+ else:
182
+ self.name = hex(id(self))[-6:]
183
+ for i in range(255):
184
+ exec('self.out{0} = (self, {0})'.format(i))
185
+
186
+ def build_modules(self, verbose=VERBOSE):
187
+ ''' Returns a list with the dimension of each output of this node,
188
+ recursively calling build_modules of the nodes connected to the input.
189
+ Use this information to initialize the pytorch nn.Module of this node.
190
+ '''
191
+
192
+ if not self.input_dims: # Only do it if this hasn't been computed yet
193
+ self.input_dims = [n.build_modules(verbose=verbose)[c]
194
+ for n, c in self.inputs]
195
+ try:
196
+ self.module = self.module_type(self.input_dims,
197
+ **self.module_args)
198
+ except Exception as e:
199
+ print('Error in node %s' % (self.name))
200
+ raise e
201
+
202
+ if verbose:
203
+ print("Node %s has following input dimensions:" % (self.name))
204
+ for d, (n, c) in zip(self.input_dims, self.inputs):
205
+ print("\t Output #%i of node %s:" % (c, n.name), d)
206
+ print()
207
+
208
+ self.output_dims = self.module.output_dims(self.input_dims)
209
+ self.n_outputs = len(self.output_dims)
210
+
211
+ return self.output_dims
212
+
213
+ def run_forward(self, op_list):
214
+ '''Determine the order of operations needed to reach this node. Calls
215
+ run_forward of parent nodes recursively. Each operation is appended to
216
+ the global list op_list, in the form (node ID, input variable IDs,
217
+ output variable IDs)'''
218
+
219
+ if not self.computed:
220
+
221
+ # Compute all nodes which provide inputs, filter out the
222
+ # channels you need
223
+ self.input_vars = []
224
+ for i, (n, c) in enumerate(self.inputs):
225
+ self.input_vars.append(n.run_forward(op_list)[c])
226
+ # Register youself as an output in the input node
227
+ n.outputs.append((self, i))
228
+
229
+ # All outputs could now be computed
230
+ self.computed = [(self.id, i) for i in range(self.n_outputs)]
231
+ op_list.append((self.id, self.input_vars, self.computed))
232
+
233
+ # Return the variables you have computed (this happens mulitple times
234
+ # without recomputing if called repeatedly)
235
+ return self.computed
236
+
237
+ def run_backward(self, op_list):
238
+ '''See run_forward, this is the same, only for the reverse computation.
239
+ Need to call run_forward first, otherwise this function will not
240
+ work'''
241
+
242
+ assert len(self.outputs) > 0, "Call run_forward first"
243
+ if not self.computed_rev:
244
+
245
+ # These are the input variables that must be computed first
246
+ output_vars = [(self.id, i) for i in range(self.n_outputs)]
247
+
248
+ # Recursively compute these
249
+ for n, c in self.outputs:
250
+ n.run_backward(op_list)
251
+
252
+ # The variables that this node computes are the input variables
253
+ # from the forward pass
254
+ self.computed_rev = self.input_vars
255
+ op_list.append((self.id, output_vars, self.computed_rev))
256
+
257
+ return self.computed_rev
258
+
259
+
260
+ class InputNode(Node):
261
+ '''Special type of node that represents the input data of the whole net (or
262
+ ouput when running reverse)'''
263
+
264
+ def __init__(self, *dims, name='node'):
265
+ self.name = name
266
+ self.data = dummy_data(*dims)
267
+ self.outputs = []
268
+ self.module = None
269
+ self.computed_rev = None
270
+ self.n_outputs = 1
271
+ self.input_vars = []
272
+ self.out0 = (self, 0)
273
+
274
+ def build_modules(self, verbose=VERBOSE):
275
+ return [self.data.shape]
276
+
277
+ def run_forward(self, op_list):
278
+ return [(self.id, 0)]
279
+
280
+
281
+ class OutputNode(Node):
282
+ '''Special type of node that represents the output of the whole net (of the
283
+ input when running in reverse)'''
284
+ class dummy(nn.Module):
285
+
286
+ def __init__(self, *args):
287
+ super(OutputNode.dummy, self).__init__()
288
+
289
+ def __call__(*args):
290
+ return args
291
+
292
+ def output_dims(*args):
293
+ return args
294
+
295
+ def __init__(self, inputs, name='node'):
296
+ self.module_type, self.module_args = self.dummy, {}
297
+ self.output_dims = []
298
+ self.inputs = inputs
299
+ self.input_dims, self.module = None, None
300
+ self.computed = None
301
+ self.id = None
302
+ self.name = name
303
+
304
+ for c, inp in enumerate(self.inputs):
305
+ inp[0].outputs.append((self, c))
306
+
307
+ def run_backward(self, op_list):
308
+ return [(self.id, 0)]
309
+
310
+
311
+ class ReversibleGraphNet(nn.Module):
312
+ '''This class represents the invertible net itself. It is a subclass of
313
+ torch.nn.Module and supports the same methods. The forward method has an
314
+ additional option 'rev', whith which the net can be computed in reverse.'''
315
+
316
+ def __init__(self, node_list, ind_in=None, ind_out=None, verbose=False):
317
+ '''node_list should be a list of all nodes involved, and ind_in,
318
+ ind_out are the indexes of the special nodes InputNode and OutputNode
319
+ in this list.'''
320
+ super(ReversibleGraphNet, self).__init__()
321
+
322
+ # Gather lists of input and output nodes
323
+ if ind_in is not None:
324
+ if isinstance(ind_in, int):
325
+ self.ind_in = list([ind_in])
326
+ else:
327
+ self.ind_in = ind_in
328
+ else:
329
+ self.ind_in = [i for i in range(len(node_list))
330
+ if isinstance(node_list[i], InputNode)]
331
+ assert len(self.ind_in) > 0, "No input nodes specified."
332
+ if ind_out is not None:
333
+ if isinstance(ind_out, int):
334
+ self.ind_out = list([ind_out])
335
+ else:
336
+ self.ind_out = ind_out
337
+ else:
338
+ self.ind_out = [i for i in range(len(node_list))
339
+ if isinstance(node_list[i], OutputNode)]
340
+ assert len(self.ind_out) > 0, "No output nodes specified."
341
+
342
+ self.return_vars = []
343
+ self.input_vars = []
344
+
345
+ # Assign each node a unique ID
346
+ self.node_list = node_list
347
+ for i, n in enumerate(node_list):
348
+ n.id = i
349
+
350
+ # Recursively build the nodes nn.Modules and determine order of
351
+ # operations
352
+ ops = []
353
+ for i in self.ind_out:
354
+ node_list[i].build_modules(verbose=verbose)
355
+ node_list[i].run_forward(ops)
356
+
357
+ # create list of Pytorch variables that are used
358
+ variables = set()
359
+ for o in ops:
360
+ variables = variables.union(set(o[1] + o[2]))
361
+ self.variables_ind = list(variables)
362
+
363
+ self.indexed_ops = self.ops_to_indexed(ops)
364
+
365
+ self.module_list = nn.ModuleList([n.module for n in node_list])
366
+ self.variable_list = [Variable(requires_grad=True) for v in variables]
367
+
368
+ # Find out the order of operations for reverse calculations
369
+ ops_rev = []
370
+ for i in self.ind_in:
371
+ node_list[i].run_backward(ops_rev)
372
+ self.indexed_ops_rev = self.ops_to_indexed(ops_rev)
373
+
374
+ def ops_to_indexed(self, ops):
375
+ '''Helper function to translate the list of variables (origin ID, channel),
376
+ to variable IDs.'''
377
+ result = []
378
+
379
+ for o in ops:
380
+ try:
381
+ vars_in = [self.variables_ind.index(v) for v in o[1]]
382
+ except ValueError:
383
+ vars_in = -1
384
+
385
+ vars_out = [self.variables_ind.index(v) for v in o[2]]
386
+
387
+ # Collect input/output nodes in separate lists, but don't add to
388
+ # indexed ops
389
+ if o[0] in self.ind_out:
390
+ self.return_vars.append(self.variables_ind.index(o[1][0]))
391
+ continue
392
+ if o[0] in self.ind_in:
393
+ self.input_vars.append(self.variables_ind.index(o[1][0]))
394
+ continue
395
+
396
+ result.append((o[0], vars_in, vars_out))
397
+
398
+ # Sort input/output variables so they correspond to initial node list
399
+ # order
400
+ self.return_vars.sort(key=lambda i: self.variables_ind[i][0])
401
+ self.input_vars.sort(key=lambda i: self.variables_ind[i][0])
402
+
403
+ return result
404
+
405
+ def forward(self, x, rev=False):
406
+ '''Forward or backward computation of the whole net.'''
407
+ if rev:
408
+ use_list = self.indexed_ops_rev
409
+ input_vars, output_vars = self.return_vars, self.input_vars
410
+ else:
411
+ use_list = self.indexed_ops
412
+ input_vars, output_vars = self.input_vars, self.return_vars
413
+
414
+ if isinstance(x, (list, tuple)):
415
+ assert len(x) == len(input_vars), (
416
+ f"Got list of {len(x)} input tensors for "
417
+ f"{'inverse' if rev else 'forward'} pass, but expected "
418
+ f"{len(input_vars)}."
419
+ )
420
+ for i in range(len(input_vars)):
421
+ self.variable_list[input_vars[i]] = x[i]
422
+ else:
423
+ assert len(input_vars) == 1, (f"Got single input tensor for "
424
+ f"{'inverse' if rev else 'forward'} "
425
+ f"pass, but expected list of "
426
+ f"{len(input_vars)}.")
427
+ self.variable_list[input_vars[0]] = x
428
+
429
+ for o in use_list:
430
+ try:
431
+ results = self.module_list[o[0]]([self.variable_list[i]
432
+ for i in o[1]], rev=rev)
433
+ except TypeError:
434
+ raise RuntimeError("Are you sure all used Nodes are in the "
435
+ "Node list?")
436
+ for i, r in zip(o[2], results):
437
+ self.variable_list[i] = r
438
+ # self.variable_list[o[2][0]] = self.variable_list[o[1][0]]
439
+
440
+ out = [self.variable_list[output_vars[i]]
441
+ for i in range(len(output_vars))]
442
+ if len(out) == 1:
443
+ return out[0]
444
+ else:
445
+ return out
446
+
447
+ def jacobian(self, x=None, rev=False, run_forward=True):
448
+ '''Compute the jacobian determinant of the whole net.'''
449
+ jacobian = 0
450
+
451
+ if rev:
452
+ use_list = self.indexed_ops_rev
453
+ else:
454
+ use_list = self.indexed_ops
455
+
456
+ if run_forward:
457
+ if x is None:
458
+ raise RuntimeError("You need to provide an input if you want "
459
+ "to run a forward pass")
460
+ self.forward(x, rev=rev)
461
+ jacobian_list = list()
462
+ for o in use_list:
463
+ try:
464
+ node_jac = self.module_list[o[0]].jacobian(
465
+ [self.variable_list[i] for i in o[1]], rev=rev
466
+ )
467
+ jacobian += node_jac
468
+ jacobian_list.append(jacobian)
469
+ except TypeError:
470
+ raise RuntimeError("Are you sure all used Nodes are in the "
471
+ "Node list?")
472
+
473
+ return jacobian
loralib/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .layers import *
2
+ from .utils import *
loralib/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (196 Bytes). View file
 
loralib/__pycache__/layers.cpython-38.pyc ADDED
Binary file (15.5 kB). View file
 
loralib/__pycache__/utils.cpython-38.pyc ADDED
Binary file (6.07 kB). View file
 
loralib/easymultiheadattention.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ """
6
+ Source : https://github.com/KyanChen/MakeMultiHeadNaive/blob/master/main.py
7
+ """
8
+
9
+ class PlainMultiHeadAttention(nn.Module):
10
+ def __init__(
11
+ self,
12
+ existing_mha: nn.MultiheadAttention):
13
+ super().__init__()
14
+
15
+ self.dropout = 0 # this module is not used to retrain the main block
16
+ self.embed_dim = existing_mha.embed_dim
17
+ self.kdim = existing_mha.kdim
18
+ self.vdim = existing_mha.vdim
19
+ self._qkv_same_embed_dim = existing_mha._qkv_same_embed_dim
20
+ self.num_heads = existing_mha.num_heads
21
+ self.batch_first = existing_mha.batch_first
22
+ self.head_dim = existing_mha.head_dim
23
+ self.qkv = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=existing_mha.in_proj_bias is not None)
24
+ self.proj = nn.Linear(self.embed_dim, self.embed_dim, bias=existing_mha.out_proj.bias is not None)
25
+
26
+ # Initialize parameters
27
+ with torch.no_grad():
28
+ self.qkv.weight.data.copy_(existing_mha.in_proj_weight.data)
29
+ if self.qkv.bias is not None:
30
+ self.qkv.bias.data.copy_(existing_mha.in_proj_bias.data)
31
+ self.proj.weight.data.copy_(existing_mha.out_proj.weight.data)
32
+ if self.proj.bias is not None:
33
+ self.proj.bias.data.copy_(existing_mha.out_proj.bias.data)
34
+
35
+ self.scaled_dot_product_attention = F.scaled_dot_product_attention
36
+
37
+ def forward(
38
+ self,
39
+ query,
40
+ key,
41
+ value,
42
+ key_padding_mask=None,
43
+ need_weights=True,
44
+ attn_mask=None,
45
+ average_attn_weights=True,
46
+ is_causal=False):
47
+
48
+ if attn_mask is not None and is_causal:
49
+ raise AssertionError("Only allow causal mask or attn_mask")
50
+ is_batched = query.dim() == 3
51
+ key_padding_mask = F._canonical_mask(
52
+ mask=key_padding_mask,
53
+ mask_name="key_padding_mask",
54
+ other_type=F._none_or_dtype(attn_mask),
55
+ other_name="attn_mask",
56
+ target_type=query.dtype
57
+ )
58
+
59
+ if self.batch_first and is_batched:
60
+ if key is value:
61
+ if query is key:
62
+ query = key = value = query.transpose(1, 0)
63
+ else:
64
+ query, key = [x.transpose(1, 0) for x in (query, key)]
65
+ value = key
66
+ else:
67
+ query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
68
+
69
+ tgt_len, bsz, embed_dim = query.shape
70
+ src_len, _, _ = key.shape
71
+
72
+ E = query.size(-1)
73
+ qkv = self.qkv(query)
74
+ qkv = qkv.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
75
+ q, k, v = qkv[0], qkv[1], qkv[2]
76
+
77
+ attn_mask = F._canonical_mask(
78
+ mask=attn_mask,
79
+ mask_name="attn_mask",
80
+ other_type=F._none_or_dtype(key_padding_mask),
81
+ other_name="key_padding_mask",
82
+ target_type=q.dtype,
83
+ check_other=False,
84
+ )
85
+
86
+ if attn_mask is not None:
87
+ # ensure attn_mask's dim is 3
88
+ if attn_mask.dim() == 2:
89
+ correct_2d_size = (tgt_len, src_len)
90
+ if attn_mask.shape != correct_2d_size:
91
+ raise RuntimeError(
92
+ f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
93
+ attn_mask = attn_mask.unsqueeze(0)
94
+ elif attn_mask.dim() == 3:
95
+ correct_3d_size = (bsz * self.num_heads, tgt_len, src_len)
96
+ if attn_mask.shape != correct_3d_size:
97
+ raise RuntimeError(
98
+ f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
99
+ else:
100
+ raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
101
+
102
+ if attn_mask is not None:
103
+ if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
104
+ attn_mask = attn_mask.unsqueeze(0)
105
+ else:
106
+ attn_mask = attn_mask.view(bsz, self.num_heads, -1, src_len)
107
+
108
+ dropout_p = self.dropout if self.training else 0.
109
+
110
+ q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
111
+ k = k.view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
112
+ v = v.view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
113
+ src_len = k.size(1)
114
+ q = q.view(bsz, self.num_heads, tgt_len, self.head_dim)
115
+ k = k.view(bsz, self.num_heads, src_len, self.head_dim)
116
+ v = v.view(bsz, self.num_heads, src_len, self.head_dim)
117
+
118
+ attn_output = self.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
119
+ attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
120
+ attn_output = self.proj(attn_output)
121
+ attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
122
+ if self.batch_first and is_batched:
123
+ return attn_output.transpose(1, 0), None
124
+ return attn_output, None
loralib/layers.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------------
2
+ # This code is reconstructed based on loralib (https://github.com/microsoft/LoRA) by Baijiong Lin.
3
+ # ------------------------------------------------------------------------------------------
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ import math
9
+ from typing import Optional, List
10
+
11
+ def set_param(curr_mod, name, param=None, mode='update'):
12
+ r"""Refer to https://github.com/Baijiong-Lin/MOML/blob/main/MTL/utils.py"""
13
+ if '.' in name:
14
+ n = name.split('.')
15
+ module_name = n[0]
16
+ rest = '.'.join(n[1:])
17
+ for name, mod in curr_mod.named_children():
18
+ if module_name == name:
19
+ return set_param(mod, rest, param, mode=mode)
20
+ else:
21
+ if mode == 'update':
22
+ delattr(curr_mod, name)
23
+ setattr(curr_mod, name, param)
24
+ elif mode == 'get':
25
+ if hasattr(curr_mod, name):
26
+ p = getattr(curr_mod, name)
27
+ return p
28
+
29
+ class LoRALayer():
30
+ def __init__(
31
+ self,
32
+ r: int,
33
+ lora_alpha: int,
34
+ fan_in_fan_out: bool = False,
35
+ dropout_rate:float = 0,
36
+ ):
37
+ self.r = r
38
+ self.lora_alpha = lora_alpha
39
+ self.dropout_rate = dropout_rate
40
+ if self.r > 0:
41
+ #self.scaling = self.lora_alpha / self.r
42
+ self.scaling = self.lora_alpha/math.sqrt(self.r) #
43
+ # Mark the weight as unmerged
44
+ self.merged = False
45
+ # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
46
+ self.fan_in_fan_out = fan_in_fan_out
47
+ # define params that require LoRA {'param_name': 'lora_name'}
48
+ self.params_with_lora = {}
49
+
50
+ def register_lora_param(self):
51
+ r"""Register LoRA matrix"""
52
+ for param_name, lora_name in self.params_with_lora.items():
53
+ assert len(eval(f'self.{param_name}').size()) == 2
54
+ self.register_parameter(f'{lora_name}_lora_A',
55
+ nn.Parameter(eval(f'self.{param_name}').new_zeros((self.r, eval(f'self.{param_name}').size()[1])))
56
+ )
57
+ self.register_parameter(f'{lora_name}_lora_B',
58
+ nn.Parameter(eval(f'self.{param_name}').new_zeros((eval(f'self.{param_name}').size()[0], self.r)))
59
+ )
60
+
61
+ eval(f'self.{param_name}').requires_grad = False
62
+
63
+ def init_lora_param(self):
64
+ for param_name, lora_name in self.params_with_lora.items():
65
+ if hasattr(self, f'{lora_name}_lora_A'):
66
+ # initialize A the same way as the default for nn.Linear and B to zero
67
+ nn.init.kaiming_uniform_(eval(f'self.{lora_name}_lora_A'), a=math.sqrt(5))
68
+ nn.init.zeros_(eval(f'self.{lora_name}_lora_B'))
69
+
70
+ def transpose(self, w: torch.Tensor):
71
+ return w.transpose(0, 1) if self.fan_in_fan_out else w
72
+
73
+ def merge_BA(self, param_name: str):
74
+ lora_name = self.params_with_lora[param_name]
75
+ return self.transpose((eval(f'self.{lora_name}_lora_B') @ eval(f'self.{lora_name}_lora_A')).view(eval(f'self.{param_name}').shape))
76
+
77
+
78
+
79
+
80
+ def merge_lora_param(self):
81
+ r"""p_new = p + scaling * B @ A and keep differentiable to A and B"""
82
+ for param_name, lora_name in self.params_with_lora.items():
83
+ p = set_param(self, param_name, mode='get')
84
+ # detach() is very important here
85
+
86
+ p_new = p.detach() + self.merge_BA(param_name) * self.scaling
87
+ set_param(self, param_name, param=p_new, mode='update')
88
+
89
+ def add_lora_data(self):
90
+ r"""NOT differentiable"""
91
+ for param_name, lora_name in self.params_with_lora.items():
92
+ eval(f'self.{param_name}').data += self.merge_BA(param_name) * self.scaling
93
+
94
+ def sub_lora_data(self):
95
+ r"""NOT differentiable"""
96
+ for param_name, lora_name in self.params_with_lora.items():
97
+ eval(f'self.{param_name}').data -= self.merge_BA(param_name) * self.scaling
98
+
99
+
100
+ def lora_train(self, mode: bool = True):
101
+ if mode:
102
+ if self.merged and self.r > 0:
103
+ # Make sure that the weights are not merged
104
+ self.sub_lora_data()
105
+ self.merged = False
106
+ else:
107
+ if not self.merged and self.r > 0:
108
+ # Merge the weights and mark it
109
+ self.add_lora_data()
110
+ self.merged = True
111
+
112
+
113
+ class Embedding(nn.Embedding, LoRALayer):
114
+ # LoRA implemented in a Embedding layer
115
+ def __init__(
116
+ self,
117
+ num_embeddings: int,
118
+ embedding_dim: int,
119
+ r: int = 0,
120
+ lora_alpha: int = 1,
121
+ **kwargs
122
+ ):
123
+ nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
124
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha)
125
+
126
+ self.params_with_lora = {'weight': 'w'}
127
+ if r > 0:
128
+ self.register_lora_param()
129
+ nn.Embedding.reset_parameters(self)
130
+ self.init_lora_param()
131
+
132
+ def init_lora_param(self):
133
+ if hasattr(self, 'w_lora_A'):
134
+ # initialize A the same way as the default for nn.Linear and B to zero
135
+ nn.init.zeros_(self.w_lora_A)
136
+ nn.init.normal_(self.w_lora_B)
137
+
138
+ def train(self, mode: bool = True):
139
+ nn.Embedding.train(self, mode)
140
+ self.lora_train(mode)
141
+
142
+ def forward(self, x: torch.Tensor, **kwargs):
143
+
144
+ if self.r > 0 and not self.merged:
145
+ self.merge_lora_param()
146
+ result = nn.Embedding.forward(self, x, **kwargs)
147
+ self.sub_lora_data()
148
+ return result
149
+ else:
150
+ return nn.Embedding.forward(self, x, **kwargs)
151
+
152
+ class LinearLoRA(nn.Linear, LoRALayer):
153
+ # LoRA implemented in a Linear layer
154
+ def __init__(
155
+ self,
156
+ existing_linear: nn.Linear,
157
+ r: int = 0,
158
+ lora_alpha: int = 1,
159
+ fan_in_fan_out: bool = False,
160
+ dropout_rate = 0.,
161
+ **kwargs
162
+ ):
163
+ super().__init__(
164
+ in_features=existing_linear.in_features,
165
+ out_features=existing_linear.out_features)
166
+
167
+ self.load_state_dict(existing_linear.state_dict())
168
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, fan_in_fan_out=fan_in_fan_out)
169
+
170
+ # Actual trainable parameters
171
+ self.params_with_lora = {'weight': 'w'}
172
+ if r > 0:
173
+ self.register_lora_param()
174
+ self.init_lora_param()
175
+ self.weight.data = self.transpose(self.weight.data)
176
+ if dropout_rate > 0:
177
+ self.dropout = nn.Dropout(dropout_rate)
178
+ else:
179
+ self.dropout = None
180
+
181
+ def train(self, mode: bool = True):
182
+ super().train(mode)
183
+ self.lora_train(mode)
184
+
185
+
186
+ def forward(self, x: torch.Tensor, **kwargs):
187
+
188
+ if self.dropout is None: # do as before
189
+ if self.r > 0 and not self.merged:
190
+ self.merge_lora_param()
191
+ result = nn.Linear.forward(self, x, **kwargs)
192
+ self.sub_lora_data()
193
+ return result
194
+ else:
195
+ return nn.Linear.forward(self, x, **kwargs)
196
+
197
+ # Compute the original linear transformation
198
+ original_output = nn.Linear.forward(self, x)
199
+
200
+ if self.training and self.dropout.p > 0:
201
+ x = self.dropout(x)
202
+
203
+ if self.r > 0 and not self.merged:
204
+ lora_adjustment = torch.matmul(x,self.merge_BA('weight').transpose(0, 1)) * self.scaling
205
+ result = original_output + lora_adjustment
206
+ else:
207
+ result = original_output
208
+ return result
209
+
210
+ class Conv1d(nn.Conv1d, LoRALayer):
211
+ # LoRA implemented in a Conv1d layer
212
+ def __init__(
213
+ self,
214
+ in_channels: int,
215
+ out_channels: int,
216
+ kernel_size: int,
217
+ r: int = 0,
218
+ lora_alpha: int = 1,
219
+ **kwargs
220
+ ):
221
+ nn.Conv1d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
222
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha)
223
+
224
+ assert type(kernel_size) is int
225
+ # Actual trainable parameters
226
+ self.params_with_lora = {'weight': 'w'}
227
+ if r > 0:
228
+ self.w_lora_A = nn.Parameter(
229
+ self.weight.new_zeros((r*kernel_size, in_channels*kernel_size))
230
+ )
231
+ self.w_lora_B = nn.Parameter(
232
+ self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size))
233
+ )
234
+ # Freezing the pre-trained weight matrix
235
+ self.weight.requires_grad = False
236
+ nn.Conv1d.reset_parameters(self)
237
+ self.init_lora_param()
238
+
239
+ def train(self, mode: bool = True):
240
+ nn.Conv1d.train(self, mode)
241
+ self.lora_train(mode)
242
+
243
+ def forward(self, x: torch.Tensor, **kwargs):
244
+
245
+ if self.r > 0 and not self.merged:
246
+ self.merge_lora_param()
247
+ result = nn.Conv1d.forward(self, x, **kwargs)
248
+ self.sub_lora_data()
249
+ return result
250
+ else:
251
+ return nn.Conv1d.forward(self, x, **kwargs)
252
+
253
+ class Conv2d(nn.Conv2d, LoRALayer):
254
+ # LoRA implemented in a Conv2d layer
255
+ def __init__(
256
+ self,
257
+ in_channels: int,
258
+ out_channels: int,
259
+ kernel_size: int,
260
+ r: int = 0,
261
+ lora_alpha: int = 1,
262
+ **kwargs
263
+ ):
264
+ nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
265
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha)
266
+
267
+ assert type(kernel_size) is int
268
+ # Actual trainable parameters
269
+ self.params_with_lora = {'weight': 'w'}
270
+ if r > 0:
271
+ self.w_lora_A = nn.Parameter(
272
+ self.weight.new_zeros((r*kernel_size, in_channels*kernel_size))
273
+ )
274
+ self.w_lora_B = nn.Parameter(
275
+ self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size))
276
+ )
277
+ # Freezing the pre-trained weight matrix
278
+ self.weight.requires_grad = False
279
+ nn.Conv2d.reset_parameters(self)
280
+ self.init_lora_param()
281
+
282
+ def train(self, mode: bool = True):
283
+ nn.Conv2d.train(self, mode)
284
+ self.lora_train(mode)
285
+
286
+ def forward(self, x: torch.Tensor, **kwargs):
287
+
288
+ if self.r > 0 and not self.merged:
289
+ self.merge_lora_param()
290
+ result = nn.Conv2d.forward(self, x, **kwargs)
291
+ self.sub_lora_data()
292
+ return result
293
+ else:
294
+ return nn.Conv2d.forward(self, x, **kwargs)
295
+
296
+ class Conv3d(nn.Conv3d, LoRALayer):
297
+ # LoRA implemented in a Conv3d layer
298
+ def __init__(
299
+ self,
300
+ in_channels: int,
301
+ out_channels: int,
302
+ kernel_size: int,
303
+ r: int = 0,
304
+ lora_alpha: int = 1,
305
+ **kwargs
306
+ ):
307
+ nn.Conv3d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
308
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha)
309
+
310
+ assert type(kernel_size) is int
311
+ # Actual trainable parameters
312
+ self.params_with_lora = {'weight': 'w'}
313
+ if r > 0:
314
+ self.w_lora_A = nn.Parameter(
315
+ self.weight.new_zeros((r*kernel_size, in_channels*kernel_size))
316
+ )
317
+ self.w_lora_B = nn.Parameter(
318
+ self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size))
319
+ )
320
+ # Freezing the pre-trained weight matrix
321
+ self.weight.requires_grad = False
322
+ nn.Conv3d.reset_parameters(self)
323
+ self.init_lora_param()
324
+
325
+ def train(self, mode: bool = True):
326
+ nn.Conv3d.train(self, mode)
327
+ self.lora_train(mode)
328
+
329
+ def forward(self, x: torch.Tensor, **kwargs):
330
+
331
+ if self.r > 0 and not self.merged:
332
+ self.merge_lora_param()
333
+ result = nn.Conv3d.forward(self, x, **kwargs)
334
+ self.sub_lora_data()
335
+ return result
336
+ else:
337
+ return nn.Conv3d.forward(self, x, **kwargs)
338
+
339
+
340
+ class PlainMultiheadAttentionLoRA(nn.Module):
341
+ def __init__(
342
+ self,
343
+ existing_mha: nn.MultiheadAttention,
344
+ enable_lora: list = ['q', 'k', 'v', 'o'],
345
+ r: int = 0,
346
+ lora_alpha: int = 1,
347
+ dropout_rate:float = 0.,
348
+ **kwargs
349
+ ):
350
+ super().__init__()
351
+
352
+ self.dropout = 0 # this module is not used to retrain the main block
353
+ self.embed_dim = existing_mha.embed_dim
354
+ self.kdim = existing_mha.kdim
355
+ self.vdim = existing_mha.vdim
356
+ self._qkv_same_embed_dim = existing_mha._qkv_same_embed_dim
357
+ self.num_heads = existing_mha.num_heads
358
+ self.batch_first = existing_mha.batch_first
359
+ self.head_dim = existing_mha.head_dim
360
+ #self.qkv = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=existing_mha.in_proj_bias is not None)
361
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=existing_mha.in_proj_bias is not None)
362
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=existing_mha.in_proj_bias is not None)
363
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=existing_mha.in_proj_bias is not None)
364
+ self.proj = nn.Linear(self.embed_dim, self.embed_dim, bias=existing_mha.out_proj.bias is not None)
365
+
366
+ # Initialize parameters
367
+ with torch.no_grad():
368
+
369
+ # Extract the existing weights and biases
370
+ existing_weight = existing_mha.in_proj_weight.data
371
+ existing_bias = existing_mha.in_proj_bias.data if existing_mha.in_proj_bias is not None else None
372
+
373
+ # Initialize q_proj
374
+ self.q_proj.weight.data.copy_(existing_weight[:self.embed_dim, :])
375
+ if existing_bias is not None:
376
+ self.q_proj.bias.data.copy_(existing_bias[:self.embed_dim])
377
+
378
+ # Initialize k_proj
379
+ self.k_proj.weight.data.copy_(existing_weight[self.embed_dim:2*self.embed_dim, :])
380
+ if existing_bias is not None:
381
+ self.k_proj.bias.data.copy_(existing_bias[self.embed_dim:2*self.embed_dim])
382
+
383
+ # Initialize v_proj
384
+ self.v_proj.weight.data.copy_(existing_weight[2*self.embed_dim:, :])
385
+ if existing_bias is not None:
386
+ self.v_proj.bias.data.copy_(existing_bias[2*self.embed_dim:])
387
+
388
+ # Initialize proj
389
+ self.proj.weight.data.copy_(existing_mha.out_proj.weight.data)
390
+ if self.proj.bias is not None:
391
+ self.proj.bias.data.copy_(existing_mha.out_proj.bias.data)
392
+
393
+ self.scaled_dot_product_attention = F.scaled_dot_product_attention
394
+
395
+
396
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, dropout_rate=dropout_rate)
397
+
398
+ # Init qkv as a new lora linear layer
399
+ for item in enable_lora:
400
+ if item == 'q':
401
+ self.q_proj = LinearLoRA(self.q_proj,
402
+ r=r,
403
+ lora_alpha=lora_alpha,
404
+ fan_in_fan_out=False,
405
+ dropout_rate = dropout_rate)
406
+ elif item == 'k':
407
+ self.k_proj = LinearLoRA(self.k_proj,
408
+ r=r,
409
+ lora_alpha=lora_alpha,
410
+ fan_in_fan_out=False,
411
+ dropout_rate = dropout_rate)
412
+ elif item == 'v':
413
+ self.v_proj = LinearLoRA(self.v_proj,
414
+ r=r,
415
+ lora_alpha=lora_alpha,
416
+ fan_in_fan_out=False,
417
+ dropout_rate = dropout_rate)
418
+ elif item == 'o':
419
+ self.proj = LinearLoRA(self.proj,
420
+ r=r,
421
+ lora_alpha=lora_alpha,
422
+ fan_in_fan_out=False,
423
+ dropout_rate = dropout_rate)
424
+
425
+ def forward_module(
426
+ self,
427
+ query,
428
+ key,
429
+ value,
430
+ key_padding_mask=None,
431
+ need_weights=True,
432
+ attn_mask=None,
433
+ average_attn_weights=True,
434
+ is_causal=False):
435
+
436
+ if attn_mask is not None and is_causal:
437
+ raise AssertionError("Only allow causal mask or attn_mask")
438
+ is_batched = query.dim() == 3
439
+ key_padding_mask = F._canonical_mask(
440
+ mask=key_padding_mask,
441
+ mask_name="key_padding_mask",
442
+ other_type=F._none_or_dtype(attn_mask),
443
+ other_name="attn_mask",
444
+ target_type=query.dtype
445
+ )
446
+
447
+ if self.batch_first and is_batched:
448
+ if key is value:
449
+ if query is key:
450
+ query = key = value = query.transpose(1, 0)
451
+ else:
452
+ query, key = [x.transpose(1, 0) for x in (query, key)]
453
+ value = key
454
+ else:
455
+ query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
456
+
457
+ tgt_len, bsz, embed_dim = query.shape
458
+ src_len, _, _ = key.shape
459
+ """
460
+ E = query.size(-1)
461
+ qkv = self.qkv(query)
462
+ qkv = qkv.unflatten(-1, (3, E)).unsqueeze(0).transpose(0, -2).squeeze(-2).contiguous()
463
+ q, k, v = qkv[0], qkv[1], qkv[2]
464
+ """
465
+
466
+ q = self.q_proj(query)
467
+ k = self.k_proj(key)
468
+ v = self.v_proj(value)
469
+
470
+ attn_mask = F._canonical_mask(
471
+ mask=attn_mask,
472
+ mask_name="attn_mask",
473
+ other_type=F._none_or_dtype(key_padding_mask),
474
+ other_name="key_padding_mask",
475
+ target_type=q.dtype,
476
+ check_other=False,
477
+ )
478
+
479
+ if attn_mask is not None:
480
+ # ensure attn_mask's dim is 3
481
+ if attn_mask.dim() == 2:
482
+ correct_2d_size = (tgt_len, src_len)
483
+ if attn_mask.shape != correct_2d_size:
484
+ raise RuntimeError(
485
+ f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.")
486
+ attn_mask = attn_mask.unsqueeze(0)
487
+ elif attn_mask.dim() == 3:
488
+ correct_3d_size = (bsz * self.num_heads, tgt_len, src_len)
489
+ if attn_mask.shape != correct_3d_size:
490
+ raise RuntimeError(
491
+ f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.")
492
+ else:
493
+ raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
494
+
495
+ if attn_mask is not None:
496
+ if attn_mask.size(0) == 1 and attn_mask.dim() == 3:
497
+ attn_mask = attn_mask.unsqueeze(0)
498
+ else:
499
+ attn_mask = attn_mask.view(bsz, self.num_heads, -1, src_len)
500
+
501
+ dropout_p = self.dropout if self.training else 0.
502
+
503
+ q = q.view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
504
+ k = k.view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
505
+ v = v.view(src_len, bsz * self.num_heads, self.head_dim).transpose(0, 1)
506
+ src_len = k.size(1)
507
+ q = q.view(bsz, self.num_heads, tgt_len, self.head_dim)
508
+ k = k.view(bsz, self.num_heads, src_len, self.head_dim)
509
+ v = v.view(bsz, self.num_heads, src_len, self.head_dim)
510
+
511
+ attn_output = self.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
512
+ attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
513
+ attn_output = self.proj(attn_output)
514
+ attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
515
+ if self.batch_first and is_batched:
516
+ return attn_output.transpose(1, 0), None
517
+ return attn_output, None
518
+
519
+ def train(self, mode: bool = True):
520
+ super().train(mode)
521
+ #self.lora_train(mode)
522
+
523
+ def forward(self,
524
+ query: torch.Tensor,
525
+ key: torch.Tensor,
526
+ value: torch.Tensor,
527
+ **kwargs):
528
+
529
+
530
+ return self.forward_module(query, key, value, **kwargs)
531
+
532
+
533
+
534
+ class MergedLinear(nn.Linear, LoRALayer):
535
+ # LoRA implemented in a dense layer
536
+ def __init__(
537
+ self,
538
+ in_features: int,
539
+ out_features: int,
540
+ r: int = 0,
541
+ lora_alpha: int = 1,
542
+ enable_lora: List[bool] = [False],
543
+ fan_in_fan_out: bool = False,
544
+ **kwargs
545
+ ):
546
+ nn.Linear.__init__(self, in_features, out_features, **kwargs)
547
+ LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha)
548
+
549
+ assert out_features % len(enable_lora) == 0, \
550
+ 'The length of enable_lora must divide out_features'
551
+ self.enable_lora = enable_lora
552
+ # Actual trainable parameters
553
+ self.params_with_lora = {'weight': 'w'}
554
+ if r > 0 and any(enable_lora):
555
+ self.w_lora_A = nn.Parameter(
556
+ self.weight.new_zeros((r * sum(enable_lora), in_features)))
557
+ self.w_lora_B = nn.Parameter(
558
+ self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r))
559
+ ) # weights for Conv1D with groups=sum(enable_lora)
560
+ # Freezing the pre-trained weight matrix
561
+ self.weight.requires_grad = False
562
+ # Compute the indices
563
+ self.lora_ind = self.weight.new_zeros(
564
+ (out_features, ), dtype=torch.bool
565
+ ).view(len(enable_lora), -1)
566
+ self.lora_ind[enable_lora, :] = True
567
+ self.lora_ind = self.lora_ind.view(-1)
568
+ nn.Linear.reset_parameters(self)
569
+ self.init_lora_param()
570
+ self.weight.data = self.transpose(self.weight.data)
571
+
572
+ def zero_pad(self, x):
573
+ result = x.new_zeros((len(self.lora_ind), *x.shape[1:]))
574
+ result[self.lora_ind] = x
575
+ return result
576
+
577
+ def merge_BA(self, param_name: str):
578
+ lora_name = self.params_with_lora[param_name]
579
+ delta_w = F.conv1d(
580
+ eval(f'self.{lora_name}_lora_A').unsqueeze(0),
581
+ eval(f'self.{lora_name}_lora_B').unsqueeze(-1),
582
+ groups=sum(self.enable_lora)
583
+ ).squeeze(0)
584
+ return self.transpose(self.zero_pad(delta_w))
585
+
586
+ def train(self, mode: bool = True):
587
+ nn.Linear.train(self, mode)
588
+ self.lora_train(mode)
589
+
590
+ def forward(self, x: torch.Tensor, **kwargs):
591
+
592
+ if self.r > 0 and not self.merged:
593
+ self.merge_lora_param()
594
+ result = nn.Linear.forward(self, x, **kwargs)
595
+ self.sub_lora_data()
596
+ return result
597
+ else:
598
+ return nn.Linear.forward(self, x, **kwargs)
loralib/utils.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from typing import Dict
7
+
8
+ from .layers import LoRALayer, PlainMultiheadAttentionLoRA
9
+
10
+ INDEX_POSITIONS_TEXT = {
11
+ 'top1': [11],
12
+ 'top2': [10, 11],
13
+ 'top3': [9, 10, 11],
14
+ 'bottom': [0, 1, 2, 3],
15
+ 'mid': [4, 5, 6, 7],
16
+ 'up': [8, 9, 10, 11],
17
+ 'half-up': [6, 7, 8, 9, 10, 11],
18
+ 'half-bottom': [0, 1, 2, 3, 4, 5],
19
+ 'all': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]}
20
+
21
+
22
+ INDEX_POSITIONS_VISION = {
23
+ 'ViT-B/16': {
24
+ 'top': [11],
25
+ 'top3': [9, 10, 11],
26
+ 'bottom': [0, 1, 2, 3],
27
+ 'mid': [4, 5, 6, 7],
28
+ 'up': [8, 9, 10, 11],
29
+ 'half-up': [6, 7, 8, 9, 10, 11],
30
+ 'half-bottom': [0, 1, 2, 3, 4, 5],
31
+ 'all': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]},
32
+ 'ViT-B/32': {
33
+ 'bottom': [0, 1, 2, 3],
34
+ 'mid': [4, 5, 6, 7],
35
+ 'up': [8, 9, 10, 11],
36
+ 'half-up': [6, 7, 8, 9, 10, 11],
37
+ 'half-bottom': [0, 1, 2, 3, 4, 5],
38
+ 'all': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]},
39
+
40
+ 'ViT-L/14': {
41
+ 'half-up': [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23],
42
+ 'half-bottom': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
43
+ 'all': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]}
44
+ }
45
+
46
+
47
+ def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None:
48
+ for n, p in model.named_parameters():
49
+ if 'lora_' not in n:
50
+ p.requires_grad = False
51
+ if bias == 'none':
52
+ return
53
+ elif bias == 'all':
54
+ for n, p in model.named_parameters():
55
+ if 'bias' in n:
56
+ p.requires_grad = True
57
+ elif bias == 'lora_only':
58
+ for m in model.modules():
59
+ if isinstance(m, LoRALayer) and \
60
+ hasattr(m, 'bias') and \
61
+ m.bias is not None:
62
+ m.bias.requires_grad = True
63
+ else:
64
+ raise NotImplementedError
65
+
66
+
67
+ def lora_state_dict(model: nn.Module, bias: str = 'none') -> Dict[str, torch.Tensor]:
68
+ my_state_dict = model.state_dict()
69
+ if bias == 'none':
70
+ return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k}
71
+ elif bias == 'all':
72
+ return {k: my_state_dict[k] for k in my_state_dict if 'lora_' in k or 'bias' in k}
73
+ elif bias == 'lora_only':
74
+ to_return = {}
75
+ for k in my_state_dict:
76
+ if 'lora_' in k:
77
+ to_return[k] = my_state_dict[k]
78
+ bias_name = k.split('lora_')[0]+'bias'
79
+ if bias_name in my_state_dict:
80
+ to_return[bias_name] = my_state_dict[bias_name]
81
+ return to_return
82
+ else:
83
+ raise NotImplementedError
84
+
85
+
86
+ def get_lora_parameters(model, bias='none'):
87
+ params = []
88
+ for name, param in model.named_parameters():
89
+ if bias == 'none':
90
+ if 'lora_' in name:
91
+ params.append(param)
92
+ elif bias == 'all':
93
+ if 'lora_' in name or 'bias' in name:
94
+ params.append(param)
95
+ elif bias == 'lora_only':
96
+ if 'lora_' in name:
97
+ params.append(param)
98
+ bias_name = name.split('lora_')[0] + 'bias'
99
+ if bias_name in model.state_dict():
100
+ bias_param = dict(model.named_parameters())[bias_name]
101
+ params.append(bias_param)
102
+ else:
103
+ raise NotImplementedError
104
+ return params
105
+
106
+
107
+ def apply_lora(args, clip_model):
108
+ list_lora_layers = []
109
+ if args.encoder == 'text' or args.encoder == 'both':
110
+ indices = INDEX_POSITIONS_TEXT[args.position]
111
+ text_encoder = clip_model.transformer
112
+ for i, block in enumerate(text_encoder.resblocks):
113
+ print(f"Residual Attention Block {i}: {block}")
114
+ if i in indices:
115
+ for name, submodule in block.named_children():
116
+ if isinstance(submodule, nn.MultiheadAttention):
117
+ new_multi_head_lora = PlainMultiheadAttentionLoRA(
118
+ submodule, enable_lora=args.params, r=args.r, lora_alpha=args.alpha, dropout_rate=args.dropout_rate)
119
+ setattr(block, name, new_multi_head_lora)
120
+ list_lora_layers.append(new_multi_head_lora)
121
+
122
+ if args.encoder == 'vision' or args.encoder == 'both':
123
+ indices = INDEX_POSITIONS_VISION[args.backbone][args.position]
124
+ vision_encoder = clip_model.visual.transformer
125
+ for i, block in enumerate(vision_encoder.resblocks):
126
+ print(f"Residual Attention Block {i}: {block}")
127
+ if i in indices:
128
+ for name, submodule in block.named_children():
129
+ if isinstance(submodule, nn.MultiheadAttention):
130
+ new_multi_head_lora = PlainMultiheadAttentionLoRA(
131
+ submodule, enable_lora=args.params, r=args.r, lora_alpha=args.alpha, dropout_rate=args.dropout_rate)
132
+ setattr(block, name, new_multi_head_lora)
133
+ list_lora_layers.append(new_multi_head_lora)
134
+ return list_lora_layers
135
+
136
+
137
+ def save_lora(args, list_lora_layers):
138
+ weights = {}
139
+ for i, layer in enumerate(list_lora_layers):
140
+ layer_weights = {}
141
+ if 'q' in args.params:
142
+ layer_weights['q_proj'] = {
143
+ 'w_lora_A': layer.q_proj.w_lora_A.data,
144
+ 'w_lora_B': layer.q_proj.w_lora_B.data
145
+ }
146
+ if 'k' in args.params:
147
+ layer_weights['k_proj'] = {
148
+ 'w_lora_A': layer.k_proj.w_lora_A.data,
149
+ 'w_lora_B': layer.k_proj.w_lora_B.data
150
+ }
151
+ if 'v' in args.params:
152
+ layer_weights['v_proj'] = {
153
+ 'w_lora_A': layer.v_proj.w_lora_A.data,
154
+ 'w_lora_B': layer.v_proj.w_lora_B.data
155
+ }
156
+ if 'o' in args.params:
157
+ layer_weights['proj'] = {
158
+ 'w_lora_A': layer.proj.w_lora_A.data,
159
+ 'w_lora_B': layer.proj.w_lora_B.data
160
+ }
161
+
162
+ weights[f'layer_{i}'] = layer_weights
163
+
164
+ metadata = {
165
+ 'r': args.r,
166
+ 'alpha': args.alpha,
167
+ 'encoder': args.encoder,
168
+ 'params': args.params,
169
+ 'position': args.position
170
+ }
171
+
172
+ save_data = {
173
+ 'weights': weights,
174
+ 'metadata': metadata
175
+ }
176
+
177
+ # to manage names like ViT-B/16
178
+ backbone = args.backbone.replace('/', '').replace('-', '').lower()
179
+ save_dir = f'{args.save_path}/{backbone}/{args.dataset}/{args.shots}shots/seed{args.seed}'
180
+ os.makedirs(save_dir, exist_ok=True)
181
+
182
+ save_path = f'{save_dir}/{args.filename}.pt'
183
+ torch.save(save_data, save_path)
184
+ print(f'LoRA weights saved to {save_path}')
185
+
186
+
187
+ def load_lora(args, list_lora_layers):
188
+ # to manage names like ViT-B/16
189
+ backbone = args.backbone.replace('/', '').replace('-', '').lower()
190
+ load_path = f'{args.save_path}/{backbone}/{args.dataset}/{args.shots}shots/seed{args.seed}/{args.filename}.pt'
191
+
192
+ if not os.path.exists(load_path):
193
+ raise FileNotFoundError(f'File {load_path} does not exist.')
194
+
195
+ loaded_data = torch.load(load_path)
196
+
197
+ metadata = loaded_data['metadata']
198
+ if metadata['r'] != args.r:
199
+ raise ValueError(
200
+ f"r mismatch: expected {args.r}, found {metadata['r']}")
201
+ if metadata['alpha'] != args.alpha:
202
+ raise ValueError(
203
+ f"alpha mismatch: expected {args.alpha}, found {metadata['alpha']}")
204
+ if metadata['encoder'] != args.encoder:
205
+ raise ValueError(
206
+ f"Encoder mismatch: expected {args.encoder}, found {metadata['encoder']}")
207
+ if metadata['params'] != args.params:
208
+ raise ValueError(
209
+ f"Params mismatch: expected {args.params}, found {metadata['params']}")
210
+ if metadata['position'] != args.position:
211
+ raise ValueError(
212
+ f"Position mismatch: expected {args.position}, found {metadata['position']}")
213
+
214
+ weights = loaded_data['weights']
215
+ for i, layer in enumerate(list_lora_layers):
216
+ layer_weights = weights[f'layer_{i}']
217
+ if 'q' in args.params and 'q_proj' in layer_weights:
218
+ layer.q_proj.w_lora_A.data.copy_(
219
+ layer_weights['q_proj']['w_lora_A'])
220
+ layer.q_proj.w_lora_B.data.copy_(
221
+ layer_weights['q_proj']['w_lora_B'])
222
+ if 'k' in args.params and 'k_proj' in layer_weights:
223
+ layer.k_proj.w_lora_A.data.copy_(
224
+ layer_weights['k_proj']['w_lora_A'])
225
+ layer.k_proj.w_lora_B.data.copy_(
226
+ layer_weights['k_proj']['w_lora_B'])
227
+ if 'v' in args.params and 'v_proj' in layer_weights:
228
+ layer.v_proj.w_lora_A.data.copy_(
229
+ layer_weights['v_proj']['w_lora_A'])
230
+ layer.v_proj.w_lora_B.data.copy_(
231
+ layer_weights['v_proj']['w_lora_B'])
232
+ if 'o' in args.params and 'proj' in layer_weights:
233
+ layer.proj.w_lora_A.data.copy_(layer_weights['proj']['w_lora_A'])
234
+ layer.proj.w_lora_B.data.copy_(layer_weights['proj']['w_lora_B'])
235
+
236
+ print(f'LoRA weights loaded from {load_path}')
model.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ from torchvision.models import alexnet
7
+
8
+ import config as c
9
+ from freia_funcs import permute_layer, glow_coupling_layer, F_fully_connected, ReversibleGraphNet, OutputNode, \
10
+ InputNode, Node
11
+
12
+ WEIGHT_DIR = './weights'
13
+ MODEL_DIR = './models'
14
+
15
+
16
+ def nf_head(input_dim=c.n_feat):
17
+ nodes = list()
18
+ nodes.append(InputNode(input_dim, name='input'))
19
+ for k in range(c.n_coupling_blocks):
20
+ nodes.append(Node([nodes[-1].out0], permute_layer, {'seed': k}, name=F'permute_{k}'))
21
+ nodes.append(Node([nodes[-1].out0], glow_coupling_layer,
22
+ {'clamp': c.clamp_alpha, 'F_class': F_fully_connected,
23
+ 'F_args': {'internal_size': c.fc_internal, 'dropout': c.dropout}},
24
+ name=F'fc_{k}'))
25
+ nodes.append(OutputNode([nodes[-1].out0], name='output'))
26
+ coder = ReversibleGraphNet(nodes)
27
+ return coder
28
+
29
+
30
+ class flow_model(nn.Module):
31
+ def __init__(self):
32
+ super(flow_model, self).__init__()
33
+
34
+ self.nf = nf_head(input_dim = 1024)
35
+
36
+ def forward(self, x):
37
+ z = self.nf(x)
38
+ return z
39
+
40
+ class flow_model_multi_fc(nn.Module):
41
+ def __init__(self):
42
+ super(flow_model_multi_fc, self).__init__()
43
+ self.fc1 = torch.nn.Linear(1024, 512)
44
+ self.relu = torch.nn.LeakyReLU(0.2)
45
+ self.fc2 = torch.nn.Linear(512, 256)
46
+
47
+ self.nf = nf_head(input_dim = 256)
48
+
49
+ def forward(self, x):
50
+ res_x = self.fc2(self.relu((self.fc1(x))))
51
+ z = self.nf(res_x)
52
+ return z
53
+
54
+
55
+ class DifferNet(nn.Module):
56
+ def __init__(self):
57
+ super(DifferNet, self).__init__()
58
+ self.feature_extractor = alexnet(pretrained=True)
59
+ self.nf = nf_head()
60
+
61
+ def forward(self, x):
62
+ y_cat = list()
63
+
64
+ for s in range(c.n_scales):
65
+ x_scaled = F.interpolate(x, size=c.img_size[0] // (2 ** s)) if s > 0 else x
66
+ feat_s = self.feature_extractor.features(x_scaled)
67
+ y_cat.append(torch.mean(feat_s, dim=(2, 3)))
68
+
69
+ y = torch.cat(y_cat, dim=1)
70
+ z = self.nf(y)
71
+ return z
72
+
73
+
74
+ def save_model(model, filename):
75
+ if not os.path.exists(MODEL_DIR):
76
+ os.makedirs(MODEL_DIR)
77
+ torch.save(model, os.path.join(MODEL_DIR, filename))
78
+
79
+
80
+ def load_model(filename):
81
+ path = os.path.join(MODEL_DIR, filename)
82
+ model = torch.load(path)
83
+ return model
84
+
85
+
86
+ def save_weights(model, filename):
87
+ if not os.path.exists(WEIGHT_DIR):
88
+ os.makedirs(WEIGHT_DIR)
89
+ torch.save(model.state_dict(), os.path.join(WEIGHT_DIR, filename))
90
+
91
+
92
+ def load_weights(model, filename):
93
+ path = os.path.join(WEIGHT_DIR, filename)
94
+ model.load_state_dict(torch.load(path))
95
+ return model
models/__init__.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .clip_models import CLIPModel
2
+ from .imagenet_models import ImagenetModel
3
+
4
+
5
+ VALID_NAMES = [
6
+ 'Imagenet:resnet18',
7
+ 'Imagenet:resnet34',
8
+ 'Imagenet:resnet50',
9
+ 'Imagenet:resnet101',
10
+ 'Imagenet:resnet152',
11
+ 'Imagenet:vgg11',
12
+ 'Imagenet:vgg19',
13
+ 'Imagenet:swin-b',
14
+ 'Imagenet:swin-s',
15
+ 'Imagenet:swin-t',
16
+ 'Imagenet:vit_b_16',
17
+ 'Imagenet:vit_b_32',
18
+ 'Imagenet:vit_l_16',
19
+ 'Imagenet:vit_l_32',
20
+
21
+ 'CLIP:RN50',
22
+ 'CLIP:RN101',
23
+ 'CLIP:RN50x4',
24
+ 'CLIP:RN50x16',
25
+ 'CLIP:RN50x64',
26
+ 'CLIP:ViT-B/32',
27
+ 'CLIP:ViT-B/16',
28
+ 'CLIP:ViT-L/14',
29
+ 'CLIP:ViT-L/14@336px',
30
+ ]
31
+
32
+
33
+
34
+
35
+
36
+ def get_model(name):
37
+ assert name in VALID_NAMES
38
+ if name.startswith("Imagenet:"):
39
+ return ImagenetModel(name[9:])
40
+ elif name.startswith("CLIP:"):
41
+ return CLIPModel(name[5:])
42
+ else:
43
+ assert False
models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (1.02 kB). View file
 
models/__pycache__/clip_models.cpython-38.pyc ADDED
Binary file (1.01 kB). View file
 
models/__pycache__/imagenet_models.cpython-38.pyc ADDED
Binary file (1.33 kB). View file
 
models/__pycache__/resnet.cpython-38.pyc ADDED
Binary file (9.76 kB). View file
 
models/__pycache__/vision_transformer.cpython-38.pyc ADDED
Binary file (12.2 kB). View file
 
models/__pycache__/vision_transformer_misc.cpython-38.pyc ADDED
Binary file (6.53 kB). View file
 
models/__pycache__/vision_transformer_utils.cpython-38.pyc ADDED
Binary file (18.1 kB). View file
 
models/clip/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .clip import *
models/clip/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (176 Bytes). View file
 
models/clip/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (193 Bytes). View file
 
models/clip/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (193 Bytes). View file
 
models/clip/__pycache__/clip.cpython-310.pyc ADDED
Binary file (8.82 kB). View file
 
models/clip/__pycache__/clip.cpython-38.pyc ADDED
Binary file (8.72 kB). View file
 
models/clip/__pycache__/clip.cpython-39.pyc ADDED
Binary file (8.8 kB). View file
 
models/clip/__pycache__/model.cpython-310.pyc ADDED
Binary file (15.4 kB). View file
 
models/clip/__pycache__/model.cpython-38.pyc ADDED
Binary file (15.2 kB). View file