Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +6 -0
- .gitignore +9 -0
- DeQA-Score/.gitignore +23 -0
- DeQA-Score/DeQA_Score.egg-info/PKG-INFO +322 -0
- DeQA-Score/DeQA_Score.egg-info/SOURCES.txt +36 -0
- DeQA-Score/DeQA_Score.egg-info/dependency_links.txt +1 -0
- DeQA-Score/DeQA_Score.egg-info/requires.txt +30 -0
- DeQA-Score/DeQA_Score.egg-info/top_level.txt +4 -0
- DeQA-Score/LICENSE +21 -0
- DeQA-Score/README.md +281 -0
- DeQA-Score/build_soft_labels/config.json +35 -0
- DeQA-Score/build_soft_labels/gen_soft_label.py +209 -0
- DeQA-Score/fig/boy_colorful.jpg +3 -0
- DeQA-Score/fig/model.png +3 -0
- DeQA-Score/fig/singapore_flyer.jpg +3 -0
- DeQA-Score/fig/teaser.png +3 -0
- DeQA-Score/preprocessor/preprocessor_config.json +19 -0
- DeQA-Score/preprocessor/special_tokens_map.json +24 -0
- DeQA-Score/preprocessor/tokenizer.model +3 -0
- DeQA-Score/preprocessor/tokenizer_config.json +35 -0
- DeQA-Score/pyproject.toml +35 -0
- DeQA-Score/scripts/eval_dist.sh +23 -0
- DeQA-Score/scripts/eval_score.sh +23 -0
- DeQA-Score/scripts/infer.sh +17 -0
- DeQA-Score/scripts/infer_lora.sh +18 -0
- DeQA-Score/scripts/train.sh +48 -0
- DeQA-Score/scripts/train_lora.sh +49 -0
- DeQA-Score/scripts/zero3.json +28 -0
- DeQA-Score/scripts/zero3_offload.json +56 -0
- DeQA-Score/src/__init__.py +2 -0
- DeQA-Score/src/constants.py +9 -0
- DeQA-Score/src/conversation.py +301 -0
- DeQA-Score/src/datasets/__init__.py +11 -0
- DeQA-Score/src/datasets/pair_dataset.py +276 -0
- DeQA-Score/src/datasets/single_dataset.py +244 -0
- DeQA-Score/src/datasets/utils.py +317 -0
- DeQA-Score/src/evaluate/__init__.py +1 -0
- DeQA-Score/src/evaluate/cal_distribution_gap.py +143 -0
- DeQA-Score/src/evaluate/cal_plcc_srcc.py +115 -0
- DeQA-Score/src/evaluate/eval_qbench_mcq.py +138 -0
- DeQA-Score/src/evaluate/iqa_eval.py +184 -0
- DeQA-Score/src/evaluate/scorer.py +63 -0
- DeQA-Score/src/evaluate/scorer_coco.py +103 -0
- DeQA-Score/src/mm_utils.py +112 -0
- DeQA-Score/src/model/__init__.py +2 -0
- DeQA-Score/src/model/builder.py +166 -0
- DeQA-Score/src/model/configuration_mplug_owl2.py +334 -0
- DeQA-Score/src/model/convert_mplug_owl2_weight_to_hf.py +395 -0
- DeQA-Score/src/model/modeling_attn_mask_utils.py +247 -0
- DeQA-Score/src/model/modeling_llama2.py +834 -0
.gitattributes
CHANGED
|
@@ -34,3 +34,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
CLIP-Huge-Flickr-Flat/faiss_IVPQ_PCA.index filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
CLIP-Huge-Flickr-Flat/faiss_IVPQ_PCA.index filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
DeQA-Score/fig/boy_colorful.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
DeQA-Score/fig/model.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
DeQA-Score/fig/singapore_flyer.jpg filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
DeQA-Score/fig/teaser.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
imgs/results1.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
imgs/results2.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
coco_data/
|
| 2 |
+
__pycache__/
|
| 3 |
+
coco_feats/
|
| 4 |
+
coco_faiss_indexes/
|
| 5 |
+
faiss_indexes/
|
| 6 |
+
processed_data/
|
| 7 |
+
outputs/
|
| 8 |
+
wandb/
|
| 9 |
+
results/
|
DeQA-Score/.gitignore
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# source file related
|
| 2 |
+
*__pycache__*
|
| 3 |
+
*.pyc
|
| 4 |
+
*.o
|
| 5 |
+
*.so
|
| 6 |
+
*.egg
|
| 7 |
+
*.egg-info
|
| 8 |
+
|
| 9 |
+
# training related
|
| 10 |
+
*log*
|
| 11 |
+
*.log
|
| 12 |
+
*.pth
|
| 13 |
+
*.pt
|
| 14 |
+
|
| 15 |
+
# result related
|
| 16 |
+
*answer*
|
| 17 |
+
*ckpt*
|
| 18 |
+
*.json
|
| 19 |
+
*.jsonl
|
| 20 |
+
res_*
|
| 21 |
+
|
| 22 |
+
# mac hidden
|
| 23 |
+
*.DS_Store*
|
DeQA-Score/DeQA_Score.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.4
|
| 2 |
+
Name: DeQA-Score
|
| 3 |
+
Version: 1.2.0
|
| 4 |
+
Summary: Teaching Large Language Models to Regress Accurate Image Quality Scores using Score Distribution (based on mPLUG-Owl2)
|
| 5 |
+
Project-URL: Bug Tracker, https://github.com/zhiyuanyou/DeQA-Score/issues
|
| 6 |
+
Classifier: Programming Language :: Python :: 3
|
| 7 |
+
Classifier: License :: OSI Approved :: Apache Software License
|
| 8 |
+
Requires-Python: >=3.8
|
| 9 |
+
Description-Content-Type: text/markdown
|
| 10 |
+
License-File: LICENSE
|
| 11 |
+
Requires-Dist: torch==2.0.1
|
| 12 |
+
Requires-Dist: torchvision==0.15.2
|
| 13 |
+
Requires-Dist: transformers==4.36.1
|
| 14 |
+
Requires-Dist: tokenizers==0.15.0
|
| 15 |
+
Requires-Dist: sentencepiece==0.1.99
|
| 16 |
+
Requires-Dist: shortuuid
|
| 17 |
+
Requires-Dist: accelerate==0.21.0
|
| 18 |
+
Requires-Dist: peft==0.4.0
|
| 19 |
+
Requires-Dist: bitsandbytes==0.41.0
|
| 20 |
+
Requires-Dist: pydantic<2,>=1
|
| 21 |
+
Requires-Dist: markdown2[all]
|
| 22 |
+
Requires-Dist: numpy
|
| 23 |
+
Requires-Dist: scikit-learn==1.2.2
|
| 24 |
+
Requires-Dist: gradio==3.35.2
|
| 25 |
+
Requires-Dist: gradio_client==0.2.9
|
| 26 |
+
Requires-Dist: requests
|
| 27 |
+
Requires-Dist: httpx==0.24.0
|
| 28 |
+
Requires-Dist: uvicorn
|
| 29 |
+
Requires-Dist: fastapi
|
| 30 |
+
Requires-Dist: icecream
|
| 31 |
+
Requires-Dist: einops==0.6.1
|
| 32 |
+
Requires-Dist: einops-exts==0.0.4
|
| 33 |
+
Requires-Dist: timm==0.6.13
|
| 34 |
+
Requires-Dist: decord
|
| 35 |
+
Requires-Dist: scipy
|
| 36 |
+
Provides-Extra: train
|
| 37 |
+
Requires-Dist: deepspeed==0.9.5; extra == "train"
|
| 38 |
+
Requires-Dist: ninja; extra == "train"
|
| 39 |
+
Requires-Dist: wandb; extra == "train"
|
| 40 |
+
Dynamic: license-file
|
| 41 |
+
|
| 42 |
+
<div align="center">
|
| 43 |
+
<h1>Teaching Large Language Models to Regress Accurate Image Quality Scores using Score Distribution</h1>
|
| 44 |
+
|
| 45 |
+
<div>
|
| 46 |
+
<a href="https://zhiyuanyou.github.io/" target="_blank">Zhiyuan You</a><sup>12</sup>,
|
| 47 |
+
<a href="https://caixin98.github.io/" target="_blank">Xin Cai</a><sup>2</sup>,
|
| 48 |
+
<a href="https://www.jasongt.com/" target="_blank">Jinjin Gu</a><sup>4</sup>,
|
| 49 |
+
<a href="https://tianfan.info/" target="_blank">Tianfan Xue</a><sup>235</sup><sup>#</sup>,
|
| 50 |
+
<a href="https://xpixel.group/2010/01/20/chaodong.html" target="_blank">Chao Dong</a><sup>134</sup><sup>#</sup>
|
| 51 |
+
</div>
|
| 52 |
+
|
| 53 |
+
<div>
|
| 54 |
+
<sup>1</sup>Shenzhen Institutes of Advanced Technology, Chinese Academy of Sciences, <sup>2</sup>Multimedia Laboratory, The Chinese University of Hong Kong,
|
| 55 |
+
<sup>3</sup>Shanghai AI Laboratory, <sup>4</sup>Shenzhen University of Advanced Technology, <sup>5</sup>CPII under InnoHK
|
| 56 |
+
</div>
|
| 57 |
+
|
| 58 |
+
<div><sup>#</sup>Corresponding author.</div>
|
| 59 |
+
|
| 60 |
+
<div>
|
| 61 |
+
<a href="https://depictqa.github.io/deqa-score/" target="_blank"><strong>Homepage</strong></a> |
|
| 62 |
+
<strong>Model Weights</strong> (
|
| 63 |
+
<a href="https://huggingface.co/zhiyuanyou/DeQA-Score-Mix3" target="_blank"><strong>Full Tuning</strong></a> /
|
| 64 |
+
<a href="https://huggingface.co/zhiyuanyou/DeQA-Score-LoRA-Mix3" target="_blank"><strong>LoRA Tuning</strong></a>
|
| 65 |
+
) |
|
| 66 |
+
<a href="https://huggingface.co/datasets/zhiyuanyou/Data-DeQA-Score" target="_blank"><strong>Datasets</strong></a> |
|
| 67 |
+
<a href="https://arxiv.org/abs/2501.11561" target="_blank"><strong>Paper</strong></a>
|
| 68 |
+
</div>
|
| 69 |
+
|
| 70 |
+
<h2>Motivation</h2>
|
| 71 |
+
|
| 72 |
+
<div style="width: 100%; text-align: center; margin:auto;">
|
| 73 |
+
<img style="width: 75%" src="fig/teaser.png">
|
| 74 |
+
</div>
|
| 75 |
+
|
| 76 |
+
<h2>Model Architecture</h2>
|
| 77 |
+
|
| 78 |
+
<div style="width: 100%; text-align: center; margin:auto;">
|
| 79 |
+
<img style="width: 100%" src="fig/model.png">
|
| 80 |
+
</div>
|
| 81 |
+
</div>
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
## [Installation Free!] Quicker Start with Hugging Face AutoModel
|
| 85 |
+
|
| 86 |
+
[2025.12] Thanks to @[lyf1212](https://github.com/lyf1212)'s suggestion, we add support on `transformers==4.46.3` with minor code modifications. See [details](https://github.com/zhiyuanyou/DeQA-Score/issues/32).
|
| 87 |
+
|
| 88 |
+
The following code could be run directly with `transformers==4.36.1`. No need to install this GitHub repo.
|
| 89 |
+
|
| 90 |
+
```python
|
| 91 |
+
import requests
|
| 92 |
+
import torch
|
| 93 |
+
from transformers import AutoModelForCausalLM
|
| 94 |
+
|
| 95 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 96 |
+
"zhiyuanyou/DeQA-Score-Mix3",
|
| 97 |
+
trust_remote_code=True,
|
| 98 |
+
attn_implementation="eager",
|
| 99 |
+
torch_dtype=torch.float16,
|
| 100 |
+
device_map="auto",
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
from PIL import Image
|
| 104 |
+
|
| 105 |
+
# The inputs should be a list of multiple PIL images
|
| 106 |
+
model.score(
|
| 107 |
+
[Image.open(requests.get(
|
| 108 |
+
"https://raw.githubusercontent.com/zhiyuanyou/DeQA-Score/main/fig/singapore_flyer.jpg", stream=True
|
| 109 |
+
).raw)]
|
| 110 |
+
)
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
## Installation
|
| 114 |
+
|
| 115 |
+
If you only need to infer / evaluate:
|
| 116 |
+
|
| 117 |
+
```shell
|
| 118 |
+
git clone https://github.com/zhiyuanyou/DeQA-Score.git
|
| 119 |
+
cd DeQA-Score
|
| 120 |
+
pip install -e .
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
For training, you need to further install additional dependencies as follows:
|
| 124 |
+
|
| 125 |
+
```shell
|
| 126 |
+
pip install -e ".[train]"
|
| 127 |
+
pip install flash_attn --no-build-isolation
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
## Quick Start
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
### Image Quality Scorer
|
| 135 |
+
|
| 136 |
+
- CLI Interface
|
| 137 |
+
|
| 138 |
+
```shell
|
| 139 |
+
python src/evaluate/scorer.py --img_path fig/singapore_flyer.jpg
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
- Python API
|
| 143 |
+
|
| 144 |
+
```python
|
| 145 |
+
from src import Scorer
|
| 146 |
+
from PIL import Image
|
| 147 |
+
|
| 148 |
+
scorer = Scorer()
|
| 149 |
+
img_list = [Image.open("fig/singapore_flyer.jpg")] # can be a list of multiple PIL images
|
| 150 |
+
print(scorer(img_list).tolist())
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
## Training, Inference & Evaluation
|
| 155 |
+
|
| 156 |
+
### Datasets
|
| 157 |
+
|
| 158 |
+
<a id="datasets"></a>
|
| 159 |
+
|
| 160 |
+
- Download our meta files from [Huggingface Metas](https://huggingface.co/datasets/zhiyuanyou/Data-DeQA-Score).
|
| 161 |
+
|
| 162 |
+
- Download source images from [KonIQ](https://database.mmsp-kn.de/koniq-10k-database.html),
|
| 163 |
+
[SPAQ](https://github.com/h4nwei/SPAQ),
|
| 164 |
+
[KADID](https://database.mmsp-kn.de/kadid-10k-database.html),
|
| 165 |
+
[PIPAL](https://github.com/HaomingCai/PIPAL-dataset),
|
| 166 |
+
[LIVE-Wild](https://live.ece.utexas.edu/research/ChallengeDB/index.html),
|
| 167 |
+
[AGIQA](https://github.com/lcysyzxdxc/AGIQA-3k-Database),
|
| 168 |
+
[TID2013](https://www.ponomarenko.info/tid2013.htm),
|
| 169 |
+
and [CSIQ](https://s2.smu.edu/~eclarson/csiq.html).
|
| 170 |
+
|
| 171 |
+
- Arrange the folders as follows:
|
| 172 |
+
|
| 173 |
+
```
|
| 174 |
+
|-- DeQA-Score
|
| 175 |
+
|-- Data-DeQA-Score
|
| 176 |
+
|-- KONIQ
|
| 177 |
+
|-- images/*.jpg
|
| 178 |
+
|-- metas
|
| 179 |
+
|-- SPAQ
|
| 180 |
+
|-- images/*.jpg
|
| 181 |
+
|-- metas
|
| 182 |
+
|-- KADID10K
|
| 183 |
+
|-- images/*.png
|
| 184 |
+
|-- metas
|
| 185 |
+
|-- PIPAL
|
| 186 |
+
|-- images/Distortion_*/*.bmp
|
| 187 |
+
|-- metas
|
| 188 |
+
|-- LIVE-WILD
|
| 189 |
+
|-- images/*.bmp
|
| 190 |
+
|-- metas
|
| 191 |
+
|-- AGIQA3K
|
| 192 |
+
|-- images/*.jpg
|
| 193 |
+
|-- metas
|
| 194 |
+
|-- TID2013
|
| 195 |
+
|-- images/distorted_images/*.bmp
|
| 196 |
+
|-- metas
|
| 197 |
+
|-- CSIQ
|
| 198 |
+
|-- images/dst_imgs/*/*.png
|
| 199 |
+
|-- metas
|
| 200 |
+
```
|
| 201 |
+
|
| 202 |
+
### Pretrained Weights
|
| 203 |
+
|
| 204 |
+
<a id="pretrained_weights"></a>
|
| 205 |
+
|
| 206 |
+
We provide two model weights (full tuning and LoRA tuning) with similar performance.
|
| 207 |
+
|
| 208 |
+
| | Training Datasets | Weights |
|
| 209 |
+
|-----|-----|-----|
|
| 210 |
+
| Full Tuning | KonIQ, SPAQ, KADID | [Huggingface Full](https://huggingface.co/zhiyuanyou/DeQA-Score-Mix3) |
|
| 211 |
+
| LoRA Tuning | KonIQ, SPAQ, KADID | [Huggingface LoRA](https://huggingface.co/zhiyuanyou/DeQA-Score-LoRA-Mix3) |
|
| 212 |
+
|
| 213 |
+
Download one of the above model weights, then arrange the folders as follows:
|
| 214 |
+
|
| 215 |
+
```
|
| 216 |
+
|-- DeQA-Score
|
| 217 |
+
|-- checkpoints
|
| 218 |
+
|-- DeQA-Score-Mix3
|
| 219 |
+
|-- DeQA-Score-LoRA-Mix3
|
| 220 |
+
```
|
| 221 |
+
|
| 222 |
+
If you would like to use the LoRA tuning weights, you need to download the base mPLUG-Owl2 weights from [Huggingface mPLUG-Owl2](https://huggingface.co/MAGAer13/mplug-owl2-llama2-7b), then arrange the folders as follows:
|
| 223 |
+
|
| 224 |
+
```
|
| 225 |
+
|-- DeQA-Score
|
| 226 |
+
|-- ModelZoo
|
| 227 |
+
|-- mplug-owl2-llama2-7b
|
| 228 |
+
```
|
| 229 |
+
|
| 230 |
+
### Inference
|
| 231 |
+
|
| 232 |
+
After preparing the datasets, you can infer using pre-trained **DeQA-Score** or **DeQA-Score-LoRA**:
|
| 233 |
+
|
| 234 |
+
```shell
|
| 235 |
+
sh scripts/infer.sh $ONE_GPU_ID
|
| 236 |
+
```
|
| 237 |
+
|
| 238 |
+
```shell
|
| 239 |
+
sh scripts/infer_lora.sh $ONE_GPU_ID
|
| 240 |
+
```
|
| 241 |
+
|
| 242 |
+
### Evaluation
|
| 243 |
+
|
| 244 |
+
After inference, you can evaluate the inference results:
|
| 245 |
+
|
| 246 |
+
- SRCC / PLCC for quality score.
|
| 247 |
+
|
| 248 |
+
```shell
|
| 249 |
+
sh scripts/eval_score.sh
|
| 250 |
+
```
|
| 251 |
+
|
| 252 |
+
- KL Divergence / JS Divergence / Wasserstein Distance for score distribution.
|
| 253 |
+
|
| 254 |
+
```shell
|
| 255 |
+
sh scripts/eval_dist.sh
|
| 256 |
+
```
|
| 257 |
+
|
| 258 |
+
### Fine-tuning
|
| 259 |
+
|
| 260 |
+
Fine-tuning needs to download the mPLUG-Owl2 weights as in [Pretrained Weights](#pretrained_weights).
|
| 261 |
+
|
| 262 |
+
#### LoRA Fine-tuning
|
| 263 |
+
|
| 264 |
+
- Only **2 RTX3090 GPUs** are required. Revise `--data_paths` in the training shell to load different datasets. Default training datasets are KonIQ, SPAQ, and KADID.
|
| 265 |
+
|
| 266 |
+
```shell
|
| 267 |
+
sh scripts/train_lora.sh $GPU_IDs
|
| 268 |
+
```
|
| 269 |
+
|
| 270 |
+
#### Full Fine-tuning from the Scratch
|
| 271 |
+
|
| 272 |
+
- At least **8 A6000 GPUs** or **4 A100 GPUs** will be enough. Revise `--data_paths` in the training shell to load different datasets. Default training datasets are KonIQ, SPAQ, and KADID.
|
| 273 |
+
|
| 274 |
+
```shell
|
| 275 |
+
sh scripts/train.sh $GPU_IDs
|
| 276 |
+
```
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
## Soft Label Construction
|
| 280 |
+
|
| 281 |
+
- Download `split.json` (training & test split info) and `mos.json` (mos & std info) of KonIQ, SPAQ, and KADID from [Huggingface Metas](https://huggingface.co/datasets/zhiyuanyou/Data-DeQA-Score), and arrange the folders as in [Datasets](#datasets).
|
| 282 |
+
|
| 283 |
+
- Run the following scripts to construct the distribution-based soft labels.
|
| 284 |
+
|
| 285 |
+
```shell
|
| 286 |
+
cd build_soft_labels
|
| 287 |
+
python gen_soft_label.py
|
| 288 |
+
```
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
## Acknowledgements
|
| 292 |
+
|
| 293 |
+
This work is based on [Q-Align](https://github.com/Q-Future/Q-Align). Sincerely thanks for this awesome work.
|
| 294 |
+
|
| 295 |
+
## Citation
|
| 296 |
+
|
| 297 |
+
If you find our work useful for your research and applications, please cite using the BibTeX:
|
| 298 |
+
|
| 299 |
+
```bibtex
|
| 300 |
+
@inproceedings{deqa_score,
|
| 301 |
+
title={Teaching Large Language Models to Regress Accurate Image Quality Scores using Score Distribution},
|
| 302 |
+
author={You, Zhiyuan and Cai, Xin and Gu, Jinjin and Xue, Tianfan and Dong, Chao},
|
| 303 |
+
booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
| 304 |
+
pages={14483--14494},
|
| 305 |
+
year={2025}
|
| 306 |
+
}
|
| 307 |
+
|
| 308 |
+
@article{depictqa_v2,
|
| 309 |
+
title={Enhancing Descriptive Image Quality Assessment with A Large-scale Multi-modal Dataset},
|
| 310 |
+
author={You, Zhiyuan and Gu, Jinjin and Cai, Xin and Li, Zheyuan and Zhu, Kaiwen and Dong, Chao and Xue, Tianfan},
|
| 311 |
+
journal={IEEE Transactions on Image Processing},
|
| 312 |
+
year={2025}
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
@inproceedings{depictqa_v1,
|
| 316 |
+
title={Depicting Beyond Scores: Advancing Image Quality Assessment through Multi-modal Language Models},
|
| 317 |
+
author={You, Zhiyuan and Li, Zheyuan and Gu, Jinjin and Yin, Zhenfei and Xue, Tianfan and Dong, Chao},
|
| 318 |
+
booktitle={European Conference on Computer Vision},
|
| 319 |
+
pages={259--276},
|
| 320 |
+
year={2024}
|
| 321 |
+
}
|
| 322 |
+
```
|
DeQA-Score/DeQA_Score.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
LICENSE
|
| 2 |
+
README.md
|
| 3 |
+
pyproject.toml
|
| 4 |
+
DeQA_Score.egg-info/PKG-INFO
|
| 5 |
+
DeQA_Score.egg-info/SOURCES.txt
|
| 6 |
+
DeQA_Score.egg-info/dependency_links.txt
|
| 7 |
+
DeQA_Score.egg-info/requires.txt
|
| 8 |
+
DeQA_Score.egg-info/top_level.txt
|
| 9 |
+
build_soft_labels/gen_soft_label.py
|
| 10 |
+
src/__init__.py
|
| 11 |
+
src/constants.py
|
| 12 |
+
src/conversation.py
|
| 13 |
+
src/mm_utils.py
|
| 14 |
+
src/utils.py
|
| 15 |
+
src/datasets/__init__.py
|
| 16 |
+
src/datasets/pair_dataset.py
|
| 17 |
+
src/datasets/single_dataset.py
|
| 18 |
+
src/datasets/utils.py
|
| 19 |
+
src/evaluate/__init__.py
|
| 20 |
+
src/evaluate/cal_distribution_gap.py
|
| 21 |
+
src/evaluate/cal_plcc_srcc.py
|
| 22 |
+
src/evaluate/eval_qbench_mcq.py
|
| 23 |
+
src/evaluate/iqa_eval.py
|
| 24 |
+
src/evaluate/scorer.py
|
| 25 |
+
src/evaluate/scorer_coco.py
|
| 26 |
+
src/model/__init__.py
|
| 27 |
+
src/model/builder.py
|
| 28 |
+
src/model/configuration_mplug_owl2.py
|
| 29 |
+
src/model/convert_mplug_owl2_weight_to_hf.py
|
| 30 |
+
src/model/modeling_attn_mask_utils.py
|
| 31 |
+
src/model/modeling_llama2.py
|
| 32 |
+
src/model/modeling_mplug_owl2.py
|
| 33 |
+
src/model/utils.py
|
| 34 |
+
src/model/visual_encoder.py
|
| 35 |
+
src/train/mplug_owl2_trainer.py
|
| 36 |
+
src/train/train_mem.py
|
DeQA-Score/DeQA_Score.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
DeQA-Score/DeQA_Score.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.0.1
|
| 2 |
+
torchvision==0.15.2
|
| 3 |
+
transformers==4.36.1
|
| 4 |
+
tokenizers==0.15.0
|
| 5 |
+
sentencepiece==0.1.99
|
| 6 |
+
shortuuid
|
| 7 |
+
accelerate==0.21.0
|
| 8 |
+
peft==0.4.0
|
| 9 |
+
bitsandbytes==0.41.0
|
| 10 |
+
pydantic<2,>=1
|
| 11 |
+
markdown2[all]
|
| 12 |
+
numpy
|
| 13 |
+
scikit-learn==1.2.2
|
| 14 |
+
gradio==3.35.2
|
| 15 |
+
gradio_client==0.2.9
|
| 16 |
+
requests
|
| 17 |
+
httpx==0.24.0
|
| 18 |
+
uvicorn
|
| 19 |
+
fastapi
|
| 20 |
+
icecream
|
| 21 |
+
einops==0.6.1
|
| 22 |
+
einops-exts==0.0.4
|
| 23 |
+
timm==0.6.13
|
| 24 |
+
decord
|
| 25 |
+
scipy
|
| 26 |
+
|
| 27 |
+
[train]
|
| 28 |
+
deepspeed==0.9.5
|
| 29 |
+
ninja
|
| 30 |
+
wandb
|
DeQA-Score/DeQA_Score.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
build_soft_labels
|
| 2 |
+
fig
|
| 3 |
+
preprocessor
|
| 4 |
+
src
|
DeQA-Score/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Depicted image Quality Assessment (DepictQA / DeQA)
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
DeQA-Score/README.md
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
<h1>Teaching Large Language Models to Regress Accurate Image Quality Scores using Score Distribution</h1>
|
| 3 |
+
|
| 4 |
+
<div>
|
| 5 |
+
<a href="https://zhiyuanyou.github.io/" target="_blank">Zhiyuan You</a><sup>12</sup>,
|
| 6 |
+
<a href="https://caixin98.github.io/" target="_blank">Xin Cai</a><sup>2</sup>,
|
| 7 |
+
<a href="https://www.jasongt.com/" target="_blank">Jinjin Gu</a><sup>4</sup>,
|
| 8 |
+
<a href="https://tianfan.info/" target="_blank">Tianfan Xue</a><sup>235</sup><sup>#</sup>,
|
| 9 |
+
<a href="https://xpixel.group/2010/01/20/chaodong.html" target="_blank">Chao Dong</a><sup>134</sup><sup>#</sup>
|
| 10 |
+
</div>
|
| 11 |
+
|
| 12 |
+
<div>
|
| 13 |
+
<sup>1</sup>Shenzhen Institutes of Advanced Technology, Chinese Academy of Sciences, <sup>2</sup>Multimedia Laboratory, The Chinese University of Hong Kong,
|
| 14 |
+
<sup>3</sup>Shanghai AI Laboratory, <sup>4</sup>Shenzhen University of Advanced Technology, <sup>5</sup>CPII under InnoHK
|
| 15 |
+
</div>
|
| 16 |
+
|
| 17 |
+
<div><sup>#</sup>Corresponding author.</div>
|
| 18 |
+
|
| 19 |
+
<div>
|
| 20 |
+
<a href="https://depictqa.github.io/deqa-score/" target="_blank"><strong>Homepage</strong></a> |
|
| 21 |
+
<strong>Model Weights</strong> (
|
| 22 |
+
<a href="https://huggingface.co/zhiyuanyou/DeQA-Score-Mix3" target="_blank"><strong>Full Tuning</strong></a> /
|
| 23 |
+
<a href="https://huggingface.co/zhiyuanyou/DeQA-Score-LoRA-Mix3" target="_blank"><strong>LoRA Tuning</strong></a>
|
| 24 |
+
) |
|
| 25 |
+
<a href="https://huggingface.co/datasets/zhiyuanyou/Data-DeQA-Score" target="_blank"><strong>Datasets</strong></a> |
|
| 26 |
+
<a href="https://arxiv.org/abs/2501.11561" target="_blank"><strong>Paper</strong></a>
|
| 27 |
+
</div>
|
| 28 |
+
|
| 29 |
+
<h2>Motivation</h2>
|
| 30 |
+
|
| 31 |
+
<div style="width: 100%; text-align: center; margin:auto;">
|
| 32 |
+
<img style="width: 75%" src="fig/teaser.png">
|
| 33 |
+
</div>
|
| 34 |
+
|
| 35 |
+
<h2>Model Architecture</h2>
|
| 36 |
+
|
| 37 |
+
<div style="width: 100%; text-align: center; margin:auto;">
|
| 38 |
+
<img style="width: 100%" src="fig/model.png">
|
| 39 |
+
</div>
|
| 40 |
+
</div>
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
## [Installation Free!] Quicker Start with Hugging Face AutoModel
|
| 44 |
+
|
| 45 |
+
[2025.12] Thanks to @[lyf1212](https://github.com/lyf1212)'s suggestion, we add support on `transformers==4.46.3` with minor code modifications. See [details](https://github.com/zhiyuanyou/DeQA-Score/issues/32).
|
| 46 |
+
|
| 47 |
+
The following code could be run directly with `transformers==4.36.1`. No need to install this GitHub repo.
|
| 48 |
+
|
| 49 |
+
```python
|
| 50 |
+
import requests
|
| 51 |
+
import torch
|
| 52 |
+
from transformers import AutoModelForCausalLM
|
| 53 |
+
|
| 54 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 55 |
+
"zhiyuanyou/DeQA-Score-Mix3",
|
| 56 |
+
trust_remote_code=True,
|
| 57 |
+
attn_implementation="eager",
|
| 58 |
+
torch_dtype=torch.float16,
|
| 59 |
+
device_map="auto",
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
from PIL import Image
|
| 63 |
+
|
| 64 |
+
# The inputs should be a list of multiple PIL images
|
| 65 |
+
model.score(
|
| 66 |
+
[Image.open(requests.get(
|
| 67 |
+
"https://raw.githubusercontent.com/zhiyuanyou/DeQA-Score/main/fig/singapore_flyer.jpg", stream=True
|
| 68 |
+
).raw)]
|
| 69 |
+
)
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
## Installation
|
| 73 |
+
|
| 74 |
+
If you only need to infer / evaluate:
|
| 75 |
+
|
| 76 |
+
```shell
|
| 77 |
+
git clone https://github.com/zhiyuanyou/DeQA-Score.git
|
| 78 |
+
cd DeQA-Score
|
| 79 |
+
pip install -e .
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
For training, you need to further install additional dependencies as follows:
|
| 83 |
+
|
| 84 |
+
```shell
|
| 85 |
+
pip install -e ".[train]"
|
| 86 |
+
pip install flash_attn --no-build-isolation
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
## Quick Start
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
### Image Quality Scorer
|
| 94 |
+
|
| 95 |
+
- CLI Interface
|
| 96 |
+
|
| 97 |
+
```shell
|
| 98 |
+
python src/evaluate/scorer.py --img_path fig/singapore_flyer.jpg
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
- Python API
|
| 102 |
+
|
| 103 |
+
```python
|
| 104 |
+
from src import Scorer
|
| 105 |
+
from PIL import Image
|
| 106 |
+
|
| 107 |
+
scorer = Scorer()
|
| 108 |
+
img_list = [Image.open("fig/singapore_flyer.jpg")] # can be a list of multiple PIL images
|
| 109 |
+
print(scorer(img_list).tolist())
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
## Training, Inference & Evaluation
|
| 114 |
+
|
| 115 |
+
### Datasets
|
| 116 |
+
|
| 117 |
+
<a id="datasets"></a>
|
| 118 |
+
|
| 119 |
+
- Download our meta files from [Huggingface Metas](https://huggingface.co/datasets/zhiyuanyou/Data-DeQA-Score).
|
| 120 |
+
|
| 121 |
+
- Download source images from [KonIQ](https://database.mmsp-kn.de/koniq-10k-database.html),
|
| 122 |
+
[SPAQ](https://github.com/h4nwei/SPAQ),
|
| 123 |
+
[KADID](https://database.mmsp-kn.de/kadid-10k-database.html),
|
| 124 |
+
[PIPAL](https://github.com/HaomingCai/PIPAL-dataset),
|
| 125 |
+
[LIVE-Wild](https://live.ece.utexas.edu/research/ChallengeDB/index.html),
|
| 126 |
+
[AGIQA](https://github.com/lcysyzxdxc/AGIQA-3k-Database),
|
| 127 |
+
[TID2013](https://www.ponomarenko.info/tid2013.htm),
|
| 128 |
+
and [CSIQ](https://s2.smu.edu/~eclarson/csiq.html).
|
| 129 |
+
|
| 130 |
+
- Arrange the folders as follows:
|
| 131 |
+
|
| 132 |
+
```
|
| 133 |
+
|-- DeQA-Score
|
| 134 |
+
|-- Data-DeQA-Score
|
| 135 |
+
|-- KONIQ
|
| 136 |
+
|-- images/*.jpg
|
| 137 |
+
|-- metas
|
| 138 |
+
|-- SPAQ
|
| 139 |
+
|-- images/*.jpg
|
| 140 |
+
|-- metas
|
| 141 |
+
|-- KADID10K
|
| 142 |
+
|-- images/*.png
|
| 143 |
+
|-- metas
|
| 144 |
+
|-- PIPAL
|
| 145 |
+
|-- images/Distortion_*/*.bmp
|
| 146 |
+
|-- metas
|
| 147 |
+
|-- LIVE-WILD
|
| 148 |
+
|-- images/*.bmp
|
| 149 |
+
|-- metas
|
| 150 |
+
|-- AGIQA3K
|
| 151 |
+
|-- images/*.jpg
|
| 152 |
+
|-- metas
|
| 153 |
+
|-- TID2013
|
| 154 |
+
|-- images/distorted_images/*.bmp
|
| 155 |
+
|-- metas
|
| 156 |
+
|-- CSIQ
|
| 157 |
+
|-- images/dst_imgs/*/*.png
|
| 158 |
+
|-- metas
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
### Pretrained Weights
|
| 162 |
+
|
| 163 |
+
<a id="pretrained_weights"></a>
|
| 164 |
+
|
| 165 |
+
We provide two model weights (full tuning and LoRA tuning) with similar performance.
|
| 166 |
+
|
| 167 |
+
| | Training Datasets | Weights |
|
| 168 |
+
|-----|-----|-----|
|
| 169 |
+
| Full Tuning | KonIQ, SPAQ, KADID | [Huggingface Full](https://huggingface.co/zhiyuanyou/DeQA-Score-Mix3) |
|
| 170 |
+
| LoRA Tuning | KonIQ, SPAQ, KADID | [Huggingface LoRA](https://huggingface.co/zhiyuanyou/DeQA-Score-LoRA-Mix3) |
|
| 171 |
+
|
| 172 |
+
Download one of the above model weights, then arrange the folders as follows:
|
| 173 |
+
|
| 174 |
+
```
|
| 175 |
+
|-- DeQA-Score
|
| 176 |
+
|-- checkpoints
|
| 177 |
+
|-- DeQA-Score-Mix3
|
| 178 |
+
|-- DeQA-Score-LoRA-Mix3
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
If you would like to use the LoRA tuning weights, you need to download the base mPLUG-Owl2 weights from [Huggingface mPLUG-Owl2](https://huggingface.co/MAGAer13/mplug-owl2-llama2-7b), then arrange the folders as follows:
|
| 182 |
+
|
| 183 |
+
```
|
| 184 |
+
|-- DeQA-Score
|
| 185 |
+
|-- ModelZoo
|
| 186 |
+
|-- mplug-owl2-llama2-7b
|
| 187 |
+
```
|
| 188 |
+
|
| 189 |
+
### Inference
|
| 190 |
+
|
| 191 |
+
After preparing the datasets, you can infer using pre-trained **DeQA-Score** or **DeQA-Score-LoRA**:
|
| 192 |
+
|
| 193 |
+
```shell
|
| 194 |
+
sh scripts/infer.sh $ONE_GPU_ID
|
| 195 |
+
```
|
| 196 |
+
|
| 197 |
+
```shell
|
| 198 |
+
sh scripts/infer_lora.sh $ONE_GPU_ID
|
| 199 |
+
```
|
| 200 |
+
|
| 201 |
+
### Evaluation
|
| 202 |
+
|
| 203 |
+
After inference, you can evaluate the inference results:
|
| 204 |
+
|
| 205 |
+
- SRCC / PLCC for quality score.
|
| 206 |
+
|
| 207 |
+
```shell
|
| 208 |
+
sh scripts/eval_score.sh
|
| 209 |
+
```
|
| 210 |
+
|
| 211 |
+
- KL Divergence / JS Divergence / Wasserstein Distance for score distribution.
|
| 212 |
+
|
| 213 |
+
```shell
|
| 214 |
+
sh scripts/eval_dist.sh
|
| 215 |
+
```
|
| 216 |
+
|
| 217 |
+
### Fine-tuning
|
| 218 |
+
|
| 219 |
+
Fine-tuning needs to download the mPLUG-Owl2 weights as in [Pretrained Weights](#pretrained_weights).
|
| 220 |
+
|
| 221 |
+
#### LoRA Fine-tuning
|
| 222 |
+
|
| 223 |
+
- Only **2 RTX3090 GPUs** are required. Revise `--data_paths` in the training shell to load different datasets. Default training datasets are KonIQ, SPAQ, and KADID.
|
| 224 |
+
|
| 225 |
+
```shell
|
| 226 |
+
sh scripts/train_lora.sh $GPU_IDs
|
| 227 |
+
```
|
| 228 |
+
|
| 229 |
+
#### Full Fine-tuning from the Scratch
|
| 230 |
+
|
| 231 |
+
- At least **8 A6000 GPUs** or **4 A100 GPUs** will be enough. Revise `--data_paths` in the training shell to load different datasets. Default training datasets are KonIQ, SPAQ, and KADID.
|
| 232 |
+
|
| 233 |
+
```shell
|
| 234 |
+
sh scripts/train.sh $GPU_IDs
|
| 235 |
+
```
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
## Soft Label Construction
|
| 239 |
+
|
| 240 |
+
- Download `split.json` (training & test split info) and `mos.json` (mos & std info) of KonIQ, SPAQ, and KADID from [Huggingface Metas](https://huggingface.co/datasets/zhiyuanyou/Data-DeQA-Score), and arrange the folders as in [Datasets](#datasets).
|
| 241 |
+
|
| 242 |
+
- Run the following scripts to construct the distribution-based soft labels.
|
| 243 |
+
|
| 244 |
+
```shell
|
| 245 |
+
cd build_soft_labels
|
| 246 |
+
python gen_soft_label.py
|
| 247 |
+
```
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
## Acknowledgements
|
| 251 |
+
|
| 252 |
+
This work is based on [Q-Align](https://github.com/Q-Future/Q-Align). Sincerely thanks for this awesome work.
|
| 253 |
+
|
| 254 |
+
## Citation
|
| 255 |
+
|
| 256 |
+
If you find our work useful for your research and applications, please cite using the BibTeX:
|
| 257 |
+
|
| 258 |
+
```bibtex
|
| 259 |
+
@inproceedings{deqa_score,
|
| 260 |
+
title={Teaching Large Language Models to Regress Accurate Image Quality Scores using Score Distribution},
|
| 261 |
+
author={You, Zhiyuan and Cai, Xin and Gu, Jinjin and Xue, Tianfan and Dong, Chao},
|
| 262 |
+
booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
| 263 |
+
pages={14483--14494},
|
| 264 |
+
year={2025}
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
@article{depictqa_v2,
|
| 268 |
+
title={Enhancing Descriptive Image Quality Assessment with A Large-scale Multi-modal Dataset},
|
| 269 |
+
author={You, Zhiyuan and Gu, Jinjin and Cai, Xin and Li, Zheyuan and Zhu, Kaiwen and Dong, Chao and Xue, Tianfan},
|
| 270 |
+
journal={IEEE Transactions on Image Processing},
|
| 271 |
+
year={2025}
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
@inproceedings{depictqa_v1,
|
| 275 |
+
title={Depicting Beyond Scores: Advancing Image Quality Assessment through Multi-modal Language Models},
|
| 276 |
+
author={You, Zhiyuan and Li, Zheyuan and Gu, Jinjin and Yin, Zhenfei and Xue, Tianfan and Dong, Chao},
|
| 277 |
+
booktitle={European Conference on Computer Vision},
|
| 278 |
+
pages={259--276},
|
| 279 |
+
year={2024}
|
| 280 |
+
}
|
| 281 |
+
```
|
DeQA-Score/build_soft_labels/config.json
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"answer": "The quality of the image is {}.",
|
| 3 |
+
"dataset_params": {
|
| 4 |
+
"koniq": {
|
| 5 |
+
"split_json": "../../Data-DeQA-Score/KONIQ/metas/split.json",
|
| 6 |
+
"mos_json": "../../Data-DeQA-Score/KONIQ/metas/mos.json",
|
| 7 |
+
"save_train": "../../Data-DeQA-Score/KONIQ/metas/train_koniq_7k_new.json",
|
| 8 |
+
"save_test": "../../Data-DeQA-Score/KONIQ/metas/test_koniq_2k_new.json",
|
| 9 |
+
"img_dir": "KONIQ/images",
|
| 10 |
+
"density_type": "pdf",
|
| 11 |
+
"thre_std": 0.2,
|
| 12 |
+
"thre_diff": 0.1
|
| 13 |
+
},
|
| 14 |
+
"spaq": {
|
| 15 |
+
"split_json": "../../Data-DeQA-Score/SPAQ/metas/split.json",
|
| 16 |
+
"mos_json": "../../Data-DeQA-Score/SPAQ/metas/mos.json",
|
| 17 |
+
"save_train": "../../Data-DeQA-Score/SPAQ/metas/train_spaq_9k_new.json",
|
| 18 |
+
"save_test": "../../Data-DeQA-Score/SPAQ/metas/test_spaq_2k_new.json",
|
| 19 |
+
"img_dir": "SPAQ/images",
|
| 20 |
+
"density_type": "cdf",
|
| 21 |
+
"thre_std": 0.2,
|
| 22 |
+
"thre_diff": 0.1
|
| 23 |
+
},
|
| 24 |
+
"kadid": {
|
| 25 |
+
"split_json": "../../Data-DeQA-Score/KADID10K/metas/split.json",
|
| 26 |
+
"mos_json": "../../Data-DeQA-Score/KADID10K/metas/mos.json",
|
| 27 |
+
"save_train": "../../Data-DeQA-Score/KADID10K/metas/train_kadid_8k_new.json",
|
| 28 |
+
"save_test": "../../Data-DeQA-Score/KADID10K/metas/test_kadid_2k_new.json",
|
| 29 |
+
"img_dir": "KADID10K/images",
|
| 30 |
+
"density_type": "pdf",
|
| 31 |
+
"thre_std": 0.2,
|
| 32 |
+
"thre_diff": 0.1
|
| 33 |
+
}
|
| 34 |
+
}
|
| 35 |
+
}
|
DeQA-Score/build_soft_labels/gen_soft_label.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import numpy as np
|
| 4 |
+
import os
|
| 5 |
+
import random
|
| 6 |
+
from scipy.stats import norm, pearsonr, spearmanr
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def parse_args():
|
| 10 |
+
parser = argparse.ArgumentParser(description="label parameters for DeQA-Score")
|
| 11 |
+
parser.add_argument("--config", type=str, default="./config.json")
|
| 12 |
+
args = parser.parse_args()
|
| 13 |
+
return args
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
questions = [
|
| 17 |
+
"What do you think about the quality of this image?",
|
| 18 |
+
"Can you rate the quality of this picture?",
|
| 19 |
+
"Can you judge the quality of this image?",
|
| 20 |
+
"How would you rate the quality of this image?",
|
| 21 |
+
"How would you judge the quality of this image?",
|
| 22 |
+
"What is your quality rating for this image?",
|
| 23 |
+
"What's your opinion on the quality of this picture?",
|
| 24 |
+
"Rate the quality of this image.",
|
| 25 |
+
"Could you evaluate the quality of this image?",
|
| 26 |
+
"How do you assess the quality of this image?",
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def calculate_srcc_plcc(pred, mos):
|
| 31 |
+
srcc, _ = spearmanr(pred, mos)
|
| 32 |
+
plcc, _ = pearsonr(pred, mos)
|
| 33 |
+
return srcc, plcc
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_level(mos, min_mos, max_mos):
|
| 37 |
+
eps = 1e-8
|
| 38 |
+
texts = ["bad", "poor", "fair", "good", "excellent"]
|
| 39 |
+
for idx in range(1, len(texts) + 1):
|
| 40 |
+
mos_left = min_mos + (idx - 1) / 5 * (max_mos - min_mos) - eps
|
| 41 |
+
mos_right = min_mos + idx / 5 * (max_mos - min_mos) + eps
|
| 42 |
+
if mos > mos_left and mos <= mos_right:
|
| 43 |
+
level = idx
|
| 44 |
+
break
|
| 45 |
+
text = texts[level - 1]
|
| 46 |
+
return text
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def adjust_gaussian_bar(probs, score):
|
| 50 |
+
"""
|
| 51 |
+
alpha * (a + b + c + d + e) + 5 * beta = 1
|
| 52 |
+
alpha * (5a + 4b + 3c + 2d + e) + 15 beta = score
|
| 53 |
+
==>
|
| 54 |
+
alpha * A + 5 * beta = 1
|
| 55 |
+
alpha * B + 15 * beta = score
|
| 56 |
+
"""
|
| 57 |
+
A = np.array(probs).sum()
|
| 58 |
+
B = np.inner(np.array(probs), np.array([5, 4, 3, 2, 1]))
|
| 59 |
+
alpha = (score - 3) / (B - 3. * A + 1e-9)
|
| 60 |
+
beta = (1. - alpha * A) / 5.
|
| 61 |
+
return alpha, beta
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def get_binary_probs(mos, min_mos=1.0, max_mos=5.0):
|
| 65 |
+
eps = 1e-8
|
| 66 |
+
probs = [0, 0, 0, 0, 0]
|
| 67 |
+
for idx in range(1, len(probs)):
|
| 68 |
+
mos_left = min_mos + (idx - 1) / 4 * (max_mos - min_mos) - eps
|
| 69 |
+
mos_right = min_mos + idx / 4 * (max_mos - min_mos) + eps
|
| 70 |
+
if mos > mos_left and mos <= mos_right:
|
| 71 |
+
probs[idx - 1] = (mos_right - mos) / (mos_right - mos_left)
|
| 72 |
+
probs[idx] = (mos - mos_left) / (mos_right - mos_left)
|
| 73 |
+
break
|
| 74 |
+
assert np.array((np.array(probs) == 0)).sum() == 3
|
| 75 |
+
assert round(np.array(probs).sum(), 5) == 1
|
| 76 |
+
probs = probs[::-1] # should start with "excellent" & end with "bad"
|
| 77 |
+
return probs
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def main(cfg):
|
| 81 |
+
density_type = cfg["density_type"] # ["pdf", "cdf"]
|
| 82 |
+
thre_std = cfg["thre_std"]
|
| 83 |
+
thre_diff = cfg["thre_diff"]
|
| 84 |
+
with open(cfg["split_json"]) as fr:
|
| 85 |
+
split = json.load(fr)
|
| 86 |
+
with open(cfg["mos_json"]) as fr:
|
| 87 |
+
mos_dict = json.load(fr)
|
| 88 |
+
save_train = cfg["save_train"]
|
| 89 |
+
save_test = cfg["save_test"]
|
| 90 |
+
img_dir = cfg["img_dir"]
|
| 91 |
+
|
| 92 |
+
moses, stds, imgs = [], [], []
|
| 93 |
+
for img in mos_dict:
|
| 94 |
+
moses.append(mos_dict[img]["mos"])
|
| 95 |
+
stds.append(mos_dict[img]["std"])
|
| 96 |
+
imgs.append(img)
|
| 97 |
+
max_mos = max([float(_) for _ in moses])
|
| 98 |
+
min_mos = min([float(_) for _ in moses])
|
| 99 |
+
|
| 100 |
+
num_binary, idx = 0, 0
|
| 101 |
+
preds, gts, raw_diffs, diffs, alphas, betas = [], [], [], [], [], []
|
| 102 |
+
train_metas, test_metas = [], []
|
| 103 |
+
for img, mos_str, std_str in zip(imgs, moses, stds):
|
| 104 |
+
mos, std = float(mos_str), float(std_str)
|
| 105 |
+
if os.path.basename(img) in split["train"]:
|
| 106 |
+
training = True
|
| 107 |
+
elif os.path.basename(img) in split["test"]:
|
| 108 |
+
training = False
|
| 109 |
+
else:
|
| 110 |
+
idx += 1
|
| 111 |
+
# print(idx, img)
|
| 112 |
+
continue
|
| 113 |
+
|
| 114 |
+
text = get_level(mos, min_mos, max_mos)
|
| 115 |
+
query = random.choice(questions)
|
| 116 |
+
resp = answer.replace("{}", text)
|
| 117 |
+
|
| 118 |
+
# norm mos and std
|
| 119 |
+
mos_norm = 4 * (mos - min_mos) / (max_mos - min_mos) + 1 # [0, 1] -> [1, 5]
|
| 120 |
+
std_norm = 4 * std / (max_mos - min_mos)
|
| 121 |
+
|
| 122 |
+
# ["excellent", "good", "fair", "poor", "bad"] -> [5, 4, 3, 2, 1]
|
| 123 |
+
probs = []
|
| 124 |
+
for x in range(5, 0, -1):
|
| 125 |
+
if density_type == "cdf":
|
| 126 |
+
# better for smaller std dataset (see Appendix) like SPAQ
|
| 127 |
+
prob = norm.cdf(x+0.5, mos_norm, std_norm) - norm.cdf(x-0.5, mos_norm, std_norm)
|
| 128 |
+
else:
|
| 129 |
+
# better for larger std dataset (see Appendix) like KonIQ and KADID
|
| 130 |
+
assert density_type == "pdf"
|
| 131 |
+
prob = norm.pdf(x, loc=mos_norm, scale=std_norm)
|
| 132 |
+
probs.append(prob)
|
| 133 |
+
|
| 134 |
+
mos_rec = np.inner(np.array(probs), np.array([5, 4, 3, 2, 1]))
|
| 135 |
+
raw_diff = abs(mos_rec - mos_norm)
|
| 136 |
+
raw_diffs.append(raw_diff)
|
| 137 |
+
|
| 138 |
+
alpha, beta = adjust_gaussian_bar(probs, mos_norm)
|
| 139 |
+
probs_norm = [max(_ * alpha + beta, 0) for _ in probs]
|
| 140 |
+
mos_rec = np.inner(np.array(probs_norm), np.array([5, 4, 3, 2, 1]))
|
| 141 |
+
diff = abs(mos_rec - mos_norm)
|
| 142 |
+
|
| 143 |
+
if std_norm < thre_std or diff > thre_diff:
|
| 144 |
+
# if std is too small, use binary probs (see Appendix)
|
| 145 |
+
probs_norm = get_binary_probs(mos_norm)
|
| 146 |
+
mos_rec = np.inner(np.array(probs_norm), np.array([5, 4, 3, 2, 1]))
|
| 147 |
+
diff, alpha, beta = abs(mos_rec - mos_norm), 1., 0.
|
| 148 |
+
num_binary += 1
|
| 149 |
+
|
| 150 |
+
preds.append(mos_rec)
|
| 151 |
+
gts.append(mos_norm)
|
| 152 |
+
diffs.append(diff)
|
| 153 |
+
alphas.append(alpha)
|
| 154 |
+
betas.append(beta)
|
| 155 |
+
|
| 156 |
+
meta = {
|
| 157 |
+
"id": os.path.basename(img) + f"->{mos_str}",
|
| 158 |
+
"image": os.path.join(img_dir, img),
|
| 159 |
+
"gt_score": mos,
|
| 160 |
+
"gt_score_norm": mos_norm,
|
| 161 |
+
"level_probs_org": probs,
|
| 162 |
+
"level_probs": probs_norm,
|
| 163 |
+
"std": std,
|
| 164 |
+
"std_norm": std_norm,
|
| 165 |
+
}
|
| 166 |
+
if training:
|
| 167 |
+
conversations = [
|
| 168 |
+
{
|
| 169 |
+
"from": "human",
|
| 170 |
+
"value": query + "\n<|image|>",
|
| 171 |
+
},
|
| 172 |
+
{
|
| 173 |
+
"from": "gpt",
|
| 174 |
+
"value": resp,
|
| 175 |
+
},
|
| 176 |
+
]
|
| 177 |
+
meta["conversations"] = conversations
|
| 178 |
+
train_metas.append(meta)
|
| 179 |
+
else:
|
| 180 |
+
del meta["level_probs_org"]
|
| 181 |
+
del meta["level_probs"]
|
| 182 |
+
test_metas.append(meta)
|
| 183 |
+
|
| 184 |
+
print("=" * 100)
|
| 185 |
+
print(f"save {len(train_metas)} into {save_train}")
|
| 186 |
+
with open(save_train, "w") as fw:
|
| 187 |
+
fw.write(json.dumps(train_metas, indent=4))
|
| 188 |
+
|
| 189 |
+
print(f"save {len(test_metas)} into {save_test}")
|
| 190 |
+
with open(save_test, "w") as fw:
|
| 191 |
+
fw.write(json.dumps(test_metas, indent=4))
|
| 192 |
+
|
| 193 |
+
srcc, plcc = calculate_srcc_plcc(preds, gts)
|
| 194 |
+
print("srcc:", srcc, "plcc:", plcc)
|
| 195 |
+
print("[raw_diff]", "l1:", sum(raw_diffs) / len(raw_diffs), "l2:", np.sqrt((np.array(raw_diffs)**2).mean()))
|
| 196 |
+
print("[diff]", "l1:", sum(diffs) / len(diffs), "l2:", np.sqrt((np.array(diffs)**2).mean()))
|
| 197 |
+
print("[alpha]", "mean:", np.mean(alphas), "std:", np.std(alphas))
|
| 198 |
+
print("[beta]", "mean:", np.mean(betas), "std:", np.std(betas))
|
| 199 |
+
print("binary / all:", num_binary, "/", len(train_metas) + len(test_metas))
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
if __name__ == "__main__":
|
| 203 |
+
args = parse_args()
|
| 204 |
+
with open(args.config) as fr:
|
| 205 |
+
cfg = json.load(fr)
|
| 206 |
+
answer = cfg["answer"]
|
| 207 |
+
for dataset in cfg["dataset_params"]:
|
| 208 |
+
random.seed(131)
|
| 209 |
+
main(cfg["dataset_params"][dataset])
|
DeQA-Score/fig/boy_colorful.jpg
ADDED
|
Git LFS Details
|
DeQA-Score/fig/model.png
ADDED
|
Git LFS Details
|
DeQA-Score/fig/singapore_flyer.jpg
ADDED
|
Git LFS Details
|
DeQA-Score/fig/teaser.png
ADDED
|
Git LFS Details
|
DeQA-Score/preprocessor/preprocessor_config.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"crop_size": 448,
|
| 3 |
+
"do_center_crop": true,
|
| 4 |
+
"do_normalize": true,
|
| 5 |
+
"do_resize": true,
|
| 6 |
+
"feature_extractor_type": "CLIPFeatureExtractor",
|
| 7 |
+
"image_mean": [
|
| 8 |
+
0.48145466,
|
| 9 |
+
0.4578275,
|
| 10 |
+
0.40821073
|
| 11 |
+
],
|
| 12 |
+
"image_std": [
|
| 13 |
+
0.26862954,
|
| 14 |
+
0.26130258,
|
| 15 |
+
0.27577711
|
| 16 |
+
],
|
| 17 |
+
"resample": 3,
|
| 18 |
+
"size": 448
|
| 19 |
+
}
|
DeQA-Score/preprocessor/special_tokens_map.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bos_token": {
|
| 3 |
+
"content": "<s>",
|
| 4 |
+
"lstrip": false,
|
| 5 |
+
"normalized": false,
|
| 6 |
+
"rstrip": false,
|
| 7 |
+
"single_word": false
|
| 8 |
+
},
|
| 9 |
+
"eos_token": {
|
| 10 |
+
"content": "</s>",
|
| 11 |
+
"lstrip": false,
|
| 12 |
+
"normalized": false,
|
| 13 |
+
"rstrip": false,
|
| 14 |
+
"single_word": false
|
| 15 |
+
},
|
| 16 |
+
"pad_token": "<unk>",
|
| 17 |
+
"unk_token": {
|
| 18 |
+
"content": "<unk>",
|
| 19 |
+
"lstrip": false,
|
| 20 |
+
"normalized": false,
|
| 21 |
+
"rstrip": false,
|
| 22 |
+
"single_word": false
|
| 23 |
+
}
|
| 24 |
+
}
|
DeQA-Score/preprocessor/tokenizer.model
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
|
| 3 |
+
size 499723
|
DeQA-Score/preprocessor/tokenizer_config.json
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"add_bos_token": true,
|
| 3 |
+
"add_eos_token": false,
|
| 4 |
+
"bos_token": {
|
| 5 |
+
"__type": "AddedToken",
|
| 6 |
+
"content": "<s>",
|
| 7 |
+
"lstrip": false,
|
| 8 |
+
"normalized": false,
|
| 9 |
+
"rstrip": false,
|
| 10 |
+
"single_word": false
|
| 11 |
+
},
|
| 12 |
+
"clean_up_tokenization_spaces": false,
|
| 13 |
+
"eos_token": {
|
| 14 |
+
"__type": "AddedToken",
|
| 15 |
+
"content": "</s>",
|
| 16 |
+
"lstrip": false,
|
| 17 |
+
"normalized": false,
|
| 18 |
+
"rstrip": false,
|
| 19 |
+
"single_word": false
|
| 20 |
+
},
|
| 21 |
+
"legacy": false,
|
| 22 |
+
"model_max_length": 2048,
|
| 23 |
+
"pad_token": null,
|
| 24 |
+
"padding_side": "right",
|
| 25 |
+
"sp_model_kwargs": {},
|
| 26 |
+
"tokenizer_class": "LlamaTokenizer",
|
| 27 |
+
"unk_token": {
|
| 28 |
+
"__type": "AddedToken",
|
| 29 |
+
"content": "<unk>",
|
| 30 |
+
"lstrip": false,
|
| 31 |
+
"normalized": false,
|
| 32 |
+
"rstrip": false,
|
| 33 |
+
"single_word": false
|
| 34 |
+
}
|
| 35 |
+
}
|
DeQA-Score/pyproject.toml
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61.0"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "DeQA-Score"
|
| 7 |
+
version = "1.2.0"
|
| 8 |
+
description = "Teaching Large Language Models to Regress Accurate Image Quality Scores using Score Distribution (based on mPLUG-Owl2)"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.8"
|
| 11 |
+
classifiers = [
|
| 12 |
+
"Programming Language :: Python :: 3",
|
| 13 |
+
"License :: OSI Approved :: Apache Software License",
|
| 14 |
+
]
|
| 15 |
+
dependencies = [
|
| 16 |
+
"torch==2.0.1", "torchvision==0.15.2",
|
| 17 |
+
"transformers==4.36.1", "tokenizers==0.15.0", "sentencepiece==0.1.99", "shortuuid",
|
| 18 |
+
"accelerate==0.21.0", "peft==0.4.0", "bitsandbytes==0.41.0",
|
| 19 |
+
"pydantic<2,>=1", "markdown2[all]", "numpy", "scikit-learn==1.2.2",
|
| 20 |
+
"gradio==3.35.2", "gradio_client==0.2.9",
|
| 21 |
+
"requests", "httpx==0.24.0", "uvicorn", "fastapi", "icecream",
|
| 22 |
+
"einops==0.6.1", "einops-exts==0.0.4", "timm==0.6.13", "decord", "scipy",
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
[project.optional-dependencies]
|
| 26 |
+
train = ["deepspeed==0.9.5", "ninja", "wandb"]
|
| 27 |
+
|
| 28 |
+
[project.urls]
|
| 29 |
+
"Bug Tracker" = "https://github.com/zhiyuanyou/DeQA-Score/issues"
|
| 30 |
+
|
| 31 |
+
[tool.setuptools.packages.find]
|
| 32 |
+
exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
|
| 33 |
+
|
| 34 |
+
[tool.wheel]
|
| 35 |
+
exclude = ["assets*", "benchmark*", "docs", "dist*", "playground*", "scripts*", "tests*"]
|
DeQA-Score/scripts/eval_dist.sh
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export PYTHONPATH=./:$PYTHONPATH
|
| 2 |
+
|
| 3 |
+
res_dir=./results/res_deqa_mix3/
|
| 4 |
+
gt_dir=../Data-DeQA-Score/
|
| 5 |
+
|
| 6 |
+
python src/evaluate/cal_distribution_gap.py \
|
| 7 |
+
--level_names excellent good fair poor bad \
|
| 8 |
+
--pred_paths $res_dir/test_koniq_2k.json \
|
| 9 |
+
$res_dir/test_spaq_2k.json \
|
| 10 |
+
$res_dir/test_kadid_2k.json \
|
| 11 |
+
$res_dir/test_pipal_5k.json \
|
| 12 |
+
$res_dir/test_livew_1k.json \
|
| 13 |
+
$res_dir/test_agiqa_3k.json \
|
| 14 |
+
$res_dir/test_tid2013_3k.json \
|
| 15 |
+
$res_dir/test_csiq_866.json \
|
| 16 |
+
--gt_paths $gt_dir/KONIQ/metas/test_koniq_2k.json \
|
| 17 |
+
$gt_dir/SPAQ/metas/test_spaq_2k.json \
|
| 18 |
+
$gt_dir/KADID10K/metas/test_kadid_2k.json \
|
| 19 |
+
$gt_dir/PIPAL/metas/test_pipal_5k.json \
|
| 20 |
+
$gt_dir/LIVE-WILD/metas/test_livew_1k.json \
|
| 21 |
+
$gt_dir/AGIQA3K/metas/test_agiqa_3k.json \
|
| 22 |
+
$gt_dir/TID2013/metas/test_tid2013_3k.json \
|
| 23 |
+
$gt_dir/CSIQ/metas/test_csiq_866.json \
|
DeQA-Score/scripts/eval_score.sh
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export PYTHONPATH=./:$PYTHONPATH
|
| 2 |
+
|
| 3 |
+
res_dir=./results/res_deqa_mix3/
|
| 4 |
+
gt_dir=../Data-DeQA-Score/
|
| 5 |
+
|
| 6 |
+
python src/evaluate/cal_plcc_srcc.py \
|
| 7 |
+
--level_names excellent good fair poor bad \
|
| 8 |
+
--pred_paths $res_dir/test_koniq_2k.json \
|
| 9 |
+
$res_dir/test_spaq_2k.json \
|
| 10 |
+
$res_dir/test_kadid_2k.json \
|
| 11 |
+
$res_dir/test_pipal_5k.json \
|
| 12 |
+
$res_dir/test_livew_1k.json \
|
| 13 |
+
$res_dir/test_agiqa_3k.json \
|
| 14 |
+
$res_dir/test_tid2013_3k.json \
|
| 15 |
+
$res_dir/test_csiq_866.json \
|
| 16 |
+
--gt_paths $gt_dir/KONIQ/metas/test_koniq_2k.json \
|
| 17 |
+
$gt_dir/SPAQ/metas/test_spaq_2k.json \
|
| 18 |
+
$gt_dir/KADID10K/metas/test_kadid_2k.json \
|
| 19 |
+
$gt_dir/PIPAL/metas/test_pipal_5k.json \
|
| 20 |
+
$gt_dir/LIVE-WILD/metas/test_livew_1k.json \
|
| 21 |
+
$gt_dir/AGIQA3K/metas/test_agiqa_3k.json \
|
| 22 |
+
$gt_dir/TID2013/metas/test_tid2013_3k.json \
|
| 23 |
+
$gt_dir/CSIQ/metas/test_csiq_866.json \
|
DeQA-Score/scripts/infer.sh
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export CUDA_VISIBLE_DEVICES=$1
|
| 2 |
+
export PYTHONPATH=./:$PYTHONPATH
|
| 3 |
+
|
| 4 |
+
python src/evaluate/iqa_eval.py \
|
| 5 |
+
--level-names excellent good fair poor bad \
|
| 6 |
+
--model-path checkpoints/DeQA-Score-Mix3/ \
|
| 7 |
+
--save-dir results/res_deqa_mix3/ \
|
| 8 |
+
--preprocessor-path ./preprocessor/ \
|
| 9 |
+
--root-dir ../Data-DeQA-Score/ \
|
| 10 |
+
--meta-paths ../Data-DeQA-Score/KONIQ/metas/test_koniq_2k.json \
|
| 11 |
+
../Data-DeQA-Score/SPAQ/metas/test_spaq_2k.json \
|
| 12 |
+
../Data-DeQA-Score/KADID10K/metas/test_kadid_2k.json \
|
| 13 |
+
../Data-DeQA-Score/PIPAL/metas/test_pipal_5k.json \
|
| 14 |
+
../Data-DeQA-Score/LIVE-WILD/metas/test_livew_1k.json \
|
| 15 |
+
../Data-DeQA-Score/AGIQA3K/metas/test_agiqa_3k.json \
|
| 16 |
+
../Data-DeQA-Score/TID2013/metas/test_tid2013_3k.json \
|
| 17 |
+
../Data-DeQA-Score/CSIQ/metas/test_csiq_866.json \
|
DeQA-Score/scripts/infer_lora.sh
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export CUDA_VISIBLE_DEVICES=$1
|
| 2 |
+
export PYTHONPATH=./:$PYTHONPATH
|
| 3 |
+
|
| 4 |
+
python src/evaluate/iqa_eval.py \
|
| 5 |
+
--level-names excellent good fair poor bad \
|
| 6 |
+
--model-path checkpoints/DeQA-Score-LoRA-Mix3/ \
|
| 7 |
+
--model-base ../ModelZoo/mplug-owl2-llama2-7b/ \
|
| 8 |
+
--save-dir results/res_deqa_lora_mix3/ \
|
| 9 |
+
--preprocessor-path ./preprocessor/ \
|
| 10 |
+
--root-dir ../Data-DeQA-Score/ \
|
| 11 |
+
--meta-paths ../Data-DeQA-Score/KONIQ/metas/test_koniq_2k.json \
|
| 12 |
+
../Data-DeQA-Score/SPAQ/metas/test_spaq_2k.json \
|
| 13 |
+
../Data-DeQA-Score/KADID10K/metas/test_kadid_2k.json \
|
| 14 |
+
../Data-DeQA-Score/PIPAL/metas/test_pipal_5k.json \
|
| 15 |
+
../Data-DeQA-Score/LIVE-WILD/metas/test_livew_1k.json \
|
| 16 |
+
../Data-DeQA-Score/AGIQA3K/metas/test_agiqa_3k.json \
|
| 17 |
+
../Data-DeQA-Score/TID2013/metas/test_tid2013_3k.json \
|
| 18 |
+
../Data-DeQA-Score/CSIQ/metas/test_csiq_866.json \
|
DeQA-Score/scripts/train.sh
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
export PYTHONPATH=./:$PYTHONPATH
|
| 3 |
+
|
| 4 |
+
LOAD="../ModelZoo/mplug-owl2-llama2-7b/"
|
| 5 |
+
|
| 6 |
+
deepspeed --include localhost:$1 --master_port 6688 src/train/train_mem.py \
|
| 7 |
+
--deepspeed scripts/zero3.json \
|
| 8 |
+
--model_name_or_path $LOAD \
|
| 9 |
+
--version v1 \
|
| 10 |
+
--dataset_type pair \
|
| 11 |
+
--level_prefix "The quality of the image is" \
|
| 12 |
+
--level_names excellent good fair poor bad \
|
| 13 |
+
--softkl_loss True \
|
| 14 |
+
--weight_rank 1.0 \
|
| 15 |
+
--weight_softkl 1.0 \
|
| 16 |
+
--weight_next_token 0.05 \
|
| 17 |
+
--continuous_rating_loss True \
|
| 18 |
+
--closeset_rating_loss True \
|
| 19 |
+
--use_fix_std True \
|
| 20 |
+
--detach_pred_std True \
|
| 21 |
+
--data_paths ../Data-DeQA-Score/KONIQ/metas/train_koniq_7k.json \
|
| 22 |
+
../Data-DeQA-Score/SPAQ/metas/train_spaq_9k.json \
|
| 23 |
+
../Data-DeQA-Score/KADID10K/metas/train_kadid_8k.json \
|
| 24 |
+
--data_weights 1 1 1 \
|
| 25 |
+
--image_folder ../Data-DeQA-Score/ \
|
| 26 |
+
--output_dir ./checkpoints/deqa_mix3_rank1.0_next0.05_kl1.0/ \
|
| 27 |
+
--image_aspect_ratio pad \
|
| 28 |
+
--group_by_modality_length True \
|
| 29 |
+
--bf16 True \
|
| 30 |
+
--num_train_epochs 3 \
|
| 31 |
+
--per_device_train_batch_size 16 \
|
| 32 |
+
--per_device_eval_batch_size 4 \
|
| 33 |
+
--gradient_accumulation_steps 1 \
|
| 34 |
+
--evaluation_strategy "no" \
|
| 35 |
+
--save_strategy "no" \
|
| 36 |
+
--learning_rate 2e-5 \
|
| 37 |
+
--weight_decay 0. \
|
| 38 |
+
--warmup_ratio 0.03 \
|
| 39 |
+
--lr_scheduler_type "cosine" \
|
| 40 |
+
--logging_steps 1 \
|
| 41 |
+
--tf32 True \
|
| 42 |
+
--model_max_length 2048 \
|
| 43 |
+
--gradient_checkpointing True \
|
| 44 |
+
--tune_visual_abstractor True \
|
| 45 |
+
--freeze_vision_model False \
|
| 46 |
+
--dataloader_num_workers 4 \
|
| 47 |
+
--lazy_preprocess True \
|
| 48 |
+
--report_to tensorboard
|
DeQA-Score/scripts/train_lora.sh
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
export PYTHONPATH=./:$PYTHONPATH
|
| 3 |
+
|
| 4 |
+
LOAD="../ModelZoo/mplug-owl2-llama2-7b/"
|
| 5 |
+
|
| 6 |
+
deepspeed --include localhost:$1 --master_port 6688 src/train/train_mem.py \
|
| 7 |
+
--deepspeed scripts/zero3.json \
|
| 8 |
+
--lora_enable True \
|
| 9 |
+
--model_name_or_path $LOAD \
|
| 10 |
+
--version v1 \
|
| 11 |
+
--dataset_type pair \
|
| 12 |
+
--level_prefix "The quality of the image is" \
|
| 13 |
+
--level_names excellent good fair poor bad \
|
| 14 |
+
--softkl_loss True \
|
| 15 |
+
--weight_rank 1.0 \
|
| 16 |
+
--weight_softkl 1.0 \
|
| 17 |
+
--weight_next_token 0.05 \
|
| 18 |
+
--continuous_rating_loss True \
|
| 19 |
+
--closeset_rating_loss True \
|
| 20 |
+
--use_fix_std True \
|
| 21 |
+
--detach_pred_std True \
|
| 22 |
+
--data_paths ../Data-DeQA-Score/KONIQ/metas/train_koniq_7k.json \
|
| 23 |
+
../Data-DeQA-Score/SPAQ/metas/train_spaq_9k.json \
|
| 24 |
+
../Data-DeQA-Score/KADID10K/metas/train_kadid_8k.json \
|
| 25 |
+
--data_weights 1 1 1 \
|
| 26 |
+
--image_folder ../Data-DeQA-Score/ \
|
| 27 |
+
--output_dir ./checkpoints/deqa_lora_mix3_rank1.0_next0.05_kl1.0 \
|
| 28 |
+
--image_aspect_ratio pad \
|
| 29 |
+
--group_by_modality_length True \
|
| 30 |
+
--bf16 True \
|
| 31 |
+
--num_train_epochs 3 \
|
| 32 |
+
--per_device_train_batch_size 16 \
|
| 33 |
+
--per_device_eval_batch_size 4 \
|
| 34 |
+
--gradient_accumulation_steps 1 \
|
| 35 |
+
--evaluation_strategy "no" \
|
| 36 |
+
--save_strategy "no" \
|
| 37 |
+
--learning_rate 2e-5 \
|
| 38 |
+
--weight_decay 0. \
|
| 39 |
+
--warmup_ratio 0.03 \
|
| 40 |
+
--lr_scheduler_type "cosine" \
|
| 41 |
+
--logging_steps 1 \
|
| 42 |
+
--tf32 True \
|
| 43 |
+
--model_max_length 2048 \
|
| 44 |
+
--gradient_checkpointing True \
|
| 45 |
+
--tune_visual_abstractor True \
|
| 46 |
+
--freeze_vision_model False \
|
| 47 |
+
--dataloader_num_workers 4 \
|
| 48 |
+
--lazy_preprocess True \
|
| 49 |
+
--report_to tensorboard
|
DeQA-Score/scripts/zero3.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": "auto",
|
| 4 |
+
"loss_scale": 0,
|
| 5 |
+
"loss_scale_window": 1000,
|
| 6 |
+
"initial_scale_power": 16,
|
| 7 |
+
"hysteresis": 2,
|
| 8 |
+
"min_loss_scale": 1
|
| 9 |
+
},
|
| 10 |
+
"bf16": {
|
| 11 |
+
"enabled": "auto"
|
| 12 |
+
},
|
| 13 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 14 |
+
"train_batch_size": "auto",
|
| 15 |
+
"gradient_accumulation_steps": "auto",
|
| 16 |
+
"zero_optimization": {
|
| 17 |
+
"stage": 3,
|
| 18 |
+
"overlap_comm": true,
|
| 19 |
+
"contiguous_gradients": true,
|
| 20 |
+
"sub_group_size": 1e9,
|
| 21 |
+
"reduce_bucket_size": "auto",
|
| 22 |
+
"stage3_param_persistence_threshold": "auto",
|
| 23 |
+
"stage3_max_live_parameters": 0,
|
| 24 |
+
"stage3_max_reuse_distance": 0,
|
| 25 |
+
"stage3_prefetch_bucket_size": 0,
|
| 26 |
+
"stage3_gather_16bit_weights_on_model_save": true
|
| 27 |
+
}
|
| 28 |
+
}
|
DeQA-Score/scripts/zero3_offload.json
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": "auto",
|
| 4 |
+
"loss_scale": 0,
|
| 5 |
+
"loss_scale_window": 1000,
|
| 6 |
+
"initial_scale_power": 16,
|
| 7 |
+
"hysteresis": 2,
|
| 8 |
+
"min_loss_scale": 1
|
| 9 |
+
},
|
| 10 |
+
"bf16": {
|
| 11 |
+
"enabled": "auto"
|
| 12 |
+
},
|
| 13 |
+
"optimizer": {
|
| 14 |
+
"type": "AdamW",
|
| 15 |
+
"params": {
|
| 16 |
+
"lr": "auto",
|
| 17 |
+
"betas": "auto",
|
| 18 |
+
"eps": "auto",
|
| 19 |
+
"weight_decay": "auto"
|
| 20 |
+
}
|
| 21 |
+
},
|
| 22 |
+
"scheduler": {
|
| 23 |
+
"type": "WarmupLR",
|
| 24 |
+
"params": {
|
| 25 |
+
"warmup_min_lr": "auto",
|
| 26 |
+
"warmup_max_lr": "auto",
|
| 27 |
+
"warmup_num_steps": "auto"
|
| 28 |
+
}
|
| 29 |
+
},
|
| 30 |
+
"zero_optimization": {
|
| 31 |
+
"stage": 3,
|
| 32 |
+
"offload_optimizer": {
|
| 33 |
+
"device": "cpu",
|
| 34 |
+
"pin_memory": true
|
| 35 |
+
},
|
| 36 |
+
"offload_param": {
|
| 37 |
+
"device": "cpu",
|
| 38 |
+
"pin_memory": true
|
| 39 |
+
},
|
| 40 |
+
"overlap_comm": true,
|
| 41 |
+
"contiguous_gradients": true,
|
| 42 |
+
"sub_group_size": 1e9,
|
| 43 |
+
"reduce_bucket_size": "auto",
|
| 44 |
+
"stage3_prefetch_bucket_size": "auto",
|
| 45 |
+
"stage3_param_persistence_threshold": "auto",
|
| 46 |
+
"stage3_max_live_parameters": 1e9,
|
| 47 |
+
"stage3_max_reuse_distance": 1e9,
|
| 48 |
+
"gather_16bit_weights_on_model_save": true
|
| 49 |
+
},
|
| 50 |
+
"gradient_accumulation_steps": "auto",
|
| 51 |
+
"gradient_clipping": "auto",
|
| 52 |
+
"train_batch_size": "auto",
|
| 53 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 54 |
+
"steps_per_print": 1e5,
|
| 55 |
+
"wall_clock_breakdown": false
|
| 56 |
+
}
|
DeQA-Score/src/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .model import MPLUGOwl2LlamaForCausalLM
|
| 2 |
+
from .evaluate import Scorer
|
DeQA-Score/src/constants.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
| 2 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
| 3 |
+
|
| 4 |
+
LOGDIR = "./demo_logs"
|
| 5 |
+
|
| 6 |
+
# Model Constants
|
| 7 |
+
IGNORE_INDEX = -100
|
| 8 |
+
IMAGE_TOKEN_INDEX = -200
|
| 9 |
+
DEFAULT_IMAGE_TOKEN = "<|image|>"
|
DeQA-Score/src/conversation.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import dataclasses
|
| 2 |
+
from enum import auto, Enum
|
| 3 |
+
from typing import List, Tuple
|
| 4 |
+
from src.constants import DEFAULT_IMAGE_TOKEN
|
| 5 |
+
|
| 6 |
+
class SeparatorStyle(Enum):
|
| 7 |
+
"""Different separator style."""
|
| 8 |
+
SINGLE = auto()
|
| 9 |
+
TWO = auto()
|
| 10 |
+
TWO_NO_SYS = auto()
|
| 11 |
+
MPT = auto()
|
| 12 |
+
PLAIN = auto()
|
| 13 |
+
LLAMA_2 = auto()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclasses.dataclass
|
| 17 |
+
class Conversation:
|
| 18 |
+
"""A class that keeps all conversation history."""
|
| 19 |
+
system: str
|
| 20 |
+
roles: List[str]
|
| 21 |
+
messages: List[List[str]]
|
| 22 |
+
offset: int
|
| 23 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
| 24 |
+
sep: str = "###"
|
| 25 |
+
sep2: str = None
|
| 26 |
+
version: str = "Unknown"
|
| 27 |
+
|
| 28 |
+
skip_next: bool = False
|
| 29 |
+
|
| 30 |
+
def get_prompt(self):
|
| 31 |
+
messages = self.messages
|
| 32 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
| 33 |
+
messages = self.messages.copy()
|
| 34 |
+
init_role, init_msg = messages[0].copy()
|
| 35 |
+
# init_msg = init_msg[0].replace("<image>", "").strip()
|
| 36 |
+
# if 'mmtag' in self.version:
|
| 37 |
+
# messages[0] = (init_role, init_msg)
|
| 38 |
+
# messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
| 39 |
+
# messages.insert(1, (self.roles[1], "Received."))
|
| 40 |
+
# else:
|
| 41 |
+
# messages[0] = (init_role, "<image>\n" + init_msg)
|
| 42 |
+
init_msg = init_msg[0].replace(DEFAULT_IMAGE_TOKEN, "").strip()
|
| 43 |
+
messages[0] = (init_role, DEFAULT_IMAGE_TOKEN + init_msg)
|
| 44 |
+
|
| 45 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
| 46 |
+
ret = self.system + self.sep
|
| 47 |
+
for role, message in messages:
|
| 48 |
+
if message:
|
| 49 |
+
if type(message) is tuple:
|
| 50 |
+
message, _, _ = message
|
| 51 |
+
ret += role + ": " + message + self.sep
|
| 52 |
+
else:
|
| 53 |
+
ret += role + ":"
|
| 54 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
| 55 |
+
seps = [self.sep, self.sep2]
|
| 56 |
+
ret = self.system + seps[0]
|
| 57 |
+
for i, (role, message) in enumerate(messages):
|
| 58 |
+
if message:
|
| 59 |
+
if type(message) is tuple:
|
| 60 |
+
message, _, _ = message
|
| 61 |
+
ret += role + ": " + message + seps[i % 2]
|
| 62 |
+
else:
|
| 63 |
+
ret += role + ":"
|
| 64 |
+
elif self.sep_style == SeparatorStyle.TWO_NO_SYS:
|
| 65 |
+
seps = [self.sep, self.sep2]
|
| 66 |
+
ret = ""
|
| 67 |
+
for i, (role, message) in enumerate(messages):
|
| 68 |
+
if message:
|
| 69 |
+
if type(message) is tuple:
|
| 70 |
+
message, _, _ = message
|
| 71 |
+
ret += role + ": " + message + seps[i % 2]
|
| 72 |
+
else:
|
| 73 |
+
ret += role + ":"
|
| 74 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
| 75 |
+
ret = self.system + self.sep
|
| 76 |
+
for role, message in messages:
|
| 77 |
+
if message:
|
| 78 |
+
if type(message) is tuple:
|
| 79 |
+
message, _, _ = message
|
| 80 |
+
ret += role + message + self.sep
|
| 81 |
+
else:
|
| 82 |
+
ret += role
|
| 83 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2:
|
| 84 |
+
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
|
| 85 |
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
| 86 |
+
ret = ""
|
| 87 |
+
|
| 88 |
+
for i, (role, message) in enumerate(messages):
|
| 89 |
+
if i == 0:
|
| 90 |
+
assert message, "first message should not be none"
|
| 91 |
+
assert role == self.roles[0], "first message should come from user"
|
| 92 |
+
if message:
|
| 93 |
+
if type(message) is tuple:
|
| 94 |
+
message, _, _ = message
|
| 95 |
+
if i == 0: message = wrap_sys(self.system) + message
|
| 96 |
+
if i % 2 == 0:
|
| 97 |
+
message = wrap_inst(message)
|
| 98 |
+
ret += self.sep + message
|
| 99 |
+
else:
|
| 100 |
+
ret += " " + message + " " + self.sep2
|
| 101 |
+
else:
|
| 102 |
+
ret += ""
|
| 103 |
+
ret = ret.lstrip(self.sep)
|
| 104 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
| 105 |
+
seps = [self.sep, self.sep2]
|
| 106 |
+
ret = self.system
|
| 107 |
+
for i, (role, message) in enumerate(messages):
|
| 108 |
+
if message:
|
| 109 |
+
if type(message) is tuple:
|
| 110 |
+
message, _, _ = message
|
| 111 |
+
ret += message + seps[i % 2]
|
| 112 |
+
else:
|
| 113 |
+
ret += ""
|
| 114 |
+
else:
|
| 115 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
| 116 |
+
|
| 117 |
+
return ret
|
| 118 |
+
|
| 119 |
+
def append_message(self, role, message):
|
| 120 |
+
self.messages.append([role, message])
|
| 121 |
+
|
| 122 |
+
def get_images(self, return_pil=False):
|
| 123 |
+
images = []
|
| 124 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
| 125 |
+
if i % 2 == 0:
|
| 126 |
+
if type(msg) is tuple:
|
| 127 |
+
import base64
|
| 128 |
+
from io import BytesIO
|
| 129 |
+
from PIL import Image
|
| 130 |
+
msg, image, image_process_mode = msg
|
| 131 |
+
if image_process_mode == "Pad":
|
| 132 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
| 133 |
+
width, height = pil_img.size
|
| 134 |
+
if width == height:
|
| 135 |
+
return pil_img
|
| 136 |
+
elif width > height:
|
| 137 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
| 138 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
| 139 |
+
return result
|
| 140 |
+
else:
|
| 141 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
| 142 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
| 143 |
+
return result
|
| 144 |
+
image = expand2square(image)
|
| 145 |
+
elif image_process_mode in ["Default", "Crop"]:
|
| 146 |
+
pass
|
| 147 |
+
elif image_process_mode == "Resize":
|
| 148 |
+
image = image.resize((336, 336))
|
| 149 |
+
else:
|
| 150 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
| 151 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
| 152 |
+
aspect_ratio = max_hw / min_hw
|
| 153 |
+
max_len, min_len = 800, 400
|
| 154 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
| 155 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
| 156 |
+
W, H = image.size
|
| 157 |
+
if longest_edge != max(image.size):
|
| 158 |
+
if H > W:
|
| 159 |
+
H, W = longest_edge, shortest_edge
|
| 160 |
+
else:
|
| 161 |
+
H, W = shortest_edge, longest_edge
|
| 162 |
+
image = image.resize((W, H))
|
| 163 |
+
if return_pil:
|
| 164 |
+
images.append(image)
|
| 165 |
+
else:
|
| 166 |
+
buffered = BytesIO()
|
| 167 |
+
image.save(buffered, format="PNG")
|
| 168 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
| 169 |
+
images.append(img_b64_str)
|
| 170 |
+
return images
|
| 171 |
+
|
| 172 |
+
def to_gradio_chatbot(self):
|
| 173 |
+
ret = []
|
| 174 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
| 175 |
+
if i % 2 == 0:
|
| 176 |
+
if type(msg) is tuple:
|
| 177 |
+
import base64
|
| 178 |
+
from io import BytesIO
|
| 179 |
+
msg, image, image_process_mode = msg
|
| 180 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
| 181 |
+
aspect_ratio = max_hw / min_hw
|
| 182 |
+
max_len, min_len = 800, 400
|
| 183 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
| 184 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
| 185 |
+
W, H = image.size
|
| 186 |
+
if H > W:
|
| 187 |
+
H, W = longest_edge, shortest_edge
|
| 188 |
+
else:
|
| 189 |
+
H, W = shortest_edge, longest_edge
|
| 190 |
+
image = image.resize((W, H))
|
| 191 |
+
buffered = BytesIO()
|
| 192 |
+
image.save(buffered, format="JPEG")
|
| 193 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
| 194 |
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
| 195 |
+
msg = img_str + msg.replace('<|image|>', '').strip()
|
| 196 |
+
ret.append([msg, None])
|
| 197 |
+
else:
|
| 198 |
+
ret.append([msg, None])
|
| 199 |
+
else:
|
| 200 |
+
ret[-1][-1] = msg
|
| 201 |
+
return ret
|
| 202 |
+
|
| 203 |
+
def copy(self):
|
| 204 |
+
return Conversation(
|
| 205 |
+
system=self.system,
|
| 206 |
+
roles=self.roles,
|
| 207 |
+
messages=[[x, y] for x, y in self.messages],
|
| 208 |
+
offset=self.offset,
|
| 209 |
+
sep_style=self.sep_style,
|
| 210 |
+
sep=self.sep,
|
| 211 |
+
sep2=self.sep2,
|
| 212 |
+
version=self.version)
|
| 213 |
+
|
| 214 |
+
def dict(self):
|
| 215 |
+
if len(self.get_images()) > 0:
|
| 216 |
+
return {
|
| 217 |
+
"system": self.system,
|
| 218 |
+
"roles": self.roles,
|
| 219 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
| 220 |
+
"offset": self.offset,
|
| 221 |
+
"sep": self.sep,
|
| 222 |
+
"sep2": self.sep2,
|
| 223 |
+
}
|
| 224 |
+
return {
|
| 225 |
+
"system": self.system,
|
| 226 |
+
"roles": self.roles,
|
| 227 |
+
"messages": self.messages,
|
| 228 |
+
"offset": self.offset,
|
| 229 |
+
"sep": self.sep,
|
| 230 |
+
"sep2": self.sep2,
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
conv_vicuna_v0 = Conversation(
|
| 235 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
| 236 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
| 237 |
+
roles=("Human", "Assistant"),
|
| 238 |
+
messages=(
|
| 239 |
+
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
|
| 240 |
+
("Assistant",
|
| 241 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
| 242 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
| 243 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
| 244 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
| 245 |
+
"renewable and non-renewable energy sources:\n"
|
| 246 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
| 247 |
+
"energy sources are finite and will eventually run out.\n"
|
| 248 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
| 249 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
| 250 |
+
"and other negative effects.\n"
|
| 251 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
| 252 |
+
"have lower operational costs than non-renewable sources.\n"
|
| 253 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
| 254 |
+
"locations than non-renewable sources.\n"
|
| 255 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
| 256 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
| 257 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
| 258 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
|
| 259 |
+
),
|
| 260 |
+
offset=2,
|
| 261 |
+
sep_style=SeparatorStyle.SINGLE,
|
| 262 |
+
sep="###",
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
conv_vicuna_v1 = Conversation(
|
| 266 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
| 267 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
| 268 |
+
roles=("USER", "ASSISTANT"),
|
| 269 |
+
version="v1",
|
| 270 |
+
messages=(),
|
| 271 |
+
offset=0,
|
| 272 |
+
sep_style=SeparatorStyle.TWO,
|
| 273 |
+
sep=" ",
|
| 274 |
+
sep2="</s>",
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
conv_mplug_owl2 = Conversation(
|
| 278 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
| 279 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
| 280 |
+
roles=("USER", "ASSISTANT"),
|
| 281 |
+
version="v1",
|
| 282 |
+
messages=(),
|
| 283 |
+
offset=0,
|
| 284 |
+
sep_style=SeparatorStyle.TWO_NO_SYS,
|
| 285 |
+
sep=" ",
|
| 286 |
+
sep2="</s>",
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
# default_conversation = conv_vicuna_v1
|
| 290 |
+
default_conversation = conv_mplug_owl2
|
| 291 |
+
conv_templates = {
|
| 292 |
+
"default": conv_vicuna_v0,
|
| 293 |
+
"v0": conv_vicuna_v0,
|
| 294 |
+
"v1": conv_vicuna_v1,
|
| 295 |
+
"vicuna_v1": conv_vicuna_v1,
|
| 296 |
+
"mplug_owl2": conv_mplug_owl2,
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
if __name__ == "__main__":
|
| 301 |
+
print(default_conversation.get_prompt())
|
DeQA-Score/src/datasets/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .pair_dataset import make_pair_data_module
|
| 2 |
+
from .single_dataset import make_single_data_module
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def make_data_module(tokenizer, data_args):
|
| 6 |
+
if data_args.dataset_type == "single":
|
| 7 |
+
return make_single_data_module(tokenizer, data_args)
|
| 8 |
+
elif data_args.dataset_type == "pair":
|
| 9 |
+
return make_pair_data_module(tokenizer, data_args)
|
| 10 |
+
else:
|
| 11 |
+
raise ValueError
|
DeQA-Score/src/datasets/pair_dataset.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import Dict, Sequence
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import transformers
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from torch.utils.data import Dataset
|
| 12 |
+
|
| 13 |
+
from src.constants import IGNORE_INDEX
|
| 14 |
+
|
| 15 |
+
from .utils import (expand2square, load_video, preprocess,
|
| 16 |
+
preprocess_multimodal, rank0_print)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class PairDataset(Dataset):
|
| 20 |
+
"""Dataset for supervised fine-tuning."""
|
| 21 |
+
|
| 22 |
+
def __init__(
|
| 23 |
+
self,
|
| 24 |
+
data_paths,
|
| 25 |
+
data_weights,
|
| 26 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 27 |
+
data_args,
|
| 28 |
+
):
|
| 29 |
+
super(PairDataset, self).__init__()
|
| 30 |
+
dataset_list = [] # list (different datasets) of list (samples in one dataset)
|
| 31 |
+
for data_path, data_weight in zip(data_paths, data_weights):
|
| 32 |
+
data_list = json.load(open(data_path, "r"))
|
| 33 |
+
dataset_list.append(data_list * data_weight)
|
| 34 |
+
self.dataset_list = dataset_list
|
| 35 |
+
|
| 36 |
+
# Construct nums_data, nums_data[i] is the number of samples in 0-i th datasets
|
| 37 |
+
nums_eachdata = [len(_) for _ in self.dataset_list]
|
| 38 |
+
nums_predata = copy.deepcopy(nums_eachdata)
|
| 39 |
+
for idx in range(1, len(nums_predata)):
|
| 40 |
+
nums_predata[idx] = nums_predata[idx] + nums_predata[idx - 1]
|
| 41 |
+
|
| 42 |
+
rank0_print("Formatting inputs...Skip in lazy mode")
|
| 43 |
+
self.tokenizer = tokenizer
|
| 44 |
+
self.nums_eachdata = nums_eachdata
|
| 45 |
+
self.nums_predata = nums_predata
|
| 46 |
+
self.data_args = data_args
|
| 47 |
+
assert self.nums_predata[-1] == sum(self.nums_eachdata)
|
| 48 |
+
|
| 49 |
+
def __len__(self):
|
| 50 |
+
return self.nums_predata[-1]
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def lengths(self):
|
| 54 |
+
length_list = []
|
| 55 |
+
for dataset in self.dataset_list:
|
| 56 |
+
for sample in dataset:
|
| 57 |
+
img_tokens = 128 if "image" in sample else 0
|
| 58 |
+
length_list.append(
|
| 59 |
+
sum(len(conv["value"].split()) for conv in sample["conversations"])
|
| 60 |
+
+ img_tokens
|
| 61 |
+
)
|
| 62 |
+
return length_list
|
| 63 |
+
|
| 64 |
+
@property
|
| 65 |
+
def modality_lengths(self):
|
| 66 |
+
length_list = []
|
| 67 |
+
for dataset in self.dataset_list:
|
| 68 |
+
for sample in dataset:
|
| 69 |
+
cur_len = sum(
|
| 70 |
+
len(conv["value"].split()) for conv in sample["conversations"]
|
| 71 |
+
)
|
| 72 |
+
cur_len = cur_len if "image" in sample else -cur_len
|
| 73 |
+
length_list.append(cur_len)
|
| 74 |
+
return length_list
|
| 75 |
+
|
| 76 |
+
def next_rand(self):
|
| 77 |
+
return random.randint(0, len(self) - 1)
|
| 78 |
+
|
| 79 |
+
def __getitem__(self, i):
|
| 80 |
+
while True:
|
| 81 |
+
try:
|
| 82 |
+
# Get idx_dataset, idx_sample
|
| 83 |
+
if i < self.nums_predata[0]:
|
| 84 |
+
idx_dataset = 0
|
| 85 |
+
idx_sample = i
|
| 86 |
+
else:
|
| 87 |
+
for idx_dataset in range(1, len(self.nums_predata)):
|
| 88 |
+
if (
|
| 89 |
+
i < self.nums_predata[idx_dataset]
|
| 90 |
+
and i >= self.nums_predata[idx_dataset - 1]
|
| 91 |
+
):
|
| 92 |
+
idx_sample = i - self.nums_predata[idx_dataset - 1]
|
| 93 |
+
break
|
| 94 |
+
# Sample two items
|
| 95 |
+
item_A = self.get_one_item(idx_dataset, idx_sample)
|
| 96 |
+
while True:
|
| 97 |
+
idx_sample_B = random.randint(
|
| 98 |
+
0, self.nums_eachdata[idx_dataset] - 1
|
| 99 |
+
)
|
| 100 |
+
if idx_sample_B != idx_sample:
|
| 101 |
+
break
|
| 102 |
+
item_B = self.get_one_item(idx_dataset, idx_sample_B)
|
| 103 |
+
return {
|
| 104 |
+
"item_A": item_A,
|
| 105 |
+
"item_B": item_B,
|
| 106 |
+
}
|
| 107 |
+
except Exception as ex:
|
| 108 |
+
print(ex)
|
| 109 |
+
i = self.next_rand()
|
| 110 |
+
continue
|
| 111 |
+
|
| 112 |
+
def get_one_item(self, idx_dataset, idx_sample) -> Dict[str, torch.Tensor]:
|
| 113 |
+
# For IQA data, i must be int
|
| 114 |
+
sources = [self.dataset_list[idx_dataset][idx_sample]]
|
| 115 |
+
sources_org = copy.deepcopy(sources)
|
| 116 |
+
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
|
| 117 |
+
if "image" in sources_org[0]:
|
| 118 |
+
image_file = sources[0]["image"]
|
| 119 |
+
|
| 120 |
+
image_folder = self.data_args.image_folder
|
| 121 |
+
processor = self.data_args.image_processor
|
| 122 |
+
|
| 123 |
+
if isinstance(image_file, list):
|
| 124 |
+
# Multiple Images as Input
|
| 125 |
+
image = [
|
| 126 |
+
Image.open(os.path.join(image_folder, imfile)).convert("RGB")
|
| 127 |
+
for imfile in image_file
|
| 128 |
+
]
|
| 129 |
+
|
| 130 |
+
if self.data_args.image_aspect_ratio == "pad":
|
| 131 |
+
image = [
|
| 132 |
+
expand2square(
|
| 133 |
+
img,
|
| 134 |
+
tuple(int(x * 255) for x in processor.image_mean),
|
| 135 |
+
)
|
| 136 |
+
for img in image
|
| 137 |
+
]
|
| 138 |
+
image = processor.preprocess(image, return_tensors="pt")[
|
| 139 |
+
"pixel_values"
|
| 140 |
+
]
|
| 141 |
+
else:
|
| 142 |
+
image = processor.preprocess(image, return_tensors="pt")[
|
| 143 |
+
"pixel_values"
|
| 144 |
+
]
|
| 145 |
+
elif os.path.join(image_folder, image_file).endswith("mp4"):
|
| 146 |
+
# Video as Input
|
| 147 |
+
image = load_video(os.path.join(image_folder, image_file))
|
| 148 |
+
if self.data_args.image_aspect_ratio == "pad":
|
| 149 |
+
image = [
|
| 150 |
+
expand2square(
|
| 151 |
+
img,
|
| 152 |
+
tuple(int(x * 255) for x in processor.image_mean),
|
| 153 |
+
)
|
| 154 |
+
for img in image
|
| 155 |
+
]
|
| 156 |
+
image = processor.preprocess(image, return_tensors="pt")[
|
| 157 |
+
"pixel_values"
|
| 158 |
+
]
|
| 159 |
+
else:
|
| 160 |
+
image = processor.preprocess(image, return_tensors="pt")[
|
| 161 |
+
"pixel_values"
|
| 162 |
+
]
|
| 163 |
+
else:
|
| 164 |
+
image = Image.open(os.path.join(image_folder, image_file)).convert(
|
| 165 |
+
"RGB"
|
| 166 |
+
)
|
| 167 |
+
if self.data_args.image_aspect_ratio == "pad":
|
| 168 |
+
image = expand2square(
|
| 169 |
+
image, tuple(int(x * 255) for x in processor.image_mean)
|
| 170 |
+
)
|
| 171 |
+
image = processor.preprocess(image, return_tensors="pt")[
|
| 172 |
+
"pixel_values"
|
| 173 |
+
]
|
| 174 |
+
else:
|
| 175 |
+
image = processor.preprocess(image, return_tensors="pt")[
|
| 176 |
+
"pixel_values"
|
| 177 |
+
]
|
| 178 |
+
sources = preprocess_multimodal(
|
| 179 |
+
copy.deepcopy([e["conversations"] for e in sources]),
|
| 180 |
+
self.data_args,
|
| 181 |
+
)
|
| 182 |
+
else:
|
| 183 |
+
# Without images
|
| 184 |
+
sources = copy.deepcopy([e["conversations"] for e in sources])
|
| 185 |
+
|
| 186 |
+
data_dict = preprocess(
|
| 187 |
+
sources,
|
| 188 |
+
self.tokenizer,
|
| 189 |
+
has_image=("image" in sources_org[0]),
|
| 190 |
+
)
|
| 191 |
+
data_dict = dict(
|
| 192 |
+
input_ids=data_dict["input_ids"][0],
|
| 193 |
+
labels=data_dict["labels"][0],
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# default task_type: "score", gt_socre & std: -10000, level_probs: [-10000] * 5
|
| 197 |
+
data_dict["task_type"] = sources_org[0].get("task_type", "score")
|
| 198 |
+
data_dict["gt_score"] = sources_org[0].get("gt_score", -10000)
|
| 199 |
+
data_dict["std"] = sources_org[0].get("std", -10000)
|
| 200 |
+
data_dict["level_probs"] = sources_org[0].get("level_probs", [-10000] * 5)
|
| 201 |
+
|
| 202 |
+
# image exist in the data
|
| 203 |
+
if "image" in sources_org[0]:
|
| 204 |
+
data_dict["image_file"] = image_file
|
| 205 |
+
data_dict["image"] = image
|
| 206 |
+
elif self.data_args.is_multimodal:
|
| 207 |
+
# image does not exist in the data, but the model is multimodal
|
| 208 |
+
crop_size = self.data_args.image_processor.crop_size
|
| 209 |
+
data_dict["image"] = torch.zeros(3, crop_size["height"], crop_size["width"])
|
| 210 |
+
return data_dict
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
@dataclass
|
| 214 |
+
class DataCollatorForPairDataset(object):
|
| 215 |
+
"""Collate examples for pair fine-tuning."""
|
| 216 |
+
|
| 217 |
+
tokenizer: transformers.PreTrainedTokenizer
|
| 218 |
+
|
| 219 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
| 220 |
+
instances_A = [instance["item_A"] for instance in instances]
|
| 221 |
+
instances_B = [instance["item_B"] for instance in instances]
|
| 222 |
+
batch = {
|
| 223 |
+
"input_type": "pair",
|
| 224 |
+
"item_A": self.collate_one(instances_A),
|
| 225 |
+
"item_B": self.collate_one(instances_B),
|
| 226 |
+
}
|
| 227 |
+
return batch
|
| 228 |
+
|
| 229 |
+
def collate_one(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
| 230 |
+
input_ids, labels = tuple(
|
| 231 |
+
[instance[key] for instance in instances] for key in ("input_ids", "labels")
|
| 232 |
+
)
|
| 233 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
| 234 |
+
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
| 235 |
+
)
|
| 236 |
+
labels = torch.nn.utils.rnn.pad_sequence(
|
| 237 |
+
labels, batch_first=True, padding_value=IGNORE_INDEX
|
| 238 |
+
)
|
| 239 |
+
input_ids = input_ids[:, : self.tokenizer.model_max_length]
|
| 240 |
+
labels = labels[:, : self.tokenizer.model_max_length]
|
| 241 |
+
batch = dict(
|
| 242 |
+
input_ids=input_ids,
|
| 243 |
+
labels=labels,
|
| 244 |
+
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
batch["task_types"] = [instance["task_type"] for instance in instances]
|
| 248 |
+
batch["gt_scores"] = torch.tensor([instance["gt_score"] for instance in instances])
|
| 249 |
+
batch["stds"] = torch.tensor([instance["std"] for instance in instances])
|
| 250 |
+
batch["level_probs"] = torch.tensor([instance["level_probs"] for instance in instances])
|
| 251 |
+
|
| 252 |
+
if "image" in instances[0]:
|
| 253 |
+
images = [instance["image"] for instance in instances]
|
| 254 |
+
if all(x is not None and x.shape == images[0].shape for x in images):
|
| 255 |
+
batch["images"] = torch.stack(images)
|
| 256 |
+
else:
|
| 257 |
+
batch["images"] = images
|
| 258 |
+
batch["image_files"] = [instance["image_file"] for instance in instances]
|
| 259 |
+
|
| 260 |
+
return batch
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def make_pair_data_module(
|
| 264 |
+
tokenizer: transformers.PreTrainedTokenizer, data_args
|
| 265 |
+
) -> Dict:
|
| 266 |
+
"""Make dataset and collator for supervised fine-tuning."""
|
| 267 |
+
train_dataset = PairDataset(
|
| 268 |
+
tokenizer=tokenizer,
|
| 269 |
+
data_paths=data_args.data_paths,
|
| 270 |
+
data_weights=data_args.data_weights,
|
| 271 |
+
data_args=data_args,
|
| 272 |
+
)
|
| 273 |
+
data_collator = DataCollatorForPairDataset(tokenizer=tokenizer)
|
| 274 |
+
return dict(
|
| 275 |
+
train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator
|
| 276 |
+
)
|
DeQA-Score/src/datasets/single_dataset.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from typing import Dict, Sequence
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import transformers
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from torch.utils.data import Dataset
|
| 11 |
+
|
| 12 |
+
from src.constants import IGNORE_INDEX
|
| 13 |
+
|
| 14 |
+
from .utils import (expand2square, load_video, preprocess,
|
| 15 |
+
preprocess_multimodal, rank0_print)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class SingleDataset(Dataset):
|
| 19 |
+
"""Dataset for supervised fine-tuning."""
|
| 20 |
+
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
data_paths: str,
|
| 24 |
+
data_weights: str,
|
| 25 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 26 |
+
data_args,
|
| 27 |
+
):
|
| 28 |
+
super(SingleDataset, self).__init__()
|
| 29 |
+
list_data_dict = []
|
| 30 |
+
for data_path, data_weight in zip(data_paths, data_weights):
|
| 31 |
+
data_dict = json.load(open(data_path, "r"))
|
| 32 |
+
list_data_dict += data_dict * data_weight
|
| 33 |
+
|
| 34 |
+
rank0_print("Formatting inputs...Skip in lazy mode")
|
| 35 |
+
self.tokenizer = tokenizer
|
| 36 |
+
self.list_data_dict = list_data_dict
|
| 37 |
+
self.data_args = data_args
|
| 38 |
+
|
| 39 |
+
def __len__(self):
|
| 40 |
+
return len(self.list_data_dict)
|
| 41 |
+
|
| 42 |
+
@property
|
| 43 |
+
def lengths(self):
|
| 44 |
+
length_list = []
|
| 45 |
+
for sample in self.list_data_dict:
|
| 46 |
+
img_tokens = 128 if "image" in sample else 0
|
| 47 |
+
length_list.append(
|
| 48 |
+
sum(len(conv["value"].split()) for conv in sample["conversations"])
|
| 49 |
+
+ img_tokens
|
| 50 |
+
)
|
| 51 |
+
return length_list
|
| 52 |
+
|
| 53 |
+
@property
|
| 54 |
+
def modality_lengths(self):
|
| 55 |
+
length_list = []
|
| 56 |
+
for sample in self.list_data_dict:
|
| 57 |
+
cur_len = sum(
|
| 58 |
+
len(conv["value"].split()) for conv in sample["conversations"]
|
| 59 |
+
)
|
| 60 |
+
cur_len = cur_len if "image" in sample else -cur_len
|
| 61 |
+
length_list.append(cur_len)
|
| 62 |
+
return length_list
|
| 63 |
+
|
| 64 |
+
def next_rand(self):
|
| 65 |
+
import random
|
| 66 |
+
|
| 67 |
+
return random.randint(0, len(self) - 1)
|
| 68 |
+
|
| 69 |
+
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
| 70 |
+
while True:
|
| 71 |
+
try:
|
| 72 |
+
sources = self.list_data_dict[i]
|
| 73 |
+
if isinstance(i, int):
|
| 74 |
+
sources = [sources]
|
| 75 |
+
sources_org = copy.deepcopy(sources)
|
| 76 |
+
assert (
|
| 77 |
+
len(sources) == 1
|
| 78 |
+
), "Don't know why it is wrapped to a list" # FIXME
|
| 79 |
+
if "image" in sources_org[0]:
|
| 80 |
+
image_file = sources_org[0]["image"]
|
| 81 |
+
|
| 82 |
+
image_folder = self.data_args.image_folder
|
| 83 |
+
processor = self.data_args.image_processor
|
| 84 |
+
from pathlib import Path
|
| 85 |
+
|
| 86 |
+
# if not Path(os.path.join(image_folder, image_file)).exists():
|
| 87 |
+
# i = self.next_rand()
|
| 88 |
+
# continue
|
| 89 |
+
if isinstance(image_file, list):
|
| 90 |
+
# Multiple Images as Input
|
| 91 |
+
try:
|
| 92 |
+
image = [
|
| 93 |
+
Image.open(os.path.join(image_folder, imfile)).convert(
|
| 94 |
+
"RGB"
|
| 95 |
+
)
|
| 96 |
+
for imfile in image_file
|
| 97 |
+
]
|
| 98 |
+
except Exception as ex:
|
| 99 |
+
print(ex)
|
| 100 |
+
i = self.next_rand()
|
| 101 |
+
continue
|
| 102 |
+
if self.data_args.image_aspect_ratio == "pad":
|
| 103 |
+
image = [
|
| 104 |
+
expand2square(
|
| 105 |
+
img,
|
| 106 |
+
tuple(int(x * 255) for x in processor.image_mean),
|
| 107 |
+
)
|
| 108 |
+
for img in image
|
| 109 |
+
]
|
| 110 |
+
image = processor.preprocess(image, return_tensors="pt")[
|
| 111 |
+
"pixel_values"
|
| 112 |
+
]
|
| 113 |
+
else:
|
| 114 |
+
image = processor.preprocess(image, return_tensors="pt")[
|
| 115 |
+
"pixel_values"
|
| 116 |
+
]
|
| 117 |
+
elif os.path.join(image_folder, image_file).endswith("mp4"):
|
| 118 |
+
# Video as Input
|
| 119 |
+
image = load_video(os.path.join(image_folder, image_file))
|
| 120 |
+
if self.data_args.image_aspect_ratio == "pad":
|
| 121 |
+
image = [
|
| 122 |
+
expand2square(
|
| 123 |
+
img,
|
| 124 |
+
tuple(int(x * 255) for x in processor.image_mean),
|
| 125 |
+
)
|
| 126 |
+
for img in image
|
| 127 |
+
]
|
| 128 |
+
image = processor.preprocess(image, return_tensors="pt")[
|
| 129 |
+
"pixel_values"
|
| 130 |
+
]
|
| 131 |
+
else:
|
| 132 |
+
image = processor.preprocess(image, return_tensors="pt")[
|
| 133 |
+
"pixel_values"
|
| 134 |
+
]
|
| 135 |
+
else:
|
| 136 |
+
try:
|
| 137 |
+
image = Image.open(
|
| 138 |
+
os.path.join(image_folder, image_file)
|
| 139 |
+
).convert("RGB")
|
| 140 |
+
except Exception as ex:
|
| 141 |
+
print(ex)
|
| 142 |
+
i = self.next_rand()
|
| 143 |
+
continue
|
| 144 |
+
if self.data_args.image_aspect_ratio == "pad":
|
| 145 |
+
image = expand2square(
|
| 146 |
+
image, tuple(int(x * 255) for x in processor.image_mean)
|
| 147 |
+
)
|
| 148 |
+
image = processor.preprocess(image, return_tensors="pt")[
|
| 149 |
+
"pixel_values"
|
| 150 |
+
]
|
| 151 |
+
else:
|
| 152 |
+
image = processor.preprocess(image, return_tensors="pt")[
|
| 153 |
+
"pixel_values"
|
| 154 |
+
]
|
| 155 |
+
sources = preprocess_multimodal(
|
| 156 |
+
copy.deepcopy([e["conversations"] for e in sources]),
|
| 157 |
+
self.data_args,
|
| 158 |
+
)
|
| 159 |
+
else:
|
| 160 |
+
|
| 161 |
+
sources = copy.deepcopy([e["conversations"] for e in sources])
|
| 162 |
+
data_dict = preprocess(
|
| 163 |
+
sources,
|
| 164 |
+
self.tokenizer,
|
| 165 |
+
has_image=("image" in sources_org[0]),
|
| 166 |
+
)
|
| 167 |
+
if isinstance(i, int):
|
| 168 |
+
data_dict = dict(
|
| 169 |
+
input_ids=data_dict["input_ids"][0],
|
| 170 |
+
labels=data_dict["labels"][0],
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# default task_type: "score", level_probs: [-10000] * 5
|
| 174 |
+
data_dict["task_type"] = sources_org[0].get("task_type", "score")
|
| 175 |
+
data_dict["level_probs"] = sources_org[0].get("level_probs", [-10000] * 5)
|
| 176 |
+
|
| 177 |
+
# image exist in the data
|
| 178 |
+
if "image" in sources_org[0]:
|
| 179 |
+
data_dict["image"] = image
|
| 180 |
+
elif self.data_args.is_multimodal:
|
| 181 |
+
# image does not exist in the data, but the model is multimodal
|
| 182 |
+
crop_size = self.data_args.image_processor.crop_size
|
| 183 |
+
data_dict["image"] = torch.zeros(
|
| 184 |
+
3, crop_size["height"], crop_size["width"]
|
| 185 |
+
)
|
| 186 |
+
return data_dict
|
| 187 |
+
except Exception as ex:
|
| 188 |
+
print(ex)
|
| 189 |
+
i = self.next_rand()
|
| 190 |
+
continue
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
@dataclass
|
| 194 |
+
class DataCollatorForSupervisedDataset(object):
|
| 195 |
+
"""Collate examples for supervised fine-tuning."""
|
| 196 |
+
|
| 197 |
+
tokenizer: transformers.PreTrainedTokenizer
|
| 198 |
+
|
| 199 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
| 200 |
+
input_ids, labels = tuple(
|
| 201 |
+
[instance[key] for instance in instances] for key in ("input_ids", "labels")
|
| 202 |
+
)
|
| 203 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
| 204 |
+
input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
| 205 |
+
)
|
| 206 |
+
labels = torch.nn.utils.rnn.pad_sequence(
|
| 207 |
+
labels, batch_first=True, padding_value=IGNORE_INDEX
|
| 208 |
+
)
|
| 209 |
+
input_ids = input_ids[:, : self.tokenizer.model_max_length]
|
| 210 |
+
labels = labels[:, : self.tokenizer.model_max_length]
|
| 211 |
+
batch = dict(
|
| 212 |
+
input_type="single",
|
| 213 |
+
input_ids=input_ids,
|
| 214 |
+
labels=labels,
|
| 215 |
+
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
batch["task_types"] = [instance["task_type"] for instance in instances]
|
| 219 |
+
batch["level_probs"] = torch.tensor([instance["level_probs"] for instance in instances])
|
| 220 |
+
|
| 221 |
+
if "image" in instances[0]:
|
| 222 |
+
images = [instance["image"] for instance in instances]
|
| 223 |
+
if all(x is not None and x.shape == images[0].shape for x in images):
|
| 224 |
+
batch["images"] = torch.stack(images)
|
| 225 |
+
else:
|
| 226 |
+
batch["images"] = images
|
| 227 |
+
|
| 228 |
+
return batch
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def make_single_data_module(
|
| 232 |
+
tokenizer: transformers.PreTrainedTokenizer, data_args
|
| 233 |
+
) -> Dict:
|
| 234 |
+
"""Make dataset and collator for supervised fine-tuning."""
|
| 235 |
+
train_dataset = SingleDataset(
|
| 236 |
+
tokenizer=tokenizer,
|
| 237 |
+
data_paths=data_args.data_paths,
|
| 238 |
+
data_weights=data_args.data_weights,
|
| 239 |
+
data_args=data_args,
|
| 240 |
+
)
|
| 241 |
+
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
| 242 |
+
return dict(
|
| 243 |
+
train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator
|
| 244 |
+
)
|
DeQA-Score/src/datasets/utils.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
from dataclasses import dataclass, field
|
| 3 |
+
from typing import Dict, List, Optional, Sequence
|
| 4 |
+
|
| 5 |
+
from PIL import ImageFile
|
| 6 |
+
|
| 7 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 8 |
+
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from typing import List, Optional
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.distributed as dist
|
| 14 |
+
import transformers
|
| 15 |
+
from PIL import Image
|
| 16 |
+
|
| 17 |
+
from src import conversation as conversation_lib
|
| 18 |
+
from src.constants import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX
|
| 19 |
+
from src.mm_utils import tokenizer_image_token
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def rank0_print(*args):
|
| 23 |
+
try:
|
| 24 |
+
if dist.get_rank() == 0:
|
| 25 |
+
print(*args)
|
| 26 |
+
except:
|
| 27 |
+
print(*args)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class DataArguments:
|
| 32 |
+
data_paths: List[str] = field(default_factory=lambda: [])
|
| 33 |
+
lazy_preprocess: bool = False
|
| 34 |
+
is_multimodal: bool = False
|
| 35 |
+
image_folder: Optional[str] = field(default=None)
|
| 36 |
+
image_aspect_ratio: str = "square"
|
| 37 |
+
image_grid_pinpoints: Optional[str] = field(default=None)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _tokenize_fn(
|
| 41 |
+
strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer
|
| 42 |
+
) -> Dict:
|
| 43 |
+
"""Tokenize a list of strings."""
|
| 44 |
+
tokenized_list = [
|
| 45 |
+
tokenizer(
|
| 46 |
+
text,
|
| 47 |
+
return_tensors="pt",
|
| 48 |
+
padding="longest",
|
| 49 |
+
max_length=tokenizer.model_max_length,
|
| 50 |
+
truncation=True,
|
| 51 |
+
)
|
| 52 |
+
for text in strings
|
| 53 |
+
]
|
| 54 |
+
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
| 55 |
+
input_ids_lens = labels_lens = [
|
| 56 |
+
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item()
|
| 57 |
+
for tokenized in tokenized_list
|
| 58 |
+
]
|
| 59 |
+
return dict(
|
| 60 |
+
input_ids=input_ids,
|
| 61 |
+
labels=labels,
|
| 62 |
+
input_ids_lens=input_ids_lens,
|
| 63 |
+
labels_lens=labels_lens,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _mask_targets(target, tokenized_lens, speakers):
|
| 68 |
+
# cur_idx = 0
|
| 69 |
+
cur_idx = tokenized_lens[0]
|
| 70 |
+
tokenized_lens = tokenized_lens[1:]
|
| 71 |
+
target[:cur_idx] = IGNORE_INDEX
|
| 72 |
+
for tokenized_len, speaker in zip(tokenized_lens, speakers):
|
| 73 |
+
if speaker == "human":
|
| 74 |
+
target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX
|
| 75 |
+
cur_idx += tokenized_len
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def _add_speaker_and_signal(header, source, get_conversation=True):
|
| 79 |
+
"""Add speaker and start/end signal on each round."""
|
| 80 |
+
BEGIN_SIGNAL = "### "
|
| 81 |
+
END_SIGNAL = "\n"
|
| 82 |
+
conversation = header
|
| 83 |
+
for sentence in source:
|
| 84 |
+
from_str = sentence["from"]
|
| 85 |
+
if from_str.lower() == "human":
|
| 86 |
+
from_str = conversation_lib.default_conversation.roles[0]
|
| 87 |
+
elif from_str.lower() == "gpt":
|
| 88 |
+
from_str = conversation_lib.default_conversation.roles[1]
|
| 89 |
+
else:
|
| 90 |
+
from_str = "unknown"
|
| 91 |
+
sentence["value"] = (
|
| 92 |
+
BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL
|
| 93 |
+
)
|
| 94 |
+
if get_conversation:
|
| 95 |
+
conversation += sentence["value"]
|
| 96 |
+
conversation += BEGIN_SIGNAL
|
| 97 |
+
return conversation
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def preprocess_multimodal(sources: Sequence[str], data_args: DataArguments) -> Dict:
|
| 101 |
+
is_multimodal = data_args.is_multimodal
|
| 102 |
+
if not is_multimodal:
|
| 103 |
+
return sources
|
| 104 |
+
|
| 105 |
+
for source in sources:
|
| 106 |
+
for sentence in source:
|
| 107 |
+
if DEFAULT_IMAGE_TOKEN in sentence["value"]:
|
| 108 |
+
sentence["value"] = (
|
| 109 |
+
sentence["value"].replace(DEFAULT_IMAGE_TOKEN, "").strip()
|
| 110 |
+
)
|
| 111 |
+
sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"]
|
| 112 |
+
sentence["value"] = sentence["value"].strip()
|
| 113 |
+
|
| 114 |
+
replace_token = DEFAULT_IMAGE_TOKEN
|
| 115 |
+
sentence["value"] = sentence["value"].replace(
|
| 116 |
+
DEFAULT_IMAGE_TOKEN, replace_token
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
return sources
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def preprocess_v1(
|
| 123 |
+
sources, tokenizer: transformers.PreTrainedTokenizer, has_image: bool = False
|
| 124 |
+
) -> Dict:
|
| 125 |
+
conv = conversation_lib.default_conversation.copy()
|
| 126 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
| 127 |
+
|
| 128 |
+
# Apply prompt templates
|
| 129 |
+
conversations = []
|
| 130 |
+
for i, source in enumerate(sources):
|
| 131 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
| 132 |
+
# Skip the first one if it is not from human
|
| 133 |
+
source = source[1:]
|
| 134 |
+
|
| 135 |
+
conv.messages = []
|
| 136 |
+
for j, sentence in enumerate(source):
|
| 137 |
+
role = roles[sentence["from"]]
|
| 138 |
+
assert role == conv.roles[j % 2], f"{i}"
|
| 139 |
+
conv.append_message(role, sentence["value"])
|
| 140 |
+
conversations.append(conv.get_prompt())
|
| 141 |
+
|
| 142 |
+
# Tokenize conversations
|
| 143 |
+
|
| 144 |
+
if has_image:
|
| 145 |
+
input_ids = torch.stack(
|
| 146 |
+
[
|
| 147 |
+
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
|
| 148 |
+
for prompt in conversations
|
| 149 |
+
],
|
| 150 |
+
dim=0,
|
| 151 |
+
)
|
| 152 |
+
else:
|
| 153 |
+
input_ids = tokenizer(
|
| 154 |
+
conversations,
|
| 155 |
+
return_tensors="pt",
|
| 156 |
+
padding="longest",
|
| 157 |
+
max_length=tokenizer.model_max_length,
|
| 158 |
+
truncation=True,
|
| 159 |
+
).input_ids
|
| 160 |
+
|
| 161 |
+
targets = input_ids.clone()
|
| 162 |
+
|
| 163 |
+
assert (
|
| 164 |
+
conv.sep_style == conversation_lib.SeparatorStyle.TWO
|
| 165 |
+
or conv.sep_style == conversation_lib.SeparatorStyle.TWO_NO_SYS
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# Mask targets
|
| 169 |
+
sep = conv.sep + conv.roles[1] + ": "
|
| 170 |
+
for conversation, target in zip(conversations, targets):
|
| 171 |
+
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
| 172 |
+
|
| 173 |
+
rounds = conversation.split(conv.sep2)
|
| 174 |
+
cur_len = 1 + 1
|
| 175 |
+
target[:cur_len] = IGNORE_INDEX
|
| 176 |
+
for i, rou in enumerate(rounds):
|
| 177 |
+
if rou == "":
|
| 178 |
+
break
|
| 179 |
+
|
| 180 |
+
parts = rou.split(sep)
|
| 181 |
+
if len(parts) != 2:
|
| 182 |
+
break
|
| 183 |
+
parts[0] += sep
|
| 184 |
+
|
| 185 |
+
if has_image:
|
| 186 |
+
round_len = len(tokenizer_image_token(rou, tokenizer))
|
| 187 |
+
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 3
|
| 188 |
+
else:
|
| 189 |
+
round_len = len(tokenizer(rou).input_ids)
|
| 190 |
+
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
|
| 191 |
+
round_len -= 1
|
| 192 |
+
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
| 193 |
+
|
| 194 |
+
cur_len += round_len
|
| 195 |
+
target[cur_len:] = IGNORE_INDEX
|
| 196 |
+
|
| 197 |
+
if cur_len < tokenizer.model_max_length:
|
| 198 |
+
if cur_len != total_len:
|
| 199 |
+
target[:] = IGNORE_INDEX
|
| 200 |
+
print(
|
| 201 |
+
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
| 202 |
+
f" (ignored)"
|
| 203 |
+
)
|
| 204 |
+
return dict(
|
| 205 |
+
input_ids=input_ids,
|
| 206 |
+
labels=targets,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def preprocess_plain(
|
| 211 |
+
sources: Sequence[str],
|
| 212 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 213 |
+
) -> Dict:
|
| 214 |
+
# add end signal and concatenate together
|
| 215 |
+
conversations = []
|
| 216 |
+
for source in sources:
|
| 217 |
+
assert len(source) == 2
|
| 218 |
+
assert DEFAULT_IMAGE_TOKEN in source[0]["value"]
|
| 219 |
+
source[0]["value"] = DEFAULT_IMAGE_TOKEN
|
| 220 |
+
conversation = (
|
| 221 |
+
source[0]["value"]
|
| 222 |
+
+ source[1]["value"]
|
| 223 |
+
+ conversation_lib.default_conversation.sep
|
| 224 |
+
)
|
| 225 |
+
conversations.append(conversation)
|
| 226 |
+
# tokenize conversations
|
| 227 |
+
input_ids = [
|
| 228 |
+
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
|
| 229 |
+
for prompt in conversations
|
| 230 |
+
]
|
| 231 |
+
targets = copy.deepcopy(input_ids)
|
| 232 |
+
for target, source in zip(targets, sources):
|
| 233 |
+
tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer))
|
| 234 |
+
target[:tokenized_len] = IGNORE_INDEX
|
| 235 |
+
|
| 236 |
+
return dict(input_ids=input_ids, labels=targets)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def preprocess(
|
| 240 |
+
sources: Sequence[str],
|
| 241 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
| 242 |
+
has_image: bool = False,
|
| 243 |
+
) -> Dict:
|
| 244 |
+
"""
|
| 245 |
+
Given a list of sources, each is a conversation list. This transform:
|
| 246 |
+
1. Add signal '### ' at the beginning each sentence, with end signal '\n';
|
| 247 |
+
2. Concatenate conversations together;
|
| 248 |
+
3. Tokenize the concatenated conversation;
|
| 249 |
+
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX.
|
| 250 |
+
"""
|
| 251 |
+
if (
|
| 252 |
+
conversation_lib.default_conversation.sep_style
|
| 253 |
+
== conversation_lib.SeparatorStyle.PLAIN
|
| 254 |
+
):
|
| 255 |
+
return preprocess_plain(sources, tokenizer)
|
| 256 |
+
if conversation_lib.default_conversation.version.startswith("v1"):
|
| 257 |
+
return preprocess_v1(sources, tokenizer, has_image=has_image)
|
| 258 |
+
# add end signal and concatenate together
|
| 259 |
+
conversations = []
|
| 260 |
+
for source in sources:
|
| 261 |
+
header = f"{conversation_lib.default_conversation.system}\n\n"
|
| 262 |
+
conversation = _add_speaker_and_signal(header, source)
|
| 263 |
+
conversations.append(conversation)
|
| 264 |
+
|
| 265 |
+
# tokenize conversations
|
| 266 |
+
def get_tokenize_len(prompts):
|
| 267 |
+
return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts]
|
| 268 |
+
|
| 269 |
+
if has_image:
|
| 270 |
+
input_ids = [
|
| 271 |
+
tokenizer_image_token(prompt, tokenizer, return_tensors="pt")
|
| 272 |
+
for prompt in conversations
|
| 273 |
+
]
|
| 274 |
+
else:
|
| 275 |
+
conversations_tokenized = _tokenize_fn(conversations, tokenizer)
|
| 276 |
+
input_ids = conversations_tokenized["input_ids"]
|
| 277 |
+
|
| 278 |
+
targets = copy.deepcopy(input_ids)
|
| 279 |
+
for target, source in zip(targets, sources):
|
| 280 |
+
if has_image:
|
| 281 |
+
tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source])
|
| 282 |
+
else:
|
| 283 |
+
tokenized_lens = _tokenize_fn(
|
| 284 |
+
[header] + [s["value"] for s in source], tokenizer
|
| 285 |
+
)["input_ids_lens"]
|
| 286 |
+
speakers = [sentence["from"] for sentence in source]
|
| 287 |
+
_mask_targets(target, tokenized_lens, speakers)
|
| 288 |
+
|
| 289 |
+
return dict(input_ids=input_ids, labels=targets)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def load_video(video_file):
|
| 293 |
+
from decord import VideoReader
|
| 294 |
+
|
| 295 |
+
vr = VideoReader(video_file)
|
| 296 |
+
|
| 297 |
+
# Get video frame rate
|
| 298 |
+
fps = vr.get_avg_fps()
|
| 299 |
+
|
| 300 |
+
# Calculate frame indices for 1fps
|
| 301 |
+
frame_indices = [int(fps * i) for i in range(int(len(vr) / fps))]
|
| 302 |
+
frames = vr.get_batch(frame_indices).asnumpy()
|
| 303 |
+
return [Image.fromarray(frames[i]) for i in range(int(len(vr) / fps))]
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def expand2square(pil_img, background_color):
|
| 307 |
+
width, height = pil_img.size
|
| 308 |
+
if width == height:
|
| 309 |
+
return pil_img
|
| 310 |
+
elif width > height:
|
| 311 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
| 312 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
| 313 |
+
return result
|
| 314 |
+
else:
|
| 315 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
| 316 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
| 317 |
+
return result
|
DeQA-Score/src/evaluate/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .scorer import Scorer
|
DeQA-Score/src/evaluate/cal_distribution_gap.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def parse_args():
|
| 8 |
+
parser = argparse.ArgumentParser(description="evaluation parameters for DeQA-Score")
|
| 9 |
+
parser.add_argument("--level_names", type=str, required=True, nargs="+")
|
| 10 |
+
parser.add_argument("--pred_paths", type=str, required=True, nargs="+")
|
| 11 |
+
parser.add_argument("--gt_paths", type=str, required=True, nargs="+")
|
| 12 |
+
parser.add_argument("--use_openset_probs", action="store_true")
|
| 13 |
+
args = parser.parse_args()
|
| 14 |
+
return args
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def kl_divergence(mu_1, mu_2, sigma_1, sigma_2):
|
| 18 |
+
"""
|
| 19 |
+
Calculate the Kullback-Leibler (KL) divergence between two Gaussian distributions for numpy arrays.
|
| 20 |
+
|
| 21 |
+
Parameters:
|
| 22 |
+
mu_1 (np.array): Mean of the first distribution (array of size N).
|
| 23 |
+
mu_2 (np.array): Mean of the second distribution (array of size N).
|
| 24 |
+
sigma_1 (np.array): Standard deviation of the first distribution (array of size N).
|
| 25 |
+
sigma_2 (np.array): Standard deviation of the second distribution (array of size N).
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
np.array: KL divergence from distribution 1 to distribution 2 (array of size N).
|
| 29 |
+
"""
|
| 30 |
+
eps = 1e-8
|
| 31 |
+
return np.log(sigma_2 / (sigma_1 + eps)) + (sigma_1**2 + (mu_1 - mu_2)**2) / (2 * sigma_2**2 + eps) - 0.5
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def js_divergence(mu_1, mu_2, sigma_1, sigma_2):
|
| 35 |
+
"""
|
| 36 |
+
Calculate the Jensen-Shannon (JS) divergence between two Gaussian distributions for numpy arrays.
|
| 37 |
+
|
| 38 |
+
Parameters:
|
| 39 |
+
mu_1 (np.array): Mean of the first distribution (array of size N).
|
| 40 |
+
mu_2 (np.array): Mean of the second distribution (array of size N).
|
| 41 |
+
sigma_1 (np.array): Standard deviation of the first distribution (array of size N).
|
| 42 |
+
sigma_2 (np.array): Standard deviation of the second distribution (array of size N).
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
np.array: JS divergence between the two distributions (array of size N).
|
| 46 |
+
"""
|
| 47 |
+
# Midpoint distribution parameters
|
| 48 |
+
mu_m = 0.5 * (mu_1 + mu_2)
|
| 49 |
+
sigma_m = np.sqrt(0.5 * (sigma_1**2 + sigma_2**2))
|
| 50 |
+
|
| 51 |
+
# JS divergence as the average of the KL divergences
|
| 52 |
+
return 0.5 * kl_divergence(mu_1, mu_m, sigma_1, sigma_m) + 0.5 * kl_divergence(mu_2, mu_m, sigma_2, sigma_m)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def wasserstein_distance(mu_1, mu_2, sigma_1, sigma_2):
|
| 56 |
+
"""
|
| 57 |
+
Calculate the Wasserstein distance between two Gaussian distributions for numpy arrays.
|
| 58 |
+
|
| 59 |
+
Parameters:
|
| 60 |
+
mu_1 (np.array): Mean of the first distribution (array of size N).
|
| 61 |
+
mu_2 (np.array): Mean of the second distribution (array of size N).
|
| 62 |
+
sigma_1 (np.array): Standard deviation of the first distribution (array of size N).
|
| 63 |
+
sigma_2 (np.array): Standard deviation of the second distribution (array of size N).
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
np.array: Wasserstein distance between the two distributions (array of size N).
|
| 67 |
+
"""
|
| 68 |
+
return np.sqrt((mu_1 - mu_2)**2 + (sigma_1 - sigma_2)**2)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def cal_score(level_names, logits=None, probs=None, use_openset_probs=False):
|
| 72 |
+
if use_openset_probs:
|
| 73 |
+
assert logits is None
|
| 74 |
+
probs = np.array([probs[_] for _ in level_names])
|
| 75 |
+
else:
|
| 76 |
+
assert probs is None
|
| 77 |
+
logprobs = np.array([logits[_] for _ in level_names])
|
| 78 |
+
probs = np.exp(logprobs) / np.sum(np.exp(logprobs))
|
| 79 |
+
score = np.inner(probs, np.array([5., 4., 3., 2., 1.]))
|
| 80 |
+
return score, probs
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def cal_std(score, probs):
|
| 84 |
+
variance = (np.array([5., 4., 3., 2., 1.]) - score) * (np.array([5., 4., 3., 2., 1.]) - score)
|
| 85 |
+
variance = np.inner(probs, variance)
|
| 86 |
+
std = np.sqrt(variance)
|
| 87 |
+
return std
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
if __name__ == "__main__":
|
| 91 |
+
args = parse_args()
|
| 92 |
+
level_names = args.level_names
|
| 93 |
+
pred_paths = args.pred_paths
|
| 94 |
+
gt_paths = args.gt_paths
|
| 95 |
+
use_openset_probs = args.use_openset_probs
|
| 96 |
+
|
| 97 |
+
for pred_path, gt_path in zip(pred_paths, gt_paths):
|
| 98 |
+
print("=" * 100)
|
| 99 |
+
print("Pred: ", pred_path)
|
| 100 |
+
print("GT: ", gt_path)
|
| 101 |
+
|
| 102 |
+
# load predict results
|
| 103 |
+
pred_metas = []
|
| 104 |
+
with open(pred_path) as fr:
|
| 105 |
+
for line in fr:
|
| 106 |
+
pred_meta = json.loads(line)
|
| 107 |
+
pred_metas.append(pred_meta)
|
| 108 |
+
|
| 109 |
+
# load gt results
|
| 110 |
+
with open(gt_path) as fr:
|
| 111 |
+
gt_metas = json.load(fr)
|
| 112 |
+
|
| 113 |
+
pred_metas.sort(key=lambda x: x["id"])
|
| 114 |
+
gt_metas.sort(key=lambda x: x["id"])
|
| 115 |
+
|
| 116 |
+
mu_preds = []
|
| 117 |
+
std_preds = []
|
| 118 |
+
mu_gts = []
|
| 119 |
+
std_gts = []
|
| 120 |
+
for pred_meta, gt_meta in zip(pred_metas, gt_metas):
|
| 121 |
+
assert pred_meta["id"] == gt_meta["id"]
|
| 122 |
+
if use_openset_probs:
|
| 123 |
+
pred_score, probs = cal_score(level_names, logits=pred_meta["logits"], use_openset_probs=True)
|
| 124 |
+
else:
|
| 125 |
+
pred_score, probs = cal_score(level_names, logits=pred_meta["logits"], use_openset_probs=False)
|
| 126 |
+
pred_std = cal_std(pred_score, probs)
|
| 127 |
+
mu_preds.append(pred_score)
|
| 128 |
+
std_preds.append(pred_std)
|
| 129 |
+
mu_gts.append(gt_meta["gt_score_norm"])
|
| 130 |
+
std_gts.append(gt_meta["std_norm"])
|
| 131 |
+
|
| 132 |
+
mu_preds = np.array(mu_preds)
|
| 133 |
+
std_preds = np.array(std_preds)
|
| 134 |
+
mu_gts = np.array(mu_gts)
|
| 135 |
+
std_gts = np.array(std_gts)
|
| 136 |
+
|
| 137 |
+
kl = kl_divergence(mu_gts, mu_preds, std_gts, std_preds).mean()
|
| 138 |
+
js = js_divergence(mu_gts, mu_preds, std_gts, std_preds).mean()
|
| 139 |
+
wd = wasserstein_distance(mu_gts, mu_preds, std_gts, std_preds).mean()
|
| 140 |
+
|
| 141 |
+
print(f"KL: {kl}")
|
| 142 |
+
print(f"JS: {js}")
|
| 143 |
+
print(f"WD: {wd}")
|
DeQA-Score/src/evaluate/cal_plcc_srcc.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
from scipy.optimize import curve_fit
|
| 6 |
+
from scipy.stats import pearsonr, spearmanr
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def parse_args():
|
| 10 |
+
parser = argparse.ArgumentParser(description="evaluation parameters for DeQA-Score")
|
| 11 |
+
parser.add_argument("--level_names", type=str, required=True, nargs="+")
|
| 12 |
+
parser.add_argument("--pred_paths", type=str, required=True, nargs="+")
|
| 13 |
+
parser.add_argument("--gt_paths", type=str, required=True, nargs="+")
|
| 14 |
+
parser.add_argument("--use_openset_probs", action="store_true")
|
| 15 |
+
args = parser.parse_args()
|
| 16 |
+
return args
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def calculate_srcc(pred, mos):
|
| 20 |
+
srcc, _ = spearmanr(pred, mos)
|
| 21 |
+
return srcc
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def calculate_plcc(pred, mos):
|
| 25 |
+
plcc, _ = pearsonr(pred, mos)
|
| 26 |
+
return plcc
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def fit_curve(x, y, curve_type="logistic_4params"):
|
| 30 |
+
r"""Fit the scale of predict scores to MOS scores using logistic regression suggested by VQEG.
|
| 31 |
+
The function with 4 params is more commonly used.
|
| 32 |
+
The 5 params function takes from DBCNN:
|
| 33 |
+
- https://github.com/zwx8981/DBCNN/blob/master/dbcnn/tools/verify_performance.m
|
| 34 |
+
"""
|
| 35 |
+
assert curve_type in [
|
| 36 |
+
"logistic_4params",
|
| 37 |
+
"logistic_5params",
|
| 38 |
+
], f"curve type should be in [logistic_4params, logistic_5params], but got {curve_type}."
|
| 39 |
+
|
| 40 |
+
betas_init_4params = [np.max(y), np.min(y), np.mean(x), np.std(x) / 4.0]
|
| 41 |
+
|
| 42 |
+
def logistic_4params(x, beta1, beta2, beta3, beta4):
|
| 43 |
+
yhat = (beta1 - beta2) / (1 + np.exp(-(x - beta3) / beta4)) + beta2
|
| 44 |
+
return yhat
|
| 45 |
+
|
| 46 |
+
betas_init_5params = [10, 0, np.mean(y), 0.1, 0.1]
|
| 47 |
+
|
| 48 |
+
def logistic_5params(x, beta1, beta2, beta3, beta4, beta5):
|
| 49 |
+
logistic_part = 0.5 - 1.0 / (1 + np.exp(beta2 * (x - beta3)))
|
| 50 |
+
yhat = beta1 * logistic_part + beta4 * x + beta5
|
| 51 |
+
return yhat
|
| 52 |
+
|
| 53 |
+
if curve_type == "logistic_4params":
|
| 54 |
+
logistic = logistic_4params
|
| 55 |
+
betas_init = betas_init_4params
|
| 56 |
+
elif curve_type == "logistic_5params":
|
| 57 |
+
logistic = logistic_5params
|
| 58 |
+
betas_init = betas_init_5params
|
| 59 |
+
|
| 60 |
+
betas, _ = curve_fit(logistic, x, y, p0=betas_init, maxfev=10000)
|
| 61 |
+
yhat = logistic(x, *betas)
|
| 62 |
+
return yhat
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def cal_score(level_names, logits=None, probs=None, use_openset_probs=False):
|
| 66 |
+
if use_openset_probs:
|
| 67 |
+
assert logits is None
|
| 68 |
+
probs = np.array([probs[_] for _ in level_names])
|
| 69 |
+
else:
|
| 70 |
+
assert probs is None
|
| 71 |
+
logprobs = np.array([logits[_] for _ in level_names])
|
| 72 |
+
probs = np.exp(logprobs) / np.sum(np.exp(logprobs))
|
| 73 |
+
score = np.inner(probs, np.array([5., 4., 3., 2., 1.]))
|
| 74 |
+
return score
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
args = parse_args()
|
| 79 |
+
level_names = args.level_names
|
| 80 |
+
pred_paths = args.pred_paths
|
| 81 |
+
gt_paths = args.gt_paths
|
| 82 |
+
use_openset_probs = args.use_openset_probs
|
| 83 |
+
|
| 84 |
+
for pred_path, gt_path in zip(pred_paths, gt_paths):
|
| 85 |
+
print("=" * 100)
|
| 86 |
+
print("Pred: ", pred_path)
|
| 87 |
+
print("GT: ", gt_path)
|
| 88 |
+
|
| 89 |
+
# load predict results
|
| 90 |
+
pred_metas = []
|
| 91 |
+
with open(pred_path) as fr:
|
| 92 |
+
for line in fr:
|
| 93 |
+
pred_meta = json.loads(line)
|
| 94 |
+
pred_metas.append(pred_meta)
|
| 95 |
+
|
| 96 |
+
# load gt results
|
| 97 |
+
with open(gt_path) as fr:
|
| 98 |
+
gt_metas = json.load(fr)
|
| 99 |
+
|
| 100 |
+
preds = []
|
| 101 |
+
gts = []
|
| 102 |
+
for pred_meta, gt_meta in zip(pred_metas, gt_metas):
|
| 103 |
+
assert pred_meta["id"] == gt_meta["id"]
|
| 104 |
+
if use_openset_probs:
|
| 105 |
+
pred_score = cal_score(level_names, probs=pred_meta["probs"], use_openset_probs=True)
|
| 106 |
+
else:
|
| 107 |
+
pred_score = cal_score(level_names, logits=pred_meta["logits"], use_openset_probs=False)
|
| 108 |
+
preds.append(pred_score)
|
| 109 |
+
gts.append(gt_meta["gt_score"])
|
| 110 |
+
|
| 111 |
+
preds_fit = fit_curve(preds, gts)
|
| 112 |
+
srcc = calculate_srcc(preds_fit, gts)
|
| 113 |
+
plcc = calculate_plcc(preds_fit, gts)
|
| 114 |
+
print(f"SRCC: {srcc}")
|
| 115 |
+
print(f"PLCC: {plcc}")
|
DeQA-Score/src/evaluate/eval_qbench_mcq.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
from src.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
| 5 |
+
from src.conversation import conv_templates, SeparatorStyle
|
| 6 |
+
from src.model.builder import load_pretrained_model
|
| 7 |
+
from src.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
| 8 |
+
|
| 9 |
+
from PIL import Image
|
| 10 |
+
|
| 11 |
+
import requests
|
| 12 |
+
from PIL import Image
|
| 13 |
+
from io import BytesIO
|
| 14 |
+
from transformers import TextStreamer
|
| 15 |
+
|
| 16 |
+
import json
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
|
| 21 |
+
def disable_torch_init():
|
| 22 |
+
"""
|
| 23 |
+
Disable the redundant torch default initialization to accelerate model creation.
|
| 24 |
+
"""
|
| 25 |
+
import torch
|
| 26 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
| 27 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def load_image(image_file):
|
| 31 |
+
if image_file.startswith('http://') or image_file.startswith('https://'):
|
| 32 |
+
response = requests.get(image_file)
|
| 33 |
+
image = Image.open(BytesIO(response.content)).convert('RGB')
|
| 34 |
+
else:
|
| 35 |
+
image = Image.open(image_file).convert('RGB')
|
| 36 |
+
return image
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def main(args):
|
| 40 |
+
# Model
|
| 41 |
+
disable_torch_init()
|
| 42 |
+
|
| 43 |
+
model_name = get_model_name_from_path(args.model_path)
|
| 44 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
|
| 45 |
+
|
| 46 |
+
os.makedirs(args.save_dir, exist_ok=True)
|
| 47 |
+
with open(args.meta_path) as f:
|
| 48 |
+
llvqa_data = json.load(f)
|
| 49 |
+
|
| 50 |
+
pbar = tqdm(total=len(llvqa_data))
|
| 51 |
+
|
| 52 |
+
conv_mode = "mplug_owl2"
|
| 53 |
+
|
| 54 |
+
if args.conv_mode is not None and conv_mode != args.conv_mode:
|
| 55 |
+
print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
|
| 56 |
+
else:
|
| 57 |
+
args.conv_mode = conv_mode
|
| 58 |
+
|
| 59 |
+
conv = conv_templates[args.conv_mode].copy()
|
| 60 |
+
roles = conv.roles
|
| 61 |
+
|
| 62 |
+
correct = 0
|
| 63 |
+
for i, llddata in enumerate((llvqa_data)):
|
| 64 |
+
filename = llddata["img_path"]
|
| 65 |
+
|
| 66 |
+
message = llddata["question"] + "\n"
|
| 67 |
+
for choice, ans in zip(["A.", "B.", "C.", "D."], llddata["candidates"]):
|
| 68 |
+
message += f"{choice} {ans}\n"
|
| 69 |
+
if "correct_ans" in llddata and ans == llddata["correct_ans"]:
|
| 70 |
+
correct_choice = choice[0]
|
| 71 |
+
message = message + "Answer with the option's letter from the given choices directly.\n"
|
| 72 |
+
|
| 73 |
+
inp = message
|
| 74 |
+
|
| 75 |
+
conv = conv_templates[args.conv_mode].copy()
|
| 76 |
+
inp = "The input image:" + DEFAULT_IMAGE_TOKEN + inp
|
| 77 |
+
conv.append_message(conv.roles[0], inp)
|
| 78 |
+
conv.append_message(conv.roles[1], None)
|
| 79 |
+
prompt = conv.get_prompt()
|
| 80 |
+
|
| 81 |
+
print(prompt)
|
| 82 |
+
|
| 83 |
+
image = load_image(os.path.join(args.root_dir, filename))
|
| 84 |
+
image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().to(model.device)
|
| 85 |
+
|
| 86 |
+
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
|
| 87 |
+
stop_str = conv.sep if conv.sep_style not in [SeparatorStyle.TWO, SeparatorStyle.TWO_NO_SYS] else conv.sep2
|
| 88 |
+
keywords = [stop_str]
|
| 89 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
| 90 |
+
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 91 |
+
|
| 92 |
+
with torch.inference_mode():
|
| 93 |
+
output_ids = model.generate(
|
| 94 |
+
input_ids,
|
| 95 |
+
attention_mask=torch.ones_like(input_ids),
|
| 96 |
+
images=image_tensor,
|
| 97 |
+
do_sample=False,
|
| 98 |
+
temperature=args.temperature,
|
| 99 |
+
max_new_tokens=args.max_new_tokens,
|
| 100 |
+
num_beams=1,
|
| 101 |
+
streamer=streamer,
|
| 102 |
+
use_cache=True,
|
| 103 |
+
stopping_criteria=[stopping_criteria])
|
| 104 |
+
|
| 105 |
+
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
|
| 106 |
+
llddata["response"] = outputs
|
| 107 |
+
|
| 108 |
+
if correct_choice in outputs:
|
| 109 |
+
correct += 1
|
| 110 |
+
|
| 111 |
+
pbar.update(1)
|
| 112 |
+
pbar.set_description("[Running Accuracy]: {:.4f},[Response]: {}, [Correct Ans]: {}, , [Prog]: {}".format(correct/(i+1), outputs, llddata.get("correct_ans", -1), i+1))
|
| 113 |
+
|
| 114 |
+
save_path = os.path.join(args.save_dir, os.path.basename(args.meta_path))
|
| 115 |
+
with open(save_path, "a") as fw:
|
| 116 |
+
fw.write(json.dumps(llddata) + "\n")
|
| 117 |
+
|
| 118 |
+
if args.debug:
|
| 119 |
+
print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
if __name__ == "__main__":
|
| 123 |
+
parser = argparse.ArgumentParser()
|
| 124 |
+
parser.add_argument("--model-path", type=str, required=True)
|
| 125 |
+
parser.add_argument("--model-base", type=str, default=None)
|
| 126 |
+
parser.add_argument("--root-dir", type=str, required=True)
|
| 127 |
+
parser.add_argument("--save-dir", type=str, required=True)
|
| 128 |
+
parser.add_argument("--meta-path", type=str, required=True)
|
| 129 |
+
parser.add_argument("--device", type=str, default="cuda")
|
| 130 |
+
parser.add_argument("--conv-mode", type=str, default=None)
|
| 131 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
| 132 |
+
parser.add_argument("--max-new-tokens", type=int, default=512)
|
| 133 |
+
parser.add_argument("--load-8bit", action="store_true")
|
| 134 |
+
parser.add_argument("--load-4bit", action="store_true")
|
| 135 |
+
parser.add_argument("--debug", action="store_true")
|
| 136 |
+
parser.add_argument("--image-aspect-ratio", type=str, default='pad')
|
| 137 |
+
args = parser.parse_args()
|
| 138 |
+
main(args)
|
DeQA-Score/src/evaluate/iqa_eval.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
|
| 7 |
+
import requests
|
| 8 |
+
import torch
|
| 9 |
+
from PIL import Image
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
from src.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
|
| 13 |
+
from src.conversation import conv_templates
|
| 14 |
+
from src.mm_utils import get_model_name_from_path, tokenizer_image_token
|
| 15 |
+
from src.model.builder import load_pretrained_model
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def disable_torch_init():
|
| 19 |
+
"""
|
| 20 |
+
Disable the redundant torch default initialization to accelerate model creation.
|
| 21 |
+
"""
|
| 22 |
+
import torch
|
| 23 |
+
|
| 24 |
+
setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
|
| 25 |
+
setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def load_image(image_file):
|
| 29 |
+
if image_file.startswith("http://") or image_file.startswith("https://"):
|
| 30 |
+
response = requests.get(image_file)
|
| 31 |
+
image = Image.open(BytesIO(response.content)).convert("RGB")
|
| 32 |
+
else:
|
| 33 |
+
image = Image.open(image_file).convert("RGB")
|
| 34 |
+
return image
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def main(args):
|
| 38 |
+
# Model
|
| 39 |
+
disable_torch_init()
|
| 40 |
+
|
| 41 |
+
model_name = get_model_name_from_path(args.model_path)
|
| 42 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(
|
| 43 |
+
args.model_path,
|
| 44 |
+
args.model_base,
|
| 45 |
+
model_name,
|
| 46 |
+
args.load_8bit,
|
| 47 |
+
args.load_4bit,
|
| 48 |
+
device=args.device,
|
| 49 |
+
preprocessor_path=args.preprocessor_path,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
meta_paths = args.meta_paths
|
| 53 |
+
root_dir = args.root_dir
|
| 54 |
+
batch_size = args.batch_size
|
| 55 |
+
save_dir = args.save_dir
|
| 56 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 57 |
+
with_prob = args.with_prob
|
| 58 |
+
|
| 59 |
+
conv_mode = "mplug_owl2"
|
| 60 |
+
inp = "How would you rate the quality of this image?"
|
| 61 |
+
|
| 62 |
+
conv = conv_templates[conv_mode].copy()
|
| 63 |
+
inp = inp + "\n" + DEFAULT_IMAGE_TOKEN
|
| 64 |
+
conv.append_message(conv.roles[0], inp)
|
| 65 |
+
image = None
|
| 66 |
+
|
| 67 |
+
conv.append_message(conv.roles[1], None)
|
| 68 |
+
prompt = conv.get_prompt() + " The quality of the image is"
|
| 69 |
+
|
| 70 |
+
toks = args.level_names
|
| 71 |
+
print(toks)
|
| 72 |
+
ids_ = [id_[1] for id_ in tokenizer(toks)["input_ids"]]
|
| 73 |
+
print(ids_)
|
| 74 |
+
|
| 75 |
+
input_ids = (
|
| 76 |
+
tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
|
| 77 |
+
.unsqueeze(0)
|
| 78 |
+
.to(args.device)
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
for meta_path in meta_paths:
|
| 82 |
+
with open(meta_path) as f:
|
| 83 |
+
iqadata = json.load(f)
|
| 84 |
+
|
| 85 |
+
image_tensors = []
|
| 86 |
+
batch_data = []
|
| 87 |
+
|
| 88 |
+
imgs_handled = []
|
| 89 |
+
save_path = os.path.join(save_dir, os.path.basename(meta_path))
|
| 90 |
+
if os.path.exists(save_path):
|
| 91 |
+
with open(save_path) as fr:
|
| 92 |
+
for line in fr:
|
| 93 |
+
meta_res = json.loads(line)
|
| 94 |
+
imgs_handled.append(meta_res["image"])
|
| 95 |
+
|
| 96 |
+
meta_name = os.path.basename(meta_path)
|
| 97 |
+
for i, llddata in enumerate(tqdm(iqadata, desc=f"Evaluating [{meta_name}]")):
|
| 98 |
+
try:
|
| 99 |
+
filename = llddata["image"]
|
| 100 |
+
except:
|
| 101 |
+
filename = llddata["img_path"]
|
| 102 |
+
if filename in imgs_handled:
|
| 103 |
+
continue
|
| 104 |
+
|
| 105 |
+
llddata["logits"] = defaultdict(float)
|
| 106 |
+
llddata["probs"] = defaultdict(float)
|
| 107 |
+
|
| 108 |
+
image = load_image(os.path.join(root_dir, filename))
|
| 109 |
+
|
| 110 |
+
def expand2square(pil_img, background_color):
|
| 111 |
+
width, height = pil_img.size
|
| 112 |
+
if width == height:
|
| 113 |
+
return pil_img
|
| 114 |
+
elif width > height:
|
| 115 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
| 116 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
| 117 |
+
return result
|
| 118 |
+
else:
|
| 119 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
| 120 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
| 121 |
+
return result
|
| 122 |
+
|
| 123 |
+
image = expand2square(
|
| 124 |
+
image, tuple(int(x * 255) for x in image_processor.image_mean)
|
| 125 |
+
)
|
| 126 |
+
image_tensor = (
|
| 127 |
+
image_processor.preprocess(image, return_tensors="pt")["pixel_values"]
|
| 128 |
+
.half()
|
| 129 |
+
.to(args.device)
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
image_tensors.append(image_tensor)
|
| 133 |
+
batch_data.append(llddata)
|
| 134 |
+
|
| 135 |
+
if (i + 1) % batch_size == 0 or i == len(iqadata) - 1:
|
| 136 |
+
with torch.inference_mode():
|
| 137 |
+
output_logits = model(
|
| 138 |
+
input_ids=input_ids.repeat(len(image_tensors), 1),
|
| 139 |
+
images=torch.cat(image_tensors, 0),
|
| 140 |
+
)["logits"][:, -1]
|
| 141 |
+
if with_prob:
|
| 142 |
+
output_probs = torch.softmax(output_logits, dim=1)
|
| 143 |
+
|
| 144 |
+
for j, xllddata in enumerate(batch_data):
|
| 145 |
+
for tok, id_ in zip(toks, ids_):
|
| 146 |
+
xllddata["logits"][tok] += output_logits[j, id_].item()
|
| 147 |
+
if with_prob:
|
| 148 |
+
xllddata["probs"][tok] += output_probs[j, id_].item()
|
| 149 |
+
meta_res = {
|
| 150 |
+
"id": xllddata["id"],
|
| 151 |
+
"image": xllddata["image"],
|
| 152 |
+
"gt_score": xllddata["gt_score"],
|
| 153 |
+
"logits": xllddata["logits"],
|
| 154 |
+
}
|
| 155 |
+
if with_prob:
|
| 156 |
+
meta_res["probs"] = xllddata["probs"]
|
| 157 |
+
with open(save_path, "a") as fw:
|
| 158 |
+
fw.write(json.dumps(meta_res) + "\n")
|
| 159 |
+
|
| 160 |
+
image_tensors = []
|
| 161 |
+
batch_data = []
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
if __name__ == "__main__":
|
| 165 |
+
parser = argparse.ArgumentParser()
|
| 166 |
+
parser.add_argument("--model-path", type=str, required=True)
|
| 167 |
+
parser.add_argument("--model-base", type=str, default=None)
|
| 168 |
+
parser.add_argument("--preprocessor-path", type=str, default=None)
|
| 169 |
+
parser.add_argument("--meta-paths", type=str, required=True, nargs="+")
|
| 170 |
+
parser.add_argument("--root-dir", type=str, required=True)
|
| 171 |
+
parser.add_argument("--save-dir", type=str, default="results")
|
| 172 |
+
parser.add_argument("--level-names", type=str, required=True, nargs="+")
|
| 173 |
+
parser.add_argument("--with-prob", type=bool, default=False) # whether to save openset prob
|
| 174 |
+
parser.add_argument("--device", type=str, default="cuda:0")
|
| 175 |
+
parser.add_argument("--conv-mode", type=str, default=None)
|
| 176 |
+
parser.add_argument("--batch-size", type=int, default=16)
|
| 177 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
| 178 |
+
parser.add_argument("--max-new-tokens", type=int, default=512)
|
| 179 |
+
parser.add_argument("--load-8bit", action="store_true")
|
| 180 |
+
parser.add_argument("--load-4bit", action="store_true")
|
| 181 |
+
parser.add_argument("--debug", action="store_true")
|
| 182 |
+
parser.add_argument("--image-aspect-ratio", type=str, default="pad")
|
| 183 |
+
args = parser.parse_args()
|
| 184 |
+
main(args)
|
DeQA-Score/src/evaluate/scorer.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from typing import List
|
| 7 |
+
|
| 8 |
+
from src.model.builder import load_pretrained_model
|
| 9 |
+
|
| 10 |
+
from src.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
| 11 |
+
from src.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Scorer(nn.Module):
|
| 15 |
+
def __init__(self, pretrained="zhiyuanyou/DeQA-Score-Mix3", device="cuda:0"):
|
| 16 |
+
super().__init__()
|
| 17 |
+
tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device)
|
| 18 |
+
prompt = "USER: How would you rate the quality of this image?\n<|image|>\nASSISTANT: The quality of the image is"
|
| 19 |
+
|
| 20 |
+
self.preferential_ids_ = [id_[1] for id_ in tokenizer(["excellent","good","fair","poor","bad"])["input_ids"]]
|
| 21 |
+
self.weight_tensor = torch.Tensor([5.,4.,3.,2.,1.]).half().to(model.device)
|
| 22 |
+
|
| 23 |
+
self.tokenizer = tokenizer
|
| 24 |
+
self.model = model
|
| 25 |
+
self.image_processor = image_processor
|
| 26 |
+
self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
|
| 27 |
+
|
| 28 |
+
def expand2square(self, pil_img, background_color):
|
| 29 |
+
width, height = pil_img.size
|
| 30 |
+
if width == height:
|
| 31 |
+
return pil_img
|
| 32 |
+
elif width > height:
|
| 33 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
| 34 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
| 35 |
+
return result
|
| 36 |
+
else:
|
| 37 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
| 38 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
| 39 |
+
return result
|
| 40 |
+
|
| 41 |
+
def forward(self, image: List[Image.Image]):
|
| 42 |
+
image = [self.expand2square(img, tuple(int(x*255) for x in self.image_processor.image_mean)) for img in image]
|
| 43 |
+
with torch.inference_mode():
|
| 44 |
+
image_tensor = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().to(self.model.device)
|
| 45 |
+
output_logits = self.model(
|
| 46 |
+
input_ids=self.input_ids.repeat(image_tensor.shape[0], 1),
|
| 47 |
+
images=image_tensor
|
| 48 |
+
)["logits"][:,-1, self.preferential_ids_]
|
| 49 |
+
|
| 50 |
+
return torch.softmax(output_logits, -1) @ self.weight_tensor
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
if __name__ == "__main__":
|
| 54 |
+
import argparse
|
| 55 |
+
|
| 56 |
+
parser = argparse.ArgumentParser()
|
| 57 |
+
parser.add_argument("--model-path", type=str, default="zhiyuanyou/DeQA-Score-Mix3")
|
| 58 |
+
parser.add_argument("--device", type=str, default="cuda:0")
|
| 59 |
+
parser.add_argument("--img_path", type=str, default="fig/singapore_flyer.jpg")
|
| 60 |
+
args = parser.parse_args()
|
| 61 |
+
|
| 62 |
+
scorer = Scorer(pretrained=args.model_path, device=args.device)
|
| 63 |
+
print(scorer([Image.open(args.img_path)]).tolist())
|
DeQA-Score/src/evaluate/scorer_coco.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch
|
| 4 |
+
from typing import List
|
| 5 |
+
from src.model.builder import load_pretrained_model
|
| 6 |
+
from src.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
|
| 7 |
+
from src.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 11 |
+
|
| 12 |
+
def resolve_path(*parts):
|
| 13 |
+
return os.path.abspath(os.path.join(BASE_DIR, *parts))
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Scorer(nn.Module):
|
| 17 |
+
def __init__(self, pretrained="zhiyuanyou/DeQA-Score-Mix3", device="cuda:0"):
|
| 18 |
+
super().__init__()
|
| 19 |
+
tokenizer, model, image_processor, _ = load_pretrained_model(pretrained, None, "mplug_owl2", device=device)
|
| 20 |
+
prompt = "USER: How would you rate the quality of this image?\n<|image|>\nASSISTANT: The quality of the image is"
|
| 21 |
+
|
| 22 |
+
self.preferential_ids_ = [id_[1] for id_ in tokenizer(["excellent","good","fair","poor","bad"])["input_ids"]]
|
| 23 |
+
self.weight_tensor = torch.Tensor([5.,4.,3.,2.,1.]).half().to(model.device)
|
| 24 |
+
|
| 25 |
+
self.tokenizer = tokenizer
|
| 26 |
+
self.model = model
|
| 27 |
+
self.image_processor = image_processor
|
| 28 |
+
self.input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
|
| 29 |
+
|
| 30 |
+
def expand2square(self, pil_img, background_color):
|
| 31 |
+
width, height = pil_img.size
|
| 32 |
+
if width == height:
|
| 33 |
+
return pil_img
|
| 34 |
+
elif width > height:
|
| 35 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
| 36 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
| 37 |
+
return result
|
| 38 |
+
else:
|
| 39 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
| 40 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
| 41 |
+
return result
|
| 42 |
+
|
| 43 |
+
def forward(self, image: List[Image.Image]):
|
| 44 |
+
image = [self.expand2square(img, tuple(int(x*255) for x in self.image_processor.image_mean)) for img in image]
|
| 45 |
+
with torch.inference_mode():
|
| 46 |
+
image_tensor = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"].half().to(self.model.device)
|
| 47 |
+
output_logits = self.model(
|
| 48 |
+
input_ids=self.input_ids.repeat(image_tensor.shape[0], 1),
|
| 49 |
+
images=image_tensor
|
| 50 |
+
)["logits"][:,-1, self.preferential_ids_]
|
| 51 |
+
|
| 52 |
+
return torch.softmax(output_logits, -1) @ self.weight_tensor
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
if __name__ == "__main__":
|
| 56 |
+
import argparse
|
| 57 |
+
|
| 58 |
+
parser = argparse.ArgumentParser()
|
| 59 |
+
parser.add_argument("--model-path", type=str, default="zhiyuanyou/DeQA-Score-Mix3")
|
| 60 |
+
parser.add_argument("--device", type=str, default="cuda:0")
|
| 61 |
+
parser.add_argument("--img_path", type=str, default="fig/singapore_flyer.jpg")
|
| 62 |
+
args = parser.parse_args()
|
| 63 |
+
|
| 64 |
+
scorer = Scorer(pretrained=args.model_path, device=args.device)
|
| 65 |
+
|
| 66 |
+
from PIL import Image, ImageFile
|
| 67 |
+
from pycocotools.coco import COCO
|
| 68 |
+
from tqdm import tqdm
|
| 69 |
+
import os
|
| 70 |
+
|
| 71 |
+
data_IQA = {
|
| 72 |
+
"captions": [],
|
| 73 |
+
"IQAs": [],
|
| 74 |
+
"image_ids": []
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
ANN_PATH = resolve_path('../../../', 'coco_data', 'annotations', 'captions_train2017.json')
|
| 79 |
+
IMG_DIR = resolve_path('../../../', 'coco_data', 'train2017')
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
coco = COCO(ANN_PATH)
|
| 83 |
+
img_ids = coco.getImgIds()
|
| 84 |
+
|
| 85 |
+
for img_id in tqdm(img_ids):
|
| 86 |
+
img_info = coco.loadImgs(img_id)[0]
|
| 87 |
+
file_name = img_info["file_name"]
|
| 88 |
+
img_path = os.path.join(IMG_DIR, file_name)
|
| 89 |
+
|
| 90 |
+
IQA_score = scorer([Image.open(img_path).convert("RGB")])
|
| 91 |
+
|
| 92 |
+
ann_ids = coco.getAnnIds(imgIds=img_id)
|
| 93 |
+
anns = coco.loadAnns(ann_ids)
|
| 94 |
+
for ann in anns:
|
| 95 |
+
data_IQA["captions"].append(ann["caption"].strip())
|
| 96 |
+
data_IQA["image_ids"].append(img_id)
|
| 97 |
+
data_IQA["IQAs"].append(IQA_score.detach().cpu().item())
|
| 98 |
+
caption = ann["caption"].strip()
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
save_path = resolve_path('../../../', 'processed_data', 'coco', 'data_IQA.pt')
|
| 102 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 103 |
+
torch.save(data_IQA, save_path)
|
DeQA-Score/src/mm_utils.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from PIL import Image
|
| 2 |
+
from io import BytesIO
|
| 3 |
+
import base64
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import StoppingCriteria
|
| 7 |
+
from src.constants import IMAGE_TOKEN_INDEX,DEFAULT_IMAGE_TOKEN
|
| 8 |
+
from icecream import ic
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def load_image_from_base64(image):
|
| 12 |
+
return Image.open(BytesIO(base64.b64decode(image)))
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def expand2square(pil_img, background_color):
|
| 16 |
+
width, height = pil_img.size
|
| 17 |
+
if width == height:
|
| 18 |
+
return pil_img
|
| 19 |
+
elif width > height:
|
| 20 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
| 21 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
| 22 |
+
return result
|
| 23 |
+
else:
|
| 24 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
| 25 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
| 26 |
+
return result
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def process_images(images, image_processor, model_cfg=None):
|
| 30 |
+
if model_cfg is not None:
|
| 31 |
+
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None)
|
| 32 |
+
else:
|
| 33 |
+
image_aspect_ratio = 'resize'
|
| 34 |
+
new_images = []
|
| 35 |
+
if image_aspect_ratio == 'pad':
|
| 36 |
+
for image in images:
|
| 37 |
+
image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean))
|
| 38 |
+
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
| 39 |
+
new_images.append(image)
|
| 40 |
+
elif image_aspect_ratio == 'resize':
|
| 41 |
+
for image in images:
|
| 42 |
+
max_edge = max(image.size)
|
| 43 |
+
image = image.resize((max_edge, max_edge))
|
| 44 |
+
image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
| 45 |
+
new_images.append(image)
|
| 46 |
+
else:
|
| 47 |
+
return image_processor(images, return_tensors='pt')['pixel_values']
|
| 48 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
| 49 |
+
new_images = torch.stack(new_images, dim=0)
|
| 50 |
+
return new_images
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None):
|
| 54 |
+
prompt_chunks = [tokenizer(chunk).input_ids if len(chunk) > 0 else [] for chunk in prompt.split(DEFAULT_IMAGE_TOKEN)]
|
| 55 |
+
|
| 56 |
+
def insert_separator(X, sep):
|
| 57 |
+
return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1]
|
| 58 |
+
|
| 59 |
+
input_ids = []
|
| 60 |
+
offset = 0
|
| 61 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
| 62 |
+
offset = 1
|
| 63 |
+
input_ids.append(prompt_chunks[0][0])
|
| 64 |
+
|
| 65 |
+
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)):
|
| 66 |
+
input_ids.extend(x[offset:])
|
| 67 |
+
|
| 68 |
+
if return_tensors is not None:
|
| 69 |
+
if return_tensors == 'pt':
|
| 70 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
| 71 |
+
raise ValueError(f'Unsupported tensor type: {return_tensors}')
|
| 72 |
+
return input_ids
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def get_model_name_from_path(model_path):
|
| 76 |
+
model_path = model_path.strip("/")
|
| 77 |
+
model_paths = model_path.split("/")
|
| 78 |
+
if model_paths[-1].startswith('checkpoint-'):
|
| 79 |
+
return model_paths[-2] + "_" + model_paths[-1]
|
| 80 |
+
else:
|
| 81 |
+
return model_paths[-1]
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
| 87 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
| 88 |
+
self.keywords = keywords
|
| 89 |
+
self.keyword_ids = []
|
| 90 |
+
self.max_keyword_len = 0
|
| 91 |
+
for keyword in keywords:
|
| 92 |
+
cur_keyword_ids = tokenizer(keyword).input_ids
|
| 93 |
+
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
| 94 |
+
cur_keyword_ids = cur_keyword_ids[1:]
|
| 95 |
+
if len(cur_keyword_ids) > self.max_keyword_len:
|
| 96 |
+
self.max_keyword_len = len(cur_keyword_ids)
|
| 97 |
+
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
| 98 |
+
self.tokenizer = tokenizer
|
| 99 |
+
self.start_len = input_ids.shape[1]
|
| 100 |
+
|
| 101 |
+
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
| 102 |
+
assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO
|
| 103 |
+
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
| 104 |
+
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
| 105 |
+
for keyword_id in self.keyword_ids:
|
| 106 |
+
if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all():
|
| 107 |
+
return True
|
| 108 |
+
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
| 109 |
+
for keyword in self.keywords:
|
| 110 |
+
if keyword in outputs:
|
| 111 |
+
return True
|
| 112 |
+
return False
|
DeQA-Score/src/model/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .modeling_mplug_owl2 import MPLUGOwl2LlamaForCausalLM
|
| 2 |
+
from .configuration_mplug_owl2 import MPLUGOwl2Config
|
DeQA-Score/src/model/builder.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 Haotian Liu
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
import warnings
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from transformers import (
|
| 21 |
+
AutoConfig,
|
| 22 |
+
AutoModelForCausalLM,
|
| 23 |
+
AutoTokenizer,
|
| 24 |
+
BitsAndBytesConfig,
|
| 25 |
+
)
|
| 26 |
+
from transformers.models.clip.image_processing_clip import CLIPImageProcessor
|
| 27 |
+
|
| 28 |
+
from src.model import *
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def load_pretrained_model(
|
| 32 |
+
model_path,
|
| 33 |
+
model_base,
|
| 34 |
+
model_name,
|
| 35 |
+
load_8bit=False,
|
| 36 |
+
load_4bit=False,
|
| 37 |
+
device_map="auto",
|
| 38 |
+
device="cuda",
|
| 39 |
+
preprocessor_path=None,
|
| 40 |
+
):
|
| 41 |
+
kwargs = {"device_map": device_map}
|
| 42 |
+
|
| 43 |
+
if device != "cuda":
|
| 44 |
+
kwargs["device_map"] = {"": device}
|
| 45 |
+
|
| 46 |
+
if load_8bit:
|
| 47 |
+
kwargs["load_in_8bit"] = True
|
| 48 |
+
elif load_4bit:
|
| 49 |
+
kwargs["load_in_4bit"] = True
|
| 50 |
+
kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 51 |
+
load_in_4bit=True,
|
| 52 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 53 |
+
bnb_4bit_use_double_quant=True,
|
| 54 |
+
bnb_4bit_quant_type="nf4",
|
| 55 |
+
)
|
| 56 |
+
else:
|
| 57 |
+
kwargs["torch_dtype"] = torch.float16
|
| 58 |
+
|
| 59 |
+
if preprocessor_path is None:
|
| 60 |
+
preprocessor_path = model_path
|
| 61 |
+
|
| 62 |
+
if "deqa" in model_name.lower():
|
| 63 |
+
# Load LLaVA model
|
| 64 |
+
if "lora" in model_name.lower() and model_base is None:
|
| 65 |
+
warnings.warn(
|
| 66 |
+
"There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged."
|
| 67 |
+
)
|
| 68 |
+
if "lora" in model_name.lower() and model_base is not None:
|
| 69 |
+
lora_cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
| 70 |
+
tokenizer = AutoTokenizer.from_pretrained(preprocessor_path, use_fast=False)
|
| 71 |
+
print("Loading mPLUG-Owl2 from base model...")
|
| 72 |
+
model = MPLUGOwl2LlamaForCausalLM.from_pretrained(
|
| 73 |
+
model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs
|
| 74 |
+
)
|
| 75 |
+
token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features
|
| 76 |
+
if model.lm_head.weight.shape[0] != token_num:
|
| 77 |
+
model.lm_head.weight = torch.nn.Parameter(
|
| 78 |
+
torch.empty(
|
| 79 |
+
token_num, tokem_dim, device=model.device, dtype=model.dtype
|
| 80 |
+
)
|
| 81 |
+
)
|
| 82 |
+
model.model.embed_tokens.weight = torch.nn.Parameter(
|
| 83 |
+
torch.empty(
|
| 84 |
+
token_num, tokem_dim, device=model.device, dtype=model.dtype
|
| 85 |
+
)
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
print("Loading additional mPLUG-Owl2 weights...")
|
| 89 |
+
if os.path.exists(os.path.join(model_path, "non_lora_trainables.bin")):
|
| 90 |
+
non_lora_trainables = torch.load(
|
| 91 |
+
os.path.join(model_path, "non_lora_trainables.bin"),
|
| 92 |
+
map_location="cpu",
|
| 93 |
+
)
|
| 94 |
+
print(non_lora_trainables.keys())
|
| 95 |
+
else:
|
| 96 |
+
# this is probably from HF Hub
|
| 97 |
+
from huggingface_hub import hf_hub_download
|
| 98 |
+
|
| 99 |
+
def load_from_hf(repo_id, filename, subfolder=None):
|
| 100 |
+
cache_file = hf_hub_download(
|
| 101 |
+
repo_id=repo_id, filename=filename, subfolder=subfolder
|
| 102 |
+
)
|
| 103 |
+
return torch.load(cache_file, map_location="cpu")
|
| 104 |
+
|
| 105 |
+
non_lora_trainables = load_from_hf(
|
| 106 |
+
model_path, "non_lora_trainables.bin"
|
| 107 |
+
)
|
| 108 |
+
non_lora_trainables = {
|
| 109 |
+
(k[17:] if k.startswith("base_model.model.") else k): v
|
| 110 |
+
for k, v in non_lora_trainables.items()
|
| 111 |
+
}
|
| 112 |
+
model.load_state_dict(non_lora_trainables, strict=False)
|
| 113 |
+
|
| 114 |
+
from peft import PeftModel
|
| 115 |
+
|
| 116 |
+
print("Loading LoRA weights...")
|
| 117 |
+
model = PeftModel.from_pretrained(model, model_path)
|
| 118 |
+
print("Merging LoRA weights...")
|
| 119 |
+
model = model.merge_and_unload()
|
| 120 |
+
print("Model is loaded...")
|
| 121 |
+
elif model_base is not None:
|
| 122 |
+
# this may be mm projector only
|
| 123 |
+
print("Loading mPLUG-Owl2 from base model...")
|
| 124 |
+
tokenizer = AutoTokenizer.from_pretrained(preprocessor_path, use_fast=False)
|
| 125 |
+
cfg_pretrained = AutoConfig.from_pretrained(model_path)
|
| 126 |
+
model = MPLUGOwl2LlamaForCausalLM.from_pretrained(
|
| 127 |
+
model_base, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs
|
| 128 |
+
)
|
| 129 |
+
else:
|
| 130 |
+
tokenizer = AutoTokenizer.from_pretrained(preprocessor_path, use_fast=False)
|
| 131 |
+
model = MPLUGOwl2LlamaForCausalLM.from_pretrained(
|
| 132 |
+
model_path, low_cpu_mem_usage=True, **kwargs
|
| 133 |
+
)
|
| 134 |
+
else:
|
| 135 |
+
# Load language model
|
| 136 |
+
if model_base is not None:
|
| 137 |
+
# PEFT model
|
| 138 |
+
from peft import PeftModel
|
| 139 |
+
|
| 140 |
+
tokenizer = AutoTokenizer.from_pretrained(preprocessor_path, use_fast=False)
|
| 141 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 142 |
+
model_base, low_cpu_mem_usage=True, **kwargs
|
| 143 |
+
)
|
| 144 |
+
print(f"Loading LoRA weights from {model_path}")
|
| 145 |
+
model = PeftModel.from_pretrained(model, model_path)
|
| 146 |
+
print(f"Merging weights")
|
| 147 |
+
model = model.merge_and_unload()
|
| 148 |
+
print("Convert to FP16...")
|
| 149 |
+
model.to(torch.float16)
|
| 150 |
+
else:
|
| 151 |
+
tokenizer = AutoTokenizer.from_pretrained(preprocessor_path, use_fast=False)
|
| 152 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 153 |
+
model_path, low_cpu_mem_usage=True, **kwargs
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# vision_tower = model.get_model().vision_model
|
| 157 |
+
# print(vision_tower.device)
|
| 158 |
+
# vision_tower.to(device=device, dtype=torch.float16)
|
| 159 |
+
image_processor = CLIPImageProcessor.from_pretrained(preprocessor_path)
|
| 160 |
+
|
| 161 |
+
if hasattr(model.config, "max_sequence_length"):
|
| 162 |
+
context_len = model.config.max_sequence_length
|
| 163 |
+
else:
|
| 164 |
+
context_len = 2048
|
| 165 |
+
|
| 166 |
+
return tokenizer, model, image_processor, context_len
|
DeQA-Score/src/model/configuration_mplug_owl2.py
ADDED
|
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Alibaba.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
import copy
|
| 6 |
+
import os
|
| 7 |
+
from typing import Union
|
| 8 |
+
|
| 9 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 10 |
+
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
| 11 |
+
from transformers.utils import logging
|
| 12 |
+
from transformers.models.auto import CONFIG_MAPPING
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LlamaConfig(PretrainedConfig):
|
| 16 |
+
r"""
|
| 17 |
+
This is the configuration class to store the configuration of a [`LlamaModel`]. It is used to instantiate an LLaMA
|
| 18 |
+
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
| 19 |
+
defaults will yield a similar configuration to that of the LLaMA-7B.
|
| 20 |
+
|
| 21 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 22 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
vocab_size (`int`, *optional*, defaults to 32000):
|
| 27 |
+
Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
|
| 28 |
+
`inputs_ids` passed when calling [`LlamaModel`]
|
| 29 |
+
hidden_size (`int`, *optional*, defaults to 4096):
|
| 30 |
+
Dimension of the hidden representations.
|
| 31 |
+
intermediate_size (`int`, *optional*, defaults to 11008):
|
| 32 |
+
Dimension of the MLP representations.
|
| 33 |
+
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 34 |
+
Number of hidden layers in the Transformer decoder.
|
| 35 |
+
num_attention_heads (`int`, *optional*, defaults to 32):
|
| 36 |
+
Number of attention heads for each attention layer in the Transformer decoder.
|
| 37 |
+
num_key_value_heads (`int`, *optional*):
|
| 38 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 39 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 40 |
+
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 41 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 42 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
| 43 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
| 44 |
+
`num_attention_heads`.
|
| 45 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 46 |
+
The non-linear activation function (function or string) in the decoder.
|
| 47 |
+
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
| 48 |
+
The maximum sequence length that this model might ever be used with. Llama 1 supports up to 2048 tokens,
|
| 49 |
+
Llama 2 up to 4096, CodeLlama up to 16384.
|
| 50 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 51 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 52 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
| 53 |
+
The epsilon used by the rms normalization layers.
|
| 54 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 55 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 56 |
+
relevant if `config.is_decoder=True`.
|
| 57 |
+
pad_token_id (`int`, *optional*):
|
| 58 |
+
Padding token id.
|
| 59 |
+
bos_token_id (`int`, *optional*, defaults to 1):
|
| 60 |
+
Beginning of stream token id.
|
| 61 |
+
eos_token_id (`int`, *optional*, defaults to 2):
|
| 62 |
+
End of stream token id.
|
| 63 |
+
pretraining_tp (`int`, *optional*, defaults to 1):
|
| 64 |
+
Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this
|
| 65 |
+
document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is
|
| 66 |
+
necessary to ensure exact reproducibility of the pretraining results. Please refer to [this
|
| 67 |
+
issue](https://github.com/pytorch/pytorch/issues/76232).
|
| 68 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 69 |
+
Whether to tie weight embeddings
|
| 70 |
+
rope_theta (`float`, *optional*, defaults to 10000.0):
|
| 71 |
+
The base period of the RoPE embeddings.
|
| 72 |
+
rope_scaling (`Dict`, *optional*):
|
| 73 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
| 74 |
+
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
| 75 |
+
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
| 76 |
+
`max_position_embeddings` to the expected new maximum. See the following thread for more information on how
|
| 77 |
+
these scaling strategies behave:
|
| 78 |
+
https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. This is an
|
| 79 |
+
experimental feature, subject to breaking API changes in future versions.
|
| 80 |
+
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
| 81 |
+
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
```python
|
| 85 |
+
>>> from transformers import LlamaModel, LlamaConfig
|
| 86 |
+
|
| 87 |
+
>>> # Initializing a LLaMA llama-7b style configuration
|
| 88 |
+
>>> configuration = LlamaConfig()
|
| 89 |
+
|
| 90 |
+
>>> # Initializing a model from the llama-7b style configuration
|
| 91 |
+
>>> model = LlamaModel(configuration)
|
| 92 |
+
|
| 93 |
+
>>> # Accessing the model configuration
|
| 94 |
+
>>> configuration = model.config
|
| 95 |
+
```"""
|
| 96 |
+
model_type = "llama"
|
| 97 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 98 |
+
|
| 99 |
+
def __init__(
|
| 100 |
+
self,
|
| 101 |
+
vocab_size=32000,
|
| 102 |
+
hidden_size=4096,
|
| 103 |
+
intermediate_size=11008,
|
| 104 |
+
num_hidden_layers=32,
|
| 105 |
+
num_attention_heads=32,
|
| 106 |
+
num_key_value_heads=None,
|
| 107 |
+
hidden_act="silu",
|
| 108 |
+
max_position_embeddings=2048,
|
| 109 |
+
initializer_range=0.02,
|
| 110 |
+
rms_norm_eps=1e-6,
|
| 111 |
+
use_cache=True,
|
| 112 |
+
pad_token_id=None,
|
| 113 |
+
bos_token_id=1,
|
| 114 |
+
eos_token_id=2,
|
| 115 |
+
pretraining_tp=1,
|
| 116 |
+
tie_word_embeddings=False,
|
| 117 |
+
rope_theta=10000.0,
|
| 118 |
+
rope_scaling=None,
|
| 119 |
+
attention_bias=False,
|
| 120 |
+
attention_dropout=0.0,
|
| 121 |
+
**kwargs,
|
| 122 |
+
):
|
| 123 |
+
self.vocab_size = vocab_size
|
| 124 |
+
self.max_position_embeddings = max_position_embeddings
|
| 125 |
+
self.hidden_size = hidden_size
|
| 126 |
+
self.intermediate_size = intermediate_size
|
| 127 |
+
self.num_hidden_layers = num_hidden_layers
|
| 128 |
+
self.num_attention_heads = num_attention_heads
|
| 129 |
+
|
| 130 |
+
# for backward compatibility
|
| 131 |
+
if num_key_value_heads is None:
|
| 132 |
+
num_key_value_heads = num_attention_heads
|
| 133 |
+
|
| 134 |
+
self.num_key_value_heads = num_key_value_heads
|
| 135 |
+
self.hidden_act = hidden_act
|
| 136 |
+
self.initializer_range = initializer_range
|
| 137 |
+
self.rms_norm_eps = rms_norm_eps
|
| 138 |
+
self.pretraining_tp = pretraining_tp
|
| 139 |
+
self.use_cache = use_cache
|
| 140 |
+
self.rope_theta = rope_theta
|
| 141 |
+
self.rope_scaling = rope_scaling
|
| 142 |
+
self._rope_scaling_validation()
|
| 143 |
+
self.attention_bias = attention_bias
|
| 144 |
+
self.attention_dropout = attention_dropout
|
| 145 |
+
|
| 146 |
+
super().__init__(
|
| 147 |
+
pad_token_id=pad_token_id,
|
| 148 |
+
bos_token_id=bos_token_id,
|
| 149 |
+
eos_token_id=eos_token_id,
|
| 150 |
+
tie_word_embeddings=tie_word_embeddings,
|
| 151 |
+
**kwargs,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
def _rope_scaling_validation(self):
|
| 155 |
+
"""
|
| 156 |
+
Validate the `rope_scaling` configuration.
|
| 157 |
+
"""
|
| 158 |
+
if self.rope_scaling is None:
|
| 159 |
+
return
|
| 160 |
+
|
| 161 |
+
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
| 162 |
+
raise ValueError(
|
| 163 |
+
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
|
| 164 |
+
f"got {self.rope_scaling}"
|
| 165 |
+
)
|
| 166 |
+
rope_scaling_type = self.rope_scaling.get("type", None)
|
| 167 |
+
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
| 168 |
+
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
| 169 |
+
raise ValueError(
|
| 170 |
+
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
| 171 |
+
)
|
| 172 |
+
if rope_scaling_factor is None or not isinstance(rope_scaling_factor, float) or rope_scaling_factor <= 1.0:
|
| 173 |
+
raise ValueError(f"`rope_scaling`'s factor field must be a float > 1, got {rope_scaling_factor}")
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class MplugOwlVisionConfig(PretrainedConfig):
|
| 177 |
+
r"""
|
| 178 |
+
This is the configuration class to store the configuration of a [`MplugOwlVisionModel`]. It is used to instantiate
|
| 179 |
+
a
|
| 180 |
+
mPLUG-Owl vision encoder according to the specified arguments, defining the model architecture. Instantiating a
|
| 181 |
+
configuration defaults will yield a similar configuration to that of the mPLUG-Owl
|
| 182 |
+
[x-plug/x_plug-llama-7b](https://huggingface.co/x-plug/x_plug-llama-7b) architecture.
|
| 183 |
+
|
| 184 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 185 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
| 189 |
+
Dimensionality of the encoder layers and the pooler layer.
|
| 190 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
| 191 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
| 192 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
| 193 |
+
Number of hidden layers in the Transformer encoder.
|
| 194 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
| 195 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 196 |
+
image_size (`int`, *optional*, defaults to 224):
|
| 197 |
+
The size (resolution) of each image.
|
| 198 |
+
patch_size (`int`, *optional*, defaults to 32):
|
| 199 |
+
The size (resolution) of each patch.
|
| 200 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
|
| 201 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
| 202 |
+
`"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
|
| 203 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-5):
|
| 204 |
+
The epsilon used by the layer normalization layers.
|
| 205 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 206 |
+
The dropout ratio for the attention probabilities.
|
| 207 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 208 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 209 |
+
initializer_factor (`float`, *optional*, defaults to 1):
|
| 210 |
+
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
|
| 211 |
+
testing).
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
```"""
|
| 215 |
+
|
| 216 |
+
model_type = "mplug_owl_vision_model"
|
| 217 |
+
|
| 218 |
+
def __init__(
|
| 219 |
+
self,
|
| 220 |
+
hidden_size=1024,
|
| 221 |
+
intermediate_size=4096,
|
| 222 |
+
projection_dim=768,
|
| 223 |
+
num_hidden_layers=24,
|
| 224 |
+
num_attention_heads=16,
|
| 225 |
+
num_channels=3,
|
| 226 |
+
image_size=448,
|
| 227 |
+
patch_size=14,
|
| 228 |
+
hidden_act="quick_gelu",
|
| 229 |
+
layer_norm_eps=1e-6,
|
| 230 |
+
attention_dropout=0.0,
|
| 231 |
+
initializer_range=0.02,
|
| 232 |
+
initializer_factor=1.0,
|
| 233 |
+
use_flash_attn=False,
|
| 234 |
+
**kwargs,
|
| 235 |
+
):
|
| 236 |
+
super().__init__(**kwargs)
|
| 237 |
+
self.hidden_size = hidden_size
|
| 238 |
+
self.intermediate_size = intermediate_size
|
| 239 |
+
self.projection_dim = projection_dim
|
| 240 |
+
self.num_hidden_layers = num_hidden_layers
|
| 241 |
+
self.num_attention_heads = num_attention_heads
|
| 242 |
+
self.num_channels = num_channels
|
| 243 |
+
self.patch_size = patch_size
|
| 244 |
+
self.image_size = image_size
|
| 245 |
+
self.initializer_range = initializer_range
|
| 246 |
+
self.initializer_factor = initializer_factor
|
| 247 |
+
self.attention_dropout = attention_dropout
|
| 248 |
+
self.layer_norm_eps = layer_norm_eps
|
| 249 |
+
self.hidden_act = hidden_act
|
| 250 |
+
self.use_flash_attn = use_flash_attn
|
| 251 |
+
|
| 252 |
+
@classmethod
|
| 253 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
| 254 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 255 |
+
|
| 256 |
+
# get the vision config dict if we are loading from MplugOwlConfig
|
| 257 |
+
if config_dict.get("model_type") == "mplug-owl":
|
| 258 |
+
config_dict = config_dict["vision_config"]
|
| 259 |
+
|
| 260 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
| 261 |
+
logger.warning(
|
| 262 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
| 263 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
return cls.from_dict(config_dict, **kwargs)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class MplugOwlVisualAbstractorConfig(PretrainedConfig):
|
| 270 |
+
model_type = "mplug_owl_visual_abstract"
|
| 271 |
+
|
| 272 |
+
def __init__(
|
| 273 |
+
self,
|
| 274 |
+
num_learnable_queries=64,
|
| 275 |
+
hidden_size=1024,
|
| 276 |
+
num_hidden_layers=6,
|
| 277 |
+
num_attention_heads=16,
|
| 278 |
+
intermediate_size=2816,
|
| 279 |
+
attention_probs_dropout_prob=0.,
|
| 280 |
+
initializer_range=0.02,
|
| 281 |
+
layer_norm_eps=1e-6,
|
| 282 |
+
encoder_hidden_size=1024,
|
| 283 |
+
grid_size=None,
|
| 284 |
+
**kwargs,
|
| 285 |
+
):
|
| 286 |
+
super().__init__(**kwargs)
|
| 287 |
+
self.hidden_size = hidden_size
|
| 288 |
+
self.num_learnable_queries = num_learnable_queries
|
| 289 |
+
self.num_hidden_layers = num_hidden_layers
|
| 290 |
+
self.num_attention_heads = num_attention_heads
|
| 291 |
+
self.intermediate_size = intermediate_size
|
| 292 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 293 |
+
self.initializer_range = initializer_range
|
| 294 |
+
self.layer_norm_eps = layer_norm_eps
|
| 295 |
+
self.encoder_hidden_size = encoder_hidden_size
|
| 296 |
+
self.grid_size = grid_size if grid_size else 32
|
| 297 |
+
|
| 298 |
+
@classmethod
|
| 299 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
| 300 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 301 |
+
|
| 302 |
+
# get the visual_abstractor config dict if we are loading from MplugOwlConfig
|
| 303 |
+
if config_dict.get("model_type") == "mplug-owl":
|
| 304 |
+
config_dict = config_dict["abstractor_config"]
|
| 305 |
+
|
| 306 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
| 307 |
+
logger.warning(
|
| 308 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
| 309 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
return cls.from_dict(config_dict, **kwargs)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
DEFAULT_VISUAL_CONFIG = {
|
| 317 |
+
"visual_model": MplugOwlVisionConfig().to_dict(),
|
| 318 |
+
"visual_abstractor": MplugOwlVisualAbstractorConfig().to_dict()
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
class MPLUGOwl2Config(LlamaConfig):
|
| 322 |
+
model_type = "mplug_owl2"
|
| 323 |
+
def __init__(self, visual_config=None, **kwargs):
|
| 324 |
+
if visual_config is None:
|
| 325 |
+
self.visual_config = DEFAULT_VISUAL_CONFIG
|
| 326 |
+
else:
|
| 327 |
+
self.visual_config = visual_config
|
| 328 |
+
|
| 329 |
+
super().__init__(
|
| 330 |
+
**kwargs,
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
if __name__ == "__main__":
|
| 334 |
+
print(MplugOwlVisionConfig().to_dict())
|
DeQA-Score/src/model/convert_mplug_owl2_weight_to_hf.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 DAMO Academy and The HuggingFace Inc. team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import argparse
|
| 15 |
+
import gc
|
| 16 |
+
import json
|
| 17 |
+
import math
|
| 18 |
+
import os
|
| 19 |
+
import shutil
|
| 20 |
+
import warnings
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
|
| 24 |
+
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
|
| 25 |
+
from .configuration_mplug_owl2 import MPLUGOwl2Config, MplugOwlVisionConfig, MplugOwlVisualAbstractorConfig
|
| 26 |
+
from .modeling_mplug_owl2 import MPLUGOwl2LlamaForCausalLM
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
from transformers import LlamaTokenizerFast
|
| 30 |
+
except ImportError as e:
|
| 31 |
+
warnings.warn(e)
|
| 32 |
+
warnings.warn(
|
| 33 |
+
"The converted tokenizer will be the `slow` tokenizer. To use the fast, update your `tokenizers` library and re-run the tokenizer conversion"
|
| 34 |
+
)
|
| 35 |
+
LlamaTokenizerFast = None
|
| 36 |
+
|
| 37 |
+
"""
|
| 38 |
+
Sample usage:
|
| 39 |
+
|
| 40 |
+
```
|
| 41 |
+
python3 /pure-mlo-scratch/sfan/model-parallel-trainer/llama2megatron/convert_llama2hf.py \
|
| 42 |
+
--input_dir /pure-mlo-scratch/llama/ --model_size 7 --output_dir /pure-mlo-scratch/llama/converted_HF_7B
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
Thereafter, models can be loaded via:
|
| 46 |
+
|
| 47 |
+
```py
|
| 48 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer
|
| 49 |
+
|
| 50 |
+
model = LlamaForCausalLM.from_pretrained("/output/path")
|
| 51 |
+
tokenizer = LlamaTokenizer.from_pretrained("/output/path")
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
|
| 55 |
+
come in several checkpoints they each contain a part of each weight of the model, so we need to load them all in RAM).
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
llama_s2layer = {7: 32, 13: 40, 30: 60, 65: 80, 70: 80}
|
| 59 |
+
llama_s2heads = {7: 32, 13: 40, 30: 52, 65: 64, 70: 64}
|
| 60 |
+
llama_s2dense = {7: 11008, 13: 13824, 30: 17920, 65: 22016,
|
| 61 |
+
70: 28672} # should be (2/3)*4*d, but it isn't exaclty that
|
| 62 |
+
llama_s2hidden = {7: 4096, 13: 5120, 32: 6656, 65: 8192, 70: 8192}
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def compute_intermediate_size(n):
|
| 66 |
+
return int(math.ceil(n * 8 / 3) + 255) // 256 * 256
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def read_json(path):
|
| 70 |
+
with open(path, "r") as f:
|
| 71 |
+
return json.load(f)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def write_json(text, path):
|
| 75 |
+
with open(path, "w") as f:
|
| 76 |
+
json.dump(text, f)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def write_model(model_path,
|
| 80 |
+
input_base_path,
|
| 81 |
+
model_size,
|
| 82 |
+
num_input_shards=1,
|
| 83 |
+
num_output_shards=2,
|
| 84 |
+
skip_permute=True,
|
| 85 |
+
norm_eps=1e-05):
|
| 86 |
+
# if os.path.exists(model_path):
|
| 87 |
+
# shutil.rmtree(model_path)
|
| 88 |
+
os.makedirs(model_path, exist_ok=True)
|
| 89 |
+
# tmp_model_path = os.path.join(model_path, "tmp")
|
| 90 |
+
tmp_model_path = model_path
|
| 91 |
+
os.makedirs(tmp_model_path, exist_ok=True)
|
| 92 |
+
|
| 93 |
+
num_shards = num_input_shards
|
| 94 |
+
n_layers = llama_s2layer[model_size]
|
| 95 |
+
n_heads = llama_s2heads[model_size]
|
| 96 |
+
n_heads_per_shard = n_heads // num_shards
|
| 97 |
+
n_dense = llama_s2dense[model_size]
|
| 98 |
+
n_hidden = llama_s2hidden[model_size]
|
| 99 |
+
hidden_per_head = n_hidden // n_heads
|
| 100 |
+
base = 10000.0
|
| 101 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, hidden_per_head, 2).float() / hidden_per_head))
|
| 102 |
+
|
| 103 |
+
# permute for sliced rotary
|
| 104 |
+
def permute(w, skip_permute=skip_permute):
|
| 105 |
+
if skip_permute:
|
| 106 |
+
return w
|
| 107 |
+
return w.view(n_heads, n_hidden // n_heads // 2, 2, n_hidden).transpose(1, 2).reshape(n_hidden, n_hidden)
|
| 108 |
+
|
| 109 |
+
print(f"Fetching all parameters from the checkpoint at {input_base_path}.")
|
| 110 |
+
# Load weights
|
| 111 |
+
if num_shards==1:
|
| 112 |
+
# Not sharded
|
| 113 |
+
# (The sharded implementation would also work, but this is simpler.)
|
| 114 |
+
# /pure-mlo-scratch/alhernan/megatron-data/checkpoints/llama2-7b-tp4-pp1-optim/release/mp_rank_00/model_optim_rng.pt
|
| 115 |
+
if os.path.exists(os.path.join(input_base_path, 'release')):
|
| 116 |
+
filename = os.path.join(input_base_path, 'release', 'mp_rank_00', 'model_optim_rng.pt')
|
| 117 |
+
elif input_base_path.split('/')[-1].startswith('iter_'):
|
| 118 |
+
iteration = eval(input_base_path.split('/')[-1].replace('iter_', '').lstrip('0'))
|
| 119 |
+
load_dir = '/'.join(input_base_path.split('/')[:-1])
|
| 120 |
+
filename = os.path.join(input_base_path, 'mp_rank_00', 'model_optim_rng.pt')
|
| 121 |
+
if not os.path.exists(filename):
|
| 122 |
+
filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
|
| 123 |
+
else:
|
| 124 |
+
tracker_filename = os.path.join(input_base_path, 'latest_checkpointed_iteration.txt')
|
| 125 |
+
with open(tracker_filename, 'r') as f:
|
| 126 |
+
metastring = f.read().strip()
|
| 127 |
+
iteration = 'iter_{:07d}'.format(int(metastring))
|
| 128 |
+
filename = os.path.join(input_base_path, iteration, 'mp_rank_00', 'model_optim_rng.pt')
|
| 129 |
+
if not os.path.exists(filename):
|
| 130 |
+
filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
|
| 131 |
+
original_filename = filename
|
| 132 |
+
loaded = torch.load(filename, map_location="cpu")['model']['language_model']
|
| 133 |
+
|
| 134 |
+
else:
|
| 135 |
+
# Sharded
|
| 136 |
+
filenames = []
|
| 137 |
+
for i in range(num_shards):
|
| 138 |
+
if os.path.exists(os.path.join(input_base_path, 'release')):
|
| 139 |
+
filename = os.path.join(input_base_path, 'release', f'mp_rank_{i:02d}', 'model_optim_rng.pt')
|
| 140 |
+
else:
|
| 141 |
+
tracker_filename = os.path.join(input_base_path, 'latest_checkpointed_iteration.txt')
|
| 142 |
+
with open(tracker_filename, 'r') as f:
|
| 143 |
+
metastring = f.read().strip()
|
| 144 |
+
iteration = 'iter_{:07d}'.format(int(metastring))
|
| 145 |
+
filename = os.path.join(input_base_path, iteration, f'mp_rank_{i:02d}', 'model_optim_rng.pt')
|
| 146 |
+
if not os.path.exists(filename):
|
| 147 |
+
filename = filename.replace('model_optim_rng.pt', 'model_rng.pt')
|
| 148 |
+
filenames.append(filename)
|
| 149 |
+
loaded = [
|
| 150 |
+
torch.load(filenames[i], map_location="cpu")['model']['language_model']
|
| 151 |
+
for i in range(num_shards)
|
| 152 |
+
]
|
| 153 |
+
|
| 154 |
+
print('Llama-Megatron Loaded!')
|
| 155 |
+
param_count = 0
|
| 156 |
+
index_dict = {"weight_map": {}}
|
| 157 |
+
|
| 158 |
+
print(f'Weighted Converting for {n_layers} layers...')
|
| 159 |
+
for layer_i in range(n_layers):
|
| 160 |
+
print(layer_i)
|
| 161 |
+
filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
|
| 162 |
+
if num_shards == 1:
|
| 163 |
+
# Unsharded
|
| 164 |
+
state_dict = {
|
| 165 |
+
f"model.layers.{layer_i}.self_attn.q_proj.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.q_proj.weight"],
|
| 166 |
+
f"model.layers.{layer_i}.self_attn.k_proj.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.k_proj.multiway.0.weight"],
|
| 167 |
+
f"model.layers.{layer_i}.self_attn.v_proj.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.v_proj.multiway.0.weight"],
|
| 168 |
+
f"model.layers.{layer_i}.self_attn.k_proj.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.k_proj.multiway.1.weight"],
|
| 169 |
+
f"model.layers.{layer_i}.self_attn.v_proj.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.v_proj.multiway.1.weight"],
|
| 170 |
+
f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded['encoder'][f"layers.{layer_i}.self_attention.o_proj.weight"],
|
| 171 |
+
f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.gate_proj.weight"],
|
| 172 |
+
f"model.layers.{layer_i}.mlp.down_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.down_proj.weight"],
|
| 173 |
+
f"model.layers.{layer_i}.mlp.up_proj.weight": loaded['encoder'][f"layers.{layer_i}.mlp.up_proj.weight"],
|
| 174 |
+
f"model.layers.{layer_i}.input_layernorm.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.input_layernorm.multiway.0.weight"],
|
| 175 |
+
f"model.layers.{layer_i}.post_attention_layernorm.multiway.0.weight": loaded['encoder'][f"layers.{layer_i}.post_attention_layernorm.multiway.0.weight"],
|
| 176 |
+
f"model.layers.{layer_i}.input_layernorm.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.input_layernorm.multiway.1.weight"],
|
| 177 |
+
f"model.layers.{layer_i}.post_attention_layernorm.multiway.1.weight": loaded['encoder'][f"layers.{layer_i}.post_attention_layernorm.multiway.1.weight"],
|
| 178 |
+
}
|
| 179 |
+
else:
|
| 180 |
+
raise NotImplemented
|
| 181 |
+
# else:
|
| 182 |
+
# # Sharded
|
| 183 |
+
# # Note that attention.w{q,k,v,o}, feed_fordward.w[1,2,3], attention_norm.weight and ffn_norm.weight share
|
| 184 |
+
# # the same storage object, saving attention_norm and ffn_norm will save other weights too, which is
|
| 185 |
+
# # redundant as other weights will be stitched from multiple shards. To avoid that, they are cloned.
|
| 186 |
+
|
| 187 |
+
# state_dict = {
|
| 188 |
+
# f"model.layers.{layer_i}.input_layernorm.weight": loaded[0]['encoder'][
|
| 189 |
+
# f"layers.{layer_i}.input_layernorm.multiway.0.weight"
|
| 190 |
+
# ].clone(),
|
| 191 |
+
# f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[0]['encoder'][
|
| 192 |
+
# f"layers.{layer_i}.post_attention_layernorm.multiway.0.weight"
|
| 193 |
+
# ].clone(),
|
| 194 |
+
# }
|
| 195 |
+
|
| 196 |
+
# wqs, wks, wvs, ffn_w1s, ffn_w3s = [], [], [], [], []
|
| 197 |
+
# for shard_idx in range(num_shards):
|
| 198 |
+
# wqs.append(loaded[shard_idx]['encoder'][f"layers.{layer_i}.self_attention.q_proj.weight"])
|
| 199 |
+
# wks.append(loaded[shard_idx]['encoder'][f"layers.{layer_i}.self_attention.k_proj.multiway.0.weight"])
|
| 200 |
+
# wvs.append(loaded[shard_idx]['encoder'][f"layers.{layer_i}.self_attention.v_proj.multiway.0.weight"])
|
| 201 |
+
|
| 202 |
+
# state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = permute(
|
| 203 |
+
# torch.cat(
|
| 204 |
+
# [
|
| 205 |
+
# wq.view(n_heads_per_shard, hidden_per_head, n_hidden)
|
| 206 |
+
# for wq in range(wqs)
|
| 207 |
+
# ],
|
| 208 |
+
# dim=0,
|
| 209 |
+
# ).reshape(n_hidden, n_hidden)
|
| 210 |
+
# )
|
| 211 |
+
# state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = permute(
|
| 212 |
+
# torch.cat(
|
| 213 |
+
# [
|
| 214 |
+
# wk.view(n_heads_per_shard, hidden_per_head, n_hidden)
|
| 215 |
+
# for wk in range(wks)
|
| 216 |
+
# ],
|
| 217 |
+
# dim=0,
|
| 218 |
+
# ).reshape(n_hidden, n_hidden)
|
| 219 |
+
# )
|
| 220 |
+
# state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat(
|
| 221 |
+
# [
|
| 222 |
+
# wv.view(n_heads_per_shard, hidden_per_head, n_hidden)
|
| 223 |
+
# for wv in range(wvs)
|
| 224 |
+
# ],
|
| 225 |
+
# dim=0,
|
| 226 |
+
# ).reshape(n_hidden, n_hidden)
|
| 227 |
+
|
| 228 |
+
# state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
|
| 229 |
+
# [loaded[i]['encoder'][f"layers.{layer_i}.self_attention.o_proj.weight"] for i in range(num_shards)], dim=1
|
| 230 |
+
# )
|
| 231 |
+
# state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
|
| 232 |
+
# [loaded[i]['encoder'][f"layers.{layer_i}.mlp.gate_proj.weight"] for i in range(num_shards)], dim=0
|
| 233 |
+
# )
|
| 234 |
+
# state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
|
| 235 |
+
# [loaded[i]['encoder'][f"layers.{layer_i}.mlp.down_proj.weight"] for i in range(num_shards)], dim=1
|
| 236 |
+
# )
|
| 237 |
+
# state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
|
| 238 |
+
# [loaded[i]['encoder'][f"layers.{layer_i}.mlp.up_proj.weight"] for i in range(num_shards)], dim=0
|
| 239 |
+
# )
|
| 240 |
+
|
| 241 |
+
state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
|
| 242 |
+
for k, v in state_dict.items():
|
| 243 |
+
index_dict["weight_map"][k] = filename
|
| 244 |
+
param_count += v.numel()
|
| 245 |
+
torch.save(state_dict, os.path.join(tmp_model_path, filename))
|
| 246 |
+
print(f'Sharded file saved to {filename}')
|
| 247 |
+
|
| 248 |
+
filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
|
| 249 |
+
if num_shards==1:
|
| 250 |
+
# Unsharded
|
| 251 |
+
state_dict = {
|
| 252 |
+
"model.embed_tokens.weight": loaded['embedding']['word_embeddings']['weight'],
|
| 253 |
+
"model.norm.weight": loaded['encoder']['norm.weight'],
|
| 254 |
+
"lm_head.weight": loaded['encoder']['lm_head.weight'],
|
| 255 |
+
}
|
| 256 |
+
else:
|
| 257 |
+
state_dict = {
|
| 258 |
+
"model.embed_tokens.weight": loaded[0]['embedding']['word_embeddings']['weight'],
|
| 259 |
+
"model.norm.weight": loaded[0]['encoder']['norm.weight'],
|
| 260 |
+
"lm_head.weight": loaded[0]['encoder']['lm_head.weight'],
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
loaded_all = torch.load(original_filename, map_location="cpu")['model']
|
| 265 |
+
# Vision Part
|
| 266 |
+
state_dict.update({
|
| 267 |
+
"model.vision_model.embeddings.cls_token": loaded_all['vision_model']['cls_token'],
|
| 268 |
+
"model.vision_model.embeddings.patch_embed.weight": loaded_all['vision_model']['patch_embed']['weight'],
|
| 269 |
+
"model.vision_model.embeddings.position_embedding": loaded_all['vision_model']['position_embeddings'],
|
| 270 |
+
"model.vision_model.embeddings.pre_layernorm.bias": loaded_all['vision_model']['pre_layernorm']['bias'],
|
| 271 |
+
"model.vision_model.embeddings.pre_layernorm.weight": loaded_all['vision_model']['pre_layernorm']['weight'],
|
| 272 |
+
"model.vision_model.post_layernorm.bias": loaded_all['vision_model']['transformer']['final_layernorm.bias'],
|
| 273 |
+
"model.vision_model.post_layernorm.weight": loaded_all['vision_model']['transformer']['final_layernorm.weight'],
|
| 274 |
+
})
|
| 275 |
+
for v_layer_idx in range(24):
|
| 276 |
+
state_dict.update({
|
| 277 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.input_layernorm.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.input_layernorm.bias'],
|
| 278 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.input_layernorm.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.input_layernorm.weight'],
|
| 279 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc1.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_h_to_4h.bias'],
|
| 280 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc1.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_h_to_4h.weight'],
|
| 281 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc2.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_4h_to_h.bias'],
|
| 282 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.mlp.fc2.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.mlp.dense_4h_to_h.weight'],
|
| 283 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.post_attention_layernorm.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.post_attention_layernorm.bias'],
|
| 284 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.post_attention_layernorm.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.post_attention_layernorm.weight'],
|
| 285 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.dense.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.dense.bias'],
|
| 286 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.dense.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.dense.weight'],
|
| 287 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.query_key_value.bias": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.query_key_value.bias'],
|
| 288 |
+
f"model.vision_model.encoder.layers.{v_layer_idx}.self_attn.query_key_value.weight": loaded_all['vision_model']['transformer'][f'layers.{v_layer_idx}.self_attention.query_key_value.weight'],
|
| 289 |
+
})
|
| 290 |
+
|
| 291 |
+
# Abstractor Part
|
| 292 |
+
state_dict.update({
|
| 293 |
+
"model.visual_abstractor.query_embeds": loaded_all['vision_abstractor']['learnable_queries'],
|
| 294 |
+
"model.visual_abstractor.visual_fc.bias": loaded_all['vision_abstractor']['visual_fc']['bias'],
|
| 295 |
+
"model.visual_abstractor.visual_fc.weight": loaded_all['vision_abstractor']['visual_fc']['weight'],
|
| 296 |
+
"model.visual_abstractor.vit_eos": loaded_all['vision_abstractor']['vit_eos'],
|
| 297 |
+
})
|
| 298 |
+
for v_layer_idx in range(6):
|
| 299 |
+
state_dict.update({
|
| 300 |
+
# f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.k_pos_embed":
|
| 301 |
+
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.key.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.k_proj.bias"],
|
| 302 |
+
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.key.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.k_proj.weight"],
|
| 303 |
+
# f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.q_pos_embed": "pytorch_model-00004-of-00004.bin",
|
| 304 |
+
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.query.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.q_proj.bias"],
|
| 305 |
+
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.query.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.q_proj.weight"],
|
| 306 |
+
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.value.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.v_proj.bias"],
|
| 307 |
+
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.attention.value.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.v_proj.weight"],
|
| 308 |
+
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.norm1.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.norm1.bias"],
|
| 309 |
+
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.norm1.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.norm1.weight"],
|
| 310 |
+
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.normk.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.normk.bias"],
|
| 311 |
+
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.normk.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.normk.weight"],
|
| 312 |
+
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.ffn_ln.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.ffn_ln.bias"],
|
| 313 |
+
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.ffn_ln.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.ffn_ln.weight"],
|
| 314 |
+
|
| 315 |
+
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w1.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w1.bias"],
|
| 316 |
+
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w1.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w1.weight"],
|
| 317 |
+
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w2.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w2.bias"],
|
| 318 |
+
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w2.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w2.weight"],
|
| 319 |
+
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w3.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w3.bias"],
|
| 320 |
+
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.mlp.w3.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.mlp.w3.weight"],
|
| 321 |
+
|
| 322 |
+
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.norm2.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.norm2.bias"],
|
| 323 |
+
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.norm2.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.norm2.weight"],
|
| 324 |
+
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.out_proj.bias": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.o_proj.bias"],
|
| 325 |
+
f"model.visual_abstractor.encoder.layers.{v_layer_idx}.crossattention.output.out_proj.weight": loaded_all['vision_abstractor']['transformer'][f"layers.{v_layer_idx}.self_attention.o_proj.weight"],
|
| 326 |
+
})
|
| 327 |
+
|
| 328 |
+
for k, v in state_dict.items():
|
| 329 |
+
index_dict["weight_map"][k] = filename
|
| 330 |
+
param_count += v.numel()
|
| 331 |
+
torch.save(state_dict, os.path.join(tmp_model_path, filename))
|
| 332 |
+
|
| 333 |
+
# Write configs
|
| 334 |
+
index_dict["metadata"] = {"total_size": param_count * 2}
|
| 335 |
+
write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
|
| 336 |
+
|
| 337 |
+
config = MPLUGOwl2Config()
|
| 338 |
+
config.save_pretrained(tmp_model_path)
|
| 339 |
+
|
| 340 |
+
# Make space so we can load the model properly now.
|
| 341 |
+
del state_dict
|
| 342 |
+
del loaded
|
| 343 |
+
del loaded_all
|
| 344 |
+
gc.collect()
|
| 345 |
+
|
| 346 |
+
def write_tokenizer(tokenizer_path, input_tokenizer_path):
|
| 347 |
+
# Initialize the tokenizer based on the `spm` model
|
| 348 |
+
tokenizer_class = LlamaTokenizer if LlamaTokenizerFast is None else LlamaTokenizerFast
|
| 349 |
+
print(f"Saving a {tokenizer_class.__name__} to {tokenizer_path}.")
|
| 350 |
+
tokenizer = tokenizer_class(input_tokenizer_path)
|
| 351 |
+
tokenizer.save_pretrained(tokenizer_path)
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def main():
|
| 355 |
+
parser = argparse.ArgumentParser()
|
| 356 |
+
parser.add_argument(
|
| 357 |
+
"--input_dir",
|
| 358 |
+
help="Location of LLaMA_Megatron weights",
|
| 359 |
+
)
|
| 360 |
+
parser.add_argument(
|
| 361 |
+
"--model_size",
|
| 362 |
+
type=int,
|
| 363 |
+
default=7,
|
| 364 |
+
choices=[7, 13, 30, 65, 70],
|
| 365 |
+
)
|
| 366 |
+
parser.add_argument(
|
| 367 |
+
"--num_input_shards",
|
| 368 |
+
type=int,
|
| 369 |
+
default=1,
|
| 370 |
+
)
|
| 371 |
+
parser.add_argument(
|
| 372 |
+
"--num_output_shards",
|
| 373 |
+
type=int,
|
| 374 |
+
default=1,
|
| 375 |
+
)
|
| 376 |
+
parser.add_argument('--skip_permute', action='store_true')
|
| 377 |
+
|
| 378 |
+
parser.add_argument(
|
| 379 |
+
"--output_dir",
|
| 380 |
+
help="Location to write HF model and tokenizer",
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
args = parser.parse_args()
|
| 384 |
+
write_model(
|
| 385 |
+
model_path=args.output_dir,
|
| 386 |
+
input_base_path=args.input_dir,
|
| 387 |
+
model_size=args.model_size,
|
| 388 |
+
num_input_shards=args.num_input_shards,
|
| 389 |
+
num_output_shards=args.num_output_shards,
|
| 390 |
+
skip_permute=args.skip_permute
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
if __name__ == "__main__":
|
| 395 |
+
main()
|
DeQA-Score/src/model/modeling_attn_mask_utils.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import List, Optional, Tuple, Union
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class AttentionMaskConverter:
|
| 20 |
+
"""
|
| 21 |
+
A utility attention mask class that allows one to:
|
| 22 |
+
- Create a causal 4d mask
|
| 23 |
+
- Create a causal 4d mask with slided window
|
| 24 |
+
- Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length,
|
| 25 |
+
key_value_length) that can be multiplied with attention scores
|
| 26 |
+
|
| 27 |
+
Parameters:
|
| 28 |
+
is_causal (`bool`):
|
| 29 |
+
Whether the attention mask should be a uni-directional (causal) or bi-directional mask.
|
| 30 |
+
|
| 31 |
+
sliding_window (`int`, *optional*):
|
| 32 |
+
Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):
|
| 36 |
+
self.is_causal = is_causal
|
| 37 |
+
self.sliding_window = sliding_window
|
| 38 |
+
|
| 39 |
+
if self.sliding_window is not None and self.sliding_window <= 0:
|
| 40 |
+
raise ValueError(
|
| 41 |
+
f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def to_causal_4d(
|
| 45 |
+
self,
|
| 46 |
+
batch_size: int,
|
| 47 |
+
query_length: int,
|
| 48 |
+
key_value_length: int,
|
| 49 |
+
dtype: torch.dtype = torch.float32,
|
| 50 |
+
device: Union[torch.device, "str"] = "cpu",
|
| 51 |
+
) -> torch.Tensor:
|
| 52 |
+
"""
|
| 53 |
+
Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative
|
| 54 |
+
bias to upper right hand triangular matrix (causal mask).
|
| 55 |
+
"""
|
| 56 |
+
if not self.is_causal:
|
| 57 |
+
raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.")
|
| 58 |
+
|
| 59 |
+
# If shape is not cached, create a new causal mask and cache it
|
| 60 |
+
input_shape = (batch_size, query_length)
|
| 61 |
+
past_key_values_length = key_value_length - query_length
|
| 62 |
+
|
| 63 |
+
# create causal mask
|
| 64 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 65 |
+
causal_4d_mask = None
|
| 66 |
+
if input_shape[-1] > 1 or self.sliding_window is not None:
|
| 67 |
+
causal_4d_mask = self._make_causal_mask(
|
| 68 |
+
input_shape,
|
| 69 |
+
dtype,
|
| 70 |
+
device=device,
|
| 71 |
+
past_key_values_length=past_key_values_length,
|
| 72 |
+
sliding_window=self.sliding_window,
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
return causal_4d_mask
|
| 76 |
+
|
| 77 |
+
def to_4d(
|
| 78 |
+
self,
|
| 79 |
+
attention_mask_2d: torch.Tensor,
|
| 80 |
+
query_length: int,
|
| 81 |
+
key_value_length: Optional[int] = None,
|
| 82 |
+
dtype: torch.dtype = torch.float32,
|
| 83 |
+
) -> torch.Tensor:
|
| 84 |
+
"""
|
| 85 |
+
Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
|
| 86 |
+
key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is
|
| 87 |
+
causal, a causal mask will be added.
|
| 88 |
+
"""
|
| 89 |
+
input_shape = (attention_mask_2d.shape[0], query_length)
|
| 90 |
+
|
| 91 |
+
# create causal mask
|
| 92 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 93 |
+
causal_4d_mask = None
|
| 94 |
+
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
|
| 95 |
+
if key_value_length is None:
|
| 96 |
+
raise ValueError(
|
| 97 |
+
"This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
past_key_values_length = key_value_length - query_length
|
| 101 |
+
causal_4d_mask = self._make_causal_mask(
|
| 102 |
+
input_shape,
|
| 103 |
+
dtype,
|
| 104 |
+
device=attention_mask_2d.device,
|
| 105 |
+
past_key_values_length=past_key_values_length,
|
| 106 |
+
sliding_window=self.sliding_window,
|
| 107 |
+
)
|
| 108 |
+
elif self.sliding_window is not None:
|
| 109 |
+
raise NotImplementedError("Sliding window is currently only implemented for causal masking")
|
| 110 |
+
|
| 111 |
+
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
|
| 112 |
+
expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to(
|
| 113 |
+
attention_mask_2d.device
|
| 114 |
+
)
|
| 115 |
+
expanded_4d_mask = expanded_attn_mask if causal_4d_mask is None else expanded_attn_mask + causal_4d_mask
|
| 116 |
+
|
| 117 |
+
return expanded_4d_mask
|
| 118 |
+
|
| 119 |
+
@staticmethod
|
| 120 |
+
def _make_causal_mask(
|
| 121 |
+
input_ids_shape: torch.Size,
|
| 122 |
+
dtype: torch.dtype,
|
| 123 |
+
device: torch.device,
|
| 124 |
+
past_key_values_length: int = 0,
|
| 125 |
+
sliding_window: Optional[int] = None,
|
| 126 |
+
):
|
| 127 |
+
"""
|
| 128 |
+
Make causal mask used for bi-directional self-attention.
|
| 129 |
+
"""
|
| 130 |
+
bsz, tgt_len = input_ids_shape
|
| 131 |
+
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
|
| 132 |
+
mask_cond = torch.arange(mask.size(-1), device=device)
|
| 133 |
+
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
|
| 134 |
+
|
| 135 |
+
mask = mask.to(dtype)
|
| 136 |
+
|
| 137 |
+
if past_key_values_length > 0:
|
| 138 |
+
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
|
| 139 |
+
|
| 140 |
+
# add lower triangular sliding window mask if necessary
|
| 141 |
+
if sliding_window is not None:
|
| 142 |
+
diagonal = past_key_values_length - sliding_window + 1
|
| 143 |
+
|
| 144 |
+
context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal)
|
| 145 |
+
mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min)
|
| 146 |
+
|
| 147 |
+
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
|
| 148 |
+
|
| 149 |
+
@staticmethod
|
| 150 |
+
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
| 151 |
+
"""
|
| 152 |
+
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
| 153 |
+
"""
|
| 154 |
+
bsz, src_len = mask.size()
|
| 155 |
+
tgt_len = tgt_len if tgt_len is not None else src_len
|
| 156 |
+
|
| 157 |
+
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
|
| 158 |
+
|
| 159 |
+
inverted_mask = 1.0 - expanded_mask
|
| 160 |
+
|
| 161 |
+
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def _prepare_4d_causal_attention_mask(
|
| 165 |
+
attention_mask: Optional[torch.Tensor],
|
| 166 |
+
input_shape: Union[torch.Size, Tuple, List],
|
| 167 |
+
inputs_embeds: torch.Tensor,
|
| 168 |
+
past_key_values_length: int,
|
| 169 |
+
sliding_window: Optional[int] = None,
|
| 170 |
+
):
|
| 171 |
+
"""
|
| 172 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 173 |
+
`(batch_size, key_value_length)`
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
attention_mask (`torch.Tensor` or `None`):
|
| 177 |
+
A 2D attention mask of shape `(batch_size, key_value_length)`
|
| 178 |
+
input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
|
| 179 |
+
The input shape should be a tuple that defines `(batch_size, query_length)`.
|
| 180 |
+
inputs_embeds (`torch.Tensor`):
|
| 181 |
+
The embedded inputs as a torch Tensor.
|
| 182 |
+
past_key_values_length (`int`):
|
| 183 |
+
The length of the key value cache.
|
| 184 |
+
sliding_window (`int`, *optional*):
|
| 185 |
+
If the model uses windowed attention, a sliding window should be passed.
|
| 186 |
+
"""
|
| 187 |
+
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
|
| 188 |
+
|
| 189 |
+
key_value_length = input_shape[-1] + past_key_values_length
|
| 190 |
+
|
| 191 |
+
# 4d mask is passed through the layers
|
| 192 |
+
if attention_mask is not None:
|
| 193 |
+
attention_mask = attn_mask_converter.to_4d(
|
| 194 |
+
attention_mask, input_shape[-1], key_value_length, dtype=inputs_embeds.dtype
|
| 195 |
+
)
|
| 196 |
+
else:
|
| 197 |
+
attention_mask = attn_mask_converter.to_causal_4d(
|
| 198 |
+
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
return attention_mask
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def _prepare_4d_attention_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
| 205 |
+
"""
|
| 206 |
+
Creates a non-causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
| 207 |
+
`(batch_size, key_value_length)`
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
mask (`torch.Tensor` or `None`):
|
| 211 |
+
A 2D attention mask of shape `(batch_size, key_value_length)`
|
| 212 |
+
dtype (`torch.dtype`):
|
| 213 |
+
The torch dtype the created mask shall have.
|
| 214 |
+
tgt_len (`int`):
|
| 215 |
+
The target length or query length the created mask shall have.
|
| 216 |
+
"""
|
| 217 |
+
return AttentionMaskConverter._expand_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def _create_4d_causal_attention_mask(
|
| 221 |
+
input_shape: Union[torch.Size, Tuple, List],
|
| 222 |
+
dtype: torch.dtype,
|
| 223 |
+
device: torch.device,
|
| 224 |
+
past_key_values_length: int = 0,
|
| 225 |
+
sliding_window: Optional[int] = None,
|
| 226 |
+
):
|
| 227 |
+
"""
|
| 228 |
+
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
input_shape (`tuple(int)` or `list(int)` or `torch.Size`):
|
| 232 |
+
The input shape should be a tuple that defines `(batch_size, query_length)`.
|
| 233 |
+
dtype (`torch.dtype`):
|
| 234 |
+
The torch dtype the created mask shall have.
|
| 235 |
+
device (`int`):
|
| 236 |
+
The torch device the created mask shall have.
|
| 237 |
+
sliding_window (`int`, *optional*):
|
| 238 |
+
If the model uses windowed attention, a sliding window should be passed.
|
| 239 |
+
"""
|
| 240 |
+
attn_mask_converter = AttentionMaskConverter(is_causal=True, sliding_window=sliding_window)
|
| 241 |
+
|
| 242 |
+
key_value_length = past_key_values_length + input_shape[-1]
|
| 243 |
+
attention_mask = attn_mask_converter.to_causal_4d(
|
| 244 |
+
input_shape[0], input_shape[-1], key_value_length, dtype=dtype, device=device
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
return attention_mask
|
DeQA-Score/src/model/modeling_llama2.py
ADDED
|
@@ -0,0 +1,834 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import warnings
|
| 3 |
+
from functools import partial
|
| 4 |
+
from typing import List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import torch.utils.checkpoint
|
| 9 |
+
from torch import nn
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
import copy
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
|
| 16 |
+
dir_path = os.path.dirname(os.path.realpath(__file__))
|
| 17 |
+
sys.path.insert(0, dir_path)
|
| 18 |
+
|
| 19 |
+
import transformers
|
| 20 |
+
from transformers.models.llama.modeling_llama import *
|
| 21 |
+
|
| 22 |
+
def _get_unpad_data(attention_mask):
|
| 23 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
| 24 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
| 25 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
| 26 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
| 27 |
+
return (
|
| 28 |
+
indices,
|
| 29 |
+
cu_seqlens,
|
| 30 |
+
max_seqlen_in_batch,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 35 |
+
from transformers.utils import logging
|
| 36 |
+
|
| 37 |
+
from .modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
| 38 |
+
from .configuration_mplug_owl2 import LlamaConfig
|
| 39 |
+
|
| 40 |
+
class MultiwayNetwork(nn.Module):
|
| 41 |
+
|
| 42 |
+
def __init__(self, module_provider, num_multiway=2):
|
| 43 |
+
super(MultiwayNetwork, self).__init__()
|
| 44 |
+
|
| 45 |
+
self.multiway = torch.nn.ModuleList([module_provider() for _ in range(num_multiway)])
|
| 46 |
+
|
| 47 |
+
def forward(self, hidden_states, multiway_indices):
|
| 48 |
+
|
| 49 |
+
if len(self.multiway) == 1:
|
| 50 |
+
return self.multiway[0](hidden_states)
|
| 51 |
+
|
| 52 |
+
output_hidden_states = torch.empty_like(hidden_states)
|
| 53 |
+
|
| 54 |
+
for idx, subway in enumerate(self.multiway):
|
| 55 |
+
local_indices = multiway_indices.eq(idx).nonzero(as_tuple=True)
|
| 56 |
+
hidden = hidden_states[local_indices].unsqueeze(1).contiguous()
|
| 57 |
+
if hidden.numel():
|
| 58 |
+
output = subway(hidden)
|
| 59 |
+
if isinstance(output, tuple):
|
| 60 |
+
output = output[0]
|
| 61 |
+
output = output.squeeze(1)
|
| 62 |
+
output_hidden_states[local_indices] = output
|
| 63 |
+
|
| 64 |
+
return output_hidden_states.contiguous()
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class LlamaAttention(nn.Module):
|
| 68 |
+
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
| 69 |
+
|
| 70 |
+
def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.config = config
|
| 73 |
+
self.layer_idx = layer_idx
|
| 74 |
+
if layer_idx is None:
|
| 75 |
+
logger.warning_once(
|
| 76 |
+
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
| 77 |
+
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
| 78 |
+
"when creating this class."
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
self.attention_dropout = config.attention_dropout
|
| 82 |
+
self.hidden_size = config.hidden_size
|
| 83 |
+
self.num_heads = config.num_attention_heads
|
| 84 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 85 |
+
self.num_key_value_heads = config.num_key_value_heads
|
| 86 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 87 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 88 |
+
self.rope_theta = config.rope_theta
|
| 89 |
+
self.is_causal = True
|
| 90 |
+
|
| 91 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
| 92 |
+
raise ValueError(
|
| 93 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
| 94 |
+
f" and `num_heads`: {self.num_heads})."
|
| 95 |
+
)
|
| 96 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
|
| 97 |
+
self.k_proj = MultiwayNetwork(module_provider=partial(
|
| 98 |
+
nn.Linear, in_features=self.hidden_size, out_features=self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
| 99 |
+
)
|
| 100 |
+
self.v_proj = MultiwayNetwork(module_provider=partial(
|
| 101 |
+
nn.Linear, in_features=self.hidden_size, out_features=self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
|
| 102 |
+
)
|
| 103 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
|
| 104 |
+
self._init_rope()
|
| 105 |
+
|
| 106 |
+
def _init_rope(self):
|
| 107 |
+
if self.config.rope_scaling is None:
|
| 108 |
+
self.rotary_emb = LlamaRotaryEmbedding(
|
| 109 |
+
self.head_dim,
|
| 110 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 111 |
+
base=self.rope_theta,
|
| 112 |
+
)
|
| 113 |
+
else:
|
| 114 |
+
scaling_type = self.config.rope_scaling["type"]
|
| 115 |
+
scaling_factor = self.config.rope_scaling["factor"]
|
| 116 |
+
if scaling_type == "linear":
|
| 117 |
+
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
|
| 118 |
+
self.head_dim,
|
| 119 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 120 |
+
scaling_factor=scaling_factor,
|
| 121 |
+
base=self.rope_theta,
|
| 122 |
+
)
|
| 123 |
+
elif scaling_type == "dynamic":
|
| 124 |
+
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
|
| 125 |
+
self.head_dim,
|
| 126 |
+
max_position_embeddings=self.max_position_embeddings,
|
| 127 |
+
scaling_factor=scaling_factor,
|
| 128 |
+
base=self.rope_theta,
|
| 129 |
+
)
|
| 130 |
+
else:
|
| 131 |
+
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
|
| 132 |
+
|
| 133 |
+
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
| 134 |
+
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 135 |
+
|
| 136 |
+
def forward(
|
| 137 |
+
self,
|
| 138 |
+
hidden_states: torch.Tensor,
|
| 139 |
+
modality_indicators: torch.Tensor,
|
| 140 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 141 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 142 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 143 |
+
output_attentions: bool = False,
|
| 144 |
+
use_cache: bool = False,
|
| 145 |
+
padding_mask: Optional[torch.LongTensor] = None,
|
| 146 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 147 |
+
bsz, q_len, _ = hidden_states.size()
|
| 148 |
+
|
| 149 |
+
query_states = self.q_proj(hidden_states, )
|
| 150 |
+
key_states = self.k_proj(hidden_states, modality_indicators)
|
| 151 |
+
value_states = self.v_proj(hidden_states, modality_indicators)
|
| 152 |
+
|
| 153 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 154 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 155 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 156 |
+
|
| 157 |
+
kv_seq_len = key_states.shape[-2]
|
| 158 |
+
if past_key_value is not None:
|
| 159 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
| 160 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 161 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 162 |
+
|
| 163 |
+
if past_key_value is not None:
|
| 164 |
+
# reuse k, v, self_attention
|
| 165 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
| 166 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
| 167 |
+
|
| 168 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
| 169 |
+
|
| 170 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 171 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 172 |
+
|
| 173 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
| 174 |
+
|
| 175 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
| 176 |
+
raise ValueError(
|
| 177 |
+
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
| 178 |
+
f" {attn_weights.size()}"
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
if attention_mask is not None:
|
| 182 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
| 183 |
+
raise ValueError(
|
| 184 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
| 185 |
+
)
|
| 186 |
+
attn_weights = attn_weights + attention_mask
|
| 187 |
+
|
| 188 |
+
# upcast attention to fp32
|
| 189 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
| 190 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
| 191 |
+
|
| 192 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
| 193 |
+
raise ValueError(
|
| 194 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
| 195 |
+
f" {attn_output.size()}"
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 199 |
+
|
| 200 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 201 |
+
|
| 202 |
+
attn_output = self.o_proj(attn_output)
|
| 203 |
+
|
| 204 |
+
if not output_attentions:
|
| 205 |
+
attn_weights = None
|
| 206 |
+
|
| 207 |
+
return attn_output, attn_weights, past_key_value
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class LlamaFlashAttention2(LlamaAttention):
|
| 211 |
+
"""
|
| 212 |
+
Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
|
| 213 |
+
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
| 214 |
+
flash attention and deal with padding tokens in case the input contains any of them.
|
| 215 |
+
"""
|
| 216 |
+
|
| 217 |
+
def __init__(self, *args, **kwargs):
|
| 218 |
+
super().__init__(*args, **kwargs)
|
| 219 |
+
|
| 220 |
+
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
| 221 |
+
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
| 222 |
+
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
| 223 |
+
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
| 224 |
+
|
| 225 |
+
def forward(
|
| 226 |
+
self,
|
| 227 |
+
hidden_states: torch.Tensor,
|
| 228 |
+
modality_indicators: torch.Tensor,
|
| 229 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 230 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 231 |
+
past_key_value: Optional[Cache] = None,
|
| 232 |
+
output_attentions: bool = False,
|
| 233 |
+
use_cache: bool = False,
|
| 234 |
+
**kwargs,
|
| 235 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 236 |
+
# LlamaFlashAttention2 attention does not support output_attentions
|
| 237 |
+
if "padding_mask" in kwargs:
|
| 238 |
+
warnings.warn(
|
| 239 |
+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
# overwrite attention_mask with padding_mask
|
| 243 |
+
attention_mask = kwargs.pop("padding_mask")
|
| 244 |
+
|
| 245 |
+
output_attentions = False
|
| 246 |
+
|
| 247 |
+
bsz, q_len, _ = hidden_states.size()
|
| 248 |
+
|
| 249 |
+
query_states = self.q_proj(hidden_states)
|
| 250 |
+
key_states = self.k_proj(hidden_states, modality_indicators)
|
| 251 |
+
value_states = self.v_proj(hidden_states, modality_indicators)
|
| 252 |
+
|
| 253 |
+
# Flash attention requires the input to have the shape
|
| 254 |
+
# batch_size x seq_length x head_dim x hidden_dim
|
| 255 |
+
# therefore we just need to keep the original shape
|
| 256 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 257 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 258 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 259 |
+
|
| 260 |
+
kv_seq_len = key_states.shape[-2]
|
| 261 |
+
if past_key_value is not None:
|
| 262 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 263 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 264 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 265 |
+
|
| 266 |
+
if past_key_value is not None:
|
| 267 |
+
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
| 268 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 269 |
+
|
| 270 |
+
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
|
| 271 |
+
# to be able to avoid many of these transpose/reshape/view.
|
| 272 |
+
query_states = query_states.transpose(1, 2)
|
| 273 |
+
key_states = key_states.transpose(1, 2)
|
| 274 |
+
value_states = value_states.transpose(1, 2)
|
| 275 |
+
|
| 276 |
+
dropout_rate = self.attention_dropout if self.training else 0.0
|
| 277 |
+
|
| 278 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
| 279 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
| 280 |
+
# cast them back in the correct dtype just to be sure everything works as expected.
|
| 281 |
+
# This might slowdown training & inference so it is recommended to not cast the LayerNorms
|
| 282 |
+
# in fp32. (LlamaRMSNorm handles it correctly)
|
| 283 |
+
|
| 284 |
+
input_dtype = query_states.dtype
|
| 285 |
+
if input_dtype == torch.float32:
|
| 286 |
+
if torch.is_autocast_enabled():
|
| 287 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
| 288 |
+
# Handle the case where the model is quantized
|
| 289 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
| 290 |
+
target_dtype = self.config._pre_quantization_dtype
|
| 291 |
+
else:
|
| 292 |
+
target_dtype = self.q_proj.weight.dtype
|
| 293 |
+
|
| 294 |
+
logger.warning_once(
|
| 295 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
| 296 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
| 297 |
+
f" {target_dtype}."
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
query_states = query_states.to(target_dtype)
|
| 301 |
+
key_states = key_states.to(target_dtype)
|
| 302 |
+
value_states = value_states.to(target_dtype)
|
| 303 |
+
|
| 304 |
+
attn_output = self._flash_attention_forward(
|
| 305 |
+
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
| 309 |
+
attn_output = self.o_proj(attn_output)
|
| 310 |
+
|
| 311 |
+
if not output_attentions:
|
| 312 |
+
attn_weights = None
|
| 313 |
+
|
| 314 |
+
return attn_output, attn_weights, past_key_value
|
| 315 |
+
|
| 316 |
+
def _flash_attention_forward(
|
| 317 |
+
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
|
| 318 |
+
):
|
| 319 |
+
"""
|
| 320 |
+
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
|
| 321 |
+
first unpad the input, then computes the attention scores and pad the final attention scores.
|
| 322 |
+
|
| 323 |
+
Args:
|
| 324 |
+
query_states (`torch.Tensor`):
|
| 325 |
+
Input query states to be passed to Flash Attention API
|
| 326 |
+
key_states (`torch.Tensor`):
|
| 327 |
+
Input key states to be passed to Flash Attention API
|
| 328 |
+
value_states (`torch.Tensor`):
|
| 329 |
+
Input value states to be passed to Flash Attention API
|
| 330 |
+
attention_mask (`torch.Tensor`):
|
| 331 |
+
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
|
| 332 |
+
position of padding tokens and 1 for the position of non-padding tokens.
|
| 333 |
+
dropout (`int`, *optional*):
|
| 334 |
+
Attention dropout
|
| 335 |
+
softmax_scale (`float`, *optional*):
|
| 336 |
+
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
| 337 |
+
"""
|
| 338 |
+
if not self._flash_attn_uses_top_left_mask:
|
| 339 |
+
causal = self.is_causal
|
| 340 |
+
else:
|
| 341 |
+
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__.
|
| 342 |
+
causal = self.is_causal and query_length != 1
|
| 343 |
+
|
| 344 |
+
# Contains at least one padding token in the sequence
|
| 345 |
+
if attention_mask is not None:
|
| 346 |
+
batch_size = query_states.shape[0]
|
| 347 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
|
| 348 |
+
query_states, key_states, value_states, attention_mask, query_length
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
| 352 |
+
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
| 353 |
+
|
| 354 |
+
attn_output_unpad = flash_attn_varlen_func(
|
| 355 |
+
query_states,
|
| 356 |
+
key_states,
|
| 357 |
+
value_states,
|
| 358 |
+
cu_seqlens_q=cu_seqlens_q,
|
| 359 |
+
cu_seqlens_k=cu_seqlens_k,
|
| 360 |
+
max_seqlen_q=max_seqlen_in_batch_q,
|
| 361 |
+
max_seqlen_k=max_seqlen_in_batch_k,
|
| 362 |
+
dropout_p=dropout,
|
| 363 |
+
softmax_scale=softmax_scale,
|
| 364 |
+
causal=causal,
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
|
| 368 |
+
else:
|
| 369 |
+
attn_output = flash_attn_func(
|
| 370 |
+
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
return attn_output
|
| 374 |
+
|
| 375 |
+
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
| 376 |
+
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
| 377 |
+
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
| 378 |
+
|
| 379 |
+
key_layer = index_first_axis(
|
| 380 |
+
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
| 381 |
+
)
|
| 382 |
+
value_layer = index_first_axis(
|
| 383 |
+
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
|
| 384 |
+
)
|
| 385 |
+
if query_length == kv_seq_len:
|
| 386 |
+
query_layer = index_first_axis(
|
| 387 |
+
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
|
| 388 |
+
)
|
| 389 |
+
cu_seqlens_q = cu_seqlens_k
|
| 390 |
+
max_seqlen_in_batch_q = max_seqlen_in_batch_k
|
| 391 |
+
indices_q = indices_k
|
| 392 |
+
elif query_length == 1:
|
| 393 |
+
max_seqlen_in_batch_q = 1
|
| 394 |
+
cu_seqlens_q = torch.arange(
|
| 395 |
+
batch_size + 1, dtype=torch.int32, device=query_layer.device
|
| 396 |
+
) # There is a memcpy here, that is very bad.
|
| 397 |
+
indices_q = cu_seqlens_q[:-1]
|
| 398 |
+
query_layer = query_layer.squeeze(1)
|
| 399 |
+
else:
|
| 400 |
+
# The -q_len: slice assumes left padding.
|
| 401 |
+
attention_mask = attention_mask[:, -query_length:]
|
| 402 |
+
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
|
| 403 |
+
|
| 404 |
+
return (
|
| 405 |
+
query_layer,
|
| 406 |
+
key_layer,
|
| 407 |
+
value_layer,
|
| 408 |
+
indices_q,
|
| 409 |
+
(cu_seqlens_q, cu_seqlens_k),
|
| 410 |
+
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
class LlamaSdpaAttention(LlamaAttention):
|
| 415 |
+
"""
|
| 416 |
+
Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
|
| 417 |
+
`LlamaAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
|
| 418 |
+
SDPA API.
|
| 419 |
+
"""
|
| 420 |
+
|
| 421 |
+
# Adapted from LlamaAttention.forward
|
| 422 |
+
def forward(
|
| 423 |
+
self,
|
| 424 |
+
hidden_states: torch.Tensor,
|
| 425 |
+
modality_indicators: torch.Tensor,
|
| 426 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 427 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 428 |
+
past_key_value: Optional[Cache] = None,
|
| 429 |
+
output_attentions: bool = False,
|
| 430 |
+
use_cache: bool = False,
|
| 431 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 432 |
+
if output_attentions:
|
| 433 |
+
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
| 434 |
+
logger.warning_once(
|
| 435 |
+
"LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
| 436 |
+
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 437 |
+
)
|
| 438 |
+
return super().forward(
|
| 439 |
+
hidden_states=hidden_states,
|
| 440 |
+
modality_indicators=modality_indicators,
|
| 441 |
+
attention_mask=attention_mask,
|
| 442 |
+
position_ids=position_ids,
|
| 443 |
+
past_key_value=past_key_value,
|
| 444 |
+
output_attentions=output_attentions,
|
| 445 |
+
use_cache=use_cache,
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
bsz, q_len, _ = hidden_states.size()
|
| 449 |
+
|
| 450 |
+
query_states = self.q_proj(hidden_states)
|
| 451 |
+
key_states = self.k_proj(hidden_states, modality_indicators)
|
| 452 |
+
value_states = self.v_proj(hidden_states, modality_indicators)
|
| 453 |
+
|
| 454 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 455 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 456 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 457 |
+
|
| 458 |
+
kv_seq_len = key_states.shape[-2]
|
| 459 |
+
if past_key_value is not None:
|
| 460 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 461 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 462 |
+
|
| 463 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 464 |
+
|
| 465 |
+
if past_key_value is not None:
|
| 466 |
+
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
| 467 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 468 |
+
|
| 469 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 470 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 471 |
+
|
| 472 |
+
if attention_mask is not None:
|
| 473 |
+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
| 474 |
+
raise ValueError(
|
| 475 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
| 479 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
| 480 |
+
if query_states.device.type == "cuda" and attention_mask is not None:
|
| 481 |
+
query_states = query_states.contiguous()
|
| 482 |
+
key_states = key_states.contiguous()
|
| 483 |
+
value_states = value_states.contiguous()
|
| 484 |
+
|
| 485 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 486 |
+
query_states,
|
| 487 |
+
key_states,
|
| 488 |
+
value_states,
|
| 489 |
+
attn_mask=attention_mask,
|
| 490 |
+
dropout_p=self.attention_dropout if self.training else 0.0,
|
| 491 |
+
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
| 492 |
+
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 496 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
| 497 |
+
|
| 498 |
+
attn_output = self.o_proj(attn_output)
|
| 499 |
+
|
| 500 |
+
return attn_output, None, past_key_value
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
LLAMA_ATTENTION_CLASSES = {
|
| 505 |
+
"eager": LlamaAttention,
|
| 506 |
+
"flash_attention_2": LlamaFlashAttention2,
|
| 507 |
+
"sdpa": LlamaSdpaAttention,
|
| 508 |
+
}
|
| 509 |
+
|
| 510 |
+
class LlamaDecoderLayer(nn.Module):
|
| 511 |
+
def __init__(self, config: LlamaConfig, layer_idx):
|
| 512 |
+
super().__init__()
|
| 513 |
+
self.hidden_size = config.hidden_size
|
| 514 |
+
self.self_attn = LlamaAttention(config=config)
|
| 515 |
+
self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx)
|
| 516 |
+
self.mlp = LlamaMLP(config)
|
| 517 |
+
self.input_layernorm = MultiwayNetwork(module_provider=partial(
|
| 518 |
+
LlamaRMSNorm, hidden_size=config.hidden_size, eps=config.rms_norm_eps
|
| 519 |
+
))
|
| 520 |
+
self.post_attention_layernorm = MultiwayNetwork(module_provider=partial(
|
| 521 |
+
LlamaRMSNorm, hidden_size=config.hidden_size, eps=config.rms_norm_eps
|
| 522 |
+
))
|
| 523 |
+
|
| 524 |
+
def forward(
|
| 525 |
+
self,
|
| 526 |
+
hidden_states: torch.Tensor,
|
| 527 |
+
modality_indicators: torch.Tensor = None,
|
| 528 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 529 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 530 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 531 |
+
output_attentions: Optional[bool] = False,
|
| 532 |
+
use_cache: Optional[bool] = False,
|
| 533 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 534 |
+
"""
|
| 535 |
+
Args:
|
| 536 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 537 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
| 538 |
+
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
| 539 |
+
output_attentions (`bool`, *optional*):
|
| 540 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 541 |
+
returned tensors for more detail.
|
| 542 |
+
use_cache (`bool`, *optional*):
|
| 543 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 544 |
+
(see `past_key_values`).
|
| 545 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
| 546 |
+
"""
|
| 547 |
+
|
| 548 |
+
residual = hidden_states
|
| 549 |
+
|
| 550 |
+
hidden_states = self.input_layernorm(hidden_states, modality_indicators)
|
| 551 |
+
|
| 552 |
+
# Self Attention
|
| 553 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
| 554 |
+
hidden_states=hidden_states,
|
| 555 |
+
modality_indicators=modality_indicators,
|
| 556 |
+
attention_mask=attention_mask,
|
| 557 |
+
position_ids=position_ids,
|
| 558 |
+
past_key_value=past_key_value,
|
| 559 |
+
output_attentions=output_attentions,
|
| 560 |
+
use_cache=use_cache,
|
| 561 |
+
)
|
| 562 |
+
hidden_states = residual + hidden_states
|
| 563 |
+
|
| 564 |
+
# Fully Connected
|
| 565 |
+
residual = hidden_states
|
| 566 |
+
hidden_states = self.post_attention_layernorm(hidden_states, modality_indicators)
|
| 567 |
+
hidden_states = self.mlp(hidden_states)
|
| 568 |
+
hidden_states = residual + hidden_states
|
| 569 |
+
|
| 570 |
+
outputs = (hidden_states,)
|
| 571 |
+
|
| 572 |
+
if output_attentions:
|
| 573 |
+
outputs += (self_attn_weights,)
|
| 574 |
+
|
| 575 |
+
if use_cache:
|
| 576 |
+
outputs += (present_key_value,)
|
| 577 |
+
|
| 578 |
+
return outputs
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
def model_forward(
|
| 582 |
+
self,
|
| 583 |
+
input_ids: torch.LongTensor = None,
|
| 584 |
+
modality_indicators: torch.Tensor = None,
|
| 585 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 586 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 587 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 588 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 589 |
+
use_cache: Optional[bool] = None,
|
| 590 |
+
output_attentions: Optional[bool] = None,
|
| 591 |
+
output_hidden_states: Optional[bool] = None,
|
| 592 |
+
return_dict: Optional[bool] = None,
|
| 593 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 594 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 595 |
+
output_hidden_states = (
|
| 596 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 597 |
+
)
|
| 598 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 599 |
+
|
| 600 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 601 |
+
|
| 602 |
+
# retrieve input_ids and inputs_embeds
|
| 603 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 604 |
+
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
|
| 605 |
+
elif input_ids is not None:
|
| 606 |
+
batch_size, seq_length = input_ids.shape
|
| 607 |
+
elif inputs_embeds is not None:
|
| 608 |
+
batch_size, seq_length, _ = inputs_embeds.shape
|
| 609 |
+
else:
|
| 610 |
+
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
| 611 |
+
|
| 612 |
+
seq_length_with_past = seq_length
|
| 613 |
+
past_key_values_length = 0
|
| 614 |
+
|
| 615 |
+
if past_key_values is not None:
|
| 616 |
+
past_key_values_length = past_key_values[0][0].shape[2]
|
| 617 |
+
seq_length_with_past = seq_length_with_past + past_key_values_length
|
| 618 |
+
|
| 619 |
+
if position_ids is None:
|
| 620 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 621 |
+
position_ids = torch.arange(
|
| 622 |
+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
|
| 623 |
+
)
|
| 624 |
+
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
| 625 |
+
else:
|
| 626 |
+
position_ids = position_ids.view(-1, seq_length).long()
|
| 627 |
+
|
| 628 |
+
if inputs_embeds is None:
|
| 629 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
| 630 |
+
# embed positions
|
| 631 |
+
if attention_mask is None:
|
| 632 |
+
attention_mask = torch.ones(
|
| 633 |
+
(batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
if self._use_flash_attention_2:
|
| 637 |
+
# 2d mask is passed through the layers
|
| 638 |
+
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
| 639 |
+
elif self._use_sdpa and not output_attentions:
|
| 640 |
+
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
| 641 |
+
# the manual implementation that requires a 4D causal mask in all cases.
|
| 642 |
+
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
| 643 |
+
attention_mask,
|
| 644 |
+
(batch_size, seq_length),
|
| 645 |
+
inputs_embeds,
|
| 646 |
+
past_key_values_length,
|
| 647 |
+
)
|
| 648 |
+
else:
|
| 649 |
+
# 4d mask is passed through the layers
|
| 650 |
+
attention_mask = _prepare_4d_causal_attention_mask(
|
| 651 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
hidden_states = inputs_embeds
|
| 655 |
+
|
| 656 |
+
if self.gradient_checkpointing and self.training:
|
| 657 |
+
if use_cache:
|
| 658 |
+
logger.warning_once(
|
| 659 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 660 |
+
)
|
| 661 |
+
use_cache = False
|
| 662 |
+
|
| 663 |
+
# decoder layers
|
| 664 |
+
all_hidden_states = () if output_hidden_states else None
|
| 665 |
+
all_self_attns = () if output_attentions else None
|
| 666 |
+
next_decoder_cache = () if use_cache else None
|
| 667 |
+
|
| 668 |
+
for idx, decoder_layer in enumerate(self.layers):
|
| 669 |
+
if output_hidden_states:
|
| 670 |
+
all_hidden_states += (hidden_states,)
|
| 671 |
+
|
| 672 |
+
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
| 673 |
+
|
| 674 |
+
if self.gradient_checkpointing and self.training:
|
| 675 |
+
|
| 676 |
+
def create_custom_forward(module):
|
| 677 |
+
def custom_forward(*inputs):
|
| 678 |
+
# None for past_key_value
|
| 679 |
+
return module(*inputs, past_key_value, output_attentions)
|
| 680 |
+
|
| 681 |
+
return custom_forward
|
| 682 |
+
|
| 683 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
| 684 |
+
create_custom_forward(decoder_layer),
|
| 685 |
+
hidden_states,
|
| 686 |
+
modality_indicators,
|
| 687 |
+
attention_mask,
|
| 688 |
+
position_ids,
|
| 689 |
+
)
|
| 690 |
+
else:
|
| 691 |
+
layer_outputs = decoder_layer(
|
| 692 |
+
hidden_states,
|
| 693 |
+
modality_indicators=modality_indicators,
|
| 694 |
+
attention_mask=attention_mask,
|
| 695 |
+
position_ids=position_ids,
|
| 696 |
+
past_key_value=past_key_value,
|
| 697 |
+
output_attentions=output_attentions,
|
| 698 |
+
use_cache=use_cache,
|
| 699 |
+
)
|
| 700 |
+
|
| 701 |
+
hidden_states = layer_outputs[0]
|
| 702 |
+
|
| 703 |
+
if use_cache:
|
| 704 |
+
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
| 705 |
+
|
| 706 |
+
if output_attentions:
|
| 707 |
+
all_self_attns += (layer_outputs[1],)
|
| 708 |
+
|
| 709 |
+
hidden_states = self.norm(hidden_states)
|
| 710 |
+
|
| 711 |
+
# add hidden states from the last decoder layer
|
| 712 |
+
if output_hidden_states:
|
| 713 |
+
all_hidden_states += (hidden_states,)
|
| 714 |
+
|
| 715 |
+
next_cache = next_decoder_cache if use_cache else None
|
| 716 |
+
if not return_dict:
|
| 717 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
| 718 |
+
return BaseModelOutputWithPast(
|
| 719 |
+
last_hidden_state=hidden_states,
|
| 720 |
+
past_key_values=next_cache,
|
| 721 |
+
hidden_states=all_hidden_states,
|
| 722 |
+
attentions=all_self_attns,
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
|
| 726 |
+
def causal_model_forward(
|
| 727 |
+
self,
|
| 728 |
+
input_ids: torch.LongTensor = None,
|
| 729 |
+
modality_indicators: torch.Tensor = None,
|
| 730 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 731 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 732 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 733 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 734 |
+
labels: Optional[torch.LongTensor] = None,
|
| 735 |
+
use_cache: Optional[bool] = None,
|
| 736 |
+
output_attentions: Optional[bool] = None,
|
| 737 |
+
output_hidden_states: Optional[bool] = None,
|
| 738 |
+
return_dict: Optional[bool] = None,
|
| 739 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
| 740 |
+
r"""
|
| 741 |
+
Args:
|
| 742 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 743 |
+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
| 744 |
+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
| 745 |
+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
| 746 |
+
|
| 747 |
+
Returns:
|
| 748 |
+
|
| 749 |
+
Example:
|
| 750 |
+
|
| 751 |
+
```python
|
| 752 |
+
>>> from transformers import AutoTokenizer, LlamaForCausalLM
|
| 753 |
+
|
| 754 |
+
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
|
| 755 |
+
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
|
| 756 |
+
|
| 757 |
+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
|
| 758 |
+
>>> inputs = tokenizer(prompt, return_tensors="pt")
|
| 759 |
+
|
| 760 |
+
>>> # Generate
|
| 761 |
+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 762 |
+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 763 |
+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
|
| 764 |
+
```"""
|
| 765 |
+
|
| 766 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 767 |
+
output_hidden_states = (
|
| 768 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 769 |
+
)
|
| 770 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 771 |
+
|
| 772 |
+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
| 773 |
+
outputs = self.model(
|
| 774 |
+
input_ids=input_ids,
|
| 775 |
+
modality_indicators=modality_indicators,
|
| 776 |
+
attention_mask=attention_mask,
|
| 777 |
+
position_ids=position_ids,
|
| 778 |
+
past_key_values=past_key_values,
|
| 779 |
+
inputs_embeds=inputs_embeds,
|
| 780 |
+
use_cache=use_cache,
|
| 781 |
+
output_attentions=output_attentions,
|
| 782 |
+
output_hidden_states=output_hidden_states,
|
| 783 |
+
return_dict=return_dict,
|
| 784 |
+
)
|
| 785 |
+
|
| 786 |
+
hidden_states = outputs[0]
|
| 787 |
+
if self.config.pretraining_tp > 1:
|
| 788 |
+
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
|
| 789 |
+
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
|
| 790 |
+
logits = torch.cat(logits, dim=-1)
|
| 791 |
+
else:
|
| 792 |
+
logits = self.lm_head(hidden_states)
|
| 793 |
+
logits = logits.float()
|
| 794 |
+
|
| 795 |
+
loss = None
|
| 796 |
+
if labels is not None:
|
| 797 |
+
# Shift so that tokens < n predict n
|
| 798 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 799 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 800 |
+
# Flatten the tokens
|
| 801 |
+
loss_fct = CrossEntropyLoss()
|
| 802 |
+
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
| 803 |
+
shift_labels = shift_labels.view(-1)
|
| 804 |
+
# Enable model parallelism
|
| 805 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 806 |
+
loss = loss_fct(shift_logits, shift_labels)
|
| 807 |
+
|
| 808 |
+
if not return_dict:
|
| 809 |
+
output = (logits,) + outputs[1:]
|
| 810 |
+
return (loss,) + output if loss is not None else output
|
| 811 |
+
|
| 812 |
+
return CausalLMOutputWithPast(
|
| 813 |
+
loss=loss,
|
| 814 |
+
logits=logits,
|
| 815 |
+
past_key_values=outputs.past_key_values,
|
| 816 |
+
hidden_states=outputs.hidden_states,
|
| 817 |
+
attentions=outputs.attentions,
|
| 818 |
+
)
|
| 819 |
+
|
| 820 |
+
def replace_llama_modality_adaptive():
|
| 821 |
+
transformers.models.llama.configuration_llama.LlamaConfig = LlamaConfig
|
| 822 |
+
transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention
|
| 823 |
+
transformers.models.llama.modeling_llama.LlamaFlashAttention2 = LlamaFlashAttention2
|
| 824 |
+
transformers.models.llama.modeling_llama.LlamaSdpaAttention = LlamaSdpaAttention
|
| 825 |
+
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
|
| 826 |
+
transformers.models.llama.modeling_llama.LlamaModel.forward = model_forward
|
| 827 |
+
transformers.models.llama.modeling_llama.LlamaForCausalLM.forward = causal_model_forward
|
| 828 |
+
|
| 829 |
+
|
| 830 |
+
if __name__ == "__main__":
|
| 831 |
+
replace_llama_modality_adaptive()
|
| 832 |
+
config = transformers.LlamaConfig.from_pretrained('/cpfs01/shared/public/test/vicuna-7b-v1.5/')
|
| 833 |
+
model = transformers.LlamaForCausalLM(config)
|
| 834 |
+
print(model)
|