Johnny050407 commited on
Commit
9ed01de
·
verified ·
1 Parent(s): 5ebc0df

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. .gitignore +9 -0
  3. DeQA-Score/.gitignore +23 -0
  4. DeQA-Score/DeQA_Score.egg-info/PKG-INFO +322 -0
  5. DeQA-Score/DeQA_Score.egg-info/SOURCES.txt +36 -0
  6. DeQA-Score/DeQA_Score.egg-info/dependency_links.txt +1 -0
  7. DeQA-Score/DeQA_Score.egg-info/requires.txt +30 -0
  8. DeQA-Score/DeQA_Score.egg-info/top_level.txt +4 -0
  9. DeQA-Score/LICENSE +21 -0
  10. DeQA-Score/README.md +281 -0
  11. DeQA-Score/build_soft_labels/config.json +35 -0
  12. DeQA-Score/build_soft_labels/gen_soft_label.py +209 -0
  13. DeQA-Score/fig/boy_colorful.jpg +3 -0
  14. DeQA-Score/fig/model.png +3 -0
  15. DeQA-Score/fig/singapore_flyer.jpg +3 -0
  16. DeQA-Score/fig/teaser.png +3 -0
  17. DeQA-Score/preprocessor/preprocessor_config.json +19 -0
  18. DeQA-Score/preprocessor/special_tokens_map.json +24 -0
  19. DeQA-Score/preprocessor/tokenizer.model +3 -0
  20. DeQA-Score/preprocessor/tokenizer_config.json +35 -0
  21. DeQA-Score/pyproject.toml +35 -0
  22. DeQA-Score/scripts/eval_dist.sh +23 -0
  23. DeQA-Score/scripts/eval_score.sh +23 -0
  24. DeQA-Score/scripts/infer.sh +17 -0
  25. DeQA-Score/scripts/infer_lora.sh +18 -0
  26. DeQA-Score/scripts/train.sh +48 -0
  27. DeQA-Score/scripts/train_lora.sh +49 -0
  28. DeQA-Score/scripts/zero3.json +28 -0
  29. DeQA-Score/scripts/zero3_offload.json +56 -0
  30. DeQA-Score/src/__init__.py +2 -0
  31. DeQA-Score/src/constants.py +9 -0
  32. DeQA-Score/src/conversation.py +301 -0
  33. DeQA-Score/src/datasets/__init__.py +11 -0
  34. DeQA-Score/src/datasets/pair_dataset.py +276 -0
  35. DeQA-Score/src/datasets/single_dataset.py +244 -0
  36. DeQA-Score/src/datasets/utils.py +317 -0
  37. DeQA-Score/src/evaluate/__init__.py +1 -0
  38. DeQA-Score/src/evaluate/cal_distribution_gap.py +143 -0
  39. DeQA-Score/src/evaluate/cal_plcc_srcc.py +115 -0
  40. DeQA-Score/src/evaluate/eval_qbench_mcq.py +138 -0
  41. DeQA-Score/src/evaluate/iqa_eval.py +184 -0
  42. DeQA-Score/src/evaluate/scorer.py +63 -0
  43. DeQA-Score/src/evaluate/scorer_coco.py +103 -0
  44. DeQA-Score/src/mm_utils.py +112 -0
  45. DeQA-Score/src/model/__init__.py +2 -0
  46. DeQA-Score/src/model/builder.py +166 -0
  47. DeQA-Score/src/model/configuration_mplug_owl2.py +334 -0
  48. DeQA-Score/src/model/convert_mplug_owl2_weight_to_hf.py +395 -0
  49. DeQA-Score/src/model/modeling_attn_mask_utils.py +247 -0
  50. DeQA-Score/src/model/modeling_llama2.py +834 -0
