Imaginethat commited on
Commit
8a11f7f
·
verified ·
1 Parent(s): 65cddc5

Upload 68 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. README.md +92 -33
  3. assets/avocado.ico +0 -0
  4. assets/case_1.mp4 +3 -0
  5. assets/case_2.png +3 -0
  6. environment.yml +327 -0
  7. eval_scripts/DREAM-1K/dream_example.jsonl +0 -0
  8. eval_scripts/DREAM-1K/eval_DREAM-1K.sh +12 -0
  9. eval_scripts/DREAM-1K/generate_caption.py +91 -0
  10. eval_scripts/DREAM-1K/tarsier/LICENSE +201 -0
  11. eval_scripts/DREAM-1K/tarsier/configs/tarser2_default_config.yaml +14 -0
  12. eval_scripts/DREAM-1K/tarsier/dataset/custom_data_parsers/multi_images_parser.py +199 -0
  13. eval_scripts/DREAM-1K/tarsier/dataset/custom_data_parsers/object_tracking_parser.py +160 -0
  14. eval_scripts/DREAM-1K/tarsier/dataset/custom_data_parsers/standard_vision_parser.py +255 -0
  15. eval_scripts/DREAM-1K/tarsier/dataset/custom_data_parsers/utils.py +452 -0
  16. eval_scripts/DREAM-1K/tarsier/dataset/custom_data_parsers/utils_visualize.py +54 -0
  17. eval_scripts/DREAM-1K/tarsier/dataset/custom_data_parsers/video_permutation_parser.py +137 -0
  18. eval_scripts/DREAM-1K/tarsier/dataset/tarsier_datamodule.py +280 -0
  19. eval_scripts/DREAM-1K/tarsier/dataset/tarsier_processor.py +240 -0
  20. eval_scripts/DREAM-1K/tarsier/dataset/utils.py +186 -0
  21. eval_scripts/DREAM-1K/tarsier/evaluation/evaluate.py +177 -0
  22. eval_scripts/DREAM-1K/tarsier/evaluation/metrics/__init__.py +5 -0
  23. eval_scripts/DREAM-1K/tarsier/evaluation/metrics/evaluate_caption_cider.py +82 -0
  24. eval_scripts/DREAM-1K/tarsier/evaluation/metrics/evaluate_dream_gpt.py +436 -0
  25. eval_scripts/DREAM-1K/tarsier/evaluation/metrics/evaluate_qa_mc.py +159 -0
  26. eval_scripts/DREAM-1K/tarsier/evaluation/metrics/evaluate_qa_oe_gpt.py +153 -0
  27. eval_scripts/DREAM-1K/tarsier/evaluation/metrics/evaluate_video_mme.py +358 -0
  28. eval_scripts/DREAM-1K/tarsier/models/modeling_qwen2_vl_fast.py +1320 -0
  29. eval_scripts/DREAM-1K/tarsier/models/modeling_tarsier.py +502 -0
  30. eval_scripts/DREAM-1K/tarsier/models/utils.py +17 -0
  31. eval_scripts/DREAM-1K/tarsier/scripts/run_demo_cli.sh +15 -0
  32. eval_scripts/DREAM-1K/tarsier/scripts/run_demo_gradio.sh +9 -0
  33. eval_scripts/DREAM-1K/tarsier/scripts/run_evaluation_only.sh +12 -0
  34. eval_scripts/DREAM-1K/tarsier/scripts/run_inference_benchmark.sh +80 -0
  35. eval_scripts/DREAM-1K/tarsier/scripts/run_inference_caption.sh +79 -0
  36. eval_scripts/DREAM-1K/tarsier/tasks/demo_cli.py +116 -0
  37. eval_scripts/DREAM-1K/tarsier/tasks/demo_gradio.py +230 -0
  38. eval_scripts/DREAM-1K/tarsier/tasks/inference_benchmark.py +197 -0
  39. eval_scripts/DREAM-1K/tarsier/tasks/inference_caption.py +165 -0
  40. eval_scripts/DREAM-1K/tarsier/tasks/inference_quick_start.py +91 -0
  41. eval_scripts/DREAM-1K/tarsier/tasks/utils.py +45 -0
  42. eval_scripts/DREAM-1K/tarsier/tools/color.py +36 -0
  43. eval_scripts/DREAM-1K/tarsier/tools/conversation.py +256 -0
  44. eval_scripts/DREAM-1K/tarsier/tools/ptbtokenizer.py +66 -0
  45. eval_scripts/DREAM-1K/tarsier/tools/rw_utils.py +64 -0
  46. eval_scripts/Daily-Omni/Daily-Omni_pipeline.sh +62 -0
  47. eval_scripts/Daily-Omni/analysis.py +18 -0
  48. eval_scripts/Daily-Omni/evaluation.py +225 -0
  49. eval_scripts/Daily-Omni/generate_caption.py +142 -0
  50. eval_scripts/Daily-Omni/grouped_data.json +0 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/case_1.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ assets/case_2.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,48 +1,107 @@
1
- ---
2
- title: Avocado On Toast
3
- emoji: 🥑
4
- colorFrom: green
5
- colorTo: yellow
6
- sdk: gradio
7
- app_file: app.py
8
- pinned: false
9
  ---
10
 
11
- # Avocado On Toast
 
12
 
13
- Gradio Space that standardizes uploaded videos with ffmpeg, runs AVoCaDO, and
14
- writes JSONL outputs into a run-labeled dataset folder.
 
15
 
16
- ## Usage (Hugging Face Space)
 
17
 
18
- 1. Create a new Gradio Space and connect it to this repo.
19
- 2. Ensure `ffmpeg` and AVoCaDO are available in the Space image.
20
- 3. Set `AVOCADO_CMD` as a Space secret or enter it in the UI textbox.
21
- - The command must accept `{input}` and `{output}` placeholders.
22
- - Example:
23
- ```bash
24
- python -m avocado_captioner.cli --input {input} --output {output}
25
- ```
26
- 4. Launch the Space, upload videos, and click **Process videos**.
27
 
28
- Outputs are written to:
 
 
 
29
 
 
 
 
 
 
 
30
  ```
31
- /data/dataset/runs/<run_label>/annotations.jsonl
 
 
 
32
  ```
33
 
34
- Each run produces `annotations.jsonl` and `manifest.json` under the run label.
 
 
 
 
 
 
 
35
 
36
- ## Configuration
 
 
 
37
 
38
- - `DATA_ROOT` (default: `/data`) controls the root folder for uploads and output.
39
- - `AVOCADO_CMD` defines the AVoCaDO command used during processing.
 
 
 
40
 
41
- ## Notes
 
 
 
42
 
43
- The standardization step uses:
44
- - 720p scaling (`scale=-2:720`)
45
- - H.264 video (`libx264`, `crf=23`)
46
- - AAC audio (`128k`)
 
 
 
 
47
 
48
- Adjust these settings in `app.py` if you want a different quality/compute tradeoff.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # <img src="assets/avocado.ico" alt="AVoCaDO icon" width="28px"> AVoCaDO: An <u>A</u>udio<u>V</u>isual Vide<u>o</u> <u>Ca</u>ptioner <u>D</u>riven by Temporal <u>O</u>rchestration
2
+
3
+ <p align="left">
4
+ <a href="https://avocado-captioner.github.io/"><img src="https://img.shields.io/badge/Project%20webpage-558b2f?style=for-the-badge"></a>
5
+ <a href="https://huggingface.co/AVoCaDO-Captioner/AVoCaDO"><img src="https://img.shields.io/badge/Model-db8905?style=for-the-badge"></a>
6
+ <a href="https://arxiv.org/abs/2510.10395"><img src="https://img.shields.io/badge/arXiv-red?style=for-the-badge"></a>
7
+ </p>
8
+
9
  ---
10
 
11
+ ## Overview
12
+ Audiovisual video captioning aims to generate semantically rich descriptions with temporal alignment between visual and auditory events, thereby benefiting both video understanding and generation. We introduce <b>AVoCaDO</b>, a powerful audiovisual video captioner driven by the temporal orchestration between audio and visual modalities. Experimental results demonstrate that AVoCaDO significantly outperforms existing open-source models across four audiovisual video captioning benchmarks, and also achieves competitive performance under visual-only settings.
13
 
14
+ ## 🎬 Captioning Case of AVoCaDO
15
+ <img src="assets/case_2.png" alt="AVoCaDO caption">
16
+ An illustration of a video caption generated by AVoCaDO, featuring both <b>precise audiovisual temporal alignment</b> and <u>accurate dialogue rendering</u>.
17
 
18
+ ## 🚀 Getting Started
19
+ Follow these simple steps to set up and run AVoCaDO on your machine.
20
 
21
+ ### 1. Clone the repository
22
+ First, clone the project and navigate into the directory:
 
 
 
 
 
 
 
23
 
24
+ ```bash
25
+ git clone https://github.com/AVoCaDO-Captioner/AVoCaDO.git
26
+ cd AVoCaDO
27
+ ```
28
 
29
+ ### 2. Set Up the Environment
30
+ Create and activate the Conda environment using the provided ``environment.yml`` file.
31
+
32
+ ```bash
33
+ conda env create -f environment.yml
34
+ conda activate AVoCaDO
35
  ```
36
+
37
+ ### 3. Quick Usage
38
+ ```python
39
+ python inference.py assets/case_1.mp4
40
  ```
41
 
42
+ ## 📈 Benchmark Evaluation
43
+ We provide evaluation scripts for all evaluated benchmarks in our paper.
44
+
45
+ ### Direct Audiovisual Caption Evaluation
46
+ 1. **video-SALMONN2-testset:**
47
+ ```bash
48
+ bash eval_scripts/video-SALMONN2-testset/eval_video-SALMONN2-test.sh <your_save_directory>
49
+ ```
50
 
51
+ 2. **UGC-VideoCap:**
52
+ ```bash
53
+ bash eval_scripts/UGC-VideoCap/eval_UGC-VideoCap.sh <your_save_directory>
54
+ ```
55
 
56
+ ### QA-based Audiovisual Caption Evaluation
57
+ 1. **Daily-Omni:**
58
+ ```bash
59
+ bash eval_scripts/Daily-Omni/Daily-Omni_pipeline.sh <your_save_directory>
60
+ ```
61
 
62
+ 2. **WorldSense:**
63
+ ```bash
64
+ bash eval_scripts/WorldSense/WorldSense_pipeline.sh <your_save_directory>
65
+ ```
66
 
67
+ ### Visual-only Caption Evaluation
68
+ 1. **VDC:**
69
+ First, generate captions for the videos in the VDC benchmark.
70
+ ```python
71
+ python eval_scripts/VDC/generate_caption.py \
72
+ --model_path <path_to_AVoCaDO> \
73
+ --fout_path <your_save_path>
74
+ ```
75
 
