tonyassi commited on
Commit
17cdd51
·
1 Parent(s): 9e9ee28

Upload 84 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. LICENSE_Lavis.md +14 -0
  3. cheetah/.DS_Store +0 -0
  4. cheetah/__init__.py +22 -0
  5. cheetah/__pycache__/__init__.cpython-310.pyc +0 -0
  6. cheetah/common/__init__.py +0 -0
  7. cheetah/common/__pycache__/__init__.cpython-310.pyc +0 -0
  8. cheetah/common/__pycache__/config.cpython-310.pyc +0 -0
  9. cheetah/common/__pycache__/dist_utils.cpython-310.pyc +0 -0
  10. cheetah/common/__pycache__/logger.cpython-310.pyc +0 -0
  11. cheetah/common/__pycache__/registry.cpython-310.pyc +0 -0
  12. cheetah/common/__pycache__/utils.cpython-310.pyc +0 -0
  13. cheetah/common/config.py +468 -0
  14. cheetah/common/dist_utils.py +137 -0
  15. cheetah/common/gradcam.py +24 -0
  16. cheetah/common/logger.py +195 -0
  17. cheetah/common/optims.py +119 -0
  18. cheetah/common/registry.py +216 -0
  19. cheetah/common/utils.py +424 -0
  20. cheetah/configs/default.yaml +5 -0
  21. cheetah/configs/models/cheetah_llama2.yaml +33 -0
  22. cheetah/configs/models/cheetah_vicuna.yaml +33 -0
  23. cheetah/conversation/__init__.py +0 -0
  24. cheetah/conversation/__pycache__/__init__.cpython-310.pyc +0 -0
  25. cheetah/conversation/__pycache__/conversation.cpython-310.pyc +0 -0
  26. cheetah/conversation/__pycache__/conversation_llama2.cpython-310.pyc +0 -0
  27. cheetah/conversation/conversation.py +364 -0
  28. cheetah/conversation/conversation_llama2.py +387 -0
  29. cheetah/models/Qformer.py +1216 -0
  30. cheetah/models/__init__.py +202 -0
  31. cheetah/models/__pycache__/Qformer.cpython-310.pyc +0 -0
  32. cheetah/models/__pycache__/__init__.cpython-310.pyc +0 -0
  33. cheetah/models/__pycache__/base_model.cpython-310.pyc +0 -0
  34. cheetah/models/__pycache__/blip2.cpython-310.pyc +0 -0
  35. cheetah/models/__pycache__/cheetah_llama2.cpython-310.pyc +0 -0
  36. cheetah/models/__pycache__/cheetah_vicuna.cpython-310.pyc +0 -0
  37. cheetah/models/__pycache__/eva_vit.cpython-310.pyc +0 -0
  38. cheetah/models/__pycache__/modeling_llama.cpython-310.pyc +0 -0
  39. cheetah/models/__pycache__/modeling_llama2.cpython-310.pyc +0 -0
  40. cheetah/models/base_model.py +247 -0
  41. cheetah/models/blip2.py +221 -0
  42. cheetah/models/blip2_outputs.py +110 -0
  43. cheetah/models/cheetah_llama2.py +388 -0
  44. cheetah/models/cheetah_vicuna.py +387 -0
  45. cheetah/models/eva_vit.py +442 -0
  46. cheetah/models/modeling_llama.py +803 -0
  47. cheetah/models/modeling_llama2.py +1070 -0
  48. cheetah/processors/__init__.py +33 -0
  49. cheetah/processors/__pycache__/__init__.cpython-310.pyc +0 -0
  50. cheetah/processors/__pycache__/base_processor.cpython-310.pyc +0 -0
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/18.png filter=lfs diff=lfs merge=lfs -text
37
+ examples/19.png filter=lfs diff=lfs merge=lfs -text
38
+ examples/24.png filter=lfs diff=lfs merge=lfs -text
LICENSE_Lavis.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2022 Salesforce, Inc.
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
9
+
10
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
11
+
12
+ 3. Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
13
+
14
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cheetah/.DS_Store ADDED
Binary file (6.15 kB). View file
 
cheetah/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ from omegaconf import OmegaConf
5
+
6
+ from cheetah.common.registry import registry
7
+ from cheetah.models import *
8
+ from cheetah.processors import *
9
+
10
+
11
+
12
+ root_dir = os.path.dirname(os.path.abspath(__file__))
13
+ default_cfg = OmegaConf.load(os.path.join(root_dir, "configs/default.yaml"))
14
+
15
+ registry.register_path("library_root", root_dir)
16
+ repo_root = os.path.join(root_dir, "..")
17
+ registry.register_path("repo_root", repo_root)
18
+ cache_root = os.path.join(repo_root, default_cfg.env.cache_root)
19
+ registry.register_path("cache_root", cache_root)
20
+
21
+ registry.register("MAX_INT", sys.maxsize)
22
+ registry.register("SPLIT_NAMES", ["train", "val", "test"])
cheetah/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (730 Bytes). View file
 
cheetah/common/__init__.py ADDED
File without changes
cheetah/common/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (169 Bytes). View file
 
cheetah/common/__pycache__/config.cpython-310.pyc ADDED
Binary file (12.1 kB). View file
 
cheetah/common/__pycache__/dist_utils.cpython-310.pyc ADDED
Binary file (3.78 kB). View file
 
cheetah/common/__pycache__/logger.cpython-310.pyc ADDED
Binary file (6.43 kB). View file
 
cheetah/common/__pycache__/registry.cpython-310.pyc ADDED
Binary file (6.31 kB). View file
 
cheetah/common/__pycache__/utils.cpython-310.pyc ADDED
Binary file (12.8 kB). View file
 
cheetah/common/config.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import json
10
+ from typing import Dict
11
+
12
+ from omegaconf import OmegaConf
13
+ from cheetah.common.registry import registry
14
+
15
+
16
+ class Config:
17
+ def __init__(self, args):
18
+ self.config = {}
19
+
20
+ self.args = args
21
+
22
+ # Register the config and configuration for setup
23
+ registry.register("configuration", self)
24
+
25
+ user_config = self._build_opt_list(self.args.options)
26
+
27
+ config = OmegaConf.load(self.args.cfg_path)
28
+
29
+ runner_config = self.build_runner_config(config)
30
+ model_config = self.build_model_config(config, **user_config)
31
+ dataset_config = self.build_dataset_config(config)
32
+
33
+ # Validate the user-provided runner configuration
34
+ # model and dataset configuration are supposed to be validated by the respective classes
35
+ # [TODO] validate the model/dataset configuration
36
+ # self._validate_runner_config(runner_config)
37
+
38
+ # Override the default configuration with user options.
39
+ self.config = OmegaConf.merge(
40
+ runner_config, model_config, dataset_config, user_config
41
+ )
42
+
43
+ def _validate_runner_config(self, runner_config):
44
+ """
45
+ This method validates the configuration, such that
46
+ 1) all the user specified options are valid;
47
+ 2) no type mismatches between the user specified options and the config.
48
+ """
49
+ runner_config_validator = create_runner_config_validator()
50
+ runner_config_validator.validate(runner_config)
51
+
52
+ def _build_opt_list(self, opts):
53
+ opts_dot_list = self._convert_to_dot_list(opts)
54
+ return OmegaConf.from_dotlist(opts_dot_list)
55
+
56
+ @staticmethod
57
+ def build_model_config(config, **kwargs):
58
+ model = config.get("model", None)
59
+ assert model is not None, "Missing model configuration file."
60
+
61
+ model_cls = registry.get_model_class(model.arch)
62
+ assert model_cls is not None, f"Model '{model.arch}' has not been registered."
63
+
64
+ model_type = kwargs.get("model.model_type", None)
65
+ if not model_type:
66
+ model_type = model.get("model_type", None)
67
+ # else use the model type selected by user.
68
+
69
+ assert model_type is not None, "Missing model_type."
70
+
71
+ model_config_path = model_cls.default_config_path(model_type=model_type)
72
+
73
+ model_config = OmegaConf.create()
74
+ # hierarchy override, customized config > default config
75
+ model_config = OmegaConf.merge(
76
+ model_config,
77
+ OmegaConf.load(model_config_path),
78
+ {"model": config["model"]},
79
+ )
80
+
81
+ return model_config
82
+
83
+ @staticmethod
84
+ def build_runner_config(config):
85
+ return {"run": config.run}
86
+
87
+ @staticmethod
88
+ def build_dataset_config(config):
89
+ datasets = config.get("datasets", None)
90
+ if datasets is None:
91
+ raise KeyError(
92
+ "Expecting 'datasets' as the root key for dataset configuration."
93
+ )
94
+
95
+ dataset_config = OmegaConf.create()
96
+
97
+ for dataset_name in datasets:
98
+ builder_cls = registry.get_builder_class(dataset_name)
99
+
100
+ dataset_config_type = datasets[dataset_name].get("type", "default")
101
+ dataset_config_path = builder_cls.default_config_path(
102
+ type=dataset_config_type
103
+ )
104
+
105
+ # hierarchy override, customized config > default config
106
+ dataset_config = OmegaConf.merge(
107
+ dataset_config,
108
+ OmegaConf.load(dataset_config_path),
109
+ {"datasets": {dataset_name: config["datasets"][dataset_name]}},
110
+ )
111
+
112
+ return dataset_config
113
+
114
+ def _convert_to_dot_list(self, opts):
115
+ if opts is None:
116
+ opts = []
117
+
118
+ if len(opts) == 0:
119
+ return opts
120
+
121
+ has_equal = opts[0].find("=") != -1
122
+
123
+ if has_equal:
124
+ return opts
125
+
126
+ return [(opt + "=" + value) for opt, value in zip(opts[0::2], opts[1::2])]
127
+
128
+ def get_config(self):
129
+ return self.config
130
+
131
+ @property
132
+ def run_cfg(self):
133
+ return self.config.run
134
+
135
+ @property
136
+ def datasets_cfg(self):
137
+ return self.config.datasets
138
+
139
+ @property
140
+ def model_cfg(self):
141
+ return self.config.model
142
+
143
+ def pretty_print(self):
144
+ logging.info("\n===== Running Parameters =====")
145
+ logging.info(self._convert_node_to_json(self.config.run))
146
+
147
+ logging.info("\n====== Dataset Attributes ======")
148
+ datasets = self.config.datasets
149
+
150
+ for dataset in datasets:
151
+ if dataset in self.config.datasets:
152
+ logging.info(f"\n======== {dataset} =======")
153
+ dataset_config = self.config.datasets[dataset]
154
+ logging.info(self._convert_node_to_json(dataset_config))
155
+ else:
156
+ logging.warning(f"No dataset named '{dataset}' in config. Skipping")
157
+
158
+ logging.info(f"\n====== Model Attributes ======")
159
+ logging.info(self._convert_node_to_json(self.config.model))
160
+
161
+ def _convert_node_to_json(self, node):
162
+ container = OmegaConf.to_container(node, resolve=True)
163
+ return json.dumps(container, indent=4, sort_keys=True)
164
+
165
+ def to_dict(self):
166
+ return OmegaConf.to_container(self.config)
167
+
168
+
169
+ def node_to_dict(node):
170
+ return OmegaConf.to_container(node)
171
+
172
+
173
+ class ConfigValidator:
174
+ """
175
+ This is a preliminary implementation to centralize and validate the configuration.
176
+ May be altered in the future.
177
+
178
+ A helper class to validate configurations from yaml file.
179
+
180
+ This serves the following purposes:
181
+ 1. Ensure all the options in the yaml are defined, raise error if not.
182
+ 2. when type mismatches are found, the validator will raise an error.
183
+ 3. a central place to store and display helpful messages for supported configurations.
184
+
185
+ """
186
+
187
+ class _Argument:
188
+ def __init__(self, name, choices=None, type=None, help=None):
189
+ self.name = name
190
+ self.val = None
191
+ self.choices = choices
192
+ self.type = type
193
+ self.help = help
194
+
195
+ def __str__(self):
196
+ s = f"{self.name}={self.val}"
197
+ if self.type is not None:
198
+ s += f", ({self.type})"
199
+ if self.choices is not None:
200
+ s += f", choices: {self.choices}"
201
+ if self.help is not None:
202
+ s += f", ({self.help})"
203
+ return s
204
+
205
+ def __init__(self, description):
206
+ self.description = description
207
+
208
+ self.arguments = dict()
209
+
210
+ self.parsed_args = None
211
+
212
+ def __getitem__(self, key):
213
+ assert self.parsed_args is not None, "No arguments parsed yet."
214
+
215
+ return self.parsed_args[key]
216
+
217
+ def __str__(self) -> str:
218
+ return self.format_help()
219
+
220
+ def add_argument(self, *args, **kwargs):
221
+ """
222
+ Assume the first argument is the name of the argument.
223
+ """
224
+ self.arguments[args[0]] = self._Argument(*args, **kwargs)
225
+
226
+ def validate(self, config=None):
227
+ """
228
+ Convert yaml config (dict-like) to list, required by argparse.
229
+ """
230
+ for k, v in config.items():
231
+ assert (
232
+ k in self.arguments
233
+ ), f"""{k} is not a valid argument. Support arguments are {self.format_arguments()}."""
234
+
235
+ if self.arguments[k].type is not None:
236
+ try:
237
+ self.arguments[k].val = self.arguments[k].type(v)
238
+ except ValueError:
239
+ raise ValueError(f"{k} is not a valid {self.arguments[k].type}.")
240
+
241
+ if self.arguments[k].choices is not None:
242
+ assert (
243
+ v in self.arguments[k].choices
244
+ ), f"""{k} must be one of {self.arguments[k].choices}."""
245
+
246
+ return config
247
+
248
+ def format_arguments(self):
249
+ return str([f"{k}" for k in sorted(self.arguments.keys())])
250
+
251
+ def format_help(self):
252
+ # description + key-value pair string for each argument
253
+ help_msg = str(self.description)
254
+ return help_msg + ", available arguments: " + self.format_arguments()
255
+
256
+ def print_help(self):
257
+ # display help message
258
+ print(self.format_help())
259
+
260
+
261
+ def create_runner_config_validator():
262
+ validator = ConfigValidator(description="Runner configurations")
263
+
264
+ validator.add_argument(
265
+ "runner",
266
+ type=str,
267
+ choices=["runner_base", "runner_iter"],
268
+ help="""Runner to use. The "runner_base" uses epoch-based training while iter-based
269
+ runner runs based on iters. Default: runner_base""",
270
+ )
271
+ # add argumetns for training dataset ratios
272
+ validator.add_argument(
273
+ "train_dataset_ratios",
274
+ type=Dict[str, float],
275
+ help="""Ratios of training dataset. This is used in iteration-based runner.
276
+ Do not support for epoch-based runner because how to define an epoch becomes tricky.
277
+ Default: None""",
278
+ )
279
+ validator.add_argument(
280
+ "max_iters",
281
+ type=float,
282
+ help="Maximum number of iterations to run.",
283
+ )
284
+ validator.add_argument(
285
+ "max_epoch",
286
+ type=int,
287
+ help="Maximum number of epochs to run.",
288
+ )
289
+ # add arguments for iters_per_inner_epoch
290
+ validator.add_argument(
291
+ "iters_per_inner_epoch",
292
+ type=float,
293
+ help="Number of iterations per inner epoch. This is required when runner is runner_iter.",
294
+ )
295
+ lr_scheds_choices = registry.list_lr_schedulers()
296
+ validator.add_argument(
297
+ "lr_sched",
298
+ type=str,
299
+ choices=lr_scheds_choices,
300
+ help="Learning rate scheduler to use, from {}".format(lr_scheds_choices),
301
+ )
302
+ task_choices = registry.list_tasks()
303
+ validator.add_argument(
304
+ "task",
305
+ type=str,
306
+ choices=task_choices,
307
+ help="Task to use, from {}".format(task_choices),
308
+ )
309
+ # add arguments for init_lr
310
+ validator.add_argument(
311
+ "init_lr",
312
+ type=float,
313
+ help="Initial learning rate. This will be the learning rate after warmup and before decay.",
314
+ )
315
+ # add arguments for min_lr
316
+ validator.add_argument(
317
+ "min_lr",
318
+ type=float,
319
+ help="Minimum learning rate (after decay).",
320
+ )
321
+ # add arguments for warmup_lr
322
+ validator.add_argument(
323
+ "warmup_lr",
324
+ type=float,
325
+ help="Starting learning rate for warmup.",
326
+ )
327
+ # add arguments for learning rate decay rate
328
+ validator.add_argument(
329
+ "lr_decay_rate",
330
+ type=float,
331
+ help="Learning rate decay rate. Required if using a decaying learning rate scheduler.",
332
+ )
333
+ # add arguments for weight decay
334
+ validator.add_argument(
335
+ "weight_decay",
336
+ type=float,
337
+ help="Weight decay rate.",
338
+ )
339
+ # add arguments for training batch size
340
+ validator.add_argument(
341
+ "batch_size_train",
342
+ type=int,
343
+ help="Training batch size.",
344
+ )
345
+ # add arguments for evaluation batch size
346
+ validator.add_argument(
347
+ "batch_size_eval",
348
+ type=int,
349
+ help="Evaluation batch size, including validation and testing.",
350
+ )
351
+ # add arguments for number of workers for data loading
352
+ validator.add_argument(
353
+ "num_workers",
354
+ help="Number of workers for data loading.",
355
+ )
356
+ # add arguments for warm up steps
357
+ validator.add_argument(
358
+ "warmup_steps",
359
+ type=int,
360
+ help="Number of warmup steps. Required if a warmup schedule is used.",
361
+ )
362
+ # add arguments for random seed
363
+ validator.add_argument(
364
+ "seed",
365
+ type=int,
366
+ help="Random seed.",
367
+ )
368
+ # add arguments for output directory
369
+ validator.add_argument(
370
+ "output_dir",
371
+ type=str,
372
+ help="Output directory to save checkpoints and logs.",
373
+ )
374
+ # add arguments for whether only use evaluation
375
+ validator.add_argument(
376
+ "evaluate",
377
+ help="Whether to only evaluate the model. If true, training will not be performed.",
378
+ )
379
+ # add arguments for splits used for training, e.g. ["train", "val"]
380
+ validator.add_argument(
381
+ "train_splits",
382
+ type=list,
383
+ help="Splits to use for training.",
384
+ )
385
+ # add arguments for splits used for validation, e.g. ["val"]
386
+ validator.add_argument(
387
+ "valid_splits",
388
+ type=list,
389
+ help="Splits to use for validation. If not provided, will skip the validation.",
390
+ )
391
+ # add arguments for splits used for testing, e.g. ["test"]
392
+ validator.add_argument(
393
+ "test_splits",
394
+ type=list,
395
+ help="Splits to use for testing. If not provided, will skip the testing.",
396
+ )
397
+ # add arguments for accumulating gradient for iterations
398
+ validator.add_argument(
399
+ "accum_grad_iters",
400
+ type=int,
401
+ help="Number of iterations to accumulate gradient for.",
402
+ )
403
+
404
+ # ====== distributed training ======
405
+ validator.add_argument(
406
+ "device",
407
+ type=str,
408
+ choices=["cpu", "cuda"],
409
+ help="Device to use. Support 'cuda' or 'cpu' as for now.",
410
+ )
411
+ validator.add_argument(
412
+ "world_size",
413
+ type=int,
414
+ help="Number of processes participating in the job.",
415
+ )
416
+ validator.add_argument("dist_url", type=str)
417
+ validator.add_argument("distributed", type=bool)
418
+ # add arguments to opt using distributed sampler during evaluation or not
419
+ validator.add_argument(
420
+ "use_dist_eval_sampler",
421
+ type=bool,
422
+ help="Whether to use distributed sampler during evaluation or not.",
423
+ )
424
+
425
+ # ====== task specific ======
426
+ # generation task specific arguments
427
+ # add arguments for maximal length of text output
428
+ validator.add_argument(
429
+ "max_len",
430
+ type=int,
431
+ help="Maximal length of text output.",
432
+ )
433
+ # add arguments for minimal length of text output
434
+ validator.add_argument(
435
+ "min_len",
436
+ type=int,
437
+ help="Minimal length of text output.",
438
+ )
439
+ # add arguments number of beams
440
+ validator.add_argument(
441
+ "num_beams",
442
+ type=int,
443
+ help="Number of beams used for beam search.",
444
+ )
445
+
446
+ # vqa task specific arguments
447
+ # add arguments for number of answer candidates
448
+ validator.add_argument(
449
+ "num_ans_candidates",
450
+ type=int,
451
+ help="""For ALBEF and BLIP, these models first rank answers according to likelihood to select answer candidates.""",
452
+ )
453
+ # add arguments for inference method
454
+ validator.add_argument(
455
+ "inference_method",
456
+ type=str,
457
+ choices=["genearte", "rank"],
458
+ help="""Inference method to use for question answering. If rank, requires a answer list.""",
459
+ )
460
+
461
+ # ====== model specific ======
462
+ validator.add_argument(
463
+ "k_test",
464
+ type=int,
465
+ help="Number of top k most similar samples from ITC/VTC selection to be tested.",
466
+ )
467
+
468
+ return validator
cheetah/common/dist_utils.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import datetime
9
+ import functools
10
+ import os
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ import timm.models.hub as timm_hub
15
+
16
+
17
+ def setup_for_distributed(is_master):
18
+ """
19
+ This function disables printing when not in master process
20
+ """
21
+ import builtins as __builtin__
22
+
23
+ builtin_print = __builtin__.print
24
+
25
+ def print(*args, **kwargs):
26
+ force = kwargs.pop("force", False)
27
+ if is_master or force:
28
+ builtin_print(*args, **kwargs)
29
+
30
+ __builtin__.print = print
31
+
32
+
33
+ def is_dist_avail_and_initialized():
34
+ if not dist.is_available():
35
+ return False
36
+ if not dist.is_initialized():
37
+ return False
38
+ return True
39
+
40
+
41
+ def get_world_size():
42
+ if not is_dist_avail_and_initialized():
43
+ return 1
44
+ return dist.get_world_size()
45
+
46
+
47
+ def get_rank():
48
+ if not is_dist_avail_and_initialized():
49
+ return 0
50
+ return dist.get_rank()
51
+
52
+
53
+ def is_main_process():
54
+ return get_rank() == 0
55
+
56
+
57
+ def init_distributed_mode(args):
58
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
59
+ args.rank = int(os.environ["RANK"])
60
+ args.world_size = int(os.environ["WORLD_SIZE"])
61
+ args.gpu = int(os.environ["LOCAL_RANK"])
62
+ elif "SLURM_PROCID" in os.environ:
63
+ args.rank = int(os.environ["SLURM_PROCID"])
64
+ args.gpu = args.rank % torch.cuda.device_count()
65
+ else:
66
+ print("Not using distributed mode")
67
+ args.distributed = False
68
+ return
69
+
70
+ args.distributed = True
71
+
72
+ torch.cuda.set_device(args.gpu)
73
+ args.dist_backend = "nccl"
74
+ print(
75
+ "| distributed init (rank {}, world {}): {}".format(
76
+ args.rank, args.world_size, args.dist_url
77
+ ),
78
+ flush=True,
79
+ )
80
+ torch.distributed.init_process_group(
81
+ backend=args.dist_backend,
82
+ init_method=args.dist_url,
83
+ world_size=args.world_size,
84
+ rank=args.rank,
85
+ timeout=datetime.timedelta(
86
+ days=365
87
+ ), # allow auto-downloading and de-compressing
88
+ )
89
+ torch.distributed.barrier()
90
+ setup_for_distributed(args.rank == 0)
91
+
92
+
93
+ def get_dist_info():
94
+ if torch.__version__ < "1.0":
95
+ initialized = dist._initialized
96
+ else:
97
+ initialized = dist.is_initialized()
98
+ if initialized:
99
+ rank = dist.get_rank()
100
+ world_size = dist.get_world_size()
101
+ else: # non-distributed training
102
+ rank = 0
103
+ world_size = 1
104
+ return rank, world_size
105
+
106
+
107
+ def main_process(func):
108
+ @functools.wraps(func)
109
+ def wrapper(*args, **kwargs):
110
+ rank, _ = get_dist_info()
111
+ if rank == 0:
112
+ return func(*args, **kwargs)
113
+
114
+ return wrapper
115
+
116
+
117
+ def download_cached_file(url, check_hash=True, progress=False):
118
+ """
119
+ Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again.
120
+ If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded.
121
+ """
122
+
123
+ def get_cached_file_path():
124
+ # a hack to sync the file path across processes
125
+ parts = torch.hub.urlparse(url)
126
+ filename = os.path.basename(parts.path)
127
+ cached_file = os.path.join(timm_hub.get_cache_dir(), filename)
128
+
129
+ return cached_file
130
+
131
+ if is_main_process():
132
+ timm_hub.download_cached_file(url, check_hash, progress)
133
+
134
+ if is_dist_avail_and_initialized():
135
+ dist.barrier()
136
+
137
+ return get_cached_file_path()
cheetah/common/gradcam.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from matplotlib import pyplot as plt
3
+ from scipy.ndimage import filters
4
+ from skimage import transform as skimage_transform
5
+
6
+
7
+ def getAttMap(img, attMap, blur=True, overlap=True):
8
+ attMap -= attMap.min()
9
+ if attMap.max() > 0:
10
+ attMap /= attMap.max()
11
+ attMap = skimage_transform.resize(attMap, (img.shape[:2]), order=3, mode="constant")
12
+ if blur:
13
+ attMap = filters.gaussian_filter(attMap, 0.02 * max(img.shape[:2]))
14
+ attMap -= attMap.min()
15
+ attMap /= attMap.max()
16
+ cmap = plt.get_cmap("jet")
17
+ attMapV = cmap(attMap)
18
+ attMapV = np.delete(attMapV, 3, 2)
19
+ if overlap:
20
+ attMap = (
21
+ 1 * (1 - attMap**0.7).reshape(attMap.shape + (1,)) * img
22
+ + (attMap**0.7).reshape(attMap.shape + (1,)) * attMapV
23
+ )
24
+ return attMap
cheetah/common/logger.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import datetime
9
+ import logging
10
+ import time
11
+ from collections import defaultdict, deque
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+
16
+ from cheetah.common import dist_utils
17
+
18
+
19
+ class SmoothedValue(object):
20
+ """Track a series of values and provide access to smoothed values over a
21
+ window or the global series average.
22
+ """
23
+
24
+ def __init__(self, window_size=20, fmt=None):
25
+ if fmt is None:
26
+ fmt = "{median:.4f} ({global_avg:.4f})"
27
+ self.deque = deque(maxlen=window_size)
28
+ self.total = 0.0
29
+ self.count = 0
30
+ self.fmt = fmt
31
+
32
+ def update(self, value, n=1):
33
+ self.deque.append(value)
34
+ self.count += n
35
+ self.total += value * n
36
+
37
+ def synchronize_between_processes(self):
38
+ """
39
+ Warning: does not synchronize the deque!
40
+ """
41
+ if not dist_utils.is_dist_avail_and_initialized():
42
+ return
43
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
44
+ dist.barrier()
45
+ dist.all_reduce(t)
46
+ t = t.tolist()
47
+ self.count = int(t[0])
48
+ self.total = t[1]
49
+
50
+ @property
51
+ def median(self):
52
+ d = torch.tensor(list(self.deque))
53
+ return d.median().item()
54
+
55
+ @property
56
+ def avg(self):
57
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
58
+ return d.mean().item()
59
+
60
+ @property
61
+ def global_avg(self):
62
+ return self.total / self.count
63
+
64
+ @property
65
+ def max(self):
66
+ return max(self.deque)
67
+
68
+ @property
69
+ def value(self):
70
+ return self.deque[-1]
71
+
72
+ def __str__(self):
73
+ return self.fmt.format(
74
+ median=self.median,
75
+ avg=self.avg,
76
+ global_avg=self.global_avg,
77
+ max=self.max,
78
+ value=self.value,
79
+ )
80
+
81
+
82
+ class MetricLogger(object):
83
+ def __init__(self, delimiter="\t"):
84
+ self.meters = defaultdict(SmoothedValue)
85
+ self.delimiter = delimiter
86
+
87
+ def update(self, **kwargs):
88
+ for k, v in kwargs.items():
89
+ if isinstance(v, torch.Tensor):
90
+ v = v.item()
91
+ assert isinstance(v, (float, int))
92
+ self.meters[k].update(v)
93
+
94
+ def __getattr__(self, attr):
95
+ if attr in self.meters:
96
+ return self.meters[attr]
97
+ if attr in self.__dict__:
98
+ return self.__dict__[attr]
99
+ raise AttributeError(
100
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
101
+ )
102
+
103
+ def __str__(self):
104
+ loss_str = []
105
+ for name, meter in self.meters.items():
106
+ loss_str.append("{}: {}".format(name, str(meter)))
107
+ return self.delimiter.join(loss_str)
108
+
109
+ def global_avg(self):
110
+ loss_str = []
111
+ for name, meter in self.meters.items():
112
+ loss_str.append("{}: {:.4f}".format(name, meter.global_avg))
113
+ return self.delimiter.join(loss_str)
114
+
115
+ def synchronize_between_processes(self):
116
+ for meter in self.meters.values():
117
+ meter.synchronize_between_processes()
118
+
119
+ def add_meter(self, name, meter):
120
+ self.meters[name] = meter
121
+
122
+ def log_every(self, iterable, print_freq, header=None):
123
+ i = 0
124
+ if not header:
125
+ header = ""
126
+ start_time = time.time()
127
+ end = time.time()
128
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
129
+ data_time = SmoothedValue(fmt="{avg:.4f}")
130
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
131
+ log_msg = [
132
+ header,
133
+ "[{0" + space_fmt + "}/{1}]",
134
+ "eta: {eta}",
135
+ "{meters}",
136
+ "time: {time}",
137
+ "data: {data}",
138
+ ]
139
+ if torch.cuda.is_available():
140
+ log_msg.append("max mem: {memory:.0f}")
141
+ log_msg = self.delimiter.join(log_msg)
142
+ MB = 1024.0 * 1024.0
143
+ for obj in iterable:
144
+ data_time.update(time.time() - end)
145
+ yield obj
146
+ iter_time.update(time.time() - end)
147
+ if i % print_freq == 0 or i == len(iterable) - 1:
148
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
149
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
150
+ if torch.cuda.is_available():
151
+ print(
152
+ log_msg.format(
153
+ i,
154
+ len(iterable),
155
+ eta=eta_string,
156
+ meters=str(self),
157
+ time=str(iter_time),
158
+ data=str(data_time),
159
+ memory=torch.cuda.max_memory_allocated() / MB,
160
+ )
161
+ )
162
+ else:
163
+ print(
164
+ log_msg.format(
165
+ i,
166
+ len(iterable),
167
+ eta=eta_string,
168
+ meters=str(self),
169
+ time=str(iter_time),
170
+ data=str(data_time),
171
+ )
172
+ )
173
+ i += 1
174
+ end = time.time()
175
+ total_time = time.time() - start_time
176
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
177
+ print(
178
+ "{} Total time: {} ({:.4f} s / it)".format(
179
+ header, total_time_str, total_time / len(iterable)
180
+ )
181
+ )
182
+
183
+
184
+ class AttrDict(dict):
185
+ def __init__(self, *args, **kwargs):
186
+ super(AttrDict, self).__init__(*args, **kwargs)
187
+ self.__dict__ = self
188
+
189
+
190
+ def setup_logger():
191
+ logging.basicConfig(
192
+ level=logging.INFO if dist_utils.is_main_process() else logging.WARN,
193
+ format="%(asctime)s [%(levelname)s] %(message)s",
194
+ handlers=[logging.StreamHandler()],
195
+ )
cheetah/common/optims.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import math
9
+
10
+ from cheetah.common.registry import registry
11
+
12
+
13
+ @registry.register_lr_scheduler("linear_warmup_step_lr")
14
+ class LinearWarmupStepLRScheduler:
15
+ def __init__(
16
+ self,
17
+ optimizer,
18
+ max_epoch,
19
+ min_lr,
20
+ init_lr,
21
+ decay_rate=1,
22
+ warmup_start_lr=-1,
23
+ warmup_steps=0,
24
+ **kwargs
25
+ ):
26
+ self.optimizer = optimizer
27
+
28
+ self.max_epoch = max_epoch
29
+ self.min_lr = min_lr
30
+
31
+ self.decay_rate = decay_rate
32
+
33
+ self.init_lr = init_lr
34
+ self.warmup_steps = warmup_steps
35
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
36
+
37
+ def step(self, cur_epoch, cur_step):
38
+ if cur_epoch == 0:
39
+ warmup_lr_schedule(
40
+ step=cur_step,
41
+ optimizer=self.optimizer,
42
+ max_step=self.warmup_steps,
43
+ init_lr=self.warmup_start_lr,
44
+ max_lr=self.init_lr,
45
+ )
46
+ else:
47
+ step_lr_schedule(
48
+ epoch=cur_epoch,
49
+ optimizer=self.optimizer,
50
+ init_lr=self.init_lr,
51
+ min_lr=self.min_lr,
52
+ decay_rate=self.decay_rate,
53
+ )
54
+
55
+
56
+ @registry.register_lr_scheduler("linear_warmup_cosine_lr")
57
+ class LinearWarmupCosineLRScheduler:
58
+ def __init__(
59
+ self,
60
+ optimizer,
61
+ max_epoch,
62
+ iters_per_epoch,
63
+ min_lr,
64
+ init_lr,
65
+ warmup_steps=0,
66
+ warmup_start_lr=-1,
67
+ **kwargs
68
+ ):
69
+ self.optimizer = optimizer
70
+
71
+ self.max_epoch = max_epoch
72
+ self.iters_per_epoch = iters_per_epoch
73
+ self.min_lr = min_lr
74
+
75
+ self.init_lr = init_lr
76
+ self.warmup_steps = warmup_steps
77
+ self.warmup_start_lr = warmup_start_lr if warmup_start_lr >= 0 else init_lr
78
+
79
+ def step(self, cur_epoch, cur_step):
80
+ total_cur_step = cur_epoch * self.iters_per_epoch + cur_step
81
+ if total_cur_step < self.warmup_steps:
82
+ warmup_lr_schedule(
83
+ step=cur_step,
84
+ optimizer=self.optimizer,
85
+ max_step=self.warmup_steps,
86
+ init_lr=self.warmup_start_lr,
87
+ max_lr=self.init_lr,
88
+ )
89
+ else:
90
+ cosine_lr_schedule(
91
+ epoch=total_cur_step,
92
+ optimizer=self.optimizer,
93
+ max_epoch=self.max_epoch * self.iters_per_epoch,
94
+ init_lr=self.init_lr,
95
+ min_lr=self.min_lr,
96
+ )
97
+
98
+
99
+ def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
100
+ """Decay the learning rate"""
101
+ lr = (init_lr - min_lr) * 0.5 * (
102
+ 1.0 + math.cos(math.pi * epoch / max_epoch)
103
+ ) + min_lr
104
+ for param_group in optimizer.param_groups:
105
+ param_group["lr"] = lr
106
+
107
+
108
+ def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
109
+ """Warmup the learning rate"""
110
+ lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max(max_step, 1))
111
+ for param_group in optimizer.param_groups:
112
+ param_group["lr"] = lr
113
+
114
+
115
+ def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
116
+ """Decay the learning rate"""
117
+ lr = max(min_lr, init_lr * (decay_rate**epoch))
118
+ for param_group in optimizer.param_groups:
119
+ param_group["lr"] = lr
cheetah/common/registry.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+
9
+ class Registry:
10
+ mapping = {
11
+ "builder_name_mapping": {},
12
+ "task_name_mapping": {},
13
+ "processor_name_mapping": {},
14
+ "model_name_mapping": {},
15
+ "lr_scheduler_name_mapping": {},
16
+ "runner_name_mapping": {},
17
+ "state": {},
18
+ "paths": {},
19
+ }
20
+
21
+ @classmethod
22
+ def register_model(cls, name):
23
+ r"""Register a task to registry with key 'name'
24
+
25
+ Args:
26
+ name: Key with which the task will be registered.
27
+
28
+ Usage:
29
+
30
+ from cheetah.common.registry import registry
31
+ """
32
+
33
+ def wrap(model_cls):
34
+ from cheetah.models import BaseModel
35
+
36
+ assert issubclass(
37
+ model_cls, BaseModel
38
+ ), "All models must inherit BaseModel class"
39
+ if name in cls.mapping["model_name_mapping"]:
40
+ raise KeyError(
41
+ "Name '{}' already registered for {}.".format(
42
+ name, cls.mapping["model_name_mapping"][name]
43
+ )
44
+ )
45
+ cls.mapping["model_name_mapping"][name] = model_cls
46
+ return model_cls
47
+
48
+ return wrap
49
+
50
+ @classmethod
51
+ def register_processor(cls, name):
52
+ r"""Register a processor to registry with key 'name'
53
+
54
+ Args:
55
+ name: Key with which the task will be registered.
56
+
57
+ Usage:
58
+
59
+ from cheetah.common.registry import registry
60
+ """
61
+
62
+ def wrap(processor_cls):
63
+ from cheetah.processors import BaseProcessor
64
+
65
+ assert issubclass(
66
+ processor_cls, BaseProcessor
67
+ ), "All processors must inherit BaseProcessor class"
68
+ if name in cls.mapping["processor_name_mapping"]:
69
+ raise KeyError(
70
+ "Name '{}' already registered for {}.".format(
71
+ name, cls.mapping["processor_name_mapping"][name]
72
+ )
73
+ )
74
+ cls.mapping["processor_name_mapping"][name] = processor_cls
75
+ return processor_cls
76
+
77
+ return wrap
78
+
79
+ @classmethod
80
+ def register_path(cls, name, path):
81
+ r"""Register a path to registry with key 'name'
82
+
83
+ Args:
84
+ name: Key with which the path will be registered.
85
+
86
+ Usage:
87
+
88
+ from cheetah.common.registry import registry
89
+ """
90
+ assert isinstance(path, str), "All path must be str."
91
+ if name in cls.mapping["paths"]:
92
+ raise KeyError("Name '{}' already registered.".format(name))
93
+ cls.mapping["paths"][name] = path
94
+
95
+ @classmethod
96
+ def register(cls, name, obj):
97
+ r"""Register an item to registry with key 'name'
98
+
99
+ Args:
100
+ name: Key with which the item will be registered.
101
+
102
+ Usage::
103
+
104
+ from cheetah.common.registry import registry
105
+
106
+ registry.register("config", {})
107
+ """
108
+ path = name.split(".")
109
+ current = cls.mapping["state"]
110
+
111
+ for part in path[:-1]:
112
+ if part not in current:
113
+ current[part] = {}
114
+ current = current[part]
115
+
116
+ current[path[-1]] = obj
117
+
118
+ @classmethod
119
+ def get_builder_class(cls, name):
120
+ return cls.mapping["builder_name_mapping"].get(name, None)
121
+
122
+ @classmethod
123
+ def get_model_class(cls, name):
124
+ return cls.mapping["model_name_mapping"].get(name, None)
125
+
126
+ @classmethod
127
+ def get_task_class(cls, name):
128
+ return cls.mapping["task_name_mapping"].get(name, None)
129
+
130
+ @classmethod
131
+ def get_processor_class(cls, name):
132
+ return cls.mapping["processor_name_mapping"].get(name, None)
133
+
134
+ @classmethod
135
+ def get_lr_scheduler_class(cls, name):
136
+ return cls.mapping["lr_scheduler_name_mapping"].get(name, None)
137
+
138
+ @classmethod
139
+ def get_runner_class(cls, name):
140
+ return cls.mapping["runner_name_mapping"].get(name, None)
141
+
142
+ @classmethod
143
+ def list_runners(cls):
144
+ return sorted(cls.mapping["runner_name_mapping"].keys())
145
+
146
+ @classmethod
147
+ def list_models(cls):
148
+ return sorted(cls.mapping["model_name_mapping"].keys())
149
+
150
+ @classmethod
151
+ def list_tasks(cls):
152
+ return sorted(cls.mapping["task_name_mapping"].keys())
153
+
154
+ @classmethod
155
+ def list_processors(cls):
156
+ return sorted(cls.mapping["processor_name_mapping"].keys())
157
+
158
+ @classmethod
159
+ def list_lr_schedulers(cls):
160
+ return sorted(cls.mapping["lr_scheduler_name_mapping"].keys())
161
+
162
+ @classmethod
163
+ def list_datasets(cls):
164
+ return sorted(cls.mapping["builder_name_mapping"].keys())
165
+
166
+ @classmethod
167
+ def get_path(cls, name):
168
+ return cls.mapping["paths"].get(name, None)
169
+
170
+ @classmethod
171
+ def get(cls, name, default=None, no_warning=False):
172
+ r"""Get an item from registry with key 'name'
173
+
174
+ Args:
175
+ name (string): Key whose value needs to be retrieved.
176
+ default: If passed and key is not in registry, default value will
177
+ be returned with a warning. Default: None
178
+ no_warning (bool): If passed as True, warning when key doesn't exist
179
+ will not be generated. Useful for MMF's
180
+ internal operations. Default: False
181
+ """
182
+ original_name = name
183
+ name = name.split(".")
184
+ value = cls.mapping["state"]
185
+ for subname in name:
186
+ value = value.get(subname, default)
187
+ if value is default:
188
+ break
189
+
190
+ if (
191
+ "writer" in cls.mapping["state"]
192
+ and value == default
193
+ and no_warning is False
194
+ ):
195
+ cls.mapping["state"]["writer"].warning(
196
+ "Key {} is not present in registry, returning default value "
197
+ "of {}".format(original_name, default)
198
+ )
199
+ return value
200
+
201
+ @classmethod
202
+ def unregister(cls, name):
203
+ r"""Remove an item from registry with key 'name'
204
+
205
+ Args:
206
+ name: Key which needs to be removed.
207
+ Usage::
208
+
209
+ from mmf.common.registry import registry
210
+
211
+ config = registry.unregister("config")
212
+ """
213
+ return cls.mapping["state"].pop(name, None)
214
+
215
+
216
+ registry = Registry()
cheetah/common/utils.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import io
9
+ import json
10
+ import logging
11
+ import os
12
+ import pickle
13
+ import re
14
+ import shutil
15
+ import urllib
16
+ import urllib.error
17
+ import urllib.request
18
+ from typing import Optional
19
+ from urllib.parse import urlparse
20
+
21
+ import numpy as np
22
+ import pandas as pd
23
+ import yaml
24
+ from iopath.common.download import download
25
+ from iopath.common.file_io import file_lock, g_pathmgr
26
+ from cheetah.common.registry import registry
27
+ from torch.utils.model_zoo import tqdm
28
+ from torchvision.datasets.utils import (
29
+ check_integrity,
30
+ download_file_from_google_drive,
31
+ extract_archive,
32
+ )
33
+
34
+
35
+ def now():
36
+ from datetime import datetime
37
+
38
+ return datetime.now().strftime("%Y%m%d%H%M")[:-1]
39
+
40
+
41
+ def is_url(url_or_filename):
42
+ parsed = urlparse(url_or_filename)
43
+ return parsed.scheme in ("http", "https")
44
+
45
+
46
+ def get_cache_path(rel_path):
47
+ return os.path.expanduser(os.path.join(registry.get_path("cache_root"), rel_path))
48
+
49
+
50
+ def get_abs_path(rel_path):
51
+ return os.path.join(registry.get_path("library_root"), rel_path)
52
+
53
+
54
+ def load_json(filename):
55
+ with open(filename, "r") as f:
56
+ return json.load(f)
57
+
58
+
59
+ # The following are adapted from torchvision and vissl
60
+ # torchvision: https://github.com/pytorch/vision
61
+ # vissl: https://github.com/facebookresearch/vissl/blob/main/vissl/utils/download.py
62
+
63
+
64
+ def makedir(dir_path):
65
+ """
66
+ Create the directory if it does not exist.
67
+ """
68
+ is_success = False
69
+ try:
70
+ if not g_pathmgr.exists(dir_path):
71
+ g_pathmgr.mkdirs(dir_path)
72
+ is_success = True
73
+ except BaseException:
74
+ print(f"Error creating directory: {dir_path}")
75
+ return is_success
76
+
77
+
78
+ def get_redirected_url(url: str):
79
+ """
80
+ Given a URL, returns the URL it redirects to or the
81
+ original URL in case of no indirection
82
+ """
83
+ import requests
84
+
85
+ with requests.Session() as session:
86
+ with session.get(url, stream=True, allow_redirects=True) as response:
87
+ if response.history:
88
+ return response.url
89
+ else:
90
+ return url
91
+
92
+
93
+ def to_google_drive_download_url(view_url: str) -> str:
94
+ """
95
+ Utility function to transform a view URL of google drive
96
+ to a download URL for google drive
97
+ Example input:
98
+ https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp/view
99
+ Example output:
100
+ https://drive.google.com/uc?export=download&id=137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp
101
+ """
102
+ splits = view_url.split("/")
103
+ assert splits[-1] == "view"
104
+ file_id = splits[-2]
105
+ return f"https://drive.google.com/uc?export=download&id={file_id}"
106
+
107
+
108
+ def download_google_drive_url(url: str, output_path: str, output_file_name: str):
109
+ """
110
+ Download a file from google drive
111
+ Downloading an URL from google drive requires confirmation when
112
+ the file of the size is too big (google drive notifies that
113
+ anti-viral checks cannot be performed on such files)
114
+ """
115
+ import requests
116
+
117
+ with requests.Session() as session:
118
+
119
+ # First get the confirmation token and append it to the URL
120
+ with session.get(url, stream=True, allow_redirects=True) as response:
121
+ for k, v in response.cookies.items():
122
+ if k.startswith("download_warning"):
123
+ url = url + "&confirm=" + v
124
+
125
+ # Then download the content of the file
126
+ with session.get(url, stream=True, verify=True) as response:
127
+ makedir(output_path)
128
+ path = os.path.join(output_path, output_file_name)
129
+ total_size = int(response.headers.get("Content-length", 0))
130
+ with open(path, "wb") as file:
131
+ from tqdm import tqdm
132
+
133
+ with tqdm(total=total_size) as progress_bar:
134
+ for block in response.iter_content(
135
+ chunk_size=io.DEFAULT_BUFFER_SIZE
136
+ ):
137
+ file.write(block)
138
+ progress_bar.update(len(block))
139
+
140
+
141
+ def _get_google_drive_file_id(url: str) -> Optional[str]:
142
+ parts = urlparse(url)
143
+
144
+ if re.match(r"(drive|docs)[.]google[.]com", parts.netloc) is None:
145
+ return None
146
+
147
+ match = re.match(r"/file/d/(?P<id>[^/]*)", parts.path)
148
+ if match is None:
149
+ return None
150
+
151
+ return match.group("id")
152
+
153
+
154
+ def _urlretrieve(url: str, filename: str, chunk_size: int = 1024) -> None:
155
+ with open(filename, "wb") as fh:
156
+ with urllib.request.urlopen(
157
+ urllib.request.Request(url, headers={"User-Agent": "vissl"})
158
+ ) as response:
159
+ with tqdm(total=response.length) as pbar:
160
+ for chunk in iter(lambda: response.read(chunk_size), ""):
161
+ if not chunk:
162
+ break
163
+ pbar.update(chunk_size)
164
+ fh.write(chunk)
165
+
166
+
167
+ def download_url(
168
+ url: str,
169
+ root: str,
170
+ filename: Optional[str] = None,
171
+ md5: Optional[str] = None,
172
+ ) -> None:
173
+ """Download a file from a url and place it in root.
174
+ Args:
175
+ url (str): URL to download file from
176
+ root (str): Directory to place downloaded file in
177
+ filename (str, optional): Name to save the file under.
178
+ If None, use the basename of the URL.
179
+ md5 (str, optional): MD5 checksum of the download. If None, do not check
180
+ """
181
+ root = os.path.expanduser(root)
182
+ if not filename:
183
+ filename = os.path.basename(url)
184
+ fpath = os.path.join(root, filename)
185
+
186
+ makedir(root)
187
+
188
+ # check if file is already present locally
189
+ if check_integrity(fpath, md5):
190
+ print("Using downloaded and verified file: " + fpath)
191
+ return
192
+
193
+ # expand redirect chain if needed
194
+ url = get_redirected_url(url)
195
+
196
+ # check if file is located on Google Drive
197
+ file_id = _get_google_drive_file_id(url)
198
+ if file_id is not None:
199
+ return download_file_from_google_drive(file_id, root, filename, md5)
200
+
201
+ # download the file
202
+ try:
203
+ print("Downloading " + url + " to " + fpath)
204
+ _urlretrieve(url, fpath)
205
+ except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined]
206
+ if url[:5] == "https":
207
+ url = url.replace("https:", "http:")
208
+ print(
209
+ "Failed download. Trying https -> http instead."
210
+ " Downloading " + url + " to " + fpath
211
+ )
212
+ _urlretrieve(url, fpath)
213
+ else:
214
+ raise e
215
+
216
+ # check integrity of downloaded file
217
+ if not check_integrity(fpath, md5):
218
+ raise RuntimeError("File not found or corrupted.")
219
+
220
+
221
+ def download_and_extract_archive(
222
+ url: str,
223
+ download_root: str,
224
+ extract_root: Optional[str] = None,
225
+ filename: Optional[str] = None,
226
+ md5: Optional[str] = None,
227
+ remove_finished: bool = False,
228
+ ) -> None:
229
+ download_root = os.path.expanduser(download_root)
230
+ if extract_root is None:
231
+ extract_root = download_root
232
+ if not filename:
233
+ filename = os.path.basename(url)
234
+
235
+ download_url(url, download_root, filename, md5)
236
+
237
+ archive = os.path.join(download_root, filename)
238
+ print("Extracting {} to {}".format(archive, extract_root))
239
+ extract_archive(archive, extract_root, remove_finished)
240
+
241
+
242
+ def cache_url(url: str, cache_dir: str) -> str:
243
+ """
244
+ This implementation downloads the remote resource and caches it locally.
245
+ The resource will only be downloaded if not previously requested.
246
+ """
247
+ parsed_url = urlparse(url)
248
+ dirname = os.path.join(cache_dir, os.path.dirname(parsed_url.path.lstrip("/")))
249
+ makedir(dirname)
250
+ filename = url.split("/")[-1]
251
+ cached = os.path.join(dirname, filename)
252
+ with file_lock(cached):
253
+ if not os.path.isfile(cached):
254
+ logging.info(f"Downloading {url} to {cached} ...")
255
+ cached = download(url, dirname, filename=filename)
256
+ logging.info(f"URL {url} cached in {cached}")
257
+ return cached
258
+
259
+
260
+ # TODO (prigoyal): convert this into RAII-style API
261
+ def create_file_symlink(file1, file2):
262
+ """
263
+ Simply create the symlinks for a given file1 to file2.
264
+ Useful during model checkpointing to symlinks to the
265
+ latest successful checkpoint.
266
+ """
267
+ try:
268
+ if g_pathmgr.exists(file2):
269
+ g_pathmgr.rm(file2)
270
+ g_pathmgr.symlink(file1, file2)
271
+ except Exception as e:
272
+ logging.info(f"Could NOT create symlink. Error: {e}")
273
+
274
+
275
+ def save_file(data, filename, append_to_json=True, verbose=True):
276
+ """
277
+ Common i/o utility to handle saving data to various file formats.
278
+ Supported:
279
+ .pkl, .pickle, .npy, .json
280
+ Specifically for .json, users have the option to either append (default)
281
+ or rewrite by passing in Boolean value to append_to_json.
282
+ """
283
+ if verbose:
284
+ logging.info(f"Saving data to file: {filename}")
285
+ file_ext = os.path.splitext(filename)[1]
286
+ if file_ext in [".pkl", ".pickle"]:
287
+ with g_pathmgr.open(filename, "wb") as fopen:
288
+ pickle.dump(data, fopen, pickle.HIGHEST_PROTOCOL)
289
+ elif file_ext == ".npy":
290
+ with g_pathmgr.open(filename, "wb") as fopen:
291
+ np.save(fopen, data)
292
+ elif file_ext == ".json":
293
+ if append_to_json:
294
+ with g_pathmgr.open(filename, "a") as fopen:
295
+ fopen.write(json.dumps(data, sort_keys=True) + "\n")
296
+ fopen.flush()
297
+ else:
298
+ with g_pathmgr.open(filename, "w") as fopen:
299
+ fopen.write(json.dumps(data, sort_keys=True) + "\n")
300
+ fopen.flush()
301
+ elif file_ext == ".yaml":
302
+ with g_pathmgr.open(filename, "w") as fopen:
303
+ dump = yaml.dump(data)
304
+ fopen.write(dump)
305
+ fopen.flush()
306
+ else:
307
+ raise Exception(f"Saving {file_ext} is not supported yet")
308
+
309
+ if verbose:
310
+ logging.info(f"Saved data to file: {filename}")
311
+
312
+
313
+ def load_file(filename, mmap_mode=None, verbose=True, allow_pickle=False):
314
+ """
315
+ Common i/o utility to handle loading data from various file formats.
316
+ Supported:
317
+ .pkl, .pickle, .npy, .json
318
+ For the npy files, we support reading the files in mmap_mode.
319
+ If the mmap_mode of reading is not successful, we load data without the
320
+ mmap_mode.
321
+ """
322
+ if verbose:
323
+ logging.info(f"Loading data from file: {filename}")
324
+
325
+ file_ext = os.path.splitext(filename)[1]
326
+ if file_ext == ".txt":
327
+ with g_pathmgr.open(filename, "r") as fopen:
328
+ data = fopen.readlines()
329
+ elif file_ext in [".pkl", ".pickle"]:
330
+ with g_pathmgr.open(filename, "rb") as fopen:
331
+ data = pickle.load(fopen, encoding="latin1")
332
+ elif file_ext == ".npy":
333
+ if mmap_mode:
334
+ try:
335
+ with g_pathmgr.open(filename, "rb") as fopen:
336
+ data = np.load(
337
+ fopen,
338
+ allow_pickle=allow_pickle,
339
+ encoding="latin1",
340
+ mmap_mode=mmap_mode,
341
+ )
342
+ except ValueError as e:
343
+ logging.info(
344
+ f"Could not mmap {filename}: {e}. Trying without g_pathmgr"
345
+ )
346
+ data = np.load(
347
+ filename,
348
+ allow_pickle=allow_pickle,
349
+ encoding="latin1",
350
+ mmap_mode=mmap_mode,
351
+ )
352
+ logging.info("Successfully loaded without g_pathmgr")
353
+ except Exception:
354
+ logging.info("Could not mmap without g_pathmgr. Trying without mmap")
355
+ with g_pathmgr.open(filename, "rb") as fopen:
356
+ data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
357
+ else:
358
+ with g_pathmgr.open(filename, "rb") as fopen:
359
+ data = np.load(fopen, allow_pickle=allow_pickle, encoding="latin1")
360
+ elif file_ext == ".json":
361
+ with g_pathmgr.open(filename, "r") as fopen:
362
+ data = json.load(fopen)
363
+ elif file_ext == ".yaml":
364
+ with g_pathmgr.open(filename, "r") as fopen:
365
+ data = yaml.load(fopen, Loader=yaml.FullLoader)
366
+ elif file_ext == ".csv":
367
+ with g_pathmgr.open(filename, "r") as fopen:
368
+ data = pd.read_csv(fopen)
369
+ else:
370
+ raise Exception(f"Reading from {file_ext} is not supported yet")
371
+ return data
372
+
373
+
374
+ def abspath(resource_path: str):
375
+ """
376
+ Make a path absolute, but take into account prefixes like
377
+ "http://" or "manifold://"
378
+ """
379
+ regex = re.compile(r"^\w+://")
380
+ if regex.match(resource_path) is None:
381
+ return os.path.abspath(resource_path)
382
+ else:
383
+ return resource_path
384
+
385
+
386
+ def makedir(dir_path):
387
+ """
388
+ Create the directory if it does not exist.
389
+ """
390
+ is_success = False
391
+ try:
392
+ if not g_pathmgr.exists(dir_path):
393
+ g_pathmgr.mkdirs(dir_path)
394
+ is_success = True
395
+ except BaseException:
396
+ logging.info(f"Error creating directory: {dir_path}")
397
+ return is_success
398
+
399
+
400
+ def is_url(input_url):
401
+ """
402
+ Check if an input string is a url. look for http(s):// and ignoring the case
403
+ """
404
+ is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None
405
+ return is_url
406
+
407
+
408
+ def cleanup_dir(dir):
409
+ """
410
+ Utility for deleting a directory. Useful for cleaning the storage space
411
+ that contains various training artifacts like checkpoints, data etc.
412
+ """
413
+ if os.path.exists(dir):
414
+ logging.info(f"Deleting directory: {dir}")
415
+ shutil.rmtree(dir)
416
+ logging.info(f"Deleted contents of directory: {dir}")
417
+
418
+
419
+ def get_file_size(filename):
420
+ """
421
+ Given a file, get the size of file in MB
422
+ """
423
+ size_in_mb = os.path.getsize(filename) / float(1024**2)
424
+ return size_in_mb
cheetah/configs/default.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ env:
2
+ # For default users
3
+ # cache_root: "cache"
4
+ # For internal use with persistent storage
5
+ cache_root: "/home/user/.cache/cheetah"
cheetah/configs/models/cheetah_llama2.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ arch: cheetah_llama2
3
+ # vit encoder
4
+ image_size: 224
5
+ drop_path_rate: 0
6
+ use_grad_checkpoint: False
7
+ vit_precision: "bf16"
8
+ freeze_vit: True
9
+ freeze_qformer: True
10
+ freeze_llama_proj: True
11
+
12
+ # Q-Former
13
+ num_query_token: 32
14
+
15
+ # llama2
16
+ llama_model: "/content/drive/MyDrive/HuggingFace/llama/weights"
17
+
18
+ # generation configs
19
+ prompt: ""
20
+
21
+ preprocess:
22
+ vis_processor:
23
+ train:
24
+ name: "blip2_image_train"
25
+ image_size: 224
26
+ eval:
27
+ name: "blip2_image_eval"
28
+ image_size: 224
29
+ text_processor:
30
+ train:
31
+ name: "blip_caption"
32
+ eval:
33
+ name: "blip_caption"
cheetah/configs/models/cheetah_vicuna.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ arch: cheetah_vicuna
3
+ # vit encoder
4
+ image_size: 224
5
+ drop_path_rate: 0
6
+ use_grad_checkpoint: False
7
+ vit_precision: "fp16"
8
+ freeze_vit: True
9
+ freeze_qformer: True
10
+ freeze_llama_proj: True
11
+
12
+ # Q-Former
13
+ num_query_token: 32
14
+
15
+ # Vicuna
16
+ llama_model: "/content/drive/MyDrive/HuggingFace/cheetah/weights"
17
+
18
+ # generation configs
19
+ prompt: ""
20
+
21
+ preprocess:
22
+ vis_processor:
23
+ train:
24
+ name: "blip2_image_train"
25
+ image_size: 224
26
+ eval:
27
+ name: "blip2_image_eval"
28
+ image_size: 224
29
+ text_processor:
30
+ train:
31
+ name: "blip_caption"
32
+ eval:
33
+ name: "blip_caption"
cheetah/conversation/__init__.py ADDED
File without changes
cheetah/conversation/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (175 Bytes). View file
 