.gitattributes CHANGED
@@ -34,3 +34,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  CLIP-Huge-Flickr-Flat/faiss_IVPQ_PCA.index filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  CLIP-Huge-Flickr-Flat/faiss_IVPQ_PCA.index filter=lfs diff=lfs merge=lfs -text
37
+ DeQA-Score/fig/boy_colorful.jpg filter=lfs diff=lfs merge=lfs -text
38
+ DeQA-Score/fig/model.png filter=lfs diff=lfs merge=lfs -text
39
+ DeQA-Score/fig/singapore_flyer.jpg filter=lfs diff=lfs merge=lfs -text
40
+ DeQA-Score/fig/teaser.png filter=lfs diff=lfs merge=lfs -text
41
+ imgs/results1.png filter=lfs diff=lfs merge=lfs -text
42
+ imgs/results2.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ coco_data/
2
+ __pycache__/
3
+ coco_feats/
4
+ coco_faiss_indexes/
5
+ faiss_indexes/
6
+ processed_data/
7
+ outputs/
8
+ wandb/
9
+ results/
DeQA-Score/.gitignore ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # source file related
2
+ *__pycache__*
3
+ *.pyc
4
+ *.o
5
+ *.so
6
+ *.egg
7
+ *.egg-info
8
+
9
+ # training related
10
+ *log*
11
+ *.log
12
+ *.pth
13
+ *.pt
14
+
15
+ # result related
16
+ *answer*
17
+ *ckpt*
18
+ *.json
19
+ *.jsonl
20
+ res_*
21
+
22
+ # mac hidden
23
+ *.DS_Store*
DeQA-Score/DeQA_Score.egg-info/PKG-INFO ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: DeQA-Score
3
+ Version: 1.2.0
4
+ Summary: Teaching Large Language Models to Regress Accurate Image Quality Scores using Score Distribution (based on mPLUG-Owl2)
5
+ Project-URL: Bug Tracker, https://github.com/zhiyuanyou/DeQA-Score/issues
6
+ Classifier: Programming Language :: Python :: 3
7
+ Classifier: License :: OSI Approved :: Apache Software License
8
+ Requires-Python: >=3.8
9
+ Description-Content-Type: text/markdown
10
+ License-File: LICENSE
11
+ Requires-Dist: torch==2.0.1
12
+ Requires-Dist: torchvision==0.15.2
13
+ Requires-Dist: transformers==4.36.1
14
+ Requires-Dist: tokenizers==0.15.0
15
+ Requires-Dist: sentencepiece==0.1.99
16
+ Requires-Dist: shortuuid
17
+ Requires-Dist: accelerate==0.21.0
18
+ Requires-Dist: peft==0.4.0
19
+ Requires-Dist: bitsandbytes==0.41.0
20
+ Requires-Dist: pydantic<2,>=1
21
+ Requires-Dist: markdown2[all]
22
+ Requires-Dist: numpy
23
+ Requires-Dist: scikit-learn==1.2.2
24
+ Requires-Dist: gradio==3.35.2
25
+ Requires-Dist: gradio_client==0.2.9
26
+ Requires-Dist: requests
27
+ Requires-Dist: httpx==0.24.0
28
+ Requires-Dist: uvicorn
29
+ Requires-Dist: fastapi
30
+ Requires-Dist: icecream
31
+ Requires-Dist: einops==0.6.1
32
+ Requires-Dist: einops-exts==0.0.4
33
+ Requires-Dist: timm==0.6.13
34
+ Requires-Dist: decord
35
+ Requires-Dist: scipy
36
+ Provides-Extra: train
37
+ Requires-Dist: deepspeed==0.9.5; extra == "train"
38
+ Requires-Dist: ninja; extra == "train"
39
+ Requires-Dist: wandb; extra == "train"
40
+ Dynamic: license-file
41
+
42
+ <div align="center">
43
+ <h1>Teaching Large Language Models to Regress Accurate Image Quality Scores using Score Distribution</h1>
44
+
45
+ <div>
46
+ <a href="https://zhiyuanyou.github.io/" target="_blank">Zhiyuan You</a><sup>12</sup>,
47
+ <a href="https://caixin98.github.io/" target="_blank">Xin Cai</a><sup>2</sup>,
48
+ <a href="https://www.jasongt.com/" target="_blank">Jinjin Gu</a><sup>4</sup>,
49
+ <a href="https://tianfan.info/" target="_blank">Tianfan Xue</a><sup>235</sup><sup>#</sup>,
50
+ <a href="https://xpixel.group/2010/01/20/chaodong.html" target="_blank">Chao Dong</a><sup>134</sup><sup>#</sup>
51
+ </div>
52
+
53
+ <div>
54
+ <sup>1</sup>Shenzhen Institutes of Advanced Technology, Chinese Academy of Sciences, <sup>2</sup>Multimedia Laboratory, The Chinese University of Hong Kong,
55
+ <sup>3</sup>Shanghai AI Laboratory, <sup>4</sup>Shenzhen University of Advanced Technology, <sup>5</sup>CPII under InnoHK
56
+ </div>
57
+
58
+ <div><sup>#</sup>Corresponding author.</div>
59
+
60
+ <div>
61
+ <a href="https://depictqa.github.io/deqa-score/" target="_blank"><strong>Homepage</strong></a> |
62
+ <strong>Model Weights</strong> (
63
+ <a href="https://huggingface.co/zhiyuanyou/DeQA-Score-Mix3" target="_blank"><strong>Full Tuning</strong></a> /
64
+ <a href="https://huggingface.co/zhiyuanyou/DeQA-Score-LoRA-Mix3" target="_blank"><strong>LoRA Tuning</strong></a>
65
+ ) |
66
+ <a href="https://huggingface.co/datasets/zhiyuanyou/Data-DeQA-Score" target="_blank"><strong>Datasets</strong></a> |
67
+ <a href="https://arxiv.org/abs/2501.11561" target="_blank"><strong>Paper</strong></a>
68
+ </div>
69
+
70
+ <h2>Motivation</h2>
71
+
72
+ <div style="width: 100%; text-align: center; margin:auto;">
73
+ <img style="width: 75%" src="fig/teaser.png">
74
+ </div>
75
+
76
+ <h2>Model Architecture</h2>
77
+
78
+ <div style="width: 100%; text-align: center; margin:auto;">
79
+ <img style="width: 100%" src="fig/model.png">
80
+ </div>
81
+ </div>
82
+
83
+
84
+ ## [Installation Free!] Quicker Start with Hugging Face AutoModel
85
+
86
+ [2025.12] Thanks to @[lyf1212](https://github.com/lyf1212)'s suggestion, we add support on `transformers==4.46.3` with minor code modifications. See [details](https://github.com/zhiyuanyou/DeQA-Score/issues/32).
87
+
88
+ The following code could be run directly with `transformers==4.36.1`. No need to install this GitHub repo.
89
+
90
+ ```python
91
+ import requests
92
+ import torch
93
+ from transformers import AutoModelForCausalLM
94
+
95
+ model = AutoModelForCausalLM.from_pretrained(
96
+ "zhiyuanyou/DeQA-Score-Mix3",
97
+ trust_remote_code=True,
98
+ attn_implementation="eager",
99
+ torch_dtype=torch.float16,
100
+ device_map="auto",
101
+ )
102
+
103
+ from PIL import Image
104
+
105
+ # The inputs should be a list of multiple PIL images
106
+ model.score(
107
+ [Image.open(requests.get(
108
+ "https://raw.githubusercontent.com/zhiyuanyou/DeQA-Score/main/fig/singapore_flyer.jpg", stream=True
109
+ ).raw)]
110
+ )
111
+ ```
112
+
113
+ ## Installation
114
+
115
+ If you only need to infer / evaluate:
116
+
117
+ ```shell
118
+ git clone https://github.com/zhiyuanyou/DeQA-Score.git
119
+ cd DeQA-Score
120
+ pip install -e .
121
+ ```
122
+
123
+ For training, you need to further install additional dependencies as follows:
124
+
125
+ ```shell
126
+ pip install -e ".[train]"
127
+ pip install flash_attn --no-build-isolation
128
+ ```
129
+
130
+
131
+ ## Quick Start
132
+
133
+
134
+ ### Image Quality Scorer
135
+
136
+ - CLI Interface
137
+
138
+ ```shell
139
+ python src/evaluate/scorer.py --img_path fig/singapore_flyer.jpg
140
+ ```
141
+
142
+ - Python API
143
+
144
+ ```python
145
+ from src import Scorer
146
+ from PIL import Image
147
+
148
+ scorer = Scorer()
149
+ img_list = [Image.open("fig/singapore_flyer.jpg")] # can be a list of multiple PIL images
150
+ print(scorer(img_list).tolist())
151
+ ```
152
+
153
+
154
+ ## Training, Inference & Evaluation
155
+
156
+ ### Datasets
157
+
158
+ <a id="datasets"></a>
159
+
160
+ - Download our meta files from [Huggingface Metas](https://huggingface.co/datasets/zhiyuanyou/Data-DeQA-Score).
161
+
162
+ - Download source images from [KonIQ](https://database.mmsp-kn.de/koniq-10k-database.html),
163
+ [SPAQ](https://github.com/h4nwei/SPAQ),
164
+ [KADID](https://database.mmsp-kn.de/kadid-10k-database.html),
165
+ [PIPAL](https://github.com/HaomingCai/PIPAL-dataset),
166
+ [LIVE-Wild](https://live.ece.utexas.edu/research/ChallengeDB/index.html),
167
+ [AGIQA](https://github.com/lcysyzxdxc/AGIQA-3k-Database),
168
+ [TID2013](https://www.ponomarenko.info/tid2013.htm),
169
+ and [CSIQ](https://s2.smu.edu/~eclarson/csiq.html).
170
+
171
+ - Arrange the folders as follows:
172
+
173
+ ```
174
+ |-- DeQA-Score
175
+ |-- Data-DeQA-Score
176
+ |-- KONIQ
177
+ |-- images/*.jpg
178
+ |-- metas
179
+ |-- SPAQ
180
+ |-- images/*.jpg
181
+ |-- metas
182
+ |-- KADID10K
183
+ |-- images/*.png
184
+ |-- metas
185
+ |-- PIPAL
186
+ |-- images/Distortion_*/*.bmp
187
+ |-- metas
188
+ |-- LIVE-WILD
189
+ |-- images/*.bmp
190
+ |-- metas
191
+ |-- AGIQA3K
192
+ |-- images/*.jpg
193
+ |-- metas
194
+ |-- TID2013
195
+ |-- images/distorted_images/*.bmp
196
+ |-- metas
197
+ |-- CSIQ
198
+ |-- images/dst_imgs/*/*.png
199
+ |-- metas
200
+ ```
201
+
202
+ ### Pretrained Weights
203
+
204
+ <a id="pretrained_weights"></a>
205
+
206
+ We provide two model weights (full tuning and LoRA tuning) with similar performance.
207
+
208
+ | | Training Datasets | Weights |
209
+ |-----|-----|-----|
210
+ | Full Tuning | KonIQ, SPAQ, KADID | [Huggingface Full](https://huggingface.co/zhiyuanyou/DeQA-Score-Mix3) |
211
+ | LoRA Tuning | KonIQ, SPAQ, KADID | [Huggingface LoRA](https://huggingface.co/zhiyuanyou/DeQA-Score-LoRA-Mix3) |
212
+
213
+ Download one of the above model weights, then arrange the folders as follows:
214
+
215
+ ```
216
+ |-- DeQA-Score
217
+ |-- checkpoints
218
+ |-- DeQA-Score-Mix3
219
+ |-- DeQA-Score-LoRA-Mix3
220
+ ```
221
+
222
+ If you would like to use the LoRA tuning weights, you need to download the base mPLUG-Owl2 weights from [Huggingface mPLUG-Owl2](https://huggingface.co/MAGAer13/mplug-owl2-llama2-7b), then arrange the folders as follows:
223
+
224
+ ```
225
+ |-- DeQA-Score
226
+ |-- ModelZoo
227
+ |-- mplug-owl2-llama2-7b
228
+ ```
229
+
230
+ ### Inference
231
+
232
+ After preparing the datasets, you can infer using pre-trained **DeQA-Score** or **DeQA-Score-LoRA**:
233
+
234
+ ```shell
235
+ sh scripts/infer.sh $ONE_GPU_ID
236
+ ```
237
+
238
+ ```shell
239
+ sh scripts/infer_lora.sh $ONE_GPU_ID
240
+ ```
241
+
242
+ ### Evaluation
243
+
244
+ After inference, you can evaluate the inference results:
245
+
246
+ - SRCC / PLCC for quality score.
247
+
248
+ ```shell
249
+ sh scripts/eval_score.sh
250
+ ```
251
+
252
+ - KL Divergence / JS Divergence / Wasserstein Distance for score distribution.
253
+
254
+ ```shell
255
+ sh scripts/eval_dist.sh
256
+ ```
257
+
258
+ ### Fine-tuning
259
+
260
+ Fine-tuning needs to download the mPLUG-Owl2 weights as in [Pretrained Weights](#pretrained_weights).
261
+
262
+ #### LoRA Fine-tuning
263
+
264
+ - Only **2 RTX3090 GPUs** are required. Revise `--data_paths` in the training shell to load different datasets. Default training datasets are KonIQ, SPAQ, and KADID.
265
+
266
+ ```shell
267
+ sh scripts/train_lora.sh $GPU_IDs
268
+ ```
269
+
270
+ #### Full Fine-tuning from the Scratch
271
+
272
+ - At least **8 A6000 GPUs** or **4 A100 GPUs** will be enough. Revise `--data_paths` in the training shell to load different datasets. Default training datasets are KonIQ, SPAQ, and KADID.
273
+
274
+ ```shell
275
+ sh scripts/train.sh $GPU_IDs
276
+ ```
277
+
278
+
279
+ ## Soft Label Construction
280
+
281
+ - Download `split.json` (training & test split info) and `mos.json` (mos & std info) of KonIQ, SPAQ, and KADID from [Huggingface Metas](https://huggingface.co/datasets/zhiyuanyou/Data-DeQA-Score), and arrange the folders as in [Datasets](#datasets).
282
+
283
+ - Run the following scripts to construct the distribution-based soft labels.
284
+
285
+ ```shell
286
+ cd build_soft_labels
287
+ python gen_soft_label.py
288
+ ```
289
+
290
+
291
+ ## Acknowledgements
292
+
293
+ This work is based on [Q-Align](https://github.com/Q-Future/Q-Align). Sincerely thanks for this awesome work.
294
+
295
+ ## Citation
296
+
297
+ If you find our work useful for your research and applications, please cite using the BibTeX:
298
+
299
+ ```bibtex
300
+ @inproceedings{deqa_score,
301
+ title={Teaching Large Language Models to Regress Accurate Image Quality Scores using Score Distribution},
302
+ author={You, Zhiyuan and Cai, Xin and Gu, Jinjin and Xue, Tianfan and Dong, Chao},
303
+ booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition},
304
+ pages={14483--14494},
305
+ year={2025}
306
+ }
307
+
308
+ @article{depictqa_v2,
309
+ title={Enhancing Descriptive Image Quality Assessment with A Large-scale Multi-modal Dataset},
310
+ author={You, Zhiyuan and Gu, Jinjin and Cai, Xin and Li, Zheyuan and Zhu, Kaiwen and Dong, Chao and Xue, Tianfan},
311
+ journal={IEEE Transactions on Image Processing},
312
+ year={2025}
313
+ }
314
+
315
+ @inproceedings{depictqa_v1,
316
+ title={Depicting Beyond Scores: Advancing Image Quality Assessment through Multi-modal Language Models},
317
+ author={You, Zhiyuan and Li, Zheyuan and Gu, Jinjin and Yin, Zhenfei and Xue, Tianfan and Dong, Chao},
318
+ booktitle={European Conference on Computer Vision},
319
+ pages={259--276},
320
+ year={2024}
321
+ }
322
+ ```
DeQA-Score/DeQA_Score.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ DeQA_Score.egg-info/PKG-INFO
5
+ DeQA_Score.egg-info/SOURCES.txt
6
+ DeQA_Score.egg-info/dependency_links.txt
7
+ DeQA_Score.egg-info/requires.txt
8
+ DeQA_Score.egg-info/top_level.txt
9
+ build_soft_labels/gen_soft_label.py
10
+ src/__init__.py
11
+ src/constants.py
12
+ src/conversation.py
13
+ src/mm_utils.py
14
+ src/utils.py
15
+ src/datasets/__init__.py
16
+ src/datasets/pair_dataset.py
17
+ src/datasets/single_dataset.py
18
+ src/datasets/utils.py
19
+ src/evaluate/__init__.py
20
+ src/evaluate/cal_distribution_gap.py
21
+ src/evaluate/cal_plcc_srcc.py
22
+ src/evaluate/eval_qbench_mcq.py
23
+ src/evaluate/iqa_eval.py
24
+ src/evaluate/scorer.py
25
+ src/evaluate/scorer_coco.py
26
+ src/model/__init__.py
27
+ src/model/builder.py
28
+ src/model/configuration_mplug_owl2.py
29
+ src/model/convert_mplug_owl2_weight_to_hf.py
30
+ src/model/modeling_attn_mask_utils.py
31
+ src/model/modeling_llama2.py
32
+ src/model/modeling_mplug_owl2.py
33
+ src/model/utils.py
34
+ src/model/visual_encoder.py
35
+ src/train/mplug_owl2_trainer.py
36
+ src/train/train_mem.py
DeQA-Score/DeQA_Score.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
DeQA-Score/DeQA_Score.egg-info/requires.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ torchvision==0.15.2
3
+ transformers==4.36.1
4
+ tokenizers==0.15.0
5
+ sentencepiece==0.1.99
6
+ shortuuid
7
+ accelerate==0.21.0
8
+ peft==0.4.0
9
+ bitsandbytes==0.41.0
10
+ pydantic<2,>=1
11
+ markdown2[all]
12
+ numpy
13
+ scikit-learn==1.2.2
14
+ gradio==3.35.2
15
+ gradio_client==0.2.9
16
+ requests
17
+ httpx==0.24.0
18
+ uvicorn
19
+ fastapi
20
+ icecream
21
+ einops==0.6.1
22
+ einops-exts==0.0.4
23
+ timm==0.6.13
24
+ decord
25
+ scipy
26
+
27
+ [train]
28
+ deepspeed==0.9.5
29
+ ninja
30
+ wandb
DeQA-Score/DeQA_Score.egg-info/top_level.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ build_soft_labels
2
+ fig
3
+ preprocessor
4
+ src
DeQA-Score/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Depicted image Quality Assessment (DepictQA / DeQA)
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
DeQA-Score/README.md ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <h1>Teaching Large Language Models to Regress Accurate Image Quality Scores using Score Distribution</h1>
3
+
4
+ <div>
5
+ <a href="https://zhiyuanyou.github.io/" target="_blank">Zhiyuan You</a><sup>12</sup>,
6
+ <a href="https://caixin98.github.io/" target="_blank">Xin Cai</a><sup>2</sup>,
7
+ <a href="https://www.jasongt.com/" target="_blank">Jinjin Gu</a><sup>4</sup>,
8
+ <a href="https://tianfan.info/" target="_blank">Tianfan Xue</a><sup>235</sup><sup>#</sup>,
9
+ <a href="https://xpixel.group/2010/01/20/chaodong.html" target="_blank">Chao Dong</a><sup>134</sup><sup>#</sup>
10
+ </div>
11
+
12
+ <div>
13
+ <sup>1</sup>Shenzhen Institutes of Advanced Technology, Chinese Academy of Sciences, <sup>2</sup>Multimedia Laboratory, The Chinese University of Hong Kong,
14
+ <sup>3</sup>Shanghai AI Laboratory, <sup>4</sup>Shenzhen University of Advanced Technology, <sup>5</sup>CPII under InnoHK
15
+ </div>
16
+
17
+ <div><sup>#</sup>Corresponding author.</div>
18
+
19
+ <div>
20
+ <a href="https://depictqa.github.io/deqa-score/" target="_blank"><strong>Homepage</strong></a> |
21
+ <strong>Model Weights</strong> (
22
+ <a href="https://huggingface.co/zhiyuanyou/DeQA-Score-Mix3" target="_blank"><strong>Full Tuning</strong></a> /
23
+ <a href="https://huggingface.co/zhiyuanyou/DeQA-Score-LoRA-Mix3" target="_blank"><strong>LoRA Tuning</strong></a>
24
+ ) |
25
+ <a href="https://huggingface.co/datasets/zhiyuanyou/Data-DeQA-Score" target="_blank"><strong>Datasets</strong></a> |
26
+ <a href="https://arxiv.org/abs/2501.11561" target="_blank"><strong>Paper</strong></a>
27
+ </div>
28
+
29
+ <h2>Motivation</h2>
30
+
31
+ <div style="width: 100%; text-align: center; margin:auto;">
32
+ <img style="width: 75%" src="fig/teaser.png">
33
+ </div>
34
+
35
+ <h2>Model Architecture</h2>
36
+
37
+ <div style="width: 100%; text-align: center; margin:auto;">
38
+ <img style="width: 100%" src="fig/model.png">
39
+ </div>
40
+ </div>
41
+
42
+
43
+ ## [Installation Free!] Quicker Start with Hugging Face AutoModel
44
+
45
+ [2025.12] Thanks to @[lyf1212](https://github.com/lyf1212)'s suggestion, we add support on `transformers==4.46.3` with minor code modifications. See [details](https://github.com/zhiyuanyou/DeQA-Score/issues/32).
46
+
47
+ The following code could be run directly with `transformers==4.36.1`. No need to install this GitHub repo.
48
+
49
+ ```python
50
+ import requests
51
+ import torch
52
+ from transformers import AutoModelForCausalLM
53
+
54
+ model = AutoModelForCausalLM.from_pretrained(
55
+ "zhiyuanyou/DeQA-Score-Mix3",
56
+ trust_remote_code=True,
57
+ attn_implementation="eager",
58
+ torch_dtype=torch.float16,
59
+ device_map="auto",
60
+ )
61
+
62
+ from PIL import Image
63
+
64
+ # The inputs should be a list of multiple PIL images
65
+ model.score(
66
+ [Image.open(requests.get(
67
+ "https://raw.githubusercontent.com/zhiyuanyou/DeQA-Score/main/fig/singapore_flyer.jpg", stream=True
68
+ ).raw)]
69
+ )
70
+ ```
71
+
72
+ ## Installation
73
+
74
+ If you only need to infer / evaluate:
75
+
76
+ ```shell
77
+ git clone https://github.com/zhiyuanyou/DeQA-Score.git
78
+ cd DeQA-Score
79
+ pip install -e .
80
+ ```
81
+
82
+ For training, you need to further install additional dependencies as follows:
83
+
84
+ ```shell
85
+ pip install -e ".[train]"
86
+ pip install flash_attn --no-build-isolation
87
+ ```
88
+
89
+
90
+ ## Quick Start
91
+
92
+
93
+ ### Image Quality Scorer
94
+
95
+ - CLI Interface
96
+
97
+ ```shell
98
+ python src/evaluate/scorer.py --img_path fig/singapore_flyer.jpg
99
+ ```
100
+
101
+ - Python API
102
+
103
+ ```python
104
+ from src import Scorer
105
+ from PIL import Image
106
+
107
+ scorer = Scorer()
108
+ img_list = [Image.open("fig/singapore_flyer.jpg")] # can be a list of multiple PIL images
109
+ print(scorer(img_list).tolist())
110
+ ```
111
+
112
+
113
+ ## Training, Inference & Evaluation
114
+
115
+ ### Datasets
116
+
117
+ <a id="datasets"></a>
118
+
119
+ - Download our meta files from [Huggingface Metas](https://huggingface.co/datasets/zhiyuanyou/Data-DeQA-Score).
120
+
121
+ - Download source images from [KonIQ](https://database.mmsp-kn.de/koniq-10k-database.html),
122
+ [SPAQ](https://github.com/h4nwei/SPAQ),
123
+ [KADID](https://database.mmsp-kn.de/kadid-10k-database.html),
124
+ [PIPAL](https://github.com/HaomingCai/PIPAL-dataset),
125
+ [LIVE-Wild](https://live.ece.utexas.edu/research/ChallengeDB/index.html),
126
+ [AGIQA](https://github.com/lcysyzxdxc/AGIQA-3k-Database),
127
+ [TID2013](https://www.ponomarenko.info/tid2013.htm),
128
+ and [CSIQ](https://s2.smu.edu/~eclarson/csiq.html).
129
+
130
+ - Arrange the folders as follows:
131
+
132
+ ```
133
+ |-- DeQA-Score
134
+ |-- Data-DeQA-Score
135
+ |-- KONIQ
136
+ |-- images/*.jpg
137
+ |-- metas
138
+ |-- SPAQ
139
+ |-- images/*.jpg
140
+ |-- metas
141
+ |-- KADID10K
142
+ |-- images/*.png
143
+ |-- metas
144
+ |-- PIPAL
145
+ |-- images/Distortion_*/*.bmp
146
+ |-- metas
147
+ |-- LIVE-WILD
148
+ |-- images/*.bmp
149
+ |-- metas
150
+ |-- AGIQA3K
151
+ |-- images/*.jpg
152
+ |-- metas
153
+ |-- TID2013
154
+ |-- images/distorted_images/*.bmp
155
+ |-- metas
156
+ |-- CSIQ
157
+ |-- images/dst_imgs/*/*.png
158
+ |-- metas
159
+ ```
160
+
161
+ ### Pretrained Weights
162
+
163
+ <a id="pretrained_weights"></a>
164
+
165
+ We provide two model weights (full tuning and LoRA tuning) with similar performance.
166
+
167
+ | | Training Datasets | Weights |
168
+ |-----|-----|-----|
169
+ | Full Tuning | KonIQ, SPAQ, KADID | [Huggingface Full](https://huggingface.co/zhiyuanyou/DeQA-Score-Mix3) |
170
+ | LoRA Tuning | KonIQ, SPAQ, KADID | [Huggingface LoRA](https://huggingface.co/zhiyuanyou/DeQA-Score-LoRA-Mix3) |
171
+
172
+ Download one of the above model weights, then arrange the folders as follows:
173
+
174
+ ```
175
+ |-- DeQA-Score
176
+ |-- checkpoints
177
+ |-- DeQA-Score-Mix3
178
+ |-- DeQA-Score-LoRA-Mix3
179
+ ```
180
+
181
+ If you would like to use the LoRA tuning weights, you need to download the base mPLUG-Owl2 weights from [Huggingface mPLUG-Owl2](https://huggingface.co/MAGAer13/mplug-owl2-llama2-7b), then arrange the folders as follows:
182
+
183
+ ```
184
+ |-- DeQA-Score
185
+ |-- ModelZoo
186
+ |-- mplug-owl2-llama2-7b
187
+ ```
188
+
189
+ ### Inference
190
+
191
+ After preparing the datasets, you can infer using pre-trained **DeQA-Score** or **DeQA-Score-LoRA**:
192
+
193
+ ```shell
194
+ sh scripts/infer.sh $ONE_GPU_ID
195
+ ```
196
+
197
+ ```shell
198
+ sh scripts/infer_lora.sh $ONE_GPU_ID
199
+ ```
200
+
201
+ ### Evaluation
202
+
203
+ After inference, you can evaluate the inference results:
204
+
205
+ - SRCC / PLCC for quality score.
206
+
207
+ ```shell
208
+ sh scripts/eval_score.sh
209
+ ```
210
+
211
+ - KL Divergence / JS Divergence / Wasserstein Distance for score distribution.
212
+
213
+ ```shell
214
+ sh scripts/eval_dist.sh
215
+ ```
216
+
217
+ ### Fine-tuning
218
+
219
+ Fine-tuning needs to download the mPLUG-Owl2 weights as in [Pretrained Weights](#pretrained_weights).
220
+
221
+ #### LoRA Fine-tuning
222
+
223
+ - Only **2 RTX3090 GPUs** are required. Revise `--data_paths` in the training shell to load different datasets. Default training datasets are KonIQ, SPAQ, and KADID.
224
+
225
+ ```shell
226
+ sh scripts/train_lora.sh $GPU_IDs
227
+ ```
228
+
229
+ #### Full Fine-tuning from the Scratch
230
+
231
+ - At least **8 A6000 GPUs** or **4 A100 GPUs** will be enough. Revise `--data_paths` in the training shell to load different datasets. Default training datasets are KonIQ, SPAQ, and KADID.
232
+
233
+ ```shell
234
+ sh scripts/train.sh $GPU_IDs
235
+ ```
236
+
237
+
238
+ ## Soft Label Construction
239
+
240
+ - Download `split.json` (training & test split info) and `mos.json` (mos & std info) of KonIQ, SPAQ, and KADID from [Huggingface Metas](https://huggingface.co/datasets/zhiyuanyou/Data-DeQA-Score), and arrange the folders as in [Datasets](#datasets).
241
+
242
+ - Run the following scripts to construct the distribution-based soft labels.
243
+
244
+ ```shell
245
+ cd build_soft_labels
246
+ python gen_soft_label.py
247
+ ```
248
+
249
+
250
+ ## Acknowledgements
251
+
252
+ This work is based on [Q-Align](https://github.com/Q-Future/Q-Align). Sincerely thanks for this awesome work.
253
+
254
+ ## Citation
255
+
256
+ If you find our work useful for your research and applications, please cite using the BibTeX:
257
+
258
+ ```bibtex
259
+ @inproceedings{deqa_score,
260
+ title={Teaching Large Language Models to Regress Accurate Image Quality Scores using Score Distribution},
261
+ author={You, Zhiyuan and Cai, Xin and Gu, Jinjin and Xue, Tianfan and Dong, Chao},
262
+ booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition},
263
+ pages={14483--14494},
264
+ year={2025}
265
+ }
266
+
267
+ @article{depictqa_v2,
268
+ title={Enhancing Descriptive Image Quality Assessment with A Large-scale Multi-modal Dataset},
269
+ author={You, Zhiyuan and Gu, Jinjin and Cai, Xin and Li, Zheyuan and Zhu, Kaiwen and Dong, Chao and Xue, Tianfan},
270
+ journal={IEEE Transactions on Image Processing},
271
+ year={2025}
272
+ }
273
+
274
+ @inproceedings{depictqa_v1,
275
+ title={Depicting Beyond Scores: Advancing Image Quality Assessment through Multi-modal Language Models},
276
+ author={You, Zhiyuan and Li, Zheyuan and Gu, Jinjin and Yin, Zhenfei and Xue, Tianfan and Dong, Chao},
277
+ booktitle={European Conference on Computer Vision},
278
+ pages={259--276},
279
+ year={2024}
280
+ }
281
+ ```
DeQA-Score/build_soft_labels/config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "answer": "The quality of the image is {}.",
3
+ "dataset_params": {
4
+ "koniq": {
5
+ "split_json": "../../Data-DeQA-Score/KONIQ/metas/split.json",
6
+ "mos_json": "../../Data-DeQA-Score/KONIQ/metas/mos.json",
7
+ "save_train": "../../Data-DeQA-Score/KONIQ/metas/train_koniq_7k_new.json",
8
+ "save_test": "../../Data-DeQA-Score/KONIQ/metas/test_koniq_2k_new.json",
9
+ "img_dir": "KONIQ/images",
10
+ "density_type": "pdf",
11
+ "thre_std": 0.2,
12
+ "thre_diff": 0.1
13
+ },
14
+ "spaq": {
15
+ "split_json": "../../Data-DeQA-Score/SPAQ/metas/split.json",
16
+ "mos_json": "../../Data-DeQA-Score/SPAQ/metas/mos.json",
17
+ "save_train": "../../Data-DeQA-Score/SPAQ/metas/train_spaq_9k_new.json",
18
+ "save_test": "../../Data-DeQA-Score/SPAQ/metas/test_spaq_2k_new.json",
19
+ "img_dir": "SPAQ/images",
20
+ "density_type": "cdf",
21
+ "thre_std": 0.2,
22
+ "thre_diff": 0.1
23
+ },
24
+ "kadid": {
25
+ "split_json": "../../Data-DeQA-Score/KADID10K/metas/split.json",
26
+ "mos_json": "../../Data-DeQA-Score/KADID10K/metas/mos.json",
27
+ "save_train": "../../Data-DeQA-Score/KADID10K/metas/train_kadid_8k_new.json",
28
+ "save_test": "../../Data-DeQA-Score/KADID10K/metas/test_kadid_2k_new.json",
29
+ "img_dir": "KADID10K/images",
30
+ "density_type": "pdf",
31
+ "thre_std": 0.2,
32
+ "thre_diff": 0.1
33
+ }
34
+ }
35
+ }
DeQA-Score/build_soft_labels/gen_soft_label.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import numpy as np
4
+ import os
5
+ import random
6
+ from scipy.stats import norm, pearsonr, spearmanr
7
+
8
+
9
+ def parse_args():
10
+ parser = argparse.ArgumentParser(description="label parameters for DeQA-Score")
11
+ parser.add_argument("--config", type=str, default="./config.json")
12
+ args = parser.parse_args()
13
+ return args
14
+
15
+
16
+ questions = [
17
+ "What do you think about the quality of this image?",
18
+ "Can you rate the quality of this picture?",
19
+ "Can you judge the quality of this image?",
20
+ "How would you rate the quality of this image?",
21
+ "How would you judge the quality of this image?",
22
+ "What is your quality rating for this image?",
23
+ "What's your opinion on the quality of this picture?",
24
+ "Rate the quality of this image.",
25
+ "Could you evaluate the quality of this image?",
26
+ "How do you assess the quality of this image?",
27
+ ]
28
+
29
+
30
+ def calculate_srcc_plcc(pred, mos):
31
+ srcc, _ = spearmanr(pred, mos)
32
+ plcc, _ = pearsonr(pred, mos)
33
+ return srcc, plcc
34
+
35
+
36
+ def get_level(mos, min_mos, max_mos):
37
+ eps = 1e-8
38
+ texts = ["bad", "poor", "fair", "good", "excellent"]
39
+ for idx in range(1, len(texts) + 1):
40
+ mos_left = min_mos + (idx - 1) / 5 * (max_mos - min_mos) - eps
41
+ mos_right = min_mos + idx / 5 * (max_mos - min_mos) + eps
42
+ if mos > mos_left and mos <= mos_right:
43
+ level = idx
44
+ break
45
+ text = texts[level - 1]
46
+ return text
47
+
48
+
49
+ def adjust_gaussian_bar(probs, score):
50
+ """
51
+ alpha * (a + b + c + d + e) + 5 * beta = 1
52
+ alpha * (5a + 4b + 3c + 2d + e) + 15 beta = score
53
+ ==>
54
+ alpha * A + 5 * beta = 1
55
+ alpha * B + 15 * beta = score
56
+ """
57
+ A = np.array(probs).sum()
58
+ B = np.inner(np.array(probs), np.array([5, 4, 3, 2, 1]))
59
+ alpha = (score - 3) / (B - 3. * A + 1e-9)
60
+ beta = (1. - alpha * A) / 5.
61
+ return alpha, beta
62
+
63
+
64
+ def get_binary_probs(mos, min_mos=1.0, max_mos=5.0):
65
+ eps = 1e-8
66
+ probs = [0, 0, 0, 0, 0]
67
+ for idx in range(1, len(probs)):
68
+ mos_left = min_mos + (idx - 1) / 4 * (max_mos - min_mos) - eps
69
+ mos_right = min_mos + idx / 4 * (max_mos - min_mos) + eps
70
+ if mos > mos_left and mos <= mos_right:
71
+ probs[idx - 1] = (mos_right - mos) / (mos_right - mos_left)
72
+ probs[idx] = (mos - mos_left) / (mos_right - mos_left)
73
+ break
74
+ assert np.array((np.array(probs) == 0)).sum() == 3
75
+ assert round(np.array(probs).sum(), 5) == 1
76
+ probs = probs[::-1] # should start with "excellent" & end with "bad"
77
+ return probs
78
+
79
+
80
+ def main(cfg):
81
+ density_type = cfg["density_type"] # ["pdf", "cdf"]
82
+ thre_std = cfg["thre_std"]
83
+ thre_diff = cfg["thre_diff"]
84
+ with open(cfg["split_json"]) as fr:
85
+ split = json.load(fr)
86
+ with open(cfg["mos_json"]) as fr:
87
+ mos_dict = json.load(fr)
88
+ save_train = cfg["save_train"]
89
+ save_test = cfg["save_test"]
90
+ img_dir = cfg["img_dir"]
91
+
92
+ moses, stds, imgs = [], [], []
93
+ for img in mos_dict:
94
+ moses.append(mos_dict[img]["mos"])
95
+ stds.append(mos_dict[img]["std"])
96
+ imgs.append(img)
97
+ max_mos = max([float(_) for _ in moses])
98
+ min_mos = min([float(_) for _ in moses])
99
+
100
+ num_binary, idx = 0, 0
101
+ preds, gts, raw_diffs, diffs, alphas, betas = [], [], [], [], [], []
102
+ train_metas, test_metas = [], []
103
+ for img, mos_str, std_str in zip(imgs, moses, stds):
104
+ mos, std = float(mos_str), float(std_str)
105
+ if os.path.basename(img) in split["train"]:
106
+ training = True
107
+ elif os.path.basename(img) in split["test"]:
108
+ training = False
109
+ else:
110
+ idx += 1
111
+ # print(idx, img)
112
+ continue
113
+
114
+ text = get_level(mos, min_mos, max_mos)
115
+ query = random.choice(questions)
116
+ resp = answer.replace("{}", text)
117
+
118
+ # norm mos and std
119
+ mos_norm = 4 * (mos - min_mos) / (max_mos - min_mos) + 1 # [0, 1] -> [1, 5]
120
+ std_norm = 4 * std / (max_mos - min_mos)
121
+
122
+ # ["excellent", "good", "fair", "poor", "bad"] -> [5, 4, 3, 2, 1]
123
+ probs = []
124
+ for x in range(5, 0, -1):
125
+ if density_type == "cdf":
126
+ # better for smaller std dataset (see Appendix) like SPAQ
127
+ prob = norm.cdf(x+0.5, mos_norm, std_norm) - norm.cdf(x-0.5, mos_norm, std_norm)
128
+ else:
129
+ # better for larger std dataset (see Appendix) like KonIQ and KADID
130
+ assert density_type == "pdf"
131
+ prob = norm.pdf(x, loc=mos_norm, scale=std_norm)
132
+ probs.append(prob)
133
+
134
+ mos_rec = np.inner(np.array(probs), np.array([5, 4, 3, 2, 1]))
135
+ raw_diff = abs(mos_rec - mos_norm)
136
+ raw_diffs.append(raw_diff)
137
+
138
+ alpha, beta = adjust_gaussian_bar(probs, mos_norm)
139
+ probs_norm = [max(_ * alpha + beta, 0) for _ in probs]
140
+ mos_rec = np.inner(np.array(probs_norm), np.array([5, 4, 3, 2, 1]))
141
+ diff = abs(mos_rec - mos_norm)
142
+
143
+ if std_norm < thre_std or diff > thre_diff:
144
+ # if std is too small, use binary probs (see Appendix)
145
+ probs_norm = get_binary_probs(mos_norm)
146
+ mos_rec = np.inner(np.array(probs_norm), np.array([5, 4, 3, 2, 1]))
147
+ diff, alpha, beta = abs(mos_rec - mos_norm), 1., 0.
148
+ num_binary += 1
149
+
150
+ preds.append(mos_rec)
151
+ gts.append(mos_norm)
152
+ diffs.append(diff)
153
+ alphas.append(alpha)
154
+ betas.append(beta)
155
+
156
+ meta = {
157
+ "id": os.path.basename(img) + f"->{mos_str}",
158
+ "image": os.path.join(img_dir, img),
159
+ "gt_score": mos,
160
+ "gt_score_norm": mos_norm,
161
+ "level_probs_org": probs,
162
+ "level_probs": probs_norm,
163
+ "std": std,
164
+ "std_norm": std_norm,
165
+ }
166
+ if training:
167
+ conversations = [
168
+ {
169
+ "from": "human",
170
+ "value": query + "\n<|image|>",
171
+ },
172
+ {
173
+ "from": "gpt",
174
+ "value": resp,
175
+ },
176
+ ]
177
+ meta["conversations"] = conversations
178
+ train_metas.append(meta)
179
+ else:
180
+ del meta["level_probs_org"]
181
+ del meta["level_probs"]
182
+ test_metas.append(meta)
183
+
184
+ print("=" * 100)
185
+ print(f"save {len(train_metas)} into {save_train}")
186
+ with open(save_train, "w") as fw:
187
+ fw.write(json.dumps(train_metas, indent=4))
188
+
189
+ print(f"save {len(test_metas)} into {save_test}")
190
+ with open(save_test, "w") as fw:
191
+ fw.write(json.dumps(test_metas, indent=4))
192
+
193
+ srcc, plcc = calculate_srcc_plcc(preds, gts)
194
+ print("srcc:", srcc, "plcc:", plcc)
195
+ print("[raw_diff]", "l1:", sum(raw_diffs) / len(raw_diffs), "l2:", np.sqrt((np.array(raw_diffs)**2).mean()))
196
+ print("[diff]", "l1:", sum(diffs) / len(diffs), "l2:", np.sqrt((np.array(diffs)**2).mean()))
197
+ print("[alpha]", "mean:", np.mean(alphas), "std:", np.std(alphas))
198
+ print("[beta]", "mean:", np.mean(betas), "std:", np.std(betas))
199
+ print("binary / all:", num_binary, "/", len(train_metas) + len(test_metas))
200
+
201
+
202
+ if __name__ == "__main__":
203
+ args = parse_args()
204
+ with open(args.config) as fr:
205
+ cfg = json.load(fr)
206
+ answer = cfg["answer"]
207
+ for dataset in cfg["dataset_params"]:
208
+ random.seed(131)
209
+ main(cfg["dataset_params"][dataset])
DeQA-Score/fig/boy_colorful.jpg ADDED

Git LFS Details

  • SHA256: 4bbfc21a36df2dfe57b3f96d935a8e257657a27382455dfb1c2aeb0804746997
  • Pointer size: 131 Bytes
  • Size of remote file: 120 kB
DeQA-Score/fig/model.png ADDED

Git LFS Details

  • SHA256: 4952c5a6bc6b81a549bb034473d8eec461bf1582f66e3112da21753aff02263c
  • Pointer size: 131 Bytes
  • Size of remote file: 270 kB
DeQA-Score/fig/singapore_flyer.jpg ADDED

Git LFS Details

  • SHA256: 09f86ecd97a2a16a79a8ccbc3acc8d8fa435e53e34da6d9fe144083446d2c644
  • Pointer size: 131 Bytes
  • Size of remote file: 144 kB
DeQA-Score/fig/teaser.png ADDED

Git LFS Details

  • SHA256: 6e1ab9a280eee4d7fc14bab91057d84c97a1bdde42e67be840a3a6fdce4ad16b
  • Pointer size: 131 Bytes
  • Size of remote file: 566 kB
DeQA-Score/preprocessor/preprocessor_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": 448,
3
+ "do_center_crop": true,
4
+ "do_normalize": true,
5
+ "do_resize": true,
6
+ "feature_extractor_type": "CLIPFeatureExtractor",
7
+ "image_mean": [
8
+ 0.48145466,
9
+ 0.4578275,
10
+ 0.40821073
11
+ ],
12
+ "image_std": [
13
+ 0.26862954,
14
+ 0.26130258,
15
+ 0.27577711
16
+ ],
17
+ "resample": 3,
18
+ "size": 448
19
+ }
DeQA-Score/preprocessor/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<unk>",
17
+ "unk_token": {
18
+ "content": "<unk>",
19
+ "lstrip": false,
20
+ "normalized": false,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
DeQA-Score/preprocessor/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723
DeQA-Score/preprocessor/tokenizer_config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "bos_token": {
5
+ "__type": "AddedToken",
6
+ "content": "<s>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "clean_up_tokenization_spaces": false,
13
+ "eos_token": {
14
+ "__type": "AddedToken",
15
+ "content": "</s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "legacy": false,
22
+ "model_max_length": 2048,
23
+ "pad_token": null,
24
+ "padding_side": "right",
25
+ "sp_model_kwargs": {},
26
+ "tokenizer_class": "LlamaTokenizer",
27
+ "unk_token": {
28
+ "__type": "AddedToken",
29
+ "content": "<unk>",
30
+ "lstrip": false,
31
+ "normalized": false,
32
+ "rstrip": false,
33
+ "single_word": false
34
+ }
35
+ }
DeQA-Score/pyproject.toml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "DeQA-Score"
7
+ version = "1.2.0"
8
+ description = "Teaching Large Language Models to Regress Accurate Image Quality Scores using Score Distribution (based on mPLUG-Owl2)"
9
+ readme = "README.md"
10
+ requires-python = ">=3.8"
11
+ classifiers = [
12
+ "Programming Language :: Python :: 3",
13
+ "License :: OSI Approved :: Apache Software License",
14
+ ]
15
+ dependencies = [
16
+ "torch==2.0.1", "torchvision==0.15.2",
17
+ "transformers==4.36.1", "tokenizers==0.15.0", "sentencepiece==0.1.99", "shortuuid",
18
+ "accelerate==0.21.0", "peft==0.4.0", "bitsandbytes==0.41.0",
19
+ "pydantic<2,>=1", "markdown2[all]", "numpy", "scikit-learn==1.2.2",
20
+ "gradio==3.35.2", "gradio_client==0.2.9",
21
+ "requests", "httpx==0.24.0", "uvicorn", "fastapi", "icecream",
22
+ "einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13", "decord", "scipy",
23
+ ]
24
+
25
+ [project.optional-dependencies]
26
+ train = ["deepspeed==0.9.5", "ninja", "wandb"]
27
+
28
+ [project.urls]
29
+ "Bug Tracker" = "https://github.com/zhiyuanyou/DeQA-Score/issues"
30
+
31
+ [tool.setuptools.packages.find]
32
+ exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
33
+
34
+ [tool.wheel]
35
+ exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
DeQA-Score/scripts/eval_dist.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export PYTHONPATH=./:$PYTHONPATH
2
+
3
+ res_dir=./results/res_deqa_mix3/
4
+ gt_dir=../Data-DeQA-Score/
5
+
6
+ python src/evaluate/cal_distribution_gap.py \
7
+ --level_names excellent good fair poor bad \
8
+ --pred_paths $res_dir/test_koniq_2k.json \
9
+ $res_dir/test_spaq_2k.json \
10
+ $res_dir/test_kadid_2k.json \
11
+ $res_dir/test_pipal_5k.json \
12
+ $res_dir/test_livew_1k.json \
13
+ $res_dir/test_agiqa_3k.json \
14
+ $res_dir/test_tid2013_3k.json \
15
+ $res_dir/test_csiq_866.json \
16
+ --gt_paths $gt_dir/KONIQ/metas/test_koniq_2k.json \
17
+ $gt_dir/SPAQ/metas/test_spaq_2k.json \
18
+ $gt_dir/KADID10K/metas/test_kadid_2k.json \
19
+ $gt_dir/PIPAL/metas/test_pipal_5k.json \
20
+ $gt_dir/LIVE-WILD/metas/test_livew_1k.json \
21
+ $gt_dir/AGIQA3K/metas/test_agiqa_3k.json \
22
+ $gt_dir/TID2013/metas/test_tid2013_3k.json \
23
+ $gt_dir/CSIQ/metas/test_csiq_866.json \
DeQA-Score/scripts/eval_score.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export PYTHONPATH=./:$PYTHONPATH
2
+
3
+ res_dir=./results/res_deqa_mix3/
4
+ gt_dir=../Data-DeQA-Score/
5
+
6
+ python src/evaluate/cal_plcc_srcc.py \
7
+ --level_names excellent good fair poor bad \
8
+ --pred_paths $res_dir/test_koniq_2k.json \
9
+ $res_dir/test_spaq_2k.json \
10
+ $res_dir/test_kadid_2k.json \
11
+ $res_dir/test_pipal_5k.json \
12
+ $res_dir/test_livew_1k.json \
13
+ $res_dir/test_agiqa_3k.json \
14
+ $res_dir/test_tid2013_3k.json \
15
+ $res_dir/test_csiq_866.json \
16
+ --gt_paths $gt_dir/KONIQ/metas/test_koniq_2k.json \
17
+ $gt_dir/SPAQ/metas/test_spaq_2k.json \
18
+ $gt_dir/KADID10K/metas/test_kadid_2k.json \
19
+ $gt_dir/PIPAL/metas/test_pipal_5k.json \
20
+ $gt_dir/LIVE-WILD/metas/test_livew_1k.json \
21
+ $gt_dir/AGIQA3K/metas/test_agiqa_3k.json \
22
+ $gt_dir/TID2013/metas/test_tid2013_3k.json \
23
+ $gt_dir/CSIQ/metas/test_csiq_866.json \
DeQA-Score/scripts/infer.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export CUDA_VISIBLE_DEVICES=$1
2
+ export PYTHONPATH=./:$PYTHONPATH
3
+
4
+ python src/evaluate/iqa_eval.py \
5
+ --level-names excellent good fair poor bad \
6
+ --model-path checkpoints/DeQA-Score-Mix3/ \
7
+ --save-dir results/res_deqa_mix3/ \
8
+ --preprocessor-path ./preprocessor/ \
9
+ --root-dir ../Data-DeQA-Score/ \
10
+ --meta-paths ../Data-DeQA-Score/KONIQ/metas/test_koniq_2k.json \
11
+ ../Data-DeQA-Score/SPAQ/metas/test_spaq_2k.json \
12
+ ../Data-DeQA-Score/KADID10K/metas/test_kadid_2k.json \
13
+ ../Data-DeQA-Score/PIPAL/metas/test_pipal_5k.json \
14
+ ../Data-DeQA-Score/LIVE-WILD/metas/test_livew_1k.json \
15
+ ../Data-DeQA-Score/AGIQA3K/metas/test_agiqa_3k.json \
16
+ ../Data-DeQA-Score/TID2013/metas/test_tid2013_3k.json \
17
+ ../Data-DeQA-Score/CSIQ/metas/test_csiq_866.json \
DeQA-Score/scripts/infer_lora.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export CUDA_VISIBLE_DEVICES=$1
2
+ export PYTHONPATH=./:$PYTHONPATH
3
+
4
+ python src/evaluate/iqa_eval.py \
5
+ --level-names excellent good fair poor bad \
6
+ --model-path checkpoints/DeQA-Score-LoRA-Mix3/ \
7
+ --model-base ../ModelZoo/mplug-owl2-llama2-7b/ \
8
+ --save-dir results/res_deqa_lora_mix3/ \
9
+ --preprocessor-path ./preprocessor/ \
10
+ --root-dir ../Data-DeQA-Score/ \
11
+ --meta-paths ../Data-DeQA-Score/KONIQ/metas/test_koniq_2k.json \
12
+ ../Data-DeQA-Score/SPAQ/metas/test_spaq_2k.json \
13
+ ../Data-DeQA-Score/KADID10K/metas/test_kadid_2k.json \
14
+ ../Data-DeQA-Score/PIPAL/metas/test_pipal_5k.json \
15
+ ../Data-DeQA-Score/LIVE-WILD/metas/test_livew_1k.json \
16
+ ../Data-DeQA-Score/AGIQA3K/metas/test_agiqa_3k.json \
17
+ ../Data-DeQA-Score/TID2013/metas/test_tid2013_3k.json \
18
+ ../Data-DeQA-Score/CSIQ/metas/test_csiq_866.json \
DeQA-Score/scripts/train.sh ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ export PYTHONPATH=./:$PYTHONPATH
3
+
4
+ LOAD="../ModelZoo/mplug-owl2-llama2-7b/"
5
+
6
+ deepspeed --include localhost:$1 --master_port 6688 src/train/train_mem.py \
7
+ --deepspeed scripts/zero3.json \
8
+ --model_name_or_path $LOAD \
9
+ --version v1 \
10
+ --dataset_type pair \
11
+ --level_prefix "The quality of the image is" \
12
+ --level_names excellent good fair poor bad \
13
+ --softkl_loss True \
14
+ --weight_rank 1.0 \
15
+ --weight_softkl 1.0 \
16
+ --weight_next_token 0.05 \
17
+ --continuous_rating_loss True \
18
+ --closeset_rating_loss True \
19
+ --use_fix_std True \
20
+ --detach_pred_std True \
21
+ --data_paths ../Data-DeQA-Score/KONIQ/metas/train_koniq_7k.json \
22
+ ../Data-DeQA-Score/SPAQ/metas/train_spaq_9k.json \
23
+ ../Data-DeQA-Score/KADID10K/metas/train_kadid_8k.json \
24
+ --data_weights 1 1 1 \
25
+ --image_folder ../Data-DeQA-Score/ \
26
+ --output_dir ./checkpoints/deqa_mix3_rank1.0_next0.05_kl1.0/ \
27
+ --image_aspect_ratio pad \
28
+ --group_by_modality_length True \
29
+ --bf16 True \
30
+ --num_train_epochs 3 \
31
+ --per_device_train_batch_size 16 \
32
+ --per_device_eval_batch_size 4 \
33
+ --gradient_accumulation_steps 1 \
34
+ --evaluation_strategy "no" \
35
+ --save_strategy "no" \
36
+ --learning_rate 2e-5 \
37
+ --weight_decay 0. \
38
+ --warmup_ratio 0.03 \
39
+ --lr_scheduler_type "cosine" \
40
+ --logging_steps 1 \
41
+ --tf32 True \
42
+ --model_max_length 2048 \
43
+ --gradient_checkpointing True \
44
+ --tune_visual_abstractor True \
45
+ --freeze_vision_model False \
46
+ --dataloader_num_workers 4 \
47
+ --lazy_preprocess True \
48
+ --report_to tensorboard
DeQA-Score/scripts/train_lora.sh ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ export PYTHONPATH=./:$PYTHONPATH
3
+
4
+ LOAD="../ModelZoo/mplug-owl2-llama2-7b/"
5
+
6
+ deepspeed --include localhost:$1 --master_port 6688 src/train/train_mem.py \
7
+ --deepspeed scripts/zero3.json \
8
+ --lora_enable True \
9
+ --model_name_or_path $LOAD \
10
+ --version v1 \
11
+ --dataset_type pair \
12
+ --level_prefix "The quality of the image is" \
13
+ --level_names excellent good fair poor bad \
14
+ --softkl_loss True \
15
+ --weight_rank 1.0 \
16
+ --weight_softkl 1.0 \
17
+ --weight_next_token 0.05 \
18
+ --continuous_rating_loss True \
19
+ --closeset_rating_loss True \
20
+ --use_fix_std True \
21
+ --detach_pred_std True \
22
+ --data_paths ../Data-DeQA-Score/KONIQ/metas/train_koniq_7k.json \
23
+ ../Data-DeQA-Score/SPAQ/metas/train_spaq_9k.json \
24
+ ../Data-DeQA-Score/KADID10K/metas/train_kadid_8k.json \
25
+ --data_weights 1 1 1 \
26
+ --image_folder ../Data-DeQA-Score/ \
27
+ --output_dir ./checkpoints/deqa_lora_mix3_rank1.0_next0.05_kl1.0 \
28
+ --image_aspect_ratio pad \
29
+ --group_by_modality_length True \
30
+ --bf16 True \
31
+ --num_train_epochs 3 \
32
+ --per_device_train_batch_size 16 \
33
+ --per_device_eval_batch_size 4 \
34
+ --gradient_accumulation_steps 1 \
35
+ --evaluation_strategy "no" \
36
+ --save_strategy "no" \
37
+ --learning_rate 2e-5 \
38
+ --weight_decay 0. \
39
+ --warmup_ratio 0.03 \
40
+ --lr_scheduler_type "cosine" \
41
+ --logging_steps 1 \
42
+ --tf32 True \
43
+ --model_max_length 2048 \
44
+ --gradient_checkpointing True \
45
+ --tune_visual_abstractor True \
46
+ --freeze_vision_model False \
47
+ --dataloader_num_workers 4 \
48
+ --lazy_preprocess True \
49
+ --report_to tensorboard
DeQA-Score/scripts/zero3.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "train_micro_batch_size_per_gpu": "auto",
14
+ "train_batch_size": "auto",
15
+ "gradient_accumulation_steps": "auto",
16
+ "zero_optimization": {
17
+ "stage": 3,
18
+ "overlap_comm": true,
19
+ "contiguous_gradients": true,
20
+ "sub_group_size": 1e9,
21
+ "reduce_bucket_size": "auto",
22
+ "stage3_param_persistence_threshold": "auto",
23
+ "stage3_max_live_parameters": 0,
24
+ "stage3_max_reuse_distance": 0,
25
+ "stage3_prefetch_bucket_size": 0,
26
+ "stage3_gather_16bit_weights_on_model_save": true
27
+ }
28
+ }
DeQA-Score/scripts/zero3_offload.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "optimizer": {
14
+ "type": "AdamW",
15
+ "params": {
16
+ "lr": "auto",
17
+ "betas": "auto",
18
+ "eps": "auto",
19
+ "weight_decay": "auto"
20
+ }
21
+ },
22
+ "scheduler": {
23
+ "type": "WarmupLR",
24
+ "params": {
25
+ "warmup_min_lr": "auto",
26
+ "warmup_max_lr": "auto",
27
+ "warmup_num_steps": "auto"
28
+ }
29
+ },
30
+ "zero_optimization": {
31
+ "stage": 3,
32
+ "offload_optimizer": {
33
+ "device": "cpu",
34
+ "pin_memory": true
35
+ },
36
+ "offload_param": {
37
+ "device": "cpu",
38
+ "pin_memory": true
39
+ },
40
+ "overlap_comm": true,
41
+ "contiguous_gradients": true,
42
+ "sub_group_size": 1e9,
43
+ "reduce_bucket_size": "auto",
44
+ "stage3_prefetch_bucket_size": "auto",
45
+ "stage3_param_persistence_threshold": "auto",
46
+ "stage3_max_live_parameters": 1e9,
47
+ "stage3_max_reuse_distance": 1e9,
48
+ "gather_16bit_weights_on_model_save": true
49
+ },
50
+ "gradient_accumulation_steps": "auto",
51
+ "gradient_clipping": "auto",
52
+ "train_batch_size": "auto",
53
+ "train_micro_batch_size_per_gpu": "auto",
54
+ "steps_per_print": 1e5,
55
+ "wall_clock_breakdown": false
56
+ }
DeQA-Score/src/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .model import MPLUGOwl2LlamaForCausalLM
2
+ from .evaluate import Scorer
DeQA-Score/src/constants.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "./demo_logs"
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ IMAGE_TOKEN_INDEX = -200
9
+ DEFAULT_IMAGE_TOKEN = "<|image|>"
DeQA-Score/src/conversation.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+ from src.constants import DEFAULT_IMAGE_TOKEN
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+ SINGLE = auto()
9
+ TWO = auto()
10
+ TWO_NO_SYS = auto()
11
+ MPT = auto()
12
+ PLAIN = auto()
13
+ LLAMA_2 = auto()
14
+
15
+
16
+ @dataclasses.dataclass
17
+ class Conversation:
18
+ """A class that keeps all conversation history."""
19
+ system: str
20
+ roles: List[str]
21
+ messages: List[List[str]]
22
+ offset: int
23
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
24
+ sep: str = "###"
25
+ sep2: str = None
26
+ version: str = "Unknown"
27
+
28
+ skip_next: bool = False
29
+
30
+ def get_prompt(self):
31
+ messages = self.messages
32
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
33
+ messages = self.messages.copy()
34
+ init_role, init_msg = messages[0].copy()
35
+ # init_msg = init_msg[0].replace("<image>", "").strip()
36
+ # if 'mmtag' in self.version:
37
+ # messages[0] = (init_role, init_msg)
38
+ # messages.insert(0, (self.roles[0], "<Image><image></Image>"))
39
+ # messages.insert(1, (self.roles[1], "Received."))
40
+ # else:
41
+ # messages[0] = (init_role, "<image>\n" + init_msg)
42
+ init_msg = init_msg[0].replace(DEFAULT_IMAGE_TOKEN, "").strip()
43
+ messages[0] = (init_role, DEFAULT_IMAGE_TOKEN + init_msg)
44
+
45
+ if self.sep_style == SeparatorStyle.SINGLE:
46
+ ret = self.system + self.sep
47
+ for role, message in messages:
48
+ if message:
49
+ if type(message) is tuple:
50
+ message, _, _ = message
51
+ ret += role + ": " + message + self.sep
52
+ else:
53
+ ret += role + ":"
54
+ elif self.sep_style == SeparatorStyle.TWO:
55
+ seps = [self.sep, self.sep2]
56
+ ret = self.system + seps[0]
57
+ for i, (role, message) in enumerate(messages):
58
+ if message:
59
+ if type(message) is tuple:
60
+ message, _, _ = message
61
+ ret += role + ": " + message + seps[i % 2]
62
+ else:
63
+ ret += role + ":"
64
+ elif self.sep_style == SeparatorStyle.TWO_NO_SYS:
65
+ seps = [self.sep, self.sep2]
66
+ ret = ""
67
+ for i, (role, message) in enumerate(messages):
68
+ if message:
69
+ if type(message) is tuple:
70
+ message, _, _ = message
71
+ ret += role + ": " + message + seps[i % 2]
72
+ else:
73
+ ret += role + ":"
74
+ elif self.sep_style == SeparatorStyle.MPT:
75
+ ret = self.system + self.sep
76
+ for role, message in messages:
77
+ if message:
78
+ if type(message) is tuple:
79
+ message, _, _ = message
80
+ ret += role + message + self.sep
81
+ else:
82
+ ret += role
83
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
84
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
85
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
86
+ ret = ""
87
+
88
+ for i, (role, message) in enumerate(messages):
89
+ if i == 0:
90
+ assert message, "first message should not be none"
91
+ assert role == self.roles[0], "first message should come from user"
92
+ if message:
93
+ if type(message) is tuple:
94
+ message, _, _ = message
95
+ if i == 0: message = wrap_sys(self.system) + message
96
+ if i % 2 == 0:
97
+ message = wrap_inst(message)
98
+ ret += self.sep + message
99
+ else:
100
+ ret += " " + message + " " + self.sep2
101
+ else:
102
+ ret += ""
103
+ ret = ret.lstrip(self.sep)
104
+ elif self.sep_style == SeparatorStyle.PLAIN:
105
+ seps = [self.sep, self.sep2]
106
+ ret = self.system
107
+ for i, (role, message) in enumerate(messages):
108
+ if message:
109
+ if type(message) is tuple:
110
+ message, _, _ = message
111
+ ret += message + seps[i % 2]
112
+ else:
113
+ ret += ""
114
+ else:
115
+ raise ValueError(f"Invalid style: {self.sep_style}")
116
+
117
+ return ret
118
+
119
+ def append_message(self, role, message):
120
+ self.messages.append([role, message])
121
+
122
+ def get_images(self, return_pil=False):
123
+ images = []
124
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
125
+ if i % 2 == 0:
126
+ if type(msg) is tuple:
127
+ import base64
128
+ from io import BytesIO
129
+ from PIL import Image
130
+ msg, image, image_process_mode = msg
131
+ if image_process_mode == "Pad":
132
+ def expand2square(pil_img, background_color=(122, 116, 104)):
133
+ width, height = pil_img.size
134
+ if width == height:
135
+ return pil_img
136
+ elif width > height:
137
+ result = Image.new(pil_img.mode, (width, width), background_color)
138
+ result.paste(pil_img, (0, (width - height) // 2))
139
+ return result
140
+ else:
141
+ result = Image.new(pil_img.mode, (height, height), background_color)
142
+ result.paste(pil_img, ((height - width) // 2, 0))
143
+ return result
144
+ image = expand2square(image)
145
+ elif image_process_mode in ["Default", "Crop"]:
146
+ pass
147
+ elif image_process_mode == "Resize":
148
+ image = image.resize((336, 336))
149
+ else:
150
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
151
+ max_hw, min_hw = max(image.size), min(image.size)
152
+ aspect_ratio = max_hw / min_hw
153
+ max_len, min_len = 800, 400
154
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
155
+ longest_edge = int(shortest_edge * aspect_ratio)
156
+ W, H = image.size
157
+ if longest_edge != max(image.size):
158
+ if H > W:
159
+ H, W = longest_edge, shortest_edge
160
+ else:
161
+ H, W = shortest_edge, longest_edge
162
+ image = image.resize((W, H))
163
+ if return_pil:
164
+ images.append(image)
165
+ else:
166
+ buffered = BytesIO()
167
+ image.save(buffered, format="PNG")
168
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
169
+ images.append(img_b64_str)
170
+ return images
171
+
172
+ def to_gradio_chatbot(self):
173
+ ret = []
174
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
175
+ if i % 2 == 0:
176
+ if type(msg) is tuple:
177
+ import base64
178
+ from io import BytesIO
179
+ msg, image, image_process_mode = msg
180
+ max_hw, min_hw = max(image.size), min(image.size)
181
+ aspect_ratio = max_hw / min_hw
182
+ max_len, min_len = 800, 400
183
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
184
+ longest_edge = int(shortest_edge * aspect_ratio)
185
+ W, H = image.size
186
+ if H > W:
187
+ H, W = longest_edge, shortest_edge
188
+ else:
189
+ H, W = shortest_edge, longest_edge
190
+ image = image.resize((W, H))
191
+ buffered = BytesIO()
192
+ image.save(buffered, format="JPEG")
193
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
194
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
195
+ msg = img_str + msg.replace('<|image|>', '').strip()
196
+ ret.append([msg, None])
197
+ else:
198
+ ret.append([msg, None])
199
+ else:
200
+ ret[-1][-1] = msg
201
+ return ret
202
+
203
+ def copy(self):
204
+ return Conversation(
205
+ system=self.system,
206
+ roles=self.roles,
207
+ messages=[[x, y] for x, y in self.messages],
208
+ offset=self.offset,
209
+ sep_style=self.sep_style,
210
+ sep=self.sep,
211
+ sep2=self.sep2,
212
+ version=self.version)
213
+
214
+ def dict(self):
215
+ if len(self.get_images()) > 0:
216
+ return {
217
+ "system": self.system,
218
+ "roles": self.roles,
219
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
220
+ "offset": self.offset,
221
+ "sep": self.sep,
222
+ "sep2": self.sep2,
223
+ }
224
+ return {
225
+ "system": self.system,
226
+ "roles": self.roles,
227
+ "messages": self.messages,
228
+ "offset": self.offset,
229
+ "sep": self.sep,
230
+ "sep2": self.sep2,
231
+ }
232
+
233
+
234
+ conv_vicuna_v0 = Conversation(
235
+ system="A chat between a curious human and an artificial intelligence assistant. "
236
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
237
+ roles=("Human", "Assistant"),
238
+ messages=(
239
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
240
+ ("Assistant",
241
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
242
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
243
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
244
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
245
+ "renewable and non-renewable energy sources:\n"
246
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
247
+ "energy sources are finite and will eventually run out.\n"
248
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
249
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
250
+ "and other negative effects.\n"
251
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
252
+ "have lower operational costs than non-renewable sources.\n"
253
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
254
+ "locations than non-renewable sources.\n"
255
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
256
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
257
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
258
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
259
+ ),
260
+ offset=2,
261
+ sep_style=SeparatorStyle.SINGLE,
262
+ sep="###",
263
+ )
264
+
265
+ conv_vicuna_v1 = Conversation(
266
+ system="A chat between a curious user and an artificial intelligence assistant. "
267
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
268
+ roles=("USER", "ASSISTANT"),
269
+ version="v1",
270
+ messages=(),
271
+ offset=0,
272
+ sep_style=SeparatorStyle.TWO,
273
+ sep=" ",
274
+ sep2="</s>",
275
+ )
276
+
277
+ conv_mplug_owl2 = Conversation(
278
+ system="A chat between a curious human and an artificial intelligence assistant. "
279
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
280
+ roles=("USER", "ASSISTANT"),
281
+ version="v1",
282
+ messages=(),
283
+ offset=0,
284
+ sep_style=SeparatorStyle.TWO_NO_SYS,
285
+ sep=" ",
286
+ sep2="</s>",
287
+ )
288
+
289
+ # default_conversation = conv_vicuna_v1
290
+ default_conversation = conv_mplug_owl2
291
+ conv_templates = {
292
+ "default": conv_vicuna_v0,
293
+ "v0": conv_vicuna_v0,
294
+ "v1": conv_vicuna_v1,
295
+ "vicuna_v1": conv_vicuna_v1,
296
+ "mplug_owl2": conv_mplug_owl2,
297
+ }
298
+
299
+
300
+ if __name__ == "__main__":
301
+ print(default_conversation.get_prompt())
DeQA-Score/src/datasets/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .pair_dataset import make_pair_data_module
2
+ from .single_dataset import make_single_data_module
3
+
4
+
5
+ def make_data_module(tokenizer, data_args):
6
+ if data_args.dataset_type == "single":
7
+ return make_single_data_module(tokenizer, data_args)
8
+ elif data_args.dataset_type == "pair":
9
+ return make_pair_data_module(tokenizer, data_args)
10
+ else:
11
+ raise ValueError
DeQA-Score/src/datasets/pair_dataset.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import os
4
+ import random
5
+ from dataclasses import dataclass
6
+ from typing import Dict, Sequence
7
+
8
+ import torch
9
+ import transformers
10
+ from PIL import Image
11
+ from torch.utils.data import Dataset
12
+
13
+ from src.constants import IGNORE_INDEX
14
+
15
+ from .utils import (expand2square, load_video, preprocess,
16
+ preprocess_multimodal, rank0_print)
17
+
18
+
19
+ class PairDataset(Dataset):
20
+ """Dataset for supervised fine-tuning."""
21
+
22
+ def __init__(
23
+ self,
24
+ data_paths,
25
+ data_weights,
26
+ tokenizer: transformers.PreTrainedTokenizer,
27
+ data_args,
28
+ ):
29
+ super(PairDataset, self).__init__()
30
+ dataset_list = [] # list (different datasets) of list (samples in one dataset)
31
+ for data_path, data_weight in zip(data_paths, data_weights):
32
+ data_list = json.load(open(data_path, "r"))
33
+ dataset_list.append(data_list * data_weight)
34
+ self.dataset_list = dataset_list
35
+
36
+ # Construct nums_data, nums_data[i] is the number of samples in 0-i th datasets
37
+ nums_eachdata = [len(_) for _ in self.dataset_list]
38
+ nums_predata = copy.deepcopy(nums_eachdata)
39
+ for idx in range(1, len(nums_predata)):
40
+ nums_predata[idx] = nums_predata[idx] + nums_predata[idx - 1]
41
+
42
+ rank0_print("Formatting inputs...Skip in lazy mode")
43
+ self.tokenizer = tokenizer
44
+ self.nums_eachdata = nums_eachdata
45
+ self.nums_predata = nums_predata
46
+ self.data_args = data_args
47
+ assert self.nums_predata[-1] == sum(self.nums_eachdata)
48
+
49
+ def __len__(self):
50
+ return self.nums_predata[-1]
51
+
52
+ @property
53
+ def lengths(self):
54
+ length_list = []
55
+ for dataset in self.dataset_list:
56
+ for sample in dataset:
57
+ img_tokens = 128 if "image" in sample else 0
58
+ length_list.append(
59
+ sum(len(conv["value"].split()) for conv in sample["conversations"])
60
+ + img_tokens
61
+ )
62
+ return length_list
63
+
64
+ @property
65
+ def modality_lengths(self):
66
+ length_list = []
67
+ for dataset in self.dataset_list:
68
+ for sample in dataset:
69
+ cur_len = sum(
70
+ len(conv["value"].split()) for conv in sample["conversations"]
71
+ )
72
+ cur_len = cur_len if "image" in sample else -cur_len
73
+ length_list.append(cur_len)
74
+ return length_list
75
+
76
+ def next_rand(self):
77
+ return random.randint(0, len(self) - 1)
78
+
79
+ def __getitem__(self, i):
80
+ while True:
81
+ try:
82
+ # Get idx_dataset, idx_sample
83
+ if i < self.nums_predata[0]:
84
+ idx_dataset = 0
85
+ idx_sample = i
86
+ else:
87
+ for idx_dataset in range(1, len(self.nums_predata)):
88
+ if (
89
+ i < self.nums_predata[idx_dataset]
90
+ and i >= self.nums_predata[idx_dataset - 1]
91
+ ):
92
+ idx_sample = i - self.nums_predata[idx_dataset - 1]
93
+ break
94
+ # Sample two items
95
+ item_A = self.get_one_item(idx_dataset, idx_sample)
96
+ while True:
97
+ idx_sample_B = random.randint(
98
+ 0, self.nums_eachdata[idx_dataset] - 1
99
+ )
100
+ if idx_sample_B != idx_sample:
101
+ break
102
+ item_B = self.get_one_item(idx_dataset, idx_sample_B)
103
+ return {
104
+ "item_A": item_A,
105
+ "item_B": item_B,
106
+ }
107
+ except Exception as ex:
108
+ print(ex)
109
+ i = self.next_rand()
110
+ continue
111
+
112
+ def get_one_item(self, idx_dataset, idx_sample) -> Dict[str, torch.Tensor]:
113
+ # For IQA data, i must be int
114
+ sources = [self.dataset_list[idx_dataset][idx_sample]]
115
+ sources_org = copy.deepcopy(sources)
116
+ assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
117
+ if "image" in sources_org[0]:
118
+ image_file = sources[0]["image"]
119
+
120
+ image_folder = self.data_args.image_folder
121
+ processor = self.data_args.image_processor
122
+
123
+ if isinstance(image_file, list):
124
+ # Multiple Images as Input
125
+ image = [
126
+ Image.open(os.path.join(image_folder, imfile)).convert("RGB")
127
+ for imfile in image_file
128
+ ]
129
+
130
+ if self.data_args.image_aspect_ratio == "pad":
131
+ image = [
132
+ expand2square(
133
+ img,
134
+ tuple(int(x * 255) for x in processor.image_mean),
135
+ )
136
+ for img in image
137
+ ]
138
+ image = processor.preprocess(image, return_tensors="pt")[
139
+ "pixel_values"
140
+ ]
141
+ else:
142
+ image = processor.preprocess(image, return_tensors="pt")[
143
+ "pixel_values"
144
+ ]
145
+ elif os.path.join(image_folder, image_file).endswith("mp4"):
146
+ # Video as Input
147
+ image = load_video(os.path.join(image_folder, image_file))
148
+ if self.data_args.image_aspect_ratio == "pad":
149
+ image = [
150
+ expand2square(
151
+ img,
152
+ tuple(int(x * 255) for x in processor.image_mean),
153
+ )
154
+ for img in image
155
+ ]
156
+ image = processor.preprocess(image, return_tensors="pt")[
157
+ "pixel_values"
158
+ ]
159
+ else:
160
+ image = processor.preprocess(image, return_tensors="pt")[
161
+ "pixel_values"
162
+ ]
163
+ else:
164
+ image = Image.open(os.path.join(image_folder, image_file)).convert(
165
+ "RGB"
166
+ )
167
+ if self.data_args.image_aspect_ratio == "pad":
168
+ image = expand2square(
169
+ image, tuple(int(x * 255) for x in processor.image_mean)
170
+ )
171
+ image = processor.preprocess(image, return_tensors="pt")[
172
+ "pixel_values"
173
+ ]
174
+ else:
175
+ image = processor.preprocess(image, return_tensors="pt")[
176
+ "pixel_values"
177
+ ]
178
+ sources = preprocess_multimodal(
179
+ copy.deepcopy([e["conversations"] for e in sources]),
180
+ self.data_args,
181
+ )
182
+ else:
183
+ # Without images
184
+ sources = copy.deepcopy([e["conversations"] for e in sources])
185
+
186
+ data_dict = preprocess(
187
+ sources,
188
+ self.tokenizer,
189
+ has_image=("image" in sources_org[0]),
190
+ )
191
+ data_dict = dict(
192
+ input_ids=data_dict["input_ids"][0],
193
+ labels=data_dict["labels"][0],
194
+ )
195
+
196
+ # default task_type: "score", gt_socre & std: -10000, level_probs: [-10000] * 5
197
+ data_dict["task_type"] = sources_org[0].get("task_type", "score")
198
+ data_dict["gt_score"] = sources_org[0].get("gt_score", -10000)
199
+ data_dict["std"] = sources_org[0].get("std", -10000)
200
+ data_dict["level_probs"] = sources_org[0].get("level_probs", [-10000] * 5)
201
+
202
+ # image exist in the data
203
+ if "image" in sources_org[0]:
204
+ data_dict["image_file"] = image_file
205
+ data_dict["image"] = image
206
+ elif self.data_args.is_multimodal:
207
+ # image does not exist in the data, but the model is multimodal
208
+ crop_size = self.data_args.image_processor.crop_size
209
+ data_dict["image"] = torch.zeros(3, crop_size["height"], crop_size["width"])
210
+ return data_dict
211
+
212
+
213
+ @dataclass
214
+ class DataCollatorForPairDataset(object):
215
+ """Collate examples for pair fine-tuning."""
216
+
217
+ tokenizer: transformers.PreTrainedTokenizer
218
+
219
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
220
+ instances_A = [instance["item_A"] for instance in instances]
221
+ instances_B = [instance["item_B"] for instance in instances]
222
+ batch = {
223
+ "input_type": "pair",
224
+ "item_A": self.collate_one(instances_A),
225
+ "item_B": self.collate_one(instances_B),
226
+ }
227
+ return batch
228
+
229
+ def collate_one(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
230
+ input_ids, labels = tuple(
231
+ [instance[key] for instance in instances] for key in ("input_ids", "labels")
232
+ )
233
+ input_ids = torch.nn.utils.rnn.pad_sequence(
234
+ input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
235
+ )
236
+ labels = torch.nn.utils.rnn.pad_sequence(
237
+ labels, batch_first=True, padding_value=IGNORE_INDEX
238
+ )
239
+ input_ids = input_ids[:, : self.tokenizer.model_max_length]
240
+ labels = labels[:, : self.tokenizer.model_max_length]
241
+ batch = dict(
242
+ input_ids=input_ids,
243
+ labels=labels,
244
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
245
+ )
246
+
247
+ batch["task_types"] = [instance["task_type"] for instance in instances]
248
+ batch["gt_scores"] = torch.tensor([instance["gt_score"] for instance in instances])
249
+ batch["stds"] = torch.tensor([instance["std"] for instance in instances])
250
+ batch["level_probs"] = torch.tensor([instance["level_probs"] for instance in instances])
251
+
252
+ if "image" in instances[0]:
253
+ images = [instance["image"] for instance in instances]
254
+ if all(x is not None and x.shape == images[0].shape for x in images):
255
+ batch["images"] = torch.stack(images)
256
+ else:
257
+ batch["images"] = images
258
+ batch["image_files"] = [instance["image_file"] for instance in instances]
259
+
260
+ return batch
261
+
262
+
263
+ def make_pair_data_module(
264
+ tokenizer: transformers.PreTrainedTokenizer, data_args
265
+ ) -> Dict:
266
+ """Make dataset and collator for supervised fine-tuning."""
267
+ train_dataset = PairDataset(
268
+ tokenizer=tokenizer,
269
+ data_paths=data_args.data_paths,
270
+ data_weights=data_args.data_weights,
271
+ data_args=data_args,
272
+ )
273
+ data_collator = DataCollatorForPairDataset(tokenizer=tokenizer)
274
+ return dict(
275
+ train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator
276
+ )
DeQA-Score/src/datasets/single_dataset.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import os
4
+ from dataclasses import dataclass
5
+ from typing import Dict, Sequence
6
+
7
+ import torch
8
+ import transformers
9
+ from PIL import Image
10
+ from torch.utils.data import Dataset
11
+
12
+ from src.constants import IGNORE_INDEX
13
+
14
+ from .utils import (expand2square, load_video, preprocess,
15
+ preprocess_multimodal, rank0_print)
16
+
17
+
18
+ class SingleDataset(Dataset):
19
+ """Dataset for supervised fine-tuning."""
20
+
21
+ def __init__(
22
+ self,
23
+ data_paths: str,
24
+ data_weights: str,
25
+ tokenizer: transformers.PreTrainedTokenizer,
26
+ data_args,
27
+ ):
28
+ super(SingleDataset, self).__init__()
29
+ list_data_dict = []
30
+ for data_path, data_weight in zip(data_paths, data_weights):
31
+ data_dict = json.load(open(data_path, "r"))
32
+ list_data_dict += data_dict * data_weight
33
+
34
+ rank0_print("Formatting inputs...Skip in lazy mode")
35
+ self.tokenizer = tokenizer
36
+ self.list_data_dict = list_data_dict
37
+ self.data_args = data_args
38
+
39
+ def __len__(self):
40
+ return len(self.list_data_dict)
41
+
42
+ @property
43
+ def lengths(self):
44
+ length_list = []
45
+ for sample in self.list_data_dict:
46
+ img_tokens = 128 if "image" in sample else 0
47
+ length_list.append(
48
+ sum(len(conv["value"].split()) for conv in sample["conversations"])
49
+ + img_tokens
50
+ )
51
+ return length_list
52
+
53
+ @property
54
+ def modality_lengths(self):
55
+ length_list = []
56
+ for sample in self.list_data_dict:
57
+ cur_len = sum(
58
+ len(conv["value"].split()) for conv in sample["conversations"]
59
+ )
60
+ cur_len = cur_len if "image" in sample else -cur_len
61
+ length_list.append(cur_len)
62
+ return length_list
63
+
64
+ def next_rand(self):
65
+ import random
66
+
67
+ return random.randint(0, len(self) - 1)
68
+
69
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
70
+ while True:
71
+ try:
72
+ sources = self.list_data_dict[i]
73
+ if isinstance(i, int):
74
+ sources = [sources]
75
+ sources_org = copy.deepcopy(sources)
76
+ assert (
77
+ len(sources) == 1
78
+ ), "Don't know why it is wrapped to a list" # FIXME
79
+ if "image" in sources_org[0]:
80
+ image_file = sources_org[0]["image"]
81
+
82
+ image_folder = self.data_args.image_folder
83
+ processor = self.data_args.image_processor
84
+ from pathlib import Path
85
+
86
+ # if not Path(os.path.join(image_folder, image_file)).exists():
87
+ # i = self.next_rand()
88
+ # continue
89
+ if isinstance(image_file, list):
90
+ # Multiple Images as Input
91
+ try:
92
+ image = [
93
+ Image.open(os.path.join(image_folder, imfile)).convert(
94
+ "RGB"
95
+ )
96
+ for imfile in image_file
97
+ ]
98
+ except Exception as ex:
99
+ print(ex)
100
+ i = self.next_rand()
101
+ continue
102
+ if self.data_args.image_aspect_ratio == "pad":
103
+ image = [
104
+ expand2square(
105
+ img,
106
+ tuple(int(x * 255) for x in processor.image_mean),
107
+ )
108
+ for img in image
109
+ ]
110
+ image = processor.preprocess(image, return_tensors="pt")[
111
+ "pixel_values"
112
+ ]
113
+ else:
114
+ image = processor.preprocess(image, return_tensors="pt")[
115
+ "pixel_values"
116
+ ]
117
+ elif os.path.join(image_folder, image_file).endswith("mp4"):
118
+ # Video as Input
119
+ image = load_video(os.path.join(image_folder, image_file))
120
+ if self.data_args.image_aspect_ratio == "pad":
121
+ image = [
122
+ expand2square(
123
+ img,
124
+ tuple(int(x * 255) for x in processor.image_mean),
125
+ )
126
+ for img in image
127
+ ]
128
+ image = processor.preprocess(image, return_tensors="pt")[
129
+ "pixel_values"
130
+ ]
131
+ else:
132
+ image = processor.preprocess(image, return_tensors="pt")[
133
+ "pixel_values"
134
+ ]
135
+ else:
136
+ try:
137
+ image = Image.open(
138
+ os.path.join(image_folder, image_file)
139
+ ).convert("RGB")
140
+ except Exception as ex:
141
+ print(ex)
142
+ i = self.next_rand()
143
+ continue
144
+ if self.data_args.image_aspect_ratio == "pad":
145
+ image = expand2square(
146
+ image, tuple(int(x * 255) for x in processor.image_mean)
147
+ )
148
+ image = processor.preprocess(image, return_tensors="pt")[
149
+ "pixel_values"
150
+ ]
151
+ else:
152
+ image = processor.preprocess(image, return_tensors="pt")[
153
+ "pixel_values"
154
+ ]
155
+ sources = preprocess_multimodal(
156
+ copy.deepcopy([e["conversations"] for e in sources]),
157
+ self.data_args,
158
+ )
159
+ else:
160
+
161
+ sources = copy.deepcopy([e["conversations"] for e in sources])
162
+ data_dict = preprocess(
163
+ sources,
164
+ self.tokenizer,
165
+ has_image=("image" in sources_org[0]),
166
+ )
167
+ if isinstance(i, int):
168
+ data_dict = dict(
169
+ input_ids=data_dict["input_ids"][0],
170
+ labels=data_dict["labels"][0],
171
+ )
172
+
173
+ # default task_type: "score", level_probs: [-10000] * 5
174
+ data_dict["task_type"] = sources_org[0].get("task_type", "score")
175
+ data_dict["level_probs"] = sources_org[0].get("level_probs", [-10000] * 5)
176
+
177
+ # image exist in the data
178
+ if "image" in sources_org[0]:
179
+ data_dict["image"] = image
180
+ elif self.data_args.is_multimodal:
181
+ # image does not exist in the data, but the model is multimodal
182
+ crop_size = self.data_args.image_processor.crop_size
183
+ data_dict["image"] = torch.zeros(
184
+ 3, crop_size["height"], crop_size["width"]
185
+ )
186
+ return data_dict
187
+ except Exception as ex:
188
+ print(ex)
189
+ i = self.next_rand()
190
+ continue
191
+
192
+
193
+ @dataclass
194
+ class DataCollatorForSupervisedDataset(object):
195
+ """Collate examples for supervised fine-tuning."""
196
+
197
+ tokenizer: transformers.PreTrainedTokenizer
198
+
199
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
200
+ input_ids, labels = tuple(
201
+ [instance[key] for instance in instances] for key in ("input_ids", "labels")
202
+ )
203
+ input_ids = torch.nn.utils.rnn.pad_sequence(
204
+ input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
205
+ )
206
+ labels = torch.nn.utils.rnn.pad_sequence(
207
+ labels, batch_first=True, padding_value=IGNORE_INDEX
208
+ )
209
+ input_ids = input_ids[:, : self.tokenizer.model_max_length]
210
+ labels = labels[:, : self.tokenizer.model_max_length]
211
+ batch = dict(
212
+ input_type="single",
213
+ input_ids=input_ids,
214
+ labels=labels,
215
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
216
+ )
217
+
218
+ batch["task_types"] = [instance["task_type"] for instance in instances]
219
+ batch["level_probs"] = torch.tensor([instance["level_probs"] for instance in instances])
220
+
221
+ if "image" in instances[0]:
222
+ images = [instance["image"] for instance in instances]
223
+ if all(x is not None and x.shape == images[0].shape for x in images):
224
+ batch["images"] = torch.stack(images)
225
+ else:
226
+ batch["images"] = images
227
+
228
+ return batch
229
+
230
+
231
+ def make_single_data_module(
232
+ tokenizer: transformers.PreTrainedTokenizer, data_args
233
+ ) -> Dict:
234
+ """Make dataset and collator for supervised fine-tuning."""
235
+ train_dataset = SingleDataset(
236
+ tokenizer=tokenizer,
237
+ data_paths=data_args.data_paths,
238
+ data_weights=data_args.data_weights,
239
+ data_args=data_args,
240
+ )
241
+ data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
242
+ return dict(
243
+ train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator
244
+ )
DeQA-Score/src/datasets/utils.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from dataclasses import dataclass, field
3
+ from typing import Dict, List, Optional, Sequence
4
+
5
+ from PIL import ImageFile
6
+
7
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
8
+
9
+ from dataclasses import dataclass, field
10
+ from typing import List, Optional
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ import transformers
15
+ from PIL import Image
16
+
17
+ from src import conversation as conversation_lib
18
+ from src.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX
19
+ from src.mm_utils import tokenizer_image_token
20
+
21
+
22
+ def rank0_print(*args):
23
+ try:
24
+ if dist.get_rank() == 0:
25
+ print(*args)
26
+ except:
27
+ print(*args)
28
+
29
+
30
+ @dataclass
31
+ class DataArguments:
32
+ data_paths: List[str] = field(default_factory=lambda: [])
33
+ lazy_preprocess: bool = False
34
+ is_multimodal: bool = False
35
+ image_folder: Optional[str] = field(default=None)
36
+ image_aspect_ratio: str = "square"
37
+ image_grid_pinpoints: Optional[str] = field(default=None)
38
+
39
+
40
+ def _tokenize_fn(
41
+ strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer
42
+ ) -> Dict:
43
+ """Tokenize a list of strings."""
44
+ tokenized_list = [
45
+ tokenizer(
46
+ text,
47
+ return_tensors="pt",
48
+ padding="longest",
49
+ max_length=tokenizer.model_max_length,
50
+ truncation=True,
51
+ )
52
+ for text in strings
53
+ ]
54
+ input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
55
+ input_ids_lens = labels_lens = [
56
+ tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
57
+ for tokenized in tokenized_list
58
+ ]
59
+ return dict(
60
+ input_ids=input_ids,
61
+ labels=labels,
62
+ input_ids_lens=input_ids_lens,
63
+ labels_lens=labels_lens,
64
+ )
65
+
66
+
67
+ def _mask_targets(target, tokenized_lens, speakers):
68
+ # cur_idx = 0
69
+ cur_idx = tokenized_lens[0]
70
+ tokenized_lens = tokenized_lens[1:]
71
+ target[:cur_idx] = IGNORE_INDEX
72
+ for tokenized_len, speaker in zip(tokenized_lens, speakers):
73
+ if speaker == "human":
74
+ target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX
75
+ cur_idx += tokenized_len
76
+
77
+
78
+ def _add_speaker_and_signal(header, source, get_conversation=True):
79
+ """Add speaker and start/end signal on each round."""
80
+ BEGIN_SIGNAL = "### "
81
+ END_SIGNAL = "\n"
82
+ conversation = header
83
+ for sentence in source:
84
+ from_str = sentence["from"]
85
+ if from_str.lower() == "human":
86
+ from_str = conversation_lib.default_conversation.roles[0]
87
+ elif from_str.lower() == "gpt":
88
+ from_str = conversation_lib.default_conversation.roles[1]
89
+ else:
90
+ from_str = "unknown"
91
+ sentence["value"] = (
92
+ BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL
93
+ )
94
+ if get_conversation:
95
+ conversation += sentence["value"]
96
+ conversation += BEGIN_SIGNAL
97
+ return conversation
98
+
99
+
100
+ def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> Dict:
101
+ is_multimodal = data_args.is_multimodal
102
+ if not is_multimodal:
103
+ return sources
104
+
105
+ for source in sources:
106
+ for sentence in source:
107
+ if DEFAULT_IMAGE_TOKEN in sentence["value"]:
108
+ sentence["value"] = (
109
+ sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip()
110
+ )
111
+ sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"]
112
+ sentence["value"] = sentence["value"].strip()
113
+
114
+ replace_token = DEFAULT_IMAGE_TOKEN
115
+ sentence["value"] = sentence["value"].replace(
116
+ DEFAULT_IMAGE_TOKEN, replace_token
117
+ )
118
+
119
+ return sources
120
+
121
+
122
+ def preprocess_v1(
123
+ sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False
124
+ ) -> Dict:
125
+ conv = conversation_lib.default_conversation.copy()
126
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
127
+
128
+ # Apply prompt templates
129
+ conversations = []
130
+ for i, source in enumerate(sources):
131
+ if roles[source[0]["from"]] != conv.roles[0]:
132
+ # Skip the first one if it is not from human
133
+ source = source[1:]
134
+
135
+ conv.messages = []
136
+ for j, sentence in enumerate(source):
137
+ role = roles[sentence["from"]]
138
+ assert role == conv.roles[j % 2], f"{i}"
139
+ conv.append_message(role, sentence["value"])
140
+ conversations.append(conv.get_prompt())
141
+
142
+ # Tokenize conversations
143
+
144
+ if has_image:
145
+ input_ids = torch.stack(
146
+ [
147
+ tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
148
+ for prompt in conversations
149
+ ],
150
+ dim=0,
151
+ )
152
+ else:
153
+ input_ids = tokenizer(
154
+ conversations,
155
+ return_tensors="pt",
156
+ padding="longest",
157
+ max_length=tokenizer.model_max_length,
158
+ truncation=True,
159
+ ).input_ids
160
+
161
+ targets = input_ids.clone()
162
+
163
+ assert (
164
+ conv.sep_style == conversation_lib.SeparatorStyle.TWO
165
+ or conv.sep_style == conversation_lib.SeparatorStyle.TWO_NO_SYS
166
+ )
167
+
168
+ # Mask targets
169
+ sep = conv.sep + conv.roles[1] + ": "
170
+ for conversation, target in zip(conversations, targets):
171
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
172
+
173
+ rounds = conversation.split(conv.sep2)
174
+ cur_len = 1 + 1
175
+ target[:cur_len] = IGNORE_INDEX
176
+ for i, rou in enumerate(rounds):
177
+ if rou == "":
178
+ break
179
+
180
+ parts = rou.split(sep)
181
+ if len(parts) != 2:
182
+ break
183
+ parts[0] += sep
184
+
185
+ if has_image:
186
+ round_len = len(tokenizer_image_token(rou, tokenizer))
187
+ instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 3
188
+ else:
189
+ round_len = len(tokenizer(rou).input_ids)
190
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
191
+ round_len -= 1
192
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
193
+
194
+ cur_len += round_len
195
+ target[cur_len:] = IGNORE_INDEX
196
+
197
+ if cur_len < tokenizer.model_max_length:
198
+ if cur_len != total_len:
199
+ target[:] = IGNORE_INDEX
200
+ print(
201
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
202
+ f" (ignored)"
203
+ )
204
+ return dict(
205
+ input_ids=input_ids,
206
+ labels=targets,
207
+ )
208
+
209
+
210
+ def preprocess_plain(
211
+ sources: Sequence[str],
212
+ tokenizer: transformers.PreTrainedTokenizer,
213
+ ) -> Dict:
214
+ # add end signal and concatenate together
215
+ conversations = []
216
+ for source in sources:
217
+ assert len(source) == 2
218
+ assert DEFAULT_IMAGE_TOKEN in source[0]["value"]
219
+ source[0]["value"] = DEFAULT_IMAGE_TOKEN
220
+ conversation = (
221
+ source[0]["value"]
222
+ + source[1]["value"]
223
+ + conversation_lib.default_conversation.sep
224
+ )
225
+ conversations.append(conversation)
226
+ # tokenize conversations
227
+ input_ids = [
228
+ tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
229
+ for prompt in conversations
230
+ ]
231
+ targets = copy.deepcopy(input_ids)
232
+ for target, source in zip(targets, sources):
233
+ tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer))
234
+ target[:tokenized_len] = IGNORE_INDEX
235
+
236
+ return dict(input_ids=input_ids, labels=targets)
237
+
238
+
239
+ def preprocess(
240
+ sources: Sequence[str],
241
+ tokenizer: transformers.PreTrainedTokenizer,
242
+ has_image: bool = False,
243
+ ) -> Dict:
244
+ """
245
+ Given a list of sources, each is a conversation list. This transform:
246
+ 1. Add signal '### ' at the beginning each sentence, with end signal '\n';
247
+ 2. Concatenate conversations together;
248
+ 3. Tokenize the concatenated conversation;
249
+ 4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
250
+ """
251
+ if (
252
+ conversation_lib.default_conversation.sep_style
253
+ == conversation_lib.SeparatorStyle.PLAIN
254
+ ):
255
+ return preprocess_plain(sources, tokenizer)
256
+ if conversation_lib.default_conversation.version.startswith("v1"):
257
+ return preprocess_v1(sources, tokenizer, has_image=has_image)
258
+ # add end signal and concatenate together
259
+ conversations = []
260
+ for source in sources:
261
+ header = f"{conversation_lib.default_conversation.system}\n\n"
262
+ conversation = _add_speaker_and_signal(header, source)
263
+ conversations.append(conversation)
264
+
265
+ # tokenize conversations
266
+ def get_tokenize_len(prompts):
267
+ return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
268
+
269
+ if has_image:
270
+ input_ids = [
271
+ tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
272
+ for prompt in conversations
273
+ ]
274
+ else:
275
+ conversations_tokenized = _tokenize_fn(conversations, tokenizer)
276
+ input_ids = conversations_tokenized["input_ids"]
277
+
278
+ targets = copy.deepcopy(input_ids)
279
+ for target, source in zip(targets, sources):
280
+ if has_image:
281
+ tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
282
+ else:
283
+ tokenized_lens = _tokenize_fn(
284
+ [header] + [s["value"] for s in source], tokenizer
285
+ )["input_ids_lens"]
286
+ speakers = [sentence["from"] for sentence in source]
287
+ _mask_targets(target, tokenized_lens, speakers)
288
+
289
+ return dict(input_ids=input_ids, labels=targets)
290
+
291
+
292
+ def load_video(video_file):
293
+ from decord import VideoReader
294
+
295
+ vr = VideoReader(video_file)
296
+
297
+ # Get video frame rate
298
+ fps = vr.get_avg_fps()
299
+
300
+ # Calculate frame indices for 1fps
301
+ frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))]
302
+ frames = vr.get_batch(frame_indices).asnumpy()
303
+ return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))]
304
+
305
+
306
+ def expand2square(pil_img, background_color):
307
+ width, height = pil_img.size
308
+ if width == height:
309
+ return pil_img
310
+ elif width > height:
311
+ result = Image.new(pil_img.mode, (width, width), background_color)
312
+ result.paste(pil_img, (0, (width - height) // 2))
313
+ return result
314
+ else:
315
+ result = Image.new(pil_img.mode, (height, height), background_color)
316
+ result.paste(pil_img, ((height - width) // 2, 0))
317
+ return result
DeQA-Score/src/evaluate/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .scorer import Scorer
DeQA-Score/src/evaluate/cal_distribution_gap.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+
4
+ import numpy as np
5
+
6
+
7
+ def parse_args():
8
+ parser = argparse.ArgumentParser(description="evaluation parameters for DeQA-Score")
9
+ parser.add_argument("--level_names", type=str, required=True, nargs="+")
10
+ parser.add_argument("--pred_paths", type=str, required=True, nargs="+")
11
+ parser.add_argument("--gt_paths", type=str, required=True, nargs="+")
12
+ parser.add_argument("--use_openset_probs", action="store_true")
13
+ args = parser.parse_args()
14
+ return args
15
+
16
+
17
+ def kl_divergence(mu_1, mu_2, sigma_1, sigma_2):
18
+ """
19
+ Calculate the Kullback-Leibler (KL) divergence between two Gaussian distributions for numpy arrays.
20
+
21
+ Parameters:
22
+ mu_1 (np.array): Mean of the first distribution (array of size N).
23
+ mu_2 (np.array): Mean of the second distribution (array of size N).
24
+ sigma_1 (np.array): Standard deviation of the first distribution (array of size N).
25
+ sigma_2 (np.array): Standard deviation of the second distribution (array of size N).
26
+
27
+ Returns:
28
+ np.array: KL divergence from distribution 1 to distribution 2 (array of size N).
29
+ """
30
+ eps = 1e-8
31
+ return np.log(sigma_2 / (sigma_1 + eps)) + (sigma_1**2 + (mu_1 - mu_2)**2) / (2 * sigma_2**2 + eps) - 0.5
32
+
33
+
34
+ def js_divergence(mu_1, mu_2, sigma_1, sigma_2):
35
+ """
36
+ Calculate the Jensen-Shannon (JS) divergence between two Gaussian distributions for numpy arrays.
37
+
38
+ Parameters:
39
+ mu_1 (np.array): Mean of the first distribution (array of size N).
40
+ mu_2 (np.array): Mean of the second distribution (array of size N).
41
+ sigma_1 (np.array): Standard deviation of the first distribution (array of size N).
42
+ sigma_2 (np.array): Standard deviation of the second distribution (array of size N).
43
+
44
+ Returns:
45
+ np.array: JS divergence between the two distributions (array of size N).
46
+ """
47
+ # Midpoint distribution parameters
48
+ mu_m = 0.5 * (mu_1 + mu_2)
49
+ sigma_m = np.sqrt(0.5 * (sigma_1**2 + sigma_2**2))
50
+
51
+ # JS divergence as the average of the KL divergences
52
+ return 0.5 * kl_divergence(mu_1, mu_m, sigma_1, sigma_m) + 0.5 * kl_divergence(mu_2, mu_m, sigma_2, sigma_m)
53
+
54
+
55
+ def wasserstein_distance(mu_1, mu_2, sigma_1, sigma_2):
56
+ """
57
+ Calculate the Wasserstein distance between two Gaussian distributions for numpy arrays.
58
+
59
+ Parameters:
60
+ mu_1 (np.array): Mean of the first distribution (array of size N).
61
+ mu_2 (np.array): Mean of the second distribution (array of size N).
62
+ sigma_1 (np.array): Standard deviation of the first distribution (array of size N).
63
+ sigma_2 (np.array): Standard deviation of the second distribution (array of size N).
64
+
65
+ Returns:
66
+ np.array: Wasserstein distance between the two distributions (array of size N).
67
+ """
68
+ return np.sqrt((mu_1 - mu_2)**2 + (sigma_1 - sigma_2)**2)
69
+
70
+
71
+ def cal_score(level_names, logits=None, probs=None, use_openset_probs=False):
72
+ if use_openset_probs:
73
+ assert logits is None
74
+ probs = np.array([probs[_] for _ in level_names])
75
+ else:
76
+ assert probs is None
77
+ logprobs = np.array([logits[_] for _ in level_names])
78
+ probs = np.exp(logprobs) / np.sum(np.exp(logprobs))
79
+ score = np.inner(probs, np.array([5., 4., 3., 2., 1.]))
80
+ return score, probs
81
+
82
+
83
+ def cal_std(score, probs):
84
+ variance = (np.array([5., 4., 3., 2., 1.]) - score) * (np.array([5., 4., 3., 2., 1.]) - score)
85
+ variance = np.inner(probs, variance)
86
+ std = np.sqrt(variance)
87
+ return std
88
+
89
+
90
+ if __name__ == "__main__":
91
+ args = parse_args()
92
+ level_names = args.level_names
93
+ pred_paths = args.pred_paths
94
+ gt_paths = args.gt_paths
95
+ use_openset_probs = args.use_openset_probs
96
+
97
+ for pred_path, gt_path in zip(pred_paths, gt_paths):
98
+ print("=" * 100)
99
+ print("Pred: ", pred_path)
100
+ print("GT: ", gt_path)
101
+
102
+ # load predict results
103
+ pred_metas = []
104
+ with open(pred_path) as fr:
105
+ for line in fr:
106
+ pred_meta = json.loads(line)
107
+ pred_metas.append(pred_meta)
108
+
109
+ # load gt results
110
+ with open(gt_path) as fr:
111
+ gt_metas = json.load(fr)
112
+
113
+ pred_metas.sort(key=lambda x: x["id"])
114
+ gt_metas.sort(key=lambda x: x["id"])
115
+
116
+ mu_preds = []
117
+ std_preds = []
118
+ mu_gts = []
119
+ std_gts = []
120
+ for pred_meta, gt_meta in zip(pred_metas, gt_metas):
121
+ assert pred_meta["id"] == gt_meta["id"]
122
+ if use_openset_probs:
123
+ pred_score, probs = cal_score(level_names, logits=pred_meta["logits"], use_openset_probs=True)
124
+ else:
125
+ pred_score, probs = cal_score(level_names, logits=pred_meta["logits"], use_openset_probs=False)
126
+ pred_std = cal_std(pred_score, probs)
127
+ mu_preds.append(pred_score)
128
+ std_preds.append(pred_std)
129
+ mu_gts.append(gt_meta["gt_score_norm"])
130
+ std_gts.append(gt_meta["std_norm"])
131
+
132
+ mu_preds = np.array(mu_preds)
133
+ std_preds = np.array(std_preds)
134
+ mu_gts = np.array(mu_gts)
135
+ std_gts = np.array(std_gts)
136
+
137
+ kl = kl_divergence(mu_gts, mu_preds, std_gts, std_preds).mean()
138
+ js = js_divergence(mu_gts, mu_preds, std_gts, std_preds).mean()
139
+ wd = wasserstein_distance(mu_gts, mu_preds, std_gts, std_preds).mean()
140
+
141
+ print(f"KL: {kl}")
142
+ print(f"JS: {js}")
143
+ print(f"WD: {wd}")
DeQA-Score/src/evaluate/cal_plcc_srcc.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+
4
+ import numpy as np
5
+ from scipy.optimize import curve_fit
6
+ from scipy.stats import pearsonr, spearmanr
7
+
8
+
9
+ def parse_args():
10
+ parser = argparse.ArgumentParser(description="evaluation parameters for DeQA-Score")
11
+ parser.add_argument("--level_names", type=str, required=True, nargs="+")
12
+ parser.add_argument("--pred_paths", type=str, required=True, nargs="+")
13
+ parser.add_argument("--gt_paths", type=str, required=True, nargs="+")
14
+ parser.add_argument("--use_openset_probs", action="store_true")
15
+ args = parser.parse_args()
16
+ return args
17
+
18
+
19
+ def calculate_srcc(pred, mos):
20
+ srcc, _ = spearmanr(pred, mos)
21
+ return srcc
22
+
23
+
24
+ def calculate_plcc(pred, mos):
25
+ plcc, _ = pearsonr(pred, mos)
26
+ return plcc
27
+
28
+
29
+ def fit_curve(x, y, curve_type="logistic_4params"):
30
+ r"""Fit the scale of predict scores to MOS scores using logistic regression suggested by VQEG.
31
+ The function with 4 params is more commonly used.
32
+ The 5 params function takes from DBCNN:
33
+ - https://github.com/zwx8981/DBCNN/blob/master/dbcnn/tools/verify_performance.m
34
+ """
35
+ assert curve_type in [
36
+ "logistic_4params",
37
+ "logistic_5params",
38
+ ], f"curve type should be in [logistic_4params, logistic_5params], but got {curve_type}."
39
+
40
+ betas_init_4params = [np.max(y), np.min(y), np.mean(x), np.std(x) / 4.0]
41
+
42
+ def logistic_4params(x, beta1, beta2, beta3, beta4):
43
+ yhat = (beta1 - beta2) / (1 + np.exp(-(x - beta3) / beta4)) + beta2
44
+ return yhat
45
+
46
+ betas_init_5params = [10, 0, np.mean(y), 0.1, 0.1]
47
+
48
+ def logistic_5params(x, beta1, beta2, beta3, beta4, beta5):
49
+ logistic_part = 0.5 - 1.0 / (1 + np.exp(beta2 * (x - beta3)))
50
+ yhat = beta1 * logistic_part + beta4 * x + beta5
51
+ return yhat
52
+
53
+ if curve_type == "logistic_4params":
54
+ logistic = logistic_4params
55
+ betas_init = betas_init_4params
56
+ elif curve_type == "logistic_5params":
57
+ logistic = logistic_5params
58
+ betas_init = betas_init_5params
59
+
60
+ betas, _ = curve_fit(logistic, x, y, p0=betas_init, maxfev=10000)
61
+ yhat = logistic(x, *betas)
62
+ return yhat
63
+
64
+
65
+ def cal_score(level_names, logits=None, probs=None, use_openset_probs=False):
66
+ if use_openset_probs:
67
+ assert logits is None
68
+ probs = np.array([probs[_] for _ in level_names])
69
+ else:
70
+ assert probs is None
71
+ logprobs = np.array([logits[_] for _ in level_names])
72
+ probs = np.exp(logprobs) / np.sum(np.exp(logprobs))
73
+ score = np.inner(probs, np.array([5., 4., 3., 2., 1.]))
74
+ return score
75
+
76
+
77
+ if __name__ == "__main__":
78
+ args = parse_args()
79
+ level_names = args.level_names
80
+ pred_paths = args.pred_paths
81
+ gt_paths = args.gt_paths
82
+ use_openset_probs = args.use_openset_probs
83
+
84
+ for pred_path, gt_path in zip(pred_paths, gt_paths):
85
+ print("=" * 100)
86
+ print("Pred: ", pred_path)
87
+ print("GT: ", gt_path)
88
+
89
+ # load predict results
90
+ pred_metas = []
91
+ with open(pred_path) as fr:
92
+ for line in fr:
93
+ pred_meta = json.loads(line)
94
+ pred_metas.append(pred_meta)
95
+
96
+ # load gt results
97
+ with open(gt_path) as fr:
98
+ gt_metas = json.load(fr)
99
+
100
+ preds = []
101
+ gts = []
102
+ for pred_meta, gt_meta in zip(pred_metas, gt_metas):
103
+ assert pred_meta["id"] == gt_meta["id"]
104
+ if use_openset_probs:
105
+ pred_score = cal_score(level_names, probs=pred_meta["probs"], use_openset_probs=True)
106
+ else:
107
+ pred_score = cal_score(level_names, logits=pred_meta["logits"], use_openset_probs=False)
108
+ preds.append(pred_score)
109
+ gts.append(gt_meta["gt_score"])
110
+
111
+ preds_fit = fit_curve(preds, gts)
112
+ srcc = calculate_srcc(preds_fit, gts)
113
+ plcc = calculate_plcc(preds_fit, gts)
114
+ print(f"SRCC: {srcc}")
115
+ print(f"PLCC: {plcc}")
DeQA-Score/src/evaluate/eval_qbench_mcq.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+
4
+ from src.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
5
+ from src.conversation import conv_templates, SeparatorStyle
6
+ from src.model.builder import load_pretrained_model
7
+ from src.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
8
+
9
+ from PIL import Image
10
+
11
+ import requests
12
+ from PIL import Image
13
+ from io import BytesIO
14
+ from transformers import TextStreamer
15
+
16
+ import json
17
+ from tqdm import tqdm
18
+
19
+ import os
20
+
21
+ def disable_torch_init():
22
+ """
23
+ Disable the redundant torch default initialization to accelerate model creation.
24
+ """
25
+ import torch
26
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
27
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
28
+
29
+
30
+ def load_image(image_file):
31
+ if image_file.startswith('http://') or image_file.startswith('https://'):
32
+ response = requests.get(image_file)
33
+ image = Image.open(BytesIO(response.content)).convert('RGB')
34
+ else:
35
+ image = Image.open(image_file).convert('RGB')
36
+ return image
37
+
38
+
39
+ def main(args):
40
+ # Model
41
+ disable_torch_init()
42
+
43
+ model_name = get_model_name_from_path(args.model_path)
44
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
45
+
46
+ os.makedirs(args.save_dir, exist_ok=True)
47
+ with open(args.meta_path) as f:
48
+ llvqa_data = json.load(f)
49
+
50
+ pbar = tqdm(total=len(llvqa_data))
51
+
52
+ conv_mode = "mplug_owl2"
53
+
54
+ if args.conv_mode is not None and conv_mode != args.conv_mode:
55
+ print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
56
+ else:
57
+ args.conv_mode = conv_mode
58
+
59
+ conv = conv_templates[args.conv_mode].copy()
60
+ roles = conv.roles
61
+
62
+ correct = 0
63
+ for i, llddata in enumerate((llvqa_data)):
64
+ filename = llddata["img_path"]
65
+
66
+ message = llddata["question"] + "\n"
67
+ for choice, ans in zip(["A.", "B.", "C.", "D."], llddata["candidates"]):
68
+ message += f"{choice} {ans}\n"
69
+ if "correct_ans" in llddata and ans == llddata["correct_ans"]:
70
+ correct_choice = choice[0]
71
+ message = message + "Answer with the option's letter from the given choices directly.\n"
72
+
73
+ inp = message
74
+
75
+ conv = conv_templates[args.conv_mode].copy()
76
+ inp = "The input image:" + DEFAULT_IMAGE_TOKEN + inp
77
+ conv.append_message(conv.roles[0], inp)
78
+ conv.append_message(conv.roles[1], None)
79
+ prompt = conv.get_prompt()
80
+
81
+ print(prompt)
82
+
83
+ image = load_image(os.path.join(args.root_dir, filename))
84
+ image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(model.device)
85
+
86
+ input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
87
+ stop_str = conv.sep if conv.sep_style not in [SeparatorStyle.TWO, SeparatorStyle.TWO_NO_SYS] else conv.sep2
88
+ keywords = [stop_str]
89
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
90
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
91
+
92
+ with torch.inference_mode():
93
+ output_ids = model.generate(
94
+ input_ids,
95
+ attention_mask=torch.ones_like(input_ids),
96
+ images=image_tensor,
97
+ do_sample=False,
98
+ temperature=args.temperature,
99
+ max_new_tokens=args.max_new_tokens,
100
+ num_beams=1,
101
+ streamer=streamer,
102
+ use_cache=True,
103
+ stopping_criteria=[stopping_criteria])
104
+
105
+ outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
106
+ llddata["response"] = outputs
107
+
108
+ if correct_choice in outputs:
109
+ correct += 1
110
+
111
+ pbar.update(1)
112
+ pbar.set_description("[Running Accuracy]: {:.4f},[Response]: {}, [Correct Ans]: {}, , [Prog]: {}".format(correct/(i+1), outputs, llddata.get("correct_ans", -1), i+1))
113
+
114
+ save_path = os.path.join(args.save_dir, os.path.basename(args.meta_path))
115
+ with open(save_path, "a") as fw:
116
+ fw.write(json.dumps(llddata) + "\n")
117
+
118
+ if args.debug:
119
+ print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
120
+
121
+
122
+ if __name__ == "__main__":
123
+ parser = argparse.ArgumentParser()
124
+ parser.add_argument("--model-path", type=str, required=True)
125
+ parser.add_argument("--model-base", type=str, default=None)
126
+ parser.add_argument("--root-dir", type=str, required=True)
127
+ parser.add_argument("--save-dir", type=str, required=True)
128
+ parser.add_argument("--meta-path", type=str, required=True)
129
+ parser.add_argument("--device", type=str, default="cuda")
130
+ parser.add_argument("--conv-mode", type=str, default=None)
131
+ parser.add_argument("--temperature", type=float, default=0.2)
132
+ parser.add_argument("--max-new-tokens", type=int, default=512)
133
+ parser.add_argument("--load-8bit", action="store_true")
134
+ parser.add_argument("--load-4bit", action="store_true")
135
+ parser.add_argument("--debug", action="store_true")
136
+ parser.add_argument("--image-aspect-ratio", type=str, default='pad')
137
+ args = parser.parse_args()
138
+ main(args)
DeQA-Score/src/evaluate/iqa_eval.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ from collections import defaultdict
5
+ from io import BytesIO
6
+
7
+ import requests
8
+ import torch
9
+ from PIL import Image
10
+ from tqdm import tqdm
11
+
12
+ from src.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
13
+ from src.conversation import conv_templates
14
+ from src.mm_utils import get_model_name_from_path, tokenizer_image_token
15
+ from src.model.builder import load_pretrained_model
16
+
17
+
18
+ def disable_torch_init():
19
+ """
20
+ Disable the redundant torch default initialization to accelerate model creation.
21
+ """
22
+ import torch
23
+
24
+ setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
25
+ setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
26
+
27
+
28
+ def load_image(image_file):
29
+ if image_file.startswith("http://") or image_file.startswith("https://"):
30
+ response = requests.get(image_file)
31
+ image = Image.open(BytesIO(response.content)).convert("RGB")
32
+ else:
33
+ image = Image.open(image_file).convert("RGB")
34
+ return image
35
+
36
+
37
+ def main(args):
38
+ # Model
39
+ disable_torch_init()
40
+
41
+ model_name = get_model_name_from_path(args.model_path)
42
+ tokenizer, model, image_processor, context_len = load_pretrained_model(
43
+ args.model_path,
44
+ args.model_base,
45
+ model_name,
46
+ args.load_8bit,
47
+ args.load_4bit,
48
+ device=args.device,
49
+ preprocessor_path=args.preprocessor_path,
50
+ )
51
+
52
+ meta_paths = args.meta_paths
53
+ root_dir = args.root_dir
54
+ batch_size = args.batch_size
55
+ save_dir = args.save_dir
56
+ os.makedirs(save_dir, exist_ok=True)
57
+ with_prob = args.with_prob
58
+
59
+ conv_mode = "mplug_owl2"
60
+ inp = "How would you rate the quality of this image?"
61
+
62
+ conv = conv_templates[conv_mode].copy()
63
+ inp = inp + "\n" + DEFAULT_IMAGE_TOKEN
64
+ conv.append_message(conv.roles[0], inp)
65
+ image = None
66
+
67
+ conv.append_message(conv.roles[1], None)
68
+ prompt = conv.get_prompt() + " The quality of the image is"
69
+
70
+ toks = args.level_names
71
+ print(toks)
72
+ ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]]
73
+ print(ids_)
74
+
75
+ input_ids = (
76
+ tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
77
+ .unsqueeze(0)
78
+ .to(args.device)
79
+ )
80
+
81
+ for meta_path in meta_paths:
82
+ with open(meta_path) as f:
83
+ iqadata = json.load(f)
84
+
85
+ image_tensors = []
86
+ batch_data = []
87
+
88
+ imgs_handled = []
89
+ save_path = os.path.join(save_dir, os.path.basename(meta_path))
90
+ if os.path.exists(save_path):
91
+ with open(save_path) as fr:
92
+ for line in fr:
93
+ meta_res = json.loads(line)
94
+ imgs_handled.append(meta_res["image"])
95
+
96
+ meta_name = os.path.basename(meta_path)
97
+ for i, llddata in enumerate(tqdm(iqadata, desc=f"Evaluating [{meta_name}]")):
98
+ try:
99
+ filename = llddata["image"]
100
+ except:
101
+ filename = llddata["img_path"]
102
+ if filename in imgs_handled:
103
+ continue
104
+
105
+ llddata["logits"] = defaultdict(float)
106
+ llddata["probs"] = defaultdict(float)
107
+
108
+ image = load_image(os.path.join(root_dir, filename))
109
+
110
+ def expand2square(pil_img, background_color):
111
+ width, height = pil_img.size
112
+ if width == height:
113
+ return pil_img
114
+ elif width > height:
115
+ result = Image.new(pil_img.mode, (width, width), background_color)
116
+ result.paste(pil_img, (0, (width - height) // 2))
117
+ return result
118
+ else:
119
+ result = Image.new(pil_img.mode, (height, height), background_color)
120
+ result.paste(pil_img, ((height - width) // 2, 0))
121
+ return result
122
+
123
+ image = expand2square(
124
+ image, tuple(int(x * 255) for x in image_processor.image_mean)
125
+ )
126
+ image_tensor = (
127
+ image_processor.preprocess(image, return_tensors="pt")["pixel_values"]
128
+ .half()
129
+ .to(args.device)
130
+ )
131
+
132
+ image_tensors.append(image_tensor)
133
+ batch_data.append(llddata)
134
+
135
+ if (i + 1) % batch_size == 0 or i == len(iqadata) - 1:
136
+ with torch.inference_mode():
137
+ output_logits = model(
138
+ input_ids=input_ids.repeat(len(image_tensors), 1),
139
+ images=torch.cat(image_tensors, 0),
140
+ )["logits"][:, -1]
141
+ if with_prob:
142
+ output_probs = torch.softmax(output_logits, dim=1)
143
+
144
+ for j, xllddata in enumerate(batch_data):
145
+ for tok, id_ in zip(toks, ids_):
146
+ xllddata["logits"][tok] += output_logits[j, id_].item()
147
+ if with_prob:
148
+ xllddata["probs"][tok] += output_probs[j, id_].item()
149
+ meta_res = {
150
+ "id": xllddata["id"],
151
+ "image": xllddata["image"],
152
+ "gt_score": xllddata["gt_score"],
153
+ "logits": xllddata["logits"],
154
+ }
155
+ if with_prob:
156
+ meta_res["probs"] = xllddata["probs"]
157
+ with open(save_path, "a") as fw:
158
+ fw.write(json.dumps(meta_res) + "\n")
159
+
160
+ image_tensors = []
161
+ batch_data = []
162
+
163
+
164
+ if __name__ == "__main__":
165
+ parser = argparse.ArgumentParser()
166
+ parser.add_argument("--model-path", type=str, required=True)
167
+ parser.add_argument("--model-base", type=str, default=None)
168
+ parser.add_argument("--preprocessor-path", type=str, default=None)
169
+ parser.add_argument("--meta-paths", type=str, required=True, nargs="+")
170
+ parser.add_argument("--root-dir", type=str, required=True)
171
+ parser.add_argument("--save-dir", type=str, default="results")
172
+ parser.add_argument("--level-names", type=str, required=True, nargs="+")
173
+ parser.add_argument("--with-prob", type=bool, default=False) # whether to save openset prob
174
+ parser.add_argument("--device", type=str, default="cuda:0")
175
+ parser.add_argument("--conv-mode", type=str, default=None)
176
+ parser.add_argument("--batch-size", type=int, default=16)
177
+ parser.add_argument("--temperature", type=float, default=0.2)
178
+ parser.add_argument("--max-new-tokens", type=int, default=512)
179
+ parser.add_argument("--load-8bit", action="store_true")
180
+ parser.add_argument("--load-4bit", action="store_true")
181
+ parser.add_argument("--debug", action="store_true")
182
+ parser.add_argument("--image-aspect-ratio", type=str, default="pad")
183
+ args = parser.parse_args()
184
+ main(args)
DeQA-Score/src/evaluate/scorer.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+
3
+ import torch.nn as nn
4
+ import torch
5
+
6
+ from typing import List
7
+
8
+ from src.model.builder import load_pretrained_model
9
+
10
+ from src.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
11
+ from src.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
12
+
13
+
14
+ class Scorer(nn.Module):
15
+ def __init__(self, pretrained="zhiyuanyou/DeQA-Score-Mix3", device="cuda:0"):
16
+ super().__init__()
17
+ tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device)
18
+ prompt = "USER: How would you rate the quality of this image?\n<|image|>\nASSISTANT: The quality of the image is"
19
+
20
+ self.preferential_ids_ = [id_[1] for id_ in tokenizer(["excellent","good","fair","poor","bad"])["input_ids"]]
21
+ self.weight_tensor = torch.Tensor([5.,4.,3.,2.,1.]).half().to(model.device)
22
+
23
+ self.tokenizer = tokenizer
24
+ self.model = model
25
+ self.image_processor = image_processor
26
+ self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
27
+
28
+ def expand2square(self, pil_img, background_color):
29
+ width, height = pil_img.size
30
+ if width == height:
31
+ return pil_img
32
+ elif width > height:
33
+ result = Image.new(pil_img.mode, (width, width), background_color)
34
+ result.paste(pil_img, (0, (width - height) // 2))
35
+ return result
36
+ else:
37
+ result = Image.new(pil_img.mode, (height, height), background_color)
38
+ result.paste(pil_img, ((height - width) // 2, 0))
39
+ return result
40
+
41
+ def forward(self, image: List[Image.Image]):
42
+ image = [self.expand2square(img, tuple(int(x*255) for x in self.image_processor.image_mean)) for img in image]
43
+ with torch.inference_mode():
44
+ image_tensor = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().to(self.model.device)
45
+ output_logits = self.model(
46
+ input_ids=self.input_ids.repeat(image_tensor.shape[0], 1),
47
+ images=image_tensor
48
+ )["logits"][:,-1, self.preferential_ids_]
49
+
50
+ return torch.softmax(output_logits, -1) @ self.weight_tensor
51
+
52
+
53
+ if __name__ == "__main__":
54
+ import argparse
55
+
56
+ parser = argparse.ArgumentParser()
57
+ parser.add_argument("--model-path", type=str, default="zhiyuanyou/DeQA-Score-Mix3")
58
+ parser.add_argument("--device", type=str, default="cuda:0")
59
+ parser.add_argument("--img_path", type=str, default="fig/singapore_flyer.jpg")
60
+ args = parser.parse_args()
61
+
62
+ scorer = Scorer(pretrained=args.model_path, device=args.device)
63
+ print(scorer([Image.open(args.img_path)]).tolist())
DeQA-Score/src/evaluate/scorer_coco.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import torch.nn as nn
3
+ import torch
4
+ from typing import List
5
+ from src.model.builder import load_pretrained_model
6
+ from src.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
7
+ from src.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
8
+ import os
9
+
10
+ BASE_DIR = os.path.dirname(os.path.abspath(__file__))
11
+
12
+ def resolve_path(*parts):
13
+ return os.path.abspath(os.path.join(BASE_DIR, *parts))
14
+
15
+
16
+ class Scorer(nn.Module):
17
+ def __init__(self, pretrained="zhiyuanyou/DeQA-Score-Mix3", device="cuda:0"):
18
+ super().__init__()
19
+ tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device)
20
+ prompt = "USER: How would you rate the quality of this image?\n<|image|>\nASSISTANT: The quality of the image is"
21
+
22
+ self.preferential_ids_ = [id_[1] for id_ in tokenizer(["excellent","good","fair","poor","bad"])["input_ids"]]
23
+ self.weight_tensor = torch.Tensor([5.,4.,3.,2.,1.]).half().to(model.device)
24
+
25
+ self.tokenizer = tokenizer
26
+ self.model = model
27
+ self.image_processor = image_processor
28
+ self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
29
+
30
+ def expand2square(self, pil_img, background_color):
31
+ width, height = pil_img.size
32
+ if width == height:
33
+ return pil_img
34
+ elif width > height:
35
+ result = Image.new(pil_img.mode, (width, width), background_color)
36
+ result.paste(pil_img, (0, (width - height) // 2))
37
+ return result
38
+ else:
39
+ result = Image.new(pil_img.mode, (height, height), background_color)
40
+ result.paste(pil_img, ((height - width) // 2, 0))
41
+ return result
42
+
43
+ def forward(self, image: List[Image.Image]):
44
+ image = [self.expand2square(img, tuple(int(x*255) for x in self.image_processor.image_mean)) for img in image]
45
+ with torch.inference_mode():
46
+ image_tensor = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().to(self.model.device)
47
+ output_logits = self.model(
48
+ input_ids=self.input_ids.repeat(image_tensor.shape[0], 1),
49
+ images=image_tensor
50
+ )["logits"][:,-1, self.preferential_ids_]
51
+
52
+ return torch.softmax(output_logits, -1) @ self.weight_tensor
53
+
54
+
55
+ if __name__ == "__main__":
56
+ import argparse
57
+
58
+ parser = argparse.ArgumentParser()
59
+ parser.add_argument("--model-path", type=str, default="zhiyuanyou/DeQA-Score-Mix3")
60
+ parser.add_argument("--device", type=str, default="cuda:0")
61
+ parser.add_argument("--img_path", type=str, default="fig/singapore_flyer.jpg")
62
+ args = parser.parse_args()
63
+
64
+ scorer = Scorer(pretrained=args.model_path, device=args.device)
65
+
66
+ from PIL import Image, ImageFile
67
+ from pycocotools.coco import COCO
68
+ from tqdm import tqdm
69
+ import os
70
+
71
+ data_IQA = {
72
+ "captions": [],
73
+ "IQAs": [],
74
+ "image_ids": []
75
+ }
76
+
77
+
78
+ ANN_PATH = resolve_path('../../../', 'coco_data', 'annotations', 'captions_train2017.json')
79
+ IMG_DIR = resolve_path('../../../', 'coco_data', 'train2017')
80
+
81
+
82
+ coco = COCO(ANN_PATH)
83
+ img_ids = coco.getImgIds()
84
+
85
+ for img_id in tqdm(img_ids):
86
+ img_info = coco.loadImgs(img_id)[0]
87
+ file_name = img_info["file_name"]
88
+ img_path = os.path.join(IMG_DIR, file_name)
89
+
90
+ IQA_score = scorer([Image.open(img_path).convert("RGB")])
91
+
92
+ ann_ids = coco.getAnnIds(imgIds=img_id)
93
+ anns = coco.loadAnns(ann_ids)
94
+ for ann in anns:
95
+ data_IQA["captions"].append(ann["caption"].strip())
96
+ data_IQA["image_ids"].append(img_id)
97
+ data_IQA["IQAs"].append(IQA_score.detach().cpu().item())
98
+ caption = ann["caption"].strip()
99
+
100
+
101
+ save_path = resolve_path('../../../', 'processed_data', 'coco', 'data_IQA.pt')
102
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
103
+ torch.save(data_IQA, save_path)
DeQA-Score/src/mm_utils.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from io import BytesIO
3
+ import base64
4
+
5
+ import torch
6
+ from transformers import StoppingCriteria
7
+ from src.constants import IMAGE_TOKEN_INDEX,DEFAULT_IMAGE_TOKEN
8
+ from icecream import ic
9
+
10
+
11
+ def load_image_from_base64(image):
12
+ return Image.open(BytesIO(base64.b64decode(image)))
13
+
14
+
15
+ def expand2square(pil_img, background_color):
16
+ width, height = pil_img.size
17
+ if width == height:
18
+ return pil_img
19
+ elif width > height:
20
+ result = Image.new(pil_img.mode, (width, width), background_color)
21
+ result.paste(pil_img, (0, (width - height) // 2))
22
+ return result
23
+ else:
24
+ result = Image.new(pil_img.mode, (height, height), background_color)
25
+ result.paste(pil_img, ((height - width) // 2, 0))
26
+ return result
27
+
28
+
29
+ def process_images(images, image_processor, model_cfg=None):
30
+ if model_cfg is not None:
31
+ image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
32
+ else:
33
+ image_aspect_ratio = 'resize'
34
+ new_images = []
35
+ if image_aspect_ratio == 'pad':
36
+ for image in images:
37
+ image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
38
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
39
+ new_images.append(image)
40
+ elif image_aspect_ratio == 'resize':
41
+ for image in images:
42
+ max_edge = max(image.size)
43
+ image = image.resize((max_edge, max_edge))
44
+ image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
45
+ new_images.append(image)
46
+ else:
47
+ return image_processor(images, return_tensors='pt')['pixel_values']
48
+ if all(x.shape == new_images[0].shape for x in new_images):
49
+ new_images = torch.stack(new_images, dim=0)
50
+ return new_images
51
+
52
+
53
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
54
+ prompt_chunks = [tokenizer(chunk).input_ids if len(chunk) > 0 else [] for chunk in prompt.split(DEFAULT_IMAGE_TOKEN)]
55
+
56
+ def insert_separator(X, sep):
57
+ return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
58
+
59
+ input_ids = []
60
+ offset = 0
61
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
62
+ offset = 1
63
+ input_ids.append(prompt_chunks[0][0])
64
+
65
+ for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
66
+ input_ids.extend(x[offset:])
67
+
68
+ if return_tensors is not None:
69
+ if return_tensors == 'pt':
70
+ return torch.tensor(input_ids, dtype=torch.long)
71
+ raise ValueError(f'Unsupported tensor type: {return_tensors}')
72
+ return input_ids
73
+
74
+
75
+ def get_model_name_from_path(model_path):
76
+ model_path = model_path.strip("/")
77
+ model_paths = model_path.split("/")
78
+ if model_paths[-1].startswith('checkpoint-'):
79
+ return model_paths[-2] + "_" + model_paths[-1]
80
+ else:
81
+ return model_paths[-1]
82
+
83
+
84
+
85
+
86
+ class KeywordsStoppingCriteria(StoppingCriteria):
87
+ def __init__(self, keywords, tokenizer, input_ids):
88
+ self.keywords = keywords
89
+ self.keyword_ids = []
90
+ self.max_keyword_len = 0
91
+ for keyword in keywords:
92
+ cur_keyword_ids = tokenizer(keyword).input_ids
93
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
94
+ cur_keyword_ids = cur_keyword_ids[1:]
95
+ if len(cur_keyword_ids) > self.max_keyword_len:
96
+ self.max_keyword_len = len(cur_keyword_ids)
97
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
98
+ self.tokenizer = tokenizer
99
+ self.start_len = input_ids.shape[1]
100
+
101
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
102
+ assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
103
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
104
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
105
+ for keyword_id in self.keyword_ids:
106
+ if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
107
+ return True
108
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
109
+ for keyword in self.keywords:
110
+ if keyword in outputs:
111
+ return True
112
+ return False
DeQA-Score/src/model/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .modeling_mplug_owl2 import MPLUGOwl2LlamaForCausalLM
2
+ from .configuration_mplug_owl2 import MPLUGOwl2Config
DeQA-Score/src/model/builder.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Haotian Liu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import os
17
+ import warnings
18
+
19
+ import torch
20
+ from transformers import (
21
+ AutoConfig,
22
+ AutoModelForCausalLM,
23
+ AutoTokenizer,
24
+ BitsAndBytesConfig,
25
+ )
26
+ from transformers.models.clip.image_processing_clip import CLIPImageProcessor
27
+
28
+ from src.model import *
29
+
30
+
31
+ def load_pretrained_model(
32
+ model_path,
33
+ model_base,
34
+ model_name,
35
+ load_8bit=False,
36
+ load_4bit=False,
37
+ device_map="auto",
38
+ device="cuda",
39
+ preprocessor_path=None,
40
+ ):
41
+ kwargs = {"device_map": device_map}
42
+
43
+ if device != "cuda":
44
+ kwargs["device_map"] = {"": device}
45
+
46
+ if load_8bit:
47
+ kwargs["load_in_8bit"] = True
48
+ elif load_4bit:
49
+ kwargs["load_in_4bit"] = True
50
+ kwargs["quantization_config"] = BitsAndBytesConfig(
51
+ load_in_4bit=True,
52
+ bnb_4bit_compute_dtype=torch.float16,
53
+ bnb_4bit_use_double_quant=True,
54
+ bnb_4bit_quant_type="nf4",
55
+ )
56
+ else:
57
+ kwargs["torch_dtype"] = torch.float16
58
+
59
+ if preprocessor_path is None:
60
+ preprocessor_path = model_path
61
+
62
+ if "deqa" in model_name.lower():
63
+ # Load LLaVA model
64
+ if "lora" in model_name.lower() and model_base is None:
65
+ warnings.warn(
66
+ "There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged."
67
+ )
68
+ if "lora" in model_name.lower() and model_base is not None:
69
+ lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
70
+ tokenizer = AutoTokenizer.from_pretrained(preprocessor_path, use_fast=False)
71
+ print("Loading mPLUG-Owl2 from base model...")
72
+ model = MPLUGOwl2LlamaForCausalLM.from_pretrained(
73
+ model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs
74
+ )
75
+ token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
76
+ if model.lm_head.weight.shape[0] != token_num:
77
+ model.lm_head.weight = torch.nn.Parameter(
78
+ torch.empty(
79
+ token_num, tokem_dim, device=model.device, dtype=model.dtype
80
+ )
81
+ )
82
+ model.model.embed_tokens.weight = torch.nn.Parameter(
83
+ torch.empty(
84
+ token_num, tokem_dim, device=model.device, dtype=model.dtype
85
+ )
86
+ )
87
+
88
+ print("Loading additional mPLUG-Owl2 weights...")
89
+ if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")):
90
+ non_lora_trainables = torch.load(
91
+ os.path.join(model_path, "non_lora_trainables.bin"),
92
+ map_location="cpu",
93
+ )
94
+ print(non_lora_trainables.keys())
95
+ else:
96
+ # this is probably from HF Hub
97
+ from huggingface_hub import hf_hub_download
98
+
99
+ def load_from_hf(repo_id, filename, subfolder=None):
100
+ cache_file = hf_hub_download(
101
+ repo_id=repo_id, filename=filename, subfolder=subfolder
102
+ )
103
+ return torch.load(cache_file, map_location="cpu")
104
+
105
+ non_lora_trainables = load_from_hf(
106
+ model_path, "non_lora_trainables.bin"
107
+ )
108
+ non_lora_trainables = {
109
+ (k[17:] if k.startswith("base_model.model.") else k): v
110
+ for k, v in non_lora_trainables.items()
111
+ }
112
+ model.load_state_dict(non_lora_trainables, strict=False)
113
+
114
+ from peft import PeftModel
115
+
116
+ print("Loading LoRA weights...")
117
+ model = PeftModel.from_pretrained(model, model_path)
118
+ print("Merging LoRA weights...")
119
+ model = model.merge_and_unload()
120
+ print("Model is loaded...")
121
+ elif model_base is not None:
122
+ # this may be mm projector only
123
+ print("Loading mPLUG-Owl2 from base model...")
124
+ tokenizer = AutoTokenizer.from_pretrained(preprocessor_path, use_fast=False)
125
+ cfg_pretrained = AutoConfig.from_pretrained(model_path)
126
+ model = MPLUGOwl2LlamaForCausalLM.from_pretrained(
127
+ model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs
128
+ )
129
+ else:
130
+ tokenizer = AutoTokenizer.from_pretrained(preprocessor_path, use_fast=False)
131
+ model = MPLUGOwl2LlamaForCausalLM.from_pretrained(
132
+ model_path, low_cpu_mem_usage=True, **kwargs
133
+ )
134
+ else:
135
+ # Load language model
136
+ if model_base is not None:
137
+ # PEFT model
138
+ from peft import PeftModel
139
+
140
+ tokenizer = AutoTokenizer.from_pretrained(preprocessor_path, use_fast=False)
141
+ model = AutoModelForCausalLM.from_pretrained(
142
+ model_base, low_cpu_mem_usage=True, **kwargs
143
+ )
144
+ print(f"Loading LoRA weights from {model_path}")
145
+ model = PeftModel.from_pretrained(model, model_path)
146
+ print(f"Merging weights")
147
+ model = model.merge_and_unload()
148
+ print("Convert to FP16...")
149
+ model.to(torch.float16)
150
+ else:
151
+ tokenizer = AutoTokenizer.from_pretrained(preprocessor_path, use_fast=False)
152
+ model = AutoModelForCausalLM.from_pretrained(
153
+ model_path, low_cpu_mem_usage=True, **kwargs
154
+ )
155
+
156
+ # vision_tower = model.get_model().vision_model
157
+ # print(vision_tower.device)
158
+ # vision_tower.to(device=device, dtype=torch.float16)
159
+ image_processor = CLIPImageProcessor.from_pretrained(preprocessor_path)
160
+
161
+ if hasattr(model.config, "max_sequence_length"):
162
+ context_len = model.config.max_sequence_length
163
+ else:
164
+ context_len = 2048
165
+
166
+ return tokenizer, model, image_processor, context_len
DeQA-Score/src/model/configuration_mplug_owl2.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Alibaba.
2
+ #
3
+ # This source code is licensed under the license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import copy
6
+ import os
7
+ from typing import Union
8
+
9
+ from transformers.configuration_utils import PretrainedConfig
10
+ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
11
+ from transformers.utils import logging
12
+ from transformers.models.auto import CONFIG_MAPPING
13
+
14
+
15
+ class LlamaConfig(PretrainedConfig):
16
+ r"""
17
+ This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
18
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
19
+ defaults will yield a similar configuration to that of the LLaMA-7B.
20
+
21
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
22
+ documentation from [`PretrainedConfig`] for more information.
23
+
24
+
25
+ Args:
26
+ vocab_size (`int`, *optional*, defaults to 32000):
27
+ Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
28
+ `inputs_ids` passed when calling [`LlamaModel`]
29
+ hidden_size (`int`, *optional*, defaults to 4096):
30
+ Dimension of the hidden representations.
31
+ intermediate_size (`int`, *optional*, defaults to 11008):
32
+ Dimension of the MLP representations.
33
+ num_hidden_layers (`int`, *optional*, defaults to 32):
34
+ Number of hidden layers in the Transformer decoder.
35
+ num_attention_heads (`int`, *optional*, defaults to 32):
36
+ Number of attention heads for each attention layer in the Transformer decoder.
37
+ num_key_value_heads (`int`, *optional*):
38
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
39
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
40
+ `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
41
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
42
+ by meanpooling all the original heads within that group. For more details checkout [this
43
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
44
+ `num_attention_heads`.
45
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
46
+ The non-linear activation function (function or string) in the decoder.
47
+ max_position_embeddings (`int`, *optional*, defaults to 2048):
48
+ The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
49
+ Llama 2 up to 4096, CodeLlama up to 16384.
50
+ initializer_range (`float`, *optional*, defaults to 0.02):
51
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
52
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
53
+ The epsilon used by the rms normalization layers.
54
+ use_cache (`bool`, *optional*, defaults to `True`):
55
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
56
+ relevant if `config.is_decoder=True`.
57
+ pad_token_id (`int`, *optional*):
58
+ Padding token id.
59
+ bos_token_id (`int`, *optional*, defaults to 1):
60
+ Beginning of stream token id.
61
+ eos_token_id (`int`, *optional*, defaults to 2):
62
+ End of stream token id.
63
+ pretraining_tp (`int`, *optional*, defaults to 1):
64
+ Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
65
+ document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
66
+ necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
67
+ issue](https://github.com/pytorch/pytorch/issues/76232).
68
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
69
+ Whether to tie weight embeddings
70
+ rope_theta (`float`, *optional*, defaults to 10000.0):
71
+ The base period of the RoPE embeddings.
72
+ rope_scaling (`Dict`, *optional*):
73
+ Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
74
+ strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
75
+ `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
76
+ `max_position_embeddings` to the expected new maximum. See the following thread for more information on how
77
+ these scaling strategies behave:
78
+ https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
79
+ experimental feature, subject to breaking API changes in future versions.
80
+ attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
81
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
82
+
83
+
84
+ ```python
85
+ >>> from transformers import LlamaModel, LlamaConfig
86
+
87
+ >>> # Initializing a LLaMA llama-7b style configuration
88
+ >>> configuration = LlamaConfig()
89
+
90
+ >>> # Initializing a model from the llama-7b style configuration
91
+ >>> model = LlamaModel(configuration)
92
+
93
+ >>> # Accessing the model configuration
94
+ >>> configuration = model.config
95
+ ```"""
96
+ model_type = "llama"
97
+ keys_to_ignore_at_inference = ["past_key_values"]
98
+
99
+ def __init__(
100
+ self,
101
+ vocab_size=32000,
102
+ hidden_size=4096,
103
+ intermediate_size=11008,
104
+ num_hidden_layers=32,
105
+ num_attention_heads=32,
106
+ num_key_value_heads=None,
107
+ hidden_act="silu",
108
+ max_position_embeddings=2048,
109
+ initializer_range=0.02,
110
+ rms_norm_eps=1e-6,
111
+ use_cache=True,
112
+ pad_token_id=None,
113
+ bos_token_id=1,
114
+ eos_token_id=2,
115
+ pretraining_tp=1,
116
+ tie_word_embeddings=False,
117
+ rope_theta=10000.0,
118
+ rope_scaling=None,
119
+ attention_bias=False,
120
+ attention_dropout=0.0,
121
+ **kwargs,
122
+ ):
123
+ self.vocab_size = vocab_size
124
+ self.max_position_embeddings = max_position_embeddings
125
+ self.hidden_size = hidden_size
126
+ self.intermediate_size = intermediate_size
127
+ self.num_hidden_layers = num_hidden_layers
128
+ self.num_attention_heads = num_attention_heads
129
+
130
+ # for backward compatibility
131
+ if num_key_value_heads is None:
132
+ num_key_value_heads = num_attention_heads
133
+
134
+ self.num_key_value_heads = num_key_value_heads
135
+ self.hidden_act = hidden_act
136
+ self.initializer_range = initializer_range
137
+ self.rms_norm_eps = rms_norm_eps
138
+ self.pretraining_tp = pretraining_tp
139
+ self.use_cache = use_cache
140
+ self.rope_theta = rope_theta
141
+ self.rope_scaling = rope_scaling
142
+ self._rope_scaling_validation()
143
+ self.attention_bias = attention_bias
144
+ self.attention_dropout = attention_dropout
145
+
146
+ super().__init__(
147
+ pad_token_id=pad_token_id,
148
+ bos_token_id=bos_token_id,
149
+ eos_token_id=eos_token_id,
150
+ tie_word_embeddings=tie_word_embeddings,
151
+ **kwargs,
152
+ )
153
+
154
+ def _rope_scaling_validation(self):
155
+ """
156
+ Validate the `rope_scaling` configuration.
157
+ """
158
+ if self.rope_scaling is None:
159
+ return
160
+
161
+ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
162
+ raise ValueError(
163
+ "`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
164
+ f"got {self.rope_scaling}"
165
+ )
166
+ rope_scaling_type = self.rope_scaling.get("type", None)
167
+ rope_scaling_factor = self.rope_scaling.get("factor", None)
168
+ if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
169
+ raise ValueError(
170
+ f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
171
+ )
172
+ if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
173
+ raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
174
+
175
+
176
+ class MplugOwlVisionConfig(PretrainedConfig):
177
+ r"""
178
+ This is the configuration class to store the configuration of a [`MplugOwlVisionModel`]. It is used to instantiate
179
+ a
180
+ mPLUG-Owl vision encoder according to the specified arguments, defining the model architecture. Instantiating a
181
+ configuration defaults will yield a similar configuration to that of the mPLUG-Owl
182
+ [x-plug/x_plug-llama-7b](https://huggingface.co/x-plug/x_plug-llama-7b) architecture.
183
+
184
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
185
+ documentation from [`PretrainedConfig`] for more information.
186
+
187
+ Args:
188
+ hidden_size (`int`, *optional*, defaults to 768):
189
+ Dimensionality of the encoder layers and the pooler layer.
190
+ intermediate_size (`int`, *optional*, defaults to 3072):
191
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
192
+ num_hidden_layers (`int`, *optional*, defaults to 12):
193
+ Number of hidden layers in the Transformer encoder.
194
+ num_attention_heads (`int`, *optional*, defaults to 12):
195
+ Number of attention heads for each attention layer in the Transformer encoder.
196
+ image_size (`int`, *optional*, defaults to 224):
197
+ The size (resolution) of each image.
198
+ patch_size (`int`, *optional*, defaults to 32):
199
+ The size (resolution) of each patch.
200
+ hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
201
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
202
+ `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
203
+ layer_norm_eps (`float`, *optional*, defaults to 1e-5):
204
+ The epsilon used by the layer normalization layers.
205
+ attention_dropout (`float`, *optional*, defaults to 0.0):
206
+ The dropout ratio for the attention probabilities.
207
+ initializer_range (`float`, *optional*, defaults to 0.02):
208
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
209
+ initializer_factor (`float`, *optional*, defaults to 1):
210
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
211
+ testing).
212
+
213
+
214
+ ```"""
215
+
216
+ model_type = "mplug_owl_vision_model"
217
+
218
+ def __init__(
219
+ self,
220
+ hidden_size=1024,
221
+ intermediate_size=4096,
222
+ projection_dim=768,
223
+ num_hidden_layers=24,
224
+ num_attention_heads=16,
225
+ num_channels=3,
226
+ image_size=448,
227
+ patch_size=14,
228
+ hidden_act="quick_gelu",
229
+ layer_norm_eps=1e-6,
230
+ attention_dropout=0.0,
231
+ initializer_range=0.02,
232
+ initializer_factor=1.0,
233
+ use_flash_attn=False,
234
+ **kwargs,
235
+ ):
236
+ super().__init__(**kwargs)
237
+ self.hidden_size = hidden_size
238
+ self.intermediate_size = intermediate_size
239
+ self.projection_dim = projection_dim
240
+ self.num_hidden_layers = num_hidden_layers
241
+ self.num_attention_heads = num_attention_heads
242
+ self.num_channels = num_channels
243
+ self.patch_size = patch_size
244
+ self.image_size = image_size
245
+ self.initializer_range = initializer_range
246
+ self.initializer_factor = initializer_factor
247
+ self.attention_dropout = attention_dropout
248
+ self.layer_norm_eps = layer_norm_eps
249
+ self.hidden_act = hidden_act
250
+ self.use_flash_attn = use_flash_attn
251
+
252
+ @classmethod
253
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
254
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
255
+
256
+ # get the vision config dict if we are loading from MplugOwlConfig
257
+ if config_dict.get("model_type") == "mplug-owl":
258
+ config_dict = config_dict["vision_config"]
259
+
260
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
261
+ logger.warning(
262
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
263
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
264
+ )
265
+
266
+ return cls.from_dict(config_dict, **kwargs)
267
+
268
+
269
+ class MplugOwlVisualAbstractorConfig(PretrainedConfig):
270
+ model_type = "mplug_owl_visual_abstract"
271
+
272
+ def __init__(
273
+ self,
274
+ num_learnable_queries=64,
275
+ hidden_size=1024,
276
+ num_hidden_layers=6,
277
+ num_attention_heads=16,
278
+ intermediate_size=2816,
279
+ attention_probs_dropout_prob=0.,
280
+ initializer_range=0.02,
281
+ layer_norm_eps=1e-6,
282
+ encoder_hidden_size=1024,
283
+ grid_size=None,
284
+ **kwargs,
285
+ ):
286
+ super().__init__(**kwargs)
287
+ self.hidden_size = hidden_size
288
+ self.num_learnable_queries = num_learnable_queries
289
+ self.num_hidden_layers = num_hidden_layers
290
+ self.num_attention_heads = num_attention_heads
291
+ self.intermediate_size = intermediate_size
292
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
293
+ self.initializer_range = initializer_range
294
+ self.layer_norm_eps = layer_norm_eps
295
+ self.encoder_hidden_size = encoder_hidden_size
296
+ self.grid_size = grid_size if grid_size else 32
297
+
298
+ @classmethod
299
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
300
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
301
+
302
+ # get the visual_abstractor config dict if we are loading from MplugOwlConfig
303
+ if config_dict.get("model_type") == "mplug-owl":
304
+ config_dict = config_dict["abstractor_config"]
305
+
306
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
307
+ logger.warning(
308
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
309
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
310
+ )
311
+
312
+ return cls.from_dict(config_dict, **kwargs)
313
+
314
+
315
+
316
+ DEFAULT_VISUAL_CONFIG = {
317
+ "visual_model": MplugOwlVisionConfig().to_dict(),
318
+ "visual_abstractor": MplugOwlVisualAbstractorConfig().to_dict()
319
+ }
320
+
321
+ class MPLUGOwl2Config(LlamaConfig):
322
+ model_type = "mplug_owl2"
323
+ def __init__(self, visual_config=None, **kwargs):
324
+ if visual_config is None:
325
+ self.visual_config = DEFAULT_VISUAL_CONFIG
326
+ else:
327
+ self.visual_config = visual_config
328
+
329
+ super().__init__(
330
+ **kwargs,
331
+ )
332
+
333
+ if __name__ == "__main__":
334
+ print(MplugOwlVisionConfig().to_dict())
DeQA-Score/src/model/convert_mplug_owl2_weight_to_hf.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DAMO Academy and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import argparse
15
+ import gc
16
+ import json
17
+ import math
18
+ import os
19
+ import shutil
20
+ import warnings
21
+
22
+ import torch
23
+
24
+ from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
25
+ from .configuration_mplug_owl2 import MPLUGOwl2Config, MplugOwlVisionConfig, MplugOwlVisualAbstractorConfig
26
+ from .modeling_mplug_owl2 import MPLUGOwl2LlamaForCausalLM
27
+
28
+ try:
29
+ from transformers import LlamaTokenizerFast
30
+ except ImportError as e:
31
+ warnings.warn(e)
32
+ warnings.warn(
33
+ "The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
34
+ )
35
+ LlamaTokenizerFast = None
36
+
37
+ """
38
+ Sample usage:
39
+
40
+ ```
41
+ python3 /pure-mlo-scratch/sfan/model-parallel-trainer/llama2megatron/convert_llama2hf.py \
42
+ --input_dir /pure-mlo-scratch/llama/ --model_size 7 --output_dir /pure-mlo-scratch/llama/converted_HF_7B
43
+ ```
44
+
45
+ Thereafter, models can be loaded via:
46
+
47
+ ```py
48
+ from transformers import LlamaForCausalLM, LlamaTokenizer
49
+
50
+ model = LlamaForCausalLM.from_pretrained("/output/path")
51
+ tokenizer = LlamaTokenizer.from_pretrained("/output/path")
52
+ ```
53
+
54
+ Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
55
+ come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
56
+ """
57
+
58
+ llama_s2layer = {7: 32, 13: 40, 30: 60, 65: 80, 70: 80}
59
+ llama_s2heads = {7: 32, 13: 40, 30: 52, 65: 64, 70: 64}
60
+ llama_s2dense = {7: 11008, 13: 13824, 30: 17920, 65: 22016,
61
+ 70: 28672} # should be (2/3)*4*d, but it isn't exaclty that
62
+ llama_s2hidden = {7: 4096, 13: 5120, 32: 6656, 65: 8192, 70: 8192}
63
+
64
+
65
+ def compute_intermediate_size(n):
66
+ return int(math.ceil(n * 8 / 3) + 255) // 256 * 256
67
+
68
+
69
+ def read_json(path):
70
+ with open(path, "r") as f:
71
+ return json.load(f)
72
+
73
+
74
+ def write_json(text, path):
75
+ with open(path, "w") as f:
76
+ json.dump(text, f)
77
+
78
+
79
+ def write_model(model_path,
80
+ input_base_path,
81
+ model_size,
82
+ num_input_shards=1,
83
+ num_output_shards=2,
84
+ skip_permute=True,
85
+ norm_eps=1e-05):
86
+ # if os.path.exists(model_path):
87
+ # shutil.rmtree(model_path)
88
+ os.makedirs(model_path, exist_ok=True)
89
+ # tmp_model_path = os.path.join(model_path, "tmp")
90
+ tmp_model_path = model_path
91
+ os.makedirs(tmp_model_path, exist_ok=True)
92
+
93
+ num_shards = num_input_shards
94
+ n_layers = llama_s2layer[model_size]
95
+ n_heads = llama_s2heads[model_size]
96
+ n_heads_per_shard = n_heads // num_shards
97
+ n_dense = llama_s2dense[model_size]
98
+ n_hidden = llama_s2hidden[model_size]
99
+ hidden_per_head = n_hidden // n_heads
100
+ base = 10000.0
101
+ inv_freq = 1.0 / (base ** (torch.arange(0, hidden_per_head, 2).float() / hidden_per_head))
102
+
103
+ # permute for sliced rotary
104
+ def permute(w, skip_permute=skip_permute):
105
+ if skip_permute:
106
+ return w
107
+ return w.view(n_heads, n_hidden // n_heads // 2, 2, n_hidden).transpose(1, 2).reshape(n_hidden, n_hidden)
108
+
109
+ print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
110
+ # Load weights
111
+ if num_shards==1:
112
+ # Not sharded
113
+ # (The sharded implementation would also work, but this is simpler.)
114
+ # /pure-mlo-scratch/alhernan/megatron-data/checkpoints/llama2-7b-tp4-pp1-optim/release/mp_rank_00/model_optim_rng.pt
115
+ if os.path.exists(os.path.join(input_base_path, 'release')):
116
+ filename = os.path.join(input_base_path, 'release', 'mp_rank_00', 'model_optim_rng.pt')
117
+ elif input_base_path.split('/')[-1].startswith('iter_'):
118
+ iteration = eval(input_base_path.split('/')[-1].replace('iter_', '').lstrip('0'))
119
+ load_dir = '/'.join(input_base_path.split('/')[:-1])
120
+ filename = os.path.join(input_base_path, 'mp_rank_00', 'model_optim_rng.pt')
121
+ if not os.path.exists(filename):
122
+ filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
123
+ else:
124
+ tracker_filename = os.path.join(input_base_path, 'latest_checkpointed_iteration.txt')
125
+ with open(tracker_filename, 'r') as f:
126
+ metastring = f.read().strip()
127
+ iteration = 'iter_{:07d}'.format(int(metastring))
128
+ filename = os.path.join(input_base_path, iteration, 'mp_rank_00', 'model_optim_rng.pt')
129
+ if not os.path.exists(filename):
130
+ filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
131
+ original_filename = filename
132
+ loaded = torch.load(filename, map_location="cpu")['model']['language_model']
133
+
134
+ else:
135
+ # Sharded
136
+ filenames = []
137
+ for i in range(num_shards):
138
+ if os.path.exists(os.path.join(input_base_path, 'release')):
139
+ filename = os.path.join(input_base_path, 'release', f'mp_rank_{i:02d}', 'model_optim_rng.pt')
140
+ else:
141
+ tracker_filename = os.path.join(input_base_path, 'latest_checkpointed_iteration.txt')
142
+ with open(tracker_filename, 'r') as f:
143
+ metastring = f.read().strip()
144
+ iteration = 'iter_{:07d}'.format(int(metastring))
145
+ filename = os.path.join(input_base_path, iteration, f'mp_rank_{i:02d}', 'model_optim_rng.pt')
146
+ if not os.path.exists(filename):
147
+ filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
148
+ filenames.append(filename)
149
+ loaded = [
150
+ torch.load(filenames[i], map_location="cpu")['model']['language_model']
151
+ for i in range(num_shards)
152
+ ]
153
+
154
+ print('Llama-Megatron Loaded!')
155
+ param_count = 0
156
+ index_dict = {"weight_map": {}}
157
+
158
+ print(f'Weighted Converting for {n_layers} layers...')
159
+ for layer_i in range(n_layers):
160
+ print(layer_i)
161
+ filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
162
+ if num_shards == 1:
163
+ # Unsharded
164
+ state_dict = {
165
+ f"model.layers.{layer_i}.self_attn.q_proj.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.q_proj.weight"],
166
+ f"model.layers.{layer_i}.self_attn.k_proj.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.k_proj.multiway.0.weight"],
167
+ f"model.layers.{layer_i}.self_attn.v_proj.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.v_proj.multiway.0.weight"],
168
+ f"model.layers.{layer_i}.self_attn.k_proj.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.k_proj.multiway.1.weight"],
169
+ f"model.layers.{layer_i}.self_attn.v_proj.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.v_proj.multiway.1.weight"],
170
+ f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.o_proj.weight"],
171
+ f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.gate_proj.weight"],
172
+ f"model.layers.{layer_i}.mlp.down_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.down_proj.weight"],
173
+ f"model.layers.{layer_i}.mlp.up_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.up_proj.weight"],
174
+ f"model.layers.{layer_i}.input_layernorm.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.input_layernorm.multiway.0.weight"],
175
+ f"model.layers.{layer_i}.post_attention_layernorm.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.post_attention_layernorm.multiway.0.weight"],
176
+ f"model.layers.{layer_i}.input_layernorm.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.input_layernorm.multiway.1.weight"],
177
+ f"model.layers.{layer_i}.post_attention_layernorm.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.post_attention_layernorm.multiway.1.weight"],
178
+ }
179
+ else:
180
+ raise NotImplemented
181
+ # else:
182
+ # # Sharded
183
+ # # Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share
184
+ # # the same storage object, saving attention_norm and ffn_norm will save other weights too, which is
185
+ # # redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned.
186
+
187
+ # state_dict = {
188
+ # f"model.layers.{layer_i}.input_layernorm.weight": loaded[0]['encoder'][
189
+ # f"layers.{layer_i}.input_layernorm.multiway.0.weight"
190
+ # ].clone(),
191
+ # f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0]['encoder'][
192
+ # f"layers.{layer_i}.post_attention_layernorm.multiway.0.weight"
193
+ # ].clone(),
194
+ # }
195
+
196
+ # wqs, wks, wvs, ffn_w1s, ffn_w3s = [], [], [], [], []
197
+ # for shard_idx in range(num_shards):
198
+ # wqs.append(loaded[shard_idx]['encoder'][f"layers.{layer_i}.self_attention.q_proj.weight"])
199
+ # wks.append(loaded[shard_idx]['encoder'][f"layers.{layer_i}.self_attention.k_proj.multiway.0.weight"])
200
+ # wvs.append(loaded[shard_idx]['encoder'][f"layers.{layer_i}.self_attention.v_proj.multiway.0.weight"])
201
+
202
+ # state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
203
+ # torch.cat(
204
+ # [
205
+ # wq.view(n_heads_per_shard, hidden_per_head, n_hidden)
206
+ # for wq in range(wqs)
207
+ # ],
208
+ # dim=0,
209
+ # ).reshape(n_hidden, n_hidden)
210
+ # )
211
+ # state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
212
+ # torch.cat(
213
+ # [
214
+ # wk.view(n_heads_per_shard, hidden_per_head, n_hidden)
215
+ # for wk in range(wks)
216
+ # ],
217
+ # dim=0,
218
+ # ).reshape(n_hidden, n_hidden)
219
+ # )
220
+ # state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
221
+ # [
222
+ # wv.view(n_heads_per_shard, hidden_per_head, n_hidden)
223
+ # for wv in range(wvs)
224
+ # ],
225
+ # dim=0,
226
+ # ).reshape(n_hidden, n_hidden)
227
+
228
+ # state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
229
+ # [loaded[i]['encoder'][f"layers.{layer_i}.self_attention.o_proj.weight"] for i in range(num_shards)], dim=1
230
+ # )
231
+ # state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
232
+ # [loaded[i]['encoder'][f"layers.{layer_i}.mlp.gate_proj.weight"] for i in range(num_shards)], dim=0
233
+ # )
234
+ # state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
235
+ # [loaded[i]['encoder'][f"layers.{layer_i}.mlp.down_proj.weight"] for i in range(num_shards)], dim=1
236
+ # )
237
+ # state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
238
+ # [loaded[i]['encoder'][f"layers.{layer_i}.mlp.up_proj.weight"] for i in range(num_shards)], dim=0
239
+ # )
240
+
241
+ state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
242
+ for k, v in state_dict.items():
243
+ index_dict["weight_map"][k] = filename
244
+ param_count += v.numel()
245
+ torch.save(state_dict, os.path.join(tmp_model_path, filename))
246
+ print(f'Sharded file saved to {filename}')
247
+
248
+ filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
249
+ if num_shards==1:
250
+ # Unsharded
251
+ state_dict = {
252
+ "model.embed_tokens.weight": loaded['embedding']['word_embeddings']['weight'],
253
+ "model.norm.weight": loaded['encoder']['norm.weight'],
254
+ "lm_head.weight": loaded['encoder']['lm_head.weight'],
255
+ }
256
+ else:
257
+ state_dict = {
258
+ "model.embed_tokens.weight": loaded[0]['embedding']['word_embeddings']['weight'],
259
+ "model.norm.weight": loaded[0]['encoder']['norm.weight'],
260
+ "lm_head.weight": loaded[0]['encoder']['lm_head.weight'],
261
+ }
262
+
263
+
264
+ loaded_all = torch.load(original_filename, map_location="cpu")['model']
265
+ # Vision Part
266
+ state_dict.update({
267
+ "model.vision_model.embeddings.cls_token": loaded_all['vision_model']['cls_token'],
268
+ "model.vision_model.embeddings.patch_embed.weight": loaded_all['vision_model']['patch_embed']['weight'],
269
+ "model.vision_model.embeddings.position_embedding": loaded_all['vision_model']['position_embeddings'],
270
+ "model.vision_model.embeddings.pre_layernorm.bias": loaded_all['vision_model']['pre_layernorm']['bias'],
271
+ "model.vision_model.embeddings.pre_layernorm.weight": loaded_all['vision_model']['pre_layernorm']['weight'],
272
+ "model.vision_model.post_layernorm.bias": loaded_all['vision_model']['transformer']['final_layernorm.bias'],
273
+ "model.vision_model.post_layernorm.weight": loaded_all['vision_model']['transformer']['final_layernorm.weight'],
274
+ })
275
+ for v_layer_idx in range(24):
276
+ state_dict.update({
277
+ f"model.vision_model.encoder.layers.{v_layer_idx}.input_layernorm.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.input_layernorm.bias'],
278
+ f"model.vision_model.encoder.layers.{v_layer_idx}.input_layernorm.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.input_layernorm.weight'],
279
+ f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc1.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_h_to_4h.bias'],
280
+ f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc1.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_h_to_4h.weight'],
281
+ f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc2.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_4h_to_h.bias'],
282
+ f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc2.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_4h_to_h.weight'],
283
+ f"model.vision_model.encoder.layers.{v_layer_idx}.post_attention_layernorm.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.post_attention_layernorm.bias'],
284
+ f"model.vision_model.encoder.layers.{v_layer_idx}.post_attention_layernorm.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.post_attention_layernorm.weight'],
285
+ f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.dense.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.dense.bias'],
286
+ f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.dense.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.dense.weight'],
287
+ f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.query_key_value.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.query_key_value.bias'],
288
+ f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.query_key_value.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.query_key_value.weight'],
289
+ })
290
+
291
+ # Abstractor Part
292
+ state_dict.update({
293
+ "model.visual_abstractor.query_embeds": loaded_all['vision_abstractor']['learnable_queries'],
294
+ "model.visual_abstractor.visual_fc.bias": loaded_all['vision_abstractor']['visual_fc']['bias'],
295
+ "model.visual_abstractor.visual_fc.weight": loaded_all['vision_abstractor']['visual_fc']['weight'],
296
+ "model.visual_abstractor.vit_eos": loaded_all['vision_abstractor']['vit_eos'],
297
+ })
298
+ for v_layer_idx in range(6):
299
+ state_dict.update({
300
+ # f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.k_pos_embed":
301
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.key.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.k_proj.bias"],
302
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.key.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.k_proj.weight"],
303
+ # f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.q_pos_embed": "pytorch_model-00004-of-00004.bin",
304
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.query.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.q_proj.bias"],
305
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.query.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.q_proj.weight"],
306
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.value.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.v_proj.bias"],
307
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.value.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.v_proj.weight"],
308
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.norm1.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.norm1.bias"],
309
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.norm1.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.norm1.weight"],
310
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.normk.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.normk.bias"],
311
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.normk.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.normk.weight"],
312
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.ffn_ln.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.ffn_ln.bias"],
313
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.ffn_ln.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.ffn_ln.weight"],
314
+
315
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w1.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w1.bias"],
316
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w1.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w1.weight"],
317
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w2.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w2.bias"],
318
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w2.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w2.weight"],
319
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w3.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w3.bias"],
320
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w3.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w3.weight"],
321
+
322
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.norm2.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.norm2.bias"],
323
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.norm2.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.norm2.weight"],
324
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.out_proj.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.o_proj.bias"],
325
+ f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.out_proj.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.o_proj.weight"],
326
+ })
327
+
328
+ for k, v in state_dict.items():
329
+ index_dict["weight_map"][k] = filename
330
+ param_count += v.numel()
331
+ torch.save(state_dict, os.path.join(tmp_model_path, filename))
332
+
333
+ # Write configs
334
+ index_dict["metadata"] = {"total_size": param_count * 2}
335
+ write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
336
+
337
+ config = MPLUGOwl2Config()
338
+ config.save_pretrained(tmp_model_path)
339
+
340
+ # Make space so we can load the model properly now.
341
+ del state_dict
342
+ del loaded
343
+ del loaded_all
344
+ gc.collect()
345
+
346
+ def write_tokenizer(tokenizer_path, input_tokenizer_path):
347
+ # Initialize the tokenizer based on the `spm` model
348
+ tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
349
+ print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.")
350
+ tokenizer = tokenizer_class(input_tokenizer_path)
351
+ tokenizer.save_pretrained(tokenizer_path)
352
+
353
+
354
+ def main():
355
+ parser = argparse.ArgumentParser()
356
+ parser.add_argument(
357
+ "--input_dir",
358
+ help="Location of LLaMA_Megatron weights",
359
+ )
360
+ parser.add_argument(
361
+ "--model_size",
362
+ type=int,
363
+ default=7,
364
+ choices=[7, 13, 30, 65, 70],
365
+ )
366
+ parser.add_argument(
367
+ "--num_input_shards",
368
+ type=int,
369
+ default=1,
370
+ )
371
+ parser.add_argument(
372
+ "--num_output_shards",
373
+ type=int,
374
+ default=1,
375
+ )
376
+ parser.add_argument('--skip_permute', action='store_true')
377
+
378
+ parser.add_argument(
379
+ "--output_dir",
380
+ help="Location to write HF model and tokenizer",
381
+ )
382
+
383
+ args = parser.parse_args()
384
+ write_model(
385
+ model_path=args.output_dir,
386
+ input_base_path=args.input_dir,
387
+ model_size=args.model_size,
388
+ num_input_shards=args.num_input_shards,
389
+ num_output_shards=args.num_output_shards,
390
+ skip_permute=args.skip_permute
391
+ )
392
+
393
+
394
+ if __name__ == "__main__":
395
+ main()
DeQA-Score/src/model/modeling_attn_mask_utils.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import List, Optional, Tuple, Union
15
+
16
+ import torch
17
+
18
+
19
+ class AttentionMaskConverter:
20
+ """
21
+ A utility attention mask class that allows one to:
22
+ - Create a causal 4d mask
23
+ - Create a causal 4d mask with slided window
24
+ - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
25
+ key_value_length) that can be multiplied with attention scores
26
+
27
+ Parameters:
28
+ is_causal (`bool`):
29
+ Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
30
+
31
+ sliding_window (`int`, *optional*):
32
+ Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
33
+ """
34
+
35
+ def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
36
+ self.is_causal = is_causal
37
+ self.sliding_window = sliding_window
38
+
39
+ if self.sliding_window is not None and self.sliding_window <= 0:
40
+ raise ValueError(
41
+ f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
42
+ )
43
+
44
+ def to_causal_4d(
45
+ self,
46
+ batch_size: int,
47
+ query_length: int,
48
+ key_value_length: int,
49
+ dtype: torch.dtype = torch.float32,
50
+ device: Union[torch.device, "str"] = "cpu",
51
+ ) -> torch.Tensor:
52
+ """
53
+ Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
54
+ bias to upper right hand triangular matrix (causal mask).
55
+ """
56
+ if not self.is_causal:
57
+ raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
58
+
59
+ # If shape is not cached, create a new causal mask and cache it
60
+ input_shape = (batch_size, query_length)
61
+ past_key_values_length = key_value_length - query_length
62
+
63
+ # create causal mask
64
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
65
+ causal_4d_mask = None
66
+ if input_shape[-1] > 1 or self.sliding_window is not None:
67
+ causal_4d_mask = self._make_causal_mask(
68
+ input_shape,
69
+ dtype,
70
+ device=device,
71
+ past_key_values_length=past_key_values_length,
72
+ sliding_window=self.sliding_window,
73
+ )
74
+
75
+ return causal_4d_mask
76
+
77
+ def to_4d(
78
+ self,
79
+ attention_mask_2d: torch.Tensor,
80
+ query_length: int,
81
+ key_value_length: Optional[int] = None,
82
+ dtype: torch.dtype = torch.float32,
83
+ ) -> torch.Tensor:
84
+ """
85
+ Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
86
+ key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
87
+ causal, a causal mask will be added.
88
+ """
89
+ input_shape = (attention_mask_2d.shape[0], query_length)
90
+
91
+ # create causal mask
92
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
93
+ causal_4d_mask = None
94
+ if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
95
+ if key_value_length is None:
96
+ raise ValueError(
97
+ "This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
98
+ )
99
+
100
+ past_key_values_length = key_value_length - query_length
101
+ causal_4d_mask = self._make_causal_mask(
102
+ input_shape,
103
+ dtype,
104
+ device=attention_mask_2d.device,
105
+ past_key_values_length=past_key_values_length,
106
+ sliding_window=self.sliding_window,
107
+ )
108
+ elif self.sliding_window is not None:
109
+ raise NotImplementedError("Sliding window is currently only implemented for causal masking")
110
+
111
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
112
+ expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
113
+ attention_mask_2d.device
114
+ )
115
+ expanded_4d_mask = expanded_attn_mask if causal_4d_mask is None else expanded_attn_mask + causal_4d_mask
116
+
117
+ return expanded_4d_mask
118
+
119
+ @staticmethod
120
+ def _make_causal_mask(
121
+ input_ids_shape: torch.Size,
122
+ dtype: torch.dtype,
123
+ device: torch.device,
124
+ past_key_values_length: int = 0,
125
+ sliding_window: Optional[int] = None,
126
+ ):
127
+ """
128
+ Make causal mask used for bi-directional self-attention.
129
+ """
130
+ bsz, tgt_len = input_ids_shape
131
+ mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
132
+ mask_cond = torch.arange(mask.size(-1), device=device)
133
+ mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
134
+
135
+ mask = mask.to(dtype)
136
+
137
+ if past_key_values_length > 0:
138
+ mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
139
+
140
+ # add lower triangular sliding window mask if necessary
141
+ if sliding_window is not None:
142
+ diagonal = past_key_values_length - sliding_window + 1
143
+
144
+ context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal)
145
+ mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min)
146
+
147
+ return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
148
+
149
+ @staticmethod
150
+ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
151
+ """
152
+ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
153
+ """
154
+ bsz, src_len = mask.size()
155
+ tgt_len = tgt_len if tgt_len is not None else src_len
156
+
157
+ expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
158
+
159
+ inverted_mask = 1.0 - expanded_mask
160
+
161
+ return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
162
+
163
+
164
+ def _prepare_4d_causal_attention_mask(
165
+ attention_mask: Optional[torch.Tensor],
166
+ input_shape: Union[torch.Size, Tuple, List],
167
+ inputs_embeds: torch.Tensor,
168
+ past_key_values_length: int,
169
+ sliding_window: Optional[int] = None,
170
+ ):
171
+ """
172
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
173
+ `(batch_size, key_value_length)`
174
+
175
+ Args:
176
+ attention_mask (`torch.Tensor` or `None`):
177
+ A 2D attention mask of shape `(batch_size, key_value_length)`
178
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
179
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
180
+ inputs_embeds (`torch.Tensor`):
181
+ The embedded inputs as a torch Tensor.
182
+ past_key_values_length (`int`):
183
+ The length of the key value cache.
184
+ sliding_window (`int`, *optional*):
185
+ If the model uses windowed attention, a sliding window should be passed.
186
+ """
187
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
188
+
189
+ key_value_length = input_shape[-1] + past_key_values_length
190
+
191
+ # 4d mask is passed through the layers
192
+ if attention_mask is not None:
193
+ attention_mask = attn_mask_converter.to_4d(
194
+ attention_mask, input_shape[-1], key_value_length, dtype=inputs_embeds.dtype
195
+ )
196
+ else:
197
+ attention_mask = attn_mask_converter.to_causal_4d(
198
+ input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
199
+ )
200
+
201
+ return attention_mask
202
+
203
+
204
+ def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
205
+ """
206
+ Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
207
+ `(batch_size, key_value_length)`
208
+
209
+ Args:
210
+ mask (`torch.Tensor` or `None`):
211
+ A 2D attention mask of shape `(batch_size, key_value_length)`
212
+ dtype (`torch.dtype`):
213
+ The torch dtype the created mask shall have.
214
+ tgt_len (`int`):
215
+ The target length or query length the created mask shall have.
216
+ """
217
+ return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
218
+
219
+
220
+ def _create_4d_causal_attention_mask(
221
+ input_shape: Union[torch.Size, Tuple, List],
222
+ dtype: torch.dtype,
223
+ device: torch.device,
224
+ past_key_values_length: int = 0,
225
+ sliding_window: Optional[int] = None,
226
+ ):
227
+ """
228
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
229
+
230
+ Args:
231
+ input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
232
+ The input shape should be a tuple that defines `(batch_size, query_length)`.
233
+ dtype (`torch.dtype`):
234
+ The torch dtype the created mask shall have.
235
+ device (`int`):
236
+ The torch device the created mask shall have.
237
+ sliding_window (`int`, *optional*):
238
+ If the model uses windowed attention, a sliding window should be passed.
239
+ """
240
+ attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
241
+
242
+ key_value_length = past_key_values_length + input_shape[-1]
243
+ attention_mask = attn_mask_converter.to_causal_4d(
244
+ input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device
245
+ )
246
+
247
+ return attention_mask
DeQA-Score/src/model/modeling_llama2.py ADDED
@@ -0,0 +1,834 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+ from functools import partial
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torch.utils.checkpoint
9
+ from torch import nn
10
+
11
+
12
+ import copy
13
+ import os
14
+ import sys
15
+
16
+ dir_path = os.path.dirname(os.path.realpath(__file__))
17
+ sys.path.insert(0, dir_path)
18
+
19
+ import transformers
20
+ from transformers.models.llama.modeling_llama import *
21
+
22
+ def _get_unpad_data(attention_mask):
23
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
24
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
25
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
26
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
27
+ return (
28
+ indices,
29
+ cu_seqlens,
30
+ max_seqlen_in_batch,
31
+ )
32
+
33
+
34
+ from transformers.configuration_utils import PretrainedConfig
35
+ from transformers.utils import logging
36
+
37
+ from .modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
38
+ from .configuration_mplug_owl2 import LlamaConfig
39
+
40
+ class MultiwayNetwork(nn.Module):
41
+
42
+ def __init__(self, module_provider, num_multiway=2):
43
+ super(MultiwayNetwork, self).__init__()
44
+
45
+ self.multiway = torch.nn.ModuleList([module_provider() for _ in range(num_multiway)])
46
+
47
+ def forward(self, hidden_states, multiway_indices):
48
+
49
+ if len(self.multiway) == 1:
50
+ return self.multiway[0](hidden_states)
51
+
52
+ output_hidden_states = torch.empty_like(hidden_states)
53
+
54
+ for idx, subway in enumerate(self.multiway):
55
+ local_indices = multiway_indices.eq(idx).nonzero(as_tuple=True)
56
+ hidden = hidden_states[local_indices].unsqueeze(1).contiguous()
57
+ if hidden.numel():
58
+ output = subway(hidden)
59
+ if isinstance(output, tuple):
60
+ output = output[0]
61
+ output = output.squeeze(1)
62
+ output_hidden_states[local_indices] = output
63
+
64
+ return output_hidden_states.contiguous()
65
+
66
+
67
+ class LlamaAttention(nn.Module):
68
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
69
+
70
+ def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
71
+ super().__init__()
72
+ self.config = config
73
+ self.layer_idx = layer_idx
74
+ if layer_idx is None:
75
+ logger.warning_once(
76
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
77
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
78
+ "when creating this class."
79
+ )
80
+
81
+ self.attention_dropout = config.attention_dropout
82
+ self.hidden_size = config.hidden_size
83
+ self.num_heads = config.num_attention_heads
84
+ self.head_dim = self.hidden_size // self.num_heads
85
+ self.num_key_value_heads = config.num_key_value_heads
86
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
87
+ self.max_position_embeddings = config.max_position_embeddings
88
+ self.rope_theta = config.rope_theta
89
+ self.is_causal = True
90
+
91
+ if (self.head_dim * self.num_heads) != self.hidden_size:
92
+ raise ValueError(
93
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
94
+ f" and `num_heads`: {self.num_heads})."
95
+ )
96
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
97
+ self.k_proj = MultiwayNetwork(module_provider=partial(
98
+ nn.Linear, in_features=self.hidden_size, out_features=self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
99
+ )
100
+ self.v_proj = MultiwayNetwork(module_provider=partial(
101
+ nn.Linear, in_features=self.hidden_size, out_features=self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
102
+ )
103
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
104
+ self._init_rope()
105
+
106
+ def _init_rope(self):
107
+ if self.config.rope_scaling is None:
108
+ self.rotary_emb = LlamaRotaryEmbedding(
109
+ self.head_dim,
110
+ max_position_embeddings=self.max_position_embeddings,
111
+ base=self.rope_theta,
112
+ )
113
+ else:
114
+ scaling_type = self.config.rope_scaling["type"]
115
+ scaling_factor = self.config.rope_scaling["factor"]
116
+ if scaling_type == "linear":
117
+ self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
118
+ self.head_dim,
119
+ max_position_embeddings=self.max_position_embeddings,
120
+ scaling_factor=scaling_factor,
121
+ base=self.rope_theta,
122
+ )
123
+ elif scaling_type == "dynamic":
124
+ self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
125
+ self.head_dim,
126
+ max_position_embeddings=self.max_position_embeddings,
127
+ scaling_factor=scaling_factor,
128
+ base=self.rope_theta,
129
+ )
130
+ else:
131
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
132
+
133
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
134
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
135
+
136
+ def forward(
137
+ self,
138
+ hidden_states: torch.Tensor,
139
+ modality_indicators: torch.Tensor,
140
+ attention_mask: Optional[torch.Tensor] = None,
141
+ position_ids: Optional[torch.LongTensor] = None,
142
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
143
+ output_attentions: bool = False,
144
+ use_cache: bool = False,
145
+ padding_mask: Optional[torch.LongTensor] = None,
146
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
147
+ bsz, q_len, _ = hidden_states.size()
148
+
149
+ query_states = self.q_proj(hidden_states, )
150
+ key_states = self.k_proj(hidden_states, modality_indicators)
151
+ value_states = self.v_proj(hidden_states, modality_indicators)
152
+
153
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
154
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
155
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
156
+
157
+ kv_seq_len = key_states.shape[-2]
158
+ if past_key_value is not None:
159
+ kv_seq_len += past_key_value[0].shape[-2]
160
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
161
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
162
+
163
+ if past_key_value is not None:
164
+ # reuse k, v, self_attention
165
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
166
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
167
+
168
+ past_key_value = (key_states, value_states) if use_cache else None
169
+
170
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
171
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
172
+
173
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
174
+
175
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
176
+ raise ValueError(
177
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
178
+ f" {attn_weights.size()}"
179
+ )
180
+
181
+ if attention_mask is not None:
182
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
183
+ raise ValueError(
184
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
185
+ )
186
+ attn_weights = attn_weights + attention_mask
187
+
188
+ # upcast attention to fp32
189
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
190
+ attn_output = torch.matmul(attn_weights, value_states)
191
+
192
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
193
+ raise ValueError(
194
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
195
+ f" {attn_output.size()}"
196
+ )
197
+
198
+ attn_output = attn_output.transpose(1, 2).contiguous()
199
+
200
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
201
+
202
+ attn_output = self.o_proj(attn_output)
203
+
204
+ if not output_attentions:
205
+ attn_weights = None
206
+
207
+ return attn_output, attn_weights, past_key_value
208
+
209
+
210
+ class LlamaFlashAttention2(LlamaAttention):
211
+ """
212
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
213
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
214
+ flash attention and deal with padding tokens in case the input contains any of them.
215
+ """
216
+
217
+ def __init__(self, *args, **kwargs):
218
+ super().__init__(*args, **kwargs)
219
+
220
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
221
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
222
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
223
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
224
+
225
+ def forward(
226
+ self,
227
+ hidden_states: torch.Tensor,
228
+ modality_indicators: torch.Tensor,
229
+ attention_mask: Optional[torch.LongTensor] = None,
230
+ position_ids: Optional[torch.LongTensor] = None,
231
+ past_key_value: Optional[Cache] = None,
232
+ output_attentions: bool = False,
233
+ use_cache: bool = False,
234
+ **kwargs,
235
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
236
+ # LlamaFlashAttention2 attention does not support output_attentions
237
+ if "padding_mask" in kwargs:
238
+ warnings.warn(
239
+ "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
240
+ )
241
+
242
+ # overwrite attention_mask with padding_mask
243
+ attention_mask = kwargs.pop("padding_mask")
244
+
245
+ output_attentions = False
246
+
247
+ bsz, q_len, _ = hidden_states.size()
248
+
249
+ query_states = self.q_proj(hidden_states)
250
+ key_states = self.k_proj(hidden_states, modality_indicators)
251
+ value_states = self.v_proj(hidden_states, modality_indicators)
252
+
253
+ # Flash attention requires the input to have the shape
254
+ # batch_size x seq_length x head_dim x hidden_dim
255
+ # therefore we just need to keep the original shape
256
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
257
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
258
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
259
+
260
+ kv_seq_len = key_states.shape[-2]
261
+ if past_key_value is not None:
262
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
263
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
264
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
265
+
266
+ if past_key_value is not None:
267
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
268
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
269
+
270
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
271
+ # to be able to avoid many of these transpose/reshape/view.
272
+ query_states = query_states.transpose(1, 2)
273
+ key_states = key_states.transpose(1, 2)
274
+ value_states = value_states.transpose(1, 2)
275
+
276
+ dropout_rate = self.attention_dropout if self.training else 0.0
277
+
278
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
279
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
280
+ # cast them back in the correct dtype just to be sure everything works as expected.
281
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
282
+ # in fp32. (LlamaRMSNorm handles it correctly)
283
+
284
+ input_dtype = query_states.dtype
285
+ if input_dtype == torch.float32:
286
+ if torch.is_autocast_enabled():
287
+ target_dtype = torch.get_autocast_gpu_dtype()
288
+ # Handle the case where the model is quantized
289
+ elif hasattr(self.config, "_pre_quantization_dtype"):
290
+ target_dtype = self.config._pre_quantization_dtype
291
+ else:
292
+ target_dtype = self.q_proj.weight.dtype
293
+
294
+ logger.warning_once(
295
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
296
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
297
+ f" {target_dtype}."
298
+ )
299
+
300
+ query_states = query_states.to(target_dtype)
301
+ key_states = key_states.to(target_dtype)
302
+ value_states = value_states.to(target_dtype)
303
+
304
+ attn_output = self._flash_attention_forward(
305
+ query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
306
+ )
307
+
308
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
309
+ attn_output = self.o_proj(attn_output)
310
+
311
+ if not output_attentions:
312
+ attn_weights = None
313
+
314
+ return attn_output, attn_weights, past_key_value
315
+
316
+ def _flash_attention_forward(
317
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
318
+ ):
319
+ """
320
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
321
+ first unpad the input, then computes the attention scores and pad the final attention scores.
322
+
323
+ Args:
324
+ query_states (`torch.Tensor`):
325
+ Input query states to be passed to Flash Attention API
326
+ key_states (`torch.Tensor`):
327
+ Input key states to be passed to Flash Attention API
328
+ value_states (`torch.Tensor`):
329
+ Input value states to be passed to Flash Attention API
330
+ attention_mask (`torch.Tensor`):
331
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
332
+ position of padding tokens and 1 for the position of non-padding tokens.
333
+ dropout (`int`, *optional*):
334
+ Attention dropout
335
+ softmax_scale (`float`, *optional*):
336
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
337
+ """
338
+ if not self._flash_attn_uses_top_left_mask:
339
+ causal = self.is_causal
340
+ else:
341
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
342
+ causal = self.is_causal and query_length != 1
343
+
344
+ # Contains at least one padding token in the sequence
345
+ if attention_mask is not None:
346
+ batch_size = query_states.shape[0]
347
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
348
+ query_states, key_states, value_states, attention_mask, query_length
349
+ )
350
+
351
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
352
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
353
+
354
+ attn_output_unpad = flash_attn_varlen_func(
355
+ query_states,
356
+ key_states,
357
+ value_states,
358
+ cu_seqlens_q=cu_seqlens_q,
359
+ cu_seqlens_k=cu_seqlens_k,
360
+ max_seqlen_q=max_seqlen_in_batch_q,
361
+ max_seqlen_k=max_seqlen_in_batch_k,
362
+ dropout_p=dropout,
363
+ softmax_scale=softmax_scale,
364
+ causal=causal,
365
+ )
366
+
367
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
368
+ else:
369
+ attn_output = flash_attn_func(
370
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
371
+ )
372
+
373
+ return attn_output
374
+
375
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
376
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
377
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
378
+
379
+ key_layer = index_first_axis(
380
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
381
+ )
382
+ value_layer = index_first_axis(
383
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
384
+ )
385
+ if query_length == kv_seq_len:
386
+ query_layer = index_first_axis(
387
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
388
+ )
389
+ cu_seqlens_q = cu_seqlens_k
390
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
391
+ indices_q = indices_k
392
+ elif query_length == 1:
393
+ max_seqlen_in_batch_q = 1
394
+ cu_seqlens_q = torch.arange(
395
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
396
+ ) # There is a memcpy here, that is very bad.
397
+ indices_q = cu_seqlens_q[:-1]
398
+ query_layer = query_layer.squeeze(1)
399
+ else:
400
+ # The -q_len: slice assumes left padding.
401
+ attention_mask = attention_mask[:, -query_length:]
402
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
403
+
404
+ return (
405
+ query_layer,
406
+ key_layer,
407
+ value_layer,
408
+ indices_q,
409
+ (cu_seqlens_q, cu_seqlens_k),
410
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
411
+ )
412
+
413
+
414
+ class LlamaSdpaAttention(LlamaAttention):
415
+ """
416
+ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
417
+ `LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
418
+ SDPA API.
419
+ """
420
+
421
+ # Adapted from LlamaAttention.forward
422
+ def forward(
423
+ self,
424
+ hidden_states: torch.Tensor,
425
+ modality_indicators: torch.Tensor,
426
+ attention_mask: Optional[torch.Tensor] = None,
427
+ position_ids: Optional[torch.LongTensor] = None,
428
+ past_key_value: Optional[Cache] = None,
429
+ output_attentions: bool = False,
430
+ use_cache: bool = False,
431
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
432
+ if output_attentions:
433
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
434
+ logger.warning_once(
435
+ "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
436
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
437
+ )
438
+ return super().forward(
439
+ hidden_states=hidden_states,
440
+ modality_indicators=modality_indicators,
441
+ attention_mask=attention_mask,
442
+ position_ids=position_ids,
443
+ past_key_value=past_key_value,
444
+ output_attentions=output_attentions,
445
+ use_cache=use_cache,
446
+ )
447
+
448
+ bsz, q_len, _ = hidden_states.size()
449
+
450
+ query_states = self.q_proj(hidden_states)
451
+ key_states = self.k_proj(hidden_states, modality_indicators)
452
+ value_states = self.v_proj(hidden_states, modality_indicators)
453
+
454
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
455
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
456
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
457
+
458
+ kv_seq_len = key_states.shape[-2]
459
+ if past_key_value is not None:
460
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
461
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
462
+
463
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
464
+
465
+ if past_key_value is not None:
466
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
467
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
468
+
469
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
470
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
471
+
472
+ if attention_mask is not None:
473
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
474
+ raise ValueError(
475
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
476
+ )
477
+
478
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
479
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
480
+ if query_states.device.type == "cuda" and attention_mask is not None:
481
+ query_states = query_states.contiguous()
482
+ key_states = key_states.contiguous()
483
+ value_states = value_states.contiguous()
484
+
485
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
486
+ query_states,
487
+ key_states,
488
+ value_states,
489
+ attn_mask=attention_mask,
490
+ dropout_p=self.attention_dropout if self.training else 0.0,
491
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
492
+ is_causal=self.is_causal and attention_mask is None and q_len > 1,
493
+ )
494
+
495
+ attn_output = attn_output.transpose(1, 2).contiguous()
496
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
497
+
498
+ attn_output = self.o_proj(attn_output)
499
+
500
+ return attn_output, None, past_key_value
501
+
502
+
503
+
504
+ LLAMA_ATTENTION_CLASSES = {
505
+ "eager": LlamaAttention,
506
+ "flash_attention_2": LlamaFlashAttention2,
507
+ "sdpa": LlamaSdpaAttention,
508
+ }
509
+
510
+ class LlamaDecoderLayer(nn.Module):
511
+ def __init__(self, config: LlamaConfig, layer_idx):
512
+ super().__init__()
513
+ self.hidden_size = config.hidden_size
514
+ self.self_attn = LlamaAttention(config=config)
515
+ self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
516
+ self.mlp = LlamaMLP(config)
517
+ self.input_layernorm = MultiwayNetwork(module_provider=partial(
518
+ LlamaRMSNorm, hidden_size=config.hidden_size, eps=config.rms_norm_eps
519
+ ))
520
+ self.post_attention_layernorm = MultiwayNetwork(module_provider=partial(
521
+ LlamaRMSNorm, hidden_size=config.hidden_size, eps=config.rms_norm_eps
522
+ ))
523
+
524
+ def forward(
525
+ self,
526
+ hidden_states: torch.Tensor,
527
+ modality_indicators: torch.Tensor = None,
528
+ attention_mask: Optional[torch.Tensor] = None,
529
+ position_ids: Optional[torch.LongTensor] = None,
530
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
531
+ output_attentions: Optional[bool] = False,
532
+ use_cache: Optional[bool] = False,
533
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
534
+ """
535
+ Args:
536
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
537
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
538
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
539
+ output_attentions (`bool`, *optional*):
540
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
541
+ returned tensors for more detail.
542
+ use_cache (`bool`, *optional*):
543
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
544
+ (see `past_key_values`).
545
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
546
+ """
547
+
548
+ residual = hidden_states
549
+
550
+ hidden_states = self.input_layernorm(hidden_states, modality_indicators)
551
+
552
+ # Self Attention
553
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
554
+ hidden_states=hidden_states,
555
+ modality_indicators=modality_indicators,
556
+ attention_mask=attention_mask,
557
+ position_ids=position_ids,
558
+ past_key_value=past_key_value,
559
+ output_attentions=output_attentions,
560
+ use_cache=use_cache,
561
+ )
562
+ hidden_states = residual + hidden_states
563
+
564
+ # Fully Connected
565
+ residual = hidden_states
566
+ hidden_states = self.post_attention_layernorm(hidden_states, modality_indicators)
567
+ hidden_states = self.mlp(hidden_states)
568
+ hidden_states = residual + hidden_states
569
+
570
+ outputs = (hidden_states,)
571
+
572
+ if output_attentions:
573
+ outputs += (self_attn_weights,)
574
+
575
+ if use_cache:
576
+ outputs += (present_key_value,)
577
+
578
+ return outputs
579
+
580
+
581
+ def model_forward(
582
+ self,
583
+ input_ids: torch.LongTensor = None,
584
+ modality_indicators: torch.Tensor = None,
585
+ attention_mask: Optional[torch.Tensor] = None,
586
+ position_ids: Optional[torch.LongTensor] = None,
587
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
588
+ inputs_embeds: Optional[torch.FloatTensor] = None,
589
+ use_cache: Optional[bool] = None,
590
+ output_attentions: Optional[bool] = None,
591
+ output_hidden_states: Optional[bool] = None,
592
+ return_dict: Optional[bool] = None,
593
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
594
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
595
+ output_hidden_states = (
596
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
597
+ )
598
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
599
+
600
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
601
+
602
+ # retrieve input_ids and inputs_embeds
603
+ if input_ids is not None and inputs_embeds is not None:
604
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
605
+ elif input_ids is not None:
606
+ batch_size, seq_length = input_ids.shape
607
+ elif inputs_embeds is not None:
608
+ batch_size, seq_length, _ = inputs_embeds.shape
609
+ else:
610
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
611
+
612
+ seq_length_with_past = seq_length
613
+ past_key_values_length = 0
614
+
615
+ if past_key_values is not None:
616
+ past_key_values_length = past_key_values[0][0].shape[2]
617
+ seq_length_with_past = seq_length_with_past + past_key_values_length
618
+
619
+ if position_ids is None:
620
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
621
+ position_ids = torch.arange(
622
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
623
+ )
624
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
625
+ else:
626
+ position_ids = position_ids.view(-1, seq_length).long()
627
+
628
+ if inputs_embeds is None:
629
+ inputs_embeds = self.embed_tokens(input_ids)
630
+ # embed positions
631
+ if attention_mask is None:
632
+ attention_mask = torch.ones(
633
+ (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
634
+ )
635
+
636
+ if self._use_flash_attention_2:
637
+ # 2d mask is passed through the layers
638
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
639
+ elif self._use_sdpa and not output_attentions:
640
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
641
+ # the manual implementation that requires a 4D causal mask in all cases.
642
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
643
+ attention_mask,
644
+ (batch_size, seq_length),
645
+ inputs_embeds,
646
+ past_key_values_length,
647
+ )
648
+ else:
649
+ # 4d mask is passed through the layers
650
+ attention_mask = _prepare_4d_causal_attention_mask(
651
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
652
+ )
653
+
654
+ hidden_states = inputs_embeds
655
+
656
+ if self.gradient_checkpointing and self.training:
657
+ if use_cache:
658
+ logger.warning_once(
659
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
660
+ )
661
+ use_cache = False
662
+
663
+ # decoder layers
664
+ all_hidden_states = () if output_hidden_states else None
665
+ all_self_attns = () if output_attentions else None
666
+ next_decoder_cache = () if use_cache else None
667
+
668
+ for idx, decoder_layer in enumerate(self.layers):
669
+ if output_hidden_states:
670
+ all_hidden_states += (hidden_states,)
671
+
672
+ past_key_value = past_key_values[idx] if past_key_values is not None else None
673
+
674
+ if self.gradient_checkpointing and self.training:
675
+
676
+ def create_custom_forward(module):
677
+ def custom_forward(*inputs):
678
+ # None for past_key_value
679
+ return module(*inputs, past_key_value, output_attentions)
680
+
681
+ return custom_forward
682
+
683
+ layer_outputs = torch.utils.checkpoint.checkpoint(
684
+ create_custom_forward(decoder_layer),
685
+ hidden_states,
686
+ modality_indicators,
687
+ attention_mask,
688
+ position_ids,
689
+ )
690
+ else:
691
+ layer_outputs = decoder_layer(
692
+ hidden_states,
693
+ modality_indicators=modality_indicators,
694
+ attention_mask=attention_mask,
695
+ position_ids=position_ids,
696
+ past_key_value=past_key_value,
697
+ output_attentions=output_attentions,
698
+ use_cache=use_cache,
699
+ )
700
+
701
+ hidden_states = layer_outputs[0]
702
+
703
+ if use_cache:
704
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
705
+
706
+ if output_attentions:
707
+ all_self_attns += (layer_outputs[1],)
708
+
709
+ hidden_states = self.norm(hidden_states)
710
+
711
+ # add hidden states from the last decoder layer
712
+ if output_hidden_states:
713
+ all_hidden_states += (hidden_states,)
714
+
715
+ next_cache = next_decoder_cache if use_cache else None
716
+ if not return_dict:
717
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
718
+ return BaseModelOutputWithPast(
719
+ last_hidden_state=hidden_states,
720
+ past_key_values=next_cache,
721
+ hidden_states=all_hidden_states,
722
+ attentions=all_self_attns,
723
+ )
724
+
725
+
726
+ def causal_model_forward(
727
+ self,
728
+ input_ids: torch.LongTensor = None,
729
+ modality_indicators: torch.Tensor = None,
730
+ attention_mask: Optional[torch.Tensor] = None,
731
+ position_ids: Optional[torch.LongTensor] = None,
732
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
733
+ inputs_embeds: Optional[torch.FloatTensor] = None,
734
+ labels: Optional[torch.LongTensor] = None,
735
+ use_cache: Optional[bool] = None,
736
+ output_attentions: Optional[bool] = None,
737
+ output_hidden_states: Optional[bool] = None,
738
+ return_dict: Optional[bool] = None,
739
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
740
+ r"""
741
+ Args:
742
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
743
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
744
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
745
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
746
+
747
+ Returns:
748
+
749
+ Example:
750
+
751
+ ```python
752
+ >>> from transformers import AutoTokenizer, LlamaForCausalLM
753
+
754
+ >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
755
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
756
+
757
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
758
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
759
+
760
+ >>> # Generate
761
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
762
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
763
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
764
+ ```"""
765
+
766
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
767
+ output_hidden_states = (
768
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
769
+ )
770
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
771
+
772
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
773
+ outputs = self.model(
774
+ input_ids=input_ids,
775
+ modality_indicators=modality_indicators,
776
+ attention_mask=attention_mask,
777
+ position_ids=position_ids,
778
+ past_key_values=past_key_values,
779
+ inputs_embeds=inputs_embeds,
780
+ use_cache=use_cache,
781
+ output_attentions=output_attentions,
782
+ output_hidden_states=output_hidden_states,
783
+ return_dict=return_dict,
784
+ )
785
+
786
+ hidden_states = outputs[0]
787
+ if self.config.pretraining_tp > 1:
788
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
789
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
790
+ logits = torch.cat(logits, dim=-1)
791
+ else:
792
+ logits = self.lm_head(hidden_states)
793
+ logits = logits.float()
794
+
795
+ loss = None
796
+ if labels is not None:
797
+ # Shift so that tokens < n predict n
798
+ shift_logits = logits[..., :-1, :].contiguous()
799
+ shift_labels = labels[..., 1:].contiguous()
800
+ # Flatten the tokens
801
+ loss_fct = CrossEntropyLoss()
802
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
803
+ shift_labels = shift_labels.view(-1)
804
+ # Enable model parallelism
805
+ shift_labels = shift_labels.to(shift_logits.device)
806
+ loss = loss_fct(shift_logits, shift_labels)
807
+
808
+ if not return_dict:
809
+ output = (logits,) + outputs[1:]
810
+ return (loss,) + output if loss is not None else output
811
+
812
+ return CausalLMOutputWithPast(
813
+ loss=loss,
814
+ logits=logits,
815
+ past_key_values=outputs.past_key_values,
816
+ hidden_states=outputs.hidden_states,
817
+ attentions=outputs.attentions,
818
+ )
819
+
820
+ def replace_llama_modality_adaptive():
821
+ transformers.models.llama.configuration_llama.LlamaConfig = LlamaConfig
822
+ transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
823
+ transformers.models.llama.modeling_llama.LlamaFlashAttention2 = LlamaFlashAttention2
824
+ transformers.models.llama.modeling_llama.LlamaSdpaAttention = LlamaSdpaAttention
825
+ transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
826
+ transformers.models.llama.modeling_llama.LlamaModel.forward = model_forward
827
+ transformers.models.llama.modeling_llama.LlamaForCausalLM.forward = causal_model_forward
828
+
829
+
830
+ if __name__ == "__main__":
831
+ replace_llama_modality_adaptive()
832
+ config = transformers.LlamaConfig.from_pretrained('/cpfs01/shared/public/test/vicuna-7b-v1.5/')
833
+ model = transformers.LlamaForCausalLM(config)
834
+ print(model)