76
+ Next, set up the judge server. This requires installing [SGLang](https://github.com/sgl-project/sglang) to deploy the [Llama-3.1-8B](https://huggingface.co/meta-llama/Llama-3.1-8B) as the judge model.
77
+ ```python
78
+ # Deploy the judge model using SGLang
79
+ python -m sglang.launch_server \
80
+ --model-path path_to_Meta-Llama-3.1-8B-Instruct \
81
+ --port 30000 \
82
+ --dp 2 --tp 4
83
+ ```
84
+
85
+ Once the judge model is successfully deployed and running, you can start the evaluation.
86
+ ```bash
87
+ bash AVoCaDO/eval_scripts/VDC/evaluation.sh <your_save_path>
88
+ ```
89
+
90
+ 2. **DREAM-1K:**
91
+ ```bash
92
+ bash eval_scripts/DREAM-1K/eval_DREAM-1K.sh <your_save_directory>
93
+ ```
94
+
95
+
96
+ ## ✒️ Citation
97
+
98
+ If you find our work helpful for your research, please consider giving a star ⭐ and citing our paper. We appreciate your support!
99
+
100
+ ```bibtex
101
+ @article{chen2025avocado,
102
+ title={AVoCaDO: An Audiovisual Video Captioner Driven by Temporal Orchestration},
103
+ author={Chen, Xinlong and Ding, Yue and Lin, Weihong and Hua, Jingyun and Yao, Linli and Shi, Yang and Li, Bozhou and Zhang, Yuanxing and Liu, Qiang and Wan, Pengfei and others},
104
+ journal={arXiv preprint arXiv:2510.10395},
105
+ year={2025}
106
+ }
107
+ ```
assets/avocado.ico ADDED
assets/case_1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78c17b195eae977e5ffdeafb499fbeeec2ea50f9258973154c0da111c2b90b07
3
+ size 6417271
assets/case_2.png ADDED

Git LFS Details

  • SHA256: 46159b25d4560ca19ab7cbe605de1762e71ce1fc3c4f2ed72321af1b268fe2bc
  • Pointer size: 132 Bytes
  • Size of remote file: 2.56 MB
environment.yml ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: AVoCaDO
2
+ channels:
3
+ - conda-forge
4
+ - defaults
5
+ dependencies:
6
+ - _libgcc_mutex=0.1=main
7
+ - _openmp_mutex=5.1=1_gnu
8
+ - bzip2=1.0.8=h5eee18b_6
9
+ - ca-certificates=2024.9.24=h06a4308_0
10
+ - ld_impl_linux-64=2.40=h12ee557_0
11
+ - libffi=3.4.4=h6a678d5_1
12
+ - libgcc-ng=11.2.0=h1234567_1
13
+ - libgomp=11.2.0=h1234567_1
14
+ - libstdcxx-ng=11.2.0=h1234567_1
15
+ - libuuid=1.41.5=h5eee18b_0
16
+ - mscorefonts=0.0.1=3
17
+ - ncurses=6.4=h6a678d5_0
18
+ - openssl=3.0.15=h5eee18b_0
19
+ - pip=24.2=py310h06a4308_0
20
+ - python=3.10.15=he870216_1
21
+ - readline=8.2=h5eee18b_0
22
+ - setuptools=75.1.0=py310h06a4308_0
23
+ - sqlite=3.45.3=h5eee18b_0
24
+ - tk=8.6.14=h39e8969_0
25
+ - wheel=0.44.0=py310h06a4308_0
26
+ - xz=5.4.6=h5eee18b_1
27
+ - zlib=1.2.13=h5eee18b_1
28
+ - pip:
29
+ - accelerate==0.34.1
30
+ - aiodns==3.2.0
31
+ - aiohappyeyeballs==2.4.3
32
+ - aiohttp==3.10.10
33
+ - aiosignal==1.3.1
34
+ - annotated-types==0.7.0
35
+ - antlr4-python3-runtime==4.11.0
36
+ - anyio==4.6.2.post1
37
+ - apscheduler==3.10.4
38
+ - asttokens==2.4.1
39
+ - async-timeout==4.0.3
40
+ - attrs==24.2.0
41
+ - audioread==3.0.1
42
+ - av==13.1.0
43
+ - backcall==0.2.0
44
+ - beartype==0.19.0
45
+ - beautifulsoup4==4.13.4
46
+ - bleach==6.2.0
47
+ - blinker==1.9.0
48
+ - cachetools==4.2.4
49
+ - cchardet==2.1.7
50
+ - certifi==2021.10.8
51
+ - cffi==1.17.1
52
+ - charset-normalizer==3.4.0
53
+ - cityhash==0.2.4.post11
54
+ - click==8.1.7
55
+ - cloudpickle==3.1.0
56
+ - colorama==0.4.6
57
+ - compressed-tensors==0.8.0
58
+ - contourpy==1.3.1
59
+ - cycler==0.12.1
60
+ - dapr==1.14.0
61
+ - dapr-ext-fastapi==1.14.0
62
+ - dask==2024.12.1
63
+ - datasets==3.1.0
64
+ - dbpool==1.2.1
65
+ - decorator==4.4.2
66
+ - decord==0.6.0
67
+ - deepspeed==0.16.2
68
+ - defusedxml==0.7.1
69
+ - dill==0.4.0
70
+ - diskcache==5.6.3
71
+ - distro==1.9.0
72
+ - docopt==0.6.2
73
+ - docstring-parser==0.16
74
+ - dsc-auth==0.1.18
75
+ - einops==0.8.0
76
+ - et-xmlfile==2.0.0
77
+ - exceptiongroup==1.2.2
78
+ - executing==2.1.0
79
+ - fastapi==0.115.5
80
+ - fastjsonschema==2.21.1
81
+ - ffmpeg-python==0.2.0
82
+ - filelock==3.13.1
83
+ - fire==0.7.0
84
+ - flash-attn==2.7.0.post2
85
+ - flask==3.1.0
86
+ - fonttools==4.55.0
87
+ - frozenlist==1.5.0
88
+ - fsspec==2024.2.0
89
+ - ftfy==6.3.1
90
+ - func-timeout==4.3.5
91
+ - future==1.0.0
92
+ - fvcore==0.1.5.post20221221
93
+ - gguf==0.10.0
94
+ - google-api-core==2.23.0
95
+ - google-auth==2.36.0
96
+ - google-cloud-aiplatform==1.71.1
97
+ - google-cloud-bigquery==3.27.0
98
+ - google-cloud-core==2.4.1
99
+ - google-cloud-resource-manager==1.13.1
100
+ - google-cloud-storage==2.18.2
101
+ - google-crc32c==1.6.0
102
+ - google-resumable-media==2.7.2
103
+ - googleapis-common-protos==1.66.0
104
+ - grpc-google-iam-v1==0.13.1
105
+ - grpcio==1.68.0
106
+ - grpcio-reflection==1.48.2
107
+ - grpcio-status==1.68.0
108
+ - h11==0.14.0
109
+ - h5py==3.12.1
110
+ - hf-xet==1.1.4
111
+ - hickle==5.0.3
112
+ - hiredis==2.4.0
113
+ - hjson==3.1.0
114
+ - httpcore==1.0.6
115
+ - httptools==0.6.4
116
+ - httpx==0.27.2
117
+ - huggingface-hub==0.33.0
118
+ - humanize==4.11.0
119
+ - icecream==2.1.3
120
+ - idna==3.10
121
+ - imageio==2.36.0
122
+ - imageio-ffmpeg==0.5.1
123
+ - importlib-metadata==8.5.0
124
+ - infra-component==1.4.7
125
+ - infra-framework==1.17.10
126
+ - infra-kconf==1.1.3
127
+ - infra-kess==1.1.5
128
+ - infra-keycenter==1.1.1
129
+ - infra-storage==1.3.1
130
+ - install==1.3.5
131
+ - interegular==0.3.3
132
+ - iopath==0.1.10
133
+ - ipdb==0.13.13
134
+ - ipython==8.12.3
135
+ - itsdangerous==2.2.0
136
+ - jedi==0.19.2
137
+ - jinja2==3.1.3
138
+ - jiter==0.7.0
139
+ - joblib==1.4.2
140
+ - jsonschema==4.23.0
141
+ - jsonschema-specifications==2024.10.1
142
+ - jupyter-client==8.6.3
143
+ - jupyter-core==5.8.1
144
+ - jupyterlab-pygments==0.3.0
145
+ - kazoo==2.10.0
146
+ - kiwisolver==1.4.7
147
+ - ks-kafka-python==2.0.3
148
+ - lark==1.2.2
149
+ - lazy-loader==0.4
150
+ - levenshtein==0.26.1
151
+ - librosa==0.11.0
152
+ - llvmlite==0.43.0
153
+ - lm-format-enforcer==0.10.6
154
+ - locket==1.0.0
155
+ - lxml==4.9.4
156
+ - lz4==3.1.10
157
+ - markupsafe==2.1.5
158
+ - matplotlib==3.10.0
159
+ - matplotlib-inline==0.1.7
160
+ - mergedeep==1.3.4
161
+ - mistral-common==1.5.1
162
+ - mistral-inference==1.5.0
163
+ - mistune==3.1.3
164
+ - moviepy==1.0.3
165
+ - mpmath==1.3.0
166
+ - msgpack==1.1.0
167
+ - msgpack-numpy==0.4.8
168
+ - msgspec==0.18.6
169
+ - multidict==6.1.0
170
+ - multiprocess==0.70.18
171
+ - mysql-connector-python==8.0.31
172
+ - nbclient==0.10.2
173
+ - nbconvert==7.16.6
174
+ - nbformat==5.10.4
175
+ - nest-asyncio==1.6.0
176
+ - networkx==3.2.1
177
+ - ninja==1.11.1.3
178
+ - numba==0.60.0
179
+ - numpy==1.26.3
180
+ - nvidia-cublas-cu12==12.4.5.8
181
+ - nvidia-cuda-cupti-cu12==12.4.127
182
+ - nvidia-cuda-nvrtc-cu12==12.4.127
183
+ - nvidia-cuda-runtime-cu12==12.4.127
184
+ - nvidia-cudnn-cu12==9.1.0.70
185
+ - nvidia-cufft-cu12==11.2.1.3
186
+ - nvidia-curand-cu12==10.3.5.147
187
+ - nvidia-cusolver-cu12==11.6.1.9
188
+ - nvidia-cusparse-cu12==12.3.1.170
189
+ - nvidia-cusparselt-cu12==0.6.2
190
+ - nvidia-ml-py==12.560.30
191
+ - nvidia-nccl-cu12==2.21.5
192
+ - nvidia-nvjitlink-cu12==12.4.127
193
+ - nvidia-nvtx-cu12==12.4.127
194
+ - nvitop==1.3.2
195
+ - omegaconf==2.3.0
196
+ - open-clip-torch==2.29.0
197
+ - openai==1.54.3
198
+ - opencv-python==4.10.0.84
199
+ - opencv-python-headless==4.10.0.84
200
+ - openpyxl==3.1.5
201
+ - outlines==0.0.46
202
+ - packaging==24.1
203
+ - pandas==2.2.3
204
+ - pandocfilters==1.5.1
205
+ - parameterized==0.9.0
206
+ - parso==0.8.4
207
+ - partd==1.4.2
208
+ - partial-json-parser==0.2.1.1.post4
209
+ - pathos==0.3.4
210
+ - peft==0.14.0
211
+ - pexpect==4.9.0
212
+ - pickleshare==0.7.5
213
+ - pillow==10.4.0
214
+ - pipreqs==0.5.0
215
+ - platformdirs==4.3.6
216
+ - pooch==1.8.2
217
+ - portalocker==3.0.0
218
+ - pox==0.3.6
219
+ - ppft==1.7.7
220
+ - prettytable==2.5.0
221
+ - proglog==0.1.10
222
+ - prometheus-client==0.21.0
223
+ - prometheus-fastapi-instrumentator==7.0.0
224
+ - prompt-toolkit==3.0.48
225
+ - propcache==0.2.0
226
+ - proto-plus==1.25.0
227
+ - protobuf==3.20.0
228
+ - psutil==6.1.0
229
+ - ptyprocess==0.7.0
230
+ - pure-eval==0.2.3
231
+ - py-cpuinfo==9.0.0
232
+ - pyairports==2.1.1
233
+ - pyarrow==18.0.0
234
+ - pyasn1==0.6.1
235
+ - pyasn1-modules==0.4.1
236
+ - pycares==4.4.0
237
+ - pycocoevalcap==1.2
238
+ - pycocotools==2.0.8
239
+ - pycountry==24.6.1
240
+ - pycparser==2.22
241
+ - pycryptodome==3.21.0
242
+ - pydantic==2.9.2
243
+ - pydantic-core==2.23.4
244
+ - pydub==0.25.1
245
+ - pygments==2.18.0
246
+ - pyparsing==3.2.0
247
+ - pysmhasher==0.2.5
248
+ - pysoundfile==0.9.0.post1
249
+ - python-dateutil==2.9.0.post0
250
+ - python-dotenv==1.0.1
251
+ - python-levenshtein==0.26.1
252
+ - python-snappy==0.6.1
253
+ - pytorchvideo==0.1.5
254
+ - pytube==15.0.0
255
+ - pytz==2021.3
256
+ - pytz-deprecation-shim==0.1.0.post0
257
+ - pyyaml==6.0.2
258
+ - pyzmq==26.2.0
259
+ - qwen-omni-utils==0.0.8
260
+ - qwen-vl-utils==0.0.8
261
+ - rapidfuzz==3.12.2
262
+ - ray==2.38.0
263
+ - redis==4.6.0
264
+ - referencing==0.35.1
265
+ - regex==2024.9.11
266
+ - requests==2.32.3
267
+ - rpds-py==0.21.0
268
+ - rsa==4.9
269
+ - safetensors==0.4.5
270
+ - scenedetect==0.6.4
271
+ - scikit-learn==1.6.0
272
+ - scipy==1.14.1
273
+ - seaborn==0.13.2
274
+ - sentencepiece==0.2.0
275
+ - setuptools-scm==8.1.0
276
+ - shapely==2.0.6
277
+ - simple-parsing==0.1.6
278
+ - six==1.16.0
279
+ - sniffio==1.3.1
280
+ - soundfile==0.13.1
281
+ - soupsieve==2.7
282
+ - soxr==0.5.0.post1
283
+ - sqlparse==0.4.4
284
+ - stack-data==0.6.3
285
+ - starlette==0.41.3
286
+ - sympy==1.13.1
287
+ - tabulate==0.9.0
288
+ - termcolor==2.5.0
289
+ - threadpoolctl==3.5.0
290
+ - tiktoken==0.7.0
291
+ - timm==1.0.12
292
+ - tinycss2==1.4.0
293
+ - tokenizers==0.21.1
294
+ - tomli==2.0.2
295
+ - toolz==1.0.0
296
+ - torch==2.6.0
297
+ - torchvision==0.21.0
298
+ - tornado==6.4.1
299
+ - tqdm==4.67.0
300
+ - traitlets==5.14.3
301
+ - transformers==4.52.3
302
+ - triton==3.2.0
303
+ - typeguard==4.4.1
304
+ - typing-extensions==4.12.2
305
+ - tzdata==2024.2
306
+ - tzlocal==4.3.1
307
+ - unpaddedbase64==2.1.0
308
+ - urllib3==1.26.20
309
+ - uvicorn==0.32.0
310
+ - uvloop==0.21.0
311
+ - vertexai==1.71.1
312
+ - vllm==0.6.3
313
+ - watchfiles==0.24.0
314
+ - wcwidth==0.2.13
315
+ - webencodings==0.5.1
316
+ - websockets==13.1
317
+ - werkzeug==3.1.3
318
+ - word2number==1.1
319
+ - xformers==0.0.27.post2
320
+ - xmltodict==0.12.0
321
+ - xxhash==3.5.0
322
+ - yacs==0.1.8
323
+ - yarg==0.1.9
324
+ - yarl==1.17.1
325
+ - zhon==2.1.1
326
+ - zipp==3.20.2
327
+ - zsvision==0.7.12
eval_scripts/DREAM-1K/dream_example.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
eval_scripts/DREAM-1K/eval_DREAM-1K.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ MODEL_PATH="path_to_AVoCaDO" # TODO
3
+ OUTPUT_DIR="$1"
4
+
5
+ mkdir -p "$OUTPUT_DIR"
6
+
7
+ python eval_scripts/DREAM-1K/generate_caption.py \
8
+ --model_path "$MODEL_PATH" \
9
+ --save_path "$OUTPUT_DIR/model_caption.jsonl"
10
+
11
+ bash eval_scripts/DREAM-1K/tarsier/scripts/run_evaluation_only.sh "$OUTPUT_DIR/model_caption.jsonl"
12
+
eval_scripts/DREAM-1K/generate_caption.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor
4
+ from qwen_omni_utils import process_mm_info
5
+ import argparse
6
+ import json
7
+ from tqdm import tqdm
8
+ from pathlib import Path
9
+
10
+ VIDEO_MAX_PIXELS = 401408 # 512*28*28
11
+ VIDEO_TOTAL_PIXELS = 20070400 # 512*28*28*50
12
+ USE_AUDIO_IN_VIDEO = False
13
+ os.environ['VIDEO_MAX_PIXELS'] = str(VIDEO_TOTAL_PIXELS)
14
+ script_dir = Path(__file__).resolve().parent
15
+ example_path = script_dir / "dream_example.jsonl"
16
+ video_dir = "" # TODO
17
+
18
+ parser = argparse.ArgumentParser(description="Evaluate a model and save results.")
19
+ parser.add_argument("--model_path", type=str, required=True, help="Path to the model checkpoint.")
20
+ parser.add_argument("--save_path", type=str, required=True, help="Path to save the evaluation results.")
21
+ args = parser.parse_args()
22
+
23
+ model_path = args.model_path
24
+ fout_path = args.save_path
25
+
26
+ f_example = open(example_path, 'r', encoding='utf-8')
27
+ fout = open(fout_path, 'w', encoding='utf-8')
28
+
29
+ model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
30
+ model_path,
31
+ torch_dtype=torch.bfloat16,
32
+ device_map="auto",
33
+ attn_implementation="flash_attention_2",
34
+ )
35
+ model.disable_talker()
36
+ processor = Qwen2_5OmniProcessor.from_pretrained(model_path)
37
+
38
+ def chat(data):
39
+ conversation = [
40
+ {
41
+ "role": "system",
42
+ "content": [
43
+ {"type": "text", "text": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."}
44
+ ],
45
+ },
46
+ {
47
+ "role": "user",
48
+ "content": [
49
+ {
50
+ "type": "video",
51
+ "video": data["video_path"],
52
+ "max_pixels": VIDEO_MAX_PIXELS,
53
+ },
54
+ {
55
+ "type": "text",
56
+ "text": data["question"]
57
+ },
58
+ ],
59
+ },
60
+ ]
61
+
62
+ text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
63
+ audios, images, videos = process_mm_info(conversation, use_audio_in_video=USE_AUDIO_IN_VIDEO)
64
+ inputs = processor(text=text, audio=audios, images=images, videos=videos, return_tensors="pt", padding=True, use_audio_in_video=USE_AUDIO_IN_VIDEO)
65
+ inputs = inputs.to(model.device).to(model.dtype)
66
+
67
+ text_ids = model.generate(**inputs, use_audio_in_video=USE_AUDIO_IN_VIDEO, do_sample=False, thinker_max_new_tokens=2048)
68
+
69
+ text = processor.batch_decode(text_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
70
+ model_generation = text.split("\nassistant\n")[-1]
71
+
72
+ return model_generation
73
+
74
+
75
+ for idx, line in tqdm(enumerate(f_example, start=1)):
76
+ data = json.loads(line)
77
+ video_path = os.path.join(video_dir, data["messages"][0]["content"][0]["video"]["video_file"])
78
+ question = "Imagine the video from these frames and describe it in detail."
79
+
80
+ temp_data = {
81
+ "video_path": video_path,
82
+ "question": question,
83
+ }
84
+ with torch.inference_mode():
85
+ response = chat(temp_data)
86
+
87
+ out_data = data
88
+ data["messages"][0]["content"][1]["text"] = question
89
+ out_data["messages"][1]["content"][0]["text"] = response
90
+ fout.write(json.dumps(out_data, ensure_ascii=False) + '\n')
91
+ fout.flush()
eval_scripts/DREAM-1K/tarsier/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
eval_scripts/DREAM-1K/tarsier/configs/tarser2_default_config.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ max_n_frames: 256
2
+ n_frames: 16
3
+ max_pixels: 460800 # 1280 * 720 // 2
4
+ min_pixels: 0
5
+ max_seq_len: 16384
6
+ is_training: false # 会影响:1. 训练和测试时采帧不同;2. 测试时忽略 response。
7
+ print_data_error: true
8
+ is_training: false
9
+ do_image_padding: false
10
+ do_image_crop: false
11
+ do_image_resize: false
12
+ video_sampling_strategy: {'video_sampler_version': 'v1', 'force_frames_n_divisible': 1, 'use_multi_images_for_video': true}
13
+ prompt: ""
14
+ train_task: sft
eval_scripts/DREAM-1K/tarsier/dataset/custom_data_parsers/multi_images_parser.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ import random
3
+ import re
4
+ from PIL import Image
5
+
6
+ from .utils import sample_video, read_image
7
+
8
+ class MultiImagesParser:
9
+ def __init__(
10
+ self,
11
+ n_frames=8,
12
+ is_training=True,
13
+ ):
14
+ self.n_frames = n_frames
15
+ self.is_training = is_training
16
+ # fmt: off
17
+ self.data_temp = {
18
+ "text": [
19
+ [{
20
+ "prompt": "Describe the image in short.",
21
+ "response": "A rollerblader rides high in a full pipe while others watch"
22
+ }],
23
+ [{
24
+ "prompt": "Describe the image in short.",
25
+ "response": "A woman in winter clothes is on the sidewalk with a phone."
26
+ }]
27
+ ],
28
+ "image": [
29
+ {
30
+ "image_file": "/mnt/bn/videonaslq/images/flickr30k/images/3371533654.jpg"
31
+ },
32
+ {
33
+ "image_file": "/mnt/bn/videonaslq/images/coco/train2014/COCO_train2014_000000177950.jpg"
34
+ },
35
+ {
36
+ "video_file": "/mnt/bn/llmdatalq/jiangnan/video_generation/webvid_10M_download/20230609/videos/011851_011900/1047443473.mp4",
37
+ "frame_indices": [0, 85, 171, 256, 342, 427, 513, 598]
38
+ }
39
+ ],
40
+ "dataset": "coco",
41
+ "task": "multi_images",
42
+ "image_processing_config": {},
43
+ }
44
+ # fmt: on
45
+
46
+ def check_format(self, data_dict: Dict, image_processing_config: Dict):
47
+ assert data_dict['dataset'] in ['coco', 'sharegpt4v_cap100k', 'sharegpt4v_mix665k', 'webvid', 'movie'], data_dict
48
+
49
+ # 目前多图数据应该没有包含坐标的数据吧
50
+ if image_processing_config.get('has_coordinates', False):
51
+ raise ValueError(f'do_crop and has_coordinates cannot be True at the same time in MultiImagesParser!')
52
+
53
+ # 检查是否能匹配到坐标
54
+ texts = data_dict['text']
55
+ for text in texts:
56
+ match = re.search(r'\[(\d+(\.\d+)?,\s*)+\d+(\.\d+)?\]', text['prompt'] + text['response'])
57
+ if match:
58
+ print(f'[Warning] 疑似检测到包含坐标的数据:{data_dict}')
59
+
60
+
61
+ def transform(self, data_dict: Dict, image_processing_config: Dict = None) -> Dict:
62
+ self.check_format(data_dict, image_processing_config)
63
+
64
+ # shuffle
65
+ texts = data_dict['text']
66
+ images = data_dict['image']
67
+ images = self.load_images(images)
68
+ idxs = list(range(len(texts)))
69
+ random.shuffle(idxs)
70
+ texts = [texts[i] for i in idxs]
71
+ images = [images[i] for i in idxs]
72
+
73
+ # sample n_frames
74
+ if isinstance(self.n_frames, int):
75
+ n_frames = random.choice(list(range(1, self.n_frames + 1)))
76
+ else:
77
+ n_frames = random.choice(self.n_frames)
78
+ texts = texts[: n_frames]
79
+ images = images[: n_frames]
80
+
81
+ dataset = data_dict['dataset']
82
+ if dataset in ['coco', 'sharegpt4v_cap100k', 'webvid', 'movie']:
83
+ prompt, response = self.transform_for_caption_task(texts, dataset, images)
84
+ else:
85
+ prompt, response = self.transform_for_qa_task(texts, dataset, images)
86
+
87
+ messages = [
88
+ {
89
+ "role": "user",
90
+ "content": [
91
+ *[{"type": "image", "image": img} for img in images],
92
+ {"type": "text", "text": prompt},
93
+ ]
94
+ },
95
+ {
96
+ "role": "assistant",
97
+ "content": [
98
+ {"type": "text", "text": response}
99
+ ]
100
+ }
101
+ ]
102
+
103
+ return messages
104
+
105
+ def transform_for_caption_task(self, texts, dataset, images):
106
+ idx = random.choice(list(range(len(texts))))
107
+
108
+ if dataset == 'coco':
109
+ if len(texts) == 1:
110
+ prompt = 'Describe the image in short.'
111
+ else:
112
+ prompt = f'Describe the images starting from frame {idx + 1} in short in order.'
113
+ elif dataset == 'sharegpt4v_cap100k':
114
+ if len(texts) == 1:
115
+ prompt = 'Describe the image in detail.'
116
+ else:
117
+ prompt = f'Describe the images starting from frame {idx + 1} in detail in order.'
118
+ else:
119
+ if len(texts) == 1:
120
+ prompt = 'Describe the image.'
121
+ else:
122
+ prompt = f'Describe the images starting from frame {idx + 1} in order.'
123
+ response = ''
124
+ for i, text in enumerate(texts):
125
+ if i < idx:
126
+ continue
127
+ if not isinstance(text, dict):
128
+ text = random.choice(text)
129
+ resp = text['response']
130
+ response += f'{resp}\n'
131
+ return prompt, response
132
+
133
+ def transform_for_qa_task(self, texts, dataset, images):
134
+ prompt, response = '', ''
135
+ for i, text in enumerate(texts):
136
+ if not isinstance(text, dict):
137
+ text = random.choice(text)
138
+ if len(texts) > 1:
139
+ prompt += f'Question for frame {i+1}:\n' + text['prompt'] + '\n'
140
+ response += f'Answer to question of frame {i+1}:\n' + text['response'] + '\n'
141
+ else:
142
+ prompt += text['prompt'] + '\n'
143
+ response += text['response'] + '\n'
144
+ return prompt, response
145
+
146
+
147
+ def load_images(self, image_items: List[Dict]) -> List[Image.Image]:
148
+ """
149
+ image_items: List[Dict]. each item like:
150
+ {"video_file": "path/to/video", "frame_indices": [1]}
151
+ or
152
+ {"image_file": "path/to/image"}
153
+ """
154
+ if image_items is None:
155
+ raise ValueError(f'image_items is None!')
156
+
157
+ if isinstance(image_items, dict):
158
+ image_items = [image_items]
159
+
160
+ images = []
161
+
162
+ for image_item in image_items:
163
+
164
+ if 'video_file' in image_item:
165
+ file_key = 'video_file'
166
+ elif 'image_file' in image_item:
167
+ file_key = 'image_file'
168
+ else:
169
+ raise KeyError(f'video_file or image_file not in {image_item}')
170
+
171
+ file_path = image_item[file_key]
172
+ if file_key == 'video_file':
173
+ frame_indices = image_item.get('frame_indices', None)
174
+ if frame_indices is None:
175
+ raise ValueError(f'read 0 frame: {image_item}')
176
+ if isinstance(frame_indices, int):
177
+ frame_indices = [frame_indices]
178
+ frames = sample_video(file_path, frame_indices = frame_indices)
179
+ images.extend(frames)
180
+ else:
181
+ if isinstance(file_path, str):
182
+ file_path = [file_path]
183
+ images.extend([read_image(f) for f in file_path])
184
+
185
+ return images
186
+
187
+ if __name__ == '__main__':
188
+ # python3 -m xenon_generation.data.custom_data_parsers.multi_images_parser
189
+
190
+ from tqdm import tqdm
191
+ from tools.rw_utils import read_jsonlines
192
+
193
+ lines = read_jsonlines('/mnt/bn/videonaslq/VideoCaption/datasets_1009/sharegpt4v_cap100k/part_36.jsonl')
194
+ lines = lines[:10]
195
+ parser = MultiImagesParser(n_frames=8)
196
+ for i, l in tqdm(enumerate(lines)):
197
+ l_image_processing_config = l.get('image_processing_config', {})
198
+ messages = parser.transform(l, l_image_processing_config)
199
+ print(messages)
eval_scripts/DREAM-1K/tarsier/dataset/custom_data_parsers/object_tracking_parser.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ import random
3
+ import re
4
+
5
+ from torchvision import transforms
6
+
7
+ from .utils import sample_video
8
+
9
+ def return_same(x):
10
+ return x
11
+
12
+ def _bbox_transform_for_padding(bbox, frame):
13
+ w1, h1, w2, h2 = bbox
14
+ width, height = frame.size
15
+ if width == height:
16
+ pass
17
+ elif width > height:
18
+ h1 += (width - height) // 2
19
+ h2 += (width - height) // 2
20
+ height = width
21
+ else:
22
+ w1 += (height - width) // 2
23
+ w2 += (height - width) // 2
24
+ width = height
25
+ new_bbox = [w1 / width, h1 / height, w2 / width, h2 / height]
26
+ new_bbox = [round(i, 2) for i in new_bbox]
27
+ return new_bbox
28
+
29
+ def _bbox_transform_for_resize(bbox, frame):
30
+ w1, h1, w2, h2 = bbox
31
+ width, height = frame.size
32
+ new_bbox = [w1 / width, h1 / height, w2 / width, h2 / height]
33
+ new_bbox = [round(i, 2) for i in new_bbox]
34
+ return new_bbox
35
+
36
+ class InAndOutCropAndResize(object):
37
+ """Crop and resize for in_and_out boxes data according to yuchen
38
+ Args:
39
+ size: tuple of (width, height)
40
+ """
41
+
42
+ def __init__(self, size):
43
+ self.size = size
44
+
45
+ def __call__(self, img):
46
+ """
47
+ Args:
48
+ img (PIL Image): PIL Image
49
+ Returns:
50
+ PIL Image: PIL image.
51
+ """
52
+ w = img.width
53
+ h = img.height
54
+ x0 = int(w * 0.5 - h * 0.375)
55
+ y0 = int(h * 0.125)
56
+ x1 = int(w * 0.5 + h * 0.375)
57
+ y1 = int(h * 0.875)
58
+ img = img.crop((x0, y0, x1, y1)).resize(self.size)
59
+ return img
60
+
61
+
62
+ class ObjectTrackingParser:
63
+ def __init__(
64
+ self,
65
+ n_frames = 8,
66
+ max_objects = 3,
67
+ is_training=True,
68
+ ):
69
+ self.n_frames = n_frames
70
+ self.max_objects = max_objects
71
+ self.is_training = is_training
72
+ self.img_transform = self.get_img_transform()
73
+ # fmt: off
74
+ self.data_temp = {
75
+ "video_file": "/mnt/bn/llmdatalq/jiaxin/hdvila/20230926/saved/saved_video_clips/0076/lOjn__YCec4.624.1104.mp4",
76
+ "frame_indices": [154, 157, 160, 163, 166, 169, 172, 175, 178, 181, 184, 187, 190, 193, 196, 199, 202],
77
+ "objects": {
78
+ "0": {
79
+ "phrase": "person",
80
+ "all_frame_bounding_boxes": [[2, 0, 255, 250], [17, 0, 255, 251], [35, 0, 255, 253], [44, 0, 255, 255], [52, 0, 255, 255], [54, 0, 255, 255], [63, 0, 255, 255], [60, 0, 255, 255], [54, 0, 253, 255], [43, 0, 250, 255], [36, 1, 249, 255], [36, 0, 252, 254], [41, 0, 252, 254], [61, 0, 255, 253], [68, 4, 255, 255], [74, 8, 255, 255], [91, 3, 255, 255]]
81
+ }
82
+ },
83
+ "task": "object_tracking",
84
+ "dataset": "hdvila"
85
+ }
86
+ # fmt: on
87
+
88
+ def check_format(self, data_dict: Dict, image_processing_config: Dict):
89
+ # box tracking 数据不支持 do_crop!!!
90
+ if image_processing_config.get('do_crop', False):
91
+ raise ValueError(f'do_crop is not supported in ObjectTrackingParser!')
92
+
93
+ def transform(self, data_dict: Dict, image_processing_config: Dict = None) -> Dict:
94
+ self.check_format(data_dict, image_processing_config)
95
+
96
+ bbox_transform = _bbox_transform_for_padding if image_processing_config['do_padding'] else _bbox_transform_for_resize
97
+
98
+ # sample n_frames
99
+ if isinstance(self.n_frames, int):
100
+ n_frames = self.n_frames
101
+ else:
102
+ n_frames = random.choice(self.n_frames)
103
+ total_frames = list(range(len(data_dict['frame_indices'])))
104
+ idxs = random.sample(total_frames, min(n_frames, len(total_frames)))
105
+ idxs.sort()
106
+
107
+ frame_indices = [data_dict['frame_indices'][i] for i in idxs]
108
+ frames = sample_video(data_dict['video_file'], frame_indices=frame_indices)
109
+ img_transform = self.img_transform[data_dict['dataset']]
110
+ frames = [img_transform(f) for f in frames]
111
+
112
+ objects = []
113
+ for _, o in data_dict['objects'].items():
114
+ if o is None:
115
+ continue
116
+ all_frame_bounding_boxes = [o['all_frame_bounding_boxes'][i] for i in idxs]
117
+ all_frame_bounding_boxes_t = []
118
+ for bbox, frame in zip(all_frame_bounding_boxes, frames):
119
+ all_frame_bounding_boxes_t.append(bbox_transform(bbox, frame))
120
+ objects.append(all_frame_bounding_boxes_t)
121
+ if len(objects) >= self.max_objects:
122
+ break
123
+
124
+ prompt = "Given the bounding box coordinates of these objects in the first frame, output the bounding box coordinates in the following frames.\n{}"
125
+ response = ''
126
+
127
+ object_info = ''
128
+ for i, o in enumerate(objects):
129
+ object_info += f'object {i+1}: {o[0]}\n'
130
+ response += f'object {i+1}: {o[1:]}\n'
131
+ response = response.strip()
132
+ prompt = prompt.format(object_info)
133
+
134
+ messages = [
135
+ {
136
+ "role": "user",
137
+ "content": [
138
+ {"type": "video", "video": frames},
139
+ {"type": "text", "text": prompt}
140
+ ]
141
+ },
142
+ {
143
+ "role": "assistant",
144
+ "content": [
145
+ {"type": "text", "text": response}
146
+ ]
147
+ }
148
+ ]
149
+
150
+ return messages
151
+
152
+ def get_img_transform(self):
153
+ return {
154
+ 'webvid': return_same,
155
+ 'hdvila': transforms.Compose([
156
+ transforms.Resize(size=256),
157
+ transforms.CenterCrop(size=(256, 256))
158
+ ]),
159
+ 'hdvila_in_and_out_boxes': InAndOutCropAndResize(size=(256, 256))
160
+ }
eval_scripts/DREAM-1K/tarsier/dataset/custom_data_parsers/standard_vision_parser.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ from PIL import Image
3
+ import random
4
+
5
+ from .utils import sample_video, read_image, adjust_bbox, filter_ocr_polygon
6
+
7
+
8
+ class VisionParser:
9
+ def __init__(
10
+ self,
11
+ n_frames=8,
12
+ max_n_frames=256,
13
+ is_training=True,
14
+ video_sampling_strategy={},
15
+ ):
16
+ self.n_frames = n_frames
17
+ self.max_n_frames = max_n_frames
18
+ self.is_training = is_training
19
+ self.video_sampling_strategy = video_sampling_strategy
20
+
21
+ # fmt: off
22
+ self.data_temp = {
23
+ "messages": [
24
+ {
25
+ "role": "user",
26
+ "content": [
27
+ {"type": "text", "text": "Describe the image and the video."},
28
+ # 支持的 image 格式:
29
+ {"type": "image", "image": {"image_file": "/path/to/image"}},
30
+ {"type": "image", "image": {"video_file": "/path/to/video", "frame_indices": 0}},
31
+ # 支持的 video 格式:
32
+ {"type": "video", "video": {"video_file": "/path/to/video"}},
33
+ {"type": "video", "video": {"video_file": "/path/to/video", "frame_indices": [0, 1, 2]}},
34
+ {"type": "video", "video": {"video_file": "/path/to/video", "start_frame": 0, "end_frame": 100}},
35
+ {"type": "video", "video": {"video_file": "/path/to/video", "time_indices": [0, 1, 2]}},
36
+ {"type": "video", "video": {"video_file": "/path/to/video", "start_time": 0, "end_time": 100}},
37
+ {"type": "video", "video": {"image_file": ["/path/to/image"]}, "frame_indices": [0, 1, 2]},
38
+ ]
39
+ },
40
+ {
41
+ "role": "assistant",
42
+ "content": [
43
+ {"type": "text","text": "xxx"}
44
+ ]
45
+ }
46
+ ],
47
+ "dataset": "LSMDC",
48
+ "task": "video/caption"
49
+ }
50
+ # fmt: on
51
+
52
+ def check_format(self, data_dict: Dict, image_processing_config: Dict):
53
+ if image_processing_config.get('do_crop', False) and image_processing_config.get('has_coordinates', False):
54
+ raise ValueError(f'do_crop and has_coordinates cannot be True at the same time!')
55
+
56
+ """
57
+ 1. 将 messages 中的 image/video 替换成相应的 PIL.Image/List[PIL.Image]
58
+ 2. text 的特殊处理:调整 box;过滤面积太小的OCR
59
+ """
60
+ def transform(self, data_dict: Dict, image_processing_config: Dict = None) -> Dict:
61
+ self.check_format(data_dict, image_processing_config)
62
+
63
+ self.set_n_frames(data_dict)
64
+
65
+ first_image = None # ugly! 需要调整box/过滤面积太小的OCR的数据只有图片任务
66
+
67
+ for msg in data_dict['messages']:
68
+ if isinstance(msg['content'], dict):
69
+ msg['content'] = [msg['content']]
70
+ for content in msg['content']:
71
+
72
+ if content['type'] == 'image':
73
+ content['image'] = self.load_image_item(content['image'])
74
+ if first_image is None:
75
+ first_image = content['image']
76
+ elif content['type'] == 'video':
77
+ video = self.load_video_item(content['video'])
78
+ content['video'] = video.pop('frames')
79
+ if video:
80
+ data_dict['extra_info']['frame_disturb_info'] = video.pop('video_info', {})
81
+ elif content['type'] == 'text':
82
+ pass
83
+ else:
84
+ raise ValueError(f"content['type']={content['type']} MUST be one of ['image', 'video', 'text']")
85
+ for msg in data_dict['messages']:
86
+ for content in msg['content']:
87
+ if content['type'] == 'text':
88
+ self.postprocess_text(content, data_dict, image_processing_config, first_image)
89
+
90
+ return data_dict['messages']
91
+
92
+ # set n_frames for each vision item.
93
+ def set_n_frames(self, data_dict):
94
+
95
+ if isinstance(self.n_frames, int):
96
+ n_frames = self.n_frames
97
+ else:
98
+ n_frames = random.choice(self.n_frames)
99
+
100
+ assert n_frames <= self.max_n_frames
101
+
102
+ curr_n_frames = 0
103
+ has_dynamic = False
104
+ for msg in data_dict['messages']:
105
+ if isinstance(msg['content'], dict):
106
+ msg['content'] = [msg['content']]
107
+
108
+ for content in msg['content']:
109
+
110
+ if content['type'] == 'image':
111
+ curr_n_frames += 1
112
+ elif content['type'] == 'video':
113
+ if 'frame_indices' in content['video']:
114
+ curr_n_frames += len(content['video']['frame_indices'])
115
+ content['video']['n_frames'] = len(content['video']['frame_indices'])
116
+ elif 'time_indices' in content['video']:
117
+ curr_n_frames += len(content['video']['time_indices'])
118
+ content['video']['n_frames'] = len(content['video']['time_indices'])
119
+ elif 'min_n_frames' in content['video']:
120
+ content['video']['min_n_frames'] = int(content['video']['min_n_frames'])
121
+ curr_n_frames += content['video']['min_n_frames']
122
+ content['video']['n_frames'] = content['video']['min_n_frames']
123
+ has_dynamic = True
124
+ elif 'fps' in content['video']:
125
+ content['video']['n_frames'] = self.max_n_frames
126
+ curr_n_frames += self.max_n_frames
127
+ has_dynamic = True
128
+ else:
129
+ content['video']['n_frames'] = 0
130
+ has_dynamic = True
131
+
132
+ while curr_n_frames < n_frames and has_dynamic:
133
+ for msg in data_dict['messages']:
134
+ for content in msg['content']:
135
+ if content['type'] == 'video':
136
+ if 'frame_indices' in content['video']:
137
+ pass
138
+ elif 'time_indices' in content['video']:
139
+ pass
140
+ else:
141
+ if curr_n_frames < n_frames:
142
+ content['video']['n_frames'] += 1
143
+ curr_n_frames += 1
144
+
145
+ while curr_n_frames > self.max_n_frames and has_dynamic:
146
+ for msg in data_dict['messages']:
147
+ for content in msg['content']:
148
+ if content['type'] == 'video':
149
+ if 'frame_indices' in content['video']:
150
+ pass
151
+ elif 'time_indices' in content['video']:
152
+ pass
153
+ else:
154
+ if curr_n_frames > self.max_n_frames:
155
+ content['video']['n_frames'] -= 1
156
+ curr_n_frames -= 1
157
+
158
+
159
+ for msg in data_dict['messages']:
160
+ for content in msg['content']:
161
+ if content['type'] == 'video':
162
+ if 'frame_indices' in content['video']:
163
+ pass
164
+ elif 'time_indices' in content['video']:
165
+ pass
166
+ else:
167
+ n = self.video_sampling_strategy.get('force_frames_n_divisible', 1)
168
+ if n > 1 and content['video']['n_frames'] % n != 0:
169
+ content['video']['n_frames'] += n - content['video']['n_frames'] % n
170
+
171
+ def load_image_item(self, image_item) -> Image.Image:
172
+ """
173
+ image_item:
174
+ {"image_file": {"lq": "/path/to/image"}}
175
+ {"video_file": {"lq": "/path/to/video"}, "frame_indices": 0}
176
+ """
177
+
178
+ # check format
179
+ if ("image_file" not in image_item) and ("video_file" not in image_item):
180
+ raise KeyError(f"Key 'image_file' or 'video_file' not found in image_item")
181
+ if 'image_file' in image_item:
182
+ if not isinstance(image_item['image_file'], str):
183
+ raise ValueError(f"{image_item['image_file']} is not a str!")
184
+ if 'video_file' in image_item:
185
+ if not isinstance(image_item['frame_indices'], int):
186
+ raise ValueError(f"{image_item['frame_indices']} is not a int!")
187
+
188
+ if 'image_file' in image_item:
189
+ image = read_image(image_item['image_file'])
190
+ else:
191
+ frame_indices = [image_item['frame_indices']]
192
+ image = sample_video(image_item['video_file'], frame_indices = frame_indices)[0]
193
+
194
+ return image
195
+
196
+ def load_video_item(self, video_item) -> List[Image.Image]:
197
+ """
198
+ video_item:
199
+ {"video_file": {"lq": "/path/to/video"}, "n_frames": 8}
200
+ {"video_file": {"lq": "/path/to/video"}, "frame_indices": [0, 1, 2], "n_frames": 3}
201
+ {"video_file": {"lq": "/path/to/video"}, "start_frame": 0, "end_frame": 100, "n_frames": 8}
202
+ {"video_file": {"lq": "/path/to/video"}, "time_indices": [0, 1, 2], "n_frames": 3}
203
+ {"video_file": {"lq": "/path/to/video"}, "start_time": 0, "end_time": 100, "n_frames": 8}
204
+ {"image_file": {"lq": ["/path/to/image"]}, "frame_indices": [0, 1, 2], "n_frames": 3}
205
+ """
206
+
207
+ # check format
208
+ if ("image_file" not in video_item) and ("video_file" not in video_item):
209
+ raise KeyError(f"Key 'image_file' or 'video_file' not found in video_item")
210
+
211
+ video_path = video_item.get('video_file', video_item.get('image_file'))
212
+ n_frames = video_item.get('n_frames', None)
213
+ frame_indices = video_item.get('frame_indices', None)
214
+ start_frame = video_item.get('start_frame', None)
215
+ end_frame = video_item.get('end_frame', None)
216
+ time_indices = video_item.get('time_indices', None)
217
+ start_time = video_item.get('start_time', None)
218
+ end_time = video_item.get('end_time', None)
219
+ mask_boxes = video_item.get('mask_boxes', None)
220
+ fps = video_item.get('fps', None)
221
+
222
+ frames, frame_indices = sample_video(
223
+ video_path=video_path,
224
+ frame_indices=frame_indices,
225
+ start_frame=start_frame,
226
+ end_frame=end_frame,
227
+ n_frames=n_frames,
228
+ time_indices=time_indices,
229
+ start_time=start_time,
230
+ end_time=end_time,
231
+ sampling_fps=fps,
232
+ mask_boxes=mask_boxes,
233
+ is_training=self.is_training,
234
+ video_sampling_strategy=self.video_sampling_strategy,
235
+ return_frame_ids=True,
236
+ )
237
+
238
+ if self.video_sampling_strategy.get('use_multi_images_for_video', False):
239
+ new_frames = []
240
+ for f in frames:
241
+ new_frames.extend([f, f])
242
+ frames = new_frames
243
+
244
+ if isinstance(frame_indices, dict):
245
+ return {
246
+ 'frames': frames,
247
+ 'video_info': frame_indices
248
+ }
249
+ return {'frames': frames}
250
+
251
+ def postprocess_text(self, content, data_dict, image_processing_config, first_image):
252
+ if image_processing_config.get('has_coordinates') and image_processing_config.get('do_padding'):
253
+ content['text'] = adjust_bbox(content['text'], frame=first_image)
254
+ if data_dict.get('task') == 'image/OCR' and image_processing_config.get('has_coordinates'):
255
+ content['text'] = filter_ocr_polygon(content['text'])
eval_scripts/DREAM-1K/tarsier/dataset/custom_data_parsers/utils.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Union
2
+ import os
3
+ import random
4
+ import tempfile
5
+ from PIL import Image, ImageSequence
6
+ import base64
7
+ import io
8
+ import re
9
+ import uuid
10
+ import json
11
+ import numpy as np
12
+ import pyarrow.fs as pf
13
+ import func_timeout
14
+ from func_timeout import func_set_timeout
15
+ import math
16
+
17
+ # fmt: on
18
+ import decord
19
+ # fmt: off
20
+
21
+
22
+ def denorm_box(points, height, width):
23
+ new_points = []
24
+ for p in points:
25
+ new_points.append((round(p[0] * width), round(p[1] * height)))
26
+ return new_points
27
+
28
+ def process_image_for_tiktok(frames: List[Image.Image], mask_boxes):
29
+ mask_boxes = mask_boxes[:len(frames)]
30
+ frames = [np.array(f) for f in frames]
31
+ # assert len(mask_boxes) == len(frames)
32
+ height, width = frames[0].shape[:2]
33
+
34
+ new_frames = []
35
+ for boxes, frame in zip(mask_boxes, frames):
36
+ left, top, right, bottom = 0, 0, width, height
37
+ for box in boxes:
38
+ pts = np.array(denorm_box(box, height, width), np.int32)
39
+ upper_bound = max([p[1] for p in pts]) + 30
40
+ if bottom > upper_bound:
41
+ bottom = upper_bound
42
+ frame[pts[0][1]: pts[2][1], pts[0][0]: pts[1][0]] = 0
43
+
44
+ new_frames.append(Image.fromarray(frame[top: bottom, left: right]))
45
+ return new_frames
46
+
47
+ # 先将视频分成 n_frames 份。训练时,每份随机抽一帧;测试时,每份抽中间的那一帧。
48
+ def _sample_frame_indices_v2(
49
+ total_frames: int,
50
+ n_frames: int,
51
+ is_training=False,
52
+ video_sampling_strategy = {},
53
+ ):
54
+ total_frames_idxs = list(range(total_frames))
55
+ if total_frames <= n_frames:
56
+ return total_frames_idxs
57
+ k, m = divmod(total_frames, n_frames)
58
+ frame_splits = [total_frames_idxs[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in list(range(n_frames))]
59
+ if is_training:
60
+ sample_ids = [random.choice(i) for i in frame_splits]
61
+ else:
62
+ sample_ids = [i[(len(i)+1)//2-1] for i in frame_splits]
63
+ return sample_ids
64
+
65
+ # 均匀抽帧,必采样首尾帧。
66
+ def _sample_frame_indices_v1(total_frames: int, n_frames: int, is_training=False, video_sampling_strategy = {}):
67
+ if n_frames == 1:
68
+ return [0] # sample first frame in default
69
+ if total_frames <= n_frames:
70
+ return list(range(total_frames))
71
+ sample_ids = [round(i * (total_frames - 1) / (n_frames - 1)) for i in range(n_frames)]
72
+ return sample_ids
73
+
74
+ def conduct_disturb_frame(frame_indices):
75
+ disturb_type = random.choice(['exchange', 'crop', 'reverse', 'discard'])
76
+ n_frames = len(frame_indices)
77
+ frame_indices_new = []
78
+ if disturb_type == 'exchange':
79
+ # 均等分成4个segments, 随机交换两个segment
80
+ seg_len = math.ceil(n_frames / 4)
81
+ seg_idxs = list(range(0, n_frames, seg_len))
82
+ target_idxs = random.sample(range(0, 4), 2)
83
+ seg_idxs[target_idxs[0]], seg_idxs[target_idxs[1]] = seg_idxs[target_idxs[1]], seg_idxs[target_idxs[0]]
84
+ for idx in seg_idxs:
85
+ frame_indices_new += frame_indices[idx: idx+seg_len]
86
+ elif disturb_type == 'crop':
87
+ # 随机截取出3/4时长,再采均匀n_frames帧
88
+ crop_len = math.ceil(n_frames / 4)
89
+ idx_s = random.choice(range(0, crop_len+1))
90
+ idx_e = n_frames - 1 - (crop_len - idx_s)
91
+ frame_indices_new = np.linspace(frame_indices[idx_s], frame_indices[idx_e], n_frames, dtype=int).tolist()
92
+ elif disturb_type == 'reverse':
93
+ # 随机选择长度为[1/2, 1]时长的片段进行顺序颠倒
94
+ reverse_len = math.ceil(random.uniform(0.5,1) * n_frames)
95
+ idx_s = random.choice(range(0, n_frames-reverse_len+1))
96
+ idx_e = idx_s + reverse_len - 1
97
+ frame_indices_new = frame_indices[:idx_s] + list(reversed(frame_indices[idx_s: idx_e+1])) + frame_indices[idx_e+1:]
98
+ elif disturb_type == 'discard':
99
+ # 随机丢弃一半帧
100
+ frame_indices_new = random.sample(frame_indices, n_frames//2)
101
+ frame_indices_new.sort()
102
+ return disturb_type, frame_indices_new
103
+
104
+ @func_set_timeout(60)
105
+ def _download_file(path):
106
+ if path.startswith("hdfs"):
107
+ local_path = os.path.join(tempfile.gettempdir(), f'{uuid.uuid4()}_' + os.path.basename(path))
108
+
109
+ fs = pf.HadoopFileSystem.from_uri(uri="hdfs://harunava")
110
+ hdfs_file = fs.open_input_file(path)
111
+ file_size = hdfs_file.size()
112
+ if file_size > 1024 * 1024 * 1024: # 1G
113
+ os.system(f"hadoop fs -get --ct 8 -c 512 '{path}' '{local_path}' > /dev/null 2>&1")
114
+ elif file_size > 1024 * 1024 * 100: # 100M
115
+ os.system(f"hadoop fs -get '{path}' '{local_path}' > /dev/null 2>&1")
116
+ else:
117
+ local_fs = pf.LocalFileSystem()
118
+ with local_fs.open_output_stream(local_path) as local_file:
119
+ while True:
120
+ chunk = hdfs_file.read(1024 * 1024 * 100) # Reading 1MB chunks, you can adjust this as needed
121
+ if not chunk:
122
+ break
123
+ local_file.write(chunk)
124
+ else:
125
+ local_path = path
126
+
127
+ if not os.path.exists(local_path):
128
+ raise FileNotFoundError(f'{local_path}')
129
+
130
+ return local_path
131
+
132
+ def download_file(path):
133
+ try:
134
+ # with timer(f'Download {path}'):
135
+ return _download_file(path)
136
+ except func_timeout.exceptions.FunctionTimedOut as e:
137
+ raise ValueError(e)
138
+
139
+ class VideoReader:
140
+ def __init__(self, path: str) -> None:
141
+ self.path = path
142
+ self.local_path = self.preprocess()
143
+ self.vr = decord.VideoReader(self.local_path, num_threads=1, ctx=decord.cpu(0), fault_tol=1)
144
+ self.vr.seek(0)
145
+ self._length = len(self.vr)
146
+ self._fps = self.vr.get_avg_fps()
147
+
148
+ @property
149
+ def length(self):
150
+ return self._length
151
+
152
+ @property
153
+ def fps(self):
154
+ return self._fps
155
+
156
+ def sample(self, frame_indices) -> List[Image.Image]:
157
+ frames = self.vr.get_batch(frame_indices).asnumpy()
158
+ frames = [Image.fromarray(f).convert('RGB') for f in frames]
159
+ return frames
160
+
161
+ def preprocess(self):
162
+ return download_file(self.path)
163
+
164
+ def postprocess(self):
165
+ if self.path.startswith("hdfs"):
166
+ os.remove(self.local_path)
167
+
168
+ class ImageSeqReader:
169
+ def __init__(self, path: List[str]) -> None:
170
+ self.path = path
171
+ self.local_path = self.preprocess()
172
+ self._length = len(self.local_path)
173
+ self._fps = None
174
+
175
+ @property
176
+ def length(self):
177
+ return self._length
178
+
179
+ @property
180
+ def fps(self):
181
+ return self._fps
182
+
183
+ def sample(self, frame_indices):
184
+ return [read_image(self.local_path[i]) for i in frame_indices]
185
+
186
+ def preprocess(self):
187
+ local_paths = []
188
+ for p in self.path:
189
+ local_paths.append(p)
190
+ return local_paths
191
+
192
+ def postprocess(self):
193
+ pass
194
+
195
+ class GIFReader:
196
+ def __init__(self, path: str) -> None:
197
+ self.path = path
198
+ self.local_path = self.preprocess()
199
+ self.gif = Image.open(self.local_path)
200
+ self._length = self.gif.n_frames
201
+ duration = self.gif.info.get('duration', 0) / 1000 # 转换为秒
202
+ if duration > 0:
203
+ self._fps = 1 / duration
204
+ else:
205
+ self._fps = None
206
+
207
+ @property
208
+ def length(self):
209
+ return self._length
210
+
211
+ @property
212
+ def fps(self):
213
+ return self._fps
214
+
215
+ def sample(self, frame_indices):
216
+ frames = []
217
+ i = 0
218
+ for frame in ImageSequence.Iterator(self.gif):
219
+ if i in frame_indices:
220
+ frames.append(frame.convert('RGB'))
221
+ i += 1
222
+ return frames
223
+
224
+ def preprocess(self):
225
+ return download_file(self.path)
226
+
227
+ def postprocess(self):
228
+ if self.path.startswith("hdfs"):
229
+ os.remove(self.local_path)
230
+
231
+ def check_frame_indices(frame_indices, total_frames, video_path):
232
+ if frame_indices[-1] == total_frames:
233
+ frame_indices[-1] = total_frames - 1
234
+
235
+ valid_frame_indices = [i for i in frame_indices if i >= 0 and i < total_frames]
236
+
237
+ if len(valid_frame_indices) != len(frame_indices):
238
+ print(f'[Error] frame out of index. video_path={video_path}, frame_indices={frame_indices}, total_frames={total_frames}', flush=True)
239
+
240
+ return valid_frame_indices
241
+
242
+
243
+ def sample_video(
244
+ video_path: Union[str, List[str]],
245
+ frame_indices: List[int] = None,
246
+ start_frame:int=None,
247
+ end_frame:int=None,
248
+ n_frames:int = None,
249
+ time_indices: List[float] = None,
250
+ start_time:int=None,
251
+ end_time:int=None,
252
+ sampling_fps:float=None,
253
+ mask_boxes=None,
254
+ is_training:bool=False,
255
+ video_sampling_strategy={'video_sampler_version': 'v1'},
256
+ return_frame_ids: bool=False,
257
+ ) -> List[Image.Image]:
258
+
259
+ do_frame_disturb = video_sampling_strategy.get('do_frame_disturb', False)
260
+
261
+ if isinstance(video_path, str):
262
+ if video_path.endswith('.gif'):
263
+ reader = GIFReader(video_path)
264
+ else:
265
+ reader = VideoReader(video_path)
266
+ else:
267
+ reader = ImageSeqReader(video_path)
268
+
269
+ total_frames = reader.length
270
+ fps = reader.fps
271
+
272
+ if sampling_fps is not None:
273
+ frame_indices = list(range(0, total_frames, round(fps / sampling_fps)))
274
+ if len(frame_indices) > n_frames:
275
+ frame_indices = None
276
+
277
+ if time_indices is not None:
278
+ frame_indices = [round(float(i) * fps) for i in time_indices]
279
+
280
+ if start_time is not None and end_time is not None:
281
+ start_frame = round(start_time * fps)
282
+ end_frame = round(end_time * fps)
283
+
284
+ if frame_indices is None:
285
+ start_frame = 0 if start_frame is None else round(start_frame)
286
+ end_frame = total_frames - 1 if end_frame is None else round(end_frame)
287
+
288
+ if end_frame == total_frames:
289
+ end_frame -= 1
290
+
291
+ if video_sampling_strategy['video_sampler_version'] == 'v1':
292
+ # 均匀抽帧,必采样首尾帧。
293
+ frame_indices = _sample_frame_indices_v1(end_frame - start_frame + 1, n_frames, is_training, video_sampling_strategy)
294
+ elif video_sampling_strategy['video_sampler_version'] == 'v2':
295
+ frame_indices = _sample_frame_indices_v2(end_frame - start_frame + 1, n_frames, is_training, video_sampling_strategy)
296
+ else:
297
+ raise ValueError(f"video_sampler_version={video_sampling_strategy['video_sampler_version']} must be 'v1' or 'v2'")
298
+ frame_indices = [i + start_frame for i in frame_indices]
299
+
300
+ frame_indices = check_frame_indices(frame_indices, total_frames, video_path)
301
+
302
+ if do_frame_disturb:
303
+ frame_disturb_type, frame_indices_new = conduct_disturb_frame(frame_indices)
304
+ frame_indices_raw = frame_indices[:]
305
+ frame_indices = frame_indices_new
306
+
307
+ frames = reader.sample(frame_indices)
308
+ if mask_boxes is not None:
309
+ frames = process_image_for_tiktok(frames, mask_boxes)
310
+
311
+ n = video_sampling_strategy.get('force_frames_n_divisible', 1)
312
+ if n > 1 and len(frames) % n != 0:
313
+ new_n = n - len(frames) % n
314
+ frames.extend([Image.new(mode='RGB', size=frames[-1].size) for _ in range(new_n)])
315
+
316
+ reader.postprocess()
317
+
318
+ if do_frame_disturb:
319
+ return frames, {"frame_indices": frame_indices, "disturb_type": frame_disturb_type, "frame_indices_raw": frame_indices_raw}
320
+ if return_frame_ids:
321
+ return frames, frame_indices
322
+ return frames
323
+
324
+
325
+
326
+ def load_image_from_base64String(img_path):
327
+ img = base64.b64decode(open(img_path, "rb").read())
328
+ buf = io.BytesIO(img)
329
+ img = Image.open(buf)
330
+ return img
331
+
332
+ def read_image(image_path):
333
+ local_file = download_file(image_path)
334
+
335
+ if local_file.endswith('.dat'):
336
+ image = load_image_from_base64String(local_file)
337
+ else:
338
+ image = Image.open(local_file).convert('RGB')
339
+ if image_path.startswith("hdfs"):
340
+ os.remove(local_file)
341
+ return image
342
+
343
+
344
+ def adjust_bbox(text, frame):
345
+
346
+ width, height = frame.size
347
+ new_text = []
348
+ start_idx = 0
349
+ for match in re.finditer(r'\[(\d+(\.\d+)?,\s*)+\d+(\.\d+)?\]', text):
350
+ coordinate_matches = re.findall(r"([0-9.]+)", match.group(0))
351
+ xys = [float(coord) for coord in coordinate_matches]
352
+
353
+ new_xys = []
354
+ for i in range(len(xys)):
355
+ p = xys[i]
356
+
357
+ if width == height:
358
+ pass
359
+
360
+ if width > height and i % 2 != 0:
361
+ p = xys[i] * height
362
+ p += (width - height) // 2
363
+ p = round(p / width, 2)
364
+
365
+ if height > width and i % 2 == 0:
366
+ p = xys[i] * width
367
+ p += (height - width) // 2
368
+ p = round(p / height, 2)
369
+
370
+ new_xys.append(p)
371
+
372
+ new_text.append(text[start_idx: match.span()[0]])
373
+ new_text.append(str(new_xys))
374
+ start_idx = match.span()[1]
375
+ new_text.append(text[start_idx: ])
376
+ text = ''.join(new_text)
377
+
378
+
379
+ return text
380
+
381
+ def bbox_area(vertices, convert_format = True):
382
+ if convert_format:
383
+ vertices = list(zip(vertices[::2], vertices[1::2]))
384
+ x0, y0 = vertices[0]
385
+ x1, y1 = vertices[1]
386
+ return abs((x1 - x0) * (y1 - y0))
387
+
388
+ def polygon_area(vertices, convert_format = True):
389
+ if convert_format:
390
+ vertices = list(zip(vertices[::2], vertices[1::2]))
391
+ n = len(vertices) # 多边形顶点的数量
392
+ if n == 2:
393
+ return bbox_area(vertices, convert_format=False)
394
+ area = 0
395
+ for i in range(n):
396
+ x1, y1 = vertices[i]
397
+ x2, y2 = vertices[(i + 1) % n]
398
+ area += x1 * y2 - x2 * y1
399
+ return abs(area) / 2
400
+
401
+ def get_text_len(text_line):
402
+ l = 0
403
+ for c in text_line:
404
+ if '\u4e00' <= c <= '\u9fff':
405
+ l += 1
406
+ else:
407
+ l += 0.5
408
+ return l
409
+
410
+ def filter_ocr_polygon(response, area_threshold=0.0005):
411
+ try:
412
+ resp = json.loads(response)
413
+ except:
414
+ return response
415
+ new_resp = []
416
+ for coords, text_line in resp:
417
+ area = polygon_area(coords, convert_format=True)
418
+ text_len = get_text_len(text_line)
419
+ if text_len == 0:
420
+ continue
421
+ if area / text_len < area_threshold:
422
+ continue
423
+ new_resp.append([coords, text_line])
424
+ new_resp = json.dumps(new_resp, ensure_ascii=False)
425
+
426
+ return new_resp
427
+
428
+ def put_pred_to_data_dict(prediction, data_dict):
429
+ msg = data_dict['messages'][-1]
430
+ if msg['role'] == 'assistant':
431
+ msg['content'][-1]['text'] = prediction
432
+ else:
433
+ data_dict['messages'].append({
434
+ "role": "assistant",
435
+ "content": [{"type": "text", "text": prediction}]
436
+ })
437
+
438
+ def get_prompt_from_data_dict(data_dict):
439
+ prompt = ""
440
+ for msg in data_dict['messages']:
441
+ role = msg['role']
442
+ assert role in {'system', 'user', 'assistant'}
443
+ for content in msg['content']:
444
+ if content['type'] == 'text':
445
+ if content['text']:
446
+ prompt += f"[{role}]: {content['text']}"
447
+ elif content['type'] == 'image':
448
+ prompt += f"[{role}]: <image>"
449
+ elif content['type'] == 'video':
450
+ prompt += f"[{role}]: <video>"
451
+ prompt += '\n'
452
+ return prompt
eval_scripts/DREAM-1K/tarsier/dataset/custom_data_parsers/utils_visualize.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Dict, List, Optional
3
+ from PIL import Image, ImageDraw, ImageFont
4
+
5
+
6
+ def scale_polygon(polygon, w, h):
7
+ new_polygon = []
8
+ for (x, y) in polygon:
9
+ new_polygon.append((x * w, y * h))
10
+ return new_polygon
11
+
12
+ def draw_polygon(image: Image.Image, points: List[List[int]], label: Optional[str] = None):
13
+ draw = ImageDraw.Draw(image)
14
+ if len(points) > 2:
15
+ draw.polygon(points, outline="red", width=3)
16
+ elif len(points) == 2:
17
+ draw.rectangle(points, outline="red", width=3)
18
+ else:
19
+ raise ValueError(f'points={points} only has one point!')
20
+
21
+ if label is not None:
22
+ font = ImageFont.truetype('/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf', 20)
23
+ draw.text(points[0], label, font=font, fill=(0, 0, 255))
24
+ return image
25
+
26
+ def visualize_image_bbox(data_dict, image_processing_config, processor):
27
+ if image_processing_config.get('has_coordinates') != True:
28
+ return
29
+
30
+ messages = data_dict['messages']
31
+
32
+ polygons = []
33
+ first_image_content = None
34
+
35
+ for msg in messages:
36
+ for content in msg['content']:
37
+ if content['type'] == 'text':
38
+ for match in re.finditer(r'\[(\d+(\.\d+)?,\s*)+\d+(\.\d+)?\]', content["text"]):
39
+ coordinate_matches = re.findall(r"([0-9.]+)", match.group(0))
40
+ coords = [float(coord) for coord in coordinate_matches]
41
+ polygons.append(list(zip(coords[::2], coords[1::2])))
42
+ elif first_image_content is None and content['type'] == 'image':
43
+ first_image_content = content
44
+
45
+ first_image = first_image_content['image']
46
+ first_image = processor.preprocess_image(first_image, image_processing_config)
47
+ w, h = first_image.size
48
+
49
+ if len(polygons) > 0:
50
+ for i, polygon in enumerate(polygons):
51
+ polygon = scale_polygon(polygon, w, h)
52
+ first_image = draw_polygon(first_image, polygon, label=str(i))
53
+
54
+ first_image_content['image'] = first_image
eval_scripts/DREAM-1K/tarsier/dataset/custom_data_parsers/video_permutation_parser.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ import random
3
+ from PIL import Image, ImageDraw, ImageFont
4
+
5
+ from .utils import sample_video
6
+
7
+
8
+ class VideoPermutationParser:
9
+ def __init__(
10
+ self,
11
+ n_frames=8,
12
+ is_training=True,
13
+ frame_nums = list(range(8, 25)),
14
+ video_sampling_strategy={},
15
+ ):
16
+ self.n_frames = n_frames
17
+ self.is_training = is_training
18
+ self.frame_nums = frame_nums
19
+ self.video_sampling_strategy = video_sampling_strategy
20
+ # fmt: off
21
+ self.data_temp = {
22
+ "text": [{
23
+ "prompt": "<video>",
24
+ "response": ""
25
+ }],
26
+ "video": [{
27
+ "video_file": {
28
+ "yg": "/mnt/bn/videonasyg/videos/webvid_10M_download/011851_011900/1047443473.mp4",
29
+ "lq": "/mnt/bn/llmdatalq/jiangnan/video_generation/webvid_10M_download/20230609/videos/011851_011900/1047443473.mp4"
30
+ },
31
+ "frame_indices": [0, 85, 171, 256, 342, 427, 513, 598]
32
+ }],
33
+ }
34
+ # fmt: on
35
+
36
+ def check_format(self, data_dict: Dict):
37
+ pass
38
+ # for k in self.data_temp.keys():
39
+ # assert k in data_dict
40
+
41
+ def transform(self, data_dict: Dict, image_processing_config: Dict = None) -> Dict:
42
+ self.check_format(data_dict)
43
+
44
+ frames = self.load_video_item(data_dict['video'][0])
45
+
46
+ # frames = self.add_text_to_frames(frames) # for debug
47
+
48
+ idxs = list(range(1, len(frames) + 1))
49
+ random.shuffle(idxs)
50
+
51
+ prefix_len = int(3/8*len(idxs))
52
+
53
+ shuffled_frames = [frames[i-1] for i in idxs]
54
+
55
+ prompt = f'Output the correct chronological order of scrambled video frames. The order of the first {prefix_len} ones are:\n'
56
+ prompt += '\n'.join([str(i) for i in idxs[: prefix_len]]) + '\nOutput the order of the following frames:'
57
+ response = '\n'.join([str(i) for i in idxs[prefix_len: ]])
58
+
59
+ messages = [
60
+ {
61
+ "role": "user",
62
+ "content": [
63
+ {"type": "video", "video": shuffled_frames},
64
+ {"type": "text", "text": prompt},
65
+ ]
66
+ },
67
+ {
68
+ "role": "assistant",
69
+ "content": [
70
+ {"type": "text", "text": response}
71
+ ]
72
+ }
73
+ ]
74
+
75
+ return messages
76
+
77
+
78
+ def load_video_item(self, video_item) -> List[Image.Image]:
79
+ """
80
+ video_item:
81
+ {"video_file": "/path/to/video", "n_frames": 8}
82
+ {"video_file": "/path/to/video", "frame_indices": [0, 1, 2], "n_frames": 3}
83
+ {"video_file": "/path/to/video", "start_frame": 0, "end_frame": 100, "n_frames": 8}
84
+ {"video_file": "/path/to/video", "time_indices": [0, 1, 2], "n_frames": 3}
85
+ {"video_file": "/path/to/video", "start_time": 0, "end_time": 100, "n_frames": 8}
86
+ {"image_file": ["/path/to/image"], "frame_indices": [0, 1, 2], "n_frames": 3}
87
+ """
88
+
89
+ # check format
90
+ if ("image_file" not in video_item) and ("video_file" not in video_item):
91
+ raise KeyError(f"Key 'image_file' or 'video_file' not found in video_item")
92
+
93
+ video_path = video_item.get('video_file', video_item.get('image_file'))
94
+ n_frames = video_item.get('n_frames', None)
95
+ frame_indices = video_item.get('frame_indices', None)
96
+ start_frame = video_item.get('start_frame', None)
97
+ end_frame = video_item.get('end_frame', None)
98
+ time_indices = video_item.get('time_indices', None)
99
+ start_time = video_item.get('start_time', None)
100
+ end_time = video_item.get('end_time', None)
101
+ mask_boxes = video_item.get('mask_boxes', None)
102
+
103
+ n_frames = random.choice(self.frame_nums)
104
+ n = self.video_sampling_strategy.get('force_frames_n_divisible', 1)
105
+ if n > 1 and n_frames % n != 0:
106
+ n_frames += n - n_frames % n
107
+
108
+ frames, frame_indices = sample_video(
109
+ video_path=video_path,
110
+ frame_indices=frame_indices,
111
+ start_frame=start_frame,
112
+ end_frame=end_frame,
113
+ n_frames=n_frames,
114
+ time_indices=time_indices,
115
+ start_time=start_time,
116
+ end_time=end_time,
117
+ mask_boxes=mask_boxes,
118
+ is_training=self.is_training,
119
+ video_sampling_strategy=self.video_sampling_strategy,
120
+ return_frame_ids=True,
121
+ )
122
+ return frames
123
+
124
+
125
+ def add_text_to_frames(self, frames: List[Image.Image]):
126
+ new_frames = []
127
+ for i, image in enumerate(frames):
128
+ draw = ImageDraw.Draw(image)
129
+
130
+ font = ImageFont.truetype('/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf', 100)
131
+ text_position = (50, 50)
132
+ text_content = f'{i+1}'
133
+ text_color = (255, 0, 0)
134
+ draw.text(text_position, text_content, font=font, fill=text_color)
135
+ new_frames.append(image)
136
+ return new_frames
137
+
eval_scripts/DREAM-1K/tarsier/dataset/tarsier_datamodule.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Datamodule for Llava Pretraining and Finetuning"""
2
+ import os
3
+ import re
4
+ from PIL import Image
5
+ import numpy as np
6
+ import re
7
+ import tempfile
8
+ from typing import Dict, List, Union, Tuple
9
+ import traceback
10
+ import json
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from transformers import DataCollatorForSeq2Seq
15
+
16
+ from tools.rw_utils import read_jsonlines
17
+ from torch.utils.data import Dataset, DataLoader
18
+
19
+ np_str_obj_array_pattern = re.compile(r"[SaUO]")
20
+
21
+ default_collate_err_msg_format = (
22
+ "default_collate: batch must contain tensors, numpy arrays, numbers, "
23
+ "dicts or lists; found {}"
24
+ )
25
+
26
+ from .custom_data_parsers.standard_vision_parser import VisionParser
27
+ from .custom_data_parsers.object_tracking_parser import ObjectTrackingParser
28
+ from .custom_data_parsers.multi_images_parser import MultiImagesParser
29
+ from .custom_data_parsers.video_permutation_parser import VideoPermutationParser
30
+ from .custom_data_parsers.utils_visualize import visualize_image_bbox
31
+
32
+ from .tarsier_processor import TarsierProcessor
33
+
34
+ from tools.rw_utils import NumpyArrayEncoder
35
+ from .utils import DictToObject
36
+
37
+ class TarsierDataProcessor:
38
+ def __init__(
39
+ self,
40
+ processor: TarsierProcessor,
41
+ n_frames: Union[int, list],
42
+ max_n_frames=256,
43
+ max_pixels=int(1280 * 720 // 2),
44
+ min_pixels=0,
45
+ max_seq_len=None,
46
+ is_training=True, # 会影响:1. 训练和测试时采帧不同;2. 测试时忽略 response。
47
+ print_data_error=True,
48
+ do_image_padding=False,
49
+ do_image_crop=False,
50
+ do_image_resize=True,
51
+ video_sampling_strategy={},
52
+ prompt='',
53
+ train_task='sft',
54
+ **kwargs
55
+ ):
56
+ self.kwargs = kwargs
57
+
58
+ self.processor = processor
59
+ self.pad_collator = DataCollatorForSeq2Seq(processor.tokenizer, padding='longest')
60
+
61
+ self.processor.max_seq_len = self.tokenizer.model_max_length if max_seq_len is None else max_seq_len
62
+
63
+ self.n_frames = n_frames
64
+ self.max_n_frames = max_n_frames
65
+ self.max_pixels = max_pixels
66
+ self.min_pixels = min_pixels
67
+
68
+ self.is_training = is_training
69
+ self.print_data_error = print_data_error
70
+ self.do_image_padding = do_image_padding
71
+ self.do_image_crop = do_image_crop
72
+ self.do_image_resize = do_image_resize
73
+ self.video_sampling_strategy = video_sampling_strategy
74
+ self.prompt = prompt
75
+ self.train_task = train_task
76
+
77
+ self.object_tracking_parser = ObjectTrackingParser(
78
+ n_frames=self.n_frames,
79
+ max_objects=4,
80
+ is_training=self.is_training,
81
+ )
82
+ self.multi_images_parser = MultiImagesParser(
83
+ n_frames=self.n_frames,
84
+ is_training=self.is_training,
85
+ )
86
+ self.video_permutation_parser = VideoPermutationParser(
87
+ n_frames=self.n_frames,
88
+ is_training=self.is_training,
89
+ video_sampling_strategy=self.video_sampling_strategy,
90
+ )
91
+ self.vision_parser = VisionParser(
92
+ n_frames=self.n_frames,
93
+ max_n_frames=self.max_n_frames,
94
+ is_training=self.is_training,
95
+ video_sampling_strategy=self.video_sampling_strategy
96
+ )
97
+
98
+ def select_parser(self, data_dict):
99
+ if data_dict.get('task', None) == 'video/object_tracking':
100
+ return self.object_tracking_parser
101
+ elif data_dict.get('task', None) == 'multi_images':
102
+ return self.multi_images_parser
103
+ elif data_dict.get('dataset', None) == 'video_permutation':
104
+ return self.video_permutation_parser
105
+ else:
106
+ return self.vision_parser
107
+
108
+ def parse_image_processing_config(self, data_dict):
109
+ image_processing_config=data_dict.get('image_processing_config', {})
110
+
111
+ do_padding = image_processing_config.get('do_padding', self.do_image_padding)
112
+ do_crop = image_processing_config.get('do_crop', self.do_image_crop)
113
+ do_resize = image_processing_config.get('do_resize', self.do_image_resize)
114
+ max_pixels = image_processing_config.get('max_pixels', self.max_pixels)
115
+ min_pixels = image_processing_config.get('min_pixels', self.min_pixels)
116
+
117
+ assert min_pixels <= max_pixels
118
+
119
+ image_processing_config['do_padding'] = do_padding
120
+ image_processing_config['do_crop'] = do_crop
121
+ image_processing_config['do_resize'] = do_resize
122
+ image_processing_config['max_pixels'] = max_pixels
123
+ image_processing_config['min_pixels'] = min_pixels
124
+
125
+ return image_processing_config
126
+
127
+
128
+ def _transform(self, raw_data_dict: Dict) -> Dict:
129
+ data_dict = json.loads(json.dumps(raw_data_dict, cls=NumpyArrayEncoder))
130
+ del raw_data_dict
131
+
132
+ if self.prompt:
133
+ for msg in data_dict['messages']:
134
+ if msg['role'] == 'user':
135
+ for content in msg['content']:
136
+ if content['type'] == 'text':
137
+ content['text'] = self.prompt
138
+
139
+ data_dict_copy = json.loads(json.dumps(data_dict, cls=NumpyArrayEncoder))
140
+
141
+ image_processing_config = self.parse_image_processing_config(data_dict)
142
+ parser = self.select_parser(data_dict)
143
+ messages = parser.transform(data_dict, image_processing_config)
144
+ data_dict_copy['extra_info'] = data_dict.pop('extra_info', {})
145
+
146
+ # visualize_image_bbox(data_dict, image_processing_config, self.processor)
147
+ outputs = self.processor(messages, image_processing_config, is_training=self.is_training)
148
+
149
+ # if not self.is_training:
150
+ outputs['raw_data_dict'] = data_dict_copy
151
+
152
+ return [outputs]
153
+
154
+ def _split_chosen_rejected(self, data_dict: Dict):
155
+ chosen_data_dict = data_dict
156
+ rejected_data_dict = json.loads(json.dumps(data_dict, cls=NumpyArrayEncoder))
157
+ for msg in chosen_data_dict['messages']:
158
+ if msg['role'] == 'assistant':
159
+ for content in msg['content']:
160
+ if content['type'] == 'text':
161
+ content['text'] = content['chosen']
162
+
163
+ for msg in rejected_data_dict['messages']:
164
+ if msg['role'] == 'assistant':
165
+ for content in msg['content']:
166
+ if content['type'] == 'text':
167
+ content['text'] = content['rejected']
168
+
169
+ return chosen_data_dict, rejected_data_dict
170
+
171
+ def transform(self, data_dict: Dict) -> Dict:
172
+ try:
173
+ if self.train_task == 'dpo':
174
+ chosen_data_dict, rejected_data_dict = self._split_chosen_rejected(data_dict)
175
+ return self._transform(chosen_data_dict) + self._transform(rejected_data_dict)
176
+ return self._transform(data_dict)
177
+ except Exception as e:
178
+ if self.print_data_error:
179
+ print(traceback.format_exc())
180
+ print(f'Error occurs when processing: \n{data_dict}')
181
+ return []
182
+
183
+ def batch_transform(self, batch_data: List[Dict]) -> Dict:
184
+ model_inputs = {}
185
+ # if not self.is_training:
186
+ raw_data_dict = [d.pop('raw_data_dict') for d in batch_data]
187
+ model_inputs['raw_data_dict'] = raw_data_dict
188
+
189
+ batch_pixel_values = [d.pop('pixel_values') for d in batch_data if 'pixel_values' in d]
190
+ batch_image_grid_thw = [d.pop('image_grid_thw') for d in batch_data if 'image_grid_thw' in d]
191
+ if len(batch_pixel_values) == 0:
192
+ vision_placeholder = self.get_vision_placeholder()
193
+ batch_pixel_values = [vision_placeholder.get('pixel_values')]
194
+ batch_image_grid_thw = [vision_placeholder.get('image_grid_thw')] if 'image_grid_thw' in vision_placeholder else []
195
+
196
+ model_inputs['pixel_values'] = torch.cat(batch_pixel_values, dim=0)
197
+ if len(batch_image_grid_thw) > 0:
198
+ model_inputs['image_grid_thw'] = torch.cat(batch_image_grid_thw, dim=0)
199
+
200
+ batch_num_images = [d.pop('num_images') for d in batch_data]
201
+ model_inputs['num_images'] = torch.tensor(batch_num_images)
202
+ model_inputs.update(self.pad_collator(batch_data))
203
+ return model_inputs
204
+
205
+ def __call__(self, batch_data: Union[Dict, List[Dict]]) -> Dict:
206
+ if isinstance(batch_data, dict):
207
+ batch_data = [batch_data]
208
+ batch = [self.transform(d)[0] for d in batch_data]
209
+ return self.batch_transform(batch)
210
+
211
+ def get_vision_placeholder(self):
212
+ messages = [{"role": "user", "content": [{"type": "image", "image": Image.new(mode='RGB', size=(336, 336))}]}]
213
+ image_processing_config = self.parse_image_processing_config({})
214
+ return self.processor(messages, image_processing_config)
215
+
216
+ def get_text_placeholder(self):
217
+ messages = [
218
+ {"role": "user", "content": [{"type": "text", "text": "Hello!"}]},
219
+ {"role": "assistant", "content": [{"type": "text", "text": "Thank you very much"}]},
220
+ ]
221
+ image_processing_config = self.parse_image_processing_config({})
222
+ return self.processor(messages, image_processing_config)
223
+
224
+ def init_processor(processor: Union[TarsierProcessor, str]=None, config: Dict=None):
225
+ config = DictToObject(config) if isinstance(config, dict) else config
226
+ if isinstance(processor, str):
227
+ sub_processor = TarsierProcessor.from_pretrained(
228
+ processor,
229
+ padding_side='left',
230
+ trust_remote_code=True
231
+ )
232
+ else:
233
+ sub_processor = processor
234
+ processor = TarsierDataProcessor(
235
+ processor=sub_processor,
236
+ n_frames=config.n_frames,
237
+ max_n_frames=config.max_n_frames,
238
+ max_pixels=config.max_pixels,
239
+ min_pixels=config.min_pixels,
240
+ max_seq_len=config.max_seq_len,
241
+ is_training=config.is_training,
242
+ print_data_error=config.print_data_error,
243
+ do_image_padding=config.do_image_padding,
244
+ do_image_crop=config.do_image_crop,
245
+ do_image_resize=config.do_image_resize,
246
+ video_sampling_strategy=config.video_sampling_strategy,
247
+ prompt=config.prompt,
248
+ train_task=config.train_task
249
+ )
250
+ return processor
251
+
252
+ class TarsierDataset(Dataset):
253
+ def __init__(self, ann_path="", anns=None, config: Dict=None, processor: Union[TarsierDataProcessor, TarsierProcessor, str]=None):
254
+ self.config = DictToObject(config) if isinstance(config, dict) else config
255
+ if not isinstance(processor, TarsierDataProcessor):
256
+ self.processor = init_processor(processor, config)
257
+ else:
258
+ self.processor = processor
259
+ if anns is None:
260
+ self.anns = []
261
+ if isinstance(ann_path, str):
262
+ ann_path = [ann_path]
263
+ for path in ann_path:
264
+ self.anns.extend(read_jsonlines(path))
265
+ else:
266
+ self.anns = anns
267
+
268
+ def __len__(self):
269
+ return len(self.anns)
270
+
271
+ def __getitem__(self, index):
272
+ if index < 0 or index >= len(self.anns):
273
+ raise IndexError("Index out of range")
274
+ try:
275
+ ann = self.anns[index]
276
+ model_inputs = self.processor(ann)
277
+ except Exception as e:
278
+ print(f"Load data error: {e}")
279
+ return ann, None
280
+ return ann, model_inputs
eval_scripts/DREAM-1K/tarsier/dataset/tarsier_processor.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Union
2
+ from PIL import Image
3
+
4
+ import torch
5
+
6
+ from transformers.feature_extraction_utils import BatchFeature
7
+ from transformers.image_utils import ImageInput, get_image_size, to_numpy_array
8
+ from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
9
+ from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
10
+ from transformers.utils import logging
11
+ from transformers import Qwen2VLImageProcessor
12
+ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
13
+
14
+ logger = logging.get_logger(__name__)
15
+
16
+
17
+ class TarsierProcessorKwargs(ProcessingKwargs, total=False):
18
+ _defaults = {
19
+ "text_kwargs": {},
20
+ "images_kwargs": {},
21
+ }
22
+
23
+
24
+ class TarsierProcessor(ProcessorMixin):
25
+
26
+ attributes = ["image_processor", "tokenizer"]
27
+ valid_kwargs = ["chat_template", "image_token", "patch_size", "merge_size", "temporal_patch_size", "max_seq_len"]
28
+ image_processor_class = "AutoImageProcessor"
29
+ tokenizer_class = "AutoTokenizer"
30
+
31
+ def __init__(
32
+ self,
33
+ image_processor=None,
34
+ tokenizer=None,
35
+ chat_template=None,
36
+ image_token="<image>",
37
+ patch_size=None,
38
+ merge_size=1,
39
+ temporal_patch_size=1,
40
+ max_seq_len=8192,
41
+ **kwargs,
42
+ ) -> None:
43
+
44
+ self.image_token = image_token
45
+ self.patch_size = patch_size
46
+ self.merge_size = merge_size
47
+ self.temporal_patch_size = temporal_patch_size
48
+ self.max_seq_len = max_seq_len
49
+ self.max_pixels_per_sample = 128 * 384 * 384
50
+
51
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
52
+
53
+ def __call__(
54
+ self,
55
+ messages,
56
+ image_processing_config=None,
57
+ is_training=True,
58
+ ) -> torch.Tensor:
59
+
60
+ output_kwargs = self._merge_kwargs(
61
+ TarsierProcessorKwargs,
62
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
63
+ )
64
+
65
+ # 【图片处理】
66
+ pixel_values, image_grid_thw = [], []
67
+ num_images = 0
68
+ for msg in messages:
69
+ for content in msg['content']:
70
+ if content['type'] == 'image':
71
+ num_images += self.temporal_patch_size
72
+ elif content['type'] == 'video':
73
+ num_images += len(content['video'])
74
+ if num_images > 0 and self.max_pixels_per_sample // num_images < image_processing_config['max_pixels']:
75
+ image_processing_config['max_pixels'] = self.max_pixels_per_sample // num_images
76
+ image_processing_config['min_pixels'] = min(image_processing_config['min_pixels'], image_processing_config['max_pixels'])
77
+
78
+ for msg in messages:
79
+ for content in msg['content']:
80
+ if content['type'] == 'image':
81
+ content['image'] = self.preprocess_image(content['image'], image_processing_config)
82
+ content['image'] = self.image_processor(images = content['image'], **output_kwargs["images_kwargs"], return_tensors="pt")
83
+ content['num_vision_tokens'] = self.get_num_vision_tokens(content)
84
+ pixel_values.append(content['image']['pixel_values'])
85
+ if 'image_grid_thw' in content['image']:
86
+ image_grid_thw.extend(content['image']['image_grid_thw'])
87
+ elif content['type'] == 'video':
88
+ content['video'] = self.preprocess_image(content['video'], image_processing_config)
89
+ if isinstance(self.image_processor, Qwen2VLImageProcessor):
90
+ content['video'] = self.image_processor(images = None, videos = content['video'], **output_kwargs["images_kwargs"], return_tensors="pt")
91
+ pixel_values.append(content['video']['pixel_values_videos'])
92
+ else:
93
+ content['video'] = self.image_processor(images = content['video'], **output_kwargs["images_kwargs"], return_tensors="pt")
94
+ pixel_values.append(content['video']['pixel_values'])
95
+
96
+ if 'video_grid_thw' in content['video']:
97
+ image_grid_thw.extend(content['video']['video_grid_thw'])
98
+ content['num_vision_tokens'] = self.get_num_vision_tokens(content)
99
+
100
+ #【文本处理】
101
+ add_generation_prompt = (not is_training and messages[-1]['role'] != 'assistant')
102
+ strip_final_eos = (not is_training and messages[-1]['role'] == 'assistant')
103
+ text_inputs = self.tokenizer.apply_chat_template(
104
+ messages,
105
+ chat_template = self.chat_template,
106
+ tokenize=True,
107
+ tokenizer_kwargs = output_kwargs["text_kwargs"],
108
+ return_assistant_tokens_mask=True,
109
+ return_dict=True,
110
+ add_generation_prompt=add_generation_prompt,
111
+ strip_final_eos=strip_final_eos,
112
+ )
113
+ labels = [-100 if j == 0 else i for i, j in zip(text_inputs['input_ids'], text_inputs['assistant_masks'])]
114
+ labels = labels[:self.max_seq_len]
115
+ input_ids = text_inputs['input_ids'][:self.max_seq_len]
116
+
117
+ image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token)
118
+ if image_token_id in text_inputs['input_ids'][self.max_seq_len:]:
119
+ raise ValueError(f'Too long sequence! {len(text_inputs["input_ids"])}')
120
+
121
+ outputs = {
122
+ 'input_ids': input_ids,
123
+ 'labels': labels,
124
+ 'num_images': num_images,
125
+ }
126
+ if len(pixel_values) > 0:
127
+ outputs['pixel_values'] = torch.cat(pixel_values, dim=0)
128
+ if len(image_grid_thw) > 0:
129
+ outputs['image_grid_thw'] = torch.stack(image_grid_thw)
130
+ return outputs
131
+
132
+
133
+ def preprocess_image(self, pil_img: Union[Image.Image, List[Image.Image]], image_processing_config):
134
+ if image_processing_config is None:
135
+ return pil_img
136
+ images = pil_img
137
+ if isinstance(pil_img, Image.Image):
138
+ images = [images]
139
+ if image_processing_config['do_crop']:
140
+ images = [self.centralcrop(img, rate=[4, 3]) for img in images]
141
+ if image_processing_config['do_padding']:
142
+ images = [self.expand2square(
143
+ img,
144
+ # tuple(int(x * 255) for x in self.processor.image_processor.image_mean)
145
+ tuple(int(x * 255) for x in [0, 0, 0])
146
+ ) for img in images]
147
+ if image_processing_config['do_resize']:
148
+ images = [self.resize2square(img) for img in images]
149
+ if image_processing_config.get('max_pixels'):
150
+ images = [self.resize2pixels(
151
+ img,
152
+ int(image_processing_config['max_pixels']),
153
+ int(image_processing_config['min_pixels'])
154
+ ) for img in images]
155
+ if isinstance(pil_img, Image.Image):
156
+ images = images[0]
157
+ return images
158
+
159
+ def expand2square(self, pil_img, background_color):
160
+ width, height = pil_img.size
161
+ if width == height:
162
+ return pil_img
163
+ elif width > height:
164
+ result = Image.new(pil_img.mode, (width, width), background_color)
165
+ result.paste(pil_img, (0, (width - height) // 2))
166
+ return result
167
+ else:
168
+ result = Image.new(pil_img.mode, (height, height), background_color)
169
+ result.paste(pil_img, ((height - width) // 2, 0))
170
+ return result
171
+
172
+ def resize2square(self, pil_img: Image.Image):
173
+ width, height = pil_img.size
174
+ pil_img = pil_img.resize((max(width, height), max(width, height)))
175
+ return pil_img
176
+
177
+ def centralcrop(self, pil_img: Image.Image, rate=[4, 3]):
178
+ width, height = pil_img.size
179
+ size = (width, height)
180
+ min_len = min(size)
181
+ longer_side = 0 if width >= height else 1
182
+ center = (width/2, height/2)
183
+ box = [0, 0, size[0], size[1]]
184
+
185
+ # if longer_side == 0:
186
+ # box[0] = max(0, center[0] - 1/2*min_len/rate[1]*rate[0])
187
+ # box[2] = min(width, center[0] + 1/2*min_len/rate[1]*rate[0])
188
+ # else:
189
+ # box[1] = max(0, center[1] - 1/2*min_len/rate[1]*rate[0])
190
+ # box[3] = min(height, center[1] + 1/2*min_len/rate[1]*rate[0])
191
+ box[longer_side] = max(0, center[longer_side] - 1/2*min_len/rate[1]*rate[0])
192
+ box[2 + longer_side] = min(size[longer_side], center[longer_side] + 1/2*min_len/rate[1]*rate[0])
193
+
194
+ # box = (width/2-min_len/2, height/2-min_len/2, width/2+min_len/2, height/2+min_len/2)
195
+ pil_img = pil_img.crop(box)
196
+ return pil_img
197
+
198
+ def resize2pixels(self, pil_img: Image.Image, max_pixels=None, min_pixels=None):
199
+ width, height = pil_img.size
200
+ new_height, new_width = smart_resize(height, width, factor=1, max_pixels=max_pixels, min_pixels=min_pixels)
201
+ pil_img = pil_img.resize((new_width, new_height))
202
+ return pil_img
203
+
204
+ def get_num_vision_tokens(self, content):
205
+ if isinstance(self.image_processor, Qwen2VLImageProcessor):
206
+ merge_length = self.image_processor.merge_size**2
207
+ if content['type'] == 'image':
208
+ num_image_tokens = content['image']['image_grid_thw'].prod() // merge_length
209
+ else:
210
+ num_image_tokens = content['video']['video_grid_thw'].prod() // merge_length
211
+ return num_image_tokens
212
+ else:
213
+ # 其他模型:image tokens (-> 2x2 compressed) -> add image_newline and image_new
214
+ k = 'image'if content['type'] == 'image' else 'video'
215
+ pixel_values = content[k]['pixel_values'][0]
216
+ n_frames = len(content[k]['pixel_values'])
217
+
218
+ height, width = get_image_size(to_numpy_array(pixel_values))
219
+ num_image_tokens = (height // (self.patch_size * self.merge_size)) * (width // (self.patch_size * self.merge_size) + 1) + 1
220
+ return num_image_tokens * n_frames
221
+
222
+ def batch_decode(self, *args, **kwargs):
223
+ """
224
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
225
+ refer to the docstring of this method for more information.
226
+ """
227
+ return self.tokenizer.batch_decode(*args, **kwargs)
228
+
229
+ def decode(self, *args, **kwargs):
230
+ """
231
+ This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
232
+ the docstring of this method for more information.
233
+ """
234
+ return self.tokenizer.decode(*args, **kwargs)
235
+
236
+ @property
237
+ def model_input_names(self):
238
+ tokenizer_input_names = self.tokenizer.model_input_names
239
+ image_processor_input_names = self.image_processor.model_input_names
240
+ return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
eval_scripts/DREAM-1K/tarsier/dataset/utils.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import List
15
+ import os
16
+ from PIL import Image, ImageSequence
17
+ import decord
18
+
19
+ VALID_DATA_FORMAT_STRING = "Input data must be {'.jpg', '.jpeg', '.png', '.tif'} for image; or {'.mp4', '.avi', '.webm', '.mov', '.mkv', '.wmv', '.gif'} for videos!"
20
+
21
+ # 均匀抽帧,必采样首尾帧。
22
+ def sample_frame_indices(start_frame, total_frames: int, n_frames: int):
23
+ if n_frames == 1:
24
+ return [0] # sample first frame in default
25
+ sample_ids = [round(i * (total_frames - 1) / (n_frames - 1)) for i in range(n_frames)]
26
+ sample_ids = [i + start_frame for i in sample_ids]
27
+ return sample_ids
28
+
29
+ def sample_video(
30
+ video_path: str,
31
+ n_frames: int = None,
32
+ start_time: int = 0,
33
+ end_time: int = -1
34
+ ) -> List[Image.Image]:
35
+
36
+ assert os.path.exists(video_path), f"File not found: {video_path}"
37
+ vr = decord.VideoReader(video_path, num_threads=1, ctx=decord.cpu(0))
38
+ vr.seek(0)
39
+ total_frames = len(vr)
40
+ fps = vr.get_avg_fps()
41
+
42
+ start_frame = 0
43
+ end_frame = total_frames - 1
44
+ if start_time > 0:
45
+ start_frame = min((total_frames-1), int(fps*start_time))
46
+ if end_time > 0:
47
+ end_frame = max(start_frame, int(fps*end_time))
48
+ end_frame = min(end_frame, (total_frames-1))
49
+ frame_indices = sample_frame_indices(
50
+ start_frame=start_frame,
51
+ total_frames=end_frame - start_frame + 1,
52
+ n_frames=n_frames,
53
+ )
54
+
55
+ frames = vr.get_batch(frame_indices).asnumpy()
56
+ frames = [Image.fromarray(f).convert('RGB') for f in frames]
57
+ return frames
58
+
59
+ def sample_gif(
60
+ gif_path: str,
61
+ n_frames:int = None,
62
+ start_time: int = 0,
63
+ end_time: int = -1
64
+ ) -> List[Image.Image]:
65
+
66
+ assert os.path.exists(gif_path), f"File not found: {gif_path}"
67
+
68
+ gif_frames = Image.open(gif_path)
69
+
70
+ start_frame = 0
71
+ end_frame = gif_frames.n_frames - 1
72
+ frame_indices = sample_frame_indices(
73
+ start_frame=start_frame,
74
+ total_frames=end_frame - start_frame + 1,
75
+ n_frames=n_frames,
76
+ )
77
+
78
+ frames = []
79
+ i = 0
80
+ for frame in ImageSequence.Iterator(gif_frames):
81
+ if i in frame_indices:
82
+ frames.append(frame.convert('RGB'))
83
+ i += 1
84
+ return frames
85
+
86
+ def sample_image(
87
+ image_path: str,
88
+ n_frames: int = None,
89
+ start_time: int = 0,
90
+ end_time: int = -1
91
+ ):
92
+ assert os.path.exists(image_path), f"File not found: {image_path}"
93
+ image = Image.open(image_path).convert('RGB')
94
+ return [image]
95
+
96
+ def get_visual_type(input_file):
97
+ ext = os.path.splitext(input_file)[-1]
98
+ if ext in {'.gif'}:
99
+ return 'gif'
100
+ elif ext in {'.mp4', '.avi', '.webm', '.mov', '.mkv', '.wmv'}:
101
+ return 'video'
102
+ elif ext in {'.jpg', '.jpeg', '.png', '.tif'}:
103
+ return 'image'
104
+ else:
105
+ print(f"{VALID_DATA_FORMAT_STRING} But found {ext}!")
106
+ return 'unk'
107
+
108
+ def get_benchmarks(benchmarks):
109
+ final_benchmarks = []
110
+ type2bm = {
111
+ 'dream': ['dream'],
112
+ 'caption': ['msvd-caption', 'msr-vtt-caption', 'vatex-caption'],
113
+ 'mc_qa': ['next-qa', 'egoschema', 'mvbench', 'video-mme'],
114
+ 'oe_qa': ['msvd-qa', 'msr-vtt-qa', 'tgif-qa', 'anet-qa'],
115
+ }
116
+ for bm in benchmarks:
117
+ bm = bm.lower()
118
+ if bm in final_benchmarks:
119
+ continue
120
+ if bm == 'all':
121
+ for v in type2bm.values():
122
+ final_benchmarks.extend(v)
123
+ return final_benchmarks
124
+ if bm in type2bm:
125
+ final_benchmarks.extend(type2bm[bm])
126
+ else:
127
+ final_benchmarks.append(bm)
128
+ return final_benchmarks
129
+
130
+ def check_data_format(data):
131
+ for msg in data['messages']:
132
+ if isinstance(msg['content'], dict):
133
+ msg['content'] = [msg['content']]
134
+ for content in msg['content']:
135
+ assert content['type'] in {'image', 'video', 'text'}, f"content['type']={content['type']} MUST be one of ['image', 'video', 'text']"
136
+ if content['type'] != "text":
137
+ media_path_key = f"{content['type']}_file"
138
+ meida_paths = content[content['type']][media_path_key]
139
+ if isinstance(meida_paths, str):
140
+ meida_paths = [meida_paths]
141
+ for path in meida_paths:
142
+ assert os.path.exists(path), f"File not found: {path}"
143
+
144
+ def format_one_sample(media_file=None, prompt="Describe the video in detail."):
145
+ sample = {
146
+ "messages": []
147
+ }
148
+ user_content = {
149
+ "role": "user",
150
+ "content": []
151
+ }
152
+ if media_file is not None:
153
+ media_type = get_visual_type(media_file)
154
+ if media_type in ("video", "gif"):
155
+ media_type = "video"
156
+ media_path_key = f"{media_type}_file"
157
+ user_content["content"].append({
158
+ "type": media_type,
159
+ media_type: {
160
+ media_path_key: media_file,
161
+ }
162
+ })
163
+ user_content["content"].append({
164
+ "type": "text",
165
+ "text": prompt
166
+ })
167
+
168
+ assistant_content = {
169
+ "role": "assistant",
170
+ "content": []
171
+ }
172
+
173
+ sample["messages"].append(user_content)
174
+ sample["messages"].append(assistant_content)
175
+ if media_file is not None:
176
+ sample["task"] = f"{media_type}/QA"
177
+ else:
178
+ sample["task"] = 'text-only'
179
+ check_data_format(sample)
180
+ return sample
181
+
182
+
183
+ class DictToObject(object):
184
+ def __init__(self, dictionary):
185
+ for key, value in dictionary.items():
186
+ setattr(self, key, value)
eval_scripts/DREAM-1K/tarsier/evaluation/evaluate.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ import random
16
+
17
+ from .metrics import CIDErMetric, GPTMetric, DREAMGPTMetric, AccuracyMetric, VideoMMEAccuracyMetric
18
+ import sys
19
+ sys.path.append('eval_scripts/DREAM-1K/tarsier')
20
+ from tools.rw_utils import read_jsonlines
21
+ from tools.color import Color
22
+ from dataset.utils import get_benchmarks
23
+
24
+ def extract_item_for_eval(data_dict):
25
+ item = {}
26
+ prompt, prediction, reference = [], [], []
27
+ for msg in data_dict['messages']:
28
+ for content in msg['content']:
29
+ if content['type'] == 'text':
30
+ if msg['role'] == 'user':
31
+ prompt.append(content['text'])
32
+ elif msg['role'] == 'assistant':
33
+ if content.get('reference'):
34
+ reference.append(content['reference'])
35
+ prediction.append(content['text'])
36
+ # prediction.append(content['reference']) # debug
37
+
38
+ item['prompt'] = ''.join(prompt)
39
+ item['prediction'] = ''.join(prediction)
40
+ item['response'] = ''.join(reference)
41
+
42
+ item['dataset'] = data_dict['dataset']
43
+ item['idx'] = f"{data_dict['dataset']}_{data_dict['idx']}"
44
+ extra_info = data_dict.get('extra_info', None)
45
+ vid = data_dict.get('vid', None)
46
+ if vid is not None:
47
+ item['vid'] = vid
48
+ if extra_info:
49
+ item['events'] = extra_info.get('events', None)
50
+ item['extra_info'] = extra_info
51
+ if 'is_hard' in data_dict:
52
+ item['is_hard'] = data_dict['is_hard']
53
+
54
+ return item
55
+
56
+ def read_dataset(path, dataset_name):
57
+ if os.path.isdir(path):
58
+ lines = []
59
+ for f in os.listdir(path):
60
+ if f.endswith('.jsonl'):
61
+ lines.extend(read_jsonlines(os.path.join(path, f)))
62
+ else:
63
+ lines = read_jsonlines(path)
64
+ dataset = []
65
+ idxs = set()
66
+ for l in lines:
67
+ if l['dataset'].split('/')[0] != dataset_name:
68
+ continue
69
+ idx = f"{l['dataset']}_{l['idx']}"
70
+ if idx in idxs:
71
+ continue
72
+
73
+ idxs.add(idx)
74
+ item = extract_item_for_eval(l)
75
+ dataset.append(item)
76
+ return dataset
77
+
78
+ METRIC_MAPPING = {
79
+ 'CIDErMetric': CIDErMetric,
80
+ 'GPTMetric': GPTMetric,
81
+ 'AccuracyMetric': AccuracyMetric,
82
+ 'DREAMGPTMetric': DREAMGPTMetric,
83
+ 'VideoMMEAccuracyMetric': VideoMMEAccuracyMetric
84
+ }
85
+
86
+ def evaluate(pred_file, METRIC, dataset_name, sample_num=-1, verbose = False):
87
+ dataset = read_dataset(pred_file, dataset_name)
88
+ if len(dataset) == 0:
89
+ return
90
+ if sample_num > 0:
91
+ dataset = random.sample(dataset, sample_num)
92
+ metric = METRIC(dataset_name=dataset_name, verbose=verbose)
93
+ metric.process(dataset)
94
+ metric.summarize_metric()
95
+ metric.save_results(pred_file)
96
+ if isinstance(metric, DREAMGPTMetric):
97
+ metric.save_eval_infos(pred_file)
98
+
99
+ def evaluate_all(pred_file, METRIC2DATASET, sample_num=-1, verbose = False):
100
+ for METRIC, dataset_name in METRIC2DATASET:
101
+ if isinstance(METRIC, str):
102
+ METRIC = METRIC_MAPPING[METRIC]
103
+ print(f"### Start Evaluating on {dataset_name}")
104
+ evaluate(pred_file, METRIC, dataset_name, sample_num, verbose)
105
+ print(f"### Finish Evaluating on {dataset_name}")
106
+
107
+
108
+ if __name__ == '__main__':
109
+ import argparse
110
+
111
+ parser = argparse.ArgumentParser()
112
+ parser.add_argument('--pred_file', type=str)
113
+ parser.add_argument('--benchmarks', nargs='+', default=["all"], help="Default as 'all' to evaluate on all benchmarks; Also could be task types: ('dream', 'caption', 'mc_qa', 'oe_qa'); And specific benchmark names: ('dream', 'msvd-caption', 'msr-vtt-caption', 'vatex-caption', 'next-qa', 'egoschema', 'mvbench', 'tvbench', 'video-mme', 'msvd-qa', 'msr-vtt-qa', 'tgif-qa', 'anet-qa', 'favor-bench')")
114
+ parser.add_argument('--sample_num', type=int, default=-1)
115
+ parser.add_argument('--verbose', action='store_true')
116
+
117
+ args = parser.parse_args()
118
+
119
+ args.benchmarks = get_benchmarks(args.benchmarks)
120
+ print("### Selected Benchmarks:", args.benchmarks)
121
+
122
+ Benchmark2Metric = {
123
+ # Multi-chocie QA
124
+ 'next-qa': 'AccuracyMetric',
125
+ 'egoschema': 'AccuracyMetric',
126
+ 'mvbench': 'AccuracyMetric',
127
+ 'tvbench': 'AccuracyMetric',
128
+ 'video-mme': 'VideoMMEAccuracyMetric',
129
+ 'favor-bench': 'AccuracyMetric',
130
+
131
+ # Open-ended QA
132
+ 'msvd-qa': 'GPTMetric',
133
+ 'msr-vtt-qa': 'GPTMetric',
134
+ 'tgif-qa': 'GPTMetric',
135
+ 'anet-qa': 'GPTMetric',
136
+
137
+ # Caption DREAM
138
+ 'dream': 'DREAMGPTMetric',
139
+
140
+ # Caption CIDEr
141
+ 'msvd-caption': 'CIDErMetric',
142
+ 'msr-vtt-caption': 'CIDErMetric',
143
+ 'vatex-caption': 'CIDErMetric',
144
+ }
145
+
146
+ Benchmark2Dataset = {
147
+ 'dream': 'DREAM',
148
+
149
+ 'next-qa': 'Next-QA-val-multi_choice',
150
+ 'egoschema': 'EgoSchema',
151
+ 'mvbench': 'MVBench',
152
+ 'tvbench': 'TVBench',
153
+ 'video-mme': 'Video-MME',
154
+ 'favor-bench': 'FAVOR-Bench',
155
+
156
+ 'msvd-qa': 'MSVD-QA-val',
157
+ 'msr-vtt-qa': 'MSR-VTT-QA-val',
158
+ 'tgif-qa': 'TGIF-QA-test',
159
+ 'anet-qa': 'ActivityNet-QA-test',
160
+
161
+ 'msvd-caption': 'MSVD-Caption-test',
162
+ 'msr-vtt-caption': 'MSR-VTT-Caption-test',
163
+ 'vatex-caption': 'VATEX-test',
164
+ }
165
+
166
+ METRIC2DATASET = []
167
+
168
+ for bm in args.benchmarks:
169
+ if bm not in Benchmark2Metric:
170
+ print(Color.red(f"Unknown benchmark: {bm}"))
171
+ continue
172
+
173
+ METRIC2DATASET.append([Benchmark2Metric[bm], Benchmark2Dataset[bm]])
174
+
175
+ evaluate_all(args.pred_file, METRIC2DATASET, args.sample_num, args.verbose)
176
+
177
+ # python3 -m evaluation.evaluate --pred_file $pred_file --sample_num=100
eval_scripts/DREAM-1K/tarsier/evaluation/metrics/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .evaluate_caption_cider import CIDErMetric
2
+ from .evaluate_qa_oe_gpt import GPTMetric
3
+ from .evaluate_qa_mc import AccuracyMetric
4
+ from .evaluate_dream_gpt import DREAMGPTMetric
5
+ from .evaluate_video_mme import VideoMMEAccuracyMetric
eval_scripts/DREAM-1K/tarsier/evaluation/metrics/evaluate_caption_cider.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import json
15
+ from typing import List, Dict
16
+ import os
17
+ from pycocoevalcap.cider.cider import Cider
18
+ import sys
19
+ sys.path.append('eval_scripts/DREAM-1K/tarsier')
20
+ from tools.ptbtokenizer import PTBTokenizer
21
+
22
+ from tools.color import Color
23
+
24
+ class CIDErMetric:
25
+ def __init__(self, dataset_name, verbose=False) -> None:
26
+ self.dataset_name = dataset_name
27
+ self.tokenizer = PTBTokenizer()
28
+ self.scorer = Cider()
29
+ self.score = None
30
+ self.results = []
31
+ self.dataset = []
32
+ self.verbose = verbose
33
+
34
+ def add(self, data):
35
+ self.dataset.append(data)
36
+
37
+ def process(self, dataset: List[Dict]):
38
+ references, predictions = {}, {}
39
+ for i, data in enumerate(dataset):
40
+ ref = data['response']
41
+ pred = data['prediction']
42
+
43
+ if isinstance(ref, str):
44
+ ref = [ref]
45
+
46
+ references[i] = [{'caption': r.lower()} for r in ref]
47
+ predictions[i] = [{'caption': pred.lower()}]
48
+
49
+ references = self.tokenizer.tokenize(references)
50
+ predictions = self.tokenizer.tokenize(predictions)
51
+ score, scores = self.scorer.compute_score(references, predictions)
52
+ self.score = score
53
+ for data, s in zip(dataset, scores):
54
+ self.results.append({
55
+ 'score': s,
56
+ 'data': data,
57
+ })
58
+
59
+ def summarize_metric(self):
60
+ if self.verbose:
61
+ for result in self.results:
62
+ print(Color.blue(json.dumps(result['data'])))
63
+ print(Color.red(f"CIDEr score: {result['score']}"))
64
+ print(f'=====Evaluation Summary=====')
65
+ self.eval_records = [
66
+ f'Dataset: {self.dataset_name}\tMetric: CIDEr',
67
+ f'#Successful Results: {len(self.results)}',
68
+ f'CIDEr score: {round(self.score*100, 1)}'
69
+ ]
70
+ for info in self.eval_records:
71
+ print(info)
72
+
73
+ def save_results(self, pred_path):
74
+ if os.path.isdir(pred_path):
75
+ output_dir = os.path.join(pred_path, 'eval_records')
76
+ else:
77
+ output_dir = os.path.join(os.path.dirname(pred_path), 'eval_records')
78
+ os.makedirs(output_dir, exist_ok=True)
79
+ fout = open(os.path.join(output_dir, f'{self.dataset_name}_eval_result.txt'), 'w')
80
+ for info in self.eval_records:
81
+ fout.write(info+'\n')
82
+ fout.close()
eval_scripts/DREAM-1K/tarsier/evaluation/metrics/evaluate_dream_gpt.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import json
15
+ import numpy as np
16
+ import ast
17
+ import time
18
+ from typing import List, Dict
19
+ from tqdm import tqdm
20
+ from pathos.multiprocessing import ProcessingPool as Pool
21
+ import func_timeout
22
+ from func_timeout import func_set_timeout
23
+
24
+ import sys
25
+ sys.path.append('eval_scripts/DREAM-1K/tarsier')
26
+ import re
27
+ import os
28
+ from copy import deepcopy
29
+ from traceback import format_exc
30
+
31
+ try:
32
+ with open("apikey.txt", "r") as f:
33
+ api_key = f.read()
34
+ except:
35
+ api_key = ''
36
+
37
+ def call_gpt35(msg):
38
+ while True:
39
+ try:
40
+ response = openai.ChatCompletion.create(
41
+ model="gpt-3.5-turbo",
42
+ messages=msg,
43
+ api_key=api_key,
44
+ request_timeout=5)
45
+ break
46
+ except:
47
+ print("Timeout, retrying...")
48
+ time.sleep(5)
49
+
50
+ output_text = response['choices'][0]['message']['content']
51
+ return output_text
52
+
53
+ def count_f1(r, p):
54
+ return 2*r*p/(r+p)
55
+
56
+
57
+ def call_azure_gpt_api(events, reference, prediction, model):
58
+ if len(events) == 0:
59
+ events = [reference.replace('\n', ' ')]
60
+ messages=[
61
+ {
62
+ "role": "user",
63
+ "content":
64
+ "Given a video description and a list of events. For each event, classify the relationship between the video description and the event into three classes: entailment, neutral, contradiction.\n"
65
+ "- \"entailment\" means that the video description entails the event.\n"
66
+ "- \"contradiction\" means that some detail in the video description contradicts with the event.\n"
67
+ "- \"neutral\" means that the relationship is neither \"entailment\" or \"contradiction\".\n\n"
68
+ f"Video Description:\n{prediction}\n\n"
69
+ f"Events: {events}\n"
70
+
71
+ "Output a JSON formed as:\n"
72
+ "{\n"
73
+ " \"events\": [\n"
74
+ " {\"event\": \"copy an event here\", \"relationship\": \"put class name here\", \"reason\": \"give your reason here\"},\n"
75
+ " ...\n"
76
+ " ]\n"
77
+ "}\n\n"
78
+ "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only output the JSON. Output:"
79
+ }
80
+ ]
81
+
82
+ completion = call_gpt35(messages)
83
+ return completion
84
+
85
+
86
+ def call_azure_gpt_api_for_events(caption, model):
87
+ messages=[
88
+ {
89
+ "role": "user",
90
+ "content":
91
+ "Bellow is a description of a video clip:\n"
92
+ f"Video Description: {caption}\n\n"
93
+
94
+ "Extract at most 10 key events from the above video description paragraph. Requirements\n:"
95
+ "- An event must include an action, motion or movement (NOT STATIC INFORMATION). DON'T repeat same events.\n"
96
+ "- Every event is represented by a brief sentence within 10 words, with a subject, a predicate and optionally an object, avoid unnecessary appearance descriptions.\n"
97
+ "- Every event must be atomic, meaning that it cannot be further split into multiple events.\n"
98
+ "- Scene cuts and camera motions are NOT events.\n"
99
+ "- Substitute pronouns by the nouns they refer to.\n\n"
100
+ "Please generate the response in the form of a Python dictionary string with keys \"events\". The value of \"events\" is a List(str), of which each item is an event. "
101
+ "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. "
102
+ "For example, your response should look like this: {\"events\": [event1, event2, ...]}"
103
+ }
104
+ ]
105
+
106
+ completion = call_gpt35(messages)
107
+ return completion
108
+
109
+ def try_call_api_for_eval(events, answer, prediction, model, verbose=False, max_retry=100):
110
+ for i in range(max_retry):
111
+ gpt_q = call_azure_gpt_api(events, answer, prediction, model)
112
+ if gpt_q is not None:
113
+ gpt_q = gpt_q.strip()
114
+ gpt_q = re.sub(r'\n+', '\n', gpt_q)
115
+ gpt_q = re.sub(r'\s+', ' ', gpt_q)
116
+
117
+ if gpt_q.startswith("```json"):
118
+ gpt_q = gpt_q.replace("```json", "").replace("```", "").strip()
119
+ elif gpt_q.startswith("```python"):
120
+ gpt_q = gpt_q.replace("```python", "").replace("```", "").strip()
121
+ if not gpt_q.startswith('{'):
122
+ gpt_q = '{' + gpt_q
123
+ if not gpt_q.endswith('}'):
124
+ gpt_q = gpt_q + '}'
125
+ gpt_q = gpt_q.replace("True", "true").replace("False", "false")
126
+ gpt_q = gpt_q.replace("} {", "}, {").replace("}{", "}, {")
127
+ gpt_q = gpt_q.replace(",\n}", "\n}").replace(", \n}", "\n}").replace(", }", "}").replace(",}", "}")
128
+ gpt_q = gpt_q.replace(",\n]", "\n]").replace(", \n]", "\n]").replace(", ]", "]").replace(",]", "]")
129
+ gpt_q = gpt_q.replace("[Placeholder]", "null")
130
+ gpt_q = gpt_q.replace("{Events:", "").strip()
131
+
132
+ return gpt_q, True
133
+
134
+ return f"Exceed max try: {max_retry}", False
135
+
136
+ def try_call_api_for_events(caption, model, verbose=False):
137
+ for i in range(100):
138
+ gpt_q = call_azure_gpt_api_for_events(caption, model)
139
+ if gpt_q is not None:
140
+ if gpt_q.startswith("```json"):
141
+ gpt_q = gpt_q.replace("```json", "").replace("```", "").strip()
142
+ elif gpt_q.startswith("```python"):
143
+ gpt_q = gpt_q.replace("```python", "").replace("```", "").strip()
144
+ return gpt_q, True
145
+
146
+ return "Exceed max try: 5", False
147
+
148
+ def extract_events(inputs, is_pred=False, max_retry=100):
149
+ data, model, verbose = inputs
150
+ if is_pred:
151
+ caption = data['prediction'].lower()
152
+ else:
153
+ caption = data['response'].lower()
154
+ caption = caption.replace("\"", "\'")
155
+ retry = 0
156
+ while True and (retry<max_retry or max_retry<0):
157
+ retry += 1
158
+ result, success = try_call_api_for_events(caption, model, verbose)
159
+ if not success:
160
+ print(f"[error]: try_call_api_for_events failed!", flush=True)
161
+ continue
162
+ try:
163
+ result = ast.literal_eval(result)
164
+ events = result['events']
165
+ if verbose:
166
+ print("pred_events=" if is_pred else "gt events=", events, ":", caption)
167
+ assert isinstance(events, list) and (len(events)==0 or isinstance(events[0], str))
168
+ return events
169
+ except Exception as e:
170
+ print(format_exc(), flush=True)
171
+ continue
172
+ print("[error]: Exceed max_retry!", flush=True)
173
+ raise ValueError("[error]: Exceed max_retry!")
174
+
175
+
176
+ def evaluate_one_sample(events, response, prediction, model, verbose, return_hit_num=False, is_recall=False, max_retry=100):
177
+ retry = 0
178
+ while True and (retry<max_retry or max_retry<0):
179
+ retry += 1
180
+ try:
181
+ assert isinstance(events, list)
182
+ result = None
183
+ result, success = try_call_api_for_eval(events, response, prediction, model, verbose)
184
+ if not success:
185
+ print("[error]: try_call_api_for_eval failed!", flush=True)
186
+ continue
187
+ try:
188
+ events_filled = json.loads(result)
189
+ events_filled = events_filled['events']
190
+ except Exception as e:
191
+ print("load json failed:", result)
192
+ continue
193
+ assert len(events) == len(events_filled) or (len(events) == 0 and len(events_filled) == 1)
194
+ num_matched_events = 0
195
+ try:
196
+ for event in events_filled:
197
+ pred = event['relationship'].strip().lower()
198
+ assert pred in ['entailment', 'neutral', 'contradiction']
199
+ pos_classes = ['entailment'] if is_recall else ['entailment', 'neutral']
200
+ if pred in pos_classes:
201
+ num_matched_events += 1
202
+ except Exception as e:
203
+ print(f"Invalid response: {events_filled}")
204
+ continue
205
+ if len(events) == 0:
206
+ motion_score = 1.0
207
+ else:
208
+ motion_score = num_matched_events / len(events)
209
+ if return_hit_num:
210
+ return motion_score, events_filled, f"hit: {num_matched_events} / {len(events)}"
211
+ return motion_score
212
+ except Exception as e:
213
+ print(format_exc(), flush=True)
214
+ continue
215
+ time.sleep(1)
216
+ print("[error]: Exceed max_retry!", flush=True)
217
+ raise ValueError(f"[error]: Exceed max_retry!")
218
+
219
+
220
+ def process_one_sample(inputs):
221
+ data, model, verbose = inputs
222
+ response, prediction = data['response'].lower(), data['prediction'].lower()
223
+ result = None
224
+ try:
225
+ if isinstance(data.get('events', None), list):
226
+ gt_events = data['events']
227
+ else:
228
+ gt_events = extract_events(inputs, is_pred=False)
229
+ pred_events = extract_events(inputs, is_pred=True)
230
+ assert isinstance(gt_events, list) and isinstance(pred_events, list)
231
+ result = {}
232
+ motion_score_r, events_filled_r, hit_num_r = evaluate_one_sample(gt_events, response, prediction, model, verbose, return_hit_num=True, is_recall=True)
233
+ motion_score_p, events_filled_p, hit_num_p = evaluate_one_sample(pred_events, prediction, response, model, verbose, return_hit_num=True, is_recall=True)
234
+ result['score_r'] = motion_score_r
235
+ result['score_p'] = motion_score_p
236
+ result['eval_infos'] = {
237
+ 'idx': data['idx'],
238
+ 'gt': response,
239
+ 'pred': prediction,
240
+ 'events_gt': events_filled_r,
241
+ 'hit_num_recall': hit_num_r,
242
+ 'events_pred': events_filled_p,
243
+ "hit_num_precision": hit_num_p,
244
+ }
245
+ if 'extra_info' in data:
246
+ result['extra_info'] = data['extra_info']
247
+ except Exception as e:
248
+ if verbose:
249
+ print(e)
250
+ print(f'invalid GPT response: {result}')
251
+ result = None
252
+ return {'success': False, 'result': result, 'data': data}
253
+ return {'success': True, 'result': result, 'data': data}
254
+
255
+ class DREAMGPTMetric:
256
+ def __init__(self, dataset_name, verbose=False) -> None:
257
+ self.dataset_name = dataset_name
258
+ self.num_worker = 64
259
+ # self.model = 'gpt-35-turbo'
260
+ self.model = 'gpt-35-turbo-0125'
261
+ # self.model='gpt-4-1106-preview'
262
+ self.results = []
263
+ self.invalid_results = []
264
+ self.dataset = []
265
+ self.verbose = verbose
266
+ self.eval_infos = []
267
+ self.buckets = {
268
+ "subjects": {
269
+ '<=1': [], '==2': [], '==3': [], '>=4': []
270
+ },
271
+ "shots": {'<=1': [], '==2': [], '==3': [], '>=4': []
272
+ },
273
+ "events": {'<=3': [], 'in [4, 5]': [], 'in [6, 7]': [], '>=8': []
274
+ }
275
+ }
276
+
277
+ def add(self, data):
278
+ self.dataset.append(data)
279
+
280
+ def select_bucket(self, bucket_name, num):
281
+ for key in self.buckets[bucket_name]:
282
+ if eval(f"{num}{key}"):
283
+ return key
284
+ return ''
285
+
286
+ def add_to_bucket(self, bucket_name, data):
287
+ sub_bucket = self.select_bucket(bucket_name, data['result']['extra_info'][f'n_{bucket_name}'])
288
+ if sub_bucket:
289
+ self.buckets[bucket_name][sub_bucket].append(data)
290
+
291
+ def process(self, dataset: List[Dict]):
292
+ self._process_group_by_subtask(dataset)
293
+
294
+ def _process(self, dataset: List[Dict], subtask=None):
295
+ pool = Pool(processes = self.num_worker, )
296
+ inputs = [(d, self.model, self.verbose) for d in dataset]
297
+ results = pool.uimap(process_one_sample, inputs, chunksize = 1)
298
+
299
+ for result in tqdm(results, total = len(dataset), desc=f'eval {subtask}'):
300
+ if subtask:
301
+ result['subtask'] = subtask
302
+ self.update_metric(result)
303
+ pool.close()
304
+ pool.join()
305
+ pool.clear() # MUST
306
+
307
+ def _process_group_by_subtask(self, dataset: List[Dict]):
308
+ def _group_by_subtask(dataset):
309
+ subtasks = {}
310
+ for data in dataset:
311
+ if data['dataset'] not in subtasks:
312
+ subtasks[data['dataset']] = []
313
+ subtasks[data['dataset']].append(data)
314
+ return subtasks
315
+ subtasks = _group_by_subtask(dataset)
316
+ for subtask, subdata in subtasks.items():
317
+ self._process(subdata, subtask)
318
+
319
+ def update_metric(self, result):
320
+ if result['success']:
321
+ self.results.append(result)
322
+ else:
323
+ self.invalid_results.append(result)
324
+
325
+ def summarize_metric(self):
326
+ self._summarize_metric_by_subtask()
327
+ self._summarize_metric_by_bucket()
328
+
329
+ def _summarize_metric_by_subtask(self):
330
+ from prettytable import PrettyTable
331
+ self.table = PrettyTable(['Task', 'F1 Score', 'Action Recall', 'Action Precision', 'Success', 'Failed'])
332
+ def _group_by_subtask():
333
+ sub_results = {}
334
+ sub_invalid_results = {}
335
+ for data in self.results:
336
+ if data['subtask'] not in sub_results:
337
+ sub_results[data['subtask']] = []
338
+ sub_results[data['subtask']].append(data)
339
+ for data in self.invalid_results:
340
+ if data['subtask'] not in sub_invalid_results:
341
+ sub_invalid_results[data['subtask']] = []
342
+ sub_invalid_results[data['subtask']].append(data)
343
+ return sub_results, sub_invalid_results
344
+ sub_results, sub_invalid_results = _group_by_subtask()
345
+ overall_avg_recall = []
346
+ overall_avg_precision = []
347
+ subtasks = list(sub_results.keys())
348
+ subtasks.sort()
349
+ for subtask in subtasks:
350
+ sub_rsts = sub_results[subtask]
351
+ sub_in_rsts = sub_invalid_results.get(subtask, [])
352
+ recalls = []
353
+ precisions = []
354
+ for result in sub_rsts:
355
+ r, p, infos = result['result']['score_r'], result['result']['score_p'], result['result']['eval_infos']
356
+ recalls.append(r)
357
+ precisions.append(p)
358
+ self.eval_infos.append(infos)
359
+ avg_recall = np.average(recalls)
360
+ avg_precision = np.average(precisions)
361
+ f1 = count_f1(avg_recall, avg_precision)
362
+ overall_avg_recall.append(avg_recall)
363
+ overall_avg_precision.append(avg_precision)
364
+ task_name = subtask
365
+ self.table.add_row([task_name, round(f1, 3), round(avg_recall, 3), round(avg_precision, 3), len(sub_rsts), len(sub_in_rsts)])
366
+ overall_recall = np.average(overall_avg_recall)
367
+ overall_precision = np.average(overall_avg_precision)
368
+ overall_f1 = count_f1(overall_recall, overall_precision)
369
+ self.table.add_row(['OVERALL', round(overall_f1, 3), round(overall_recall, 3), round(overall_precision, 3), len(self.results), len(self.invalid_results)])
370
+ print(f'=====DREAM Evaluation Summary=====')
371
+ print(self.table)
372
+
373
+
374
+ def _summarize_metric_by_bucket(self):
375
+ from prettytable import PrettyTable
376
+ self.bucket_tables = []
377
+ for bucket in self.buckets:
378
+ table = PrettyTable(['Score'] + list(self.buckets[bucket].keys()))
379
+ for data in self.results:
380
+ self.add_to_bucket(bucket_name=bucket, data=data)
381
+ bucket_result = {}
382
+ for sub_bucket in self.buckets[bucket]:
383
+ recalls = []
384
+ precisions = []
385
+ for result in self.buckets[bucket][sub_bucket]:
386
+ r, p = result['result']['score_r'], result['result']['score_p']
387
+ recalls.append(r)
388
+ precisions.append(p)
389
+ avg_recall = np.average(recalls)
390
+ avg_precision = np.average(precisions)
391
+ f1 = count_f1(avg_recall, avg_precision)
392
+ bucket_result[sub_bucket] = (avg_recall, avg_precision, f1)
393
+
394
+ raw = []
395
+ scores = ['Recall', 'Precision', 'F1']
396
+ for i in range(len(scores)):
397
+ raw = [scores[i]]
398
+ for sub_bucket in bucket_result:
399
+ raw.append(round(bucket_result[sub_bucket][i], 3))
400
+ table.add_row(raw)
401
+ sample_num = ['Count']
402
+ for k in self.buckets[bucket]:
403
+ sample_num.append(len(self.buckets[bucket][k]))
404
+ table.add_row(sample_num)
405
+ bucket_info = f'\n=====DREAM Evaluation Split by Bucket #{bucket}====='
406
+ print(bucket_info)
407
+ print(table)
408
+ self.bucket_tables.append(bucket_info)
409
+ self.bucket_tables.append(deepcopy(table))
410
+
411
+ def save_results(self, pred_path):
412
+ if os.path.isdir(pred_path):
413
+ output_dir = os.path.join(pred_path, 'eval_records')
414
+ else:
415
+ output_dir = os.path.join(os.path.dirname(pred_path), 'eval_records')
416
+ os.makedirs(output_dir, exist_ok=True)
417
+ model_flag = os.path.basename(pred_path).split('.')[0]
418
+ fout = open(os.path.join(output_dir, f'{self.dataset_name}_{model_flag}_eval_result.txt'), 'w')
419
+ print(self.table, file=fout)
420
+ for bucket_info in self.bucket_tables:
421
+ print(bucket_info)
422
+ fout.close()
423
+
424
+ def save_eval_infos(self, pred_path):
425
+ if os.path.isdir(pred_path):
426
+ output_dir = os.path.join(pred_path, 'eval_records')
427
+ else:
428
+ output_dir = os.path.join(os.path.dirname(pred_path), 'eval_records')
429
+ os.makedirs(output_dir, exist_ok=True)
430
+ model_flag = os.path.basename(pred_path).split('.')[0]
431
+ fout = open(os.path.join(output_dir, f'DREAM_{model_flag}_eval_infos.jsonl'), 'w')
432
+ for info in self.eval_infos:
433
+ fout.write(json.dumps(info) +'\n')
434
+ fout.close()
435
+ print(f"DREAM evaluation information saved in: {os.path.join(output_dir, 'DREAM_eval_infos.jsonl')}", flush=True)
436
+
eval_scripts/DREAM-1K/tarsier/evaluation/metrics/evaluate_qa_mc.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import json
15
+ import numpy as np
16
+ import os
17
+ from typing import List, Dict
18
+ import sys
19
+ sys.path.append('eval_scripts/DREAM-1K/tarsier')
20
+ from tools.color import Color
21
+
22
+
23
+ class AccuracyMetric:
24
+ def __init__(self, dataset_name, verbose=False) -> None:
25
+ self.dataset_name = dataset_name
26
+ self.results = []
27
+ self.invalid_results = []
28
+ self.dataset = []
29
+ self.verbose = verbose
30
+
31
+ def add(self, data):
32
+ self.dataset.append(data)
33
+
34
+ def process(self, dataset: List[Dict]):
35
+ if self.dataset_name in ['MVBench', 'TVBench', 'FAVOR-Bench']:
36
+ return self._process_group_by_subtask(dataset)
37
+ else:
38
+ return self._process(dataset)
39
+
40
+ def _process(self, dataset: List[Dict], subtask=None):
41
+ for data in dataset:
42
+ prompt, response, prediction = data['prompt'], data['response'], data['prediction']
43
+ prediction = prediction.replace('(', '').replace(')', '').strip()
44
+ response = response.replace('(', '').replace(')', '').strip()[0]
45
+ if len(prediction) <= 0:
46
+ success = False
47
+ else:
48
+ prediction = prediction[0]
49
+ if '0'<=prediction<='5':
50
+ prediction = chr(int(prediction) + ord('A'))
51
+ success = prediction.isupper() and prediction.isalpha() and len(prediction) == 1
52
+ if success:
53
+ rst = {
54
+ 'success': success,
55
+ 'data': data,
56
+ 'result': {'acc': response == prediction}
57
+ }
58
+ if subtask:
59
+ rst['subtask'] = subtask
60
+ self.results.append(rst)
61
+ else:
62
+ rst = {
63
+ 'success': success,
64
+ 'data': data,
65
+ 'result': {'acc': response == prediction}
66
+ }
67
+ if subtask:
68
+ rst['subtask'] = subtask
69
+ self.invalid_results.append(rst)
70
+
71
+ def _process_group_by_subtask(self, dataset: List[Dict]):
72
+ def _group_by_subtask(dataset):
73
+ subtasks = {}
74
+ for data in dataset:
75
+ if data['dataset'] not in subtasks:
76
+ subtasks[data['dataset']] = []
77
+ subtasks[data['dataset']].append(data)
78
+ return subtasks
79
+ subtasks = _group_by_subtask(dataset)
80
+ for subtask, subdata in subtasks.items():
81
+ self._process(subdata, subtask)
82
+
83
+ def summarize_metric(self):
84
+ if self.dataset_name in ['MVBench', 'TVBench', 'FAVOR-Bench']:
85
+ return self._summarize_metric_by_subtask()
86
+ else:
87
+ return self._summarize_metric()
88
+
89
+ def _summarize_metric(self):
90
+ if self.verbose:
91
+ for result in self.results + self.invalid_results:
92
+ print(f"{Color.red('Success: ' + str(result['success']))}")
93
+ print(Color.blue(json.dumps(result['data'], ensure_ascii=False)))
94
+ print(f"{Color.green('Accuracy: ' + str(result['result']['acc']))}")
95
+
96
+ accs = []
97
+ for result in self.results:
98
+ acc = result['result']['acc']
99
+ accs.append(acc)
100
+ avg_acc = np.average(accs)
101
+
102
+ self.eval_records = [
103
+ f'=====Evaluation Summary=====',
104
+ f'Dataset: {self.dataset_name}\tMetric: Accuracy',
105
+ f'#Successful Results: {len(self.results)}\n#Failed Results: {len(self.invalid_results)}',
106
+ f'Accuracy: {round(avg_acc*100, 1)}',
107
+ ]
108
+ for info in self.eval_records:
109
+ print(info)
110
+
111
+ def _summarize_metric_by_subtask(self):
112
+ from prettytable import PrettyTable
113
+ self.table = PrettyTable(['Task','Accuracy','Success','Failed'])
114
+ def _group_by_subtask():
115
+ sub_results = {}
116
+ sub_invalid_results = {}
117
+ for data in self.results:
118
+ if data['subtask'] not in sub_results:
119
+ sub_results[data['subtask']] = []
120
+ sub_results[data['subtask']].append(data)
121
+ for data in self.invalid_results:
122
+ if data['subtask'] not in sub_invalid_results:
123
+ sub_invalid_results[data['subtask']] = []
124
+ sub_invalid_results[data['subtask']].append(data)
125
+ return sub_results, sub_invalid_results
126
+ sub_results, sub_invalid_results = _group_by_subtask()
127
+ oa_accs = []
128
+ subtasks = list(sub_results.keys())
129
+ # subtasks.sort(key=lambda x:f"{x.split('/')[-1].split(' ')[0][0]}{x.split('/')[-1].split(' ')[1][0]}")
130
+ subtasks.sort(key=lambda x:x.split('/')[-1])
131
+ for subtask in subtasks:
132
+ sub_rsts = sub_results[subtask]
133
+ sub_in_rsts = sub_invalid_results.get(subtask, [])
134
+ accs = []
135
+ for result in sub_rsts:
136
+ acc = result['result']['acc']
137
+ accs.append(acc)
138
+ oa_accs.append(acc)
139
+ avg_acc = np.average(accs)
140
+ # task_name = f"{subtask.split('/')[-1].split(' ')[0][0]}{subtask.split('/')[-1].split(' ')[1][0]}"
141
+ task_name = subtask.split('/')[-1]
142
+ self.table.add_row([task_name, round(avg_acc*100, 1), len(sub_rsts), len(sub_in_rsts)])
143
+ self.table.add_row(['OVERALL', round(np.average(oa_accs)*100, 1), len(self.results), len(self.invalid_results)])
144
+ print(f'=====Evaluation Summary=====')
145
+ print(self.table)
146
+
147
+ def save_results(self, pred_path):
148
+ if os.path.isdir(pred_path):
149
+ output_dir = os.path.join(pred_path, 'eval_records')
150
+ else:
151
+ output_dir = os.path.join(os.path.dirname(pred_path), 'eval_records')
152
+ os.makedirs(output_dir, exist_ok=True)
153
+ fout = open(os.path.join(output_dir, f'{self.dataset_name}_eval_result.txt'), 'w')
154
+ if self.dataset_name in ['MVBench', 'TVBench', 'FAVOR-Bench']:
155
+ print(self.table, file=fout)
156
+ else:
157
+ for info in self.eval_records:
158
+ fout.write(info+'\n')
159
+ fout.close()
eval_scripts/DREAM-1K/tarsier/evaluation/metrics/evaluate_qa_oe_gpt.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import json
15
+ import numpy as np
16
+ import ast
17
+ import time
18
+ from typing import List, Dict
19
+ from tqdm import tqdm
20
+ from pathos.multiprocessing import ProcessingPool as Pool
21
+ import func_timeout
22
+ from func_timeout import func_set_timeout
23
+ import os
24
+ import sys
25
+ sys.path.append('eval_scripts/DREAM-1K/tarsier')
26
+ from tools.color import Color
27
+
28
+
29
+ def call_azure_gpt_api(question, answer, prediction, model):
30
+
31
+ messages=[
32
+ {
33
+ "role": "system",
34
+ "content":
35
+ "You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. "
36
+ "Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:"
37
+ "------"
38
+ "##INSTRUCTIONS: "
39
+ "- Focus on the meaningful match between the predicted answer and the correct answer.\n"
40
+ "- Consider synonyms or paraphrases as valid matches.\n"
41
+ "- Evaluate the correctness of the prediction compared to the answer."
42
+ },
43
+ {
44
+ "role": "user",
45
+ "content":
46
+ "Please evaluate the following video-based question-answer pair:\n\n"
47
+ f"Question: {question}\n"
48
+ f"Correct Answer: {answer}\n"
49
+ f"Predicted Answer: {prediction}\n\n"
50
+ "Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. "
51
+ "Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING."
52
+ "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. "
53
+ "For example, your response should look like this: {'pred': 'yes', 'score': 4.8}"
54
+ }
55
+ ]
56
+ completion = call_gpt35(messages)
57
+ return completion
58
+
59
+ def try_call_api(question, answer, prediction, model, verbose=False):
60
+ for i in range(5):
61
+ gpt_q = call_azure_gpt_api(question, answer, prediction, model)
62
+ if gpt_q is not None:
63
+ return gpt_q, True
64
+
65
+ return None, False
66
+
67
+ def process_one_sample(inputs):
68
+ data, model, verbose = inputs
69
+ prompt, response, prediction = data['question'], data['response'].lower(), data['prediction'].lower()
70
+ result = None
71
+ try:
72
+ result, success = try_call_api(prompt, response, prediction, model, verbose)
73
+ if not success:
74
+ raise ValueError(result)
75
+ result = ast.literal_eval(result)
76
+ pred, score = result['pred'], result['score']
77
+ # check pred
78
+ if pred not in ['yes', 'no']:
79
+ raise ValueError()
80
+ # check score
81
+ result['score'] = float(result['score'])
82
+ if score < 0 or score > 5:
83
+ raise ValueError()
84
+ except Exception as e:
85
+ if verbose:
86
+ print(e)
87
+ print(f'invalid GPT response: {result}')
88
+ return {'success': False, 'result': result, 'data': data}
89
+ return {'success': True, 'result': result, 'data': data}
90
+
91
+ class GPTMetric:
92
+ def __init__(self, dataset_name, verbose=False) -> None:
93
+ self.dataset_name = dataset_name
94
+ self.num_worker = 64
95
+ self.model = 'gpt-35-turbo-0125'
96
+ self.results = []
97
+ self.invalid_results = []
98
+ self.dataset = []
99
+ self.verbose = verbose
100
+
101
+ def add(self, data):
102
+ self.dataset.append(data)
103
+
104
+ def process(self, dataset: List[Dict]):
105
+ pool = Pool(processes = self.num_worker, )
106
+ inputs = [(d, self.model, self.verbose) for d in dataset]
107
+ results = pool.uimap(process_one_sample, inputs, chunksize = 1)
108
+
109
+ for result in tqdm(results, total = len(dataset)):
110
+ self.update_metric(result)
111
+ pool.close()
112
+ pool.join()
113
+ pool.clear() # MUST
114
+
115
+ def update_metric(self, result):
116
+ if result['success']:
117
+ self.results.append(result)
118
+ else:
119
+ self.invalid_results.append(result)
120
+
121
+ def summarize_metric(self):
122
+ if self.verbose:
123
+ for result in self.results + self.invalid_results:
124
+ print(f"Success: {Color.red(str(result['success']))}")
125
+ print(Color.blue(json.dumps(result['data'], ensure_ascii=False)))
126
+ print(Color.green(json.dumps(result['result'], ensure_ascii=False)))
127
+ preds, scores = [], []
128
+ for result in self.results:
129
+ pred, score = result['result']['pred'], result['result']['score']
130
+ preds.append(pred)
131
+ scores.append(score)
132
+ avg_score = np.average(scores)
133
+ acc = np.average([p == 'yes' for p in preds])
134
+ print(f'=====Evaluation Summary=====')
135
+ self.eval_records = [
136
+ f'Dataset: {self.dataset_name}\tMetric: GPT Accuracy',
137
+ f'#Successful Results: {len(self.results)}\n#Failed Results: {len(self.invalid_results)}',
138
+ f'Accuracy: {round(acc*100, 1)}',
139
+ f'Average Score: {round(avg_score, 3)}',
140
+ ]
141
+ for info in self.eval_records:
142
+ print(info)
143
+
144
+ def save_results(self, pred_path):
145
+ if os.path.isdir(pred_path):
146
+ output_dir = os.path.join(pred_path, 'eval_records')
147
+ else:
148
+ output_dir = os.path.join(os.path.dirname(pred_path), 'eval_records')
149
+ os.makedirs(output_dir, exist_ok=True)
150
+ fout = open(os.path.join(output_dir, f'{self.dataset_name}_eval_result.txt'), 'w')
151
+ for info in self.eval_records:
152
+ fout.write(info+'\n')
153
+ fout.close()
eval_scripts/DREAM-1K/tarsier/evaluation/metrics/evaluate_video_mme.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Bytedance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import json
15
+ from typing import List, Dict
16
+ from typing import Optional, List, Union
17
+ import os
18
+ import sys
19
+ sys.path.append('eval_scripts/DREAM-1K/tarsier')from tools.color import Color
20
+
21
+ CATEGORIES = [
22
+ "Knowledge",
23
+ "Film & Television",
24
+ "Sports Competition",
25
+ "Artistic Performance",
26
+ "Life Record",
27
+ "Multilingual"
28
+ ]
29
+
30
+ SUB_CATEGORIES = [
31
+ "Humanity & History",
32
+ "Literature & Art",
33
+ "Biology & Medicine",
34
+ "Finance & Commerce",
35
+ "Astronomy",
36
+ "Geography",
37
+ "Law",
38
+ "Life Tip",
39
+ "Technology",
40
+ "Animation",
41
+ "Movie & TV Show",
42
+ "Documentary",
43
+ "News Report",
44
+ "Esports",
45
+ "Basketball",
46
+ "Football",
47
+ "Athletics",
48
+ "Other Sports",
49
+ "Stage Play",
50
+ "Magic Show",
51
+ "Variety Show",
52
+ "Acrobatics",
53
+ "Handicraft",
54
+ "Food",
55
+ "Fashion",
56
+ "Daily Life",
57
+ "Travel",
58
+ "Pet & Animal",
59
+ "Exercise",
60
+ "Multilingual"
61
+ ]
62
+
63
+ TASK_CATEGORIES = [
64
+ "Temporal Perception",
65
+ "Spatial Perception",
66
+ "Attribute Perception",
67
+ "Action Recognition",
68
+ "Object Recognition",
69
+ "OCR Problems",
70
+ "Counting Problem",
71
+ "Temporal Reasoning",
72
+ "Spatial Reasoning",
73
+ "Action Reasoning",
74
+ "Object Reasoning",
75
+ "Information Synopsis",
76
+ ]
77
+
78
+
79
+ class VideoMMEAccuracyMetric:
80
+ def __init__(self, dataset_name, verbose=False) -> None:
81
+ self.dataset_name = dataset_name
82
+ self.results = []
83
+ self.invalid_results = []
84
+ self.dataset = []
85
+ self.verbose = verbose
86
+
87
+ def add(self, data):
88
+ self.dataset.append(data)
89
+
90
+ def process(self, dataset: List[Dict]):
91
+ return self._process(dataset)
92
+
93
+ def _process(self, dataset: List[Dict]):
94
+ for data in dataset:
95
+ prompt, response, prediction = data['prompt'], data['response'], data['prediction']
96
+ extra_info = data['extra_info']
97
+ prediction = prediction.replace('(', '').replace(')', '').strip()
98
+ if len(prediction) <= 0:
99
+ success = False
100
+ else:
101
+ prediction = prediction[0]
102
+ if '1'<=prediction<='5':
103
+ prediction = chr(int(prediction) + ord('A'))
104
+ success = prediction.isupper() and prediction.isalpha() and len(prediction) == 1
105
+ if success:
106
+ rst = {
107
+ 'success': success,
108
+ 'data': data,
109
+ 'result': {'acc': response == prediction},
110
+ 'extra_info': extra_info,
111
+ 'missing': False
112
+ }
113
+ self.results.append(rst)
114
+ else:
115
+ rst = {
116
+ 'success': success,
117
+ 'data': data,
118
+ 'result': {'acc': False},
119
+ 'extra_info': extra_info,
120
+ 'missing': True
121
+ }
122
+ self.results.append(rst)
123
+ self.invalid_results.append(rst)
124
+
125
+ def summarize_metric(self):
126
+ if self.verbose:
127
+ for result in self.results + self.invalid_results:
128
+ print(f"{Color.red('Success: ' + str(result['success']))}")
129
+ print(Color.blue(json.dumps(result['data'], ensure_ascii=False)))
130
+ print(f"{Color.green('Accuracy: ' + str(result['result']['acc']))}")
131
+ print(f'=====Evaluation Summary=====')
132
+ print(f'Dataset: {self.dataset_name}\tMetric: Accuracy')
133
+ print(f'#Successful Results: {len(self.results) - len(self.invalid_results)}\n#Failed Results: {len(self.invalid_results)}')
134
+ self.eval_your_results(
135
+ video_types = ["short","medium","long"],
136
+ skip_missing = True,
137
+ return_categories_accuracy = True,
138
+ return_sub_categories_accuracy = False,
139
+ return_task_types_accuracy = False,
140
+ )
141
+
142
+ def merge_results(self):
143
+ results_merged_by_vid = {}
144
+ for result in self.results:
145
+ vid = result['extra_info']['vid']
146
+ if vid not in results_merged_by_vid:
147
+ results_merged_by_vid[vid] = {
148
+ 'video_id': vid,
149
+ "duration": result['extra_info']['duration'],
150
+ "domain": result['extra_info']['domain'],
151
+ "sub_category": result['extra_info']['sub_category'],
152
+ 'questions': [],
153
+ 'missing': False
154
+ }
155
+ if result['missing']:
156
+ results_merged_by_vid[vid]['missing'] = True
157
+ results_merged_by_vid[vid]['questions'].append({
158
+ 'qid': result['extra_info']['idx'],
159
+ 'task_type': result['extra_info']['task_type'],
160
+ 'acc': result['result']['acc']
161
+ }
162
+ )
163
+ return results_merged_by_vid
164
+
165
+ def eval_your_results(
166
+ self,
167
+ video_types: Optional[Union[List[str], str]] = None,
168
+ skip_missing: Optional[bool] = False,
169
+ return_categories_accuracy: Optional[bool] = True,
170
+ return_sub_categories_accuracy: Optional[bool] = False,
171
+ return_task_types_accuracy: Optional[bool] = False,
172
+ gt_answer_key: Optional[str] = "answer",
173
+ your_answer_key: Optional[str] = "response"
174
+
175
+ ):
176
+ """
177
+ This copy from https://github.com/thanku-all/parse_answer/blob/main/eval_your_results.py
178
+ Evaluate your results against the ground truth
179
+
180
+ Args:
181
+ - your_results_path (str): Path to your results file
182
+ - video_types (Optional[List[str], str]): List of video types to evaluate.
183
+ - skip_missing (Optional[bool]): If True, missing files will be skipped. If False, an error will be raised if there are missing files.
184
+ - return_categories_accuracy (Optional[bool]): If True, the accuracy for each video category will be returned.
185
+ - return_sub_categories_accuracy (Optional[bool]): If True, the accuracy for each video sub category will be returned.
186
+ - return_task_types_accuracy (Optional[bool]): If True, the accuracy for each task category will be returned.
187
+ - gt_answer_key (Optional[str]): Key to access the ground truth answer in the results file.
188
+ - your_answer_key (Optional[str]): Key to access your answer in the results file.
189
+ """
190
+
191
+ # Load your results
192
+ # with open(your_results_path, 'r') as f:
193
+ # your_results = json.load(f)
194
+ your_results = list(self.merge_results().values())
195
+ self.eval_records = []
196
+ if isinstance(video_types, str):
197
+ video_types = video_types.split(",")
198
+
199
+ q_type_dict = {}
200
+ v_type_dict = {}
201
+ v_sub_type_dict = {}
202
+
203
+
204
+ for video_type in video_types:
205
+
206
+ # Filter your results based on video types
207
+ your_results_video_type = [item for item in your_results if item['duration'] == video_type]
208
+
209
+ # Task Categories
210
+ q_type_dict[video_type] = {}
211
+ for q_type in TASK_CATEGORIES:
212
+ q_type_dict[video_type][q_type] = {"correct": 0, "answered": 0}
213
+
214
+ # Video categories
215
+ v_type_dict[video_type] = {}
216
+ for v_type in CATEGORIES:
217
+ v_type_dict[video_type][v_type] = {"correct": 0, "answered": 0}
218
+
219
+ v_sub_type_dict[video_type] = {}
220
+ for v_sub_type in SUB_CATEGORIES:
221
+ v_sub_type_dict[video_type][v_sub_type] = {"correct": 0, "answered": 0}
222
+
223
+ if not skip_missing:
224
+ # Check if the number of files in your results and ground truth are the same
225
+ print(len(your_results_video_type))
226
+ assert len(your_results_video_type) == 300, f"Number of files in {video_type} is not 300. Check if there are missing files."
227
+
228
+ for item in your_results_video_type:
229
+
230
+ if skip_missing and item["missing"]:
231
+ continue
232
+
233
+ # Get the video category, sub category and question category
234
+ video_category = item["domain"]
235
+ video_sub_category = item["sub_category"]
236
+
237
+ questions = item["questions"]
238
+
239
+ for question in questions:
240
+ q_type = question["task_type"]
241
+
242
+ # Get the ground truth and your response
243
+ # gt_answer = question[gt_answer_key]
244
+ # response = question[your_answer_key]
245
+ acc = question['acc']
246
+
247
+ # Extract the answer from the response
248
+ # extration = extract_characters_regex(response)
249
+
250
+ if acc is not None:
251
+ q_type_dict[video_type][q_type]["answered"] += 1
252
+ q_type_dict[video_type][q_type]["correct"] += acc
253
+
254
+ v_type_dict[video_type][video_category]["answered"] += 1
255
+ v_type_dict[video_type][video_category]["correct"] += acc
256
+
257
+ v_sub_type_dict[video_type][video_sub_category]["answered"] += 1
258
+ v_sub_type_dict[video_type][video_sub_category]["correct"] += acc
259
+
260
+
261
+ # Print the results for each video type
262
+ for video_type in video_types:
263
+ info = f"=====================================\nEvaluation on video Type: {video_type}\n====================================="
264
+ self.eval_records.append(info)
265
+ print(info)
266
+ if return_categories_accuracy:
267
+ info = f"-------------------------------------\nVideo Categories\n-------------------------------------"
268
+ self.eval_records.append(info)
269
+ print(info)
270
+ for v_type in v_type_dict[video_type]:
271
+ info = f"{v_type}: {100 * v_type_dict[video_type][v_type]['correct'] / v_type_dict[video_type][v_type]['answered'] if v_type_dict[video_type][v_type]['answered'] > 0 else 0 : .1f}%"
272
+ self.eval_records.append(info)
273
+ print(info)
274
+ if return_sub_categories_accuracy:
275
+ info = f"-------------------------------------\nVideo Sub Categories\n-------------------------------------"
276
+ self.eval_records.append(info)
277
+ for v_sub_type in v_sub_type_dict[video_type]:
278
+ info = f"{v_sub_type}: {100 * v_sub_type_dict[video_type][v_sub_type]['correct'] / v_sub_type_dict[video_type][v_sub_type]['answered'] if v_sub_type_dict[video_type][v_sub_type]['answered'] > 0 else 0 : .1f}%"
279
+ self.eval_records.append(info)
280
+ print(info)
281
+ if return_task_types_accuracy:
282
+ info = f"-------------------------------------\nTask Categories\n-------------------------------------"
283
+ self.eval_records.append(info)
284
+ print(info)
285
+ for q_type in q_type_dict[video_type]:
286
+ info = f"{q_type}: {100 * q_type_dict[video_type][q_type]['correct'] / q_type_dict[video_type][q_type]['answered'] if q_type_dict[video_type][q_type]['answered'] > 0 else 0 : .1f}%"
287
+ self.eval_records.append(info)
288
+ print(info)
289
+ info = f"-------------------------------------\nOverall Performance\n-------------------------------------"
290
+
291
+ print(info)
292
+ total_correct = sum([q_type_dict[video_type][q_type]["correct"] for q_type in TASK_CATEGORIES])
293
+ total_answered = sum([q_type_dict[video_type][q_type]["answered"] for q_type in TASK_CATEGORIES])
294
+ info = f"Overall: {100 * total_correct / total_answered if total_answered > 0 else 0 : .1f}%"
295
+ self.eval_records.append(info)
296
+ print(info+'\n')
297
+
298
+ # Print the results for the entire dataset
299
+ info = f"=====================================\nEvaluation on the entire dataset\n====================================="
300
+ self.eval_records.append(info)
301
+ print(info)
302
+
303
+ if return_categories_accuracy:
304
+ info = f"-------------------------------------\nVideo Categories\n-------------------------------------"
305
+ self.eval_records.append(info)
306
+ print(info)
307
+ for v_type in CATEGORIES:
308
+ total_correct = sum([v_type_dict[video_type][v_type]["correct"] for video_type in video_types])
309
+ total_answered = sum([v_type_dict[video_type][v_type]["answered"] for video_type in video_types])
310
+ info = f"{v_type}: {100 * total_correct / total_answered if total_answered > 0 else 0 : .1f}%"
311
+ self.eval_records.append(info)
312
+ print(info)
313
+
314
+
315
+ if return_sub_categories_accuracy:
316
+ info = f"-------------------------------------\nVideo Sub Categories\n-------------------------------------"
317
+ self.eval_records.append(info)
318
+ print(info)
319
+
320
+ for v_sub_type in SUB_CATEGORIES:
321
+ total_correct = sum([v_sub_type_dict[video_type][v_sub_type]["correct"] for video_type in video_types])
322
+ total_answered = sum([v_sub_type_dict[video_type][v_sub_type]["answered"] for video_type in video_types])
323
+ info = f"{v_sub_type}: {100 * total_correct / total_answered if total_answered > 0 else 0 : .1f}%"
324
+ self.eval_records.append(info)
325
+ print(info)
326
+
327
+
328
+ if return_task_types_accuracy:
329
+ info = f"-------------------------------------\nTask Categories\n-------------------------------------"
330
+ self.eval_records.append(info)
331
+ print(info)
332
+ for q_type in TASK_CATEGORIES:
333
+
334
+ total_correct = sum([q_type_dict[video_type][q_type]["correct"] for video_type in video_types])
335
+ total_answered = sum([q_type_dict[video_type][q_type]["answered"] for video_type in video_types])
336
+ info = f"{q_type}: {100 * total_correct / total_answered if total_answered > 0 else 0 : .1f}%"
337
+ self.eval_records.append(info)
338
+ print(info)
339
+
340
+ info = f"*************************************\nOverall Performance\n*************************************"
341
+ self.eval_records.append(info)
342
+ print(info)
343
+ total_correct = sum([sum([q_type_dict[video_type][q_type]["correct"] for q_type in TASK_CATEGORIES]) for video_type in video_types])
344
+ total_answered = sum([sum([q_type_dict[video_type][q_type]["answered"] for q_type in TASK_CATEGORIES]) for video_type in video_types])
345
+ info = f"Overall: {100 * total_correct / total_answered if total_answered > 0 else 0 : .1f}%"
346
+ self.eval_records.append(info)
347
+ print(info)
348
+
349
+ def save_results(self, pred_path):
350
+ if os.path.isdir(pred_path):
351
+ output_dir = os.path.join(pred_path, 'eval_records')
352
+ else:
353
+ output_dir = os.path.join(os.path.dirname(pred_path), 'eval_records')
354
+ os.makedirs(output_dir, exist_ok=True)
355
+ fout = open(os.path.join(output_dir, f'{self.dataset_name}_eval_result.txt'), 'w')
356
+ for info in self.eval_records:
357
+ fout.write(info+'\n')
358
+ fout.close()
eval_scripts/DREAM-1K/tarsier/models/modeling_qwen2_vl_fast.py ADDED
@@ -0,0 +1,1320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.nn import LayerNorm
10
+
11
+ from transformers.modeling_utils import PreTrainedModel
12
+ from transformers.configuration_utils import PretrainedConfig
13
+ from transformers.modeling_rope_utils import rope_config_validation, ROPE_INIT_FUNCTIONS
14
+ from transformers.cache_utils import Cache, SlidingWindowCache, StaticCache
15
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
16
+ from transformers.utils import (
17
+ add_start_docstrings,
18
+ add_start_docstrings_to_model_forward,
19
+ is_flash_attn_2_available,
20
+ is_flash_attn_greater_or_equal_2_10,
21
+ logging,
22
+ replace_return_docstrings,
23
+ )
24
+ from transformers.modeling_outputs import (
25
+ BaseModelOutputWithPast,
26
+ ModelOutput,
27
+ )
28
+ from transformers.activations import ACT2FN
29
+ from transformers.generation import GenerationMixin
30
+
31
+ if is_flash_attn_2_available():
32
+ from flash_attn import flash_attn_varlen_func
33
+
34
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
35
+ else:
36
+ flash_attn_varlen_func = None
37
+
38
+ # from apex.normalization.fused_layer_norm import fused_rms_norm_affine
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+ @dataclass
43
+ class Qwen2VLCausalLMOutputWithPast(ModelOutput):
44
+ """
45
+ Base class for Qwen2VL causal language model (or autoregressive) outputs.
46
+
47
+ Args:
48
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
49
+ Language modeling loss (for next-token prediction).
50
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
51
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
52
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
53
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
54
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`)
55
+
56
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
57
+ `past_key_values` input) to speed up sequential decoding.
58
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
59
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
60
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
61
+
62
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
63
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
64
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
65
+ sequence_length)`.
66
+
67
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
68
+ heads.
69
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
70
+ The rope index difference between sequence length and multimodal rope.
71
+ """
72
+
73
+ loss: Optional[torch.FloatTensor] = None
74
+ logits: torch.FloatTensor = None
75
+ past_key_values: Optional[List[torch.FloatTensor]] = None
76
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
77
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
78
+
79
+ class Qwen2VLVisionConfig(PretrainedConfig):
80
+ model_type = "qwen2_vl"
81
+
82
+ def __init__(
83
+ self,
84
+ depth=32,
85
+ embed_dim=1280,
86
+ hidden_size=3584,
87
+ hidden_act="quick_gelu",
88
+ mlp_ratio=4,
89
+ num_heads=16,
90
+ in_channels=3,
91
+ patch_size=14,
92
+ spatial_merge_size=2,
93
+ temporal_patch_size=2,
94
+ attn_implementation='flash_attention_2',
95
+ **kwargs,
96
+ ):
97
+ super().__init__(**kwargs)
98
+
99
+ self.depth = depth
100
+ self.embed_dim = embed_dim
101
+ self.hidden_size = hidden_size
102
+ self.hidden_act = hidden_act
103
+ self.mlp_ratio = mlp_ratio
104
+ self.num_heads = num_heads
105
+ self.in_channels = in_channels
106
+ self.patch_size = patch_size
107
+ self.spatial_merge_size = spatial_merge_size
108
+ self.temporal_patch_size = temporal_patch_size
109
+ self.attn_implementation = attn_implementation
110
+
111
+ @classmethod
112
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
113
+ cls._set_token_in_kwargs(kwargs)
114
+
115
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
116
+
117
+ if config_dict.get("model_type") == "qwen2_vl":
118
+ config_dict = config_dict["vision_config"]
119
+
120
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
121
+ logger.warning(
122
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
123
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
124
+ )
125
+
126
+ return cls.from_dict(config_dict, **kwargs)
127
+
128
+
129
+ class Qwen2VLConfig(PretrainedConfig):
130
+ r"""
131
+ This is the configuration class to store the configuration of a [`Qwen2VLModel`]. It is used to instantiate a
132
+ Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
133
+ with the defaults will yield a similar configuration to that of
134
+ Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
135
+
136
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
137
+ documentation from [`PretrainedConfig`] for more information.
138
+
139
+
140
+ Args:
141
+ vocab_size (`int`, *optional*, defaults to 152064):
142
+ Vocabulary size of the Qwen2VL model. Defines the number of different tokens that can be represented by the
143
+ `inputs_ids` passed when calling [`Qwen2VLModel`]
144
+ hidden_size (`int`, *optional*, defaults to 8192):
145
+ Dimension of the hidden representations.
146
+ intermediate_size (`int`, *optional*, defaults to 29568):
147
+ Dimension of the MLP representations.
148
+ num_hidden_layers (`int`, *optional*, defaults to 80):
149
+ Number of hidden layers in the Transformer encoder.
150
+ num_attention_heads (`int`, *optional*, defaults to 64):
151
+ Number of attention heads for each attention layer in the Transformer encoder.
152
+ num_key_value_heads (`int`, *optional*, defaults to 8):
153
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
154
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
155
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
156
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
157
+ by meanpooling all the original heads within that group. For more details checkout [this
158
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
159
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
160
+ The non-linear activation function (function or string) in the decoder.
161
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
162
+ The maximum sequence length that this model might ever be used with.
163
+ initializer_range (`float`, *optional*, defaults to 0.02):
164
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
165
+ rms_norm_eps (`float`, *optional*, defaults to 1e-05):
166
+ The epsilon used by the rms normalization layers.
167
+ use_cache (`bool`, *optional*, defaults to `True`):
168
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
169
+ relevant if `config.is_decoder=True`.
170
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
171
+ Whether the model's input and output word embeddings should be tied.
172
+ rope_theta (`float`, *optional*, defaults to 1000000.0):
173
+ The base period of the RoPE embeddings.
174
+ use_sliding_window (`bool`, *optional*, defaults to `False`):
175
+ Whether to use sliding window attention.
176
+ sliding_window (`int`, *optional*, defaults to 4096):
177
+ Sliding window attention (SWA) window size. If not specified, will default to `4096`.
178
+ max_window_layers (`int`, *optional*, defaults to 80):
179
+ The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
180
+ attention_dropout (`float`, *optional*, defaults to 0.0):
181
+ The dropout ratio for the attention probabilities.
182
+ vision_config (`Dict`, *optional*):
183
+ The config for the visual encoder initialization.
184
+ rope_scaling (`Dict`, *optional*):
185
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
186
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
187
+ accordingly.
188
+ Expected contents:
189
+ `rope_type` (`str`):
190
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
191
+ 'llama3'], with 'default' being the original RoPE implementation.
192
+ `factor` (`float`, *optional*):
193
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
194
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
195
+ original maximum pre-trained length.
196
+ `original_max_position_embeddings` (`int`, *optional*):
197
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
198
+ pretraining.
199
+ `attention_factor` (`float`, *optional*):
200
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
201
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
202
+ `factor` field to infer the suggested value.
203
+ `beta_fast` (`float`, *optional*):
204
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
205
+ ramp function. If unspecified, it defaults to 32.
206
+ `beta_slow` (`float`, *optional*):
207
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
208
+ ramp function. If unspecified, it defaults to 1.
209
+ `short_factor` (`List[float]`, *optional*):
210
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
211
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
212
+ size divided by the number of attention heads divided by 2
213
+ `long_factor` (`List[float]`, *optional*):
214
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
215
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
216
+ size divided by the number of attention heads divided by 2
217
+ `low_freq_factor` (`float`, *optional*):
218
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
219
+ `high_freq_factor` (`float`, *optional*):
220
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
221
+
222
+ ```python
223
+ >>> from transformers import Qwen2VLForConditionalGeneration, Qwen2VLConfig
224
+
225
+ >>> # Initializing a Qwen2VL style configuration
226
+ >>> configuration = Qwen2VLConfig()
227
+
228
+ >>> # Initializing a model from the Qwen2-VL-7B style configuration
229
+ >>> model = Qwen2VLForConditionalGeneration(configuration)
230
+
231
+ >>> # Accessing the model configuration
232
+ >>> configuration = model.config
233
+ ```"""
234
+
235
+ model_type = "qwen2_vl"
236
+ keys_to_ignore_at_inference = ["past_key_values"]
237
+
238
+ def __init__(
239
+ self,
240
+ vocab_size=152064,
241
+ hidden_size=8192,
242
+ intermediate_size=29568,
243
+ num_hidden_layers=80,
244
+ num_attention_heads=64,
245
+ num_key_value_heads=8,
246
+ hidden_act="silu",
247
+ max_position_embeddings=32768,
248
+ initializer_range=0.02,
249
+ rms_norm_eps=1e-05,
250
+ use_cache=True,
251
+ tie_word_embeddings=False,
252
+ rope_theta=1000000.0,
253
+ use_sliding_window=False,
254
+ sliding_window=4096,
255
+ max_window_layers=80,
256
+ attention_dropout=0.0,
257
+ rope_scaling=None,
258
+ spatial_merge_size=2,
259
+ attn_implementation='flash_attention_2',
260
+ **kwargs,
261
+ ):
262
+
263
+ self.vocab_size = vocab_size
264
+ self.max_position_embeddings = max_position_embeddings
265
+ self.hidden_size = hidden_size
266
+ self.intermediate_size = intermediate_size
267
+ self.num_hidden_layers = num_hidden_layers
268
+ self.num_attention_heads = num_attention_heads
269
+ self.use_sliding_window = use_sliding_window
270
+ self.sliding_window = sliding_window
271
+ self.max_window_layers = max_window_layers
272
+
273
+ # for backward compatibility
274
+ if num_key_value_heads is None:
275
+ num_key_value_heads = num_attention_heads
276
+
277
+ self.num_key_value_heads = num_key_value_heads
278
+ self.hidden_act = hidden_act
279
+ self.initializer_range = initializer_range
280
+ self.rms_norm_eps = rms_norm_eps
281
+ self.use_cache = use_cache
282
+ self.rope_theta = rope_theta
283
+ self.attention_dropout = attention_dropout
284
+ self.rope_scaling = rope_scaling
285
+ self.spatial_merge_size = spatial_merge_size
286
+ self.attn_implementation = attn_implementation
287
+
288
+ # Validate the correctness of rotary position embeddings parameters
289
+ # BC: if there is a 'type' field, move it to 'rope_type'.
290
+ # and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations
291
+ # one can set it to "linear"/"dynamic" etc. to have scaled RoPE
292
+ # TODO: @raushan update config in the hub
293
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
294
+ if self.rope_scaling["type"] == "mrope":
295
+ self.rope_scaling["type"] = "default"
296
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
297
+ rope_config_validation(self, ignore_keys={"mrope_section"})
298
+
299
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
300
+
301
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
302
+ def rotate_half(x):
303
+ """Rotates half the hidden dims of the input."""
304
+ x1 = x[..., : x.shape[-1] // 2]
305
+ x2 = x[..., x.shape[-1] // 2 :]
306
+ return torch.cat((-x2, x1), dim=-1)
307
+
308
+ def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
309
+ """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
310
+
311
+ Explanation:
312
+ Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
313
+ sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
314
+ vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately.
315
+ Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
316
+ For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
317
+ height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
318
+ difference with modern LLMs.
319
+
320
+ Args:
321
+ q (`torch.Tensor`): The query tensor.
322
+ k (`torch.Tensor`): The key tensor.
323
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
324
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
325
+ position_ids (`torch.Tensor`):
326
+ The position indices of the tokens corresponding to the query and key tensors. For example, this can be
327
+ used to pass offsetted position ids when working with a KV-cache.
328
+ mrope_section(`List(int)`):
329
+ Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
330
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
331
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
332
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
333
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
334
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
335
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
336
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
337
+ Returns:
338
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
339
+ """
340
+ mrope_section = mrope_section * 2
341
+ cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
342
+ unsqueeze_dim
343
+ )
344
+ sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
345
+ unsqueeze_dim
346
+ )
347
+
348
+ q_embed = (q * cos) + (rotate_half(q) * sin)
349
+ k_embed = (k * cos) + (rotate_half(k) * sin)
350
+ return q_embed, k_embed
351
+
352
+ def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
353
+ orig_dtype = tensor.dtype
354
+ tensor = tensor.float()
355
+ cos = freqs.cos()
356
+ sin = freqs.sin()
357
+ cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
358
+ sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
359
+ output = (tensor * cos) + (rotate_half(tensor) * sin)
360
+ output = output.to(orig_dtype)
361
+ return output
362
+
363
+
364
+ class VisionRotaryEmbedding(nn.Module):
365
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
366
+ super().__init__()
367
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
368
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
369
+
370
+ def forward(self, seqlen: int) -> torch.Tensor:
371
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
372
+ freqs = torch.outer(seq, self.inv_freq)
373
+ return freqs
374
+
375
+ class PatchEmbed(nn.Module):
376
+ def __init__(
377
+ self,
378
+ patch_size: int = 14,
379
+ temporal_patch_size: int = 2,
380
+ in_channels: int = 3,
381
+ embed_dim: int = 1152,
382
+ ) -> None:
383
+ super().__init__()
384
+ self.patch_size = patch_size
385
+ self.temporal_patch_size = temporal_patch_size
386
+ self.in_channels = in_channels
387
+ self.embed_dim = embed_dim
388
+
389
+ kernel_size = [temporal_patch_size, patch_size, patch_size]
390
+ self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False)
391
+
392
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
393
+ target_dtype = self.proj.weight.dtype
394
+ hidden_states = hidden_states.view(
395
+ -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
396
+ )
397
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
398
+ return hidden_states
399
+
400
+
401
+ class PatchMerger(nn.Module):
402
+ def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
403
+ super().__init__()
404
+ self.hidden_size = context_dim * (spatial_merge_size**2)
405
+ self.ln_q = LayerNorm(context_dim, eps=1e-6)
406
+ self.mlp = nn.Sequential(
407
+ nn.Linear(self.hidden_size, self.hidden_size),
408
+ nn.GELU(),
409
+ nn.Linear(self.hidden_size, dim),
410
+ )
411
+
412
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
413
+ x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
414
+ return x
415
+
416
+ class VisionMlp(nn.Module):
417
+ def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None:
418
+ super().__init__()
419
+ self.fc1 = nn.Linear(dim, hidden_dim)
420
+ self.act = ACT2FN[hidden_act]
421
+ self.fc2 = nn.Linear(hidden_dim, dim)
422
+
423
+ def forward(self, x) -> torch.Tensor:
424
+ return self.fc2(self.act(self.fc1(x)))
425
+
426
+
427
+ class VisionAttention(nn.Module):
428
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
429
+ super().__init__()
430
+ self.num_heads = num_heads
431
+ self.head_dim = dim // num_heads
432
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
433
+ self.proj = nn.Linear(dim, dim)
434
+
435
+ def forward(
436
+ self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
437
+ ) -> torch.Tensor:
438
+ seq_length = hidden_states.shape[0]
439
+ q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
440
+ q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
441
+ k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
442
+
443
+ attention_mask = torch.full(
444
+ [1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
445
+ )
446
+ for i in range(1, len(cu_seqlens)):
447
+ attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
448
+
449
+ q = q.transpose(0, 1)
450
+ k = k.transpose(0, 1)
451
+ v = v.transpose(0, 1)
452
+ attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
453
+ attn_weights = attn_weights + attention_mask
454
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
455
+ attn_output = torch.matmul(attn_weights, v)
456
+ attn_output = attn_output.transpose(0, 1)
457
+ attn_output = attn_output.reshape(seq_length, -1)
458
+ attn_output = self.proj(attn_output)
459
+ return attn_output
460
+
461
+
462
+ class VisionFlashAttention2(nn.Module):
463
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
464
+ super().__init__()
465
+ self.num_heads = num_heads
466
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
467
+ self.proj = nn.Linear(dim, dim)
468
+
469
+ def forward(
470
+ self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
471
+ ) -> torch.Tensor:
472
+ seq_length = hidden_states.shape[0]
473
+ q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
474
+ q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
475
+ k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
476
+
477
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
478
+ attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
479
+ seq_length, -1
480
+ )
481
+ attn_output = self.proj(attn_output)
482
+ return attn_output
483
+
484
+ QWEN2_VL_VISION_ATTENTION_CLASSES = {
485
+ "eager": VisionAttention,
486
+ "flash_attention_2": VisionFlashAttention2,
487
+ }
488
+
489
+
490
+ class Qwen2VLVisionBlock(nn.Module):
491
+ def __init__(self, config, attn_implementation: str = "sdpa") -> None:
492
+ super().__init__()
493
+ self.norm1 = LayerNorm(config.embed_dim, eps=1e-6)
494
+ self.norm2 = LayerNorm(config.embed_dim, eps=1e-6)
495
+ mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio)
496
+
497
+ self.attn = QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation](
498
+ config.embed_dim, num_heads=config.num_heads
499
+ )
500
+ self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act)
501
+
502
+ def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
503
+ hidden_states = hidden_states + self.attn(
504
+ self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
505
+ )
506
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
507
+ return hidden_states
508
+
509
+ class Qwen2VLPreTrainedModel(PreTrainedModel):
510
+ config_class = Qwen2VLConfig
511
+ base_model_prefix = "model"
512
+ supports_gradient_checkpointing = True
513
+ _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock"]
514
+ _skip_keys_device_placement = "past_key_values"
515
+ _supports_flash_attn_2 = True
516
+ _supports_sdpa = False
517
+ _supports_cache_class = True
518
+ _supports_static_cache = True
519
+
520
+ def _init_weights(self, module):
521
+ std = self.config.initializer_range
522
+ if isinstance(module, (nn.Linear, nn.Conv3d)):
523
+ module.weight.data.normal_(mean=0.0, std=std)
524
+ if module.bias is not None:
525
+ module.bias.data.zero_()
526
+ elif isinstance(module, nn.Embedding):
527
+ module.weight.data.normal_(mean=0.0, std=std)
528
+ if module.padding_idx is not None:
529
+ module.weight.data[module.padding_idx].zero_()
530
+
531
+ class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
532
+ config_class = Qwen2VLVisionConfig
533
+ _no_split_modules = ["Qwen2VLVisionBlock"]
534
+
535
+ def __init__(self, config) -> None:
536
+ super().__init__(config)
537
+ self.spatial_merge_size = config.spatial_merge_size
538
+
539
+ self.patch_embed = PatchEmbed(
540
+ patch_size=config.patch_size,
541
+ temporal_patch_size=config.temporal_patch_size,
542
+ in_channels=config.in_channels,
543
+ embed_dim=config.embed_dim,
544
+ )
545
+
546
+ head_dim = config.embed_dim // config.num_heads
547
+ self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
548
+
549
+ self.blocks = nn.ModuleList(
550
+ [Qwen2VLVisionBlock(config, config.attn_implementation) for _ in range(config.depth)]
551
+ )
552
+ self.merger = PatchMerger(
553
+ dim=config.hidden_size, context_dim=config.embed_dim, spatial_merge_size=config.spatial_merge_size
554
+ )
555
+ # Initialize weights and apply final processing
556
+ self.gradient_checkpointing = False
557
+ self.post_init()
558
+
559
+ def get_dtype(self) -> torch.dtype:
560
+ return self.blocks[0].mlp.fc2.weight.dtype
561
+
562
+ def get_device(self) -> torch.device:
563
+ return self.blocks[0].mlp.fc2.weight.device
564
+
565
+ def rot_pos_emb(self, grid_thw):
566
+ pos_ids = []
567
+ for t, h, w in grid_thw:
568
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
569
+ hpos_ids = hpos_ids.reshape(
570
+ h // self.spatial_merge_size,
571
+ self.spatial_merge_size,
572
+ w // self.spatial_merge_size,
573
+ self.spatial_merge_size,
574
+ )
575
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
576
+ hpos_ids = hpos_ids.flatten()
577
+
578
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
579
+ wpos_ids = wpos_ids.reshape(
580
+ h // self.spatial_merge_size,
581
+ self.spatial_merge_size,
582
+ w // self.spatial_merge_size,
583
+ self.spatial_merge_size,
584
+ )
585
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
586
+ wpos_ids = wpos_ids.flatten()
587
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
588
+ pos_ids = torch.cat(pos_ids, dim=0)
589
+ max_grid_size = grid_thw[:, 1:].max()
590
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
591
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
592
+ return rotary_pos_emb
593
+
594
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
595
+ hidden_states = self.patch_embed(hidden_states)
596
+ rotary_pos_emb = self.rot_pos_emb(grid_thw)
597
+
598
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
599
+ dim=0, dtype=torch.int32
600
+ )
601
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
602
+
603
+ for blk in self.blocks:
604
+ if self.gradient_checkpointing and self.training:
605
+ hidden_states = self._gradient_checkpointing_func(
606
+ blk.__call__,
607
+ hidden_states,
608
+ cu_seqlens,
609
+ rotary_pos_emb,
610
+ )
611
+ else:
612
+ hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
613
+
614
+ return self.merger(hidden_states)
615
+
616
+ # class Qwen2RMSNorm(nn.Module):
617
+ # def __init__(self, hidden_size, eps=1e-6):
618
+ # """
619
+ # Qwen2RMSNorm is equivalent to T5LayerNorm
620
+ # """
621
+ # super().__init__()
622
+ # self.weight = nn.Parameter(torch.ones(hidden_size))
623
+ # self.variance_epsilon = eps
624
+ # self.normalized_shape = torch.Size((hidden_size, ))
625
+
626
+ # def forward(self, hidden_states):
627
+ # return fused_rms_norm_affine(input=hidden_states,
628
+ # weight=self.weight,
629
+ # normalized_shape=self.normalized_shape,
630
+ # eps=self.variance_epsilon,
631
+ # memory_efficient=True)
632
+
633
+ class Qwen2RMSNorm(nn.Module):
634
+ def __init__(self, hidden_size, eps=1e-6):
635
+ """
636
+ Qwen2RMSNorm is equivalent to T5LayerNorm
637
+ """
638
+ super().__init__()
639
+ self.weight = nn.Parameter(torch.ones(hidden_size))
640
+ self.variance_epsilon = eps
641
+
642
+ def forward(self, hidden_states):
643
+ input_dtype = hidden_states.dtype
644
+ hidden_states = hidden_states.to(torch.float32)
645
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
646
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
647
+ return self.weight * hidden_states.to(input_dtype)
648
+
649
+ def extra_repr(self):
650
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
651
+
652
+ class Qwen2VLRotaryEmbedding(nn.Module):
653
+ def __init__(
654
+ self,
655
+ dim=None,
656
+ max_position_embeddings=2048,
657
+ base=10000,
658
+ device=None,
659
+ scaling_factor=1.0,
660
+ rope_type="default",
661
+ config: Optional[Qwen2VLConfig] = None,
662
+ ):
663
+ super().__init__()
664
+ # TODO (joao): remove the `if` below, only used for BC
665
+ self.rope_kwargs = {}
666
+ if config is None:
667
+ logger.warning_once(
668
+ "`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the "
669
+ "`config` argument. All other arguments will be removed in v4.46"
670
+ )
671
+ self.rope_kwargs = {
672
+ "rope_type": rope_type,
673
+ "factor": scaling_factor,
674
+ "dim": dim,
675
+ "base": base,
676
+ "max_position_embeddings": max_position_embeddings,
677
+ }
678
+ self.rope_type = rope_type
679
+ self.max_seq_len_cached = max_position_embeddings
680
+ self.original_max_seq_len = max_position_embeddings
681
+ else:
682
+ # BC: "rope_type" was originally "type"
683
+ if config.rope_scaling is not None:
684
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
685
+ else:
686
+ self.rope_type = "default"
687
+ self.max_seq_len_cached = config.max_position_embeddings
688
+ self.original_max_seq_len = config.max_position_embeddings
689
+
690
+ self.config = config
691
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
692
+
693
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
694
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
695
+ self.original_inv_freq = self.inv_freq
696
+
697
+ def _dynamic_frequency_update(self, position_ids, device):
698
+ """
699
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
700
+ 1 - growing beyond the cached sequence length (allow scaling)
701
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
702
+ """
703
+ seq_len = torch.max(position_ids) + 1
704
+ if seq_len > self.max_seq_len_cached: # growth
705
+ inv_freq, self.attention_scaling = self.rope_init_fn(
706
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
707
+ )
708
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
709
+ self.max_seq_len_cached = seq_len
710
+
711
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
712
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
713
+ self.max_seq_len_cached = self.original_max_seq_len
714
+
715
+ @torch.no_grad()
716
+ def forward(self, x, position_ids):
717
+ position_ids = position_ids.permute(2, 0, 1)
718
+ if "dynamic" in self.rope_type:
719
+ self._dynamic_frequency_update(position_ids, device=x.device)
720
+
721
+ # Core RoPE block. In contrast to other models, Qwen2_VL has different position ids for thw grids
722
+ # So we expand the inv_freq to shape (3, ...)
723
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
724
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
725
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
726
+ device_type = x.device.type
727
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
728
+ with torch.autocast(device_type=device_type, enabled=False):
729
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
730
+ emb = torch.cat((freqs, freqs), dim=-1)
731
+ cos = emb.cos()
732
+ sin = emb.sin()
733
+
734
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
735
+ cos = cos * self.attention_scaling
736
+ sin = sin * self.attention_scaling
737
+
738
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
739
+
740
+ # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2MLP
741
+ class Qwen2MLP(nn.Module):
742
+ def __init__(self, config):
743
+ super().__init__()
744
+ self.hidden_size = config.hidden_size
745
+ self.intermediate_size = config.intermediate_size
746
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
747
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
748
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
749
+ self.act_fn = ACT2FN[config.hidden_act]
750
+
751
+ def forward(self, hidden_state):
752
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
753
+
754
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
755
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
756
+ """
757
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
758
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
759
+ """
760
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
761
+ if n_rep == 1:
762
+ return hidden_states
763
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
764
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
765
+
766
+ class Qwen2VLAttention(nn.Module):
767
+ """
768
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
769
+ and "Generating Long Sequences with Sparse Transformers".
770
+ """
771
+
772
+ def __init__(self, config: Qwen2VLConfig, layer_idx: Optional[int] = None):
773
+ super().__init__()
774
+ self.config = config
775
+ self.layer_idx = layer_idx
776
+ if layer_idx is None:
777
+ logger.warning_once(
778
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
779
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
780
+ "when creating this class."
781
+ )
782
+
783
+ self.hidden_size = config.hidden_size
784
+ self.num_heads = config.num_attention_heads
785
+ self.head_dim = self.hidden_size // self.num_heads
786
+ self.num_key_value_heads = config.num_key_value_heads
787
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
788
+ self.max_position_embeddings = config.max_position_embeddings
789
+ self.rope_theta = config.rope_theta
790
+ self.is_causal = True
791
+ self.attention_dropout = config.attention_dropout
792
+ self.rope_scaling = config.rope_scaling
793
+
794
+ if (self.head_dim * self.num_heads) != self.hidden_size:
795
+ raise ValueError(
796
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
797
+ f" and `num_heads`: {self.num_heads})."
798
+ )
799
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
800
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
801
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
802
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
803
+
804
+
805
+ class Qwen2VLFlashAttention2(Qwen2VLAttention):
806
+ """
807
+ Qwen2VL flash attention module, following Qwen2VL attention module. This module inherits from `Qwen2VLAttention`
808
+ as the weights of the module stays untouched. The only required change would be on the forward pass
809
+ where it needs to correctly call the public API of flash attention and deal with padding tokens
810
+ in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
811
+ config.max_window_layers layers.
812
+ """
813
+
814
+ def __init__(self, *args, **kwargs):
815
+ super().__init__(*args, **kwargs)
816
+
817
+ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
818
+ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
819
+ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
820
+ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
821
+
822
+ def forward(
823
+ self,
824
+ hidden_states: torch.Tensor,
825
+ attention_mask: Optional[torch.Tensor] = None,
826
+ position_ids: Optional[torch.LongTensor] = None,
827
+ past_key_value: Optional[Cache] = None,
828
+ output_attentions: bool = False,
829
+ use_cache: bool = False,
830
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
831
+ use_rmpad: Optional[bool] = False,
832
+ cu_seqlens: Optional[torch.Tensor] = False,
833
+ ):
834
+ """
835
+ Train:
836
+ unpad: (bsz, q_len) = (1, acc_seqlen)
837
+ pad: (bsz, q_len) = (bsz, q_len)
838
+ Test:
839
+ first_iter: (bsz, q_len) = (bsz, q_len)
840
+ other: (bsz, q_len) = (bsz, 1)
841
+ """
842
+ bsz, q_len, _ = hidden_states.size()
843
+
844
+ query_states = self.q_proj(hidden_states)
845
+ key_states = self.k_proj(hidden_states)
846
+ value_states = self.v_proj(hidden_states)
847
+
848
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
849
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
850
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
851
+
852
+ cos, sin = position_embeddings
853
+
854
+ query_states, key_states = apply_multimodal_rotary_pos_emb(
855
+ query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
856
+ )
857
+
858
+ if past_key_value is not None:
859
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
860
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
861
+
862
+ # repeat k/v heads if n_kv_heads < n_heads
863
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
864
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
865
+ dropout_rate = 0.0 if not self.training else self.attention_dropout
866
+
867
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
868
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
869
+ # cast them back in float16 just to be sure everything works as expected.
870
+ input_dtype = query_states.dtype
871
+ if input_dtype == torch.float32:
872
+ if torch.is_autocast_enabled():
873
+ target_dtype = torch.get_autocast_gpu_dtype()
874
+ # Handle the case where the model is quantized
875
+ elif hasattr(self.config, "_pre_quantization_dtype"):
876
+ target_dtype = self.config._pre_quantization_dtype
877
+ else:
878
+ target_dtype = self.q_proj.weight.dtype
879
+
880
+ logger.warning_once(
881
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
882
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
883
+ f" {target_dtype}."
884
+ )
885
+
886
+ query_states = query_states.to(target_dtype)
887
+ key_states = key_states.to(target_dtype)
888
+ value_states = value_states.to(target_dtype)
889
+
890
+ # Reashape to the expected shape for Flash Attention
891
+ query_states = query_states.transpose(1, 2)
892
+ key_states = key_states.transpose(1, 2)
893
+ value_states = value_states.transpose(1, 2)
894
+
895
+ if use_rmpad:
896
+ max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1]).item()
897
+ attn_output = flash_attn_varlen_func(
898
+ query_states.squeeze(0), key_states.squeeze(0), value_states.squeeze(0),
899
+ cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens,
900
+ max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen,
901
+ dropout_p=dropout_rate,
902
+ causal=self.is_causal, window_size=(-1, -1),
903
+ )
904
+ else:
905
+ attn_output = _flash_attention_forward(
906
+ query_states, key_states, value_states,
907
+ attention_mask,
908
+ q_len,
909
+ dropout=dropout_rate,
910
+ sliding_window=None,
911
+ is_causal=self.is_causal,
912
+ use_top_left_mask=self._flash_attn_uses_top_left_mask,
913
+ )
914
+
915
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
916
+ attn_output = self.o_proj(attn_output)
917
+
918
+ if not output_attentions:
919
+ attn_weights = None
920
+
921
+ return attn_output, attn_weights, past_key_value
922
+
923
+ QWEN2_VL_ATTENTION_CLASSES = {
924
+ "flash_attention_2": Qwen2VLFlashAttention2,
925
+ }
926
+
927
+ class Qwen2VLDecoderLayer(nn.Module):
928
+ def __init__(self, config: Qwen2VLConfig, layer_idx: int):
929
+ super().__init__()
930
+ self.hidden_size = config.hidden_size
931
+
932
+ if config.attn_implementation != "flash_attention_2":
933
+ logger.error(
934
+ f"只支持 flash_attention_2!config.attn_implementation={config.attn_implementation}"
935
+ )
936
+ self.self_attn = QWEN2_VL_ATTENTION_CLASSES[config.attn_implementation](config, layer_idx)
937
+
938
+ self.mlp = Qwen2MLP(config)
939
+ self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
940
+ self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
941
+
942
+ def forward(
943
+ self,
944
+ hidden_states: torch.Tensor,
945
+ attention_mask: Optional[torch.Tensor] = None,
946
+ position_ids: Optional[torch.LongTensor] = None,
947
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
948
+ output_attentions: Optional[bool] = False,
949
+ use_cache: Optional[bool] = False,
950
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
951
+ use_rmpad: Optional[bool] = False,
952
+ cu_seqlens: Optional[torch.Tensor] = False,
953
+ **kwargs,
954
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
955
+ """
956
+ Args:
957
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
958
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
959
+ `(batch, sequence_length)` where padding elements are indicated by 0.
960
+ output_attentions (`bool`, *optional*):
961
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
962
+ returned tensors for more detail.
963
+ use_cache (`bool`, *optional*):
964
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
965
+ (see `past_key_values`).
966
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
967
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
968
+ Indices depicting the position of the input sequence tokens in the sequence.
969
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
970
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
971
+ with `head_dim` being the embedding dimension of each attention head.
972
+ kwargs (`dict`, *optional*):
973
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
974
+ into the model
975
+ """
976
+
977
+ residual = hidden_states
978
+
979
+ hidden_states = self.input_layernorm(hidden_states)
980
+
981
+ # Self Attention
982
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
983
+ hidden_states=hidden_states,
984
+ attention_mask=attention_mask,
985
+ position_ids=position_ids,
986
+ past_key_value=past_key_value,
987
+ output_attentions=output_attentions,
988
+ use_cache=use_cache,
989
+ position_embeddings=position_embeddings,
990
+ use_rmpad=use_rmpad,
991
+ cu_seqlens=cu_seqlens,
992
+ )
993
+ hidden_states = residual + hidden_states
994
+
995
+ # Fully Connected
996
+ residual = hidden_states
997
+ hidden_states = self.post_attention_layernorm(hidden_states)
998
+ hidden_states = self.mlp(hidden_states)
999
+ hidden_states = residual + hidden_states
1000
+
1001
+ outputs = (hidden_states,)
1002
+
1003
+ if output_attentions:
1004
+ outputs += (self_attn_weights,)
1005
+
1006
+ if use_cache:
1007
+ outputs += (present_key_value,)
1008
+
1009
+ return outputs
1010
+
1011
+ class Qwen2VLModel(Qwen2VLPreTrainedModel):
1012
+ def __init__(self, config: Qwen2VLConfig):
1013
+ super().__init__(config)
1014
+ self.padding_idx = config.pad_token_id
1015
+ self.vocab_size = config.vocab_size
1016
+
1017
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
1018
+ self.layers = nn.ModuleList([Qwen2VLDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
1019
+ self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1020
+ self.rotary_emb = Qwen2VLRotaryEmbedding(config=config)
1021
+
1022
+ self.gradient_checkpointing = False
1023
+ # Initialize weights and apply final processing
1024
+ self.post_init()
1025
+
1026
+ def get_input_embeddings(self):
1027
+ return self.embed_tokens
1028
+
1029
+ def set_input_embeddings(self, value):
1030
+ self.embed_tokens = value
1031
+
1032
+ def forward(
1033
+ self,
1034
+ input_ids: torch.LongTensor = None,
1035
+ attention_mask: Optional[torch.Tensor] = None,
1036
+ position_ids: Optional[torch.LongTensor] = None,
1037
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1038
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1039
+ use_cache: Optional[bool] = None,
1040
+ output_attentions: Optional[bool] = None,
1041
+ output_hidden_states: Optional[bool] = None,
1042
+ return_dict: Optional[bool] = None,
1043
+ use_rmpad: Optional[bool] = False,
1044
+ cu_seqlens: Optional[torch.Tensor] = False,
1045
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
1046
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1047
+ output_hidden_states = (
1048
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1049
+ )
1050
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1051
+
1052
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1053
+
1054
+ if (input_ids is None) ^ (inputs_embeds is not None):
1055
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1056
+
1057
+ if self.gradient_checkpointing and self.training:
1058
+ if use_cache:
1059
+ logger.warning_once(
1060
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
1061
+ )
1062
+ use_cache = False
1063
+
1064
+
1065
+ hidden_states = inputs_embeds
1066
+
1067
+ # create position embeddings to be shared across the decoder layers
1068
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
1069
+
1070
+ # decoder layers
1071
+ all_hidden_states = () if output_hidden_states else None
1072
+ all_self_attns = () if output_attentions else None
1073
+ next_decoder_cache = None
1074
+
1075
+ for decoder_layer in self.layers:
1076
+ if output_hidden_states:
1077
+ all_hidden_states += (hidden_states,)
1078
+
1079
+ if self.gradient_checkpointing and self.training:
1080
+ layer_outputs = self._gradient_checkpointing_func(
1081
+ decoder_layer.__call__,
1082
+ hidden_states,
1083
+ attention_mask,
1084
+ position_ids,
1085
+ past_key_values,
1086
+ output_attentions,
1087
+ use_cache,
1088
+ position_embeddings,
1089
+ use_rmpad,
1090
+ cu_seqlens,
1091
+ )
1092
+ else:
1093
+ layer_outputs = decoder_layer(
1094
+ hidden_states,
1095
+ attention_mask=attention_mask,
1096
+ position_ids=position_ids,
1097
+ past_key_value=past_key_values,
1098
+ output_attentions=output_attentions,
1099
+ use_cache=use_cache,
1100
+ position_embeddings=position_embeddings,
1101
+ use_rmpad=use_rmpad,
1102
+ cu_seqlens=cu_seqlens,
1103
+ )
1104
+
1105
+ hidden_states = layer_outputs[0]
1106
+
1107
+ if use_cache:
1108
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
1109
+
1110
+ if output_attentions:
1111
+ all_self_attns += (layer_outputs[1],)
1112
+
1113
+ hidden_states = self.norm(hidden_states)
1114
+
1115
+ # add hidden states from the last decoder layer
1116
+ if output_hidden_states:
1117
+ all_hidden_states += (hidden_states,)
1118
+
1119
+ next_cache = next_decoder_cache if use_cache else None
1120
+
1121
+ if not return_dict:
1122
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1123
+ return BaseModelOutputWithPast(
1124
+ last_hidden_state=hidden_states,
1125
+ past_key_values=next_cache,
1126
+ hidden_states=all_hidden_states,
1127
+ attentions=all_self_attns,
1128
+ )
1129
+
1130
+ class Qwen2VLForCausalLM(Qwen2VLPreTrainedModel, GenerationMixin):
1131
+ _tied_weights_keys = ["lm_head.weight"]
1132
+
1133
+ def __init__(self, config):
1134
+ super().__init__(config)
1135
+ self.model = Qwen2VLModel(config)
1136
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1137
+ self.padding_side = "left" # set it to left by default, user can use setter to change padding_sides
1138
+
1139
+ # Initialize weights and apply final processing
1140
+ self.post_init()
1141
+
1142
+ def get_input_embeddings(self):
1143
+ return self.model.embed_tokens
1144
+
1145
+ def set_input_embeddings(self, value):
1146
+ self.model.embed_tokens = value
1147
+
1148
+ def get_output_embeddings(self):
1149
+ return self.lm_head
1150
+
1151
+ def set_output_embeddings(self, new_embeddings):
1152
+ self.lm_head = new_embeddings
1153
+
1154
+ def set_decoder(self, decoder):
1155
+ self.model = decoder
1156
+
1157
+ def get_decoder(self):
1158
+ return self.model
1159
+
1160
+ def get_rope_index(
1161
+ self,
1162
+ input_ids: torch.LongTensor,
1163
+ image_grid_thw: Optional[torch.LongTensor] = None,
1164
+ attention_mask: Optional[torch.Tensor] = None,
1165
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1166
+ """
1167
+ Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
1168
+
1169
+ Explanation:
1170
+ Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
1171
+
1172
+ For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs.
1173
+ Examples:
1174
+ input_ids: [T T T T T], here T is for text.
1175
+ temporal position_ids: [0, 1, 2, 3, 4]
1176
+ height position_ids: [0, 1, 2, 3, 4]
1177
+ width position_ids: [0, 1, 2, 3, 4]
1178
+
1179
+ For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
1180
+ and 1D rotary position embeddin for text part.
1181
+ Examples:
1182
+ Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches.
1183
+ input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
1184
+ vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
1185
+ vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
1186
+ vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
1187
+ text temporal position_ids: [3, 4, 5, 6, 7]
1188
+ text height position_ids: [3, 4, 5, 6, 7]
1189
+ text width position_ids: [3, 4, 5, 6, 7]
1190
+ Here we calculate the text start position_ids as the max vision position_ids plus 1.
1191
+
1192
+ Args:
1193
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1194
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
1195
+ it.
1196
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1197
+ The temporal, height and width of feature shape of each image in LLM.
1198
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1199
+ The temporal, height and width of feature shape of each video in LLM.
1200
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1201
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1202
+
1203
+ - 1 for tokens that are **not masked**,
1204
+ - 0 for tokens that are **masked**.
1205
+
1206
+ Returns:
1207
+ position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
1208
+ mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
1209
+ """
1210
+ spatial_merge_size = self.config.spatial_merge_size
1211
+ vision_token_id = self.config.image_token_id
1212
+ vision_start_token_id = self.config.vision_start_token_id
1213
+ assert image_grid_thw is not None # TODO:测试纯文本会不会卡住
1214
+ total_input_ids = input_ids
1215
+ position_ids = torch.ones(
1216
+ 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
1217
+ )
1218
+ vision_index = 0
1219
+ for i, input_ids in enumerate(total_input_ids):
1220
+ if attention_mask is not None:
1221
+ input_ids = input_ids[attention_mask[i] == 1]
1222
+ vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
1223
+ vision_num = (input_ids[vision_start_indices + 1] == vision_token_id).sum()
1224
+ input_tokens = input_ids.tolist()
1225
+ llm_pos_ids_list: list = []
1226
+ st = 0
1227
+ remain_vision_num = vision_num
1228
+ for _ in range(vision_num):
1229
+ if vision_token_id in input_tokens and remain_vision_num > 0:
1230
+ ed_vision = input_tokens.index(vision_token_id, st)
1231
+ else:
1232
+ ed_vision = len(input_tokens) + 1
1233
+
1234
+ t, h, w = (
1235
+ image_grid_thw[vision_index][0],
1236
+ image_grid_thw[vision_index][1],
1237
+ image_grid_thw[vision_index][2],
1238
+ )
1239
+ vision_index += 1
1240
+ remain_vision_num -= 1
1241
+ ed = ed_vision
1242
+
1243
+ llm_grid_t, llm_grid_h, llm_grid_w = (
1244
+ t.item(),
1245
+ h.item() // spatial_merge_size,
1246
+ w.item() // spatial_merge_size,
1247
+ )
1248
+ text_len = ed - st
1249
+
1250
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1251
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
1252
+
1253
+ t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
1254
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
1255
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
1256
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
1257
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
1258
+
1259
+ if st < len(input_tokens):
1260
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
1261
+ text_len = len(input_tokens) - st
1262
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
1263
+
1264
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1265
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
1266
+ position_ids = position_ids.permute(1, 2, 0)
1267
+ return position_ids
1268
+
1269
+ def forward(
1270
+ self,
1271
+ input_ids: torch.LongTensor = None,
1272
+ attention_mask: Optional[torch.Tensor] = None,
1273
+ position_ids: Optional[torch.LongTensor] = None,
1274
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1275
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1276
+ labels: Optional[torch.LongTensor] = None,
1277
+ use_cache: Optional[bool] = None,
1278
+ output_attentions: Optional[bool] = None,
1279
+ output_hidden_states: Optional[bool] = None,
1280
+ return_dict: Optional[bool] = None,
1281
+ use_rmpad: Optional[bool] = False,
1282
+ cu_seqlens: Optional[torch.Tensor] = False,
1283
+ ) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
1284
+
1285
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1286
+ output_hidden_states = (
1287
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1288
+ )
1289
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1290
+
1291
+
1292
+ outputs = self.model(
1293
+ input_ids=input_ids,
1294
+ attention_mask=attention_mask,
1295
+ position_ids=position_ids,
1296
+ past_key_values=past_key_values,
1297
+ inputs_embeds=inputs_embeds,
1298
+ use_cache=use_cache,
1299
+ output_attentions=output_attentions,
1300
+ output_hidden_states=output_hidden_states,
1301
+ return_dict=return_dict,
1302
+ use_rmpad=use_rmpad,
1303
+ cu_seqlens=cu_seqlens,
1304
+ )
1305
+
1306
+ hidden_states = outputs[0]
1307
+ logits = self.lm_head(hidden_states)
1308
+
1309
+ if not return_dict:
1310
+ output = (logits,) + outputs[1:]
1311
+ return output
1312
+
1313
+ return Qwen2VLCausalLMOutputWithPast(
1314
+ logits=logits,
1315
+ past_key_values=outputs.past_key_values,
1316
+ hidden_states=outputs.hidden_states,
1317
+ attentions=outputs.attentions,
1318
+ )
1319
+
1320
+
eval_scripts/DREAM-1K/tarsier/models/modeling_tarsier.py ADDED
@@ -0,0 +1,502 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Optional, Tuple, Union, Dict, Any
3
+ import math
4
+
5
+ import torch.utils.checkpoint
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+
9
+ from transformers import PreTrainedModel, AutoConfig, AutoModel
10
+ from transformers.activations import ACT2FN
11
+ from transformers.cache_utils import Cache
12
+ from transformers.modeling_outputs import ModelOutput
13
+ from transformers.utils import logging
14
+ from transformers.configuration_utils import PretrainedConfig
15
+ from transformers.dynamic_module_utils import get_class_from_dynamic_module
16
+ from transformers.models.auto import AutoModel, AutoModelForCausalLM, CONFIG_MAPPING
17
+ from transformers.generation import GenerationMixin
18
+
19
+ from transformers import LlamaForCausalLM, Qwen2ForCausalLM
20
+ # from models.modeling_qwen2 import Qwen2ForCausalLM
21
+ from models.modeling_qwen2_vl_fast import Qwen2VLForCausalLM
22
+ from models.utils import _pad_input, _unpad_input
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ class LlavaConfig(PretrainedConfig):
28
+
29
+ model_type = "llava"
30
+ is_composition = False
31
+
32
+ def __init__(
33
+ self,
34
+ vision_config=None,
35
+ text_config=None,
36
+ ignore_index=-100,
37
+ image_token_index=32000,
38
+ projector_hidden_act="gelu",
39
+ vision_feature_select_strategy="default",
40
+ vision_feature_layer=-2,
41
+ image_newline_idx=32002,
42
+ image_new_idx=32003,
43
+ projection_head="MLP",
44
+ **kwargs,
45
+ ):
46
+ self.ignore_index = ignore_index
47
+ self.image_token_index = image_token_index
48
+ self.projector_hidden_act = projector_hidden_act
49
+ self.vision_feature_select_strategy = vision_feature_select_strategy
50
+ self.vision_feature_layer = vision_feature_layer
51
+ self.image_newline_idx = image_newline_idx
52
+ self.image_new_idx = image_new_idx
53
+ self.projection_head = projection_head
54
+
55
+ self.vision_config = vision_config
56
+
57
+ if isinstance(self.vision_config, dict):
58
+ vision_config["model_type"] = (
59
+ vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model"
60
+ )
61
+ if 'auto_map' in vision_config:
62
+ repo_id, class_ref = vision_config['auto_map']['AutoConfig'].split("--")
63
+ config_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs)
64
+ self.vision_config = config_class(**vision_config)
65
+ elif vision_config["model_type"] in CONFIG_MAPPING:
66
+ self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
67
+ else:
68
+ raise ValueError(f'vision_config["model_type"] = {vision_config["model_type"]} not supported!')
69
+
70
+ self.text_config = text_config
71
+
72
+ if isinstance(self.text_config, dict):
73
+ text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
74
+ if 'auto_map' in text_config:
75
+ repo_id, class_ref = text_config['auto_map']['AutoConfig'].split("--")
76
+ config_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs)
77
+ self.text_config = config_class(**text_config)
78
+ elif text_config["model_type"] in CONFIG_MAPPING:
79
+ self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
80
+ else:
81
+ raise ValueError(f'text_config["model_type"] = {text_config["model_type"]} not supported!')
82
+
83
+
84
+ super().__init__(**kwargs)
85
+
86
+
87
+
88
+ @dataclass
89
+ # Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->Llava
90
+ class LlavaCausalLMOutputWithPast(ModelOutput):
91
+
92
+ loss: Optional[torch.FloatTensor] = None
93
+ logits: torch.FloatTensor = None
94
+ past_key_values: Optional[List[torch.FloatTensor]] = None
95
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
96
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
97
+ position_ids: Optional[torch.LongTensor] = None
98
+
99
+ def add_split_tokens(image_features, image_newline_embed, image_new_embed):
100
+ num_images, num_image_patches, embed_dim = image_features.shape
101
+ num_height_patches, num_width_patches = int(math.sqrt(num_image_patches)), int(math.sqrt(num_image_patches))
102
+
103
+ # add image_newline
104
+ image_features = image_features.view(num_images, num_height_patches, num_width_patches, embed_dim)
105
+ image_features = torch.cat([
106
+ image_features,
107
+ image_newline_embed.expand((num_images, num_height_patches, 1, embed_dim))
108
+ ], dim=2)
109
+ num_image_patches += num_height_patches
110
+ image_features = image_features.view(num_images, num_image_patches, embed_dim)
111
+
112
+ # add image_new
113
+ image_features = torch.cat([
114
+ image_features,
115
+ image_new_embed.expand((num_images, 1, embed_dim))
116
+ ], dim = 1)
117
+
118
+ return image_features
119
+
120
+
121
+ class LlavaMultiModalProjector(nn.Module):
122
+ def __init__(self, config: LlavaConfig):
123
+ super().__init__()
124
+ self.config = config
125
+
126
+ self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
127
+ self.act = ACT2FN[config.projector_hidden_act]
128
+ self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
129
+
130
+ image_newline_idx = torch.tensor([config.image_newline_idx], dtype=torch.long)
131
+ image_new_idx = torch.tensor([config.image_new_idx], dtype=torch.long)
132
+ self.register_buffer('image_newline_idx', image_newline_idx, persistent=False)
133
+ self.register_buffer('image_new_idx', image_new_idx, persistent=False)
134
+
135
+
136
+ def forward(self, image_features, input_embeddings):
137
+
138
+ selected_image_feature = image_features[self.config.vision_feature_layer]
139
+
140
+ if self.config.vision_feature_select_strategy == "default":
141
+ selected_image_feature = selected_image_feature[:, 1:]
142
+ elif self.config.vision_feature_select_strategy == "full":
143
+ selected_image_feature = selected_image_feature
144
+ else:
145
+ raise ValueError(
146
+ f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
147
+ )
148
+
149
+ hidden_states = self.linear_1(selected_image_feature)
150
+ hidden_states = self.act(hidden_states)
151
+ hidden_states = self.linear_2(hidden_states)
152
+
153
+ image_newline_embed = input_embeddings(self.image_newline_idx).squeeze()
154
+ image_new_embed = input_embeddings(self.image_new_idx).squeeze()
155
+ hidden_states = add_split_tokens(hidden_states, image_newline_embed, image_new_embed)
156
+ return hidden_states
157
+
158
+ class PixelShuffleMultiModalProjector(nn.Module):
159
+ def __init__(self, config: LlavaConfig):
160
+ super().__init__()
161
+ self.config = config
162
+
163
+ self.downsample_ratio = 0.5
164
+ vit_hidden_size = config.vision_config.hidden_size
165
+ llm_hidden_size = config.text_config.hidden_size
166
+
167
+ self.mlp = nn.Sequential(
168
+ nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
169
+ nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
170
+ nn.GELU(),
171
+ nn.Linear(llm_hidden_size, llm_hidden_size)
172
+ )
173
+
174
+ image_newline_idx = torch.tensor([config.image_newline_idx], dtype=torch.long)
175
+ image_new_idx = torch.tensor([config.image_new_idx], dtype=torch.long)
176
+ self.register_buffer('image_newline_idx', image_newline_idx, persistent=False)
177
+ self.register_buffer('image_new_idx', image_new_idx, persistent=False)
178
+
179
+ def forward(self, image_features, input_embeddings):
180
+ selected_image_feature = image_features[self.config.vision_feature_layer]
181
+
182
+ if self.config.vision_feature_select_strategy == "default":
183
+ selected_image_feature = selected_image_feature[:, 1:]
184
+ elif self.config.vision_feature_select_strategy == "full":
185
+ selected_image_feature = selected_image_feature
186
+ else:
187
+ raise ValueError(
188
+ f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
189
+ )
190
+
191
+ image_features = self.pixel_shuffle(selected_image_feature)
192
+ hidden_states = self.mlp(image_features)
193
+
194
+ image_newline_embed = input_embeddings(self.image_newline_idx).squeeze()
195
+ image_new_embed = input_embeddings(self.image_new_idx).squeeze()
196
+ hidden_states = add_split_tokens(hidden_states, image_newline_embed, image_new_embed)
197
+
198
+ return hidden_states
199
+
200
+ def pixel_shuffle(self, x, scale_factor=0.5):
201
+ if scale_factor == 1:
202
+ return x
203
+ n, wh, c = x.shape
204
+ h, w = int(math.sqrt(wh)), int(math.sqrt(wh))
205
+ x = x.view(n, h, w, c)
206
+
207
+ n, w, h, c = x.size()
208
+ # N, W, H, C --> N, W, H * scale, C // scale
209
+ x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
210
+ # N, W, H * scale, C // scale --> N, H * scale, W, C // scale
211
+ x = x.permute(0, 2, 1, 3).contiguous()
212
+ # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
213
+ x = x.view(n, int(h * scale_factor), int(w * scale_factor),
214
+ int(c / (scale_factor * scale_factor)))
215
+ x = x.permute(0, 2, 1, 3).contiguous()
216
+ x = x.view(x.shape[0], -1, x.shape[-1])
217
+ return x
218
+
219
+
220
+ LLAVA_START_DOCSTRING = r"""
221
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
222
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
223
+ etc.)
224
+
225
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
226
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
227
+ and behavior.
228
+
229
+ Parameters:
230
+ config ([`LlavaConfig`] or [`LlavaVisionConfig`]):
231
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
232
+ load the weights associated with the model, only the configuration. Check out the
233
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
234
+ """
235
+
236
+ class TarsierPreTrainedModel(PreTrainedModel):
237
+ config_class = LlavaConfig
238
+ base_model_prefix = "llm"
239
+ supports_gradient_checkpointing = True # TODO: support latest gc
240
+ _skip_keys_device_placement = "past_key_values"
241
+ _supports_flash_attn_2 = True
242
+ _supports_sdpa = False
243
+ _supports_cache_class = True # TODO: support different cache
244
+ _supports_static_cache = True
245
+
246
+ def _init_weights(self, module):
247
+ std = (
248
+ self.config.initializer_range
249
+ if hasattr(self.config, "initializer_range")
250
+ else self.config.text_config.initializer_range
251
+ )
252
+
253
+ if hasattr(module, "class_embedding"):
254
+ module.class_embedding.data.normal_(mean=0.0, std=std)
255
+
256
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)):
257
+ module.weight.data.normal_(mean=0.0, std=std)
258
+ if module.bias is not None:
259
+ module.bias.data.zero_()
260
+ elif isinstance(module, nn.Embedding):
261
+ module.weight.data.normal_(mean=0.0, std=std)
262
+ if module.padding_idx is not None:
263
+ module.weight.data[module.padding_idx].zero_()
264
+ elif isinstance(module, nn.LayerNorm):
265
+ module.weight.data.fill_(1.0)
266
+ if module.bias is not None:
267
+ module.bias.data.zero_()
268
+ @property
269
+ def _no_split_modules(self):
270
+ return self.language_model._no_split_modules + self.vision_tower._no_split_modules
271
+
272
+
273
+ class TarsierForConditionalGeneration(TarsierPreTrainedModel, GenerationMixin):
274
+ def __init__(self, config: LlavaConfig):
275
+ super().__init__(config)
276
+ self.vision_tower = AutoModel.from_config(config.vision_config, trust_remote_code=True)
277
+ if config.text_config.model_type == 'qwen2':
278
+ self.language_model = Qwen2ForCausalLM(config.text_config)
279
+ elif config.text_config.model_type == 'qwen2_vl':
280
+ self.language_model = Qwen2VLForCausalLM(config.text_config)
281
+ elif config.text_config.model_type == 'llama':
282
+ self.language_model = LlamaForCausalLM(config.text_config)
283
+ else:
284
+ raise ValueError(f'{config.text_config.model_type} not supported!')
285
+
286
+ if config.projection_head == 'Pixel_Shuffle':
287
+ self.multi_modal_projector = PixelShuffleMultiModalProjector(config)
288
+ elif config.projection_head == 'MLP':
289
+ self.multi_modal_projector = LlavaMultiModalProjector(config)
290
+ elif config.projection_head == 'auto_map':
291
+ repo_id, class_ref = config.auto_map['ProjectionLayer'].split("--")
292
+ model_class = get_class_from_dynamic_module(class_ref, repo_id)
293
+ self.multi_modal_projector = model_class(config)
294
+ elif config.projection_head is None:
295
+ self.multi_modal_projector = lambda x, *args, **kwargs: x
296
+
297
+ self.post_init()
298
+
299
+ def get_input_embeddings(self):
300
+ return self.language_model.get_input_embeddings()
301
+
302
+ def set_input_embeddings(self, value):
303
+ self.language_model.set_input_embeddings(value)
304
+
305
+ def get_output_embeddings(self):
306
+ return self.language_model.get_output_embeddings()
307
+
308
+ def set_output_embeddings(self, new_embeddings):
309
+ self.language_model.set_output_embeddings(new_embeddings)
310
+
311
+ def set_decoder(self, decoder):
312
+ self.language_model.set_decoder(decoder)
313
+
314
+ def get_decoder(self):
315
+ return self.language_model.get_decoder()
316
+
317
+ def tie_weights(self):
318
+ return self.language_model.tie_weights()
319
+
320
+ def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
321
+ model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
322
+ # update vocab size
323
+ self.config.text_config.vocab_size = model_embeds.num_embeddings
324
+ return model_embeds
325
+
326
+ def forward(
327
+ self,
328
+ input_ids: torch.LongTensor = None,
329
+ attention_mask: Optional[torch.Tensor] = None,
330
+ position_ids: Optional[torch.LongTensor] = None,
331
+ pixel_values: torch.FloatTensor = None,
332
+ image_grid_thw: Optional[torch.Tensor] = None,
333
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
334
+ labels: Optional[torch.LongTensor] = None,
335
+ num_images: Optional[torch.Tensor] = None,
336
+ use_cache: Optional[bool] = None,
337
+ output_attentions: Optional[bool] = None,
338
+ output_hidden_states: Optional[bool] = None,
339
+ return_dict: Optional[bool] = None,
340
+ use_rmpad: Optional[bool] = False,
341
+ **kwargs,
342
+ ) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
343
+
344
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
345
+ output_hidden_states = (
346
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
347
+ )
348
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
349
+
350
+
351
+ if input_ids is None:
352
+ raise ValueError("You must specify input_ids")
353
+
354
+ bsz, max_seq_len = input_ids.shape[0], input_ids.shape[1]
355
+
356
+ if max_seq_len > 1:
357
+ special_image_mask = input_ids == self.config.image_token_index
358
+ print(f'[{input_ids.device}] num_images: {num_images.tolist()} num_image_tokens: {special_image_mask.sum(-1).tolist()}', flush=True)
359
+
360
+ if position_ids is None:
361
+ if 'Qwen2VLForCausalLM' in self.language_model.__class__.__name__:
362
+ position_ids = self.language_model.get_rope_index(input_ids, image_grid_thw, attention_mask) # [bsz, seqlen, 3]
363
+ else:
364
+ position_ids = attention_mask.long().cumsum(-1) - 1 # # [bsz, seqlen]
365
+ position_ids.masked_fill_(attention_mask == 0, 1)
366
+
367
+
368
+ if use_rmpad:
369
+ input_ids, input_ids_indices, cu_seqlens, _ = _unpad_input(input_ids, attention_mask) # [bsz, seqlen] -> [1, seqlen]
370
+ position_ids, _, _, _ = _unpad_input(position_ids, attention_mask)
371
+ input_ids, position_ids = input_ids.unsqueeze(0), position_ids.unsqueeze(0)
372
+ else:
373
+ input_ids_indices, cu_seqlens = None, None
374
+
375
+ inputs_embeds = self.get_input_embeddings()(input_ids) # [1, seqlen, dim]
376
+
377
+ image_features = None
378
+ if pixel_values is not None: # training / first step in generation
379
+ if 'Qwen2VLForCausalLM' in self.language_model.__class__.__name__:
380
+ pixel_values = pixel_values.type(self.vision_tower.get_dtype())
381
+ image_features = self.vision_tower(pixel_values, image_grid_thw)
382
+ else:
383
+ image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
384
+ image_features = self.multi_modal_projector(
385
+ image_outputs.hidden_states,
386
+ self.get_input_embeddings(),
387
+ )
388
+
389
+ special_image_mask = (input_ids == self.config.image_token_index).to(inputs_embeds.device)
390
+ if special_image_mask.sum() > 0:
391
+ image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
392
+ inputs_embeds = inputs_embeds.masked_scatter(
393
+ special_image_mask.unsqueeze(-1).expand_as(inputs_embeds),
394
+ image_features
395
+ )
396
+ else:
397
+ inputs_embeds = image_features.sum(dim=(0,1)) * 0. + inputs_embeds
398
+
399
+ outputs = self.language_model(
400
+ attention_mask=attention_mask,
401
+ position_ids=position_ids,
402
+ past_key_values=past_key_values,
403
+ inputs_embeds=inputs_embeds,
404
+ use_cache=use_cache,
405
+ output_attentions=output_attentions,
406
+ output_hidden_states=output_hidden_states,
407
+ return_dict=return_dict,
408
+ use_rmpad=use_rmpad,
409
+ cu_seqlens=cu_seqlens,
410
+ )
411
+
412
+ logits = outputs[0]
413
+
414
+ loss = None
415
+ if labels is not None:
416
+ loss_fct = nn.CrossEntropyLoss()
417
+ if use_rmpad:
418
+ labels = labels.view(-1)[input_ids_indices.long()]
419
+ shift_labels = torch.cat((labels[1:], labels.new_ones((1))*-100))
420
+ shift_labels.requires_grad = False
421
+ lbl_seq_lens = (cu_seqlens[1:]-1).long()
422
+ shift_labels[lbl_seq_lens] = -100
423
+ loss = loss_fct(logits.squeeze(0), shift_labels)
424
+ else:
425
+ # Shift so that tokens < n predict n
426
+ shift_logits = logits[..., :-1, :].contiguous()
427
+ shift_labels = labels[..., 1:].contiguous()
428
+ # Flatten the tokens
429
+ shift_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
430
+ shift_labels = shift_labels.view(-1)
431
+ # Enable model parallelism
432
+ shift_labels = shift_labels.to(shift_logits.device)
433
+ loss = loss_fct(shift_logits, shift_labels)
434
+ elif use_rmpad: # 训练的时候,就不 unpad logits 了,节省显存。
435
+ logits = _pad_input(logits.squeeze(0), input_ids_indices, bsz, max_seq_len)
436
+
437
+ if not return_dict:
438
+ output = (logits,) + outputs[1:]
439
+ return (loss,) + output if loss is not None else output
440
+
441
+ return LlavaCausalLMOutputWithPast(
442
+ loss=loss,
443
+ logits=logits,
444
+ past_key_values=outputs.past_key_values,
445
+ hidden_states=outputs.hidden_states,
446
+ attentions=outputs.attentions,
447
+ position_ids=position_ids,
448
+ )
449
+
450
+ def prepare_inputs_for_generation(
451
+ self,
452
+ input_ids,
453
+ attention_mask=None,
454
+ position_ids=None,
455
+ past_key_values=None,
456
+ cache_position=None,
457
+ use_cache=True,
458
+ pixel_values=None,
459
+ image_grid_thw=None,
460
+ **kwargs,
461
+ ):
462
+ if past_key_values is not None:
463
+ past_length = past_key_values.get_seq_length()
464
+ input_ids = input_ids[:, past_length:]
465
+
466
+ model_inputs = {
467
+ "input_ids": input_ids,
468
+ "attention_mask": attention_mask,
469
+ "past_key_values": past_key_values,
470
+ "use_cache": use_cache,
471
+ }
472
+ if kwargs.get('num_images') is not None:
473
+ model_inputs['num_images'] = kwargs['num_images']
474
+
475
+ if cache_position[0] == 0:
476
+ # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
477
+ # Otherwise we need pixel values to be passed to model
478
+ model_inputs["pixel_values"] = pixel_values
479
+ model_inputs["image_grid_thw"] = image_grid_thw
480
+ else:
481
+ model_inputs['position_ids'] = position_ids[:, -1, ...].unsqueeze(1).to(device=input_ids.device) + 1
482
+ return model_inputs
483
+
484
+
485
+ def _update_model_kwargs_for_generation(
486
+ self,
487
+ outputs: ModelOutput,
488
+ model_kwargs: Dict[str, Any],
489
+ is_encoder_decoder: bool = False,
490
+ num_new_tokens: int = 1,
491
+ ) -> Dict[str, Any]:
492
+ model_kwargs = super()._update_model_kwargs_for_generation(
493
+ outputs=outputs,
494
+ model_kwargs=model_kwargs,
495
+ is_encoder_decoder=is_encoder_decoder,
496
+ num_new_tokens=num_new_tokens,
497
+ )
498
+
499
+ if getattr(outputs, "position_ids", None) is not None:
500
+ model_kwargs["position_ids"] = outputs.position_ids
501
+
502
+ return model_kwargs
eval_scripts/DREAM-1K/tarsier/models/utils.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from einops import rearrange
4
+
5
+ def _unpad_input(input_ids, attention_mask):
6
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
7
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
8
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
9
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
10
+ input_ids = rearrange(input_ids, 'b s ... -> (b s) ...')[indices]
11
+ return input_ids, indices, cu_seqlens, max_seqlen_in_batch
12
+
13
+ def _pad_input(hidden_states, indices, batch, seqlen):
14
+ output = torch.zeros(batch * seqlen, *hidden_states.shape[1:], device=hidden_states.device,
15
+ dtype=hidden_states.dtype)
16
+ output[indices] = hidden_states
17
+ return rearrange(output, '(b s) ... -> b s ...', b=batch)
eval_scripts/DREAM-1K/tarsier/scripts/run_demo_cli.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ model_path=$1
4
+ n_frames=${2:-8}
5
+ max_new_tokens=${3:-512}
6
+ top_p=${4:-0.8}
7
+ temperature=${5:-0}
8
+
9
+ python3 -m tasks.demo_cli \
10
+ --model_name_or_path $model_path \
11
+ --config "configs/tarser2_default_config.yaml" \
12
+ --max_n_frames $n_frames \
13
+ --max_new_tokens $max_new_tokens \
14
+ --top_p $top_p \
15
+ --temperature $temperature
eval_scripts/DREAM-1K/tarsier/scripts/run_demo_gradio.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ model_path=$1
4
+ max_n_frames=${2:-8}
5
+
6
+ export MODEL_PATH=$model_path
7
+ export MAX_N_FRAMES=$max_n_frames
8
+
9
+ python3 -m tasks.demo_gradio
eval_scripts/DREAM-1K/tarsier/scripts/run_evaluation_only.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ export AZURE_ENDPOINT=...
3
+ export OPENAI_API_KEY=...
4
+
5
+ pred_file=$1
6
+ # benchmarks=${@:2}
7
+ benchmarks=dream
8
+ benchmarks=${benchmarks:-"all"}
9
+
10
+ python -m evaluation.evaluate \
11
+ --pred_file $pred_file \
12
+ --benchmarks $benchmarks
eval_scripts/DREAM-1K/tarsier/scripts/run_inference_benchmark.sh ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Copy and Modified on: https://github.com/LLaVA-VL/LLaVA-NeXT/blob/video_inference/scripts/video/eval/video_detail_description_eval_shard.sh
4
+
5
+ #
6
+
7
+ model_name_or_path=$1
8
+ output_dir=$2
9
+ benchmarks=${@:3}
10
+ benchmarks=${benchmarks:-"all"}
11
+ resume=True
12
+ CHUNKS=8
13
+
14
+ mkdir $output_dir
15
+
16
+ echo "Using $CHUNKS GPUs"
17
+
18
+ # Assuming GPULIST is a bash array containing your GPUs
19
+ GPULIST=(0 1 2 3 4 5 6 7)
20
+ # GPULIST=(0 1)
21
+
22
+ # Get the number of GPUs
23
+ NUM_GPUS=${#GPULIST[@]}
24
+
25
+ # Calculate GPUs per chunk
26
+ GPUS_PER_CHUNK=$((NUM_GPUS / CHUNKS))
27
+
28
+
29
+ for IDX in $(seq 1 $CHUNKS); do
30
+ START=$(((IDX-1) * GPUS_PER_CHUNK))
31
+ LENGTH=$GPUS_PER_CHUNK # Length for slicing, not the end index
32
+
33
+ CHUNK_GPUS=(${GPULIST[@]:$START:$LENGTH})
34
+
35
+ # Convert the chunk GPUs array to a comma-separated string
36
+ CHUNK_GPUS_STR=$(IFS=,; echo "${CHUNK_GPUS[*]}")
37
+
38
+ ALL_GPUS_FREE=0
39
+ while [ $ALL_GPUS_FREE -eq 0 ]; do
40
+ ALL_GPUS_FREE=1 # Assume all GPUs are free initially
41
+
42
+ for GPU_ID in $CHUNK_GPUS; do
43
+ MEM_USAGE=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i $GPU_ID | tr -d '[:space:]')
44
+
45
+ # Assuming a GPU is considered free if its memory usage is less than 100 MiB
46
+ if [ $MEM_USAGE -ge 100 ]; then
47
+ ALL_GPUS_FREE=0
48
+ echo "GPU $GPU_ID is in use. Memory used: ${MEM_USAGE}MiB."
49
+ break # Exit the loop early as we found a GPU that is not free
50
+ fi
51
+ done
52
+
53
+ if [ $ALL_GPUS_FREE -eq 0 ]; then
54
+ echo "Not all GPUs in chunk are free. Checking again in 10 seconds..."
55
+ sleep 10
56
+ fi
57
+ done
58
+
59
+ echo "CUDA_VISIBLE_DEVICES=$CHUNK_GPUS_STR"
60
+ CUDA_VISIBLE_DEVICES=$CHUNK_GPUS_STR python3 -m tasks.inference_benchmark \
61
+ --model_name_or_path $model_name_or_path \
62
+ --config "configs/tarser2_default_config.yaml" \
63
+ --max_new_tokens 512 \
64
+ --top_p 1 \
65
+ --temperature 0 \
66
+ --output_dir $output_dir \
67
+ --output_name predictions \
68
+ --max_n_samples_per_benchmark -1 \
69
+ --benchmarks $benchmarks \
70
+ --resume $resume \
71
+ --num_chunks $CHUNKS \
72
+ --chunk_idx $(($IDX - 1)) > $output_dir/run_$IDX.log 2>&1 &
73
+
74
+ done
75
+
76
+ wait
77
+
78
+ python3 -m evaluation.evaluate \
79
+ --pred_file $output_dir \
80
+ --benchmarks $benchmarks
eval_scripts/DREAM-1K/tarsier/scripts/run_inference_caption.sh ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Copy and Modified on: https://github.com/LLaVA-VL/LLaVA-NeXT/blob/video_inference/scripts/video/eval/video_detail_description_eval_shard.sh
4
+
5
+ #
6
+
7
+ model_name_or_path=$1
8
+ input_file=$2
9
+ output_dir=$3
10
+ CHUNKS=1
11
+ resume=True
12
+
13
+ mkdir $output_dir
14
+
15
+ echo "Using $CHUNKS GPUs"
16
+
17
+ # Assuming GPULIST is a bash array containing your GPUs
18
+ # GPULIST=(0 1 2 3 4 5 6 7)
19
+ GPULIST=(0)
20
+
21
+ # Get the number of GPUs
22
+ NUM_GPUS=${#GPULIST[@]}
23
+
24
+ # Calculate GPUs per chunk
25
+ GPUS_PER_CHUNK=$((NUM_GPUS / CHUNKS))
26
+
27
+
28
+ for IDX in $(seq 1 $CHUNKS); do
29
+ START=$(((IDX-1) * GPUS_PER_CHUNK))
30
+ LENGTH=$GPUS_PER_CHUNK # Length for slicing, not the end index
31
+
32
+ CHUNK_GPUS=(${GPULIST[@]:$START:$LENGTH})
33
+
34
+ # Convert the chunk GPUs array to a comma-separated string
35
+ CHUNK_GPUS_STR=$(IFS=,; echo "${CHUNK_GPUS[*]}")
36
+
37
+ ALL_GPUS_FREE=0
38
+ while [ $ALL_GPUS_FREE -eq 0 ]; do
39
+ ALL_GPUS_FREE=1 # Assume all GPUs are free initially
40
+
41
+ for GPU_ID in $CHUNK_GPUS; do
42
+ MEM_USAGE=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i $GPU_ID | tr -d '[:space:]')
43
+
44
+ # Assuming a GPU is considered free if its memory usage is less than 100 MiB
45
+ if [ "$MEM_USAGE" -ge 100 ]; then
46
+ ALL_GPUS_FREE=0
47
+ echo "GPU $GPU_ID is in use. Memory used: ${MEM_USAGE}MiB."
48
+ break # Exit the loop early as we found a GPU that is not free
49
+ fi
50
+ done
51
+
52
+ if [ $ALL_GPUS_FREE -eq 0 ]; then
53
+ echo "Not all GPUs in chunk are free. Checking again in 10 seconds..."
54
+ sleep 10
55
+ fi
56
+ done
57
+
58
+ echo "CUDA_VISIBLE_DEVICES=$CHUNK_GPUS_STR"
59
+ CUDA_VISIBLE_DEVICES=$CHUNK_GPUS_STR python3 -m tasks.inference_caption \
60
+ --model_name_or_path $model_name_or_path \
61
+ --config "configs/tarser2_default_config.yaml" \
62
+ --max_new_tokens 512 \
63
+ --top_p 1 \
64
+ --temperature 0 \
65
+ --input_file $input_file \
66
+ --output_dir $output_dir \
67
+ --output_name predictions \
68
+ --max_n_samples_per_benchmark -1 \
69
+ --resume $resume \
70
+ --num_chunks $CHUNKS \
71
+ --chunk_idx $(($IDX - 1)) > $output_dir/run_$IDX.log 2>&1 &
72
+
73
+ done
74
+
75
+ wait
76
+
77
+ # python3 -m evaluation.evaluate \
78
+ # --pred_file $output_dir \
79
+ # --benchmarks $benchmarks
eval_scripts/DREAM-1K/tarsier/tasks/demo_cli.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Bytedance Ltd. and/or its affiliates
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import argparse
15
+ import os
16
+ import torch
17
+ from copy import deepcopy
18
+ from transformers import StoppingCriteriaList
19
+ from tasks.utils import load_model_and_processor
20
+ from dataset.utils import *
21
+ from tools.conversation import Chat, conv_templates, StoppingCriteriaSub
22
+ from transformers import TextStreamer
23
+ from tools.color import Color
24
+
25
+
26
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
27
+
28
+ def main(args):
29
+ # Load Model
30
+ print(f"### Start loading model...")
31
+ model, processor = load_model_and_processor(args.model_name_or_path, args.config)
32
+ print(f"### Finish loading model.")
33
+ if 'tarsier2' in args.model_name_or_path.lower():
34
+ conv_type = 'tarsier2-7b'
35
+ else:
36
+ if '7b' in args.model_name_or_path.lower():
37
+ conv_type = 'tarsier-7b'
38
+ elif '13b' in args.model_name_or_path.lower():
39
+ conv_type = 'tarsier-13b'
40
+ elif '34b' in args.model_name_or_path.lower():
41
+ conv_type = 'tarsier-34b'
42
+ else:
43
+ raise ValueError(f"Unknow model: {args.model_name_or_path}")
44
+
45
+ chat = Chat(model, processor, device=device, debug = args.debug)
46
+ conv = deepcopy(conv_templates[conv_type])
47
+
48
+ img_path = ''
49
+ has_img = False
50
+ while True:
51
+ if not has_img:
52
+ try:
53
+ img_path = input(Color.green(f"{conv.roles[1]}: ") + "Input a file path of your image/video:")
54
+ img_path = img_path.strip()
55
+ if not (os.path.exists(img_path) and get_visual_type(img_path) in ['video', 'gif', 'image']):
56
+ continue
57
+ has_img = True
58
+ conv.messages.append([conv.roles[0], {"type": "video", "text": img_path}])
59
+ print(Color.green(f"{conv.roles[1]}: ") + "Received your file. Now let's start conversation! :)")
60
+ print(Color.red(f"<Input \'exit\' to exit and \'reset\' to restart>"))
61
+ except Exception as e:
62
+ print(f"Error: {e}")
63
+ print("exit...")
64
+ exit()
65
+ inp = ""
66
+ while inp == "":
67
+ inp = input(Color.blue(f"{conv.roles[0]}: ")).strip()
68
+ if inp.strip() == 'exit':
69
+ print("exit...")
70
+ exit()
71
+ elif inp.strip() == "reset":
72
+ conv = deepcopy(conv_templates[conv_type])
73
+ img_path = ''
74
+ continue
75
+ conv = chat.ask(inp, conv)
76
+
77
+ stop_words_ids = [torch.tensor([processor.processor.tokenizer.eos_token_id]).to(device)]
78
+ stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
79
+ streamer = TextStreamer(processor.processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
80
+
81
+ inputs, conv = chat.prepare_model_inputs(conv, args.max_n_frames)
82
+ print("conv:", conv)
83
+ print(Color.green(f"{conv.roles[1]}: "), end="")
84
+ with torch.inference_mode():
85
+ outputs = model.generate(
86
+ **inputs,
87
+ do_sample=True if args.temperature > 0 else False,
88
+ temperature=args.temperature,
89
+ top_p=args.top_p,
90
+ max_new_tokens=args.max_new_tokens,
91
+ streamer=streamer,
92
+ use_cache=True,
93
+ stopping_criteria=[stopping_criteria])
94
+ outputs = processor.processor.tokenizer.decode(outputs[0][inputs['input_ids'][0].shape[0]:], skip_special_tokens=True)
95
+ conv.messages.append(
96
+ [conv.roles[1], {"text": outputs, "type": "text"}]
97
+ )
98
+
99
+ if args.debug:
100
+ print(f"Conversation state: {conv}")
101
+
102
+ if __name__ == "__main__":
103
+ # python3 -m tasks.demo_cli --model_name_or_path /tmp/tarsier2-1226-dpo --config configs/tarser2_default_config.yaml
104
+ import argparse
105
+
106
+ parser = argparse.ArgumentParser()
107
+ parser.add_argument('--model_name_or_path', type=str)
108
+ parser.add_argument('--config', type=str, default="configs/tarser2_default_config.yaml")
109
+ parser.add_argument("--max_n_frames", type=int, default=16, help="Max number of frames to apply average sampling from the given video.")
110
+ parser.add_argument("--max_new_tokens", type=int, default=512, help="max number of generated tokens")
111
+ parser.add_argument("--top_p", type=float, default=1, help="Top_p sampling")
112
+ parser.add_argument("--temperature", type=float, default=0, help="Set temperature > 0 to enable sampling generation.")
113
+ parser.add_argument("--debug", action="store_true")
114
+ args = parser.parse_args()
115
+
116
+ main(args)
eval_scripts/DREAM-1K/tarsier/tasks/demo_gradio.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Bytedance Ltd. and/or its affiliates
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # copy and modify from: https://github.com/OpenGVLab/Ask-Anything/blob/main/video_chat2/demo/demo.py
16
+
17
+ # import spaces # for deploying on huggingface ZeroGPU
18
+ from copy import deepcopy
19
+ import gradio as gr
20
+ from gradio.themes.utils import colors, fonts, sizes
21
+ from tools.conversation import Chat, conv_templates
22
+ from tasks.utils import load_model_and_processor, file_to_base64
23
+ from dataset.tarsier_datamodule import init_processor
24
+ import os
25
+ import torch
26
+
27
+ # huggingface-cli login
28
+
29
+ model_path = os.getenv("MODEL_PATH", "omni-research/Tarsier2-7b")
30
+ config_path = "configs/tarser2_default_config.yaml"
31
+ max_n_frames = int(os.getenv("MAX_N_FRAMES", 16))
32
+ debug = False
33
+ device = 'cuda' if not debug else 'cpu'
34
+
35
+ # ========================================
36
+ # Model Initialization
37
+ # ========================================
38
+ def init_model():
39
+ print("Start Initialization...")
40
+ # if torch.cuda.is_available():
41
+ if not debug:
42
+ model, processor = load_model_and_processor(model_path, config_path)
43
+ else:
44
+ print(f"No Valid GPU! Lauch in debug mode!")
45
+ processor = init_processor(model_path, config_path)
46
+ model = None
47
+ chat = Chat(model, processor, device, debug)c
48
+ print('Initialization Finished')
49
+ return chat
50
+
51
+
52
+ # ========================================
53
+ # Gradio Setting
54
+ # ========================================
55
+ def gradio_reset(chat_state, img_file):
56
+ if chat_state is not None:
57
+ chat_state.messages = []
58
+ img_file = None
59
+ return None, gr.update(value=None, interactive=True), gr.update(value=None, interactive=True), gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your video first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_file
60
+
61
+
62
+ def upload_img(gr_img, gr_video, gr_gif, chat_state, num_frames):
63
+ print("video, image or gif:", gr_video, gr_img, gr_gif)
64
+ conv_type = ''
65
+ if 'tarsier2-7b' in model_path.lower():
66
+ conv_type = 'tarsier2-7b'
67
+ # elif '7b' in model_path.lower():
68
+ # conv_type = 'tarsier-7b'
69
+ # elif '13b' in model_path.lower():
70
+ # conv_type = 'tarsier-13b'
71
+ # elif '34b' in model_path.lower():
72
+ # conv_type = 'tarsier-34b'
73
+ else:
74
+ raise ValueError(f"Unknow model: {model_path}")
75
+ chat_state = deepcopy(conv_templates[conv_type])
76
+
77
+ if gr_img is None and gr_video is None and gr_gif is None:
78
+ return None, None, None, gr.update(interactive=True), gr.update(interactive=True, placeholder='Please upload video/image first!'), chat_state, None, None
79
+ if gr_video or gr_img or gr_gif:
80
+ for img_file in [gr_video, gr_img, gr_gif]:
81
+ if img_file is not None:
82
+ break
83
+ chat_state.messages.append([chat_state.roles[0], {"type": "video", "text": img_file}])
84
+ return gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_file
85
+
86
+
87
+ def gradio_ask(user_message, chatbot, chat_state):
88
+ if len(user_message) == 0:
89
+ return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
90
+ chat_state = chat.ask(user_message, chat_state)
91
+ chatbot = chatbot + [[user_message, None]]
92
+ return '', chatbot, chat_state
93
+
94
+ # @spaces.GPU(duration=120) # for deploying on huggingface ZeroGPU
95
+ def gradio_answer(chatbot, chat_state, img_file, top_p, temperature, n_frames=None):
96
+ llm_message, chat_state = chat.answer(conv=chat_state, n_frames=n_frames, max_new_tokens=256, num_beams=1, temperature=temperature, top_p=top_p)
97
+ chatbot[-1][1] = llm_message
98
+ print(chat_state)
99
+ print(f"Answer: {llm_message}")
100
+ return chatbot, chat_state
101
+
102
+
103
+ class OpenGVLab(gr.themes.base.Base):
104
+ def __init__(
105
+ self,
106
+ *,
107
+ primary_hue=colors.blue,
108
+ secondary_hue=colors.sky,
109
+ neutral_hue=colors.gray,
110
+ spacing_size=sizes.spacing_md,
111
+ radius_size=sizes.radius_sm,
112
+ text_size=sizes.text_md,
113
+ font=(
114
+ fonts.GoogleFont("Noto Sans"),
115
+ "ui-sans-serif",
116
+ "sans-serif",
117
+ ),
118
+ font_mono=(
119
+ fonts.GoogleFont("IBM Plex Mono"),
120
+ "ui-monospace",
121
+ "monospace",
122
+ ),
123
+ ):
124
+ super().__init__(
125
+ primary_hue=primary_hue,
126
+ secondary_hue=secondary_hue,
127
+ neutral_hue=neutral_hue,
128
+ spacing_size=spacing_size,
129
+ radius_size=radius_size,
130
+ text_size=text_size,
131
+ font=font,
132
+ font_mono=font_mono,
133
+ )
134
+ super().set(
135
+ body_background_fill="*neutral_50",
136
+ )
137
+
138
+
139
+ gvlabtheme = OpenGVLab(primary_hue=colors.blue,
140
+ secondary_hue=colors.sky,
141
+ neutral_hue=colors.gray,
142
+ spacing_size=sizes.spacing_md,
143
+ radius_size=sizes.radius_sm,
144
+ text_size=sizes.text_md,
145
+ )
146
+
147
+ logo_b64 = file_to_base64("assets/figures/tarsier_logo.jpg")
148
+ title = f"""<center><a href="https://github.com/bytedance/tarsier"><img src="data:image/jpeg;base64,{logo_b64}" alt="Tarsier" border="0" style="margin: 0 auto; height: 140px;" /></a></center>"""
149
+ description ="""<center><p><a href='https://github.com/bytedance/tarsier'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p></center>
150
+ """
151
+
152
+
153
+ with gr.Blocks(title="Tarsier",theme=gvlabtheme,css="#chatbot {overflow:auto; height:500px;} #InputVideo {overflow:visible; height:320px;} footer {visibility: none}") as demo:
154
+ gr.Markdown(title)
155
+ gr.Markdown(description)
156
+ with gr.Row():
157
+ with gr.Column(scale=0.5, visible=True) as video_upload:
158
+ with gr.Column(elem_id="image", scale=0.5) as img_part:
159
+ with gr.Tab("Video", elem_id='video_tab'):
160
+ up_video = gr.Video(interactive=True, include_audio=True, elem_id="video_upload", height=360)
161
+ with gr.Tab("Image", elem_id='image_tab'):
162
+ up_image = gr.Image(type="filepath", interactive=True, elem_id="image_upload", height=360)
163
+ with gr.Tab("GIF", elem_id='gif_tab'):
164
+ up_gif = gr.File(type="filepath", file_count="single", file_types=[".gif"], interactive=True, elem_id="gif_upload", height=360)
165
+ upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
166
+ clear = gr.Button("Restart")
167
+
168
+ # num_beams = gr.Slider(
169
+ # minimum=1,
170
+ # maximum=10,
171
+ # value=1,
172
+ # step=1,
173
+ # interactive=True,
174
+ # label="beam search numbers)",
175
+ # )
176
+
177
+ temperature = gr.Slider(
178
+ minimum=0.0,
179
+ maximum=1.0,
180
+ value=0.0,
181
+ step=0.1,
182
+ interactive=True,
183
+ label="Temperature",
184
+ )
185
+
186
+ top_p = gr.Slider(
187
+ minimum=0.1,
188
+ maximum=1.0,
189
+ value=1.0,
190
+ step=0.1,
191
+ interactive=True,
192
+ label="Top_p",
193
+ )
194
+
195
+ num_frames = gr.Slider(
196
+ minimum=4,
197
+ maximum=16,
198
+ value=16,
199
+ step=2,
200
+ interactive=True,
201
+ label="#Frames",
202
+ )
203
+
204
+ with gr.Column(visible=True) as input_raws:
205
+ chat_state = gr.State()
206
+ img_file = gr.State()
207
+ chatbot = gr.Chatbot(elem_id="chatbot",label='VideoChat')
208
+ with gr.Row():
209
+ with gr.Column(scale=0.7):
210
+ text_input = gr.Textbox(show_label=False, placeholder='Please upload your video first', interactive=False, container=False)
211
+ with gr.Column(scale=0.15, min_width=0):
212
+ run = gr.Button("💭Send")
213
+ with gr.Column(scale=0.15, min_width=0):
214
+ clear = gr.Button("🔄Clear️")
215
+
216
+ chat = init_model()
217
+ upload_button.click(upload_img, [up_image, up_video, up_gif, chat_state, num_frames], [up_image, up_video, up_gif, text_input, upload_button, chat_state, img_file])
218
+
219
+ text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
220
+ gradio_answer, [chatbot, chat_state, img_file, top_p, temperature, num_frames], [chatbot, chat_state]
221
+ )
222
+ run.click(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
223
+ gradio_answer, [chatbot, chat_state, img_file, top_p, temperature, num_frames], [chatbot, chat_state]
224
+ )
225
+ run.click(lambda: "", None, text_input)
226
+ clear.click(gradio_reset, [chat_state, img_file], [chatbot, up_image, up_video, up_gif, text_input, upload_button, chat_state, img_file], queue=False)
227
+
228
+
229
+ demo.launch()
230
+ # demo.launch(server_name="0.0.0.0", server_port=11451)
eval_scripts/DREAM-1K/tarsier/tasks/inference_benchmark.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Bytedance Ltd. and/or its affiliates
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import argparse
15
+ import torch
16
+ from tasks.utils import load_model_and_processor
17
+ # from dataset.mm_dataset import MMDataset
18
+ from dataset.custom_data_parsers.utils import put_pred_to_data_dict, get_prompt_from_data_dict
19
+ from dataset.tarsier_datamodule import TarsierDataset
20
+ from dataset.utils import *
21
+
22
+ import json
23
+ import os
24
+ import math
25
+ from tqdm import tqdm
26
+ import yaml
27
+
28
+ ANN_ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) + '/../data/annotations'
29
+
30
+ Benchmark2fname = {
31
+ 'dream': 'DREAM-1k.jsonl',
32
+
33
+ 'next-qa': 'Next-QA-val-multi_choice.jsonl',
34
+ 'egoschema': 'EgoSchema_subset.jsonl', # change to EgoSchema_fullset.jsonl if you test on the fullset
35
+ 'mvbench': 'MVBench.jsonl',
36
+ 'tvbench': 'TVBench.jsonl',
37
+ 'video-mme': 'Video-MME.jsonl',
38
+ 'favor-bench': 'FAVOR-Bench.jsonl',
39
+
40
+ 'msvd-qa': 'MSVD-QA-val.jsonl',
41
+ 'msr-vtt-qa': 'MSR-VTT-QA-val.jsonl',
42
+ 'tgif-qa': 'TGIF-QA-test.jsonl',
43
+ 'anet-qa': 'ActivityNet-QA-test.jsonl',
44
+
45
+ 'msvd-caption': 'MSVD-Caption-test.jsonl',
46
+ 'msr-vtt-caption': 'MSR-VTT-Caption-test.jsonl',
47
+ 'vatex-caption': 'VATEX-test.jsonl',
48
+
49
+ 'video_caption': "caption-test.jsonl", # custom for video caption test
50
+ }
51
+
52
+ def get_ann_file_path(benchmark):
53
+ ann_fpath = os.path.join(ANN_ROOT_DIR, Benchmark2fname[benchmark])
54
+ assert os.path.exists(ann_fpath), f"The annotation file for {benchmark} not exists: {ann_fpath}"
55
+ return ann_fpath
56
+
57
+ def split_list(lst, n):
58
+ """Split a list into n (roughly) equal-sized chunks"""
59
+ chunk_size = math.ceil(len(lst) / n) # integer division
60
+ return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
61
+
62
+
63
+ def get_chunk(lst, n, k):
64
+ chunks = split_list(lst, n)
65
+ return chunks[k]
66
+
67
+
68
+ def parse_args():
69
+ """
70
+ Parse command-line arguments.
71
+ """
72
+ parser = argparse.ArgumentParser()
73
+
74
+ # Define the command-line arguments
75
+
76
+ parser.add_argument('--model_name_or_path', type=str, required=True)
77
+ parser.add_argument('--config', type=str, default="configs/tarser2_default_config.yaml")
78
+ # parser.add_argument("--max_n_frames", type=int, default=8, help="Max number of frames to apply average sampling from the given video.")
79
+ parser.add_argument("--max_new_tokens", type=int, default=512, help="max number of generated tokens")
80
+ parser.add_argument("--top_p", type=float, default=1, help="Top_p sampling")
81
+ parser.add_argument("--temperature", type=float, default=0, help="Set temperature > 0 to enable sampling generation.")
82
+
83
+ parser.add_argument("--output_dir", type=str, help="Directory to save the model results", required=True)
84
+ parser.add_argument("--output_name", type=str, default="predictions", help="Name of the file for storing results")
85
+
86
+ parser.add_argument("--num_chunks", type=int, default=1)
87
+ parser.add_argument("--chunk_idx", type=int, default=0)
88
+
89
+ parser.add_argument("--max_n_samples_per_benchmark", type=int, default=-1, help="Set as a small number (like 100) to run as debug.")
90
+ parser.add_argument('--benchmarks', nargs='+', default=["all"], help="Default as 'all' to inference on all benchmarks; Also could be task types: ('dream', 'caption', 'mc_qa', 'oe_qa'); And specific benchmark names: ('dream', 'msvd-caption', 'msr-vtt-caption', 'vatex-caption', 'next-qa', 'egoschema', 'mvbench', 'video-mme', 'msvd-qa', 'msr-vtt-qa', 'tgif-qa', 'anet-qa')")
91
+
92
+ parser.add_argument("--resume", type=lambda x: (str(x).lower() == 'true'), default=True, help="Resume from existing inference results file or overwrite them.")
93
+
94
+ args = parser.parse_args()
95
+
96
+ args.benchmarks = get_benchmarks(args.benchmarks)
97
+ print("### Selected Benchmarks:", args.benchmarks)
98
+
99
+ return args
100
+
101
+
102
+ def run_inference(args):
103
+ """
104
+ Run inference on selected benchmarks.
105
+
106
+ Args:
107
+ args: Command-line arguments.
108
+ """
109
+ # Initialize the model
110
+ # model, processor = load_model_and_processor(args.model_name_or_path, args.max_n_frames) # max_n_frames set in config_file
111
+ data_config = yaml.safe_load(open(args.config, 'r'))
112
+ model, processor = load_model_and_processor(args.model_name_or_path, data_config=data_config)
113
+
114
+ all_chunks = []
115
+ count = 0
116
+ print(f"Start loading dataset...")
117
+ for benchmark in args.benchmarks:
118
+ ann_fpath = get_ann_file_path(benchmark)
119
+ cur_anns = [json.loads(line) for line in open(ann_fpath)]
120
+ if args.max_n_samples_per_benchmark > 0:
121
+ cur_anns = cur_anns[:args.max_n_samples_per_benchmark]
122
+ count += len(cur_anns)
123
+ cur_chunk = get_chunk(cur_anns, args.num_chunks, args.chunk_idx)
124
+ all_chunks.extend(cur_chunk)
125
+ print(f"### [{benchmark}] Load chunk with {len(cur_chunk)} samples from {len(cur_anns)} samples.")
126
+ print(f"### Finish loading chunk with {len(all_chunks)} samples from {count} samples in total.")
127
+
128
+ # Create the output directory if it doesn't exist
129
+ if not os.path.exists(args.output_dir):
130
+ os.makedirs(args.output_dir)
131
+
132
+ if args.num_chunks > 1:
133
+ output_name = f"{args.output_name}_{args.num_chunks}_{args.chunk_idx}"
134
+ else:
135
+ output_name = args.output_name
136
+ answers_file = os.path.join(args.output_dir, f"{output_name}.jsonl")
137
+ if args.resume and os.path.exists(answers_file):
138
+ processed_data = [json.loads(line) for line in open(answers_file)]
139
+ processed_idxs = set([f"{d['dataset']}-{d['idx']}" for d in processed_data])
140
+ all_chunks = [d for d in all_chunks if f"{d['dataset']}-{d['idx']}" not in processed_idxs]
141
+ print(f"### Resume from {len(processed_idxs)} samples. {len(all_chunks)} samples to run.", flush=True)
142
+ ans_file = open(answers_file, "a")
143
+ else:
144
+ ans_file = open(answers_file, "w")
145
+
146
+ dataset = TarsierDataset(
147
+ anns=all_chunks, config=data_config, processor=processor
148
+ )
149
+
150
+ generate_kwargs = {
151
+ "do_sample": True if args.temperature > 0 else False,
152
+ "max_new_tokens": args.max_new_tokens,
153
+ "top_p": args.top_p,
154
+ "temperature": args.temperature,
155
+ "use_cache": True
156
+ }
157
+
158
+ if len(dataset) == 0:
159
+ return
160
+ for ann, inputs in tqdm(dataset, total=len(dataset)):
161
+ if inputs is not None:
162
+ if "prompt" in inputs:
163
+ prompt = get_prompt_from_data_dict(ann)
164
+ print(f"###Prompt:\n{prompt}", flush=True)
165
+ # print(f"Input: {processor.processor.tokenizer.decode(inputs['input_ids'][0]), skip_special_tokens=True}", flush=True)
166
+ try:
167
+ model_inputs = {}
168
+ for k, v in inputs.items():
169
+ if not isinstance(v, torch.Tensor):
170
+ continue
171
+ model_inputs[k] = v.to(model.device)
172
+ outputs = model.generate(
173
+ **model_inputs,
174
+ **generate_kwargs,
175
+ )
176
+ output_text = processor.processor.tokenizer.decode(outputs[0][model_inputs['input_ids'][0].shape[0]:], skip_special_tokens=True)
177
+ except Exception as e:
178
+ print(f"Error: {e}")
179
+ output_text = "<error>"
180
+ print(f"###Prediction:\n{output_text}", flush=True)
181
+ answer = ann['messages'][-1]['content'][-1]['reference']
182
+ print(f"###Answer:\n{answer}", flush=True)
183
+ put_pred_to_data_dict(output_text, ann)
184
+ else:
185
+ put_pred_to_data_dict("<error>", ann)
186
+ try:
187
+ ans_file.write(json.dumps(ann, ensure_ascii=False) + "\n")
188
+ except:
189
+ ans_file.write(json.dumps(ann) + "\n")
190
+ ans_file.flush()
191
+
192
+ ans_file.close()
193
+
194
+
195
+ if __name__ == "__main__":
196
+ args = parse_args()
197
+ run_inference(args)
eval_scripts/DREAM-1K/tarsier/tasks/inference_caption.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Bytedance Ltd. and/or its affiliates
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import argparse
15
+ import torch
16
+ from tasks.utils import load_model_and_processor
17
+ # from dataset.mm_dataset import MMDataset
18
+ from dataset.custom_data_parsers.utils import put_pred_to_data_dict, get_prompt_from_data_dict
19
+ from dataset.tarsier_datamodule import TarsierDataset
20
+ from dataset.utils import *
21
+
22
+ import json
23
+ import os
24
+ import math
25
+ from tqdm import tqdm
26
+ import yaml
27
+
28
+
29
+ ANN_ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) + '/../data/annotations'
30
+
31
+
32
+ def split_list(lst, n):
33
+ """Split a list into n (roughly) equal-sized chunks"""
34
+ chunk_size = math.ceil(len(lst) / n) # integer division
35
+ return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
36
+
37
+
38
+ def get_chunk(lst, n, k):
39
+ chunks = split_list(lst, n)
40
+ return chunks[k]
41
+
42
+
43
+ def parse_args():
44
+ """
45
+ Parse command-line arguments.
46
+ """
47
+ parser = argparse.ArgumentParser()
48
+
49
+ # Define the command-line arguments
50
+
51
+ parser.add_argument('--model_name_or_path', type=str, required=True)
52
+ parser.add_argument('--config', type=str, default="configs/tarser2_default_config.yaml")
53
+ # parser.add_argument("--max_n_frames", type=int, default=8, help="Max number of frames to apply average sampling from the given video.")
54
+ parser.add_argument("--max_new_tokens", type=int, default=512, help="max number of generated tokens")
55
+ parser.add_argument("--top_p", type=float, default=1, help="Top_p sampling")
56
+ parser.add_argument("--temperature", type=float, default=0, help="Set temperature > 0 to enable sampling generation.")
57
+
58
+ parser.add_argument("--input_file", type=str, help="Directory to input_file (jsonline)", required=True)
59
+ parser.add_argument("--output_dir", type=str, help="Directory to save the model results", required=True)
60
+ parser.add_argument("--output_name", type=str, default="predictions", help="Name of the file for storing results")
61
+
62
+ parser.add_argument("--num_chunks", type=int, default=1)
63
+ parser.add_argument("--chunk_idx", type=int, default=0)
64
+
65
+ parser.add_argument("--max_n_samples_per_benchmark", type=int, default=-1, help="Set as a small number (like 100) to run as debug.")
66
+ parser.add_argument("--resume", type=lambda x: (str(x).lower() == 'true'), default=True, help="Resume from existing inference results file or overwrite them.")
67
+
68
+ args = parser.parse_args()
69
+
70
+ return args
71
+
72
+
73
+ def run_inference(args):
74
+ """
75
+ Run inference on selected benchmarks.
76
+
77
+ Args:
78
+ args: Command-line arguments.
79
+ """
80
+ # Initialize the model
81
+ # model, processor = load_model_and_processor(args.model_name_or_path, args.max_n_frames) # max_n_frames set in config_file
82
+ data_config = yaml.safe_load(open(args.config, 'r'))
83
+ model, processor = load_model_and_processor(args.model_name_or_path, data_config=data_config)
84
+
85
+ all_chunks = []
86
+ count = 0
87
+ print(f"Start loading dataset...")
88
+ ann_fpath = args.input_file
89
+ cur_anns = [json.loads(line) for line in open(ann_fpath)]
90
+ if args.max_n_samples_per_benchmark > 0:
91
+ cur_anns = cur_anns[:args.max_n_samples_per_benchmark]
92
+ count += len(cur_anns)
93
+ cur_chunk = get_chunk(cur_anns, args.num_chunks, args.chunk_idx)
94
+ all_chunks.extend(cur_chunk)
95
+ print(f"### Load chunk with {len(cur_chunk)} samples from {len(cur_anns)} samples.")
96
+ print(f"### Finish loading chunk with {len(all_chunks)} samples from {count} samples in total.")
97
+
98
+ # Create the output directory if it doesn't exist
99
+ if not os.path.exists(args.output_dir):
100
+ os.makedirs(args.output_dir)
101
+
102
+ if args.num_chunks > 1:
103
+ output_name = f"{args.output_name}_{args.num_chunks}_{args.chunk_idx}"
104
+ else:
105
+ output_name = args.output_name
106
+ answers_file = os.path.join(args.output_dir, f"{output_name}.jsonl")
107
+ if args.resume and os.path.exists(answers_file):
108
+ processed_data = [json.loads(line) for line in open(answers_file)]
109
+ processed_idxs = set([f"{d['dataset']}-{d['idx']}" for d in processed_data])
110
+ all_chunks = [d for d in all_chunks if f"{d['dataset']}-{d['idx']}" not in processed_idxs]
111
+ print(f"### Resume from {len(processed_idxs)} samples. {len(all_chunks)} samples to run.", flush=True)
112
+ ans_file = open(answers_file, "a")
113
+ else:
114
+ ans_file = open(answers_file, "w")
115
+
116
+ dataset = TarsierDataset(
117
+ anns=all_chunks, config=data_config, processor=processor
118
+ )
119
+
120
+ generate_kwargs = {
121
+ "do_sample": True if args.temperature > 0 else False,
122
+ "max_new_tokens": args.max_new_tokens,
123
+ "top_p": args.top_p,
124
+ "temperature": args.temperature,
125
+ "use_cache": True
126
+ }
127
+
128
+ if len(dataset) == 0:
129
+ return
130
+ for ann, inputs in tqdm(dataset):
131
+ if inputs is not None:
132
+ prompt = get_prompt_from_data_dict(ann)
133
+ print(f"###Prompt:\n{prompt}", flush=True)
134
+ # print(f"Input: {processor.processor.tokenizer.decode(inputs['input_ids'][0]), skip_special_tokens=True}", flush=True)
135
+ try:
136
+ model_inputs = {}
137
+ for k, v in inputs.items():
138
+ if not isinstance(v, torch.Tensor):
139
+ continue
140
+ model_inputs[k] = v.to(model.device)
141
+ outputs = model.generate(
142
+ **model_inputs,
143
+ **generate_kwargs,
144
+ )
145
+ output_text = processor.processor.tokenizer.decode(outputs[0][model_inputs['input_ids'][0].shape[0]:], skip_special_tokens=True)
146
+ except Exception as e:
147
+ print(f"Error: {e}")
148
+ output_text = "<error>"
149
+ print(f"###Prediction:\n{output_text}", flush=True)
150
+ put_pred_to_data_dict(output_text, ann)
151
+ else:
152
+ put_pred_to_data_dict("<error>", ann)
153
+ try:
154
+ ans_file.write(json.dumps(ann, ensure_ascii=False) + "\n")
155
+ except:
156
+ ans_file.write(json.dumps(ann) + "\n")
157
+ ans_file.flush()
158
+
159
+ ans_file.close()
160
+
161
+
162
+ if __name__ == "__main__":
163
+ # python3 -m tasks.inference_caption --model_name_or_path /tmp/tarsier2-1226-dpo --config configs/tarser2_default_config.yaml --input_file data/annotations/caption-test-new.jsonl --output_dir tmp_outputs
164
+ args = parse_args()
165
+ run_inference(args)
eval_scripts/DREAM-1K/tarsier/tasks/inference_quick_start.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Bytedance Ltd. and/or its affiliates
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from tasks.utils import load_model_and_processor
15
+ from dataset.custom_data_parsers.utils import put_pred_to_data_dict, get_prompt_from_data_dict
16
+ from dataset.utils import *
17
+
18
+ import os
19
+ import torch
20
+ from tqdm import tqdm
21
+ import yaml
22
+
23
+ def process_one(model, processor, prompt, video_file, generate_kwargs):
24
+ # inputs = processor(prompt, video_file, edit_prompt=True, return_prompt=True)
25
+ sample = format_one_sample(video_file, prompt)
26
+ batch_data = processor(sample)
27
+ print(f"###Prompt:\n{get_prompt_from_data_dict(sample)}")
28
+ model_inputs = {}
29
+ for k, v in batch_data.items():
30
+ if not isinstance(v, torch.Tensor):
31
+ continue
32
+ model_inputs[k] = v.to(model.device)
33
+ outputs = model.generate(
34
+ **model_inputs,
35
+ **generate_kwargs,
36
+ )
37
+ # print(processor.processor.tokenizer.decode(outputs[0][:model_inputs['input_ids'][0].shape[0]], skip_special_tokens=True))
38
+ output_text = processor.processor.tokenizer.decode(outputs[0][model_inputs['input_ids'][0].shape[0]:], skip_special_tokens=True)
39
+ return output_text
40
+
41
+ def run():
42
+ import argparse
43
+
44
+ parser = argparse.ArgumentParser()
45
+ parser.add_argument('--model_name_or_path', type=str)
46
+ parser.add_argument('--config', type=str, default="configs/tarser2_default_config.yaml")
47
+ parser.add_argument('--instruction', type=str, default="Describe the video in detail.", help='Input prompt.')
48
+ parser.add_argument('--input_path', type=str, default="assets/examples", help='Path to video/image; or Dir to videos/images')
49
+ # parser.add_argument("--max_n_frames", type=int, default=16, help="Max number of frames to apply average sampling from the given video.")
50
+ parser.add_argument("--max_new_tokens", type=int, default=256, help="max number of generated tokens")
51
+ parser.add_argument("--top_p", type=float, default=1, help="Top_p sampling")
52
+ parser.add_argument("--temperature", type=float, default=0, help="Set temperature > 0 to enable sampling generation.")
53
+
54
+ args = parser.parse_args()
55
+
56
+ # model, processor = load_model_and_processor(args.model_name_or_path, max_n_frames=args.max_n_frames) # max_n_frames set in config_file
57
+ data_config = yaml.safe_load(open(args.config, 'r'))
58
+ model, processor = load_model_and_processor(args.model_name_or_path, data_config=data_config)
59
+
60
+ generate_kwargs = {
61
+ "do_sample": True if args.temperature > 0 else False,
62
+ "max_new_tokens": args.max_new_tokens,
63
+ "top_p": args.top_p,
64
+ "temperature": args.temperature,
65
+ "use_cache": True
66
+ }
67
+ assert os.path.exists(args.input_path), f"input_path not exist: {args.input_path}"
68
+ if os.path.isdir(args.input_path):
69
+ input_files = [os.path.join(args.input_path, fn) for fn in os.listdir(args.input_path) if get_visual_type(fn) in ['video', 'gif', 'image']]
70
+ elif get_visual_type(args.input_path) in ['video', 'gif', 'image']:
71
+ input_files = [args.input_path]
72
+ assert len(input_files) > 0, f"None valid input file in: {args.input_path} {VALID_DATA_FORMAT_STRING}"
73
+
74
+ for input_file in tqdm(input_files, desc="Generating..."):
75
+ visual_type = get_visual_type(input_file)
76
+ if args.instruction:
77
+ prompt = args.instruction
78
+ else:
79
+ if visual_type == 'image':
80
+ prompt = "Describe the image in detail."
81
+ else:
82
+ prompt = "Describe the video in detail."
83
+
84
+ pred = process_one(model, processor, prompt, input_file, generate_kwargs)
85
+ print(f"###Prediction:\n{pred}")
86
+ print('-'*100)
87
+
88
+
89
+ if __name__ == "__main__":
90
+ # python3 -m tasks.inference_quick_start --model_name_or_path /tmp/tarsier2-1226-dpo --config configs/tarser2_default_config.yaml --input_path /mnt/bn/videonasi18n/wangjw/workspace/tarsier/diving.mp4 --instruction "List the names of all sponsors on the background wall."
91
+ run()
eval_scripts/DREAM-1K/tarsier/tasks/utils.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Bytedance Ltd. and/or its affiliates
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from models.modeling_tarsier import TarsierForConditionalGeneration, LlavaConfig
15
+ # from dataset.processor import Processor
16
+ from dataset.tarsier_datamodule import init_processor
17
+ import torch
18
+ import base64
19
+ from tools.color import Color
20
+ import yaml
21
+
22
+ def load_model_and_processor(model_name_or_path, data_config):
23
+ print(Color.red(f"Load model and processor from: {model_name_or_path}"), flush=True)
24
+ if isinstance(data_config, str):
25
+ data_config = yaml.safe_load(open(data_config, 'r'))
26
+ processor = init_processor(model_name_or_path, data_config)
27
+ model_config = LlavaConfig.from_pretrained(
28
+ model_name_or_path,
29
+ trust_remote_code=True,
30
+ )
31
+ model = TarsierForConditionalGeneration.from_pretrained(
32
+ model_name_or_path,
33
+ config=model_config,
34
+ device_map='auto',
35
+ torch_dtype=torch.bfloat16,
36
+ trust_remote_code=True
37
+ )
38
+ model.eval()
39
+ return model, processor
40
+
41
+ def file_to_base64(img_path):
42
+ with open(img_path, 'rb') as video_file:
43
+ video_b64_str = base64.b64encode(video_file.read()).decode()
44
+ return video_b64_str
45
+
eval_scripts/DREAM-1K/tarsier/tools/color.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Bytedance Ltd. and/or its affiliates
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ class Color:
15
+
16
+ @staticmethod
17
+ def red(x):
18
+ return '\33[31m' +x + '\033[0m'
19
+
20
+ @staticmethod
21
+ def green(x):
22
+ return '\33[32m' +x + '\033[0m'
23
+
24
+ @staticmethod
25
+ def yellow(x):
26
+ return '\33[33m' +x + '\033[0m'
27
+
28
+ @staticmethod
29
+ def blue(x):
30
+ return '\33[34m' +x + '\033[0m'
31
+
32
+ @staticmethod
33
+ def violet(x):
34
+ return '\33[35m' +x + '\033[0m'
35
+
36
+
eval_scripts/DREAM-1K/tarsier/tools/conversation.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Bytedance Ltd. and/or its affiliates
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # copy and modify from: https://github.com/OpenGVLab/Ask-Anything/blob/main/video_chat2/conversation.py
16
+ from PIL import Image
17
+ import torch
18
+ from transformers import StoppingCriteria, StoppingCriteriaList
19
+ from dataset.custom_data_parsers.utils import put_pred_to_data_dict, get_prompt_from_data_dict
20
+ from dataset.tarsier_datamodule import TarsierDataProcessor
21
+ from dataset.utils import *
22
+
23
+ from enum import auto, Enum
24
+ import os
25
+ import re
26
+
27
+ data_dict_tmp = {
28
+ "messages": [
29
+ {
30
+ "role": "user",
31
+ "content": [
32
+ {
33
+ "type": "video",
34
+ "video": {
35
+ "video_file": "/mnt/hdfs/vlm/videos/movies_aligned_0523/tt8266310/tt8266310_1.50.24-1.50.29.mp4"}
36
+ },
37
+ {
38
+ "type": "text",
39
+ "text": "Describe the video in detail."
40
+ }
41
+ ]
42
+ },
43
+ {
44
+ "role": "assistant",
45
+ "content": [
46
+ {
47
+ "type": "text",
48
+ "text": "A man in the driver's seat, wearing a black jacket with a maroon shirt, fastens his seatbelt while smiling at the man in the passenger seat, who is adjusting his position. The passenger, also wearing a black jacket with a maroon shirt, turns to look forward and smiles. The driver then leans forward to start the car and leans back in his seat. In the background, a beige car is visible through the window."
49
+ }]}
50
+ ],
51
+ "dataset": "video_caption",
52
+ "task": "video/caption",
53
+ "idx": 0,
54
+ }
55
+
56
+
57
+ IMAGE_TOKEN = "<image>"
58
+ VIDEO_TOKEN = "<video>"
59
+
60
+ class SeparatorStyle(Enum):
61
+ """Different separator style."""
62
+ SINGLE = auto()
63
+ TWO = auto()
64
+
65
+ def get_data_dict(conv, max_n_frames=None):
66
+ data_dict = {
67
+ "messages": []
68
+ }
69
+ for i, (role, message) in enumerate(conv.messages):
70
+ if message:
71
+ text = message["text"]
72
+ content_type = message["type"]
73
+ content = {}
74
+ if content_type == "text":
75
+ content['type'] = 'text'
76
+ content['text'] = text
77
+ task = "text-only"
78
+ elif content_type == "video":
79
+ content['type'] = 'video'
80
+ content['video'] = {
81
+ "video_file": text
82
+ }
83
+ if max_n_frames is not None:
84
+ content['video']['n_frames'] = max_n_frames
85
+ task = "video/QA"
86
+ elif content_type == "image":
87
+ content['type'] = 'image'
88
+ content['image'] = {
89
+ "image_file": text
90
+ }
91
+ task = "image/QA"
92
+ else:
93
+ content['type'] = 'text'
94
+ content['text'] = text
95
+ task = "text-only"
96
+ if data_dict['messages'] and data_dict['messages'][-1]['role'] == role:
97
+ data_dict['messages'][-1]['content'].append(content)
98
+ else:
99
+ data_dict['messages'].append({
100
+ "role": role,
101
+ "content": [content]
102
+ })
103
+ data_dict['dataset'] = task
104
+ data_dict['task'] = task
105
+ check_data_format(data_dict)
106
+ return data_dict
107
+
108
+
109
+ class StoppingCriteriaSub(StoppingCriteria):
110
+ def __init__(self, stops=[], encounters=1):
111
+ super().__init__()
112
+ self.stops = stops
113
+
114
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
115
+ for stop in self.stops:
116
+ if torch.all((stop == input_ids[0][-len(stop):])).item():
117
+ return True
118
+ return False
119
+
120
+
121
+ class Chat:
122
+ def __init__(self, model, processor: TarsierDataProcessor, device='cuda', debug=False):
123
+ self.model = model
124
+ self.processor = processor
125
+ self.device = device
126
+ self.debug = debug
127
+ stop_words_ids = [torch.tensor([self.processor.processor.tokenizer.eos_token_id]).to(device)]
128
+ self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
129
+
130
+ def ask(self,text,conv):
131
+ conv.messages.append([conv.roles[0], {"text": text, "type": "text"}])
132
+ return conv
133
+
134
+ def prepare_model_inputs(self, conv, n_frames=None):
135
+ # print(conv.messages)
136
+ data_dict = get_data_dict(conv, n_frames)
137
+ if self.debug:
138
+ # print(f"visual_data_file: {visual_data_file}", flush=True)
139
+ print(f"###Prompt:\n{get_prompt_from_data_dict(data_dict)}")
140
+
141
+ batch_data = self.processor(data_dict)
142
+ model_inputs = {}
143
+ for k, v in batch_data.items():
144
+ if not isinstance(v, torch.Tensor):
145
+ continue
146
+ model_inputs[k] = v.to(self.device)
147
+ return model_inputs, conv
148
+
149
+ def answer(self, conv, n_frames=None, max_new_tokens=256, num_beams=1, min_length=1, top_p=1.0,
150
+ repetition_penalty=1.0, length_penalty=1, temperature=0):
151
+ inputs, conv = self.prepare_model_inputs(conv, n_frames)
152
+ if self.model is not None:
153
+ outputs = self.model.generate(
154
+ **inputs,
155
+ max_new_tokens=max_new_tokens,
156
+ stopping_criteria=self.stopping_criteria,
157
+ num_beams=num_beams,
158
+ do_sample=True if temperature > 0 else False,
159
+ min_length=min_length,
160
+ top_p=top_p,
161
+ repetition_penalty=repetition_penalty,
162
+ length_penalty=length_penalty,
163
+ temperature=temperature,
164
+ )
165
+ output_text = self.processor.processor.tokenizer.decode(outputs[0][inputs['input_ids'][0].shape[0]:], skip_special_tokens=True)
166
+ else:
167
+ output_text = "Fake respone as launched in debug mode!"
168
+ conv.messages.append(
169
+ [conv.roles[1], {"text": output_text, "type": "text"}]
170
+ )
171
+ return output_text, conv
172
+
173
+ class EasyDict(dict):
174
+ """
175
+ Get attributes
176
+
177
+ >>> d = EasyDict({'foo':3})
178
+ >>> d['foo']
179
+ 3
180
+ >>> d.foo
181
+ 3
182
+ >>> d.bar
183
+ Traceback (most recent call last):
184
+ ...
185
+ AttributeError: 'EasyDict' object has no attribute 'bar'
186
+
187
+ Works recursively
188
+
189
+ >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}})
190
+ >>> isinstance(d.bar, dict)
191
+ True
192
+ >>> d.bar.x
193
+ 1
194
+ """
195
+
196
+ def __init__(self, d=None, **kwargs):
197
+ if d is None:
198
+ d = {}
199
+ if kwargs:
200
+ d.update(**kwargs)
201
+ for k, v in d.items():
202
+ setattr(self, k, v)
203
+ # Class attributes
204
+ for k in self.__class__.__dict__.keys():
205
+ if not (k.startswith("__") and k.endswith("__")) and not k in ("update", "pop"):
206
+ setattr(self, k, getattr(self, k))
207
+
208
+ def __setattr__(self, name, value):
209
+ if isinstance(value, (list, tuple)):
210
+ value = [self.__class__(x) if isinstance(x, dict) else x for x in value]
211
+ elif isinstance(value, dict) and not isinstance(value, self.__class__):
212
+ value = self.__class__(value)
213
+ super(EasyDict, self).__setattr__(name, value)
214
+ super(EasyDict, self).__setitem__(name, value)
215
+
216
+ __setitem__ = __setattr__
217
+
218
+ def update(self, e=None, **f):
219
+ d = e or dict()
220
+ d.update(f)
221
+ for k in d:
222
+ setattr(self, k, d[k])
223
+
224
+ def pop(self, k, d=None):
225
+ if hasattr(self, k):
226
+ delattr(self, k)
227
+ return super(EasyDict, self).pop(k, d)
228
+
229
+ conv_tarsier = EasyDict({
230
+ "system": "",
231
+ "roles": ("USER", "ASSISTANT"),
232
+ "messages": [],
233
+ "sep1": " ",
234
+ "sep2": "</s>",
235
+ }
236
+ )
237
+
238
+ conv_tarsier_yi = EasyDict({
239
+ "system": "",
240
+ "roles": ("USER", "ASSISTANT"),
241
+ "messages": [],
242
+ "sep1": " ",
243
+ "sep2": "<|endoftext|>",
244
+ }
245
+ )
246
+
247
+ conv_tarsier_qwen2_vl = EasyDict({
248
+ "system": "",
249
+ "roles": ("user", "assistant"),
250
+ "messages": [],
251
+ }
252
+ )
253
+
254
+ conv_templates = {
255
+ "tarsier2-7b": conv_tarsier_qwen2_vl
256
+ }
eval_scripts/DREAM-1K/tarsier/tools/ptbtokenizer.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ #
3
+ # File Name : ptbtokenizer.py
4
+ #
5
+ # Description : Do the PTB Tokenization and remove punctuations.
6
+ #
7
+ # Creation Date : 29-12-2014
8
+ # Last Modified : Thu Mar 19 09:53:35 2015
9
+ # Authors : Hao Fang <hfang@uw.edu> and Tsung-Yi Lin <tl483@cornell.edu>
10
+ import os
11
+ import subprocess
12
+ import tempfile
13
+
14
+ # path to the stanford corenlp jar
15
+ STANFORD_CORENLP_3_4_1_JAR = os.path.dirname(os.path.abspath(__file__)) + '/stanford-corenlp-3.4.1.jar'
16
+
17
+ # punctuations to be removed from the sentences
18
+ PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \
19
+ ".", "?", "!", ",", ":", "-", "--", "...", ";"]
20
+
21
+ class PTBTokenizer:
22
+ """Python wrapper of Stanford PTBTokenizer"""
23
+
24
+ def tokenize(self, captions_for_image):
25
+ cmd = [os.getenv("JAVA_HOME"), '-cp', STANFORD_CORENLP_3_4_1_JAR, \
26
+ 'edu.stanford.nlp.process.PTBTokenizer', \
27
+ '-preserveLines', '-lowerCase']
28
+
29
+ # ======================================================
30
+ # prepare data for PTB Tokenizer
31
+ # ======================================================
32
+ final_tokenized_captions_for_image = {}
33
+ image_id = [k for k, v in captions_for_image.items() for _ in range(len(v))]
34
+ sentences = '\n'.join([c['caption'].replace('\n', ' ') for k, v in captions_for_image.items() for c in v])
35
+
36
+ # ======================================================
37
+ # save sentences to temporary file
38
+ # ======================================================
39
+ path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__))
40
+ tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname)
41
+ tmp_file.write(sentences.encode())
42
+ tmp_file.close()
43
+
44
+ # ======================================================
45
+ # tokenize sentence
46
+ # ======================================================
47
+ cmd.append(os.path.basename(tmp_file.name))
48
+ p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \
49
+ stdout=subprocess.PIPE)
50
+ token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0]
51
+ token_lines = token_lines.decode()
52
+ lines = token_lines.split('\n')
53
+ # remove temp file
54
+ os.remove(tmp_file.name)
55
+
56
+ # ======================================================
57
+ # create dictionary for tokenized captions
58
+ # ======================================================
59
+ for k, line in zip(image_id, lines):
60
+ if not k in final_tokenized_captions_for_image:
61
+ final_tokenized_captions_for_image[k] = []
62
+ tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \
63
+ if w not in PUNCTUATIONS])
64
+ final_tokenized_captions_for_image[k].append(tokenized_caption)
65
+
66
+ return final_tokenized_captions_for_image
eval_scripts/DREAM-1K/tarsier/tools/rw_utils.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (2024) Bytedance Ltd. and/or its affiliates
2
+
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import json
15
+ from json import JSONEncoder
16
+ import numpy
17
+ import pandas as pd
18
+
19
+ class NumpyArrayEncoder(JSONEncoder):
20
+ def default(self, obj):
21
+ if isinstance(obj, numpy.ndarray):
22
+ return obj.tolist()
23
+ return JSONEncoder.default(self, obj)
24
+
25
+ def write_txt(data, path):
26
+ with open(path, 'w', encoding='utf-8')as f:
27
+ for d in data:
28
+ f.write(f'{d}\n')
29
+
30
+ def read_txt(path):
31
+ with open(path, 'r', encoding='utf-8', errors='ignore') as f:
32
+ lines = [l.strip('\n') for l in f.readlines()]
33
+ return lines
34
+
35
+ def read_jsonlines(path):
36
+ objs = []
37
+ with open(path) as f:
38
+ for line in f:
39
+ line = json.loads(line)
40
+ objs.append(line)
41
+ return objs
42
+
43
+ def write_jsonlines(data, path, cls=None, ensure_ascii=False):
44
+ with open(path, 'w') as f:
45
+ for d in data:
46
+ d = json.dumps(d, ensure_ascii=ensure_ascii, cls=cls)
47
+ f.write(d)
48
+ f.write('\n')
49
+
50
+ def read_parquet(path):
51
+ data = pd.read_parquet(path)
52
+ return data.to_dict('records')
53
+
54
+ def write_parquet(data, path):
55
+ data = pd.DataFrame(data)
56
+ data.to_parquet(path)
57
+
58
+ def read_csv(path):
59
+ data = pd.read_csv(path)
60
+ return data.to_dict(orient='records')
61
+
62
+ def write_csv(data, path):
63
+ data = pd.DataFrame(data)
64
+ data.to_csv(path, index=False, sep='\t')
eval_scripts/Daily-Omni/Daily-Omni_pipeline.sh ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ MODEL_PATHS=(
3
+ "path_to_AVoCaDO"
4
+ )
5
+ RESULTS_DIR="$1"
6
+
7
+ ORIGINAL_FILE="eval_scripts/Daily-Omni/grouped_data.json"
8
+ MERGED_FILE="$RESULTS_DIR/captioned_results.json"
9
+ if [ ! -f "$MERGED_FILE" ]; then
10
+ echo "MERGED_FILE not found. Creating from ORIGINAL_FILE..."
11
+ cp "$ORIGINAL_FILE" "$MERGED_FILE"
12
+ fi
13
+ CAPTION_FILES_TO_MERGE=()
14
+ CAPTION_KEYS=()
15
+
16
+ # Step 1: caption geneartion
17
+ for model_path in "${MODEL_PATHS[@]}"; do
18
+ CLEAN_PATH="${model_path%/}"
19
+ model_name=$(basename "$CLEAN_PATH")
20
+
21
+ caption_file="$RESULTS_DIR/${model_name}_caption.jsonl"
22
+ echo "Output caption file will be: $caption_file"
23
+
24
+ python eval_scripts/Daily-Omni/generate_caption.py \
25
+ --model_path "$model_path" \
26
+ --fout_path "$caption_file"
27
+
28
+ if [ -f "$caption_file" ]; then
29
+ CAPTION_FILES_TO_MERGE+=("$caption_file")
30
+ CAPTION_KEYS+=("${model_name}_caption")
31
+ else
32
+ echo "Error: Caption file $caption_file not generated for model $model_path."
33
+ exit 1
34
+ fi
35
+ done
36
+
37
+ # Step 2: merge generated caption files
38
+ echo "Merging all generated caption files..."
39
+ python eval_scripts/Daily-Omni/merge_captions.py \
40
+ --original_file "$MERGED_FILE" \
41
+ --caption_files "${CAPTION_FILES_TO_MERGE[@]}" \
42
+ --merged_file "$MERGED_FILE"
43
+
44
+
45
+ # Step 3: evaluation
46
+ python eval_scripts/Daily-Omni/evaluation.py \
47
+ --merged_file "$MERGED_FILE" \
48
+ --caption_keys "${CAPTION_KEYS[@]}"
49
+
50
+ # Step 4: analysis and save evaluation results
51
+ for caption_key in "${CAPTION_KEYS[@]}"; do
52
+ echo "Running analysis for caption key: $caption_key"
53
+
54
+ result_file="$RESULTS_DIR/${caption_key}_result.jsonl"
55
+ answer_key="${caption_key//_caption/_resp}"
56
+
57
+ if [ -f "$result_file" ]; then
58
+ python eval_scripts/Daily-Omni/analysis.py --result_file_path "$result_file" --answer_key "$answer_key"
59
+ else
60
+ echo "Warning: Result file '$result_file' not found for analysis."
61
+ fi
62
+ done
eval_scripts/Daily-Omni/analysis.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import argparse
3
+ import os
4
+
5
+ if __name__ == "__main__":
6
+ parser = argparse.ArgumentParser(description="Analyze the evaluation results.")
7
+ parser.add_argument("--result_file_path", type=str, required=True, help="Path to the result file (.jsonl).")
8
+ parser.add_argument("--answer_key", type=str, required=True, help="The key for the model's response in the result file.")
9
+
10
+ args = parser.parse_args()
11
+
12
+ data = pd.read_json(args.result_file_path, lines=True)
13
+
14
+ acc = (data['answer'].str.upper() == data[args.answer_key].str.upper()).mean()
15
+ print(f"Accuracy for {args.answer_key} is: {acc:.2%}")
16
+
17
+ with open(f"{os.path.dirname(args.result_file_path)}/{args.answer_key}.log", "w", encoding='utf-8') as fout:
18
+ fout.write(f"Accuracy for {args.answer_key} is: {acc:.2%}")
eval_scripts/Daily-Omni/evaluation.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ###
2
+ # using a llm to answer questions regarding to the video with the specific caption
3
+ ###
4
+ import os
5
+ os.environ['GOOGLE_APPLICATION_CREDENTIALS']=''
6
+ LOCATION = "global"
7
+ user_info_path = ''
8
+ user_info = json.load(open(user_info_path))
9
+ PROJECT_ID = user_info['project_id']
10
+ MODEL = "gemini-2.5-pro"
11
+
12
+ import sys
13
+ import time
14
+ import json
15
+ import traceback
16
+ import multiprocessing
17
+ import random
18
+ import numpy as np
19
+ import argparse
20
+ from google import genai
21
+ from google.genai import types
22
+ from IPython.display import HTML, Image, Markdown, display
23
+ from google import genai
24
+ from google.genai.types import (
25
+ FunctionDeclaration,
26
+ GenerateContentConfig,
27
+ GoogleSearch,
28
+ HarmBlockThreshold,
29
+ HarmCategory,
30
+ Part,
31
+ SafetySetting,
32
+ ThinkingConfig,
33
+ Tool,
34
+ ToolCodeExecution,
35
+ )
36
+ import subprocess
37
+
38
+ safety_settings = [
39
+ SafetySetting(category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=HarmBlockThreshold.OFF),
40
+ SafetySetting(category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=HarmBlockThreshold.OFF),
41
+ SafetySetting(category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold=HarmBlockThreshold.OFF),
42
+ SafetySetting(category=HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=HarmBlockThreshold.OFF)
43
+ ]
44
+
45
+ CONFIG = types.GenerateContentConfig(
46
+ temperature=0,
47
+ top_p=0.001,
48
+ thinking_config=types.ThinkingConfig(
49
+ include_thoughts=True,
50
+ thinking_budget=512
51
+ ),
52
+ safety_settings=safety_settings,
53
+ seed=SEED,
54
+ system_instruction='''
55
+ You are a precise QA assistant. Your task is to answer multiple-choice questions based ONLY on the video caption provided.
56
+ Do not use any outside knowledge or assumptions—your answer must strictly reflect information from the caption.
57
+ Always output only the capital letter corresponding to your choice (e.g., A, B, C, D).
58
+ If the caption does not provide enough information to answer the question, output "N/A" instead.
59
+ '''
60
+ )
61
+ client = genai.Client(vertexai=True, project=PROJECT_ID, location=LOCATION)
62
+
63
+ def set_seed(seed):
64
+ np.random.seed(seed)
65
+ random.seed(seed)
66
+
67
+ SEED = 42
68
+ set_seed(SEED)
69
+
70
+ def caption2json(json_path, caption_path):
71
+
72
+ with open(json_path, 'r', encoding='utf-8') as f:
73
+ json_data = json.load(f)
74
+ model = os.path.basename(caption_path).split("_")[0]
75
+
76
+ captions = {}
77
+ with open(caption_path, 'r', encoding='utf-8') as f:
78
+ for line in f:
79
+ if not line.strip():
80
+ continue
81
+ item = json.loads(line)
82
+ for vid, cap in item.items():
83
+ captions[vid] = cap
84
+
85
+ for entry in json_data:
86
+ vid = entry.get("video_id")
87
+ if vid in captions:
88
+ entry[f"{model}_caption"] = captions[vid]
89
+
90
+ with open(f"{model}_merge_data.json", 'w', encoding='utf-8') as f:
91
+ json.dump(json_data, f, ensure_ascii=False, indent=2)
92
+
93
+ print(f"merged successfully, the output file is {model}_merge_data.json")
94
+
95
+
96
+ def generate(prompt):
97
+ contents = [prompt]
98
+
99
+ answer, thinking = None, None
100
+ max_retries = 10
101
+
102
+ for i in range(max_retries):
103
+ try:
104
+ response = client.models.generate_content(
105
+ model=MODEL,
106
+ contents=contents,
107
+ config=CONFIG
108
+ )
109
+
110
+ answer_parts, thought_parts = [], []
111
+ for part in response.candidates[0].content.parts:
112
+ if not getattr(part, "text", None):
113
+ continue
114
+ if getattr(part, "thought", False):
115
+ thought_parts.append(part.text)
116
+ else:
117
+ answer_parts.append(part.text)
118
+ answer = "\n".join(answer_parts).strip()
119
+ thinking = "\n".join(thought_parts).strip()
120
+ if answer:
121
+ break
122
+ else:
123
+ print(f"[WARN] Attempt {i+1}: empty answer, retrying ... ")
124
+ time.sleep(3)
125
+ except Exception as e:
126
+ print(f"[ERROR] Attempt {i+1} failed: {e}")
127
+ traceback.print_exc()
128
+ time.sleep(3)
129
+ if not answer:
130
+ return None, None
131
+ print(answer)
132
+ return answer, thinking
133
+
134
+ def worker(task):
135
+ vid, video_duration, question, choices, answer, caption_key, answer_key, caption = task
136
+ choices_text = "\n".join([f"{c}" for c in choices])
137
+ prompt_filled = f'''
138
+ Here is the video caption:
139
+ "{caption}"
140
+
141
+ Question: {question}
142
+ Choices:
143
+ {choices_text}'''
144
+ try:
145
+ resp, _ = generate(prompt_filled)
146
+ return {
147
+ "video_id": vid,
148
+ "video_duration": video_duration,
149
+ "question": question,
150
+ "choices": choices,
151
+ "answer": answer,
152
+ caption_key: caption,
153
+ answer_key: resp
154
+ }
155
+ except Exception as e:
156
+ traceback.print_exc()
157
+ return {
158
+ "video_id": vid,
159
+ "video_duration": video_duration,
160
+ "question": question,
161
+ "choices": choices,
162
+ "answer": answer,
163
+ caption_key: caption,
164
+ answer_key: None
165
+ }
166
+
167
+ def run_multiprocess_tasks(tasks, num_processes=None, fout_path=None):
168
+ if num_processes is None:
169
+ num_processes = multiprocessing.cpu_count()
170
+
171
+ with multiprocessing.Pool(processes=num_processes) as pool:
172
+ results = pool.map(worker, tasks)
173
+
174
+ if fout_path:
175
+ with open(fout_path, "w", encoding='utf-8') as f:
176
+ for item in results:
177
+ f.write(json.dumps(item, ensure_ascii=False) + '\n')
178
+ f.flush()
179
+ return results
180
+
181
+ def eval_dailyomni_caption_qas(file_path, caption_keys=["omni_caption"]):
182
+ with open(file_path, 'r', encoding='utf-8') as f:
183
+ data = json.load(f)
184
+
185
+ all_results = []
186
+ for caption_key in caption_keys:
187
+ answer_key = caption_key.replace("_caption", "_resp")
188
+ fout_path = f"{os.path.dirname(file_path)}/{caption_key}_result.jsonl"
189
+
190
+ tasks = []
191
+ for video_info in data:
192
+ vid = video_info["video_id"]
193
+ video_duration = video_info["video_duration"]
194
+ caption = video_info[caption_key]
195
+ for q in video_info["questions"]:
196
+ task_item = (
197
+ vid,
198
+ video_duration,
199
+ q["Question"],
200
+ q["Choice"],
201
+ q["Answer"],
202
+ caption_key,
203
+ answer_key,
204
+ caption
205
+ )
206
+ tasks.append(task_item)
207
+
208
+ results = run_multiprocess_tasks(tasks, num_processes=20, fout_path=fout_path)
209
+ all_results.extend(results)
210
+
211
+ return all_results
212
+
213
+ if __name__ == "__main__":
214
+ parser = argparse.ArgumentParser(description="Evaluate captions using Gemini.")
215
+ parser.add_argument("--merged_file", type=str, required=True, help="Path to the merged caption file.")
216
+ parser.add_argument(
217
+ "--caption_keys",
218
+ type=str,
219
+ nargs='+',
220
+ required=True,
221
+ help="A list of caption keys to evaluate"
222
+ )
223
+ args = parser.parse_args()
224
+
225
+ eval_dailyomni_caption_qas(args.merged_file, caption_keys=args.caption_keys)
eval_scripts/Daily-Omni/generate_caption.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor
4
+ from qwen_omni_utils import process_mm_info
5
+ import argparse
6
+ import json
7
+ from tqdm import tqdm
8
+ from pathlib import Path
9
+ import multiprocessing as mp
10
+ import traceback
11
+ import random
12
+ import glob
13
+
14
+ VIDEO_MAX_PIXELS = 401408 # 512*28*28
15
+ VIDEO_TOTAL_PIXELS = 20070400 # 512*28*28*50
16
+ USE_AUDIO_IN_VIDEO = True
17
+ video_base_dir = "path_to_Daily-Omni_Videos"
18
+ os.environ['VIDEO_MAX_PIXELS'] = str(VIDEO_TOTAL_PIXELS)
19
+
20
+ def chat(file_path, prompt, model, processor, model_path, max_new_tokens=2048):
21
+
22
+ conversation = [
23
+ {
24
+ "role": "system",
25
+ "content": [
26
+ {"type": "text", "text": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."}
27
+ ],
28
+ },
29
+ {
30
+ "role": "user",
31
+ "content": [
32
+ {
33
+ "type": "video",
34
+ "video": file_path,
35
+ "max_pixels": VIDEO_MAX_PIXELS,
36
+ "max_frames": 256
37
+ },
38
+ {
39
+ "type": "text",
40
+ "text": prompt
41
+ },
42
+ ],
43
+ },
44
+ ]
45
+
46
+ text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
47
+ audios, images, videos = process_mm_info(conversation, use_audio_in_video=USE_AUDIO_IN_VIDEO)
48
+ inputs = processor(text=text, audio=audios, images=images, videos=videos, return_tensors="pt", padding=True, use_audio_in_video=USE_AUDIO_IN_VIDEO)
49
+ inputs = inputs.to(model.device).to(model.dtype)
50
+
51
+ text_ids = model.generate(**inputs, use_audio_in_video=USE_AUDIO_IN_VIDEO, do_sample=False, thinker_max_new_tokens=max_new_tokens)
52
+ text = processor.batch_decode(text_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
53
+ model_generation = text.split("\nassistant\n")[-1]
54
+
55
+ return model_generation
56
+
57
+ def worker_proc(rank, gpu_id, model_path, video_paths, prompt, out_path):
58
+ device_map = {"": f"cuda:{gpu_id}"}
59
+
60
+ model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
61
+ model_path,
62
+ torch_dtype=torch.bfloat16,
63
+ device_map=device_map,
64
+ attn_implementation="flash_attention_2",
65
+ )
66
+ model.disable_talker()
67
+ processor = Qwen2_5OmniProcessor.from_pretrained(model_path)
68
+
69
+ fout = open(out_path, "w", encoding="utf-8")
70
+
71
+ for video_path in tqdm(video_paths, desc=f"Worker-{rank}[GPU-{gpu_id}]"):
72
+ try:
73
+ model_generation = chat(video_path, prompt, model, processor, model_path)
74
+
75
+ video_id = os.path.basename(video_path).split(".mp4")[0]
76
+
77
+ out_data = {
78
+ "video_id": video_id,
79
+ "caption": model_generation,
80
+ }
81
+
82
+ fout.write(json.dumps(out_data, ensure_ascii=False) + "\n")
83
+ fout.flush()
84
+ except Exception as e:
85
+ print(f"[Worker-{rank}] Error on {video_path}: {e}")
86
+ traceback.print_exc()
87
+
88
+ fout.close()
89
+ print(f"[Worker-{rank}] Done, wrote results to {out_path}")
90
+
91
+ def run_multi_gpu(model_path, video_paths, prompt_list, final_out_path, num_gpus=8):
92
+ chunk_size = len(video_paths) // num_gpus + 1
93
+ chunks = [video_paths[i:i+chunk_size] for i in range(0, len(video_paths), chunk_size)]
94
+
95
+ processes = []
96
+ tmp_files = []
97
+
98
+ for rank, chunk in enumerate(chunks):
99
+ gpu_id = rank % num_gpus
100
+ tmp_out = final_out_path.replace(".jsonl", f".part{rank}.jsonl")
101
+ tmp_files.append(tmp_out)
102
+ prompt = random.choice(prompt_list)
103
+
104
+ p = mp.Process(
105
+ target=worker_proc,
106
+ args=(rank, gpu_id, model_path, chunk, prompt, tmp_out)
107
+ )
108
+ p.start()
109
+ processes.append(p)
110
+
111
+ for p in processes:
112
+ p.join()
113
+
114
+ with open(final_out_path, "w", encoding="utf-8") as fout:
115
+ for tmp in tmp_files:
116
+ with open(tmp, "r", encoding="utf-8") as fin:
117
+ for line in fin:
118
+ fout.write(line)
119
+ os.remove(tmp)
120
+
121
+ print(f"All results merged into {final_out_path}")
122
+
123
+ if __name__ == "__main__":
124
+ parser = argparse.ArgumentParser(description="Evaluate a model and save results.")
125
+ parser.add_argument("--model_path", type=str, required=True, help="Path to the model checkpoint.")
126
+ parser.add_argument("--fout_path", type=str, required=True, help="Path to the output caption file")
127
+ args = parser.parse_args()
128
+ mp.set_start_method("spawn", force=True)
129
+
130
+ video_paths = glob.glob(os.path.join(video_base_dir, "**", "*.mp4"), recursive=True)
131
+
132
+ prompt_list = [
133
+ "Provide a comprehensive description of all the content in the video, leaving out no details. Be sure to include as much of the audio information as possible, and ensure that your descriptions of the audio and video are closely aligned.",
134
+ "Thoroughly describe everything in the video, capturing every detail. Include as much information from the audio as possible, and ensure that the descriptions of both audio and video are well-coordinated.",
135
+ "Please describe all the information in the video without sparing every detail in it. As you describe, you should also describe as much of the information in the audio as possible, and pay attention to the synchronization between the audio and video descriptions.",
136
+ "Offer a detailed description of the video, making sure to include every detail. Also, incorporate as much information from the audio as you can, and ensure that your descriptions of the audio and video are in sync.",
137
+ "Describe every aspect of the video in full detail, covering all the information it contains. Additionally, include as much of the audio content as you can, and make sure your descriptions of the audio and video are synchronized.",
138
+ "Please provide a thorough description of all the content in the video, including every detail. As you describe, ensure that you also cover as much information from the audio as possible, and be mindful of the synchronization between the audio and video as you do so.",
139
+ "Give a detailed account of everything in the video, capturing all the specifics. While doing so, also include as much information from the audio as possible, ensuring that the descriptions of audio and video are well-synchronized."
140
+ ]
141
+
142
+ run_multi_gpu(args.model_path, video_paths, prompt_list, args.fout_path, num_gpus=8)
eval_scripts/Daily-Omni/grouped_data.json ADDED
The diff for this file is too large to render. See raw diff