cheetah/conversation/__pycache__/conversation.cpython-310.pyc ADDED
Binary file (10.6 kB). View file
 
cheetah/conversation/__pycache__/conversation_llama2.cpython-310.pyc ADDED
Binary file (11.2 kB). View file
 
cheetah/conversation/conversation.py ADDED
@@ -0,0 +1,364 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import time
3
+ from PIL import Image
4
+
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
7
+ from transformers import StoppingCriteria, StoppingCriteriaList
8
+
9
+ import dataclasses
10
+ from enum import auto, Enum
11
+ from typing import List, Tuple, Any
12
+
13
+ from cheetah.common.registry import registry
14
+
15
+
16
+ class SeparatorStyle(Enum):
17
+ """Different separator style."""
18
+ SINGLE = auto()
19
+ TWO = auto()
20
+
21
+
22
+ @dataclasses.dataclass
23
+ class Conversation:
24
+ """A class that keeps all conversation history."""
25
+ system: str
26
+ roles: List[str]
27
+ messages: List[List[str]]
28
+ offset: int
29
+ # system_img: List[Image.Image] = []
30
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
31
+ sep: str = "###"
32
+ sep2: str = None
33
+
34
+ skip_next: bool = False
35
+ conv_id: Any = None
36
+
37
+ def get_prompt(self):
38
+ if self.sep_style == SeparatorStyle.SINGLE:
39
+ ret = self.system + self.sep
40
+ for role, message in self.messages:
41
+ if message:
42
+ ret += role + ": " + message + self.sep
43
+ else:
44
+ ret += role + ":"
45
+ return ret
46
+ elif self.sep_style == SeparatorStyle.TWO:
47
+ seps = [self.sep, self.sep2]
48
+ ret = self.system + seps[0]
49
+ for i, (role, message) in enumerate(self.messages):
50
+ if message:
51
+ ret += role + ": " + message + seps[i % 2]
52
+ else:
53
+ ret += role + ":"
54
+ return ret
55
+ else:
56
+ raise ValueError(f"Invalid style: {self.sep_style}")
57
+
58
+ def append_message(self, role, message):
59
+ self.messages.append([role, message])
60
+
61
+ def to_gradio_chatbot(self):
62
+ ret = []
63
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
64
+ if i % 2 == 0:
65
+ ret.append([msg, None])
66
+ else:
67
+ ret[-1][-1] = msg
68
+ return ret
69
+
70
+ def copy(self):
71
+ return Conversation(
72
+ system=self.system,
73
+ # system_img=self.system_img,
74
+ roles=self.roles,
75
+ messages=[[x, y] for x, y in self.messages],
76
+ offset=self.offset,
77
+ sep_style=self.sep_style,
78
+ sep=self.sep,
79
+ sep2=self.sep2,
80
+ conv_id=self.conv_id)
81
+
82
+ def dict(self):
83
+ return {
84
+ "system": self.system,
85
+ # "system_img": self.system_img,
86
+ "roles": self.roles,
87
+ "messages": self.messages,
88
+ "offset": self.offset,
89
+ "sep": self.sep,
90
+ "sep2": self.sep2,
91
+ "conv_id": self.conv_id,
92
+ }
93
+
94
+
95
+ class StoppingCriteriaSub(StoppingCriteria):
96
+
97
+ def __init__(self, stops=[], encounters=1):
98
+ super().__init__()
99
+ self.stops = stops
100
+
101
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
102
+ for stop in self.stops:
103
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
104
+ return True
105
+
106
+ return False
107
+
108
+
109
+ CONV_VISION = Conversation(
110
+ system="",
111
+ roles=("Human", "Assistant"),
112
+ messages=[],
113
+ offset=2,
114
+ sep_style=SeparatorStyle.SINGLE,
115
+ sep="###",
116
+ )
117
+
118
+
119
+
120
+ class Chat:
121
+ def __init__(self, model, vis_processor, device='cuda:0'):
122
+ self.device = device
123
+ self.model = model
124
+ self.vis_processor = vis_processor
125
+ stop_words_ids = [torch.tensor([835]).to(self.device),
126
+ torch.tensor([2277, 29937]).to(self.device)] # '###' can be encoded in two different ways.
127
+ self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
128
+
129
+ def ask(self, text, conv):
130
+ if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
131
+ and conv.messages[-1][1][-6:] == '</Img>': # last message is image.
132
+ conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
133
+ else:
134
+ conv.append_message(conv.roles[0], text)
135
+
136
+ def batch_answer(self, batch_raw_img_list, batch_context, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
137
+ repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000,
138
+ update_layer=16):
139
+ all_imgs = []
140
+ for raw_img_list in batch_raw_img_list:
141
+ images = []
142
+ for raw_image in raw_img_list:
143
+ img = self.vis_processor(raw_image).unsqueeze(1).to(self.device)
144
+ images.append(img)
145
+ images = torch.cat(images, 1)
146
+ all_imgs.append(images)
147
+ all_imgs = torch.stack(all_imgs, 0)
148
+
149
+ img_list, vit_list, att_list = [], [], []
150
+ for j in range(all_imgs.size(2)):
151
+ image = all_imgs[:,:,j,:,:]
152
+ image_emb, image_att, vit_emb = self.model.encode_img(image)
153
+ img_list.append(image_emb)
154
+ vit_list.append(vit_emb)
155
+ att_list.append(image_att)
156
+
157
+ conv_list = []
158
+ for context in batch_context:
159
+ chat_state = CONV_VISION.copy()
160
+ img_embd_list = []
161
+ for i, text in enumerate(context.split("<ImageHere>")):
162
+ if text != '' and text.strip()!='':
163
+ self.ask(text, chat_state)
164
+ if i < len(raw_img_list):
165
+ if len(chat_state.messages)>0:
166
+ chat_state.messages[-1][1] = ' '.join([chat_state.messages[-1][1], "<Img><HereForImage></Img>"])
167
+ else:
168
+ chat_state.append_message(chat_state.roles[0], "<Img><HereForImage></Img>")
169
+ chat_state.append_message(chat_state.roles[1], None)
170
+ conv_list.append(chat_state)
171
+
172
+ split_prompt = []
173
+ for conv in conv_list:
174
+ prompt = conv.get_prompt()
175
+ cur_split_prompt = prompt.split('<HereForImage>')
176
+ assert len(cur_split_prompt) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
177
+ split_prompt.append(cur_split_prompt)
178
+ embs, attention_mask, img_position_list, input_part_targets_len = self.batch_get_context_emb(split_prompt, img_list, att_list)
179
+
180
+ assert embs.shape[1] + max_new_tokens < max_length
181
+ with self.model.maybe_autocast():
182
+ outputs = self.model.llama_model.generate(
183
+ inputs_embeds=embs,
184
+ attention_mask=attention_mask,
185
+ max_new_tokens=max_new_tokens,
186
+ stopping_criteria=self.stopping_criteria,
187
+ num_beams=num_beams,
188
+ do_sample=True,
189
+ min_length=min_length,
190
+ top_p=top_p,
191
+ repetition_penalty=repetition_penalty,
192
+ length_penalty=length_penalty,
193
+ temperature=temperature,
194
+ # new add
195
+ update_layer = update_layer,
196
+ image_position_list = img_position_list,
197
+ input_part_targets_len = input_part_targets_len,
198
+ all_image_embeds = torch.stack(vit_list,dim=1)
199
+ )
200
+
201
+ batch_outputs = []
202
+ for output_token in outputs:
203
+ if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it
204
+ output_token = output_token[1:]
205
+ if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it
206
+ output_token = output_token[1:]
207
+ output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
208
+ output_text = output_text.split('###')[0] # remove the stop sign '###'
209
+ output_text = output_text.split('Assistant:')[-1].strip()
210
+ batch_outputs.append(output_text)
211
+ return batch_outputs
212
+
213
+
214
+ def answer(self, raw_img_list, context, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
215
+ repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000,
216
+ update_layer=16):
217
+ chat_state = CONV_VISION.copy()
218
+ img_embd_list = []
219
+ vit_list = []
220
+ for i, text in enumerate(context.split("<Img><HereForImage></Img>")):
221
+ if text != '' and text.strip()!='':
222
+ self.ask(text, chat_state)
223
+ if i < len(raw_img_list):
224
+ self.upload_img(raw_img_list[i], chat_state, img_embd_list, vit_list)
225
+ chat_state.append_message(chat_state.roles[1], None)
226
+ embs, img_position_list, input_part_targets_len = self.get_context_emb(chat_state, img_embd_list)
227
+
228
+ assert embs.shape[1] + max_new_tokens < max_length
229
+
230
+ with self.model.maybe_autocast():
231
+ outputs = self.model.llama_model.generate(
232
+ inputs_embeds=embs,
233
+ max_new_tokens=max_new_tokens,
234
+ stopping_criteria=self.stopping_criteria,
235
+ num_beams=num_beams,
236
+ do_sample=True,
237
+ min_length=min_length,
238
+ top_p=top_p,
239
+ repetition_penalty=repetition_penalty,
240
+ length_penalty=length_penalty,
241
+ temperature=temperature,
242
+ # new add
243
+ update_layer = update_layer,
244
+ image_position_list = img_position_list,
245
+ input_part_targets_len = input_part_targets_len,
246
+ all_image_embeds = torch.stack(vit_list,dim=1)
247
+ )
248
+ output_token = outputs[0]
249
+ if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it
250
+ output_token = output_token[1:]
251
+ if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it
252
+ output_token = output_token[1:]
253
+ output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
254
+ output_text = output_text.split('###')[0] # remove the stop sign '###'
255
+ output_text = output_text.split('Assistant:')[-1].strip()
256
+ return output_text
257
+
258
+ def upload_img(self, image, conv, img_list, vit_list, att_list=None):
259
+ if isinstance(image, str): # is a image path
260
+ raw_image = Image.open(image).convert('RGB')
261
+ image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
262
+ elif isinstance(image, Image.Image):
263
+ raw_image = image
264
+ image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
265
+ elif isinstance(image, torch.Tensor):
266
+ if len(image.shape) == 3:
267
+ image = image.unsqueeze(0)
268
+ image = image.to(self.device)
269
+
270
+ image_emb, image_att, vit_emb = self.model.encode_img(image)
271
+ img_list.append(image_emb)
272
+ vit_list.append(vit_emb)
273
+ if att_list is not None:
274
+ att_list.append(image_att)
275
+ if isinstance(conv, list):
276
+ for c in conv:
277
+ if len(c.messages)>0:
278
+ c.messages[-1][1] = ' '.join([c.messages[-1][1], "<Img><HereForImage></Img>"])
279
+ else:
280
+ c.append_message(c.roles[0], "<Img><HereForImage></Img>")
281
+ else:
282
+ if len(conv.messages)>0:
283
+ conv.messages[-1][1] = ' '.join([conv.messages[-1][1], "<Img><HereForImage></Img>"])
284
+ else:
285
+ conv.append_message(conv.roles[0], "<Img><HereForImage></Img>")
286
+ msg = "Received."
287
+ return msg
288
+
289
+ def get_context_emb(self, conv, img_list):
290
+ prompt = conv.get_prompt()
291
+ # print(prompt)
292
+ prompt_segs = prompt.split('<HereForImage>')
293
+ assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
294
+ seg_tokens = [
295
+ self.model.llama_tokenizer(
296
+ seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
297
+ # only add bos to the first seg
298
+ for i, seg in enumerate(prompt_segs)
299
+ ]
300
+ seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
301
+
302
+ mixed_embs = []
303
+ img_position_list = []
304
+ img_start = 0
305
+ for i in range(len(prompt_segs)):
306
+ mixed_embs.append(seg_embs[i])
307
+ if i != len(img_list):
308
+ mixed_embs.append(img_list[i])
309
+ img_start += seg_embs[i].size(1)
310
+ img_end = img_start + img_list[i].size(1)
311
+ img_position_list.append((img_start, img_end))
312
+ img_start = img_end
313
+
314
+ mixed_embs = torch.cat(mixed_embs, dim=1)
315
+
316
+ input_part_targets_len = []
317
+ for i in range(mixed_embs.size(0)):
318
+ input_part_targets_len.append(mixed_embs.size(1)-1)
319
+ input_part_targets_len = torch.tensor(input_part_targets_len)
320
+ return mixed_embs, img_position_list, input_part_targets_len
321
+
322
+ def batch_get_context_emb(self, split_prompt, img_list, img_attns):
323
+ prompt_segs = []
324
+ for i in range(len(img_list) + 1):
325
+ prompt_segs.append([p[i] for p in split_prompt])
326
+
327
+ self.model.llama_tokenizer.padding_side = "left"
328
+
329
+ seg_tokens = [
330
+ self.model.llama_tokenizer(
331
+ seg, return_tensors="pt", padding=True, add_special_tokens=i == 0).to(self.device)
332
+ # only add bos to the first seg
333
+ for i, seg in enumerate(prompt_segs)
334
+ ]
335
+ seg_embs = [self.model.llama_model.model.embed_tokens(seg_t.input_ids) for seg_t in seg_tokens]
336
+ seg_attns = [seg_t.attention_mask for seg_t in seg_tokens]
337
+
338
+ mixed_embs = []
339
+ mixed_attns = []
340
+ img_position_list = []
341
+ img_start = 0
342
+ for i in range(len(prompt_segs)):
343
+ mixed_embs.append(seg_embs[i])
344
+ mixed_attns.append(seg_attns[i])
345
+ if i != len(img_list):
346
+ mixed_embs.append(img_list[i])
347
+ mixed_attns.append(img_attns[i])
348
+ img_start += seg_embs[i].size(1)
349
+ img_end = img_start + img_list[i].size(1)
350
+ img_position_list.append((img_start, img_end))
351
+ img_start = img_end
352
+
353
+ mixed_embs = torch.cat(mixed_embs, dim=1)
354
+ mixed_attns = torch.cat(mixed_attns, dim=1)
355
+
356
+ input_part_targets_len = []
357
+ for i in range(mixed_embs.size(0)):
358
+ input_part_targets_len.append(mixed_embs.size(1)-1)
359
+ input_part_targets_len = torch.tensor(input_part_targets_len)
360
+
361
+ return mixed_embs, mixed_attns, img_position_list, input_part_targets_len
362
+
363
+
364
+
cheetah/conversation/conversation_llama2.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import time
3
+ from PIL import Image
4
+
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer
7
+ from transformers import StoppingCriteria, StoppingCriteriaList
8
+
9
+ import dataclasses
10
+ from enum import auto, Enum
11
+ from typing import List, Tuple, Any
12
+
13
+ from cheetah.common.registry import registry
14
+
15
+
16
+ class SeparatorStyle(Enum):
17
+ """Different separator style."""
18
+ SINGLE = auto()
19
+ TWO = auto()
20
+ LLAMA_2 = auto()
21
+
22
+
23
+ @dataclasses.dataclass
24
+ class Conversation:
25
+ """A class that keeps all conversation history."""
26
+ system: str
27
+ roles: List[str]
28
+ messages: List[List[str]]
29
+ offset: int
30
+ # system_img: List[Image.Image] = []
31
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
32
+ sep: str = "###"
33
+ sep2: str = None
34
+
35
+ skip_next: bool = False
36
+ conv_id: Any = None
37
+
38
+ def get_prompt(self):
39
+ if self.sep_style == SeparatorStyle.SINGLE:
40
+ ret = self.system + self.sep
41
+ for role, message in self.messages:
42
+ if message:
43
+ ret += role + ": " + message + self.sep
44
+ else:
45
+ ret += role + ":"
46
+ return ret
47
+ elif self.sep_style == SeparatorStyle.TWO:
48
+ seps = [self.sep, self.sep2]
49
+ ret = self.system + seps[0]
50
+ for i, (role, message) in enumerate(self.messages):
51
+ if message:
52
+ ret += role + ": " + message + seps[i % 2]
53
+ else:
54
+ ret += role + ":"
55
+ return ret
56
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
57
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
58
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
59
+ ret = ""
60
+
61
+ for i, (role, message) in enumerate(self.messages):
62
+ if i == 0:
63
+ assert message, "first message should not be none"
64
+ assert role == self.roles[0], "first message should come from user"
65
+ if message:
66
+ if type(message) is tuple:
67
+ message, _, _ = message
68
+ if i == 0: message = wrap_sys(self.system) + message
69
+ if i % 2 == 0:
70
+ message = wrap_inst(message)
71
+ ret += self.sep + message
72
+ else:
73
+ ret += " " + message + " " + self.sep2
74
+ else:
75
+ ret += ""
76
+
77
+ ret = ret.lstrip(self.sep)
78
+ return ret
79
+ else:
80
+ raise ValueError(f"Invalid style: {self.sep_style}")
81
+
82
+ def append_message(self, role, message):
83
+ self.messages.append([role, message])
84
+
85
+ def to_gradio_chatbot(self):
86
+ ret = []
87
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
88
+ if i % 2 == 0:
89
+ ret.append([msg, None])
90
+ else:
91
+ ret[-1][-1] = msg
92
+ return ret
93
+
94
+ def copy(self):
95
+ return Conversation(
96
+ system=self.system,
97
+ # system_img=self.system_img,
98
+ roles=self.roles,
99
+ messages=[[x, y] for x, y in self.messages],
100
+ offset=self.offset,
101
+ sep_style=self.sep_style,
102
+ sep=self.sep,
103
+ sep2=self.sep2,
104
+ conv_id=self.conv_id)
105
+
106
+ def dict(self):
107
+ return {
108
+ "system": self.system,
109
+ # "system_img": self.system_img,
110
+ "roles": self.roles,
111
+ "messages": self.messages,
112
+ "offset": self.offset,
113
+ "sep": self.sep,
114
+ "sep2": self.sep2,
115
+ "conv_id": self.conv_id,
116
+ }
117
+
118
+
119
+ class StoppingCriteriaSub(StoppingCriteria):
120
+
121
+ def __init__(self, stops=[], encounters=1):
122
+ super().__init__()
123
+ self.stops = stops
124
+
125
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
126
+ for stop in self.stops:
127
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
128
+ return True
129
+
130
+ return False
131
+
132
+
133
+ CONV_VISION = Conversation(
134
+ system='',
135
+ roles=("USER", "ASSISTANT"),
136
+ messages=[],
137
+ offset=0,
138
+ sep_style=SeparatorStyle.LLAMA_2,
139
+ sep="<s>",
140
+ sep2="</s>",
141
+ )
142
+
143
+
144
+
145
+ class Chat:
146
+ def __init__(self, model, vis_processor, device='cuda:0'):
147
+ self.device = device
148
+ self.model = model
149
+ self.vis_processor = vis_processor
150
+ stop_words_ids = [torch.tensor([2]).to(self.device)]
151
+ self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
152
+
153
+ def ask(self, text, conv):
154
+ if len(conv.messages) > 0 and conv.messages[-1][0] == conv.roles[0] \
155
+ and conv.messages[-1][1][-6:] == '</Img>': # last message is image.
156
+ conv.messages[-1][1] = ' '.join([conv.messages[-1][1], text])
157
+ else:
158
+ conv.append_message(conv.roles[0], text)
159
+
160
+ def batch_answer(self, batch_raw_img_list, batch_context, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
161
+ repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000,
162
+ update_layer=16):
163
+ embeds_list = []
164
+ all_imgs = []
165
+ for raw_img_list in batch_raw_img_list:
166
+ images = []
167
+ for raw_image in raw_img_list:
168
+ raw_image = Image.open(raw_image).convert('RGB')
169
+ img = self.vis_processor(raw_image).unsqueeze(1).to(self.device)
170
+ images.append(img)
171
+ images = torch.cat(images, 1)
172
+ all_imgs.append(images)
173
+ all_imgs = torch.stack(all_imgs, 0)
174
+
175
+ img_list, vit_list, att_list = [], [], []
176
+ for j in range(all_imgs.size(2)):
177
+ image = all_imgs[:,:,j,:,:]
178
+ image_emb, image_att, vit_emb = self.model.encode_img(image)
179
+ img_list.append(image_emb)
180
+ vit_list.append(vit_emb)
181
+ att_list.append(image_att)
182
+
183
+ conv_list = []
184
+ for context in batch_context:
185
+ chat_state = CONV_VISION.copy()
186
+ img_embd_list = []
187
+ for i, text in enumerate(context.split("<ImageHere>")):
188
+ if text != '' and text.strip()!='':
189
+ self.ask(text, chat_state)
190
+ if i < len(raw_img_list):
191
+ if len(chat_state.messages)>0:
192
+ chat_state.messages[-1][1] = ' '.join([chat_state.messages[-1][1], "<Img><HereForImage></Img>"])
193
+ else:
194
+ chat_state.append_message(chat_state.roles[0], "<Img><HereForImage></Img>")
195
+ chat_state.append_message(chat_state.roles[1], None)
196
+ conv_list.append(chat_state)
197
+
198
+ split_prompt = []
199
+ for conv in conv_list:
200
+ prompt = conv.get_prompt()
201
+ cur_split_prompt = prompt.split('<HereForImage>')
202
+ assert len(cur_split_prompt) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
203
+ split_prompt.append(cur_split_prompt)
204
+ embs, attention_mask, img_position_list, input_part_targets_len = self.batch_get_context_emb(split_prompt, img_list, att_list)
205
+
206
+ assert embs.shape[1] + max_new_tokens < max_length
207
+ with self.model.maybe_autocast():
208
+ outputs = self.model.llama_model.generate(
209
+ inputs_embeds=embs,
210
+ attention_mask=attention_mask,
211
+ max_new_tokens=max_new_tokens,
212
+ stopping_criteria=self.stopping_criteria,
213
+ num_beams=num_beams,
214
+ do_sample=True,
215
+ min_length=min_length,
216
+ top_p=top_p,
217
+ repetition_penalty=repetition_penalty,
218
+ length_penalty=length_penalty,
219
+ temperature=temperature,
220
+ # new add
221
+ update_layer = update_layer,
222
+ image_position_list = img_position_list,
223
+ input_part_targets_len = input_part_targets_len,
224
+ all_image_embeds = torch.stack(vit_list,dim=1)
225
+ )
226
+
227
+ batch_outputs = []
228
+ for output_token in outputs:
229
+ if output_token[0] == 0:
230
+ output_token = output_token[1:]
231
+ if output_token[0] == 1:
232
+ output_token = output_token[1:]
233
+ output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
234
+ output_text = output_text.split('</s>')[0] # remove the stop sign '</s>'
235
+ batch_outputs.append(output_text)
236
+ return batch_outputs
237
+
238
+
239
+ def answer(self, raw_img_list, context, max_new_tokens=300, num_beams=1, min_length=1, top_p=0.9,
240
+ repetition_penalty=1.0, length_penalty=1, temperature=1.0, max_length=2000,
241
+ update_layer=16):
242
+ chat_state = CONV_VISION.copy()
243
+ img_embd_list = []
244
+ vit_list = []
245
+ for i, text in enumerate(context.split("<Img><HereForImage></Img>")):
246
+ if text != '' and text.strip()!='':
247
+ self.ask(text, chat_state)
248
+ if i < len(raw_img_list):
249
+ self.upload_img(raw_img_list[i], chat_state, img_embd_list, vit_list)
250
+ chat_state.append_message(chat_state.roles[1], None)
251
+ embs, img_position_list, input_part_targets_len = self.get_context_emb(chat_state, img_embd_list)
252
+
253
+ assert embs.shape[1] + max_new_tokens < max_length
254
+
255
+ with self.model.maybe_autocast():
256
+ outputs = self.model.llama_model.generate(
257
+ inputs_embeds=embs,
258
+ max_new_tokens=max_new_tokens,
259
+ stopping_criteria=self.stopping_criteria,
260
+ num_beams=num_beams,
261
+ do_sample=True,
262
+ min_length=min_length,
263
+ top_p=top_p,
264
+ repetition_penalty=repetition_penalty,
265
+ length_penalty=length_penalty,
266
+ temperature=temperature,
267
+ # new add
268
+ update_layer = update_layer,
269
+ image_position_list = img_position_list,
270
+ input_part_targets_len = input_part_targets_len,
271
+ all_image_embeds = torch.stack(vit_list,dim=1)
272
+ )
273
+ output_token = outputs[0]
274
+ if output_token[0] == 0:
275
+ output_token = output_token[1:]
276
+ if output_token[0] == 1:
277
+ output_token = output_token[1:]
278
+ output_text = self.model.llama_tokenizer.decode(output_token, add_special_tokens=False)
279
+ output_text = output_text.split('</s>')[0] # remove the stop sign '</s>'
280
+ return output_text
281
+
282
+ def upload_img(self, image, conv, img_list, vit_list, att_list=None):
283
+ if isinstance(image, str): # is a image path
284
+ raw_image = Image.open(image).convert('RGB')
285
+ image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
286
+ elif isinstance(image, Image.Image):
287
+ raw_image = image
288
+ image = self.vis_processor(raw_image).unsqueeze(0).to(self.device)
289
+ elif isinstance(image, torch.Tensor):
290
+ if len(image.shape) == 3:
291
+ image = image.unsqueeze(0)
292
+ image = image.to(self.device)
293
+
294
+ image_emb, image_att, vit_emb = self.model.encode_img(image)
295
+ img_list.append(image_emb)
296
+ vit_list.append(vit_emb)
297
+ if att_list is not None:
298
+ att_list.append(image_att)
299
+ if isinstance(conv, list):
300
+ for c in conv:
301
+ if len(c.messages)>0:
302
+ c.messages[-1][1] = ' '.join([c.messages[-1][1], "<Img><HereForImage></Img>"])
303
+ else:
304
+ c.append_message(c.roles[0], "<Img><HereForImage></Img>")
305
+ else:
306
+ if len(conv.messages)>0:
307
+ conv.messages[-1][1] = ' '.join([conv.messages[-1][1], "<Img><HereForImage></Img>"])
308
+ else:
309
+ conv.append_message(conv.roles[0], "<Img><HereForImage></Img>")
310
+ msg = "Received."
311
+ return msg
312
+
313
+ def get_context_emb(self, conv, img_list):
314
+ prompt = conv.get_prompt()
315
+ # print(prompt)
316
+ prompt_segs = prompt.split('<HereForImage>')
317
+ assert len(prompt_segs) == len(img_list) + 1, "Unmatched numbers of image placeholders and images."
318
+ seg_tokens = [
319
+ self.model.llama_tokenizer(
320
+ seg, return_tensors="pt", add_special_tokens=i == 0).to(self.device).input_ids
321
+ # only add bos to the first seg
322
+ for i, seg in enumerate(prompt_segs)
323
+ ]
324
+ seg_embs = [self.model.llama_model.model.embed_tokens(seg_t) for seg_t in seg_tokens]
325
+
326
+ mixed_embs = []
327
+ img_position_list = []
328
+ img_start = 0
329
+ for i in range(len(prompt_segs)):
330
+ mixed_embs.append(seg_embs[i])
331
+ if i != len(img_list):
332
+ mixed_embs.append(img_list[i])
333
+ img_start += seg_embs[i].size(1)
334
+ img_end = img_start + img_list[i].size(1)
335
+ img_position_list.append((img_start, img_end))
336
+ img_start = img_end
337
+
338
+ mixed_embs = torch.cat(mixed_embs, dim=1)
339
+
340
+ input_part_targets_len = []
341
+ for i in range(mixed_embs.size(0)):
342
+ input_part_targets_len.append(mixed_embs.size(1)-1)
343
+ input_part_targets_len = torch.tensor(input_part_targets_len)
344
+ return mixed_embs, img_position_list, input_part_targets_len
345
+
346
+ def batch_get_context_emb(self, split_prompt, img_list, img_attns):
347
+ prompt_segs = []
348
+ for i in range(len(img_list) + 1):
349
+ prompt_segs.append([p[i] for p in split_prompt])
350
+
351
+ self.model.llama_tokenizer.padding_side = "left"
352
+
353
+ seg_tokens = [
354
+ self.model.llama_tokenizer(
355
+ seg, return_tensors="pt", padding=True, add_special_tokens=i == 0).to(self.device)
356
+ # only add bos to the first seg
357
+ for i, seg in enumerate(prompt_segs)
358
+ ]
359
+ seg_embs = [self.model.llama_model.model.embed_tokens(seg_t.input_ids) for seg_t in seg_tokens]
360
+ seg_attns = [seg_t.attention_mask for seg_t in seg_tokens]
361
+
362
+ mixed_embs = []
363
+ mixed_attns = []
364
+ img_position_list = []
365
+ img_start = 0
366
+ for i in range(len(prompt_segs)):
367
+ mixed_embs.append(seg_embs[i])
368
+ mixed_attns.append(seg_attns[i])
369
+ if i != len(img_list):
370
+ mixed_embs.append(img_list[i])
371
+ mixed_attns.append(img_attns[i])
372
+ img_start += seg_embs[i].size(1)
373
+ img_end = img_start + img_list[i].size(1)
374
+ img_position_list.append((img_start, img_end))
375
+ img_start = img_end
376
+
377
+ mixed_embs = torch.cat(mixed_embs, dim=1)
378
+ mixed_attns = torch.cat(mixed_attns, dim=1)
379
+
380
+ input_part_targets_len = []
381
+ for i in range(mixed_embs.size(0)):
382
+ input_part_targets_len.append(mixed_embs.size(1)-1)
383
+ input_part_targets_len = torch.tensor(input_part_targets_len)
384
+
385
+ return mixed_embs, mixed_attns, img_position_list, input_part_targets_len
386
+
387
+
cheetah/models/Qformer.py ADDED
@@ -0,0 +1,1216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ * Copyright (c) 2023, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ * Based on huggingface code base
8
+ * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert
9
+ """
10
+
11
+ import math
12
+ import os
13
+ import warnings
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple, Dict, Any
16
+
17
+ import torch
18
+ from torch import Tensor, device, dtype, nn
19
+ import torch.utils.checkpoint
20
+ from torch import nn
21
+ from torch.nn import CrossEntropyLoss
22
+ import torch.nn.functional as F
23
+
24
+ from transformers.activations import ACT2FN
25
+ from transformers.file_utils import (
26
+ ModelOutput,
27
+ )
28
+ from transformers.modeling_outputs import (
29
+ BaseModelOutputWithPastAndCrossAttentions,
30
+ BaseModelOutputWithPoolingAndCrossAttentions,
31
+ CausalLMOutputWithCrossAttentions,
32
+ MaskedLMOutput,
33
+ MultipleChoiceModelOutput,
34
+ NextSentencePredictorOutput,
35
+ QuestionAnsweringModelOutput,
36
+ SequenceClassifierOutput,
37
+ TokenClassifierOutput,
38
+ )
39
+ from transformers.modeling_utils import (
40
+ PreTrainedModel,
41
+ apply_chunking_to_forward,
42
+ find_pruneable_heads_and_indices,
43
+ prune_linear_layer,
44
+ )
45
+ from transformers.utils import logging
46
+ from transformers.models.bert.configuration_bert import BertConfig
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+
51
+ class BertEmbeddings(nn.Module):
52
+ """Construct the embeddings from word and position embeddings."""
53
+
54
+ def __init__(self, config):
55
+ super().__init__()
56
+ self.word_embeddings = nn.Embedding(
57
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id
58
+ )
59
+ self.position_embeddings = nn.Embedding(
60
+ config.max_position_embeddings, config.hidden_size
61
+ )
62
+
63
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
64
+ # any TensorFlow checkpoint file
65
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
66
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
67
+
68
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
69
+ self.register_buffer(
70
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))
71
+ )
72
+ self.position_embedding_type = getattr(
73
+ config, "position_embedding_type", "absolute"
74
+ )
75
+
76
+ self.config = config
77
+
78
+ def forward(
79
+ self,
80
+ input_ids=None,
81
+ position_ids=None,
82
+ query_embeds=None,
83
+ past_key_values_length=0,
84
+ ):
85
+ if input_ids is not None:
86
+ seq_length = input_ids.size()[1]
87
+ else:
88
+ seq_length = 0
89
+
90
+ if position_ids is None:
91
+ position_ids = self.position_ids[
92
+ :, past_key_values_length : seq_length + past_key_values_length
93
+ ].clone()
94
+
95
+ if input_ids is not None:
96
+ embeddings = self.word_embeddings(input_ids)
97
+ if self.position_embedding_type == "absolute":
98
+ position_embeddings = self.position_embeddings(position_ids)
99
+ embeddings = embeddings + position_embeddings
100
+
101
+ if query_embeds is not None:
102
+ embeddings = torch.cat((query_embeds, embeddings), dim=1)
103
+ else:
104
+ embeddings = query_embeds
105
+
106
+ embeddings = self.LayerNorm(embeddings)
107
+ embeddings = self.dropout(embeddings)
108
+ return embeddings
109
+
110
+
111
+ class BertSelfAttention(nn.Module):
112
+ def __init__(self, config, is_cross_attention):
113
+ super().__init__()
114
+ self.config = config
115
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
116
+ config, "embedding_size"
117
+ ):
118
+ raise ValueError(
119
+ "The hidden size (%d) is not a multiple of the number of attention "
120
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
121
+ )
122
+
123
+ self.num_attention_heads = config.num_attention_heads
124
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
125
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
126
+
127
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
128
+ if is_cross_attention:
129
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
130
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
131
+ else:
132
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
133
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
134
+
135
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
136
+ self.position_embedding_type = getattr(
137
+ config, "position_embedding_type", "absolute"
138
+ )
139
+ if (
140
+ self.position_embedding_type == "relative_key"
141
+ or self.position_embedding_type == "relative_key_query"
142
+ ):
143
+ self.max_position_embeddings = config.max_position_embeddings
144
+ self.distance_embedding = nn.Embedding(
145
+ 2 * config.max_position_embeddings - 1, self.attention_head_size
146
+ )
147
+ self.save_attention = False
148
+
149
+ def save_attn_gradients(self, attn_gradients):
150
+ self.attn_gradients = attn_gradients
151
+
152
+ def get_attn_gradients(self):
153
+ return self.attn_gradients
154
+
155
+ def save_attention_map(self, attention_map):
156
+ self.attention_map = attention_map
157
+
158
+ def get_attention_map(self):
159
+ return self.attention_map
160
+
161
+ def transpose_for_scores(self, x):
162
+ new_x_shape = x.size()[:-1] + (
163
+ self.num_attention_heads,
164
+ self.attention_head_size,
165
+ )
166
+ x = x.view(*new_x_shape)
167
+ return x.permute(0, 2, 1, 3)
168
+
169
+ def forward(
170
+ self,
171
+ hidden_states,
172
+ attention_mask=None,
173
+ head_mask=None,
174
+ encoder_hidden_states=None,
175
+ encoder_attention_mask=None,
176
+ past_key_value=None,
177
+ output_attentions=False,
178
+ ):
179
+
180
+ # If this is instantiated as a cross-attention module, the keys
181
+ # and values come from an encoder; the attention mask needs to be
182
+ # such that the encoder's padding tokens are not attended to.
183
+ is_cross_attention = encoder_hidden_states is not None
184
+
185
+ if is_cross_attention:
186
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
187
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
188
+ attention_mask = encoder_attention_mask
189
+ elif past_key_value is not None:
190
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
191
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
192
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
193
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
194
+ else:
195
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
196
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
197
+
198
+ mixed_query_layer = self.query(hidden_states)
199
+
200
+ query_layer = self.transpose_for_scores(mixed_query_layer)
201
+
202
+ past_key_value = (key_layer, value_layer)
203
+
204
+ # Take the dot product between "query" and "key" to get the raw attention scores.
205
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
206
+
207
+ if (
208
+ self.position_embedding_type == "relative_key"
209
+ or self.position_embedding_type == "relative_key_query"
210
+ ):
211
+ seq_length = hidden_states.size()[1]
212
+ position_ids_l = torch.arange(
213
+ seq_length, dtype=torch.long, device=hidden_states.device
214
+ ).view(-1, 1)
215
+ position_ids_r = torch.arange(
216
+ seq_length, dtype=torch.long, device=hidden_states.device
217
+ ).view(1, -1)
218
+ distance = position_ids_l - position_ids_r
219
+ positional_embedding = self.distance_embedding(
220
+ distance + self.max_position_embeddings - 1
221
+ )
222
+ positional_embedding = positional_embedding.to(
223
+ dtype=query_layer.dtype
224
+ ) # fp16 compatibility
225
+
226
+ if self.position_embedding_type == "relative_key":
227
+ relative_position_scores = torch.einsum(
228
+ "bhld,lrd->bhlr", query_layer, positional_embedding
229
+ )
230
+ attention_scores = attention_scores + relative_position_scores
231
+ elif self.position_embedding_type == "relative_key_query":
232
+ relative_position_scores_query = torch.einsum(
233
+ "bhld,lrd->bhlr", query_layer, positional_embedding
234
+ )
235
+ relative_position_scores_key = torch.einsum(
236
+ "bhrd,lrd->bhlr", key_layer, positional_embedding
237
+ )
238
+ attention_scores = (
239
+ attention_scores
240
+ + relative_position_scores_query
241
+ + relative_position_scores_key
242
+ )
243
+
244
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
245
+ if attention_mask is not None:
246
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
247
+ attention_scores = attention_scores + attention_mask
248
+
249
+ # Normalize the attention scores to probabilities.
250
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
251
+
252
+ if is_cross_attention and self.save_attention:
253
+ self.save_attention_map(attention_probs)
254
+ attention_probs.register_hook(self.save_attn_gradients)
255
+
256
+ # This is actually dropping out entire tokens to attend to, which might
257
+ # seem a bit unusual, but is taken from the original Transformer paper.
258
+ attention_probs_dropped = self.dropout(attention_probs)
259
+
260
+ # Mask heads if we want to
261
+ if head_mask is not None:
262
+ attention_probs_dropped = attention_probs_dropped * head_mask
263
+
264
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
265
+
266
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
267
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
268
+ context_layer = context_layer.view(*new_context_layer_shape)
269
+
270
+ outputs = (
271
+ (context_layer, attention_probs) if output_attentions else (context_layer,)
272
+ )
273
+
274
+ outputs = outputs + (past_key_value,)
275
+ return outputs
276
+
277
+
278
+ class BertSelfOutput(nn.Module):
279
+ def __init__(self, config):
280
+ super().__init__()
281
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
282
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
283
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
284
+
285
+ def forward(self, hidden_states, input_tensor):
286
+ hidden_states = self.dense(hidden_states)
287
+ hidden_states = self.dropout(hidden_states)
288
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
289
+ return hidden_states
290
+
291
+
292
+ class BertAttention(nn.Module):
293
+ def __init__(self, config, is_cross_attention=False):
294
+ super().__init__()
295
+ self.self = BertSelfAttention(config, is_cross_attention)
296
+ self.output = BertSelfOutput(config)
297
+ self.pruned_heads = set()
298
+
299
+ def prune_heads(self, heads):
300
+ if len(heads) == 0:
301
+ return
302
+ heads, index = find_pruneable_heads_and_indices(
303
+ heads,
304
+ self.self.num_attention_heads,
305
+ self.self.attention_head_size,
306
+ self.pruned_heads,
307
+ )
308
+
309
+ # Prune linear layers
310
+ self.self.query = prune_linear_layer(self.self.query, index)
311
+ self.self.key = prune_linear_layer(self.self.key, index)
312
+ self.self.value = prune_linear_layer(self.self.value, index)
313
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
314
+
315
+ # Update hyper params and store pruned heads
316
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
317
+ self.self.all_head_size = (
318
+ self.self.attention_head_size * self.self.num_attention_heads
319
+ )
320
+ self.pruned_heads = self.pruned_heads.union(heads)
321
+
322
+ def forward(
323
+ self,
324
+ hidden_states,
325
+ attention_mask=None,
326
+ head_mask=None,
327
+ encoder_hidden_states=None,
328
+ encoder_attention_mask=None,
329
+ past_key_value=None,
330
+ output_attentions=False,
331
+ ):
332
+ self_outputs = self.self(
333
+ hidden_states,
334
+ attention_mask,
335
+ head_mask,
336
+ encoder_hidden_states,
337
+ encoder_attention_mask,
338
+ past_key_value,
339
+ output_attentions,
340
+ )
341
+ attention_output = self.output(self_outputs[0], hidden_states)
342
+
343
+ outputs = (attention_output,) + self_outputs[
344
+ 1:
345
+ ] # add attentions if we output them
346
+ return outputs
347
+
348
+
349
+ class BertIntermediate(nn.Module):
350
+ def __init__(self, config):
351
+ super().__init__()
352
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
353
+ if isinstance(config.hidden_act, str):
354
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
355
+ else:
356
+ self.intermediate_act_fn = config.hidden_act
357
+
358
+ def forward(self, hidden_states):
359
+ hidden_states = self.dense(hidden_states)
360
+ hidden_states = self.intermediate_act_fn(hidden_states)
361
+ return hidden_states
362
+
363
+
364
+ class BertOutput(nn.Module):
365
+ def __init__(self, config):
366
+ super().__init__()
367
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
368
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
369
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
370
+
371
+ def forward(self, hidden_states, input_tensor):
372
+ hidden_states = self.dense(hidden_states)
373
+ hidden_states = self.dropout(hidden_states)
374
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
375
+ return hidden_states
376
+
377
+
378
+ class BertLayer(nn.Module):
379
+ def __init__(self, config, layer_num):
380
+ super().__init__()
381
+ self.config = config
382
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
383
+ self.seq_len_dim = 1
384
+ self.attention = BertAttention(config)
385
+ self.layer_num = layer_num
386
+ if (
387
+ self.config.add_cross_attention
388
+ and layer_num % self.config.cross_attention_freq == 0
389
+ ):
390
+ self.crossattention = BertAttention(
391
+ config, is_cross_attention=self.config.add_cross_attention
392
+ )
393
+ self.has_cross_attention = True
394
+ else:
395
+ self.has_cross_attention = False
396
+ self.intermediate = BertIntermediate(config)
397
+ self.output = BertOutput(config)
398
+
399
+ self.intermediate_query = BertIntermediate(config)
400
+ self.output_query = BertOutput(config)
401
+
402
+ def forward(
403
+ self,
404
+ hidden_states,
405
+ attention_mask=None,
406
+ head_mask=None,
407
+ encoder_hidden_states=None,
408
+ encoder_attention_mask=None,
409
+ past_key_value=None,
410
+ output_attentions=False,
411
+ query_length=0,
412
+ ):
413
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
414
+ self_attn_past_key_value = (
415
+ past_key_value[:2] if past_key_value is not None else None
416
+ )
417
+ self_attention_outputs = self.attention(
418
+ hidden_states,
419
+ attention_mask,
420
+ head_mask,
421
+ output_attentions=output_attentions,
422
+ past_key_value=self_attn_past_key_value,
423
+ )
424
+ attention_output = self_attention_outputs[0]
425
+ outputs = self_attention_outputs[1:-1]
426
+
427
+ present_key_value = self_attention_outputs[-1]
428
+
429
+ if query_length > 0:
430
+ query_attention_output = attention_output[:, :query_length, :]
431
+
432
+ if self.has_cross_attention:
433
+ assert (
434
+ encoder_hidden_states is not None
435
+ ), "encoder_hidden_states must be given for cross-attention layers"
436
+ cross_attention_outputs = self.crossattention(
437
+ query_attention_output,
438
+ attention_mask,
439
+ head_mask,
440
+ encoder_hidden_states,
441
+ encoder_attention_mask,
442
+ output_attentions=output_attentions,
443
+ )
444
+ query_attention_output = cross_attention_outputs[0]
445
+ outputs = (
446
+ outputs + cross_attention_outputs[1:-1]
447
+ ) # add cross attentions if we output attention weights
448
+
449
+ layer_output = apply_chunking_to_forward(
450
+ self.feed_forward_chunk_query,
451
+ self.chunk_size_feed_forward,
452
+ self.seq_len_dim,
453
+ query_attention_output,
454
+ )
455
+ if attention_output.shape[1] > query_length:
456
+ layer_output_text = apply_chunking_to_forward(
457
+ self.feed_forward_chunk,
458
+ self.chunk_size_feed_forward,
459
+ self.seq_len_dim,
460
+ attention_output[:, query_length:, :],
461
+ )
462
+ layer_output = torch.cat([layer_output, layer_output_text], dim=1)
463
+ else:
464
+ layer_output = apply_chunking_to_forward(
465
+ self.feed_forward_chunk,
466
+ self.chunk_size_feed_forward,
467
+ self.seq_len_dim,
468
+ attention_output,
469
+ )
470
+ outputs = (layer_output,) + outputs
471
+
472
+ outputs = outputs + (present_key_value,)
473
+
474
+ return outputs
475
+
476
+ def feed_forward_chunk(self, attention_output):
477
+ intermediate_output = self.intermediate(attention_output)
478
+ layer_output = self.output(intermediate_output, attention_output)
479
+ return layer_output
480
+
481
+ def feed_forward_chunk_query(self, attention_output):
482
+ intermediate_output = self.intermediate_query(attention_output)
483
+ layer_output = self.output_query(intermediate_output, attention_output)
484
+ return layer_output
485
+
486
+
487
+ class BertEncoder(nn.Module):
488
+ def __init__(self, config):
489
+ super().__init__()
490
+ self.config = config
491
+ self.layer = nn.ModuleList(
492
+ [BertLayer(config, i) for i in range(config.num_hidden_layers)]
493
+ )
494
+
495
+ def forward(
496
+ self,
497
+ hidden_states,
498
+ attention_mask=None,
499
+ head_mask=None,
500
+ encoder_hidden_states=None,
501
+ encoder_attention_mask=None,
502
+ past_key_values=None,
503
+ use_cache=None,
504
+ output_attentions=False,
505
+ output_hidden_states=False,
506
+ return_dict=True,
507
+ query_length=0,
508
+ ):
509
+ all_hidden_states = () if output_hidden_states else None
510
+ all_self_attentions = () if output_attentions else None
511
+ all_cross_attentions = (
512
+ () if output_attentions and self.config.add_cross_attention else None
513
+ )
514
+
515
+ next_decoder_cache = () if use_cache else None
516
+
517
+ for i in range(self.config.num_hidden_layers):
518
+ layer_module = self.layer[i]
519
+ if output_hidden_states:
520
+ all_hidden_states = all_hidden_states + (hidden_states,)
521
+
522
+ layer_head_mask = head_mask[i] if head_mask is not None else None
523
+ past_key_value = past_key_values[i] if past_key_values is not None else None
524
+
525
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
526
+
527
+ if use_cache:
528
+ logger.warn(
529
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
530
+ )
531
+ use_cache = False
532
+
533
+ def create_custom_forward(module):
534
+ def custom_forward(*inputs):
535
+ return module(
536
+ *inputs, past_key_value, output_attentions, query_length
537
+ )
538
+
539
+ return custom_forward
540
+
541
+ layer_outputs = torch.utils.checkpoint.checkpoint(
542
+ create_custom_forward(layer_module),
543
+ hidden_states,
544
+ attention_mask,
545
+ layer_head_mask,
546
+ encoder_hidden_states,
547
+ encoder_attention_mask,
548
+ )
549
+ else:
550
+ layer_outputs = layer_module(
551
+ hidden_states,
552
+ attention_mask,
553
+ layer_head_mask,
554
+ encoder_hidden_states,
555
+ encoder_attention_mask,
556
+ past_key_value,
557
+ output_attentions,
558
+ query_length,
559
+ )
560
+
561
+ hidden_states = layer_outputs[0]
562
+ if use_cache:
563
+ next_decoder_cache += (layer_outputs[-1],)
564
+ if output_attentions:
565
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
566
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
567
+
568
+ if output_hidden_states:
569
+ all_hidden_states = all_hidden_states + (hidden_states,)
570
+
571
+ if not return_dict:
572
+ return tuple(
573
+ v
574
+ for v in [
575
+ hidden_states,
576
+ next_decoder_cache,
577
+ all_hidden_states,
578
+ all_self_attentions,
579
+ all_cross_attentions,
580
+ ]
581
+ if v is not None
582
+ )
583
+ return BaseModelOutputWithPastAndCrossAttentions(
584
+ last_hidden_state=hidden_states,
585
+ past_key_values=next_decoder_cache,
586
+ hidden_states=all_hidden_states,
587
+ attentions=all_self_attentions,
588
+ cross_attentions=all_cross_attentions,
589
+ )
590
+
591
+
592
+ class BertPooler(nn.Module):
593
+ def __init__(self, config):
594
+ super().__init__()
595
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
596
+ self.activation = nn.Tanh()
597
+
598
+ def forward(self, hidden_states):
599
+ # We "pool" the model by simply taking the hidden state corresponding
600
+ # to the first token.
601
+ first_token_tensor = hidden_states[:, 0]
602
+ pooled_output = self.dense(first_token_tensor)
603
+ pooled_output = self.activation(pooled_output)
604
+ return pooled_output
605
+
606
+
607
+ class BertPredictionHeadTransform(nn.Module):
608
+ def __init__(self, config):
609
+ super().__init__()
610
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
611
+ if isinstance(config.hidden_act, str):
612
+ self.transform_act_fn = ACT2FN[config.hidden_act]
613
+ else:
614
+ self.transform_act_fn = config.hidden_act
615
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
616
+
617
+ def forward(self, hidden_states):
618
+ hidden_states = self.dense(hidden_states)
619
+ hidden_states = self.transform_act_fn(hidden_states)
620
+ hidden_states = self.LayerNorm(hidden_states)
621
+ return hidden_states
622
+
623
+
624
+ class BertLMPredictionHead(nn.Module):
625
+ def __init__(self, config):
626
+ super().__init__()
627
+ self.transform = BertPredictionHeadTransform(config)
628
+
629
+ # The output weights are the same as the input embeddings, but there is
630
+ # an output-only bias for each token.
631
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
632
+
633
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
634
+
635
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
636
+ self.decoder.bias = self.bias
637
+
638
+ def forward(self, hidden_states):
639
+ hidden_states = self.transform(hidden_states)
640
+ hidden_states = self.decoder(hidden_states)
641
+ return hidden_states
642
+
643
+
644
+ class BertOnlyMLMHead(nn.Module):
645
+ def __init__(self, config):
646
+ super().__init__()
647
+ self.predictions = BertLMPredictionHead(config)
648
+
649
+ def forward(self, sequence_output):
650
+ prediction_scores = self.predictions(sequence_output)
651
+ return prediction_scores
652
+
653
+
654
+ class BertPreTrainedModel(PreTrainedModel):
655
+ """
656
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
657
+ models.
658
+ """
659
+
660
+ config_class = BertConfig
661
+ base_model_prefix = "bert"
662
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
663
+
664
+ def _init_weights(self, module):
665
+ """Initialize the weights"""
666
+ if isinstance(module, (nn.Linear, nn.Embedding)):
667
+ # Slightly different from the TF version which uses truncated_normal for initialization
668
+ # cf https://github.com/pytorch/pytorch/pull/5617
669
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
670
+ elif isinstance(module, nn.LayerNorm):
671
+ module.bias.data.zero_()
672
+ module.weight.data.fill_(1.0)
673
+ if isinstance(module, nn.Linear) and module.bias is not None:
674
+ module.bias.data.zero_()
675
+
676
+
677
+ class BertModel(BertPreTrainedModel):
678
+ """
679
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
680
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
681
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
682
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
683
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
684
+ input to the forward pass.
685
+ """
686
+
687
+ def __init__(self, config, add_pooling_layer=False):
688
+ super().__init__(config)
689
+ self.config = config
690
+
691
+ self.embeddings = BertEmbeddings(config)
692
+
693
+ self.encoder = BertEncoder(config)
694
+
695
+ self.pooler = BertPooler(config) if add_pooling_layer else None
696
+
697
+ self.init_weights()
698
+
699
+ def get_input_embeddings(self):
700
+ return self.embeddings.word_embeddings
701
+
702
+ def set_input_embeddings(self, value):
703
+ self.embeddings.word_embeddings = value
704
+
705
+ def _prune_heads(self, heads_to_prune):
706
+ """
707
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
708
+ class PreTrainedModel
709
+ """
710
+ for layer, heads in heads_to_prune.items():
711
+ self.encoder.layer[layer].attention.prune_heads(heads)
712
+
713
+ def get_extended_attention_mask(
714
+ self,
715
+ attention_mask: Tensor,
716
+ input_shape: Tuple[int],
717
+ device: device,
718
+ is_decoder: bool,
719
+ has_query: bool = False,
720
+ ) -> Tensor:
721
+ """
722
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
723
+
724
+ Arguments:
725
+ attention_mask (:obj:`torch.Tensor`):
726
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
727
+ input_shape (:obj:`Tuple[int]`):
728
+ The shape of the input to the model.
729
+ device: (:obj:`torch.device`):
730
+ The device of the input to the model.
731
+
732
+ Returns:
733
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
734
+ """
735
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
736
+ # ourselves in which case we just need to make it broadcastable to all heads.
737
+ if attention_mask.dim() == 3:
738
+ extended_attention_mask = attention_mask[:, None, :, :]
739
+ elif attention_mask.dim() == 2:
740
+ # Provided a padding mask of dimensions [batch_size, seq_length]
741
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
742
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
743
+ if is_decoder:
744
+ batch_size, seq_length = input_shape
745
+
746
+ seq_ids = torch.arange(seq_length, device=device)
747
+ causal_mask = (
748
+ seq_ids[None, None, :].repeat(batch_size, seq_length, 1)
749
+ <= seq_ids[None, :, None]
750
+ )
751
+
752
+ # add a prefix ones mask to the causal mask
753
+ # causal and attention masks must have same type with pytorch version < 1.3
754
+ causal_mask = causal_mask.to(attention_mask.dtype)
755
+
756
+ if causal_mask.shape[1] < attention_mask.shape[1]:
757
+ prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
758
+ if has_query: # UniLM style attention mask
759
+ causal_mask = torch.cat(
760
+ [
761
+ torch.zeros(
762
+ (batch_size, prefix_seq_len, seq_length),
763
+ device=device,
764
+ dtype=causal_mask.dtype,
765
+ ),
766
+ causal_mask,
767
+ ],
768
+ axis=1,
769
+ )
770
+ causal_mask = torch.cat(
771
+ [
772
+ torch.ones(
773
+ (batch_size, causal_mask.shape[1], prefix_seq_len),
774
+ device=device,
775
+ dtype=causal_mask.dtype,
776
+ ),
777
+ causal_mask,
778
+ ],
779
+ axis=-1,
780
+ )
781
+ extended_attention_mask = (
782
+ causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
783
+ )
784
+ else:
785
+ extended_attention_mask = attention_mask[:, None, None, :]
786
+ else:
787
+ raise ValueError(
788
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
789
+ input_shape, attention_mask.shape
790
+ )
791
+ )
792
+
793
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
794
+ # masked positions, this operation will create a tensor which is 0.0 for
795
+ # positions we want to attend and -10000.0 for masked positions.
796
+ # Since we are adding it to the raw scores before the softmax, this is
797
+ # effectively the same as removing these entirely.
798
+ extended_attention_mask = extended_attention_mask.to(
799
+ dtype=self.dtype
800
+ ) # fp16 compatibility
801
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
802
+ return extended_attention_mask
803
+
804
+ def forward(
805
+ self,
806
+ input_ids=None,
807
+ attention_mask=None,
808
+ position_ids=None,
809
+ head_mask=None,
810
+ query_embeds=None,
811
+ encoder_hidden_states=None,
812
+ encoder_attention_mask=None,
813
+ past_key_values=None,
814
+ use_cache=None,
815
+ output_attentions=None,
816
+ output_hidden_states=None,
817
+ return_dict=None,
818
+ is_decoder=False,
819
+ ):
820
+ r"""
821
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
822
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
823
+ the model is configured as a decoder.
824
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
825
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
826
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
827
+ - 1 for tokens that are **not masked**,
828
+ - 0 for tokens that are **masked**.
829
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
830
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
831
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
832
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
833
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
834
+ use_cache (:obj:`bool`, `optional`):
835
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
836
+ decoding (see :obj:`past_key_values`).
837
+ """
838
+ output_attentions = (
839
+ output_attentions
840
+ if output_attentions is not None
841
+ else self.config.output_attentions
842
+ )
843
+ output_hidden_states = (
844
+ output_hidden_states
845
+ if output_hidden_states is not None
846
+ else self.config.output_hidden_states
847
+ )
848
+ return_dict = (
849
+ return_dict if return_dict is not None else self.config.use_return_dict
850
+ )
851
+
852
+ # use_cache = use_cache if use_cache is not None else self.config.use_cache
853
+
854
+ if input_ids is None:
855
+ assert (
856
+ query_embeds is not None
857
+ ), "You have to specify query_embeds when input_ids is None"
858
+
859
+ # past_key_values_length
860
+ past_key_values_length = (
861
+ past_key_values[0][0].shape[2] - self.config.query_length
862
+ if past_key_values is not None
863
+ else 0
864
+ )
865
+
866
+ query_length = query_embeds.shape[1] if query_embeds is not None else 0
867
+
868
+ embedding_output = self.embeddings(
869
+ input_ids=input_ids,
870
+ position_ids=position_ids,
871
+ query_embeds=query_embeds,
872
+ past_key_values_length=past_key_values_length,
873
+ )
874
+
875
+ input_shape = embedding_output.size()[:-1]
876
+ batch_size, seq_length = input_shape
877
+ device = embedding_output.device
878
+
879
+ if attention_mask is None:
880
+ attention_mask = torch.ones(
881
+ ((batch_size, seq_length + past_key_values_length)), device=device
882
+ )
883
+
884
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
885
+ # ourselves in which case we just need to make it broadcastable to all heads.
886
+ if is_decoder:
887
+ extended_attention_mask = self.get_extended_attention_mask(
888
+ attention_mask,
889
+ input_ids.shape,
890
+ device,
891
+ is_decoder,
892
+ has_query=(query_embeds is not None),
893
+ )
894
+ else:
895
+ extended_attention_mask = self.get_extended_attention_mask(
896
+ attention_mask, input_shape, device, is_decoder
897
+ )
898
+
899
+ # If a 2D or 3D attention mask is provided for the cross-attention
900
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
901
+ if encoder_hidden_states is not None:
902
+ if type(encoder_hidden_states) == list:
903
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[
904
+ 0
905
+ ].size()
906
+ else:
907
+ (
908
+ encoder_batch_size,
909
+ encoder_sequence_length,
910
+ _,
911
+ ) = encoder_hidden_states.size()
912
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
913
+
914
+ if type(encoder_attention_mask) == list:
915
+ encoder_extended_attention_mask = [
916
+ self.invert_attention_mask(mask) for mask in encoder_attention_mask
917
+ ]
918
+ elif encoder_attention_mask is None:
919
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
920
+ encoder_extended_attention_mask = self.invert_attention_mask(
921
+ encoder_attention_mask
922
+ )
923
+ else:
924
+ encoder_extended_attention_mask = self.invert_attention_mask(
925
+ encoder_attention_mask
926
+ )
927
+ else:
928
+ encoder_extended_attention_mask = None
929
+
930
+ # Prepare head mask if needed
931
+ # 1.0 in head_mask indicate we keep the head
932
+ # attention_probs has shape bsz x n_heads x N x N
933
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
934
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
935
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
936
+
937
+ encoder_outputs = self.encoder(
938
+ embedding_output,
939
+ attention_mask=extended_attention_mask,
940
+ head_mask=head_mask,
941
+ encoder_hidden_states=encoder_hidden_states,
942
+ encoder_attention_mask=encoder_extended_attention_mask,
943
+ past_key_values=past_key_values,
944
+ use_cache=use_cache,
945
+ output_attentions=output_attentions,
946
+ output_hidden_states=output_hidden_states,
947
+ return_dict=return_dict,
948
+ query_length=query_length,
949
+ )
950
+ sequence_output = encoder_outputs[0]
951
+ pooled_output = (
952
+ self.pooler(sequence_output) if self.pooler is not None else None
953
+ )
954
+
955
+ if not return_dict:
956
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
957
+
958
+ return BaseModelOutputWithPoolingAndCrossAttentions(
959
+ last_hidden_state=sequence_output,
960
+ pooler_output=pooled_output,
961
+ past_key_values=encoder_outputs.past_key_values,
962
+ hidden_states=encoder_outputs.hidden_states,
963
+ attentions=encoder_outputs.attentions,
964
+ cross_attentions=encoder_outputs.cross_attentions,
965
+ )
966
+
967
+
968
+ class BertLMHeadModel(BertPreTrainedModel):
969
+
970
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
971
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
972
+
973
+ def __init__(self, config):
974
+ super().__init__(config)
975
+
976
+ self.bert = BertModel(config, add_pooling_layer=False)
977
+ self.cls = BertOnlyMLMHead(config)
978
+
979
+ self.init_weights()
980
+
981
+ def get_output_embeddings(self):
982
+ return self.cls.predictions.decoder
983
+
984
+ def set_output_embeddings(self, new_embeddings):
985
+ self.cls.predictions.decoder = new_embeddings
986
+
987
+ def forward(
988
+ self,
989
+ input_ids=None,
990
+ attention_mask=None,
991
+ position_ids=None,
992
+ head_mask=None,
993
+ query_embeds=None,
994
+ encoder_hidden_states=None,
995
+ encoder_attention_mask=None,
996
+ labels=None,
997
+ past_key_values=None,
998
+ use_cache=True,
999
+ output_attentions=None,
1000
+ output_hidden_states=None,
1001
+ return_dict=None,
1002
+ return_logits=False,
1003
+ is_decoder=True,
1004
+ reduction="mean",
1005
+ ):
1006
+ r"""
1007
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
1008
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1009
+ the model is configured as a decoder.
1010
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1011
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1012
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
1013
+ - 1 for tokens that are **not masked**,
1014
+ - 0 for tokens that are **masked**.
1015
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1016
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1017
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
1018
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
1019
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1020
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1021
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1022
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
1023
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1024
+ use_cache (:obj:`bool`, `optional`):
1025
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
1026
+ decoding (see :obj:`past_key_values`).
1027
+ Returns:
1028
+ Example::
1029
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
1030
+ >>> import torch
1031
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
1032
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
1033
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
1034
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1035
+ >>> outputs = model(**inputs)
1036
+ >>> prediction_logits = outputs.logits
1037
+ """
1038
+ return_dict = (
1039
+ return_dict if return_dict is not None else self.config.use_return_dict
1040
+ )
1041
+ if labels is not None:
1042
+ use_cache = False
1043
+ if past_key_values is not None:
1044
+ query_embeds = None
1045
+
1046
+ outputs = self.bert(
1047
+ input_ids,
1048
+ attention_mask=attention_mask,
1049
+ position_ids=position_ids,
1050
+ head_mask=head_mask,
1051
+ query_embeds=query_embeds,
1052
+ encoder_hidden_states=encoder_hidden_states,
1053
+ encoder_attention_mask=encoder_attention_mask,
1054
+ past_key_values=past_key_values,
1055
+ use_cache=use_cache,
1056
+ output_attentions=output_attentions,
1057
+ output_hidden_states=output_hidden_states,
1058
+ return_dict=return_dict,
1059
+ is_decoder=is_decoder,
1060
+ )
1061
+
1062
+ sequence_output = outputs[0]
1063
+ if query_embeds is not None:
1064
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1065
+
1066
+ prediction_scores = self.cls(sequence_output)
1067
+
1068
+ if return_logits:
1069
+ return prediction_scores[:, :-1, :].contiguous()
1070
+
1071
+ lm_loss = None
1072
+ if labels is not None:
1073
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1074
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1075
+ labels = labels[:, 1:].contiguous()
1076
+ loss_fct = CrossEntropyLoss(reduction=reduction, label_smoothing=0.1)
1077
+ lm_loss = loss_fct(
1078
+ shifted_prediction_scores.view(-1, self.config.vocab_size),
1079
+ labels.view(-1),
1080
+ )
1081
+ if reduction == "none":
1082
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
1083
+
1084
+ if not return_dict:
1085
+ output = (prediction_scores,) + outputs[2:]
1086
+ return ((lm_loss,) + output) if lm_loss is not None else output
1087
+
1088
+ return CausalLMOutputWithCrossAttentions(
1089
+ loss=lm_loss,
1090
+ logits=prediction_scores,
1091
+ past_key_values=outputs.past_key_values,
1092
+ hidden_states=outputs.hidden_states,
1093
+ attentions=outputs.attentions,
1094
+ cross_attentions=outputs.cross_attentions,
1095
+ )
1096
+
1097
+ def prepare_inputs_for_generation(
1098
+ self, input_ids, query_embeds, past=None, attention_mask=None, **model_kwargs
1099
+ ):
1100
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1101
+ if attention_mask is None:
1102
+ attention_mask = input_ids.new_ones(input_ids.shape)
1103
+ query_mask = input_ids.new_ones(query_embeds.shape[:-1])
1104
+ attention_mask = torch.cat([query_mask, attention_mask], dim=-1)
1105
+
1106
+ # cut decoder_input_ids if past is used
1107
+ if past is not None:
1108
+ input_ids = input_ids[:, -1:]
1109
+
1110
+ return {
1111
+ "input_ids": input_ids,
1112
+ "query_embeds": query_embeds,
1113
+ "attention_mask": attention_mask,
1114
+ "past_key_values": past,
1115
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
1116
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
1117
+ "is_decoder": True,
1118
+ }
1119
+
1120
+ def _reorder_cache(self, past, beam_idx):
1121
+ reordered_past = ()
1122
+ for layer_past in past:
1123
+ reordered_past += (
1124
+ tuple(
1125
+ past_state.index_select(0, beam_idx) for past_state in layer_past
1126
+ ),
1127
+ )
1128
+ return reordered_past
1129
+
1130
+
1131
+ class BertForMaskedLM(BertPreTrainedModel):
1132
+
1133
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1134
+ _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1135
+
1136
+ def __init__(self, config):
1137
+ super().__init__(config)
1138
+
1139
+ self.bert = BertModel(config, add_pooling_layer=False)
1140
+ self.cls = BertOnlyMLMHead(config)
1141
+
1142
+ self.init_weights()
1143
+
1144
+ def get_output_embeddings(self):
1145
+ return self.cls.predictions.decoder
1146
+
1147
+ def set_output_embeddings(self, new_embeddings):
1148
+ self.cls.predictions.decoder = new_embeddings
1149
+
1150
+ def forward(
1151
+ self,
1152
+ input_ids=None,
1153
+ attention_mask=None,
1154
+ position_ids=None,
1155
+ head_mask=None,
1156
+ query_embeds=None,
1157
+ encoder_hidden_states=None,
1158
+ encoder_attention_mask=None,
1159
+ labels=None,
1160
+ output_attentions=None,
1161
+ output_hidden_states=None,
1162
+ return_dict=None,
1163
+ return_logits=False,
1164
+ is_decoder=False,
1165
+ ):
1166
+ r"""
1167
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1168
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1169
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1170
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1171
+ """
1172
+
1173
+ return_dict = (
1174
+ return_dict if return_dict is not None else self.config.use_return_dict
1175
+ )
1176
+
1177
+ outputs = self.bert(
1178
+ input_ids,
1179
+ attention_mask=attention_mask,
1180
+ position_ids=position_ids,
1181
+ head_mask=head_mask,
1182
+ query_embeds=query_embeds,
1183
+ encoder_hidden_states=encoder_hidden_states,
1184
+ encoder_attention_mask=encoder_attention_mask,
1185
+ output_attentions=output_attentions,
1186
+ output_hidden_states=output_hidden_states,
1187
+ return_dict=return_dict,
1188
+ is_decoder=is_decoder,
1189
+ )
1190
+
1191
+ if query_embeds is not None:
1192
+ sequence_output = outputs[0][:, query_embeds.shape[1] :, :]
1193
+ prediction_scores = self.cls(sequence_output)
1194
+
1195
+ if return_logits:
1196
+ return prediction_scores
1197
+
1198
+ masked_lm_loss = None
1199
+ if labels is not None:
1200
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1201
+ masked_lm_loss = loss_fct(
1202
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)
1203
+ )
1204
+
1205
+ if not return_dict:
1206
+ output = (prediction_scores,) + outputs[2:]
1207
+ return (
1208
+ ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1209
+ )
1210
+
1211
+ return MaskedLMOutput(
1212
+ loss=masked_lm_loss,
1213
+ logits=prediction_scores,
1214
+ hidden_states=outputs.hidden_states,
1215
+ attentions=outputs.attentions,
1216
+ )
cheetah/models/__init__.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import torch
10
+ from omegaconf import OmegaConf
11
+
12
+ from cheetah.common.registry import registry
13
+ from cheetah.models.base_model import BaseModel
14
+ from cheetah.models.blip2 import Blip2Base
15
+ from cheetah.models.cheetah_vicuna import Cheetah_Vicuna
16
+ from cheetah.models.cheetah_llama2 import Cheetah_Llama2
17
+ from cheetah.processors.base_processor import BaseProcessor
18
+
19
+
20
+ __all__ = [
21
+ "load_model",
22
+ "BaseModel",
23
+ "Blip2Base",
24
+ "Cheetah_Vicuna",
25
+ "Cheetah_Llama2"
26
+ ]
27
+
28
+
29
+ def load_model(name, model_type, is_eval=False, device="cpu", checkpoint=None):
30
+ """
31
+ Load supported models.
32
+
33
+ To list all available models and types in registry:
34
+ >>> from cheetah.models import model_zoo
35
+ >>> print(model_zoo)
36
+
37
+ Args:
38
+ name (str): name of the model.
39
+ model_type (str): type of the model.
40
+ is_eval (bool): whether the model is in eval mode. Default: False.
41
+ device (str): device to use. Default: "cpu".
42
+ checkpoint (str): path or to checkpoint. Default: None.
43
+ Note that expecting the checkpoint to have the same keys in state_dict as the model.
44
+
45
+ Returns:
46
+ model (torch.nn.Module): model.
47
+ """
48
+
49
+ model = registry.get_model_class(name).from_pretrained(model_type=model_type)
50
+
51
+ if checkpoint is not None:
52
+ model.load_checkpoint(checkpoint)
53
+
54
+ if is_eval:
55
+ model.eval()
56
+
57
+ if device == "cpu":
58
+ model = model.float()
59
+
60
+ return model.to(device)
61
+
62
+
63
+ def load_preprocess(config):
64
+ """
65
+ Load preprocessor configs and construct preprocessors.
66
+
67
+ If no preprocessor is specified, return BaseProcessor, which does not do any preprocessing.
68
+
69
+ Args:
70
+ config (dict): preprocessor configs.
71
+
72
+ Returns:
73
+ vis_processors (dict): preprocessors for visual inputs.
74
+ txt_processors (dict): preprocessors for text inputs.
75
+
76
+ Key is "train" or "eval" for processors used in training and evaluation respectively.
77
+ """
78
+
79
+ def _build_proc_from_cfg(cfg):
80
+ return (
81
+ registry.get_processor_class(cfg.name).from_config(cfg)
82
+ if cfg is not None
83
+ else BaseProcessor()
84
+ )
85
+
86
+ vis_processors = dict()
87
+ txt_processors = dict()
88
+
89
+ vis_proc_cfg = config.get("vis_processor")
90
+ txt_proc_cfg = config.get("text_processor")
91
+
92
+ if vis_proc_cfg is not None:
93
+ vis_train_cfg = vis_proc_cfg.get("train")
94
+ vis_eval_cfg = vis_proc_cfg.get("eval")
95
+ else:
96
+ vis_train_cfg = None
97
+ vis_eval_cfg = None
98
+
99
+ vis_processors["train"] = _build_proc_from_cfg(vis_train_cfg)
100
+ vis_processors["eval"] = _build_proc_from_cfg(vis_eval_cfg)
101
+
102
+ if txt_proc_cfg is not None:
103
+ txt_train_cfg = txt_proc_cfg.get("train")
104
+ txt_eval_cfg = txt_proc_cfg.get("eval")
105
+ else:
106
+ txt_train_cfg = None
107
+ txt_eval_cfg = None
108
+
109
+ txt_processors["train"] = _build_proc_from_cfg(txt_train_cfg)
110
+ txt_processors["eval"] = _build_proc_from_cfg(txt_eval_cfg)
111
+
112
+ return vis_processors, txt_processors
113
+
114
+
115
+ def load_model_and_preprocess(name, model_type, is_eval=False, device="cpu"):
116
+ """
117
+ Load model and its related preprocessors.
118
+
119
+ List all available models and types in registry:
120
+ >>> from cheetah.models import model_zoo
121
+ >>> print(model_zoo)
122
+
123
+ Args:
124
+ name (str): name of the model.
125
+ model_type (str): type of the model.
126
+ is_eval (bool): whether the model is in eval mode. Default: False.
127
+ device (str): device to use. Default: "cpu".
128
+
129
+ Returns:
130
+ model (torch.nn.Module): model.
131
+ vis_processors (dict): preprocessors for visual inputs.
132
+ txt_processors (dict): preprocessors for text inputs.
133
+ """
134
+ model_cls = registry.get_model_class(name)
135
+
136
+ # load model
137
+ model = model_cls.from_pretrained(model_type=model_type)
138
+
139
+ if is_eval:
140
+ model.eval()
141
+
142
+ # load preprocess
143
+ cfg = OmegaConf.load(model_cls.default_config_path(model_type))
144
+ if cfg is not None:
145
+ preprocess_cfg = cfg.preprocess
146
+
147
+ vis_processors, txt_processors = load_preprocess(preprocess_cfg)
148
+ else:
149
+ vis_processors, txt_processors = None, None
150
+ logging.info(
151
+ f"""No default preprocess for model {name} ({model_type}).
152
+ This can happen if the model is not finetuned on downstream datasets,
153
+ or it is not intended for direct use without finetuning.
154
+ """
155
+ )
156
+
157
+ if device == "cpu" or device == torch.device("cpu"):
158
+ model = model.float()
159
+
160
+ return model.to(device), vis_processors, txt_processors
161
+
162
+
163
+ class ModelZoo:
164
+ """
165
+ A utility class to create string representation of available model architectures and types.
166
+
167
+ >>> from cheetah.models import model_zoo
168
+ >>> # list all available models
169
+ >>> print(model_zoo)
170
+ >>> # show total number of models
171
+ >>> print(len(model_zoo))
172
+ """
173
+
174
+ def __init__(self) -> None:
175
+ self.model_zoo = {
176
+ k: list(v.PRETRAINED_MODEL_CONFIG_DICT.keys())
177
+ for k, v in registry.mapping["model_name_mapping"].items()
178
+ }
179
+
180
+ def __str__(self) -> str:
181
+ return (
182
+ "=" * 50
183
+ + "\n"
184
+ + f"{'Architectures':<30} {'Types'}\n"
185
+ + "=" * 50
186
+ + "\n"
187
+ + "\n".join(
188
+ [
189
+ f"{name:<30} {', '.join(types)}"
190
+ for name, types in self.model_zoo.items()
191
+ ]
192
+ )
193
+ )
194
+
195
+ def __iter__(self):
196
+ return iter(self.model_zoo.items())
197
+
198
+ def __len__(self):
199
+ return sum([len(v) for v in self.model_zoo.values()])
200
+
201
+
202
+ model_zoo = ModelZoo()
cheetah/models/__pycache__/Qformer.cpython-310.pyc ADDED
Binary file (30.6 kB). View file
 
cheetah/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (6.17 kB). View file
 
cheetah/models/__pycache__/base_model.cpython-310.pyc ADDED
Binary file (8.65 kB). View file
 
cheetah/models/__pycache__/blip2.cpython-310.pyc ADDED
Binary file (6.43 kB). View file
 
cheetah/models/__pycache__/cheetah_llama2.cpython-310.pyc ADDED
Binary file (10.6 kB). View file
 
cheetah/models/__pycache__/cheetah_vicuna.cpython-310.pyc ADDED
Binary file (10.6 kB). View file
 
cheetah/models/__pycache__/eva_vit.cpython-310.pyc ADDED
Binary file (14.1 kB). View file
 
cheetah/models/__pycache__/modeling_llama.cpython-310.pyc ADDED
Binary file (26.7 kB). View file
 
cheetah/models/__pycache__/modeling_llama2.cpython-310.pyc ADDED
Binary file (35 kB). View file
 
cheetah/models/base_model.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ import logging
9
+ import os
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ from cheetah.common.dist_utils import download_cached_file, is_dist_avail_and_initialized
15
+ from cheetah.common.utils import get_abs_path, is_url
16
+ from omegaconf import OmegaConf
17
+
18
+
19
+ class BaseModel(nn.Module):
20
+ """Base class for models."""
21
+
22
+ def __init__(self):
23
+ super().__init__()
24
+
25
+ @property
26
+ def device(self):
27
+ return list(self.parameters())[0].device
28
+
29
+ def load_checkpoint(self, url_or_filename):
30
+ """
31
+ Load from a finetuned checkpoint.
32
+
33
+ This should expect no mismatch in the model keys and the checkpoint keys.
34
+ """
35
+
36
+ if is_url(url_or_filename):
37
+ cached_file = download_cached_file(
38
+ url_or_filename, check_hash=False, progress=True
39
+ )
40
+ checkpoint = torch.load(cached_file, map_location="cpu")
41
+ elif os.path.isfile(url_or_filename):
42
+ checkpoint = torch.load(url_or_filename, map_location="cpu")
43
+ else:
44
+ raise RuntimeError("checkpoint url or path is invalid")
45
+
46
+ if "model" in checkpoint.keys():
47
+ state_dict = checkpoint["model"]
48
+ else:
49
+ state_dict = checkpoint
50
+
51
+ msg = self.load_state_dict(state_dict, strict=False)
52
+
53
+ logging.info("Missing keys {}".format(msg.missing_keys))
54
+ logging.info("load checkpoint from %s" % url_or_filename)
55
+
56
+ return msg
57
+
58
+ @classmethod
59
+ def from_pretrained(cls, model_type):
60
+ """
61
+ Build a pretrained model from default configuration file, specified by model_type.
62
+
63
+ Args:
64
+ - model_type (str): model type, specifying architecture and checkpoints.
65
+
66
+ Returns:
67
+ - model (nn.Module): pretrained or finetuned model, depending on the configuration.
68
+ """
69
+ model_cfg = OmegaConf.load(cls.default_config_path(model_type)).model
70
+ model = cls.from_config(model_cfg)
71
+
72
+ return model
73
+
74
+ @classmethod
75
+ def default_config_path(cls, model_type):
76
+ assert (
77
+ model_type in cls.PRETRAINED_MODEL_CONFIG_DICT
78
+ ), "Unknown model type {}".format(model_type)
79
+ return get_abs_path(cls.PRETRAINED_MODEL_CONFIG_DICT[model_type])
80
+
81
+ def load_checkpoint_from_config(self, cfg, **kwargs):
82
+ """
83
+ Load checkpoint as specified in the config file.
84
+
85
+ If load_finetuned is True, load the finetuned model; otherwise, load the pretrained model.
86
+ When loading the pretrained model, each task-specific architecture may define their
87
+ own load_from_pretrained() method.
88
+ """
89
+ load_finetuned = cfg.get("load_finetuned", True)
90
+ if load_finetuned:
91
+ finetune_path = cfg.get("finetuned", None)
92
+ assert (
93
+ finetune_path is not None
94
+ ), "Found load_finetuned is True, but finetune_path is None."
95
+ self.load_checkpoint(url_or_filename=finetune_path)
96
+ else:
97
+ # load pre-trained weights
98
+ pretrain_path = cfg.get("pretrained", None)
99
+ assert "Found load_finetuned is False, but pretrain_path is None."
100
+ self.load_from_pretrained(url_or_filename=pretrain_path, **kwargs)
101
+
102
+ def before_evaluation(self, **kwargs):
103
+ pass
104
+
105
+ def show_n_params(self, return_str=True):
106
+ tot = 0
107
+ for p in self.parameters():
108
+ w = 1
109
+ for x in p.shape:
110
+ w *= x
111
+ tot += w
112
+ if return_str:
113
+ if tot >= 1e6:
114
+ return "{:.1f}M".format(tot / 1e6)
115
+ else:
116
+ return "{:.1f}K".format(tot / 1e3)
117
+ else:
118
+ return tot
119
+
120
+
121
+ class BaseEncoder(nn.Module):
122
+ """
123
+ Base class for primitive encoders, such as ViT, TimeSformer, etc.
124
+ """
125
+
126
+ def __init__(self):
127
+ super().__init__()
128
+
129
+ def forward_features(self, samples, **kwargs):
130
+ raise NotImplementedError
131
+
132
+ @property
133
+ def device(self):
134
+ return list(self.parameters())[0].device
135
+
136
+
137
+ class SharedQueueMixin:
138
+ @torch.no_grad()
139
+ def _dequeue_and_enqueue(self, image_feat, text_feat, idxs=None):
140
+ # gather keys before updating queue
141
+ image_feats = concat_all_gather(image_feat)
142
+ text_feats = concat_all_gather(text_feat)
143
+
144
+ batch_size = image_feats.shape[0]
145
+
146
+ ptr = int(self.queue_ptr)
147
+ assert self.queue_size % batch_size == 0 # for simplicity
148
+
149
+ # replace the keys at ptr (dequeue and enqueue)
150
+ self.image_queue[:, ptr : ptr + batch_size] = image_feats.T
151
+ self.text_queue[:, ptr : ptr + batch_size] = text_feats.T
152
+
153
+ if idxs is not None:
154
+ idxs = concat_all_gather(idxs)
155
+ self.idx_queue[:, ptr : ptr + batch_size] = idxs.T
156
+
157
+ ptr = (ptr + batch_size) % self.queue_size # move pointer
158
+ self.queue_ptr[0] = ptr
159
+
160
+
161
+ class MomentumDistilationMixin:
162
+ @torch.no_grad()
163
+ def copy_params(self):
164
+ for model_pair in self.model_pairs:
165
+ for param, param_m in zip(
166
+ model_pair[0].parameters(), model_pair[1].parameters()
167
+ ):
168
+ param_m.data.copy_(param.data) # initialize
169
+ param_m.requires_grad = False # not update by gradient
170
+
171
+ @torch.no_grad()
172
+ def _momentum_update(self):
173
+ for model_pair in self.model_pairs:
174
+ for param, param_m in zip(
175
+ model_pair[0].parameters(), model_pair[1].parameters()
176
+ ):
177
+ param_m.data = param_m.data * self.momentum + param.data * (
178
+ 1.0 - self.momentum
179
+ )
180
+
181
+
182
+ class GatherLayer(torch.autograd.Function):
183
+ """
184
+ Gather tensors from all workers with support for backward propagation:
185
+ This implementation does not cut the gradients as torch.distributed.all_gather does.
186
+ """
187
+
188
+ @staticmethod
189
+ def forward(ctx, x):
190
+ output = [
191
+ torch.zeros_like(x) for _ in range(torch.distributed.get_world_size())
192
+ ]
193
+ torch.distributed.all_gather(output, x)
194
+ return tuple(output)
195
+
196
+ @staticmethod
197
+ def backward(ctx, *grads):
198
+ all_gradients = torch.stack(grads)
199
+ torch.distributed.all_reduce(all_gradients)
200
+ return all_gradients[torch.distributed.get_rank()]
201
+
202
+
203
+ def all_gather_with_grad(tensors):
204
+ """
205
+ Performs all_gather operation on the provided tensors.
206
+ Graph remains connected for backward grad computation.
207
+ """
208
+ # Queue the gathered tensors
209
+ world_size = torch.distributed.get_world_size()
210
+ # There is no need for reduction in the single-proc case
211
+ if world_size == 1:
212
+ return tensors
213
+
214
+ # tensor_all = GatherLayer.apply(tensors)
215
+ tensor_all = GatherLayer.apply(tensors)
216
+
217
+ return torch.cat(tensor_all, dim=0)
218
+
219
+
220
+ @torch.no_grad()
221
+ def concat_all_gather(tensor):
222
+ """
223
+ Performs all_gather operation on the provided tensors.
224
+ *** Warning ***: torch.distributed.all_gather has no gradient.
225
+ """
226
+ # if use distributed training
227
+ if not is_dist_avail_and_initialized():
228
+ return tensor
229
+
230
+ tensors_gather = [
231
+ torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())
232
+ ]
233
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
234
+
235
+ output = torch.cat(tensors_gather, dim=0)
236
+ return output
237
+
238
+
239
+ def tile(x, dim, n_tile):
240
+ init_dim = x.size(dim)
241
+ repeat_idx = [1] * x.dim()
242
+ repeat_idx[dim] = n_tile
243
+ x = x.repeat(*(repeat_idx))
244
+ order_index = torch.LongTensor(
245
+ np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])
246
+ )
247
+ return torch.index_select(x, dim, order_index.to(x.device))
cheetah/models/blip2.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2023, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+ import contextlib
8
+ import logging
9
+ import os
10
+ import time
11
+ import datetime
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.distributed as dist
16
+ import torch.nn.functional as F
17
+
18
+ import cheetah.common.dist_utils as dist_utils
19
+ from cheetah.common.dist_utils import download_cached_file
20
+ from cheetah.common.utils import is_url
21
+ from cheetah.common.logger import MetricLogger
22
+ from cheetah.models.base_model import BaseModel
23
+ from cheetah.models.Qformer import BertConfig, BertLMHeadModel
24
+ from cheetah.models.eva_vit import create_eva_vit_g
25
+ from transformers import BertTokenizer
26
+
27
+
28
+ class Blip2Base(BaseModel):
29
+ @classmethod
30
+ def init_tokenizer(cls):
31
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
32
+ tokenizer.add_special_tokens({"bos_token": "[DEC]"})
33
+ return tokenizer
34
+
35
+ def maybe_autocast(self, dtype=torch.float16):
36
+ # if on cpu, don't use autocast
37
+ # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
38
+ enable_autocast = self.device != torch.device("cpu")
39
+
40
+ if enable_autocast:
41
+ return torch.cuda.amp.autocast(dtype=dtype)
42
+ else:
43
+ return contextlib.nullcontext()
44
+
45
+ @classmethod
46
+ def init_Qformer(cls, num_query_token, vision_width, cross_attention_freq=2):
47
+ encoder_config = BertConfig.from_pretrained("bert-base-uncased")
48
+ encoder_config.encoder_width = vision_width
49
+ # insert cross-attention layer every other block
50
+ encoder_config.add_cross_attention = True
51
+ encoder_config.cross_attention_freq = cross_attention_freq
52
+ encoder_config.query_length = num_query_token
53
+ Qformer = BertLMHeadModel(config=encoder_config)
54
+ query_tokens = nn.Parameter(
55
+ torch.zeros(1, num_query_token, encoder_config.hidden_size)
56
+ )
57
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
58
+ return Qformer, query_tokens
59
+
60
+ @classmethod
61
+ def init_vision_encoder(
62
+ cls, model_name, img_size, drop_path_rate, use_grad_checkpoint, precision
63
+ ):
64
+ assert model_name == "eva_clip_g", "vit model must be eva_clip_g for current version of Cheetah"
65
+ visual_encoder = create_eva_vit_g(
66
+ img_size, drop_path_rate, use_grad_checkpoint, precision
67
+ )
68
+
69
+ ln_vision = LayerNorm(visual_encoder.num_features)
70
+ return visual_encoder, ln_vision
71
+
72
+ def load_from_pretrained(self, url_or_filename):
73
+ if is_url(url_or_filename):
74
+ cached_file = download_cached_file(
75
+ url_or_filename, check_hash=False, progress=True
76
+ )
77
+ checkpoint = torch.load(cached_file, map_location="cpu")
78
+ elif os.path.isfile(url_or_filename):
79
+ checkpoint = torch.load(url_or_filename, map_location="cpu")
80
+ else:
81
+ raise RuntimeError("checkpoint url or path is invalid")
82
+
83
+ state_dict = checkpoint["model"]
84
+
85
+ msg = self.load_state_dict(state_dict, strict=False)
86
+
87
+ # logging.info("Missing keys {}".format(msg.missing_keys))
88
+ logging.info("load checkpoint from %s" % url_or_filename)
89
+
90
+ return msg
91
+
92
+
93
+ def disabled_train(self, mode=True):
94
+ """Overwrite model.train with this function to make sure train/eval mode
95
+ does not change anymore."""
96
+ return self
97
+
98
+
99
+ class LayerNorm(nn.LayerNorm):
100
+ """Subclass torch's LayerNorm to handle fp16."""
101
+
102
+ def forward(self, x: torch.Tensor):
103
+ orig_type = x.dtype
104
+ ret = super().forward(x.type(torch.float32))
105
+ return ret.type(orig_type)
106
+
107
+
108
+ def compute_sim_matrix(model, data_loader, **kwargs):
109
+ k_test = kwargs.pop("k_test")
110
+
111
+ metric_logger = MetricLogger(delimiter=" ")
112
+ header = "Evaluation:"
113
+
114
+ logging.info("Computing features for evaluation...")
115
+ start_time = time.time()
116
+
117
+ texts = data_loader.dataset.text
118
+ num_text = len(texts)
119
+ text_bs = 256
120
+ text_ids = []
121
+ text_embeds = []
122
+ text_atts = []
123
+ for i in range(0, num_text, text_bs):
124
+ text = texts[i : min(num_text, i + text_bs)]
125
+ text_input = model.tokenizer(
126
+ text,
127
+ padding="max_length",
128
+ truncation=True,
129
+ max_length=35,
130
+ return_tensors="pt",
131
+ ).to(model.device)
132
+ text_feat = model.forward_text(text_input)
133
+ text_embed = F.normalize(model.text_proj(text_feat))
134
+ text_embeds.append(text_embed)
135
+ text_ids.append(text_input.input_ids)
136
+ text_atts.append(text_input.attention_mask)
137
+
138
+ text_embeds = torch.cat(text_embeds, dim=0)
139
+ text_ids = torch.cat(text_ids, dim=0)
140
+ text_atts = torch.cat(text_atts, dim=0)
141
+
142
+ vit_feats = []
143
+ image_embeds = []
144
+ for samples in data_loader:
145
+ image = samples["image"]
146
+
147
+ image = image.to(model.device)
148
+ image_feat, vit_feat = model.forward_image(image)
149
+ image_embed = model.vision_proj(image_feat)
150
+ image_embed = F.normalize(image_embed, dim=-1)
151
+
152
+ vit_feats.append(vit_feat.cpu())
153
+ image_embeds.append(image_embed)
154
+
155
+ vit_feats = torch.cat(vit_feats, dim=0)
156
+ image_embeds = torch.cat(image_embeds, dim=0)
157
+
158
+ sims_matrix = []
159
+ for image_embed in image_embeds:
160
+ sim_q2t = image_embed @ text_embeds.t()
161
+ sim_i2t, _ = sim_q2t.max(0)
162
+ sims_matrix.append(sim_i2t)
163
+ sims_matrix = torch.stack(sims_matrix, dim=0)
164
+
165
+ score_matrix_i2t = torch.full(
166
+ (len(data_loader.dataset.image), len(texts)), -100.0
167
+ ).to(model.device)
168
+
169
+ num_tasks = dist_utils.get_world_size()
170
+ rank = dist_utils.get_rank()
171
+ step = sims_matrix.size(0) // num_tasks + 1
172
+ start = rank * step
173
+ end = min(sims_matrix.size(0), start + step)
174
+
175
+ for i, sims in enumerate(
176
+ metric_logger.log_every(sims_matrix[start:end], 50, header)
177
+ ):
178
+ topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
179
+ image_inputs = vit_feats[start + i].repeat(k_test, 1, 1).to(model.device)
180
+ score = model.compute_itm(
181
+ image_inputs=image_inputs,
182
+ text_ids=text_ids[topk_idx],
183
+ text_atts=text_atts[topk_idx],
184
+ ).float()
185
+ score_matrix_i2t[start + i, topk_idx] = score + topk_sim
186
+
187
+ sims_matrix = sims_matrix.t()
188
+ score_matrix_t2i = torch.full(
189
+ (len(texts), len(data_loader.dataset.image)), -100.0
190
+ ).to(model.device)
191
+
192
+ step = sims_matrix.size(0) // num_tasks + 1
193
+ start = rank * step
194
+ end = min(sims_matrix.size(0), start + step)
195
+
196
+ for i, sims in enumerate(
197
+ metric_logger.log_every(sims_matrix[start:end], 50, header)
198
+ ):
199
+ topk_sim, topk_idx = sims.topk(k=k_test, dim=0)
200
+ image_inputs = vit_feats[topk_idx.cpu()].to(model.device)
201
+ score = model.compute_itm(
202
+ image_inputs=image_inputs,
203
+ text_ids=text_ids[start + i].repeat(k_test, 1),
204
+ text_atts=text_atts[start + i].repeat(k_test, 1),
205
+ ).float()
206
+ score_matrix_t2i[start + i, topk_idx] = score + topk_sim
207
+
208
+ if dist_utils.is_dist_avail_and_initialized():
209
+ dist.barrier()
210
+ torch.distributed.all_reduce(
211
+ score_matrix_i2t, op=torch.distributed.ReduceOp.SUM
212
+ )
213
+ torch.distributed.all_reduce(
214
+ score_matrix_t2i, op=torch.distributed.ReduceOp.SUM
215
+ )
216
+
217
+ total_time = time.time() - start_time
218
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
219
+ logging.info("Evaluation time {}".format(total_time_str))
220
+
221
+ return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
cheetah/models/blip2_outputs.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from dataclasses import dataclass
9
+ from typing import Optional
10
+
11
+ import torch
12
+ from transformers.modeling_outputs import (
13
+ ModelOutput,
14
+ BaseModelOutputWithPoolingAndCrossAttentions,
15
+ CausalLMOutputWithCrossAttentions,
16
+ )
17
+
18
+
19
+ @dataclass
20
+ class BlipSimilarity(ModelOutput):
21
+ sim_i2t: torch.FloatTensor = None
22
+ sim_t2i: torch.FloatTensor = None
23
+
24
+ sim_i2t_m: Optional[torch.FloatTensor] = None
25
+ sim_t2i_m: Optional[torch.FloatTensor] = None
26
+
27
+ sim_i2t_targets: Optional[torch.FloatTensor] = None
28
+ sim_t2i_targets: Optional[torch.FloatTensor] = None
29
+
30
+
31
+ @dataclass
32
+ class BlipIntermediateOutput(ModelOutput):
33
+ """
34
+ Data class for intermediate outputs of BLIP models.
35
+
36
+ image_embeds (torch.FloatTensor): Image embeddings, shape (batch_size, num_patches, embed_dim).
37
+ text_embeds (torch.FloatTensor): Text embeddings, shape (batch_size, seq_len, embed_dim).
38
+
39
+ image_embeds_m (torch.FloatTensor): Image embeddings from momentum visual encoder, shape (batch_size, num_patches, embed_dim).
40
+ text_embeds_m (torch.FloatTensor): Text embeddings from momentum text encoder, shape (batch_size, seq_len, embed_dim).
41
+
42
+ encoder_output (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder.
43
+ encoder_output_neg (BaseModelOutputWithPoolingAndCrossAttentions): output from the image-grounded text encoder for negative pairs.
44
+
45
+ decoder_output (CausalLMOutputWithCrossAttentions): output from the image-grounded text decoder.
46
+ decoder_labels (torch.LongTensor): labels for the captioning loss.
47
+
48
+ itm_logits (torch.FloatTensor): logits for the image-text matching loss, shape (batch_size * 3, 2).
49
+ itm_labels (torch.LongTensor): labels for the image-text matching loss, shape (batch_size * 3,)
50
+
51
+ """
52
+
53
+ # uni-modal features
54
+ image_embeds: torch.FloatTensor = None
55
+ text_embeds: Optional[torch.FloatTensor] = None
56
+
57
+ image_embeds_m: Optional[torch.FloatTensor] = None
58
+ text_embeds_m: Optional[torch.FloatTensor] = None
59
+
60
+ # intermediate outputs of multimodal encoder
61
+ encoder_output: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
62
+ encoder_output_neg: Optional[BaseModelOutputWithPoolingAndCrossAttentions] = None
63
+
64
+ itm_logits: Optional[torch.FloatTensor] = None
65
+ itm_labels: Optional[torch.LongTensor] = None
66
+
67
+ # intermediate outputs of multimodal decoder
68
+ decoder_output: Optional[CausalLMOutputWithCrossAttentions] = None
69
+ decoder_labels: Optional[torch.LongTensor] = None
70
+
71
+
72
+ @dataclass
73
+ class BlipOutput(ModelOutput):
74
+ # some finetuned models (e.g. BlipVQA) do not compute similarity, thus optional.
75
+ sims: Optional[BlipSimilarity] = None
76
+
77
+ intermediate_output: BlipIntermediateOutput = None
78
+
79
+ loss: Optional[torch.FloatTensor] = None
80
+
81
+ loss_itc: Optional[torch.FloatTensor] = None
82
+
83
+ loss_itm: Optional[torch.FloatTensor] = None
84
+
85
+ loss_lm: Optional[torch.FloatTensor] = None
86
+
87
+
88
+ @dataclass
89
+ class BlipOutputFeatures(ModelOutput):
90
+ """
91
+ Data class of features from BlipFeatureExtractor.
92
+
93
+ Args:
94
+ image_embeds: (torch.FloatTensor) of shape (batch_size, num_patches+1, embed_dim), optional
95
+ image_features: (torch.FloatTensor) of shape (batch_size, num_patches+1, feature_dim), optional
96
+ text_embeds: (torch.FloatTensor) of shape (batch_size, sequence_length+1, embed_dim), optional
97
+ text_features: (torch.FloatTensor) of shape (batch_size, sequence_length+1, feature_dim), optional
98
+
99
+ The first embedding or feature is for the [CLS] token.
100
+
101
+ Features are obtained by projecting the corresponding embedding into a normalized low-dimensional space.
102
+ """
103
+
104
+ image_embeds: Optional[torch.FloatTensor] = None
105
+ image_embeds_proj: Optional[torch.FloatTensor] = None
106
+
107
+ text_embeds: Optional[torch.FloatTensor] = None
108
+ text_embeds_proj: Optional[torch.FloatTensor] = None
109
+
110
+ multimodal_embeds: Optional[torch.FloatTensor] = None
cheetah/models/cheetah_llama2.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import random
4
+
5
+ import torch
6
+ from torch.cuda.amp import autocast as autocast
7
+ import torch.nn as nn
8
+
9
+ from cheetah.common.registry import registry
10
+ from cheetah.models.blip2 import Blip2Base, disabled_train
11
+ from cheetah.models.modeling_llama2 import LlamaForCausalLM
12
+ from transformers import LlamaTokenizer
13
+ from cheetah.models.Qformer import BertConfig
14
+
15
+ from collections import OrderedDict
16
+ from cheetah.common.dist_utils import download_cached_file
17
+ from cheetah.common.utils import is_url
18
+
19
+ def zero_module(module):
20
+ for p in module.parameters():
21
+ p.detach().zero_()
22
+ return module
23
+
24
+ @registry.register_model("cheetah_llama2")
25
+ class Cheetah_Llama2(Blip2Base):
26
+ """
27
+ BLIP2 GPT-LLAMA model.
28
+ """
29
+
30
+ PRETRAINED_MODEL_CONFIG_DICT = {
31
+ "pretrain_llama2": "configs/models/cheetah_llama2.yaml",
32
+ }
33
+
34
+ def __init__(
35
+ self,
36
+ vit_model="eva_clip_g",
37
+ q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
38
+ img_size=224,
39
+ drop_path_rate=0,
40
+ use_grad_checkpoint=False,
41
+ vit_precision="bf16",
42
+ freeze_vit=True,
43
+ freeze_qformer=True,
44
+ freeze_llama_proj=True,
45
+ num_query_token=32,
46
+ llama_model="",
47
+ # prompt_path="",
48
+ prompt_template="",
49
+ max_txt_len=32,
50
+ end_sym='\n',
51
+ update_layer = 16,
52
+ # low_resource=False, # use 8 bit and put vit in cpu
53
+ # device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
54
+ ):
55
+ super().__init__()
56
+ print("the llama2 version of cheetah!")
57
+
58
+ self.tokenizer = self.init_tokenizer()
59
+
60
+ print('Loading VIT')
61
+ self.visual_encoder, self.ln_vision = self.init_vision_encoder(
62
+ vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
63
+ )
64
+ if freeze_vit:
65
+ for name, param in self.visual_encoder.named_parameters():
66
+ param.requires_grad = False
67
+ self.visual_encoder = self.visual_encoder.eval()
68
+ self.visual_encoder.train = disabled_train
69
+ for name, param in self.ln_vision.named_parameters():
70
+ param.requires_grad = False
71
+ self.ln_vision = self.ln_vision.eval()
72
+ self.ln_vision.train = disabled_train
73
+ logging.info("freeze vision encoder")
74
+ print('Loading VIT Done')
75
+
76
+ print('Loading Q-Former')
77
+ self.Qformer, self.query_tokens = self.init_Qformer(
78
+ num_query_token, self.visual_encoder.num_features
79
+ )
80
+ self.Qformer.cls = None
81
+ self.Qformer.bert.embeddings.word_embeddings = None
82
+ self.Qformer.bert.embeddings.position_embeddings = None
83
+ for layer in self.Qformer.bert.encoder.layer:
84
+ layer.output = None
85
+ layer.intermediate = None
86
+ self.load_from_pretrained(url_or_filename=q_former_model)
87
+
88
+ if freeze_qformer:
89
+ for name, param in self.Qformer.named_parameters():
90
+ param.requires_grad = False
91
+ self.query_tokens.requires_grad = False
92
+ logging.info("freeze Qformer")
93
+ print('Loading Q-Former Done')
94
+
95
+ print('Loading LLAMA')
96
+ self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
97
+ self.llama_tokenizer.pad_token = self.llama_tokenizer.unk_token
98
+
99
+ self.llama_model = LlamaForCausalLM.from_pretrained(
100
+ llama_model,
101
+ torch_dtype=torch.bfloat16,
102
+ )
103
+
104
+ for name, param in self.llama_model.named_parameters():
105
+ param.requires_grad = False
106
+ print('Loading LLAMA Done')
107
+
108
+ self.llama_proj = nn.Linear(
109
+ self.Qformer.config.hidden_size, self.llama_model.config.hidden_size
110
+ )
111
+ if freeze_llama_proj:
112
+ for name, param in self.llama_proj.named_parameters():
113
+ param.requires_grad = False
114
+
115
+ self.max_txt_len = max_txt_len
116
+ self.end_sym = end_sym
117
+ self.prompt_template = prompt_template
118
+
119
+ new_query_tokens = self.init_query_tokens(num_query_token)
120
+
121
+ qformer_proj = zero_module(nn.Linear(
122
+ self.llama_model.config.hidden_size, self.Qformer.config.hidden_size
123
+ ))
124
+
125
+ new_llm_proj = zero_module(nn.Linear(
126
+ self.Qformer.config.hidden_size, self.llama_model.config.hidden_size
127
+ ))
128
+ self.llama_model.set_qformer_and_proj(self.Qformer, qformer_proj, new_llm_proj, new_query_tokens)
129
+ self.init_query_tokens_value(url_or_filename=q_former_model)
130
+ self.update_layer = update_layer
131
+
132
+
133
+ def init_query_tokens_value(self, url_or_filename="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth"):
134
+ if is_url(url_or_filename):
135
+ cached_file = download_cached_file(
136
+ url_or_filename, check_hash=False, progress=True
137
+ )
138
+ checkpoint = torch.load(cached_file, map_location="cpu")
139
+ elif os.path.isfile(url_or_filename):
140
+ checkpoint = torch.load(url_or_filename, map_location="cpu")
141
+ else:
142
+ raise RuntimeError("checkpoint url or path is invalid")
143
+
144
+ state_dict = checkpoint["model"]
145
+ new_state_dict = OrderedDict()
146
+ for k in list(state_dict.keys()):
147
+ if 'query_tokens' in k:
148
+ new_state_dict[f'llama_model.model.query_tokens'] = state_dict[k]
149
+ self.load_state_dict(new_state_dict, strict=False)
150
+
151
+
152
+ @classmethod
153
+ def init_query_tokens(self, num_query_token):
154
+ encoder_config = BertConfig.from_pretrained("bert-base-uncased")
155
+ query_tokens = nn.Parameter(
156
+ torch.zeros(1, num_query_token, encoder_config.hidden_size)
157
+ )
158
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
159
+ return query_tokens
160
+
161
+ def maybe_autocast(self, dtype=torch.bfloat16):
162
+ # if on cpu, don't use autocast
163
+ # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
164
+ enable_autocast = self.device != torch.device("cpu")
165
+
166
+ if enable_autocast:
167
+ return torch.cuda.amp.autocast(dtype=dtype)
168
+ else:
169
+ return contextlib.nullcontext()
170
+
171
+ def encode_img(self, image):
172
+ device = image.device
173
+
174
+ with self.maybe_autocast():
175
+ image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
176
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
177
+
178
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
179
+ query_output = self.Qformer.bert(
180
+ query_embeds=query_tokens,
181
+ encoder_hidden_states=image_embeds,
182
+ encoder_attention_mask=image_atts,
183
+ return_dict=True,
184
+ )
185
+
186
+ inputs_llama = self.llama_proj(query_output.last_hidden_state)
187
+ atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
188
+ return inputs_llama, atts_llama, image_embeds
189
+
190
+ def prompt_process(self, img_embeds, img_attens, input_text):
191
+ split_prompt = [txt.split('<HereForImage>') for txt in input_text]
192
+ for i in range(len(split_prompt)):
193
+ assert len(split_prompt[i]) == len(img_embeds) + 1, f"Unmatched numbers of image placeholders and images"
194
+ prompt_segs = []
195
+ for i in range(len(img_embeds) + 1):
196
+ prompt_segs.append([p[i] for p in split_prompt])
197
+
198
+ seg_tokens = [
199
+ self.llama_tokenizer(
200
+ seg, return_tensors="pt", padding=True, add_special_tokens=False).to(self.device)
201
+ for i, seg in enumerate(prompt_segs)
202
+ ]
203
+
204
+ seg_embs = [self.llama_model.model.embed_tokens(seg_t.input_ids) for seg_t in seg_tokens]
205
+ seg_attns = [seg_t.attention_mask for seg_t in seg_tokens]
206
+
207
+ mixed_embs = []
208
+ mixed_attns = []
209
+ img_position_list = []
210
+ img_start = 1
211
+ for i in range(len(prompt_segs)):
212
+ mixed_embs.append(seg_embs[i])
213
+ mixed_attns.append(seg_attns[i])
214
+ if i != len(img_embeds):
215
+ mixed_embs.append(img_embeds[i])
216
+ mixed_attns.append(img_attens[i])
217
+ img_start += seg_embs[i].size(1)
218
+ img_end = img_start + img_embeds[i].size(1)
219
+ img_position_list.append((img_start, img_end))
220
+ img_start = img_end
221
+
222
+ mixed_embs = torch.cat(mixed_embs, dim=1)
223
+ mixed_attns = torch.cat(mixed_attns, dim=1)
224
+
225
+ return mixed_embs, mixed_attns, img_position_list
226
+
227
+ def concat_text_input_output(self, input_attns, input_embeds, output_attns, output_embeds, target_ids):
228
+ input_part_targets_len = []
229
+ empty_targets = (
230
+ torch.ones(input_attns.size(), dtype=torch.long).to(input_attns.device).fill_(-100)
231
+ )
232
+ llm_inputs = {"inputs_embeds": [], "attention_mask": [], "targets":[]}
233
+ for i in range(input_attns.size(0)):
234
+ this_input_ones = (torch.nonzero(input_attns[i]).squeeze())[-1]
235
+ input_part_targets_len.append(this_input_ones)
236
+ this_input_ones = this_input_ones + 1
237
+
238
+ llm_inputs["targets"].append(
239
+ torch.cat([
240
+ empty_targets[i][:this_input_ones],
241
+ target_ids[i][:],
242
+ empty_targets[i][this_input_ones:]
243
+ ])
244
+ )
245
+
246
+ llm_inputs["inputs_embeds"].append(
247
+ torch.cat([
248
+ input_embeds[i][:this_input_ones, :],
249
+ output_embeds[i][:, :],
250
+ input_embeds[i][this_input_ones:, :]
251
+ ])
252
+ )
253
+
254
+ llm_inputs["attention_mask"].append(
255
+ torch.cat([
256
+ input_attns[i][:this_input_ones],
257
+ output_attns[i][:],
258
+ input_attns[i][this_input_ones:]
259
+ ])
260
+ )
261
+
262
+ llm_inputs["inputs_embeds"] = torch.stack(llm_inputs["inputs_embeds"], dim=0)
263
+ llm_inputs["targets"] = torch.stack(llm_inputs["targets"], dim=0)
264
+ llm_inputs["attention_mask"] = torch.stack(llm_inputs["attention_mask"], dim=0)
265
+
266
+ return llm_inputs, input_part_targets_len
267
+
268
+ def forward(self, samples):
269
+ image = samples["image"]
270
+ img_list, vit_list, att_list = [], [], []
271
+ if image.dim() == 5:
272
+ for j in range(image.size(2)):
273
+ this_image = image[:,:,j,:,:]
274
+ image_emb, image_att, vit_emb = self.encode_img(this_image)
275
+ img_list.append(image_emb)
276
+ att_list.append(image_att)
277
+ vit_list.append(vit_emb)
278
+ else:
279
+ image_emb, image_att, vit_emb = self.encode_img(image)
280
+ img_list.append(image_emb)
281
+ att_list.append(image_att)
282
+ vit_list.append(vit_emb)
283
+
284
+ prompt = samples["text_input"]
285
+ prompt = [self.prompt_template.format(p) for p in prompt]
286
+
287
+ self.llama_tokenizer.padding_side = "right"
288
+ self.llama_tokenizer.truncation_side = 'left'
289
+ img_embeds, atts_img, img_position_list = self.prompt_process(img_list, att_list, prompt)
290
+
291
+ self.llama_tokenizer.padding_side = "right"
292
+ self.llama_tokenizer.truncation_side = 'right'
293
+
294
+ text = [t + self.end_sym for t in samples["text_output"]]
295
+
296
+ to_regress_tokens = self.llama_tokenizer(
297
+ text,
298
+ return_tensors="pt",
299
+ padding="longest",
300
+ truncation=True,
301
+ max_length=self.max_txt_len,
302
+ add_special_tokens=False
303
+ ).to(image.device)
304
+
305
+ to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids)
306
+
307
+ targets = to_regress_tokens.input_ids.masked_fill(
308
+ to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
309
+ )
310
+
311
+ batch_size = img_embeds.shape[0]
312
+ bos = torch.ones([batch_size, 1],
313
+ dtype=to_regress_tokens.input_ids.dtype,
314
+ device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
315
+ bos_embeds = self.llama_model.model.embed_tokens(bos)
316
+ atts_bos = torch.ones((atts_img.size(0), 1), dtype=torch.long).to(atts_img.device)
317
+
318
+ img_embeds = torch.cat([bos_embeds, img_embeds], dim=1)
319
+ atts_img = torch.cat([atts_bos, atts_img], dim=1)
320
+
321
+ llm_inputs, input_part_targets_len = self.concat_text_input_output(atts_img,
322
+ img_embeds,
323
+ to_regress_tokens['attention_mask'],
324
+ to_regress_embeds,
325
+ targets)
326
+
327
+ input_part_targets_len = torch.tensor(input_part_targets_len)
328
+
329
+ with self.maybe_autocast():
330
+ outputs = self.llama_model(
331
+ inputs_embeds=llm_inputs['inputs_embeds'],
332
+ attention_mask=llm_inputs['attention_mask'],
333
+ return_dict=True,
334
+ labels=llm_inputs['targets'],
335
+ update_layer = self.update_layer,
336
+ image_position_list = img_position_list,
337
+ input_part_targets_len = input_part_targets_len,
338
+ all_image_embeds = torch.stack(vit_list,dim=1),
339
+ )
340
+ loss = outputs.loss
341
+
342
+ return {"loss": loss}
343
+
344
+ @classmethod
345
+ def from_config(cls, cfg):
346
+ vit_model = cfg.get("vit_model", "eva_clip_g")
347
+ q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth")
348
+ img_size = cfg.get("image_size")
349
+ num_query_token = cfg.get("num_query_token")
350
+ llama_model = cfg.get("llama_model")
351
+
352
+ drop_path_rate = cfg.get("drop_path_rate", 0)
353
+ use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
354
+ vit_precision = cfg.get("vit_precision", "bf16")
355
+ freeze_vit = cfg.get("freeze_vit", True)
356
+ freeze_qformer = cfg.get("freeze_qformer", True)
357
+ freeze_llama_proj = cfg.get("freeze_llama_proj", True)
358
+
359
+ prompt_template = cfg.get("prompt_template", "")
360
+ max_txt_len = cfg.get("max_txt_len", 32)
361
+ end_sym = cfg.get("end_sym", '\n')
362
+ update_layer = cfg.get("update_layer", 16)
363
+
364
+ model = cls(
365
+ vit_model=vit_model,
366
+ q_former_model=q_former_model,
367
+ img_size=img_size,
368
+ drop_path_rate=drop_path_rate,
369
+ use_grad_checkpoint=use_grad_checkpoint,
370
+ vit_precision=vit_precision,
371
+ freeze_vit=freeze_vit,
372
+ freeze_qformer=freeze_qformer,
373
+ freeze_llama_proj=freeze_llama_proj,
374
+ num_query_token=num_query_token,
375
+ llama_model=llama_model,
376
+ prompt_template=prompt_template,
377
+ max_txt_len=max_txt_len,
378
+ end_sym=end_sym,
379
+ update_layer=update_layer,
380
+ )
381
+
382
+ ckpt_path = cfg.get("ckpt", "") # load weights of Cheetah
383
+ if ckpt_path:
384
+ print("Load BLIP2-LLM Checkpoint: {}".format(ckpt_path))
385
+ ckpt = torch.load(ckpt_path, map_location="cpu")
386
+ msg = model.load_state_dict(ckpt['model'], strict=False)
387
+
388
+ return model
cheetah/models/cheetah_vicuna.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ import random
4
+
5
+ import torch
6
+ from torch.cuda.amp import autocast as autocast
7
+ import torch.nn as nn
8
+
9
+ from cheetah.common.registry import registry
10
+ from cheetah.models.blip2 import Blip2Base, disabled_train
11
+ from cheetah.models.modeling_llama import LlamaForCausalLM
12
+ from transformers import LlamaTokenizer
13
+ from cheetah.models.Qformer import BertConfig
14
+
15
+ from collections import OrderedDict
16
+ from cheetah.common.dist_utils import download_cached_file
17
+ from cheetah.common.utils import is_url
18
+
19
+ def zero_module(module):
20
+ for p in module.parameters():
21
+ p.detach().zero_()
22
+ return module
23
+
24
+ @registry.register_model("cheetah_vicuna")
25
+ class Cheetah_Vicuna(Blip2Base):
26
+ """
27
+ BLIP2 GPT-LLAMA model.
28
+ """
29
+
30
+ PRETRAINED_MODEL_CONFIG_DICT = {
31
+ "pretrain_vicuna": "configs/models/cheetah_vicuna.yaml",
32
+ }
33
+
34
+ def __init__(
35
+ self,
36
+ vit_model="eva_clip_g",
37
+ q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth",
38
+ img_size=224,
39
+ drop_path_rate=0,
40
+ use_grad_checkpoint=False,
41
+ vit_precision="fp16",
42
+ freeze_vit=True,
43
+ freeze_qformer=True,
44
+ freeze_llama_proj=True,
45
+ num_query_token=32,
46
+ llama_model="",
47
+ # prompt_path="",
48
+ prompt_template="",
49
+ max_txt_len=32,
50
+ end_sym='\n',
51
+ update_layer = 16,
52
+ # low_resource=False, # use 8 bit and put vit in cpu
53
+ # device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
54
+ ):
55
+ super().__init__()
56
+ print("the vicuna version of cheetah!")
57
+ self.tokenizer = self.init_tokenizer()
58
+
59
+ print('Loading VIT')
60
+ self.visual_encoder, self.ln_vision = self.init_vision_encoder(
61
+ vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision
62
+ )
63
+ if freeze_vit:
64
+ for name, param in self.visual_encoder.named_parameters():
65
+ param.requires_grad = False
66
+ self.visual_encoder = self.visual_encoder.eval()
67
+ self.visual_encoder.train = disabled_train
68
+ for name, param in self.ln_vision.named_parameters():
69
+ param.requires_grad = False
70
+ self.ln_vision = self.ln_vision.eval()
71
+ self.ln_vision.train = disabled_train
72
+ logging.info("freeze vision encoder")
73
+ print('Loading VIT Done')
74
+
75
+ print('Loading Q-Former')
76
+ self.Qformer, self.query_tokens = self.init_Qformer(
77
+ num_query_token, self.visual_encoder.num_features
78
+ )
79
+ self.Qformer.cls = None
80
+ self.Qformer.bert.embeddings.word_embeddings = None
81
+ self.Qformer.bert.embeddings.position_embeddings = None
82
+ for layer in self.Qformer.bert.encoder.layer:
83
+ layer.output = None
84
+ layer.intermediate = None
85
+ self.load_from_pretrained(url_or_filename=q_former_model)
86
+
87
+ if freeze_qformer:
88
+ for name, param in self.Qformer.named_parameters():
89
+ param.requires_grad = False
90
+ self.query_tokens.requires_grad = False
91
+ logging.info("freeze Qformer")
92
+ print('Loading Q-Former Done')
93
+
94
+ print('Loading LLAMA')
95
+ self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
96
+ self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
97
+
98
+ self.llama_model = LlamaForCausalLM.from_pretrained(
99
+ llama_model,
100
+ torch_dtype=torch.float16,
101
+ )
102
+
103
+ for name, param in self.llama_model.named_parameters():
104
+ param.requires_grad = False
105
+ print('Loading LLAMA Done')
106
+
107
+ self.llama_proj = nn.Linear(
108
+ self.Qformer.config.hidden_size, self.llama_model.config.hidden_size
109
+ )
110
+ if freeze_llama_proj:
111
+ for name, param in self.llama_proj.named_parameters():
112
+ param.requires_grad = False
113
+
114
+ self.max_txt_len = max_txt_len
115
+ self.end_sym = end_sym
116
+ self.prompt_template = prompt_template
117
+
118
+ new_query_tokens = self.init_query_tokens(num_query_token)
119
+
120
+ qformer_proj = zero_module(nn.Linear(
121
+ self.llama_model.config.hidden_size, self.Qformer.config.hidden_size
122
+ ))
123
+
124
+ new_llm_proj = zero_module(nn.Linear(
125
+ self.Qformer.config.hidden_size, self.llama_model.config.hidden_size
126
+ ))
127
+
128
+ self.llama_model.set_qformer_and_proj(self.Qformer, qformer_proj, new_llm_proj, new_query_tokens)
129
+ self.init_query_tokens_value(url_or_filename=q_former_model)
130
+ self.update_layer = update_layer
131
+
132
+ def init_query_tokens_value(self, url_or_filename="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth"):
133
+ if is_url(url_or_filename):
134
+ cached_file = download_cached_file(
135
+ url_or_filename, check_hash=False, progress=True
136
+ )
137
+ checkpoint = torch.load(cached_file, map_location="cpu")
138
+ elif os.path.isfile(url_or_filename):
139
+ checkpoint = torch.load(url_or_filename, map_location="cpu")
140
+ else:
141
+ raise RuntimeError("checkpoint url or path is invalid")
142
+
143
+ state_dict = checkpoint["model"]
144
+ new_state_dict = OrderedDict()
145
+ for k in list(state_dict.keys()):
146
+ if 'query_tokens' in k:
147
+ new_state_dict[f'llama_model.model.query_tokens'] = state_dict[k]
148
+ self.load_state_dict(new_state_dict, strict=False)
149
+
150
+
151
+ @classmethod
152
+ def init_query_tokens(self, num_query_token):
153
+ encoder_config = BertConfig.from_pretrained("bert-base-uncased")
154
+ query_tokens = nn.Parameter(
155
+ torch.zeros(1, num_query_token, encoder_config.hidden_size)
156
+ )
157
+ query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
158
+ return query_tokens
159
+
160
+ def maybe_autocast(self, dtype=torch.float16):
161
+ # if on cpu, don't use autocast
162
+ # if on gpu, use autocast with dtype if provided, otherwise use torch.float16
163
+ enable_autocast = self.device != torch.device("cpu")
164
+
165
+ if enable_autocast:
166
+ return torch.cuda.amp.autocast(dtype=dtype)
167
+ else:
168
+ return contextlib.nullcontext()
169
+
170
+ def encode_img(self, image):
171
+ device = image.device
172
+
173
+ with self.maybe_autocast():
174
+ image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
175
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
176
+
177
+ query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
178
+ query_output = self.Qformer.bert(
179
+ query_embeds=query_tokens,
180
+ encoder_hidden_states=image_embeds,
181
+ encoder_attention_mask=image_atts,
182
+ return_dict=True,
183
+ )
184
+
185
+ inputs_llama = self.llama_proj(query_output.last_hidden_state)
186
+ atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
187
+ return inputs_llama, atts_llama, image_embeds
188
+
189
+ def prompt_process(self, img_embeds, img_attens, input_text):
190
+ split_prompt = [txt.split('<HereForImage>') for txt in input_text]
191
+ for i in range(len(split_prompt)):
192
+ assert len(split_prompt[i]) == len(img_embeds) + 1, f"Unmatched numbers of image placeholders and images."
193
+ prompt_segs = []
194
+ for i in range(len(img_embeds) + 1):
195
+ prompt_segs.append([p[i] for p in split_prompt])
196
+
197
+ seg_tokens = [
198
+ self.llama_tokenizer(
199
+ seg, return_tensors="pt", padding=True, add_special_tokens=False).to(self.device)
200
+ for i, seg in enumerate(prompt_segs)
201
+ ]
202
+
203
+ seg_embs = [self.llama_model.model.embed_tokens(seg_t.input_ids) for seg_t in seg_tokens]
204
+ seg_attns = [seg_t.attention_mask for seg_t in seg_tokens]
205
+
206
+ mixed_embs = []
207
+ mixed_attns = []
208
+ img_position_list = []
209
+ img_start = 1
210
+ for i in range(len(prompt_segs)):
211
+ mixed_embs.append(seg_embs[i])
212
+ mixed_attns.append(seg_attns[i])
213
+ if i != len(img_embeds):
214
+ mixed_embs.append(img_embeds[i])
215
+ mixed_attns.append(img_attens[i])
216
+ img_start += seg_embs[i].size(1)
217
+ img_end = img_start + img_embeds[i].size(1)
218
+ img_position_list.append((img_start, img_end))
219
+ img_start = img_end
220
+
221
+ mixed_embs = torch.cat(mixed_embs, dim=1)
222
+ mixed_attns = torch.cat(mixed_attns, dim=1)
223
+
224
+ return mixed_embs, mixed_attns, img_position_list
225
+
226
+ def concat_text_input_output(self, input_attns, input_embeds, output_attns, output_embeds, target_ids):
227
+ input_part_targets_len = []
228
+ empty_targets = (
229
+ torch.ones(input_attns.size(), dtype=torch.long).to(input_attns.device).fill_(-100)
230
+ )
231
+ llm_inputs = {"inputs_embeds": [], "attention_mask": [], "targets":[]}
232
+ for i in range(input_attns.size(0)):
233
+ this_input_ones = (torch.nonzero(input_attns[i]).squeeze())[-1]
234
+ input_part_targets_len.append(this_input_ones)
235
+ this_input_ones = this_input_ones + 1
236
+
237
+ llm_inputs["targets"].append(
238
+ torch.cat([
239
+ empty_targets[i][:this_input_ones],
240
+ target_ids[i][:],
241
+ empty_targets[i][this_input_ones:]
242
+ ])
243
+ )
244
+
245
+ llm_inputs["inputs_embeds"].append(
246
+ torch.cat([
247
+ input_embeds[i][:this_input_ones, :],
248
+ output_embeds[i][:, :],
249
+ input_embeds[i][this_input_ones:, :]
250
+ ])
251
+ )
252
+
253
+ llm_inputs["attention_mask"].append(
254
+ torch.cat([
255
+ input_attns[i][:this_input_ones],
256
+ output_attns[i][:],
257
+ input_attns[i][this_input_ones:]
258
+ ])
259
+ )
260
+
261
+ llm_inputs["inputs_embeds"] = torch.stack(llm_inputs["inputs_embeds"], dim=0)
262
+ llm_inputs["targets"] = torch.stack(llm_inputs["targets"], dim=0)
263
+ llm_inputs["attention_mask"] = torch.stack(llm_inputs["attention_mask"], dim=0)
264
+
265
+ return llm_inputs, input_part_targets_len
266
+
267
+ def forward(self, samples):
268
+ image = samples["image"]
269
+ img_list, vit_list, att_list = [], [], []
270
+ if image.dim() == 5:
271
+ for j in range(image.size(2)):
272
+ this_image = image[:,:,j,:,:]
273
+ image_emb, image_att, vit_emb = self.encode_img(this_image)
274
+ img_list.append(image_emb)
275
+ att_list.append(image_att)
276
+ vit_list.append(vit_emb)
277
+ else:
278
+ image_emb, image_att, vit_emb = self.encode_img(image)
279
+ img_list.append(image_emb)
280
+ att_list.append(image_att)
281
+ vit_list.append(vit_emb)
282
+
283
+ prompt = samples["text_input"]
284
+ prompt = [self.prompt_template.format(p) for p in prompt]
285
+
286
+ self.llama_tokenizer.padding_side = "right"
287
+ self.llama_tokenizer.truncation_side = 'left'
288
+ img_embeds, atts_img, img_position_list = self.prompt_process(img_list, att_list, prompt)
289
+
290
+ self.llama_tokenizer.padding_side = "right"
291
+ self.llama_tokenizer.truncation_side = 'right'
292
+
293
+ text = [t + self.end_sym for t in samples["text_output"]]
294
+
295
+ to_regress_tokens = self.llama_tokenizer(
296
+ text,
297
+ return_tensors="pt",
298
+ padding="longest",
299
+ truncation=True,
300
+ max_length=self.max_txt_len,
301
+ add_special_tokens=False
302
+ ).to(image.device)
303
+
304
+ to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids)
305
+
306
+ targets = to_regress_tokens.input_ids.masked_fill(
307
+ to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100
308
+ )
309
+
310
+ batch_size = img_embeds.shape[0]
311
+ bos = torch.ones([batch_size, 1],
312
+ dtype=to_regress_tokens.input_ids.dtype,
313
+ device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
314
+ bos_embeds = self.llama_model.model.embed_tokens(bos)
315
+ atts_bos = torch.ones((atts_img.size(0), 1), dtype=torch.long).to(atts_img.device)
316
+
317
+ img_embeds = torch.cat([bos_embeds, img_embeds], dim=1)
318
+ atts_img = torch.cat([atts_bos, atts_img], dim=1)
319
+
320
+ llm_inputs, input_part_targets_len = self.concat_text_input_output(atts_img,
321
+ img_embeds,
322
+ to_regress_tokens['attention_mask'],
323
+ to_regress_embeds,
324
+ targets)
325
+
326
+ input_part_targets_len = torch.tensor(input_part_targets_len)
327
+
328
+ with self.maybe_autocast():
329
+ outputs = self.llama_model(
330
+ inputs_embeds=llm_inputs['inputs_embeds'],
331
+ attention_mask=llm_inputs['attention_mask'],
332
+ return_dict=True,
333
+ labels=llm_inputs['targets'],
334
+ update_layer = self.update_layer,
335
+ image_position_list = img_position_list,
336
+ input_part_targets_len = input_part_targets_len,
337
+ all_image_embeds = torch.stack(vit_list,dim=1),
338
+ )
339
+ loss = outputs.loss
340
+
341
+ return {"loss": loss}
342
+
343
+ @classmethod
344
+ def from_config(cls, cfg):
345
+ vit_model = cfg.get("vit_model", "eva_clip_g")
346
+ q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth")
347
+ img_size = cfg.get("image_size")
348
+ num_query_token = cfg.get("num_query_token")
349
+ llama_model = cfg.get("llama_model")
350
+
351
+ drop_path_rate = cfg.get("drop_path_rate", 0)
352
+ use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
353
+ vit_precision = cfg.get("vit_precision", "fp16")
354
+ freeze_vit = cfg.get("freeze_vit", True)
355
+ freeze_qformer = cfg.get("freeze_qformer", True)
356
+ freeze_llama_proj = cfg.get("freeze_llama_proj", True)
357
+
358
+ prompt_template = cfg.get("prompt_template", "")
359
+ max_txt_len = cfg.get("max_txt_len", 32)
360
+ end_sym = cfg.get("end_sym", '\n')
361
+ update_layer = cfg.get("update_layer", 16)
362
+
363
+ model = cls(
364
+ vit_model=vit_model,
365
+ q_former_model=q_former_model,
366
+ img_size=img_size,
367
+ drop_path_rate=drop_path_rate,
368
+ use_grad_checkpoint=use_grad_checkpoint,
369
+ vit_precision=vit_precision,
370
+ freeze_vit=freeze_vit,
371
+ freeze_qformer=freeze_qformer,
372
+ freeze_llama_proj=freeze_llama_proj,
373
+ num_query_token=num_query_token,
374
+ llama_model=llama_model,
375
+ prompt_template=prompt_template,
376
+ max_txt_len=max_txt_len,
377
+ end_sym=end_sym,
378
+ update_layer=update_layer,
379
+ )
380
+
381
+ ckpt_path = cfg.get("ckpt", "") # load weights of Cheetah
382
+ if ckpt_path:
383
+ print("Load BLIP2-LLM Checkpoint: {}".format(ckpt_path))
384
+ ckpt = torch.load(ckpt_path, map_location="cpu")
385
+ msg = model.load_state_dict(ckpt['model'], strict=False)
386
+
387
+ return model
cheetah/models/eva_vit.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Based on EVA, BEIT, timm and DeiT code bases
2
+ # https://github.com/baaivision/EVA
3
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm
4
+ # https://github.com/microsoft/unilm/tree/master/beit
5
+ # https://github.com/facebookresearch/deit/
6
+ # https://github.com/facebookresearch/dino
7
+ # --------------------------------------------------------'
8
+ import math
9
+ from functools import partial
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import torch.utils.checkpoint as checkpoint
15
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
16
+ from timm.models.registry import register_model
17
+
18
+ from cheetah.common.dist_utils import download_cached_file
19
+
20
+ def _cfg(url='', **kwargs):
21
+ return {
22
+ 'url': url,
23
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
24
+ 'crop_pct': .9, 'interpolation': 'bicubic',
25
+ 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
26
+ **kwargs
27
+ }
28
+
29
+
30
+ class DropPath(nn.Module):
31
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
32
+ """
33
+ def __init__(self, drop_prob=None):
34
+ super(DropPath, self).__init__()
35
+ self.drop_prob = drop_prob
36
+
37
+ def forward(self, x):
38
+ return drop_path(x, self.drop_prob, self.training)
39
+
40
+ def extra_repr(self) -> str:
41
+ return 'p={}'.format(self.drop_prob)
42
+
43
+
44
+ class Mlp(nn.Module):
45
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
46
+ super().__init__()
47
+ out_features = out_features or in_features
48
+ hidden_features = hidden_features or in_features
49
+ self.fc1 = nn.Linear(in_features, hidden_features)
50
+ self.act = act_layer()
51
+ self.fc2 = nn.Linear(hidden_features, out_features)
52
+ self.drop = nn.Dropout(drop)
53
+
54
+ def forward(self, x):
55
+ x = self.fc1(x)
56
+ x = self.act(x)
57
+ # x = self.drop(x)
58
+ # commit this for the original BERT implement
59
+ x = self.fc2(x)
60
+ x = self.drop(x)
61
+ return x
62
+
63
+
64
+ class Attention(nn.Module):
65
+ def __init__(
66
+ self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.,
67
+ proj_drop=0., window_size=None, attn_head_dim=None):
68
+ super().__init__()
69
+ self.num_heads = num_heads
70
+ head_dim = dim // num_heads
71
+ if attn_head_dim is not None:
72
+ head_dim = attn_head_dim
73
+ all_head_dim = head_dim * self.num_heads
74
+ self.scale = qk_scale or head_dim ** -0.5
75
+
76
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
77
+ if qkv_bias:
78
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
79
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
80
+ else:
81
+ self.q_bias = None
82
+ self.v_bias = None
83
+
84
+ if window_size:
85
+ self.window_size = window_size
86
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
87
+ self.relative_position_bias_table = nn.Parameter(
88
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
89
+ # cls to token & token 2 cls & cls to cls
90
+
91
+ # get pair-wise relative position index for each token inside the window
92
+ coords_h = torch.arange(window_size[0])
93
+ coords_w = torch.arange(window_size[1])
94
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
95
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
96
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
97
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
98
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
99
+ relative_coords[:, :, 1] += window_size[1] - 1
100
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
101
+ relative_position_index = \
102
+ torch.zeros(size=(window_size[0] * window_size[1] + 1, ) * 2, dtype=relative_coords.dtype)
103
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
104
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
105
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
106
+ relative_position_index[0, 0] = self.num_relative_distance - 1
107
+
108
+ self.register_buffer("relative_position_index", relative_position_index)
109
+ else:
110
+ self.window_size = None
111
+ self.relative_position_bias_table = None
112
+ self.relative_position_index = None
113
+
114
+ self.attn_drop = nn.Dropout(attn_drop)
115
+ self.proj = nn.Linear(all_head_dim, dim)
116
+ self.proj_drop = nn.Dropout(proj_drop)
117
+
118
+ def forward(self, x, rel_pos_bias=None):
119
+ B, N, C = x.shape
120
+ qkv_bias = None
121
+ if self.q_bias is not None:
122
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
123
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
124
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
125
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
126
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
127
+
128
+ q = q * self.scale
129
+ attn = (q @ k.transpose(-2, -1))
130
+
131
+ if self.relative_position_bias_table is not None:
132
+ relative_position_bias = \
133
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
134
+ self.window_size[0] * self.window_size[1] + 1,
135
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
136
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
137
+ attn = attn + relative_position_bias.unsqueeze(0)
138
+
139
+ if rel_pos_bias is not None:
140
+ attn = attn + rel_pos_bias
141
+
142
+ attn = attn.softmax(dim=-1)
143
+ attn = self.attn_drop(attn)
144
+
145
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
146
+ x = self.proj(x)
147
+ x = self.proj_drop(x)
148
+ return x
149
+
150
+
151
+ class Block(nn.Module):
152
+
153
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
154
+ drop_path=0., init_values=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm,
155
+ window_size=None, attn_head_dim=None):
156
+ super().__init__()
157
+ self.norm1 = norm_layer(dim)
158
+ self.attn = Attention(
159
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
160
+ attn_drop=attn_drop, proj_drop=drop, window_size=window_size, attn_head_dim=attn_head_dim)
161
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
162
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
163
+ self.norm2 = norm_layer(dim)
164
+ mlp_hidden_dim = int(dim * mlp_ratio)
165
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
166
+
167
+ if init_values is not None and init_values > 0:
168
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
169
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
170
+ else:
171
+ self.gamma_1, self.gamma_2 = None, None
172
+
173
+ def forward(self, x, rel_pos_bias=None):
174
+ if self.gamma_1 is None:
175
+ x = x + self.drop_path(self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
176
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
177
+ else:
178
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias))
179
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
180
+ return x
181
+
182
+
183
+ class PatchEmbed(nn.Module):
184
+ """ Image to Patch Embedding
185
+ """
186
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
187
+ super().__init__()
188
+ img_size = to_2tuple(img_size)
189
+ patch_size = to_2tuple(patch_size)
190
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
191
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
192
+ self.img_size = img_size
193
+ self.patch_size = patch_size
194
+ self.num_patches = num_patches
195
+
196
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
197
+
198
+ def forward(self, x, **kwargs):
199
+ B, C, H, W = x.shape
200
+ # FIXME look at relaxing size constraints
201
+ assert H == self.img_size[0] and W == self.img_size[1], \
202
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
203
+ x = self.proj(x).flatten(2).transpose(1, 2)
204
+ return x
205
+
206
+
207
+ class RelativePositionBias(nn.Module):
208
+
209
+ def __init__(self, window_size, num_heads):
210
+ super().__init__()
211
+ self.window_size = window_size
212
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
213
+ self.relative_position_bias_table = nn.Parameter(
214
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
215
+ # cls to token & token 2 cls & cls to cls
216
+
217
+ # get pair-wise relative position index for each token inside the window
218
+ coords_h = torch.arange(window_size[0])
219
+ coords_w = torch.arange(window_size[1])
220
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
221
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
222
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
223
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
224
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
225
+ relative_coords[:, :, 1] += window_size[1] - 1
226
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
227
+ relative_position_index = \
228
+ torch.zeros(size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype)
229
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
230
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
231
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
232
+ relative_position_index[0, 0] = self.num_relative_distance - 1
233
+
234
+ self.register_buffer("relative_position_index", relative_position_index)
235
+
236
+ # trunc_normal_(self.relative_position_bias_table, std=.02)
237
+
238
+ def forward(self):
239
+ relative_position_bias = \
240
+ self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
241
+ self.window_size[0] * self.window_size[1] + 1,
242
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
243
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
244
+
245
+
246
+ class VisionTransformer(nn.Module):
247
+ """ Vision Transformer with support for patch or hybrid CNN input stage
248
+ """
249
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
250
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
251
+ drop_path_rate=0., norm_layer=nn.LayerNorm, init_values=None,
252
+ use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False,
253
+ use_mean_pooling=True, init_scale=0.001, use_checkpoint=False):
254
+ super().__init__()
255
+ self.image_size = img_size
256
+ self.num_classes = num_classes
257
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
258
+
259
+ self.patch_embed = PatchEmbed(
260
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
261
+ num_patches = self.patch_embed.num_patches
262
+
263
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
264
+ if use_abs_pos_emb:
265
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
266
+ else:
267
+ self.pos_embed = None
268
+ self.pos_drop = nn.Dropout(p=drop_rate)
269
+
270
+ if use_shared_rel_pos_bias:
271
+ self.rel_pos_bias = RelativePositionBias(window_size=self.patch_embed.patch_shape, num_heads=num_heads)
272
+ else:
273
+ self.rel_pos_bias = None
274
+ self.use_checkpoint = use_checkpoint
275
+
276
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
277
+ self.use_rel_pos_bias = use_rel_pos_bias
278
+ self.blocks = nn.ModuleList([
279
+ Block(
280
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
281
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
282
+ init_values=init_values, window_size=self.patch_embed.patch_shape if use_rel_pos_bias else None)
283
+ for i in range(depth)])
284
+ # self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
285
+ # self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
286
+ # self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
287
+
288
+ if self.pos_embed is not None:
289
+ trunc_normal_(self.pos_embed, std=.02)
290
+ trunc_normal_(self.cls_token, std=.02)
291
+ # trunc_normal_(self.mask_token, std=.02)
292
+ # if isinstance(self.head, nn.Linear):
293
+ # trunc_normal_(self.head.weight, std=.02)
294
+ self.apply(self._init_weights)
295
+ self.fix_init_weight()
296
+ # if isinstance(self.head, nn.Linear):
297
+ # self.head.weight.data.mul_(init_scale)
298
+ # self.head.bias.data.mul_(init_scale)
299
+
300
+ def fix_init_weight(self):
301
+ def rescale(param, layer_id):
302
+ param.div_(math.sqrt(2.0 * layer_id))
303
+
304
+ for layer_id, layer in enumerate(self.blocks):
305
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
306
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
307
+
308
+ def _init_weights(self, m):
309
+ if isinstance(m, nn.Linear):
310
+ trunc_normal_(m.weight, std=.02)
311
+ if isinstance(m, nn.Linear) and m.bias is not None:
312
+ nn.init.constant_(m.bias, 0)
313
+ elif isinstance(m, nn.LayerNorm):
314
+ nn.init.constant_(m.bias, 0)
315
+ nn.init.constant_(m.weight, 1.0)
316
+
317
+ def get_classifier(self):
318
+ return self.head
319
+
320
+ def reset_classifier(self, num_classes, global_pool=''):
321
+ self.num_classes = num_classes
322
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
323
+
324
+ def forward_features(self, x):
325
+ x = self.patch_embed(x)
326
+ batch_size, seq_len, _ = x.size()
327
+
328
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
329
+ x = torch.cat((cls_tokens, x), dim=1)
330
+ if self.pos_embed is not None:
331
+ x = x + self.pos_embed
332
+ x = self.pos_drop(x)
333
+
334
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
335
+ for blk in self.blocks:
336
+ if self.use_checkpoint:
337
+ x = checkpoint.checkpoint(blk, x, rel_pos_bias)
338
+ else:
339
+ x = blk(x, rel_pos_bias)
340
+ return x
341
+ # x = self.norm(x)
342
+
343
+ # if self.fc_norm is not None:
344
+ # t = x[:, 1:, :]
345
+ # return self.fc_norm(t.mean(1))
346
+ # else:
347
+ # return x[:, 0]
348
+
349
+ def forward(self, x):
350
+ x = self.forward_features(x)
351
+ # x = self.head(x)
352
+ return x
353
+
354
+ def get_intermediate_layers(self, x):
355
+ x = self.patch_embed(x)
356
+ batch_size, seq_len, _ = x.size()
357
+
358
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
359
+ x = torch.cat((cls_tokens, x), dim=1)
360
+ if self.pos_embed is not None:
361
+ x = x + self.pos_embed
362
+ x = self.pos_drop(x)
363
+
364
+ features = []
365
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
366
+ for blk in self.blocks:
367
+ x = blk(x, rel_pos_bias)
368
+ features.append(x)
369
+
370
+ return features
371
+
372
+
373
+ def interpolate_pos_embed(model, checkpoint_model):
374
+ if 'pos_embed' in checkpoint_model:
375
+ pos_embed_checkpoint = checkpoint_model['pos_embed'].float()
376
+ embedding_size = pos_embed_checkpoint.shape[-1]
377
+ num_patches = model.patch_embed.num_patches
378
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
379
+ # height (== width) for the checkpoint position embedding
380
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
381
+ # height (== width) for the new position embedding
382
+ new_size = int(num_patches ** 0.5)
383
+ # class_token and dist_token are kept unchanged
384
+ if orig_size != new_size:
385
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
386
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
387
+ # only the position tokens are interpolated
388
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
389
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
390
+ pos_tokens = torch.nn.functional.interpolate(
391
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
392
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
393
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
394
+ checkpoint_model['pos_embed'] = new_pos_embed
395
+
396
+
397
+ def convert_weights_to_fp16(model: nn.Module):
398
+ """Convert applicable model parameters to fp16"""
399
+
400
+ def _convert_weights_to_fp16(l):
401
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
402
+ l.weight.data = l.weight.data.half()
403
+ if l.bias is not None:
404
+ l.bias.data = l.bias.data.half()
405
+
406
+ # if isinstance(l, (nn.MultiheadAttention, Attention)):
407
+ # for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
408
+ # tensor = getattr(l, attr)
409
+ # if tensor is not None:
410
+ # tensor.data = tensor.data.half()
411
+
412
+ model.apply(_convert_weights_to_fp16)
413
+
414
+
415
+ def create_eva_vit_g(img_size=224,drop_path_rate=0.4,use_checkpoint=False,precision="fp16"):
416
+ model = VisionTransformer(
417
+ img_size=img_size,
418
+ patch_size=14,
419
+ use_mean_pooling=False,
420
+ embed_dim=1408,
421
+ depth=39,
422
+ num_heads=1408//88,
423
+ mlp_ratio=4.3637,
424
+ qkv_bias=True,
425
+ drop_path_rate=drop_path_rate,
426
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
427
+ use_checkpoint=use_checkpoint,
428
+ )
429
+ url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
430
+ cached_file = download_cached_file(
431
+ url, check_hash=False, progress=True
432
+ )
433
+ state_dict = torch.load(cached_file, map_location="cpu")
434
+ interpolate_pos_embed(model,state_dict)
435
+
436
+ incompatible_keys = model.load_state_dict(state_dict, strict=False)
437
+ # print(incompatible_keys)
438
+
439
+ if precision == "fp16":
440
+ # model.to("cuda")
441
+ convert_weights_to_fp16(model)
442
+ return model
cheetah/models/modeling_llama.py ADDED
@@ -0,0 +1,803 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script is based on https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
2
+
3
+ """ PyTorch LLaMA model."""
4
+ import math
5
+ from typing import List, Optional, Tuple, Union
6
+
7
+ import torch
8
+ import torch.utils.checkpoint
9
+ from torch import nn
10
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
11
+
12
+ from transformers.activations import ACT2FN
13
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
14
+ from transformers.modeling_utils import PreTrainedModel
15
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
16
+ from transformers.models.llama.configuration_llama import LlamaConfig
17
+
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+ _CONFIG_FOR_DOC = "LlamaConfig"
22
+
23
+
24
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
25
+ def _make_causal_mask(
26
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
27
+ ):
28
+ """
29
+ Make causal mask used for bi-directional self-attention.
30
+ """
31
+ bsz, tgt_len = input_ids_shape
32
+ mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
33
+ mask_cond = torch.arange(mask.size(-1), device=device)
34
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
35
+ mask = mask.to(dtype)
36
+
37
+ if past_key_values_length > 0:
38
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
39
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
40
+
41
+
42
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
43
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
44
+ """
45
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
46
+ """
47
+ bsz, src_len = mask.size()
48
+ tgt_len = tgt_len if tgt_len is not None else src_len
49
+
50
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
51
+
52
+ inverted_mask = 1.0 - expanded_mask
53
+
54
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
55
+
56
+
57
+ class LlamaRMSNorm(nn.Module):
58
+ def __init__(self, hidden_size, eps=1e-6):
59
+ """
60
+ LlamaRMSNorm is equivalent to T5LayerNorm
61
+ """
62
+ super().__init__()
63
+ self.weight = nn.Parameter(torch.ones(hidden_size))
64
+ self.variance_epsilon = eps
65
+
66
+ def forward(self, hidden_states):
67
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
68
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
69
+
70
+ # convert into half-precision if necessary
71
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
72
+ hidden_states = hidden_states.to(self.weight.dtype)
73
+
74
+ return self.weight * hidden_states
75
+
76
+
77
+ class LlamaRotaryEmbedding(torch.nn.Module):
78
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
79
+ super().__init__()
80
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
81
+ self.register_buffer("inv_freq", inv_freq)
82
+
83
+ # Build here to make `torch.jit.trace` work.
84
+ self.max_seq_len_cached = max_position_embeddings
85
+ t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
86
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
87
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
88
+ emb = torch.cat((freqs, freqs), dim=-1)
89
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
90
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
91
+
92
+ def forward(self, x, seq_len=None):
93
+ # x: [bs, num_attention_heads, seq_len, head_size]
94
+ # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
95
+ if seq_len > self.max_seq_len_cached:
96
+ self.max_seq_len_cached = seq_len
97
+ t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
98
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
99
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
100
+ emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
101
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
102
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
103
+ return (
104
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
105
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
106
+ )
107
+
108
+
109
+ def rotate_half(x):
110
+ """Rotates half the hidden dims of the input."""
111
+ x1 = x[..., : x.shape[-1] // 2]
112
+ x2 = x[..., x.shape[-1] // 2 :]
113
+ return torch.cat((-x2, x1), dim=-1)
114
+
115
+
116
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
117
+ gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
118
+ gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
119
+ cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
120
+ sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
121
+ q_embed = (q * cos) + (rotate_half(q) * sin)
122
+ k_embed = (k * cos) + (rotate_half(k) * sin)
123
+ return q_embed, k_embed
124
+
125
+
126
+ class LlamaMLP(nn.Module):
127
+ def __init__(
128
+ self,
129
+ hidden_size: int,
130
+ intermediate_size: int,
131
+ hidden_act: str,
132
+ ):
133
+ super().__init__()
134
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
135
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
136
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
137
+ self.act_fn = ACT2FN[hidden_act]
138
+
139
+ def forward(self, x):
140
+ return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
141
+
142
+
143
+ class LlamaAttention(nn.Module):
144
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
145
+
146
+ def __init__(self, config: LlamaConfig):
147
+ super().__init__()
148
+ self.config = config
149
+ self.hidden_size = config.hidden_size
150
+ self.num_heads = config.num_attention_heads
151
+ self.head_dim = self.hidden_size // self.num_heads
152
+ self.max_position_embeddings = config.max_position_embeddings
153
+
154
+ if (self.head_dim * self.num_heads) != self.hidden_size:
155
+ raise ValueError(
156
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
157
+ f" and `num_heads`: {self.num_heads})."
158
+ )
159
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
160
+ self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
161
+ self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
162
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
163
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
164
+
165
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
166
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
167
+
168
+ def forward(
169
+ self,
170
+ hidden_states: torch.Tensor,
171
+ attention_mask: Optional[torch.Tensor] = None,
172
+ position_ids: Optional[torch.LongTensor] = None,
173
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
174
+ output_attentions: bool = False,
175
+ use_cache: bool = False,
176
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
177
+ bsz, q_len, _ = hidden_states.size()
178
+
179
+ query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
180
+ key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
181
+ value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
182
+
183
+ kv_seq_len = key_states.shape[-2]
184
+ if past_key_value is not None:
185
+ kv_seq_len += past_key_value[0].shape[-2]
186
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
187
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
188
+ # [bsz, nh, t, hd]
189
+
190
+ if past_key_value is not None:
191
+ # reuse k, v, self_attention
192
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
193
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
194
+
195
+ past_key_value = (key_states, value_states) if use_cache else None
196
+
197
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
198
+
199
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
200
+ raise ValueError(
201
+ f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
202
+ f" {attn_weights.size()}"
203
+ )
204
+
205
+ if attention_mask is not None:
206
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
207
+ raise ValueError(
208
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
209
+ )
210
+ attn_weights = attn_weights + attention_mask
211
+ attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
212
+
213
+ # upcast attention to fp32
214
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
215
+ attn_output = torch.matmul(attn_weights, value_states)
216
+
217
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
218
+ raise ValueError(
219
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
220
+ f" {attn_output.size()}"
221
+ )
222
+
223
+ attn_output = attn_output.transpose(1, 2)
224
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
225
+
226
+ attn_output = self.o_proj(attn_output)
227
+
228
+ if not output_attentions:
229
+ attn_weights = None
230
+
231
+ return attn_output, attn_weights, past_key_value
232
+
233
+
234
+ class LlamaDecoderLayer(nn.Module):
235
+ def __init__(self, config: LlamaConfig):
236
+ super().__init__()
237
+ self.hidden_size = config.hidden_size
238
+ self.self_attn = LlamaAttention(config=config)
239
+ self.mlp = LlamaMLP(
240
+ hidden_size=self.hidden_size,
241
+ intermediate_size=config.intermediate_size,
242
+ hidden_act=config.hidden_act,
243
+ )
244
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
245
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
246
+
247
+ def forward(
248
+ self,
249
+ hidden_states: torch.Tensor,
250
+ attention_mask: Optional[torch.Tensor] = None,
251
+ position_ids: Optional[torch.LongTensor] = None,
252
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
253
+ output_attentions: Optional[bool] = False,
254
+ use_cache: Optional[bool] = False,
255
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
256
+ """
257
+ Args:
258
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
259
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
260
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
261
+ output_attentions (`bool`, *optional*):
262
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
263
+ returned tensors for more detail.
264
+ use_cache (`bool`, *optional*):
265
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
266
+ (see `past_key_values`).
267
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
268
+ """
269
+
270
+ residual = hidden_states
271
+
272
+ hidden_states = self.input_layernorm(hidden_states)
273
+
274
+ # Self Attention
275
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
276
+ hidden_states=hidden_states,
277
+ attention_mask=attention_mask,
278
+ position_ids=position_ids,
279
+ past_key_value=past_key_value,
280
+ output_attentions=output_attentions,
281
+ use_cache=use_cache,
282
+ )
283
+ hidden_states = residual + hidden_states
284
+
285
+ # Fully Connected
286
+ residual = hidden_states
287
+ hidden_states = self.post_attention_layernorm(hidden_states)
288
+ hidden_states = self.mlp(hidden_states)
289
+ hidden_states = residual + hidden_states
290
+
291
+ outputs = (hidden_states,)
292
+
293
+ if output_attentions:
294
+ outputs += (self_attn_weights,)
295
+
296
+ if use_cache:
297
+ outputs += (present_key_value,)
298
+
299
+ return outputs
300
+
301
+
302
+ LLAMA_START_DOCSTRING = r"""
303
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
304
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
305
+ etc.)
306
+
307
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
308
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
309
+ and behavior.
310
+
311
+ Parameters:
312
+ config ([`LlamaConfig`]):
313
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
314
+ load the weights associated with the model, only the configuration. Check out the
315
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
316
+ """
317
+
318
+
319
+ @add_start_docstrings(
320
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
321
+ LLAMA_START_DOCSTRING,
322
+ )
323
+ class LlamaPreTrainedModel(PreTrainedModel):
324
+ config_class = LlamaConfig
325
+ base_model_prefix = "model"
326
+ supports_gradient_checkpointing = True
327
+ _no_split_modules = ["LlamaDecoderLayer"]
328
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
329
+
330
+ def _init_weights(self, module):
331
+ std = self.config.initializer_range
332
+ if isinstance(module, nn.Linear):
333
+ module.weight.data.normal_(mean=0.0, std=std)
334
+ if module.bias is not None:
335
+ module.bias.data.zero_()
336
+ elif isinstance(module, nn.Embedding):
337
+ module.weight.data.normal_(mean=0.0, std=std)
338
+ if module.padding_idx is not None:
339
+ module.weight.data[module.padding_idx].zero_()
340
+
341
+ def _set_gradient_checkpointing(self, module, value=False):
342
+ if isinstance(module, LlamaModel):
343
+ module.gradient_checkpointing = value
344
+
345
+
346
+ LLAMA_INPUTS_DOCSTRING = r"""
347
+ Args:
348
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
349
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
350
+ it.
351
+
352
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
353
+ [`PreTrainedTokenizer.__call__`] for details.
354
+
355
+ [What are input IDs?](../glossary#input-ids)
356
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
357
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
358
+
359
+ - 1 for tokens that are **not masked**,
360
+ - 0 for tokens that are **masked**.
361
+
362
+ [What are attention masks?](../glossary#attention-mask)
363
+
364
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
365
+ [`PreTrainedTokenizer.__call__`] for details.
366
+
367
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
368
+ `past_key_values`).
369
+
370
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
371
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
372
+ information on the default strategy.
373
+
374
+ - 1 indicates the head is **not masked**,
375
+ - 0 indicates the head is **masked**.
376
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
377
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
378
+ config.n_positions - 1]`.
379
+
380
+ [What are position IDs?](../glossary#position-ids)
381
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
382
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
383
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
384
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
385
+
386
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
387
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
388
+
389
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
390
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
391
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
392
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
393
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
394
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
395
+ model's internal embedding lookup matrix.
396
+ use_cache (`bool`, *optional*):
397
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
398
+ `past_key_values`).
399
+ output_attentions (`bool`, *optional*):
400
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
401
+ tensors for more detail.
402
+ output_hidden_states (`bool`, *optional*):
403
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
404
+ more detail.
405
+ return_dict (`bool`, *optional*):
406
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
407
+ """
408
+
409
+
410
+ @add_start_docstrings(
411
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
412
+ LLAMA_START_DOCSTRING,
413
+ )
414
+ class LlamaModel(LlamaPreTrainedModel):
415
+ """
416
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
417
+
418
+ Args:
419
+ config: LlamaConfig
420
+ """
421
+
422
+ def __init__(self, config: LlamaConfig):
423
+ super().__init__(config)
424
+ self.padding_idx = config.pad_token_id
425
+ self.vocab_size = config.vocab_size
426
+
427
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
428
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
429
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
430
+
431
+ self.gradient_checkpointing = False
432
+ # Initialize weights and apply final processing
433
+ self.post_init()
434
+
435
+ def get_input_embeddings(self):
436
+ return self.embed_tokens
437
+
438
+ def set_input_embeddings(self, value):
439
+ self.embed_tokens = value
440
+
441
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
442
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
443
+ # create causal mask
444
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
445
+ combined_attention_mask = None
446
+ if input_shape[-1] > 1:
447
+ combined_attention_mask = _make_causal_mask(
448
+ input_shape,
449
+ inputs_embeds.dtype,
450
+ device=inputs_embeds.device,
451
+ past_key_values_length=past_key_values_length,
452
+ )
453
+
454
+ if attention_mask is not None:
455
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
456
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
457
+ inputs_embeds.device
458
+ )
459
+ combined_attention_mask = (
460
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
461
+ )
462
+
463
+ return combined_attention_mask
464
+
465
+ def set_qformer_and_proj(self, Qformer, qformer_proj, llm_proj, query_tokens):
466
+ self.Qformer = Qformer
467
+ self.qformer_proj = qformer_proj
468
+ self.llm_proj = llm_proj
469
+ self.query_tokens = query_tokens
470
+
471
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
472
+ def forward(
473
+ self,
474
+ input_ids: torch.LongTensor = None,
475
+ attention_mask: Optional[torch.Tensor] = None,
476
+ position_ids: Optional[torch.LongTensor] = None,
477
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
478
+ inputs_embeds: Optional[torch.FloatTensor] = None,
479
+ query_embeds: Optional[torch.FloatTensor] = None,
480
+ use_cache: Optional[bool] = None,
481
+ output_attentions: Optional[bool] = None,
482
+ output_hidden_states: Optional[bool] = None,
483
+ return_dict: Optional[bool] = None,
484
+ update_layer: Optional[int] = 16,
485
+ image_position_list = None,
486
+ input_part_targets_len = None,
487
+ all_image_embeds = None
488
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
489
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
490
+ output_hidden_states = (
491
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
492
+ )
493
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
494
+
495
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
496
+
497
+ # retrieve input_ids and inputs_embeds
498
+ if input_ids is not None and inputs_embeds is not None:
499
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
500
+ elif input_ids is not None:
501
+ batch_size, seq_length = input_ids.shape
502
+ elif inputs_embeds is not None:
503
+ batch_size, seq_length, _ = inputs_embeds.shape
504
+ else:
505
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
506
+
507
+ if inputs_embeds is None:
508
+ inputs_embeds = self.embed_tokens(input_ids)
509
+ if query_embeds is not None:
510
+ inputs_embeds = torch.cat([query_embeds, inputs_embeds], dim=1)
511
+ batch_size, seq_length, _ = inputs_embeds.shape
512
+
513
+ seq_length_with_past = seq_length
514
+ past_key_values_length = 0
515
+
516
+ if past_key_values is not None:
517
+ past_key_values_length = past_key_values[0][0].shape[2]
518
+ seq_length_with_past = seq_length_with_past + past_key_values_length
519
+
520
+ if position_ids is None:
521
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
522
+ position_ids = torch.arange(
523
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
524
+ )
525
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
526
+ else:
527
+ position_ids = position_ids.view(-1, seq_length).long()
528
+
529
+ # embed positions
530
+ if attention_mask is None:
531
+ attention_mask = torch.ones(
532
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
533
+ )
534
+ attention_mask = self._prepare_decoder_attention_mask(
535
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
536
+ )
537
+
538
+ hidden_states = inputs_embeds
539
+
540
+ if self.gradient_checkpointing and self.training:
541
+ if use_cache:
542
+ logger.warning_once(
543
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
544
+ )
545
+ use_cache = False
546
+
547
+ # decoder layers
548
+ all_hidden_states = () if output_hidden_states else None
549
+ all_self_attns = () if output_attentions else None
550
+ next_decoder_cache = () if use_cache else None
551
+
552
+ for idx, decoder_layer in enumerate(self.layers):
553
+ if output_hidden_states:
554
+ all_hidden_states += (hidden_states,)
555
+
556
+ if idx == update_layer and past_key_values is None:
557
+ hidden_output_list = []
558
+ for i, l in enumerate(input_part_targets_len):
559
+ hidden_output = hidden_states[i, l, :]
560
+ hidden_output_list.append(hidden_output)
561
+ hidden_outputs = torch.stack(hidden_output_list, dim=0).unsqueeze(1)
562
+ hidden_outputs = self.qformer_proj(hidden_outputs)
563
+ new_query_tokens = hidden_outputs + self.query_tokens
564
+ assert all_image_embeds.size(1) == len(image_position_list)
565
+ for i in range(all_image_embeds.size(1)):
566
+ image_embeds = all_image_embeds[:,i,:,:]
567
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device)
568
+ query_output = self.Qformer.bert(
569
+ query_embeds=new_query_tokens,
570
+ encoder_hidden_states=image_embeds,
571
+ encoder_attention_mask=image_atts,
572
+ return_dict=True,
573
+ )
574
+ new_hidden_state = self.llm_proj(query_output.last_hidden_state[:,:new_query_tokens.size(1),:])
575
+ img_start, img_end = image_position_list[i][0], image_position_list[i][1]
576
+ hidden_states[:, img_start:img_end, :] += new_hidden_state[:, :, :]
577
+
578
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
579
+
580
+ if self.gradient_checkpointing and self.training:
581
+
582
+ def create_custom_forward(module):
583
+ def custom_forward(*inputs):
584
+ # None for past_key_value
585
+ return module(*inputs, output_attentions, None)
586
+
587
+ return custom_forward
588
+
589
+ layer_outputs = torch.utils.checkpoint.checkpoint(
590
+ create_custom_forward(decoder_layer),
591
+ hidden_states,
592
+ attention_mask,
593
+ position_ids,
594
+ None,
595
+ )
596
+ else:
597
+ layer_outputs = decoder_layer(
598
+ hidden_states,
599
+ attention_mask=attention_mask,
600
+ position_ids=position_ids,
601
+ past_key_value=past_key_value,
602
+ output_attentions=output_attentions,
603
+ use_cache=use_cache,
604
+ )
605
+
606
+ hidden_states = layer_outputs[0]
607
+
608
+ if use_cache:
609
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
610
+
611
+ if output_attentions:
612
+ all_self_attns += (layer_outputs[1],)
613
+
614
+ hidden_states = self.norm(hidden_states)
615
+
616
+ # add hidden states from the last decoder layer
617
+ if output_hidden_states:
618
+ all_hidden_states += (hidden_states,)
619
+
620
+ next_cache = next_decoder_cache if use_cache else None
621
+ if not return_dict:
622
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
623
+ return BaseModelOutputWithPast(
624
+ last_hidden_state=hidden_states,
625
+ past_key_values=next_cache,
626
+ hidden_states=all_hidden_states,
627
+ attentions=all_self_attns,
628
+ )
629
+
630
+
631
+ class LlamaForCausalLM(LlamaPreTrainedModel):
632
+ def __init__(self, config):
633
+ super().__init__(config)
634
+ self.model = LlamaModel(config)
635
+
636
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
637
+
638
+ # Initialize weights and apply final processing
639
+ self.post_init()
640
+
641
+ def get_input_embeddings(self):
642
+ return self.model.embed_tokens
643
+
644
+ def set_input_embeddings(self, value):
645
+ self.model.embed_tokens = value
646
+
647
+ def get_output_embeddings(self):
648
+ return self.lm_head
649
+
650
+ def set_output_embeddings(self, new_embeddings):
651
+ self.lm_head = new_embeddings
652
+
653
+ def set_decoder(self, decoder):
654
+ self.model = decoder
655
+
656
+ def get_decoder(self):
657
+ return self.model
658
+
659
+ def set_qformer_and_proj(self, Qformer, qformer_proj, llm_proj, query_tokens):
660
+ self.model.set_qformer_and_proj(Qformer, qformer_proj, llm_proj, query_tokens)
661
+
662
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
663
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
664
+ def forward(
665
+ self,
666
+ input_ids: torch.LongTensor = None,
667
+ attention_mask: Optional[torch.Tensor] = None,
668
+ position_ids: Optional[torch.LongTensor] = None,
669
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
670
+ inputs_embeds: Optional[torch.FloatTensor] = None,
671
+ query_embeds: Optional[torch.FloatTensor] = None,
672
+ labels: Optional[torch.LongTensor] = None,
673
+ use_cache: Optional[bool] = None,
674
+ output_attentions: Optional[bool] = None,
675
+ output_hidden_states: Optional[bool] = None,
676
+ return_dict: Optional[bool] = None,
677
+ update_layer: Optional[int] = 16,
678
+ image_position_list = None,
679
+ input_part_targets_len = None,
680
+ all_image_embeds = None,
681
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
682
+ r"""
683
+ Args:
684
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
685
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
686
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
687
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
688
+
689
+ Returns:
690
+
691
+ Example:
692
+
693
+ ```python
694
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
695
+
696
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
697
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
698
+
699
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
700
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
701
+
702
+ >>> # Generate
703
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
704
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
705
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
706
+ ```"""
707
+
708
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
709
+ output_hidden_states = (
710
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
711
+ )
712
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
713
+
714
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
715
+ outputs = self.model(
716
+ input_ids=input_ids,
717
+ attention_mask=attention_mask,
718
+ position_ids=position_ids,
719
+ past_key_values=past_key_values,
720
+ inputs_embeds=inputs_embeds,
721
+ query_embeds=query_embeds,
722
+ use_cache=use_cache,
723
+ output_attentions=output_attentions,
724
+ output_hidden_states=output_hidden_states,
725
+ return_dict=return_dict,
726
+ update_layer = update_layer,
727
+ image_position_list = image_position_list,
728
+ input_part_targets_len = input_part_targets_len,
729
+ all_image_embeds = all_image_embeds,
730
+ )
731
+
732
+ hidden_states = outputs[0]
733
+ logits = self.lm_head(hidden_states)
734
+
735
+ loss = None
736
+ if labels is not None:
737
+ # Shift so that tokens < n predict n
738
+ shift_logits = logits[..., :-1, :].contiguous()
739
+ shift_labels = labels[..., 1:].contiguous()
740
+ # Flatten the tokens
741
+ loss_fct = CrossEntropyLoss()
742
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
743
+ shift_labels = shift_labels.view(-1)
744
+ # Enable model parallelism
745
+ shift_labels = shift_labels.to(shift_logits.device)
746
+ loss = loss_fct(shift_logits, shift_labels)
747
+
748
+ if not return_dict:
749
+ output = (logits,) + outputs[1:]
750
+ return (loss,) + output if loss is not None else output
751
+
752
+ return CausalLMOutputWithPast(
753
+ loss=loss,
754
+ logits=logits,
755
+ past_key_values=outputs.past_key_values,
756
+ hidden_states=outputs.hidden_states,
757
+ attentions=outputs.attentions,
758
+ )
759
+
760
+ def prepare_inputs_for_generation(
761
+ self, input_ids, query_embeds=None, past_key_values=None, attention_mask=None, inputs_embeds=None,
762
+ image_position_list=None,input_part_targets_len=None,all_image_embeds=None,update_layer=16,**kwargs
763
+ ):
764
+ if past_key_values:
765
+ input_ids = input_ids[:, -1:]
766
+
767
+ position_ids = kwargs.get("position_ids", None)
768
+ if attention_mask is not None and position_ids is None:
769
+ # create position_ids on the fly for batch generation
770
+ position_ids = attention_mask.long().cumsum(-1) - 1
771
+ position_ids.masked_fill_(attention_mask == 0, 1)
772
+ if past_key_values:
773
+ position_ids = position_ids[:, -1].unsqueeze(-1)
774
+ query_embeds = None
775
+
776
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
777
+ if inputs_embeds is not None and past_key_values is None:
778
+ model_inputs = {"inputs_embeds": inputs_embeds}
779
+ else:
780
+ model_inputs = {"input_ids": input_ids}
781
+
782
+ model_inputs.update(
783
+ {
784
+ "position_ids": position_ids,
785
+ "query_embeds": query_embeds,
786
+ "past_key_values": past_key_values,
787
+ "use_cache": kwargs.get("use_cache"),
788
+ "attention_mask": attention_mask,
789
+ "image_position_list": image_position_list,
790
+ "input_part_targets_len": input_part_targets_len,
791
+ "all_image_embeds": all_image_embeds,
792
+ "update_layer": update_layer
793
+ }
794
+ )
795
+ return model_inputs
796
+
797
+ @staticmethod
798
+ def _reorder_cache(past_key_values, beam_idx):
799
+ reordered_past = ()
800
+ for layer_past in past_key_values:
801
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
802
+ return reordered_past
803
+
cheetah/models/modeling_llama2.py ADDED
@@ -0,0 +1,1070 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
5
+ # and OPT implementations in this library. It has been modified from its
6
+ # original forms to accommodate minor architectural differences compared
7
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ """ PyTorch LLaMA model."""
21
+ import math
22
+ from typing import List, Optional, Tuple, Union
23
+
24
+ import torch
25
+ import torch.nn.functional as F
26
+ import torch.utils.checkpoint
27
+ from torch import nn
28
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
29
+
30
+ from transformers.activations import ACT2FN
31
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
32
+ from transformers.modeling_utils import PreTrainedModel
33
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
34
+ from transformers.models.llama.configuration_llama import LlamaConfig
35
+
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+ _CONFIG_FOR_DOC = "LlamaConfig"
40
+
41
+
42
+ # Copied from transformers.models.bart.modeling_bart._make_causal_mask
43
+ def _make_causal_mask(
44
+ input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
45
+ ):
46
+ """
47
+ Make causal mask used for bi-directional self-attention.
48
+ """
49
+ bsz, tgt_len = input_ids_shape
50
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
51
+ mask_cond = torch.arange(mask.size(-1), device=device)
52
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
53
+ mask = mask.to(dtype)
54
+
55
+ if past_key_values_length > 0:
56
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
57
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
58
+
59
+
60
+ # Copied from transformers.models.bart.modeling_bart._expand_mask
61
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
62
+ """
63
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
64
+ """
65
+ bsz, src_len = mask.size()
66
+ tgt_len = tgt_len if tgt_len is not None else src_len
67
+
68
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
69
+
70
+ inverted_mask = 1.0 - expanded_mask
71
+
72
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
73
+
74
+
75
+ class LlamaRMSNorm(nn.Module):
76
+ def __init__(self, hidden_size, eps=1e-6):
77
+ """
78
+ LlamaRMSNorm is equivalent to T5LayerNorm
79
+ """
80
+ super().__init__()
81
+ self.weight = nn.Parameter(torch.ones(hidden_size))
82
+ self.variance_epsilon = eps
83
+
84
+ def forward(self, hidden_states):
85
+ input_dtype = hidden_states.dtype
86
+ hidden_states = hidden_states.to(torch.float32)
87
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
88
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
89
+ return self.weight * hidden_states.to(input_dtype)
90
+
91
+
92
+ class LlamaRotaryEmbedding(torch.nn.Module):
93
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
94
+ super().__init__()
95
+
96
+ self.dim = dim
97
+ self.max_position_embeddings = max_position_embeddings
98
+ self.base = base
99
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
100
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
101
+
102
+ # Build here to make `torch.jit.trace` work.
103
+ self._set_cos_sin_cache(
104
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
105
+ )
106
+
107
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
108
+ self.max_seq_len_cached = seq_len
109
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
110
+
111
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
112
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
113
+ emb = torch.cat((freqs, freqs), dim=-1)
114
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
115
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
116
+
117
+ def forward(self, x, seq_len=None):
118
+ # x: [bs, num_attention_heads, seq_len, head_size]
119
+ if seq_len > self.max_seq_len_cached:
120
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
121
+
122
+ return (
123
+ self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
124
+ self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
125
+ )
126
+
127
+
128
+ class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
129
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
130
+
131
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
132
+ self.scaling_factor = scaling_factor
133
+ super().__init__(dim, max_position_embeddings, base, device)
134
+
135
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
136
+ self.max_seq_len_cached = seq_len
137
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
138
+ t = t / self.scaling_factor
139
+
140
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
141
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
142
+ emb = torch.cat((freqs, freqs), dim=-1)
143
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
144
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
145
+
146
+
147
+ class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
148
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
149
+
150
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
151
+ self.scaling_factor = scaling_factor
152
+ super().__init__(dim, max_position_embeddings, base, device)
153
+
154
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
155
+ self.max_seq_len_cached = seq_len
156
+
157
+ if seq_len > self.max_position_embeddings:
158
+ base = self.base * (
159
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
160
+ ) ** (self.dim / (self.dim - 2))
161
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
162
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
163
+
164
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
165
+
166
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
167
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
168
+ emb = torch.cat((freqs, freqs), dim=-1)
169
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
170
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
171
+
172
+
173
+ def rotate_half(x):
174
+ """Rotates half the hidden dims of the input."""
175
+ x1 = x[..., : x.shape[-1] // 2]
176
+ x2 = x[..., x.shape[-1] // 2 :]
177
+ return torch.cat((-x2, x1), dim=-1)
178
+
179
+
180
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
181
+ # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
182
+ cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
183
+ sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
184
+ cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
185
+ sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
186
+ q_embed = (q * cos) + (rotate_half(q) * sin)
187
+ k_embed = (k * cos) + (rotate_half(k) * sin)
188
+ return q_embed, k_embed
189
+
190
+
191
+ class LlamaMLP(nn.Module):
192
+ def __init__(self, config):
193
+ super().__init__()
194
+ self.config = config
195
+ self.hidden_size = config.hidden_size
196
+ self.intermediate_size = config.intermediate_size
197
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
198
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
199
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
200
+ self.act_fn = ACT2FN[config.hidden_act]
201
+
202
+ def forward(self, x):
203
+ if self.config.pretraining_tp > 1:
204
+ slice = self.intermediate_size // self.config.pretraining_tp
205
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
206
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
207
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
208
+
209
+ gate_proj = torch.cat(
210
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
211
+ )
212
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
213
+
214
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
215
+ down_proj = [
216
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
217
+ ]
218
+ down_proj = sum(down_proj)
219
+ else:
220
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
221
+
222
+ return down_proj
223
+
224
+
225
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
226
+ """
227
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
228
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
229
+ """
230
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
231
+ if n_rep == 1:
232
+ return hidden_states
233
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
234
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
235
+
236
+
237
+ class LlamaAttention(nn.Module):
238
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
239
+
240
+ def __init__(self, config: LlamaConfig):
241
+ super().__init__()
242
+ self.config = config
243
+ self.hidden_size = config.hidden_size
244
+ self.num_heads = config.num_attention_heads
245
+ self.head_dim = self.hidden_size // self.num_heads
246
+ self.num_key_value_heads = config.num_key_value_heads
247
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
248
+ self.max_position_embeddings = config.max_position_embeddings
249
+
250
+ if (self.head_dim * self.num_heads) != self.hidden_size:
251
+ raise ValueError(
252
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
253
+ f" and `num_heads`: {self.num_heads})."
254
+ )
255
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
256
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
257
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
258
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
259
+ self._init_rope()
260
+
261
+ def _init_rope(self):
262
+ if self.config.rope_scaling is None:
263
+ self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
264
+ else:
265
+ scaling_type = self.config.rope_scaling["type"]
266
+ scaling_factor = self.config.rope_scaling["factor"]
267
+ if scaling_type == "linear":
268
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
269
+ self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
270
+ )
271
+ elif scaling_type == "dynamic":
272
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
273
+ self.head_dim, max_position_embeddings=self.max_position_embeddings, scaling_factor=scaling_factor
274
+ )
275
+ else:
276
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
277
+
278
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
279
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
280
+
281
+ def forward(
282
+ self,
283
+ hidden_states: torch.Tensor,
284
+ attention_mask: Optional[torch.Tensor] = None,
285
+ position_ids: Optional[torch.LongTensor] = None,
286
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
287
+ output_attentions: bool = False,
288
+ use_cache: bool = False,
289
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
290
+ bsz, q_len, _ = hidden_states.size()
291
+
292
+ if self.config.pretraining_tp > 1:
293
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
294
+ query_slices = self.q_proj.weight.split(
295
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
296
+ )
297
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
298
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
299
+
300
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
301
+ query_states = torch.cat(query_states, dim=-1)
302
+
303
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
304
+ key_states = torch.cat(key_states, dim=-1)
305
+
306
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
307
+ value_states = torch.cat(value_states, dim=-1)
308
+
309
+ else:
310
+ query_states = self.q_proj(hidden_states)
311
+ key_states = self.k_proj(hidden_states)
312
+ value_states = self.v_proj(hidden_states)
313
+
314
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
315
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
316
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
317
+
318
+ kv_seq_len = key_states.shape[-2]
319
+ if past_key_value is not None:
320
+ kv_seq_len += past_key_value[0].shape[-2]
321
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
322
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
323
+
324
+ if past_key_value is not None:
325
+ # reuse k, v, self_attention
326
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
327
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
328
+
329
+ past_key_value = (key_states, value_states) if use_cache else None
330
+
331
+ # repeat k/v heads if n_kv_heads < n_heads
332
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
333
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
334
+
335
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
336
+
337
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
338
+ raise ValueError(
339
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
340
+ f" {attn_weights.size()}"
341
+ )
342
+
343
+ if attention_mask is not None:
344
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
345
+ raise ValueError(
346
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
347
+ )
348
+ attn_weights = attn_weights + attention_mask
349
+
350
+ # upcast attention to fp32
351
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
352
+ attn_output = torch.matmul(attn_weights, value_states)
353
+
354
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
355
+ raise ValueError(
356
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
357
+ f" {attn_output.size()}"
358
+ )
359
+
360
+ attn_output = attn_output.transpose(1, 2).contiguous()
361
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
362
+
363
+ if self.config.pretraining_tp > 1:
364
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
365
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
366
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
367
+ else:
368
+ attn_output = self.o_proj(attn_output)
369
+
370
+ if not output_attentions:
371
+ attn_weights = None
372
+
373
+ return attn_output, attn_weights, past_key_value
374
+
375
+
376
+ class LlamaDecoderLayer(nn.Module):
377
+ def __init__(self, config: LlamaConfig):
378
+ super().__init__()
379
+ self.hidden_size = config.hidden_size
380
+ self.self_attn = LlamaAttention(config=config)
381
+ self.mlp = LlamaMLP(config)
382
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
383
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
384
+
385
+ def forward(
386
+ self,
387
+ hidden_states: torch.Tensor,
388
+ attention_mask: Optional[torch.Tensor] = None,
389
+ position_ids: Optional[torch.LongTensor] = None,
390
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
391
+ output_attentions: Optional[bool] = False,
392
+ use_cache: Optional[bool] = False,
393
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
394
+ """
395
+ Args:
396
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
397
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
398
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
399
+ output_attentions (`bool`, *optional*):
400
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
401
+ returned tensors for more detail.
402
+ use_cache (`bool`, *optional*):
403
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
404
+ (see `past_key_values`).
405
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
406
+ """
407
+
408
+ residual = hidden_states
409
+
410
+ hidden_states = self.input_layernorm(hidden_states)
411
+
412
+ # Self Attention
413
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
414
+ hidden_states=hidden_states,
415
+ attention_mask=attention_mask,
416
+ position_ids=position_ids,
417
+ past_key_value=past_key_value,
418
+ output_attentions=output_attentions,
419
+ use_cache=use_cache,
420
+ )
421
+ hidden_states = residual + hidden_states
422
+
423
+ # Fully Connected
424
+ residual = hidden_states
425
+ hidden_states = self.post_attention_layernorm(hidden_states)
426
+ hidden_states = self.mlp(hidden_states)
427
+ hidden_states = residual + hidden_states
428
+
429
+ outputs = (hidden_states,)
430
+
431
+ if output_attentions:
432
+ outputs += (self_attn_weights,)
433
+
434
+ if use_cache:
435
+ outputs += (present_key_value,)
436
+
437
+ return outputs
438
+
439
+
440
+ LLAMA_START_DOCSTRING = r"""
441
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
442
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
443
+ etc.)
444
+
445
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
446
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
447
+ and behavior.
448
+
449
+ Parameters:
450
+ config ([`LlamaConfig`]):
451
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
452
+ load the weights associated with the model, only the configuration. Check out the
453
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
454
+ """
455
+
456
+
457
+ @add_start_docstrings(
458
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
459
+ LLAMA_START_DOCSTRING,
460
+ )
461
+ class LlamaPreTrainedModel(PreTrainedModel):
462
+ config_class = LlamaConfig
463
+ base_model_prefix = "model"
464
+ supports_gradient_checkpointing = True
465
+ _no_split_modules = ["LlamaDecoderLayer"]
466
+ _skip_keys_device_placement = "past_key_values"
467
+
468
+ def _init_weights(self, module):
469
+ std = self.config.initializer_range
470
+ if isinstance(module, nn.Linear):
471
+ module.weight.data.normal_(mean=0.0, std=std)
472
+ if module.bias is not None:
473
+ module.bias.data.zero_()
474
+ elif isinstance(module, nn.Embedding):
475
+ module.weight.data.normal_(mean=0.0, std=std)
476
+ if module.padding_idx is not None:
477
+ module.weight.data[module.padding_idx].zero_()
478
+
479
+ def _set_gradient_checkpointing(self, module, value=False):
480
+ if isinstance(module, LlamaModel):
481
+ module.gradient_checkpointing = value
482
+
483
+
484
+ LLAMA_INPUTS_DOCSTRING = r"""
485
+ Args:
486
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
487
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
488
+ it.
489
+
490
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
491
+ [`PreTrainedTokenizer.__call__`] for details.
492
+
493
+ [What are input IDs?](../glossary#input-ids)
494
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
495
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
496
+
497
+ - 1 for tokens that are **not masked**,
498
+ - 0 for tokens that are **masked**.
499
+
500
+ [What are attention masks?](../glossary#attention-mask)
501
+
502
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
503
+ [`PreTrainedTokenizer.__call__`] for details.
504
+
505
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
506
+ `past_key_values`).
507
+
508
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
509
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
510
+ information on the default strategy.
511
+
512
+ - 1 indicates the head is **not masked**,
513
+ - 0 indicates the head is **masked**.
514
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
515
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
516
+ config.n_positions - 1]`.
517
+
518
+ [What are position IDs?](../glossary#position-ids)
519
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
520
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
521
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
522
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
523
+
524
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
525
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
526
+
527
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
528
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
529
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
530
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
531
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
532
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
533
+ model's internal embedding lookup matrix.
534
+ use_cache (`bool`, *optional*):
535
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
536
+ `past_key_values`).
537
+ output_attentions (`bool`, *optional*):
538
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
539
+ tensors for more detail.
540
+ output_hidden_states (`bool`, *optional*):
541
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
542
+ more detail.
543
+ return_dict (`bool`, *optional*):
544
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
545
+ """
546
+
547
+
548
+ @add_start_docstrings(
549
+ "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
550
+ LLAMA_START_DOCSTRING,
551
+ )
552
+ class LlamaModel(LlamaPreTrainedModel):
553
+ """
554
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
555
+
556
+ Args:
557
+ config: LlamaConfig
558
+ """
559
+
560
+ def __init__(self, config: LlamaConfig):
561
+ super().__init__(config)
562
+ self.padding_idx = config.pad_token_id
563
+ self.vocab_size = config.vocab_size
564
+
565
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
566
+ self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
567
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
568
+
569
+ self.gradient_checkpointing = False
570
+ # Initialize weights and apply final processing
571
+ self.post_init()
572
+
573
+ def get_input_embeddings(self):
574
+ return self.embed_tokens
575
+
576
+ def set_input_embeddings(self, value):
577
+ self.embed_tokens = value
578
+
579
+ # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
580
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
581
+ # create causal mask
582
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
583
+ combined_attention_mask = None
584
+ if input_shape[-1] > 1:
585
+ combined_attention_mask = _make_causal_mask(
586
+ input_shape,
587
+ inputs_embeds.dtype,
588
+ device=inputs_embeds.device,
589
+ past_key_values_length=past_key_values_length,
590
+ )
591
+
592
+ if attention_mask is not None:
593
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
594
+ expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(
595
+ inputs_embeds.device
596
+ )
597
+ combined_attention_mask = (
598
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
599
+ )
600
+
601
+ return combined_attention_mask
602
+
603
+ def set_qformer_and_proj(self, Qformer, qformer_proj, llm_proj, query_tokens):
604
+ self.Qformer = Qformer
605
+ self.qformer_proj = qformer_proj
606
+ self.llm_proj = llm_proj
607
+ self.query_tokens = query_tokens
608
+
609
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
610
+ def forward(
611
+ self,
612
+ input_ids: torch.LongTensor = None,
613
+ attention_mask: Optional[torch.Tensor] = None,
614
+ position_ids: Optional[torch.LongTensor] = None,
615
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
616
+ inputs_embeds: Optional[torch.FloatTensor] = None,
617
+ query_embeds: Optional[torch.FloatTensor] = None,
618
+ use_cache: Optional[bool] = None,
619
+ output_attentions: Optional[bool] = None,
620
+ output_hidden_states: Optional[bool] = None,
621
+ return_dict: Optional[bool] = None,
622
+ update_layer: Optional[int] = 16,
623
+ image_position_list = None,
624
+ input_part_targets_len = None,
625
+ all_image_embeds = None
626
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
627
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
628
+ output_hidden_states = (
629
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
630
+ )
631
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
632
+
633
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
634
+
635
+ # retrieve input_ids and inputs_embeds
636
+ if input_ids is not None and inputs_embeds is not None:
637
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
638
+ elif input_ids is not None:
639
+ batch_size, seq_length = input_ids.shape
640
+ elif inputs_embeds is not None:
641
+ batch_size, seq_length, _ = inputs_embeds.shape
642
+ else:
643
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
644
+
645
+ if inputs_embeds is None:
646
+ inputs_embeds = self.embed_tokens(input_ids)
647
+ if query_embeds is not None:
648
+ inputs_embeds = torch.cat([query_embeds, inputs_embeds], dim=1)
649
+ batch_size, seq_length, _ = inputs_embeds.shape
650
+
651
+ seq_length_with_past = seq_length
652
+ past_key_values_length = 0
653
+
654
+ if past_key_values is not None:
655
+ past_key_values_length = past_key_values[0][0].shape[2]
656
+ seq_length_with_past = seq_length_with_past + past_key_values_length
657
+
658
+ if position_ids is None:
659
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
660
+ position_ids = torch.arange(
661
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
662
+ )
663
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
664
+ else:
665
+ position_ids = position_ids.view(-1, seq_length).long()
666
+
667
+ # embed positions
668
+ if attention_mask is None:
669
+ attention_mask = torch.ones(
670
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
671
+ )
672
+ attention_mask = self._prepare_decoder_attention_mask(
673
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
674
+ )
675
+
676
+ hidden_states = inputs_embeds
677
+
678
+ if self.gradient_checkpointing and self.training:
679
+ if use_cache:
680
+ logger.warning_once(
681
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
682
+ )
683
+ use_cache = False
684
+
685
+ # decoder layers
686
+ all_hidden_states = () if output_hidden_states else None
687
+ all_self_attns = () if output_attentions else None
688
+ next_decoder_cache = () if use_cache else None
689
+
690
+ for idx, decoder_layer in enumerate(self.layers):
691
+ if output_hidden_states:
692
+ all_hidden_states += (hidden_states,)
693
+
694
+ if idx == update_layer and past_key_values is None:
695
+ hidden_output_list = []
696
+ for i, l in enumerate(input_part_targets_len):
697
+ hidden_output = hidden_states[i, l, :]
698
+ hidden_output_list.append(hidden_output)
699
+ hidden_outputs = torch.stack(hidden_output_list, dim=0).unsqueeze(1)
700
+ hidden_outputs = self.qformer_proj(hidden_outputs)
701
+ new_query_tokens = hidden_outputs + self.query_tokens
702
+ assert all_image_embeds.size(1) == len(image_position_list)
703
+ for i in range(all_image_embeds.size(1)):
704
+ image_embeds = all_image_embeds[:,i,:,:]
705
+ image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image_embeds.device)
706
+ query_output = self.Qformer.bert(
707
+ query_embeds=new_query_tokens,
708
+ encoder_hidden_states=image_embeds,
709
+ encoder_attention_mask=image_atts,
710
+ return_dict=True,
711
+ )
712
+ new_hidden_state = self.llm_proj(query_output.last_hidden_state[:,:new_query_tokens.size(1),:])
713
+ img_start, img_end = image_position_list[i][0], image_position_list[i][1]
714
+ hidden_states[:, img_start:img_end, :] += new_hidden_state[:, :, :]
715
+
716
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
717
+
718
+ if self.gradient_checkpointing and self.training:
719
+
720
+ def create_custom_forward(module):
721
+ def custom_forward(*inputs):
722
+ # None for past_key_value
723
+ return module(*inputs, output_attentions, None)
724
+
725
+ return custom_forward
726
+
727
+ layer_outputs = torch.utils.checkpoint.checkpoint(
728
+ create_custom_forward(decoder_layer),
729
+ hidden_states,
730
+ attention_mask,
731
+ position_ids,
732
+ None,
733
+ )
734
+ else:
735
+ layer_outputs = decoder_layer(
736
+ hidden_states,
737
+ attention_mask=attention_mask,
738
+ position_ids=position_ids,
739
+ past_key_value=past_key_value,
740
+ output_attentions=output_attentions,
741
+ use_cache=use_cache,
742
+ )
743
+
744
+ hidden_states = layer_outputs[0]
745
+
746
+ if use_cache:
747
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
748
+
749
+ if output_attentions:
750
+ all_self_attns += (layer_outputs[1],)
751
+
752
+ hidden_states = self.norm(hidden_states)
753
+
754
+ # add hidden states from the last decoder layer
755
+ if output_hidden_states:
756
+ all_hidden_states += (hidden_states,)
757
+
758
+ next_cache = next_decoder_cache if use_cache else None
759
+ if not return_dict:
760
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
761
+ return BaseModelOutputWithPast(
762
+ last_hidden_state=hidden_states,
763
+ past_key_values=next_cache,
764
+ hidden_states=all_hidden_states,
765
+ attentions=all_self_attns,
766
+ )
767
+
768
+
769
+ class LlamaForCausalLM(LlamaPreTrainedModel):
770
+ _tied_weights_keys = ["lm_head.weight"]
771
+
772
+ def __init__(self, config):
773
+ super().__init__(config)
774
+ self.model = LlamaModel(config)
775
+ self.vocab_size = config.vocab_size
776
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
777
+
778
+ # Initialize weights and apply final processing
779
+ self.post_init()
780
+
781
+ def get_input_embeddings(self):
782
+ return self.model.embed_tokens
783
+
784
+ def set_input_embeddings(self, value):
785
+ self.model.embed_tokens = value
786
+
787
+ def get_output_embeddings(self):
788
+ return self.lm_head
789
+
790
+ def set_output_embeddings(self, new_embeddings):
791
+ self.lm_head = new_embeddings
792
+
793
+ def set_decoder(self, decoder):
794
+ self.model = decoder
795
+
796
+ def get_decoder(self):
797
+ return self.model
798
+
799
+ def set_qformer_and_proj(self, Qformer, qformer_proj, llm_proj, query_tokens):
800
+ self.model.set_qformer_and_proj(Qformer, qformer_proj, llm_proj, query_tokens)
801
+
802
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
803
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
804
+ def forward(
805
+ self,
806
+ input_ids: torch.LongTensor = None,
807
+ attention_mask: Optional[torch.Tensor] = None,
808
+ position_ids: Optional[torch.LongTensor] = None,
809
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
810
+ inputs_embeds: Optional[torch.FloatTensor] = None,
811
+ query_embeds: Optional[torch.FloatTensor] = None,
812
+ labels: Optional[torch.LongTensor] = None,
813
+ use_cache: Optional[bool] = None,
814
+ output_attentions: Optional[bool] = None,
815
+ output_hidden_states: Optional[bool] = None,
816
+ return_dict: Optional[bool] = None,
817
+ update_layer: Optional[int] = 16,
818
+ image_position_list = None,
819
+ input_part_targets_len = None,
820
+ all_image_embeds = None,
821
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
822
+ r"""
823
+ Args:
824
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
825
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
826
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
827
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
828
+
829
+ Returns:
830
+
831
+ Example:
832
+
833
+ ```python
834
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
835
+
836
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
837
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
838
+
839
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
840
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
841
+
842
+ >>> # Generate
843
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
844
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
845
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
846
+ ```"""
847
+
848
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
849
+ output_hidden_states = (
850
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
851
+ )
852
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
853
+
854
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
855
+ outputs = self.model(
856
+ input_ids=input_ids,
857
+ attention_mask=attention_mask,
858
+ position_ids=position_ids,
859
+ past_key_values=past_key_values,
860
+ inputs_embeds=inputs_embeds,
861
+ query_embeds=query_embeds,
862
+ use_cache=use_cache,
863
+ output_attentions=output_attentions,
864
+ output_hidden_states=output_hidden_states,
865
+ return_dict=return_dict,
866
+ update_layer = update_layer,
867
+ image_position_list = image_position_list,
868
+ input_part_targets_len = input_part_targets_len,
869
+ all_image_embeds = all_image_embeds,
870
+ )
871
+
872
+ hidden_states = outputs[0]
873
+ if self.config.pretraining_tp > 1:
874
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
875
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
876
+ logits = torch.cat(logits, dim=-1)
877
+ else:
878
+ logits = self.lm_head(hidden_states)
879
+ logits = logits.float()
880
+
881
+ loss = None
882
+ if labels is not None:
883
+ # Shift so that tokens < n predict n
884
+ shift_logits = logits[..., :-1, :].contiguous()
885
+ shift_labels = labels[..., 1:].contiguous()
886
+ # Flatten the tokens
887
+ loss_fct = CrossEntropyLoss()
888
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
889
+ shift_labels = shift_labels.view(-1)
890
+ # Enable model parallelism
891
+ shift_labels = shift_labels.to(shift_logits.device)
892
+ loss = loss_fct(shift_logits, shift_labels)
893
+
894
+ if not return_dict:
895
+ output = (logits,) + outputs[1:]
896
+ return (loss,) + output if loss is not None else output
897
+
898
+ return CausalLMOutputWithPast(
899
+ loss=loss,
900
+ logits=logits,
901
+ past_key_values=outputs.past_key_values,
902
+ hidden_states=outputs.hidden_states,
903
+ attentions=outputs.attentions,
904
+ )
905
+
906
+ def prepare_inputs_for_generation(
907
+ self, input_ids, query_embeds=None, past_key_values=None, attention_mask=None, inputs_embeds=None,
908
+ image_position_list=None,input_part_targets_len=None,all_image_embeds=None,update_layer=16,**kwargs
909
+ ):
910
+ if past_key_values:
911
+ input_ids = input_ids[:, -1:]
912
+
913
+ position_ids = kwargs.get("position_ids", None)
914
+ if attention_mask is not None and position_ids is None:
915
+ # create position_ids on the fly for batch generation
916
+ position_ids = attention_mask.long().cumsum(-1) - 1
917
+ position_ids.masked_fill_(attention_mask == 0, 1)
918
+ if past_key_values:
919
+ position_ids = position_ids[:, -1].unsqueeze(-1)
920
+ query_embeds = None
921
+
922
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
923
+ if inputs_embeds is not None and past_key_values is None:
924
+ model_inputs = {"inputs_embeds": inputs_embeds}
925
+ else:
926
+ model_inputs = {"input_ids": input_ids}
927
+
928
+ model_inputs.update(
929
+ {
930
+ "position_ids": position_ids,
931
+ "query_embeds": query_embeds,
932
+ "past_key_values": past_key_values,
933
+ "use_cache": kwargs.get("use_cache"),
934
+ "attention_mask": attention_mask,
935
+ "image_position_list": image_position_list,
936
+ "input_part_targets_len": input_part_targets_len,
937
+ "all_image_embeds": all_image_embeds,
938
+ "update_layer": update_layer
939
+ }
940
+ )
941
+ return model_inputs
942
+
943
+ @staticmethod
944
+ def _reorder_cache(past_key_values, beam_idx):
945
+ reordered_past = ()
946
+ for layer_past in past_key_values:
947
+ reordered_past += (
948
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
949
+ )
950
+ return reordered_past
951
+
952
+
953
+ @add_start_docstrings(
954
+ """
955
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
956
+
957
+ [`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
958
+ (e.g. GPT-2) do.
959
+
960
+ Since it does classification on the last token, it requires to know the position of the last token. If a
961
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
962
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
963
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
964
+ each row of the batch).
965
+ """,
966
+ LLAMA_START_DOCSTRING,
967
+ )
968
+ class LlamaForSequenceClassification(LlamaPreTrainedModel):
969
+ def __init__(self, config):
970
+ super().__init__(config)
971
+ self.num_labels = config.num_labels
972
+ self.model = LlamaModel(config)
973
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
974
+
975
+ # Initialize weights and apply final processing
976
+ self.post_init()
977
+
978
+ def get_input_embeddings(self):
979
+ return self.model.embed_tokens
980
+
981
+ def set_input_embeddings(self, value):
982
+ self.model.embed_tokens = value
983
+
984
+ @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
985
+ def forward(
986
+ self,
987
+ input_ids: torch.LongTensor = None,
988
+ attention_mask: Optional[torch.Tensor] = None,
989
+ position_ids: Optional[torch.LongTensor] = None,
990
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
991
+ inputs_embeds: Optional[torch.FloatTensor] = None,
992
+ labels: Optional[torch.LongTensor] = None,
993
+ use_cache: Optional[bool] = None,
994
+ output_attentions: Optional[bool] = None,
995
+ output_hidden_states: Optional[bool] = None,
996
+ return_dict: Optional[bool] = None,
997
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
998
+ r"""
999
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1000
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1001
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1002
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1003
+ """
1004
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1005
+
1006
+ transformer_outputs = self.model(
1007
+ input_ids,
1008
+ attention_mask=attention_mask,
1009
+ position_ids=position_ids,
1010
+ past_key_values=past_key_values,
1011
+ inputs_embeds=inputs_embeds,
1012
+ use_cache=use_cache,
1013
+ output_attentions=output_attentions,
1014
+ output_hidden_states=output_hidden_states,
1015
+ return_dict=return_dict,
1016
+ )
1017
+ hidden_states = transformer_outputs[0]
1018
+ logits = self.score(hidden_states)
1019
+
1020
+ if input_ids is not None:
1021
+ batch_size = input_ids.shape[0]
1022
+ else:
1023
+ batch_size = inputs_embeds.shape[0]
1024
+
1025
+ if self.config.pad_token_id is None and batch_size != 1:
1026
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
1027
+ if self.config.pad_token_id is None:
1028
+ sequence_lengths = -1
1029
+ else:
1030
+ if input_ids is not None:
1031
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
1032
+ else:
1033
+ sequence_lengths = -1
1034
+
1035
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
1036
+
1037
+ loss = None
1038
+ if labels is not None:
1039
+ labels = labels.to(logits.device)
1040
+ if self.config.problem_type is None:
1041
+ if self.num_labels == 1:
1042
+ self.config.problem_type = "regression"
1043
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1044
+ self.config.problem_type = "single_label_classification"
1045
+ else:
1046
+ self.config.problem_type = "multi_label_classification"
1047
+
1048
+ if self.config.problem_type == "regression":
1049
+ loss_fct = MSELoss()
1050
+ if self.num_labels == 1:
1051
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
1052
+ else:
1053
+ loss = loss_fct(pooled_logits, labels)
1054
+ elif self.config.problem_type == "single_label_classification":
1055
+ loss_fct = CrossEntropyLoss()
1056
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
1057
+ elif self.config.problem_type == "multi_label_classification":
1058
+ loss_fct = BCEWithLogitsLoss()
1059
+ loss = loss_fct(pooled_logits, labels)
1060
+ if not return_dict:
1061
+ output = (pooled_logits,) + transformer_outputs[1:]
1062
+ return ((loss,) + output) if loss is not None else output
1063
+
1064
+ return SequenceClassifierOutputWithPast(
1065
+ loss=loss,
1066
+ logits=pooled_logits,
1067
+ past_key_values=transformer_outputs.past_key_values,
1068
+ hidden_states=transformer_outputs.hidden_states,
1069
+ attentions=transformer_outputs.attentions,
1070
+ )
cheetah/processors/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2022, salesforce.com, inc.
3
+ All rights reserved.
4
+ SPDX-License-Identifier: BSD-3-Clause
5
+ For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ """
7
+
8
+ from cheetah.processors.base_processor import BaseProcessor
9
+ from cheetah.processors.blip_processors import (
10
+ Blip2ImageTrainProcessor,
11
+ Blip2ImageEvalProcessor,
12
+ BlipCaptionProcessor,
13
+ )
14
+
15
+ from cheetah.common.registry import registry
16
+
17
+ __all__ = [
18
+ "BaseProcessor",
19
+ "Blip2ImageTrainProcessor",
20
+ "Blip2ImageEvalProcessor",
21
+ "BlipCaptionProcessor",
22
+ ]
23
+
24
+
25
+ def load_processor(name, cfg=None):
26
+ """
27
+ Example
28
+
29
+ >>> processor = load_processor("alpro_video_train", cfg=None)
30
+ """
31
+ processor = registry.get_processor_class(name).from_config(cfg)
32
+
33
+ return processor
cheetah/processors/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (998 Bytes). View file
 
cheetah/processors/__pycache__/base_processor.cpython-310.pyc ADDED
Binary file (1.36 kB). View file