Spaces:
Paused
Paused
Upload 68 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- README.md +92 -33
- assets/avocado.ico +0 -0
- assets/case_1.mp4 +3 -0
- assets/case_2.png +3 -0
- environment.yml +327 -0
- eval_scripts/DREAM-1K/dream_example.jsonl +0 -0
- eval_scripts/DREAM-1K/eval_DREAM-1K.sh +12 -0
- eval_scripts/DREAM-1K/generate_caption.py +91 -0
- eval_scripts/DREAM-1K/tarsier/LICENSE +201 -0
- eval_scripts/DREAM-1K/tarsier/configs/tarser2_default_config.yaml +14 -0
- eval_scripts/DREAM-1K/tarsier/dataset/custom_data_parsers/multi_images_parser.py +199 -0
- eval_scripts/DREAM-1K/tarsier/dataset/custom_data_parsers/object_tracking_parser.py +160 -0
- eval_scripts/DREAM-1K/tarsier/dataset/custom_data_parsers/standard_vision_parser.py +255 -0
- eval_scripts/DREAM-1K/tarsier/dataset/custom_data_parsers/utils.py +452 -0
- eval_scripts/DREAM-1K/tarsier/dataset/custom_data_parsers/utils_visualize.py +54 -0
- eval_scripts/DREAM-1K/tarsier/dataset/custom_data_parsers/video_permutation_parser.py +137 -0
- eval_scripts/DREAM-1K/tarsier/dataset/tarsier_datamodule.py +280 -0
- eval_scripts/DREAM-1K/tarsier/dataset/tarsier_processor.py +240 -0
- eval_scripts/DREAM-1K/tarsier/dataset/utils.py +186 -0
- eval_scripts/DREAM-1K/tarsier/evaluation/evaluate.py +177 -0
- eval_scripts/DREAM-1K/tarsier/evaluation/metrics/__init__.py +5 -0
- eval_scripts/DREAM-1K/tarsier/evaluation/metrics/evaluate_caption_cider.py +82 -0
- eval_scripts/DREAM-1K/tarsier/evaluation/metrics/evaluate_dream_gpt.py +436 -0
- eval_scripts/DREAM-1K/tarsier/evaluation/metrics/evaluate_qa_mc.py +159 -0
- eval_scripts/DREAM-1K/tarsier/evaluation/metrics/evaluate_qa_oe_gpt.py +153 -0
- eval_scripts/DREAM-1K/tarsier/evaluation/metrics/evaluate_video_mme.py +358 -0
- eval_scripts/DREAM-1K/tarsier/models/modeling_qwen2_vl_fast.py +1320 -0
- eval_scripts/DREAM-1K/tarsier/models/modeling_tarsier.py +502 -0
- eval_scripts/DREAM-1K/tarsier/models/utils.py +17 -0
- eval_scripts/DREAM-1K/tarsier/scripts/run_demo_cli.sh +15 -0
- eval_scripts/DREAM-1K/tarsier/scripts/run_demo_gradio.sh +9 -0
- eval_scripts/DREAM-1K/tarsier/scripts/run_evaluation_only.sh +12 -0
- eval_scripts/DREAM-1K/tarsier/scripts/run_inference_benchmark.sh +80 -0
- eval_scripts/DREAM-1K/tarsier/scripts/run_inference_caption.sh +79 -0
- eval_scripts/DREAM-1K/tarsier/tasks/demo_cli.py +116 -0
- eval_scripts/DREAM-1K/tarsier/tasks/demo_gradio.py +230 -0
- eval_scripts/DREAM-1K/tarsier/tasks/inference_benchmark.py +197 -0
- eval_scripts/DREAM-1K/tarsier/tasks/inference_caption.py +165 -0
- eval_scripts/DREAM-1K/tarsier/tasks/inference_quick_start.py +91 -0
- eval_scripts/DREAM-1K/tarsier/tasks/utils.py +45 -0
- eval_scripts/DREAM-1K/tarsier/tools/color.py +36 -0
- eval_scripts/DREAM-1K/tarsier/tools/conversation.py +256 -0
- eval_scripts/DREAM-1K/tarsier/tools/ptbtokenizer.py +66 -0
- eval_scripts/DREAM-1K/tarsier/tools/rw_utils.py +64 -0
- eval_scripts/Daily-Omni/Daily-Omni_pipeline.sh +62 -0
- eval_scripts/Daily-Omni/analysis.py +18 -0
- eval_scripts/Daily-Omni/evaluation.py +225 -0
- eval_scripts/Daily-Omni/generate_caption.py +142 -0
- eval_scripts/Daily-Omni/grouped_data.json +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/case_1.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/case_2.png filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,48 +1,107 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
---
|
| 10 |
|
| 11 |
-
|
|
|
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
|
|
|
| 15 |
|
| 16 |
-
##
|
|
|
|
| 17 |
|
| 18 |
-
1.
|
| 19 |
-
|
| 20 |
-
3. Set `AVOCADO_CMD` as a Space secret or enter it in the UI textbox.
|
| 21 |
-
- The command must accept `{input}` and `{output}` placeholders.
|
| 22 |
-
- Example:
|
| 23 |
-
```bash
|
| 24 |
-
python -m avocado_captioner.cli --input {input} --output {output}
|
| 25 |
-
```
|
| 26 |
-
4. Launch the Space, upload videos, and click **Process videos**.
|
| 27 |
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
```
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
| 32 |
```
|
| 33 |
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
-
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
-
-
|
| 39 |
-
-
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# <img src="assets/avocado.ico" alt="AVoCaDO icon" width="28px"> AVoCaDO: An <u>A</u>udio<u>V</u>isual Vide<u>o</u> <u>Ca</u>ptioner <u>D</u>riven by Temporal <u>O</u>rchestration
|
| 2 |
+
|
| 3 |
+
<p align="left">
|
| 4 |
+
<a href="https://avocado-captioner.github.io/"><img src="https://img.shields.io/badge/Project%20webpage-558b2f?style=for-the-badge"></a>
|
| 5 |
+
<a href="https://huggingface.co/AVoCaDO-Captioner/AVoCaDO"><img src="https://img.shields.io/badge/Model-db8905?style=for-the-badge"></a>
|
| 6 |
+
<a href="https://arxiv.org/abs/2510.10395"><img src="https://img.shields.io/badge/arXiv-red?style=for-the-badge"></a>
|
| 7 |
+
</p>
|
| 8 |
+
|
| 9 |
---
|
| 10 |
|
| 11 |
+
## ✨ Overview
|
| 12 |
+
Audiovisual video captioning aims to generate semantically rich descriptions with temporal alignment between visual and auditory events, thereby benefiting both video understanding and generation. We introduce <b>AVoCaDO</b>, a powerful audiovisual video captioner driven by the temporal orchestration between audio and visual modalities. Experimental results demonstrate that AVoCaDO significantly outperforms existing open-source models across four audiovisual video captioning benchmarks, and also achieves competitive performance under visual-only settings.
|
| 13 |
|
| 14 |
+
## 🎬 Captioning Case of AVoCaDO
|
| 15 |
+
<img src="assets/case_2.png" alt="AVoCaDO caption">
|
| 16 |
+
An illustration of a video caption generated by AVoCaDO, featuring both <b>precise audiovisual temporal alignment</b> and <u>accurate dialogue rendering</u>.
|
| 17 |
|
| 18 |
+
## 🚀 Getting Started
|
| 19 |
+
Follow these simple steps to set up and run AVoCaDO on your machine.
|
| 20 |
|
| 21 |
+
### 1. Clone the repository
|
| 22 |
+
First, clone the project and navigate into the directory:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
+
```bash
|
| 25 |
+
git clone https://github.com/AVoCaDO-Captioner/AVoCaDO.git
|
| 26 |
+
cd AVoCaDO
|
| 27 |
+
```
|
| 28 |
|
| 29 |
+
### 2. Set Up the Environment
|
| 30 |
+
Create and activate the Conda environment using the provided ``environment.yml`` file.
|
| 31 |
+
|
| 32 |
+
```bash
|
| 33 |
+
conda env create -f environment.yml
|
| 34 |
+
conda activate AVoCaDO
|
| 35 |
```
|
| 36 |
+
|
| 37 |
+
### 3. Quick Usage
|
| 38 |
+
```python
|
| 39 |
+
python inference.py assets/case_1.mp4
|
| 40 |
```
|
| 41 |
|
| 42 |
+
## 📈 Benchmark Evaluation
|
| 43 |
+
We provide evaluation scripts for all evaluated benchmarks in our paper.
|
| 44 |
+
|
| 45 |
+
### Direct Audiovisual Caption Evaluation
|
| 46 |
+
1. **video-SALMONN2-testset:**
|
| 47 |
+
```bash
|
| 48 |
+
bash eval_scripts/video-SALMONN2-testset/eval_video-SALMONN2-test.sh <your_save_directory>
|
| 49 |
+
```
|
| 50 |
|
| 51 |
+
2. **UGC-VideoCap:**
|
| 52 |
+
```bash
|
| 53 |
+
bash eval_scripts/UGC-VideoCap/eval_UGC-VideoCap.sh <your_save_directory>
|
| 54 |
+
```
|
| 55 |
|
| 56 |
+
### QA-based Audiovisual Caption Evaluation
|
| 57 |
+
1. **Daily-Omni:**
|
| 58 |
+
```bash
|
| 59 |
+
bash eval_scripts/Daily-Omni/Daily-Omni_pipeline.sh <your_save_directory>
|
| 60 |
+
```
|
| 61 |
|
| 62 |
+
2. **WorldSense:**
|
| 63 |
+
```bash
|
| 64 |
+
bash eval_scripts/WorldSense/WorldSense_pipeline.sh <your_save_directory>
|
| 65 |
+
```
|
| 66 |
|
| 67 |
+
### Visual-only Caption Evaluation
|
| 68 |
+
1. **VDC:**
|
| 69 |
+
First, generate captions for the videos in the VDC benchmark.
|
| 70 |
+
```python
|
| 71 |
+
python eval_scripts/VDC/generate_caption.py \
|
| 72 |
+
--model_path <path_to_AVoCaDO> \
|
| 73 |
+
--fout_path <your_save_path>
|
| 74 |
+
```
|
| 75 |
|
| 76 |
+
Next, set up the judge server. This requires installing [SGLang](https://github.com/sgl-project/sglang) to deploy the [Llama-3.1-8B](https://huggingface.co/meta-llama/Llama-3.1-8B) as the judge model.
|
| 77 |
+
```python
|
| 78 |
+
# Deploy the judge model using SGLang
|
| 79 |
+
python -m sglang.launch_server \
|
| 80 |
+
--model-path path_to_Meta-Llama-3.1-8B-Instruct \
|
| 81 |
+
--port 30000 \
|
| 82 |
+
--dp 2 --tp 4
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
Once the judge model is successfully deployed and running, you can start the evaluation.
|
| 86 |
+
```bash
|
| 87 |
+
bash AVoCaDO/eval_scripts/VDC/evaluation.sh <your_save_path>
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
2. **DREAM-1K:**
|
| 91 |
+
```bash
|
| 92 |
+
bash eval_scripts/DREAM-1K/eval_DREAM-1K.sh <your_save_directory>
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
## ✒️ Citation
|
| 97 |
+
|
| 98 |
+
If you find our work helpful for your research, please consider giving a star ⭐ and citing our paper. We appreciate your support!
|
| 99 |
+
|
| 100 |
+
```bibtex
|
| 101 |
+
@article{chen2025avocado,
|
| 102 |
+
title={AVoCaDO: An Audiovisual Video Captioner Driven by Temporal Orchestration},
|
| 103 |
+
author={Chen, Xinlong and Ding, Yue and Lin, Weihong and Hua, Jingyun and Yao, Linli and Shi, Yang and Li, Bozhou and Zhang, Yuanxing and Liu, Qiang and Wan, Pengfei and others},
|
| 104 |
+
journal={arXiv preprint arXiv:2510.10395},
|
| 105 |
+
year={2025}
|
| 106 |
+
}
|
| 107 |
+
```
|
assets/avocado.ico
ADDED
|
|
assets/case_1.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:78c17b195eae977e5ffdeafb499fbeeec2ea50f9258973154c0da111c2b90b07
|
| 3 |
+
size 6417271
|
assets/case_2.png
ADDED
|
Git LFS Details
|
environment.yml
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: AVoCaDO
|
| 2 |
+
channels:
|
| 3 |
+
- conda-forge
|
| 4 |
+
- defaults
|
| 5 |
+
dependencies:
|
| 6 |
+
- _libgcc_mutex=0.1=main
|
| 7 |
+
- _openmp_mutex=5.1=1_gnu
|
| 8 |
+
- bzip2=1.0.8=h5eee18b_6
|
| 9 |
+
- ca-certificates=2024.9.24=h06a4308_0
|
| 10 |
+
- ld_impl_linux-64=2.40=h12ee557_0
|
| 11 |
+
- libffi=3.4.4=h6a678d5_1
|
| 12 |
+
- libgcc-ng=11.2.0=h1234567_1
|
| 13 |
+
- libgomp=11.2.0=h1234567_1
|
| 14 |
+
- libstdcxx-ng=11.2.0=h1234567_1
|
| 15 |
+
- libuuid=1.41.5=h5eee18b_0
|
| 16 |
+
- mscorefonts=0.0.1=3
|
| 17 |
+
- ncurses=6.4=h6a678d5_0
|
| 18 |
+
- openssl=3.0.15=h5eee18b_0
|
| 19 |
+
- pip=24.2=py310h06a4308_0
|
| 20 |
+
- python=3.10.15=he870216_1
|
| 21 |
+
- readline=8.2=h5eee18b_0
|
| 22 |
+
- setuptools=75.1.0=py310h06a4308_0
|
| 23 |
+
- sqlite=3.45.3=h5eee18b_0
|
| 24 |
+
- tk=8.6.14=h39e8969_0
|
| 25 |
+
- wheel=0.44.0=py310h06a4308_0
|
| 26 |
+
- xz=5.4.6=h5eee18b_1
|
| 27 |
+
- zlib=1.2.13=h5eee18b_1
|
| 28 |
+
- pip:
|
| 29 |
+
- accelerate==0.34.1
|
| 30 |
+
- aiodns==3.2.0
|
| 31 |
+
- aiohappyeyeballs==2.4.3
|
| 32 |
+
- aiohttp==3.10.10
|
| 33 |
+
- aiosignal==1.3.1
|
| 34 |
+
- annotated-types==0.7.0
|
| 35 |
+
- antlr4-python3-runtime==4.11.0
|
| 36 |
+
- anyio==4.6.2.post1
|
| 37 |
+
- apscheduler==3.10.4
|
| 38 |
+
- asttokens==2.4.1
|
| 39 |
+
- async-timeout==4.0.3
|
| 40 |
+
- attrs==24.2.0
|
| 41 |
+
- audioread==3.0.1
|
| 42 |
+
- av==13.1.0
|
| 43 |
+
- backcall==0.2.0
|
| 44 |
+
- beartype==0.19.0
|
| 45 |
+
- beautifulsoup4==4.13.4
|
| 46 |
+
- bleach==6.2.0
|
| 47 |
+
- blinker==1.9.0
|
| 48 |
+
- cachetools==4.2.4
|
| 49 |
+
- cchardet==2.1.7
|
| 50 |
+
- certifi==2021.10.8
|
| 51 |
+
- cffi==1.17.1
|
| 52 |
+
- charset-normalizer==3.4.0
|
| 53 |
+
- cityhash==0.2.4.post11
|
| 54 |
+
- click==8.1.7
|
| 55 |
+
- cloudpickle==3.1.0
|
| 56 |
+
- colorama==0.4.6
|
| 57 |
+
- compressed-tensors==0.8.0
|
| 58 |
+
- contourpy==1.3.1
|
| 59 |
+
- cycler==0.12.1
|
| 60 |
+
- dapr==1.14.0
|
| 61 |
+
- dapr-ext-fastapi==1.14.0
|
| 62 |
+
- dask==2024.12.1
|
| 63 |
+
- datasets==3.1.0
|
| 64 |
+
- dbpool==1.2.1
|
| 65 |
+
- decorator==4.4.2
|
| 66 |
+
- decord==0.6.0
|
| 67 |
+
- deepspeed==0.16.2
|
| 68 |
+
- defusedxml==0.7.1
|
| 69 |
+
- dill==0.4.0
|
| 70 |
+
- diskcache==5.6.3
|
| 71 |
+
- distro==1.9.0
|
| 72 |
+
- docopt==0.6.2
|
| 73 |
+
- docstring-parser==0.16
|
| 74 |
+
- dsc-auth==0.1.18
|
| 75 |
+
- einops==0.8.0
|
| 76 |
+
- et-xmlfile==2.0.0
|
| 77 |
+
- exceptiongroup==1.2.2
|
| 78 |
+
- executing==2.1.0
|
| 79 |
+
- fastapi==0.115.5
|
| 80 |
+
- fastjsonschema==2.21.1
|
| 81 |
+
- ffmpeg-python==0.2.0
|
| 82 |
+
- filelock==3.13.1
|
| 83 |
+
- fire==0.7.0
|
| 84 |
+
- flash-attn==2.7.0.post2
|
| 85 |
+
- flask==3.1.0
|
| 86 |
+
- fonttools==4.55.0
|
| 87 |
+
- frozenlist==1.5.0
|
| 88 |
+
- fsspec==2024.2.0
|
| 89 |
+
- ftfy==6.3.1
|
| 90 |
+
- func-timeout==4.3.5
|
| 91 |
+
- future==1.0.0
|
| 92 |
+
- fvcore==0.1.5.post20221221
|
| 93 |
+
- gguf==0.10.0
|
| 94 |
+
- google-api-core==2.23.0
|
| 95 |
+
- google-auth==2.36.0
|
| 96 |
+
- google-cloud-aiplatform==1.71.1
|
| 97 |
+
- google-cloud-bigquery==3.27.0
|
| 98 |
+
- google-cloud-core==2.4.1
|
| 99 |
+
- google-cloud-resource-manager==1.13.1
|
| 100 |
+
- google-cloud-storage==2.18.2
|
| 101 |
+
- google-crc32c==1.6.0
|
| 102 |
+
- google-resumable-media==2.7.2
|
| 103 |
+
- googleapis-common-protos==1.66.0
|
| 104 |
+
- grpc-google-iam-v1==0.13.1
|
| 105 |
+
- grpcio==1.68.0
|
| 106 |
+
- grpcio-reflection==1.48.2
|
| 107 |
+
- grpcio-status==1.68.0
|
| 108 |
+
- h11==0.14.0
|
| 109 |
+
- h5py==3.12.1
|
| 110 |
+
- hf-xet==1.1.4
|
| 111 |
+
- hickle==5.0.3
|
| 112 |
+
- hiredis==2.4.0
|
| 113 |
+
- hjson==3.1.0
|
| 114 |
+
- httpcore==1.0.6
|
| 115 |
+
- httptools==0.6.4
|
| 116 |
+
- httpx==0.27.2
|
| 117 |
+
- huggingface-hub==0.33.0
|
| 118 |
+
- humanize==4.11.0
|
| 119 |
+
- icecream==2.1.3
|
| 120 |
+
- idna==3.10
|
| 121 |
+
- imageio==2.36.0
|
| 122 |
+
- imageio-ffmpeg==0.5.1
|
| 123 |
+
- importlib-metadata==8.5.0
|
| 124 |
+
- infra-component==1.4.7
|
| 125 |
+
- infra-framework==1.17.10
|
| 126 |
+
- infra-kconf==1.1.3
|
| 127 |
+
- infra-kess==1.1.5
|
| 128 |
+
- infra-keycenter==1.1.1
|
| 129 |
+
- infra-storage==1.3.1
|
| 130 |
+
- install==1.3.5
|
| 131 |
+
- interegular==0.3.3
|
| 132 |
+
- iopath==0.1.10
|
| 133 |
+
- ipdb==0.13.13
|
| 134 |
+
- ipython==8.12.3
|
| 135 |
+
- itsdangerous==2.2.0
|
| 136 |
+
- jedi==0.19.2
|
| 137 |
+
- jinja2==3.1.3
|
| 138 |
+
- jiter==0.7.0
|
| 139 |
+
- joblib==1.4.2
|
| 140 |
+
- jsonschema==4.23.0
|
| 141 |
+
- jsonschema-specifications==2024.10.1
|
| 142 |
+
- jupyter-client==8.6.3
|
| 143 |
+
- jupyter-core==5.8.1
|
| 144 |
+
- jupyterlab-pygments==0.3.0
|
| 145 |
+
- kazoo==2.10.0
|
| 146 |
+
- kiwisolver==1.4.7
|
| 147 |
+
- ks-kafka-python==2.0.3
|
| 148 |
+
- lark==1.2.2
|
| 149 |
+
- lazy-loader==0.4
|
| 150 |
+
- levenshtein==0.26.1
|
| 151 |
+
- librosa==0.11.0
|
| 152 |
+
- llvmlite==0.43.0
|
| 153 |
+
- lm-format-enforcer==0.10.6
|
| 154 |
+
- locket==1.0.0
|
| 155 |
+
- lxml==4.9.4
|
| 156 |
+
- lz4==3.1.10
|
| 157 |
+
- markupsafe==2.1.5
|
| 158 |
+
- matplotlib==3.10.0
|
| 159 |
+
- matplotlib-inline==0.1.7
|
| 160 |
+
- mergedeep==1.3.4
|
| 161 |
+
- mistral-common==1.5.1
|
| 162 |
+
- mistral-inference==1.5.0
|
| 163 |
+
- mistune==3.1.3
|
| 164 |
+
- moviepy==1.0.3
|
| 165 |
+
- mpmath==1.3.0
|
| 166 |
+
- msgpack==1.1.0
|
| 167 |
+
- msgpack-numpy==0.4.8
|
| 168 |
+
- msgspec==0.18.6
|
| 169 |
+
- multidict==6.1.0
|
| 170 |
+
- multiprocess==0.70.18
|
| 171 |
+
- mysql-connector-python==8.0.31
|
| 172 |
+
- nbclient==0.10.2
|
| 173 |
+
- nbconvert==7.16.6
|
| 174 |
+
- nbformat==5.10.4
|
| 175 |
+
- nest-asyncio==1.6.0
|
| 176 |
+
- networkx==3.2.1
|
| 177 |
+
- ninja==1.11.1.3
|
| 178 |
+
- numba==0.60.0
|
| 179 |
+
- numpy==1.26.3
|
| 180 |
+
- nvidia-cublas-cu12==12.4.5.8
|
| 181 |
+
- nvidia-cuda-cupti-cu12==12.4.127
|
| 182 |
+
- nvidia-cuda-nvrtc-cu12==12.4.127
|
| 183 |
+
- nvidia-cuda-runtime-cu12==12.4.127
|
| 184 |
+
- nvidia-cudnn-cu12==9.1.0.70
|
| 185 |
+
- nvidia-cufft-cu12==11.2.1.3
|
| 186 |
+
- nvidia-curand-cu12==10.3.5.147
|
| 187 |
+
- nvidia-cusolver-cu12==11.6.1.9
|
| 188 |
+
- nvidia-cusparse-cu12==12.3.1.170
|
| 189 |
+
- nvidia-cusparselt-cu12==0.6.2
|
| 190 |
+
- nvidia-ml-py==12.560.30
|
| 191 |
+
- nvidia-nccl-cu12==2.21.5
|
| 192 |
+
- nvidia-nvjitlink-cu12==12.4.127
|
| 193 |
+
- nvidia-nvtx-cu12==12.4.127
|
| 194 |
+
- nvitop==1.3.2
|
| 195 |
+
- omegaconf==2.3.0
|
| 196 |
+
- open-clip-torch==2.29.0
|
| 197 |
+
- openai==1.54.3
|
| 198 |
+
- opencv-python==4.10.0.84
|
| 199 |
+
- opencv-python-headless==4.10.0.84
|
| 200 |
+
- openpyxl==3.1.5
|
| 201 |
+
- outlines==0.0.46
|
| 202 |
+
- packaging==24.1
|
| 203 |
+
- pandas==2.2.3
|
| 204 |
+
- pandocfilters==1.5.1
|
| 205 |
+
- parameterized==0.9.0
|
| 206 |
+
- parso==0.8.4
|
| 207 |
+
- partd==1.4.2
|
| 208 |
+
- partial-json-parser==0.2.1.1.post4
|
| 209 |
+
- pathos==0.3.4
|
| 210 |
+
- peft==0.14.0
|
| 211 |
+
- pexpect==4.9.0
|
| 212 |
+
- pickleshare==0.7.5
|
| 213 |
+
- pillow==10.4.0
|
| 214 |
+
- pipreqs==0.5.0
|
| 215 |
+
- platformdirs==4.3.6
|
| 216 |
+
- pooch==1.8.2
|
| 217 |
+
- portalocker==3.0.0
|
| 218 |
+
- pox==0.3.6
|
| 219 |
+
- ppft==1.7.7
|
| 220 |
+
- prettytable==2.5.0
|
| 221 |
+
- proglog==0.1.10
|
| 222 |
+
- prometheus-client==0.21.0
|
| 223 |
+
- prometheus-fastapi-instrumentator==7.0.0
|
| 224 |
+
- prompt-toolkit==3.0.48
|
| 225 |
+
- propcache==0.2.0
|
| 226 |
+
- proto-plus==1.25.0
|
| 227 |
+
- protobuf==3.20.0
|
| 228 |
+
- psutil==6.1.0
|
| 229 |
+
- ptyprocess==0.7.0
|
| 230 |
+
- pure-eval==0.2.3
|
| 231 |
+
- py-cpuinfo==9.0.0
|
| 232 |
+
- pyairports==2.1.1
|
| 233 |
+
- pyarrow==18.0.0
|
| 234 |
+
- pyasn1==0.6.1
|
| 235 |
+
- pyasn1-modules==0.4.1
|
| 236 |
+
- pycares==4.4.0
|
| 237 |
+
- pycocoevalcap==1.2
|
| 238 |
+
- pycocotools==2.0.8
|
| 239 |
+
- pycountry==24.6.1
|
| 240 |
+
- pycparser==2.22
|
| 241 |
+
- pycryptodome==3.21.0
|
| 242 |
+
- pydantic==2.9.2
|
| 243 |
+
- pydantic-core==2.23.4
|
| 244 |
+
- pydub==0.25.1
|
| 245 |
+
- pygments==2.18.0
|
| 246 |
+
- pyparsing==3.2.0
|
| 247 |
+
- pysmhasher==0.2.5
|
| 248 |
+
- pysoundfile==0.9.0.post1
|
| 249 |
+
- python-dateutil==2.9.0.post0
|
| 250 |
+
- python-dotenv==1.0.1
|
| 251 |
+
- python-levenshtein==0.26.1
|
| 252 |
+
- python-snappy==0.6.1
|
| 253 |
+
- pytorchvideo==0.1.5
|
| 254 |
+
- pytube==15.0.0
|
| 255 |
+
- pytz==2021.3
|
| 256 |
+
- pytz-deprecation-shim==0.1.0.post0
|
| 257 |
+
- pyyaml==6.0.2
|
| 258 |
+
- pyzmq==26.2.0
|
| 259 |
+
- qwen-omni-utils==0.0.8
|
| 260 |
+
- qwen-vl-utils==0.0.8
|
| 261 |
+
- rapidfuzz==3.12.2
|
| 262 |
+
- ray==2.38.0
|
| 263 |
+
- redis==4.6.0
|
| 264 |
+
- referencing==0.35.1
|
| 265 |
+
- regex==2024.9.11
|
| 266 |
+
- requests==2.32.3
|
| 267 |
+
- rpds-py==0.21.0
|
| 268 |
+
- rsa==4.9
|
| 269 |
+
- safetensors==0.4.5
|
| 270 |
+
- scenedetect==0.6.4
|
| 271 |
+
- scikit-learn==1.6.0
|
| 272 |
+
- scipy==1.14.1
|
| 273 |
+
- seaborn==0.13.2
|
| 274 |
+
- sentencepiece==0.2.0
|
| 275 |
+
- setuptools-scm==8.1.0
|
| 276 |
+
- shapely==2.0.6
|
| 277 |
+
- simple-parsing==0.1.6
|
| 278 |
+
- six==1.16.0
|
| 279 |
+
- sniffio==1.3.1
|
| 280 |
+
- soundfile==0.13.1
|
| 281 |
+
- soupsieve==2.7
|
| 282 |
+
- soxr==0.5.0.post1
|
| 283 |
+
- sqlparse==0.4.4
|
| 284 |
+
- stack-data==0.6.3
|
| 285 |
+
- starlette==0.41.3
|
| 286 |
+
- sympy==1.13.1
|
| 287 |
+
- tabulate==0.9.0
|
| 288 |
+
- termcolor==2.5.0
|
| 289 |
+
- threadpoolctl==3.5.0
|
| 290 |
+
- tiktoken==0.7.0
|
| 291 |
+
- timm==1.0.12
|
| 292 |
+
- tinycss2==1.4.0
|
| 293 |
+
- tokenizers==0.21.1
|
| 294 |
+
- tomli==2.0.2
|
| 295 |
+
- toolz==1.0.0
|
| 296 |
+
- torch==2.6.0
|
| 297 |
+
- torchvision==0.21.0
|
| 298 |
+
- tornado==6.4.1
|
| 299 |
+
- tqdm==4.67.0
|
| 300 |
+
- traitlets==5.14.3
|
| 301 |
+
- transformers==4.52.3
|
| 302 |
+
- triton==3.2.0
|
| 303 |
+
- typeguard==4.4.1
|
| 304 |
+
- typing-extensions==4.12.2
|
| 305 |
+
- tzdata==2024.2
|
| 306 |
+
- tzlocal==4.3.1
|
| 307 |
+
- unpaddedbase64==2.1.0
|
| 308 |
+
- urllib3==1.26.20
|
| 309 |
+
- uvicorn==0.32.0
|
| 310 |
+
- uvloop==0.21.0
|
| 311 |
+
- vertexai==1.71.1
|
| 312 |
+
- vllm==0.6.3
|
| 313 |
+
- watchfiles==0.24.0
|
| 314 |
+
- wcwidth==0.2.13
|
| 315 |
+
- webencodings==0.5.1
|
| 316 |
+
- websockets==13.1
|
| 317 |
+
- werkzeug==3.1.3
|
| 318 |
+
- word2number==1.1
|
| 319 |
+
- xformers==0.0.27.post2
|
| 320 |
+
- xmltodict==0.12.0
|
| 321 |
+
- xxhash==3.5.0
|
| 322 |
+
- yacs==0.1.8
|
| 323 |
+
- yarg==0.1.9
|
| 324 |
+
- yarl==1.17.1
|
| 325 |
+
- zhon==2.1.1
|
| 326 |
+
- zipp==3.20.2
|
| 327 |
+
- zsvision==0.7.12
|
eval_scripts/DREAM-1K/dream_example.jsonl
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
eval_scripts/DREAM-1K/eval_DREAM-1K.sh
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
MODEL_PATH="path_to_AVoCaDO" # TODO
|
| 3 |
+
OUTPUT_DIR="$1"
|
| 4 |
+
|
| 5 |
+
mkdir -p "$OUTPUT_DIR"
|
| 6 |
+
|
| 7 |
+
python eval_scripts/DREAM-1K/generate_caption.py \
|
| 8 |
+
--model_path "$MODEL_PATH" \
|
| 9 |
+
--save_path "$OUTPUT_DIR/model_caption.jsonl"
|
| 10 |
+
|
| 11 |
+
bash eval_scripts/DREAM-1K/tarsier/scripts/run_evaluation_only.sh "$OUTPUT_DIR/model_caption.jsonl"
|
| 12 |
+
|
eval_scripts/DREAM-1K/generate_caption.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor
|
| 4 |
+
from qwen_omni_utils import process_mm_info
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
VIDEO_MAX_PIXELS = 401408 # 512*28*28
|
| 11 |
+
VIDEO_TOTAL_PIXELS = 20070400 # 512*28*28*50
|
| 12 |
+
USE_AUDIO_IN_VIDEO = False
|
| 13 |
+
os.environ['VIDEO_MAX_PIXELS'] = str(VIDEO_TOTAL_PIXELS)
|
| 14 |
+
script_dir = Path(__file__).resolve().parent
|
| 15 |
+
example_path = script_dir / "dream_example.jsonl"
|
| 16 |
+
video_dir = "" # TODO
|
| 17 |
+
|
| 18 |
+
parser = argparse.ArgumentParser(description="Evaluate a model and save results.")
|
| 19 |
+
parser.add_argument("--model_path", type=str, required=True, help="Path to the model checkpoint.")
|
| 20 |
+
parser.add_argument("--save_path", type=str, required=True, help="Path to save the evaluation results.")
|
| 21 |
+
args = parser.parse_args()
|
| 22 |
+
|
| 23 |
+
model_path = args.model_path
|
| 24 |
+
fout_path = args.save_path
|
| 25 |
+
|
| 26 |
+
f_example = open(example_path, 'r', encoding='utf-8')
|
| 27 |
+
fout = open(fout_path, 'w', encoding='utf-8')
|
| 28 |
+
|
| 29 |
+
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
|
| 30 |
+
model_path,
|
| 31 |
+
torch_dtype=torch.bfloat16,
|
| 32 |
+
device_map="auto",
|
| 33 |
+
attn_implementation="flash_attention_2",
|
| 34 |
+
)
|
| 35 |
+
model.disable_talker()
|
| 36 |
+
processor = Qwen2_5OmniProcessor.from_pretrained(model_path)
|
| 37 |
+
|
| 38 |
+
def chat(data):
|
| 39 |
+
conversation = [
|
| 40 |
+
{
|
| 41 |
+
"role": "system",
|
| 42 |
+
"content": [
|
| 43 |
+
{"type": "text", "text": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."}
|
| 44 |
+
],
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"role": "user",
|
| 48 |
+
"content": [
|
| 49 |
+
{
|
| 50 |
+
"type": "video",
|
| 51 |
+
"video": data["video_path"],
|
| 52 |
+
"max_pixels": VIDEO_MAX_PIXELS,
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"type": "text",
|
| 56 |
+
"text": data["question"]
|
| 57 |
+
},
|
| 58 |
+
],
|
| 59 |
+
},
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
|
| 63 |
+
audios, images, videos = process_mm_info(conversation, use_audio_in_video=USE_AUDIO_IN_VIDEO)
|
| 64 |
+
inputs = processor(text=text, audio=audios, images=images, videos=videos, return_tensors="pt", padding=True, use_audio_in_video=USE_AUDIO_IN_VIDEO)
|
| 65 |
+
inputs = inputs.to(model.device).to(model.dtype)
|
| 66 |
+
|
| 67 |
+
text_ids = model.generate(**inputs, use_audio_in_video=USE_AUDIO_IN_VIDEO, do_sample=False, thinker_max_new_tokens=2048)
|
| 68 |
+
|
| 69 |
+
text = processor.batch_decode(text_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 70 |
+
model_generation = text.split("\nassistant\n")[-1]
|
| 71 |
+
|
| 72 |
+
return model_generation
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
for idx, line in tqdm(enumerate(f_example, start=1)):
|
| 76 |
+
data = json.loads(line)
|
| 77 |
+
video_path = os.path.join(video_dir, data["messages"][0]["content"][0]["video"]["video_file"])
|
| 78 |
+
question = "Imagine the video from these frames and describe it in detail."
|
| 79 |
+
|
| 80 |
+
temp_data = {
|
| 81 |
+
"video_path": video_path,
|
| 82 |
+
"question": question,
|
| 83 |
+
}
|
| 84 |
+
with torch.inference_mode():
|
| 85 |
+
response = chat(temp_data)
|
| 86 |
+
|
| 87 |
+
out_data = data
|
| 88 |
+
data["messages"][0]["content"][1]["text"] = question
|
| 89 |
+
out_data["messages"][1]["content"][0]["text"] = response
|
| 90 |
+
fout.write(json.dumps(out_data, ensure_ascii=False) + '\n')
|
| 91 |
+
fout.flush()
|
eval_scripts/DREAM-1K/tarsier/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
eval_scripts/DREAM-1K/tarsier/configs/tarser2_default_config.yaml
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
max_n_frames: 256
|
| 2 |
+
n_frames: 16
|
| 3 |
+
max_pixels: 460800 # 1280 * 720 // 2
|
| 4 |
+
min_pixels: 0
|
| 5 |
+
max_seq_len: 16384
|
| 6 |
+
is_training: false # 会影响:1. 训练和测试时采帧不同;2. 测试时忽略 response。
|
| 7 |
+
print_data_error: true
|
| 8 |
+
is_training: false
|
| 9 |
+
do_image_padding: false
|
| 10 |
+
do_image_crop: false
|
| 11 |
+
do_image_resize: false
|
| 12 |
+
video_sampling_strategy: {'video_sampler_version': 'v1', 'force_frames_n_divisible': 1, 'use_multi_images_for_video': true}
|
| 13 |
+
prompt: ""
|
| 14 |
+
train_task: sft
|
eval_scripts/DREAM-1K/tarsier/dataset/custom_data_parsers/multi_images_parser.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List
|
| 2 |
+
import random
|
| 3 |
+
import re
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
from .utils import sample_video, read_image
|
| 7 |
+
|
| 8 |
+
class MultiImagesParser:
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
n_frames=8,
|
| 12 |
+
is_training=True,
|
| 13 |
+
):
|
| 14 |
+
self.n_frames = n_frames
|
| 15 |
+
self.is_training = is_training
|
| 16 |
+
# fmt: off
|
| 17 |
+
self.data_temp = {
|
| 18 |
+
"text": [
|
| 19 |
+
[{
|
| 20 |
+
"prompt": "Describe the image in short.",
|
| 21 |
+
"response": "A rollerblader rides high in a full pipe while others watch"
|
| 22 |
+
}],
|
| 23 |
+
[{
|
| 24 |
+
"prompt": "Describe the image in short.",
|
| 25 |
+
"response": "A woman in winter clothes is on the sidewalk with a phone."
|
| 26 |
+
}]
|
| 27 |
+
],
|
| 28 |
+
"image": [
|
| 29 |
+
{
|
| 30 |
+
"image_file": "/mnt/bn/videonaslq/images/flickr30k/images/3371533654.jpg"
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"image_file": "/mnt/bn/videonaslq/images/coco/train2014/COCO_train2014_000000177950.jpg"
|
| 34 |
+
},
|
| 35 |
+
{
|
| 36 |
+
"video_file": "/mnt/bn/llmdatalq/jiangnan/video_generation/webvid_10M_download/20230609/videos/011851_011900/1047443473.mp4",
|
| 37 |
+
"frame_indices": [0, 85, 171, 256, 342, 427, 513, 598]
|
| 38 |
+
}
|
| 39 |
+
],
|
| 40 |
+
"dataset": "coco",
|
| 41 |
+
"task": "multi_images",
|
| 42 |
+
"image_processing_config": {},
|
| 43 |
+
}
|
| 44 |
+
# fmt: on
|
| 45 |
+
|
| 46 |
+
def check_format(self, data_dict: Dict, image_processing_config: Dict):
|
| 47 |
+
assert data_dict['dataset'] in ['coco', 'sharegpt4v_cap100k', 'sharegpt4v_mix665k', 'webvid', 'movie'], data_dict
|
| 48 |
+
|
| 49 |
+
# 目前多图数据应该没有包含坐标的数据吧
|
| 50 |
+
if image_processing_config.get('has_coordinates', False):
|
| 51 |
+
raise ValueError(f'do_crop and has_coordinates cannot be True at the same time in MultiImagesParser!')
|
| 52 |
+
|
| 53 |
+
# 检查是否能匹配到坐标
|
| 54 |
+
texts = data_dict['text']
|
| 55 |
+
for text in texts:
|
| 56 |
+
match = re.search(r'\[(\d+(\.\d+)?,\s*)+\d+(\.\d+)?\]', text['prompt'] + text['response'])
|
| 57 |
+
if match:
|
| 58 |
+
print(f'[Warning] 疑似检测到包含坐标的数据:{data_dict}')
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def transform(self, data_dict: Dict, image_processing_config: Dict = None) -> Dict:
|
| 62 |
+
self.check_format(data_dict, image_processing_config)
|
| 63 |
+
|
| 64 |
+
# shuffle
|
| 65 |
+
texts = data_dict['text']
|
| 66 |
+
images = data_dict['image']
|
| 67 |
+
images = self.load_images(images)
|
| 68 |
+
idxs = list(range(len(texts)))
|
| 69 |
+
random.shuffle(idxs)
|
| 70 |
+
texts = [texts[i] for i in idxs]
|
| 71 |
+
images = [images[i] for i in idxs]
|
| 72 |
+
|
| 73 |
+
# sample n_frames
|
| 74 |
+
if isinstance(self.n_frames, int):
|
| 75 |
+
n_frames = random.choice(list(range(1, self.n_frames + 1)))
|
| 76 |
+
else:
|
| 77 |
+
n_frames = random.choice(self.n_frames)
|
| 78 |
+
texts = texts[: n_frames]
|
| 79 |
+
images = images[: n_frames]
|
| 80 |
+
|
| 81 |
+
dataset = data_dict['dataset']
|
| 82 |
+
if dataset in ['coco', 'sharegpt4v_cap100k', 'webvid', 'movie']:
|
| 83 |
+
prompt, response = self.transform_for_caption_task(texts, dataset, images)
|
| 84 |
+
else:
|
| 85 |
+
prompt, response = self.transform_for_qa_task(texts, dataset, images)
|
| 86 |
+
|
| 87 |
+
messages = [
|
| 88 |
+
{
|
| 89 |
+
"role": "user",
|
| 90 |
+
"content": [
|
| 91 |
+
*[{"type": "image", "image": img} for img in images],
|
| 92 |
+
{"type": "text", "text": prompt},
|
| 93 |
+
]
|
| 94 |
+
},
|
| 95 |
+
{
|
| 96 |
+
"role": "assistant",
|
| 97 |
+
"content": [
|
| 98 |
+
{"type": "text", "text": response}
|
| 99 |
+
]
|
| 100 |
+
}
|
| 101 |
+
]
|
| 102 |
+
|
| 103 |
+
return messages
|
| 104 |
+
|
| 105 |
+
def transform_for_caption_task(self, texts, dataset, images):
|
| 106 |
+
idx = random.choice(list(range(len(texts))))
|
| 107 |
+
|
| 108 |
+
if dataset == 'coco':
|
| 109 |
+
if len(texts) == 1:
|
| 110 |
+
prompt = 'Describe the image in short.'
|
| 111 |
+
else:
|
| 112 |
+
prompt = f'Describe the images starting from frame {idx + 1} in short in order.'
|
| 113 |
+
elif dataset == 'sharegpt4v_cap100k':
|
| 114 |
+
if len(texts) == 1:
|
| 115 |
+
prompt = 'Describe the image in detail.'
|
| 116 |
+
else:
|
| 117 |
+
prompt = f'Describe the images starting from frame {idx + 1} in detail in order.'
|
| 118 |
+
else:
|
| 119 |
+
if len(texts) == 1:
|
| 120 |
+
prompt = 'Describe the image.'
|
| 121 |
+
else:
|
| 122 |
+
prompt = f'Describe the images starting from frame {idx + 1} in order.'
|
| 123 |
+
response = ''
|
| 124 |
+
for i, text in enumerate(texts):
|
| 125 |
+
if i < idx:
|
| 126 |
+
continue
|
| 127 |
+
if not isinstance(text, dict):
|
| 128 |
+
text = random.choice(text)
|
| 129 |
+
resp = text['response']
|
| 130 |
+
response += f'{resp}\n'
|
| 131 |
+
return prompt, response
|
| 132 |
+
|
| 133 |
+
def transform_for_qa_task(self, texts, dataset, images):
|
| 134 |
+
prompt, response = '', ''
|
| 135 |
+
for i, text in enumerate(texts):
|
| 136 |
+
if not isinstance(text, dict):
|
| 137 |
+
text = random.choice(text)
|
| 138 |
+
if len(texts) > 1:
|
| 139 |
+
prompt += f'Question for frame {i+1}:\n' + text['prompt'] + '\n'
|
| 140 |
+
response += f'Answer to question of frame {i+1}:\n' + text['response'] + '\n'
|
| 141 |
+
else:
|
| 142 |
+
prompt += text['prompt'] + '\n'
|
| 143 |
+
response += text['response'] + '\n'
|
| 144 |
+
return prompt, response
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def load_images(self, image_items: List[Dict]) -> List[Image.Image]:
|
| 148 |
+
"""
|
| 149 |
+
image_items: List[Dict]. each item like:
|
| 150 |
+
{"video_file": "path/to/video", "frame_indices": [1]}
|
| 151 |
+
or
|
| 152 |
+
{"image_file": "path/to/image"}
|
| 153 |
+
"""
|
| 154 |
+
if image_items is None:
|
| 155 |
+
raise ValueError(f'image_items is None!')
|
| 156 |
+
|
| 157 |
+
if isinstance(image_items, dict):
|
| 158 |
+
image_items = [image_items]
|
| 159 |
+
|
| 160 |
+
images = []
|
| 161 |
+
|
| 162 |
+
for image_item in image_items:
|
| 163 |
+
|
| 164 |
+
if 'video_file' in image_item:
|
| 165 |
+
file_key = 'video_file'
|
| 166 |
+
elif 'image_file' in image_item:
|
| 167 |
+
file_key = 'image_file'
|
| 168 |
+
else:
|
| 169 |
+
raise KeyError(f'video_file or image_file not in {image_item}')
|
| 170 |
+
|
| 171 |
+
file_path = image_item[file_key]
|
| 172 |
+
if file_key == 'video_file':
|
| 173 |
+
frame_indices = image_item.get('frame_indices', None)
|
| 174 |
+
if frame_indices is None:
|
| 175 |
+
raise ValueError(f'read 0 frame: {image_item}')
|
| 176 |
+
if isinstance(frame_indices, int):
|
| 177 |
+
frame_indices = [frame_indices]
|
| 178 |
+
frames = sample_video(file_path, frame_indices = frame_indices)
|
| 179 |
+
images.extend(frames)
|
| 180 |
+
else:
|
| 181 |
+
if isinstance(file_path, str):
|
| 182 |
+
file_path = [file_path]
|
| 183 |
+
images.extend([read_image(f) for f in file_path])
|
| 184 |
+
|
| 185 |
+
return images
|
| 186 |
+
|
| 187 |
+
if __name__ == '__main__':
|
| 188 |
+
# python3 -m xenon_generation.data.custom_data_parsers.multi_images_parser
|
| 189 |
+
|
| 190 |
+
from tqdm import tqdm
|
| 191 |
+
from tools.rw_utils import read_jsonlines
|
| 192 |
+
|
| 193 |
+
lines = read_jsonlines('/mnt/bn/videonaslq/VideoCaption/datasets_1009/sharegpt4v_cap100k/part_36.jsonl')
|
| 194 |
+
lines = lines[:10]
|
| 195 |
+
parser = MultiImagesParser(n_frames=8)
|
| 196 |
+
for i, l in tqdm(enumerate(lines)):
|
| 197 |
+
l_image_processing_config = l.get('image_processing_config', {})
|
| 198 |
+
messages = parser.transform(l, l_image_processing_config)
|
| 199 |
+
print(messages)
|
eval_scripts/DREAM-1K/tarsier/dataset/custom_data_parsers/object_tracking_parser.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict
|
| 2 |
+
import random
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
from torchvision import transforms
|
| 6 |
+
|
| 7 |
+
from .utils import sample_video
|
| 8 |
+
|
| 9 |
+
def return_same(x):
|
| 10 |
+
return x
|
| 11 |
+
|
| 12 |
+
def _bbox_transform_for_padding(bbox, frame):
|
| 13 |
+
w1, h1, w2, h2 = bbox
|
| 14 |
+
width, height = frame.size
|
| 15 |
+
if width == height:
|
| 16 |
+
pass
|
| 17 |
+
elif width > height:
|
| 18 |
+
h1 += (width - height) // 2
|
| 19 |
+
h2 += (width - height) // 2
|
| 20 |
+
height = width
|
| 21 |
+
else:
|
| 22 |
+
w1 += (height - width) // 2
|
| 23 |
+
w2 += (height - width) // 2
|
| 24 |
+
width = height
|
| 25 |
+
new_bbox = [w1 / width, h1 / height, w2 / width, h2 / height]
|
| 26 |
+
new_bbox = [round(i, 2) for i in new_bbox]
|
| 27 |
+
return new_bbox
|
| 28 |
+
|
| 29 |
+
def _bbox_transform_for_resize(bbox, frame):
|
| 30 |
+
w1, h1, w2, h2 = bbox
|
| 31 |
+
width, height = frame.size
|
| 32 |
+
new_bbox = [w1 / width, h1 / height, w2 / width, h2 / height]
|
| 33 |
+
new_bbox = [round(i, 2) for i in new_bbox]
|
| 34 |
+
return new_bbox
|
| 35 |
+
|
| 36 |
+
class InAndOutCropAndResize(object):
|
| 37 |
+
"""Crop and resize for in_and_out boxes data according to yuchen
|
| 38 |
+
Args:
|
| 39 |
+
size: tuple of (width, height)
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(self, size):
|
| 43 |
+
self.size = size
|
| 44 |
+
|
| 45 |
+
def __call__(self, img):
|
| 46 |
+
"""
|
| 47 |
+
Args:
|
| 48 |
+
img (PIL Image): PIL Image
|
| 49 |
+
Returns:
|
| 50 |
+
PIL Image: PIL image.
|
| 51 |
+
"""
|
| 52 |
+
w = img.width
|
| 53 |
+
h = img.height
|
| 54 |
+
x0 = int(w * 0.5 - h * 0.375)
|
| 55 |
+
y0 = int(h * 0.125)
|
| 56 |
+
x1 = int(w * 0.5 + h * 0.375)
|
| 57 |
+
y1 = int(h * 0.875)
|
| 58 |
+
img = img.crop((x0, y0, x1, y1)).resize(self.size)
|
| 59 |
+
return img
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class ObjectTrackingParser:
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
n_frames = 8,
|
| 66 |
+
max_objects = 3,
|
| 67 |
+
is_training=True,
|
| 68 |
+
):
|
| 69 |
+
self.n_frames = n_frames
|
| 70 |
+
self.max_objects = max_objects
|
| 71 |
+
self.is_training = is_training
|
| 72 |
+
self.img_transform = self.get_img_transform()
|
| 73 |
+
# fmt: off
|
| 74 |
+
self.data_temp = {
|
| 75 |
+
"video_file": "/mnt/bn/llmdatalq/jiaxin/hdvila/20230926/saved/saved_video_clips/0076/lOjn__YCec4.624.1104.mp4",
|
| 76 |
+
"frame_indices": [154, 157, 160, 163, 166, 169, 172, 175, 178, 181, 184, 187, 190, 193, 196, 199, 202],
|
| 77 |
+
"objects": {
|
| 78 |
+
"0": {
|
| 79 |
+
"phrase": "person",
|
| 80 |
+
"all_frame_bounding_boxes": [[2, 0, 255, 250], [17, 0, 255, 251], [35, 0, 255, 253], [44, 0, 255, 255], [52, 0, 255, 255], [54, 0, 255, 255], [63, 0, 255, 255], [60, 0, 255, 255], [54, 0, 253, 255], [43, 0, 250, 255], [36, 1, 249, 255], [36, 0, 252, 254], [41, 0, 252, 254], [61, 0, 255, 253], [68, 4, 255, 255], [74, 8, 255, 255], [91, 3, 255, 255]]
|
| 81 |
+
}
|
| 82 |
+
},
|
| 83 |
+
"task": "object_tracking",
|
| 84 |
+
"dataset": "hdvila"
|
| 85 |
+
}
|
| 86 |
+
# fmt: on
|
| 87 |
+
|
| 88 |
+
def check_format(self, data_dict: Dict, image_processing_config: Dict):
|
| 89 |
+
# box tracking 数据不支持 do_crop!!!
|
| 90 |
+
if image_processing_config.get('do_crop', False):
|
| 91 |
+
raise ValueError(f'do_crop is not supported in ObjectTrackingParser!')
|
| 92 |
+
|
| 93 |
+
def transform(self, data_dict: Dict, image_processing_config: Dict = None) -> Dict:
|
| 94 |
+
self.check_format(data_dict, image_processing_config)
|
| 95 |
+
|
| 96 |
+
bbox_transform = _bbox_transform_for_padding if image_processing_config['do_padding'] else _bbox_transform_for_resize
|
| 97 |
+
|
| 98 |
+
# sample n_frames
|
| 99 |
+
if isinstance(self.n_frames, int):
|
| 100 |
+
n_frames = self.n_frames
|
| 101 |
+
else:
|
| 102 |
+
n_frames = random.choice(self.n_frames)
|
| 103 |
+
total_frames = list(range(len(data_dict['frame_indices'])))
|
| 104 |
+
idxs = random.sample(total_frames, min(n_frames, len(total_frames)))
|
| 105 |
+
idxs.sort()
|
| 106 |
+
|
| 107 |
+
frame_indices = [data_dict['frame_indices'][i] for i in idxs]
|
| 108 |
+
frames = sample_video(data_dict['video_file'], frame_indices=frame_indices)
|
| 109 |
+
img_transform = self.img_transform[data_dict['dataset']]
|
| 110 |
+
frames = [img_transform(f) for f in frames]
|
| 111 |
+
|
| 112 |
+
objects = []
|
| 113 |
+
for _, o in data_dict['objects'].items():
|
| 114 |
+
if o is None:
|
| 115 |
+
continue
|
| 116 |
+
all_frame_bounding_boxes = [o['all_frame_bounding_boxes'][i] for i in idxs]
|
| 117 |
+
all_frame_bounding_boxes_t = []
|
| 118 |
+
for bbox, frame in zip(all_frame_bounding_boxes, frames):
|
| 119 |
+
all_frame_bounding_boxes_t.append(bbox_transform(bbox, frame))
|
| 120 |
+
objects.append(all_frame_bounding_boxes_t)
|
| 121 |
+
if len(objects) >= self.max_objects:
|
| 122 |
+
break
|
| 123 |
+
|
| 124 |
+
prompt = "Given the bounding box coordinates of these objects in the first frame, output the bounding box coordinates in the following frames.\n{}"
|
| 125 |
+
response = ''
|
| 126 |
+
|
| 127 |
+
object_info = ''
|
| 128 |
+
for i, o in enumerate(objects):
|
| 129 |
+
object_info += f'object {i+1}: {o[0]}\n'
|
| 130 |
+
response += f'object {i+1}: {o[1:]}\n'
|
| 131 |
+
response = response.strip()
|
| 132 |
+
prompt = prompt.format(object_info)
|
| 133 |
+
|
| 134 |
+
messages = [
|
| 135 |
+
{
|
| 136 |
+
"role": "user",
|
| 137 |
+
"content": [
|
| 138 |
+
{"type": "video", "video": frames},
|
| 139 |
+
{"type": "text", "text": prompt}
|
| 140 |
+
]
|
| 141 |
+
},
|
| 142 |
+
{
|
| 143 |
+
"role": "assistant",
|
| 144 |
+
"content": [
|
| 145 |
+
{"type": "text", "text": response}
|
| 146 |
+
]
|
| 147 |
+
}
|
| 148 |
+
]
|
| 149 |
+
|
| 150 |
+
return messages
|
| 151 |
+
|
| 152 |
+
def get_img_transform(self):
|
| 153 |
+
return {
|
| 154 |
+
'webvid': return_same,
|
| 155 |
+
'hdvila': transforms.Compose([
|
| 156 |
+
transforms.Resize(size=256),
|
| 157 |
+
transforms.CenterCrop(size=(256, 256))
|
| 158 |
+
]),
|
| 159 |
+
'hdvila_in_and_out_boxes': InAndOutCropAndResize(size=(256, 256))
|
| 160 |
+
}
|
eval_scripts/DREAM-1K/tarsier/dataset/custom_data_parsers/standard_vision_parser.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
from .utils import sample_video, read_image, adjust_bbox, filter_ocr_polygon
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class VisionParser:
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
n_frames=8,
|
| 12 |
+
max_n_frames=256,
|
| 13 |
+
is_training=True,
|
| 14 |
+
video_sampling_strategy={},
|
| 15 |
+
):
|
| 16 |
+
self.n_frames = n_frames
|
| 17 |
+
self.max_n_frames = max_n_frames
|
| 18 |
+
self.is_training = is_training
|
| 19 |
+
self.video_sampling_strategy = video_sampling_strategy
|
| 20 |
+
|
| 21 |
+
# fmt: off
|
| 22 |
+
self.data_temp = {
|
| 23 |
+
"messages": [
|
| 24 |
+
{
|
| 25 |
+
"role": "user",
|
| 26 |
+
"content": [
|
| 27 |
+
{"type": "text", "text": "Describe the image and the video."},
|
| 28 |
+
# 支持的 image 格式:
|
| 29 |
+
{"type": "image", "image": {"image_file": "/path/to/image"}},
|
| 30 |
+
{"type": "image", "image": {"video_file": "/path/to/video", "frame_indices": 0}},
|
| 31 |
+
# 支持的 video 格式:
|
| 32 |
+
{"type": "video", "video": {"video_file": "/path/to/video"}},
|
| 33 |
+
{"type": "video", "video": {"video_file": "/path/to/video", "frame_indices": [0, 1, 2]}},
|
| 34 |
+
{"type": "video", "video": {"video_file": "/path/to/video", "start_frame": 0, "end_frame": 100}},
|
| 35 |
+
{"type": "video", "video": {"video_file": "/path/to/video", "time_indices": [0, 1, 2]}},
|
| 36 |
+
{"type": "video", "video": {"video_file": "/path/to/video", "start_time": 0, "end_time": 100}},
|
| 37 |
+
{"type": "video", "video": {"image_file": ["/path/to/image"]}, "frame_indices": [0, 1, 2]},
|
| 38 |
+
]
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"role": "assistant",
|
| 42 |
+
"content": [
|
| 43 |
+
{"type": "text","text": "xxx"}
|
| 44 |
+
]
|
| 45 |
+
}
|
| 46 |
+
],
|
| 47 |
+
"dataset": "LSMDC",
|
| 48 |
+
"task": "video/caption"
|
| 49 |
+
}
|
| 50 |
+
# fmt: on
|
| 51 |
+
|
| 52 |
+
def check_format(self, data_dict: Dict, image_processing_config: Dict):
|
| 53 |
+
if image_processing_config.get('do_crop', False) and image_processing_config.get('has_coordinates', False):
|
| 54 |
+
raise ValueError(f'do_crop and has_coordinates cannot be True at the same time!')
|
| 55 |
+
|
| 56 |
+
"""
|
| 57 |
+
1. 将 messages 中的 image/video 替换成相应的 PIL.Image/List[PIL.Image]
|
| 58 |
+
2. text 的特殊处理:调整 box;过滤面积太小的OCR
|
| 59 |
+
"""
|
| 60 |
+
def transform(self, data_dict: Dict, image_processing_config: Dict = None) -> Dict:
|
| 61 |
+
self.check_format(data_dict, image_processing_config)
|
| 62 |
+
|
| 63 |
+
self.set_n_frames(data_dict)
|
| 64 |
+
|
| 65 |
+
first_image = None # ugly! 需要调整box/过滤面积太小的OCR的数据只有图片任务
|
| 66 |
+
|
| 67 |
+
for msg in data_dict['messages']:
|
| 68 |
+
if isinstance(msg['content'], dict):
|
| 69 |
+
msg['content'] = [msg['content']]
|
| 70 |
+
for content in msg['content']:
|
| 71 |
+
|
| 72 |
+
if content['type'] == 'image':
|
| 73 |
+
content['image'] = self.load_image_item(content['image'])
|
| 74 |
+
if first_image is None:
|
| 75 |
+
first_image = content['image']
|
| 76 |
+
elif content['type'] == 'video':
|
| 77 |
+
video = self.load_video_item(content['video'])
|
| 78 |
+
content['video'] = video.pop('frames')
|
| 79 |
+
if video:
|
| 80 |
+
data_dict['extra_info']['frame_disturb_info'] = video.pop('video_info', {})
|
| 81 |
+
elif content['type'] == 'text':
|
| 82 |
+
pass
|
| 83 |
+
else:
|
| 84 |
+
raise ValueError(f"content['type']={content['type']} MUST be one of ['image', 'video', 'text']")
|
| 85 |
+
for msg in data_dict['messages']:
|
| 86 |
+
for content in msg['content']:
|
| 87 |
+
if content['type'] == 'text':
|
| 88 |
+
self.postprocess_text(content, data_dict, image_processing_config, first_image)
|
| 89 |
+
|
| 90 |
+
return data_dict['messages']
|
| 91 |
+
|
| 92 |
+
# set n_frames for each vision item.
|
| 93 |
+
def set_n_frames(self, data_dict):
|
| 94 |
+
|
| 95 |
+
if isinstance(self.n_frames, int):
|
| 96 |
+
n_frames = self.n_frames
|
| 97 |
+
else:
|
| 98 |
+
n_frames = random.choice(self.n_frames)
|
| 99 |
+
|
| 100 |
+
assert n_frames <= self.max_n_frames
|
| 101 |
+
|
| 102 |
+
curr_n_frames = 0
|
| 103 |
+
has_dynamic = False
|
| 104 |
+
for msg in data_dict['messages']:
|
| 105 |
+
if isinstance(msg['content'], dict):
|
| 106 |
+
msg['content'] = [msg['content']]
|
| 107 |
+
|
| 108 |
+
for content in msg['content']:
|
| 109 |
+
|
| 110 |
+
if content['type'] == 'image':
|
| 111 |
+
curr_n_frames += 1
|
| 112 |
+
elif content['type'] == 'video':
|
| 113 |
+
if 'frame_indices' in content['video']:
|
| 114 |
+
curr_n_frames += len(content['video']['frame_indices'])
|
| 115 |
+
content['video']['n_frames'] = len(content['video']['frame_indices'])
|
| 116 |
+
elif 'time_indices' in content['video']:
|
| 117 |
+
curr_n_frames += len(content['video']['time_indices'])
|
| 118 |
+
content['video']['n_frames'] = len(content['video']['time_indices'])
|
| 119 |
+
elif 'min_n_frames' in content['video']:
|
| 120 |
+
content['video']['min_n_frames'] = int(content['video']['min_n_frames'])
|
| 121 |
+
curr_n_frames += content['video']['min_n_frames']
|
| 122 |
+
content['video']['n_frames'] = content['video']['min_n_frames']
|
| 123 |
+
has_dynamic = True
|
| 124 |
+
elif 'fps' in content['video']:
|
| 125 |
+
content['video']['n_frames'] = self.max_n_frames
|
| 126 |
+
curr_n_frames += self.max_n_frames
|
| 127 |
+
has_dynamic = True
|
| 128 |
+
else:
|
| 129 |
+
content['video']['n_frames'] = 0
|
| 130 |
+
has_dynamic = True
|
| 131 |
+
|
| 132 |
+
while curr_n_frames < n_frames and has_dynamic:
|
| 133 |
+
for msg in data_dict['messages']:
|
| 134 |
+
for content in msg['content']:
|
| 135 |
+
if content['type'] == 'video':
|
| 136 |
+
if 'frame_indices' in content['video']:
|
| 137 |
+
pass
|
| 138 |
+
elif 'time_indices' in content['video']:
|
| 139 |
+
pass
|
| 140 |
+
else:
|
| 141 |
+
if curr_n_frames < n_frames:
|
| 142 |
+
content['video']['n_frames'] += 1
|
| 143 |
+
curr_n_frames += 1
|
| 144 |
+
|
| 145 |
+
while curr_n_frames > self.max_n_frames and has_dynamic:
|
| 146 |
+
for msg in data_dict['messages']:
|
| 147 |
+
for content in msg['content']:
|
| 148 |
+
if content['type'] == 'video':
|
| 149 |
+
if 'frame_indices' in content['video']:
|
| 150 |
+
pass
|
| 151 |
+
elif 'time_indices' in content['video']:
|
| 152 |
+
pass
|
| 153 |
+
else:
|
| 154 |
+
if curr_n_frames > self.max_n_frames:
|
| 155 |
+
content['video']['n_frames'] -= 1
|
| 156 |
+
curr_n_frames -= 1
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
for msg in data_dict['messages']:
|
| 160 |
+
for content in msg['content']:
|
| 161 |
+
if content['type'] == 'video':
|
| 162 |
+
if 'frame_indices' in content['video']:
|
| 163 |
+
pass
|
| 164 |
+
elif 'time_indices' in content['video']:
|
| 165 |
+
pass
|
| 166 |
+
else:
|
| 167 |
+
n = self.video_sampling_strategy.get('force_frames_n_divisible', 1)
|
| 168 |
+
if n > 1 and content['video']['n_frames'] % n != 0:
|
| 169 |
+
content['video']['n_frames'] += n - content['video']['n_frames'] % n
|
| 170 |
+
|
| 171 |
+
def load_image_item(self, image_item) -> Image.Image:
|
| 172 |
+
"""
|
| 173 |
+
image_item:
|
| 174 |
+
{"image_file": {"lq": "/path/to/image"}}
|
| 175 |
+
{"video_file": {"lq": "/path/to/video"}, "frame_indices": 0}
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
# check format
|
| 179 |
+
if ("image_file" not in image_item) and ("video_file" not in image_item):
|
| 180 |
+
raise KeyError(f"Key 'image_file' or 'video_file' not found in image_item")
|
| 181 |
+
if 'image_file' in image_item:
|
| 182 |
+
if not isinstance(image_item['image_file'], str):
|
| 183 |
+
raise ValueError(f"{image_item['image_file']} is not a str!")
|
| 184 |
+
if 'video_file' in image_item:
|
| 185 |
+
if not isinstance(image_item['frame_indices'], int):
|
| 186 |
+
raise ValueError(f"{image_item['frame_indices']} is not a int!")
|
| 187 |
+
|
| 188 |
+
if 'image_file' in image_item:
|
| 189 |
+
image = read_image(image_item['image_file'])
|
| 190 |
+
else:
|
| 191 |
+
frame_indices = [image_item['frame_indices']]
|
| 192 |
+
image = sample_video(image_item['video_file'], frame_indices = frame_indices)[0]
|
| 193 |
+
|
| 194 |
+
return image
|
| 195 |
+
|
| 196 |
+
def load_video_item(self, video_item) -> List[Image.Image]:
|
| 197 |
+
"""
|
| 198 |
+
video_item:
|
| 199 |
+
{"video_file": {"lq": "/path/to/video"}, "n_frames": 8}
|
| 200 |
+
{"video_file": {"lq": "/path/to/video"}, "frame_indices": [0, 1, 2], "n_frames": 3}
|
| 201 |
+
{"video_file": {"lq": "/path/to/video"}, "start_frame": 0, "end_frame": 100, "n_frames": 8}
|
| 202 |
+
{"video_file": {"lq": "/path/to/video"}, "time_indices": [0, 1, 2], "n_frames": 3}
|
| 203 |
+
{"video_file": {"lq": "/path/to/video"}, "start_time": 0, "end_time": 100, "n_frames": 8}
|
| 204 |
+
{"image_file": {"lq": ["/path/to/image"]}, "frame_indices": [0, 1, 2], "n_frames": 3}
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
+
# check format
|
| 208 |
+
if ("image_file" not in video_item) and ("video_file" not in video_item):
|
| 209 |
+
raise KeyError(f"Key 'image_file' or 'video_file' not found in video_item")
|
| 210 |
+
|
| 211 |
+
video_path = video_item.get('video_file', video_item.get('image_file'))
|
| 212 |
+
n_frames = video_item.get('n_frames', None)
|
| 213 |
+
frame_indices = video_item.get('frame_indices', None)
|
| 214 |
+
start_frame = video_item.get('start_frame', None)
|
| 215 |
+
end_frame = video_item.get('end_frame', None)
|
| 216 |
+
time_indices = video_item.get('time_indices', None)
|
| 217 |
+
start_time = video_item.get('start_time', None)
|
| 218 |
+
end_time = video_item.get('end_time', None)
|
| 219 |
+
mask_boxes = video_item.get('mask_boxes', None)
|
| 220 |
+
fps = video_item.get('fps', None)
|
| 221 |
+
|
| 222 |
+
frames, frame_indices = sample_video(
|
| 223 |
+
video_path=video_path,
|
| 224 |
+
frame_indices=frame_indices,
|
| 225 |
+
start_frame=start_frame,
|
| 226 |
+
end_frame=end_frame,
|
| 227 |
+
n_frames=n_frames,
|
| 228 |
+
time_indices=time_indices,
|
| 229 |
+
start_time=start_time,
|
| 230 |
+
end_time=end_time,
|
| 231 |
+
sampling_fps=fps,
|
| 232 |
+
mask_boxes=mask_boxes,
|
| 233 |
+
is_training=self.is_training,
|
| 234 |
+
video_sampling_strategy=self.video_sampling_strategy,
|
| 235 |
+
return_frame_ids=True,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
if self.video_sampling_strategy.get('use_multi_images_for_video', False):
|
| 239 |
+
new_frames = []
|
| 240 |
+
for f in frames:
|
| 241 |
+
new_frames.extend([f, f])
|
| 242 |
+
frames = new_frames
|
| 243 |
+
|
| 244 |
+
if isinstance(frame_indices, dict):
|
| 245 |
+
return {
|
| 246 |
+
'frames': frames,
|
| 247 |
+
'video_info': frame_indices
|
| 248 |
+
}
|
| 249 |
+
return {'frames': frames}
|
| 250 |
+
|
| 251 |
+
def postprocess_text(self, content, data_dict, image_processing_config, first_image):
|
| 252 |
+
if image_processing_config.get('has_coordinates') and image_processing_config.get('do_padding'):
|
| 253 |
+
content['text'] = adjust_bbox(content['text'], frame=first_image)
|
| 254 |
+
if data_dict.get('task') == 'image/OCR' and image_processing_config.get('has_coordinates'):
|
| 255 |
+
content['text'] = filter_ocr_polygon(content['text'])
|
eval_scripts/DREAM-1K/tarsier/dataset/custom_data_parsers/utils.py
ADDED
|
@@ -0,0 +1,452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Dict, Union
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import tempfile
|
| 5 |
+
from PIL import Image, ImageSequence
|
| 6 |
+
import base64
|
| 7 |
+
import io
|
| 8 |
+
import re
|
| 9 |
+
import uuid
|
| 10 |
+
import json
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pyarrow.fs as pf
|
| 13 |
+
import func_timeout
|
| 14 |
+
from func_timeout import func_set_timeout
|
| 15 |
+
import math
|
| 16 |
+
|
| 17 |
+
# fmt: on
|
| 18 |
+
import decord
|
| 19 |
+
# fmt: off
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def denorm_box(points, height, width):
|
| 23 |
+
new_points = []
|
| 24 |
+
for p in points:
|
| 25 |
+
new_points.append((round(p[0] * width), round(p[1] * height)))
|
| 26 |
+
return new_points
|
| 27 |
+
|
| 28 |
+
def process_image_for_tiktok(frames: List[Image.Image], mask_boxes):
|
| 29 |
+
mask_boxes = mask_boxes[:len(frames)]
|
| 30 |
+
frames = [np.array(f) for f in frames]
|
| 31 |
+
# assert len(mask_boxes) == len(frames)
|
| 32 |
+
height, width = frames[0].shape[:2]
|
| 33 |
+
|
| 34 |
+
new_frames = []
|
| 35 |
+
for boxes, frame in zip(mask_boxes, frames):
|
| 36 |
+
left, top, right, bottom = 0, 0, width, height
|
| 37 |
+
for box in boxes:
|
| 38 |
+
pts = np.array(denorm_box(box, height, width), np.int32)
|
| 39 |
+
upper_bound = max([p[1] for p in pts]) + 30
|
| 40 |
+
if bottom > upper_bound:
|
| 41 |
+
bottom = upper_bound
|
| 42 |
+
frame[pts[0][1]: pts[2][1], pts[0][0]: pts[1][0]] = 0
|
| 43 |
+
|
| 44 |
+
new_frames.append(Image.fromarray(frame[top: bottom, left: right]))
|
| 45 |
+
return new_frames
|
| 46 |
+
|
| 47 |
+
# 先将视频分成 n_frames 份。训练时,每份随机抽一帧;测试时,每份抽中间的那一帧。
|
| 48 |
+
def _sample_frame_indices_v2(
|
| 49 |
+
total_frames: int,
|
| 50 |
+
n_frames: int,
|
| 51 |
+
is_training=False,
|
| 52 |
+
video_sampling_strategy = {},
|
| 53 |
+
):
|
| 54 |
+
total_frames_idxs = list(range(total_frames))
|
| 55 |
+
if total_frames <= n_frames:
|
| 56 |
+
return total_frames_idxs
|
| 57 |
+
k, m = divmod(total_frames, n_frames)
|
| 58 |
+
frame_splits = [total_frames_idxs[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in list(range(n_frames))]
|
| 59 |
+
if is_training:
|
| 60 |
+
sample_ids = [random.choice(i) for i in frame_splits]
|
| 61 |
+
else:
|
| 62 |
+
sample_ids = [i[(len(i)+1)//2-1] for i in frame_splits]
|
| 63 |
+
return sample_ids
|
| 64 |
+
|
| 65 |
+
# 均匀抽帧,必采样首尾帧。
|
| 66 |
+
def _sample_frame_indices_v1(total_frames: int, n_frames: int, is_training=False, video_sampling_strategy = {}):
|
| 67 |
+
if n_frames == 1:
|
| 68 |
+
return [0] # sample first frame in default
|
| 69 |
+
if total_frames <= n_frames:
|
| 70 |
+
return list(range(total_frames))
|
| 71 |
+
sample_ids = [round(i * (total_frames - 1) / (n_frames - 1)) for i in range(n_frames)]
|
| 72 |
+
return sample_ids
|
| 73 |
+
|
| 74 |
+
def conduct_disturb_frame(frame_indices):
|
| 75 |
+
disturb_type = random.choice(['exchange', 'crop', 'reverse', 'discard'])
|
| 76 |
+
n_frames = len(frame_indices)
|
| 77 |
+
frame_indices_new = []
|
| 78 |
+
if disturb_type == 'exchange':
|
| 79 |
+
# 均等分成4个segments, 随机交换两个segment
|
| 80 |
+
seg_len = math.ceil(n_frames / 4)
|
| 81 |
+
seg_idxs = list(range(0, n_frames, seg_len))
|
| 82 |
+
target_idxs = random.sample(range(0, 4), 2)
|
| 83 |
+
seg_idxs[target_idxs[0]], seg_idxs[target_idxs[1]] = seg_idxs[target_idxs[1]], seg_idxs[target_idxs[0]]
|
| 84 |
+
for idx in seg_idxs:
|
| 85 |
+
frame_indices_new += frame_indices[idx: idx+seg_len]
|
| 86 |
+
elif disturb_type == 'crop':
|
| 87 |
+
# 随机截取出3/4时长,再采均匀n_frames帧
|
| 88 |
+
crop_len = math.ceil(n_frames / 4)
|
| 89 |
+
idx_s = random.choice(range(0, crop_len+1))
|
| 90 |
+
idx_e = n_frames - 1 - (crop_len - idx_s)
|
| 91 |
+
frame_indices_new = np.linspace(frame_indices[idx_s], frame_indices[idx_e], n_frames, dtype=int).tolist()
|
| 92 |
+
elif disturb_type == 'reverse':
|
| 93 |
+
# 随机选择长度为[1/2, 1]时长的片段进行顺序颠倒
|
| 94 |
+
reverse_len = math.ceil(random.uniform(0.5,1) * n_frames)
|
| 95 |
+
idx_s = random.choice(range(0, n_frames-reverse_len+1))
|
| 96 |
+
idx_e = idx_s + reverse_len - 1
|
| 97 |
+
frame_indices_new = frame_indices[:idx_s] + list(reversed(frame_indices[idx_s: idx_e+1])) + frame_indices[idx_e+1:]
|
| 98 |
+
elif disturb_type == 'discard':
|
| 99 |
+
# 随机丢弃一半帧
|
| 100 |
+
frame_indices_new = random.sample(frame_indices, n_frames//2)
|
| 101 |
+
frame_indices_new.sort()
|
| 102 |
+
return disturb_type, frame_indices_new
|
| 103 |
+
|
| 104 |
+
@func_set_timeout(60)
|
| 105 |
+
def _download_file(path):
|
| 106 |
+
if path.startswith("hdfs"):
|
| 107 |
+
local_path = os.path.join(tempfile.gettempdir(), f'{uuid.uuid4()}_' + os.path.basename(path))
|
| 108 |
+
|
| 109 |
+
fs = pf.HadoopFileSystem.from_uri(uri="hdfs://harunava")
|
| 110 |
+
hdfs_file = fs.open_input_file(path)
|
| 111 |
+
file_size = hdfs_file.size()
|
| 112 |
+
if file_size > 1024 * 1024 * 1024: # 1G
|
| 113 |
+
os.system(f"hadoop fs -get --ct 8 -c 512 '{path}' '{local_path}' > /dev/null 2>&1")
|
| 114 |
+
elif file_size > 1024 * 1024 * 100: # 100M
|
| 115 |
+
os.system(f"hadoop fs -get '{path}' '{local_path}' > /dev/null 2>&1")
|
| 116 |
+
else:
|
| 117 |
+
local_fs = pf.LocalFileSystem()
|
| 118 |
+
with local_fs.open_output_stream(local_path) as local_file:
|
| 119 |
+
while True:
|
| 120 |
+
chunk = hdfs_file.read(1024 * 1024 * 100) # Reading 1MB chunks, you can adjust this as needed
|
| 121 |
+
if not chunk:
|
| 122 |
+
break
|
| 123 |
+
local_file.write(chunk)
|
| 124 |
+
else:
|
| 125 |
+
local_path = path
|
| 126 |
+
|
| 127 |
+
if not os.path.exists(local_path):
|
| 128 |
+
raise FileNotFoundError(f'{local_path}')
|
| 129 |
+
|
| 130 |
+
return local_path
|
| 131 |
+
|
| 132 |
+
def download_file(path):
|
| 133 |
+
try:
|
| 134 |
+
# with timer(f'Download {path}'):
|
| 135 |
+
return _download_file(path)
|
| 136 |
+
except func_timeout.exceptions.FunctionTimedOut as e:
|
| 137 |
+
raise ValueError(e)
|
| 138 |
+
|
| 139 |
+
class VideoReader:
|
| 140 |
+
def __init__(self, path: str) -> None:
|
| 141 |
+
self.path = path
|
| 142 |
+
self.local_path = self.preprocess()
|
| 143 |
+
self.vr = decord.VideoReader(self.local_path, num_threads=1, ctx=decord.cpu(0), fault_tol=1)
|
| 144 |
+
self.vr.seek(0)
|
| 145 |
+
self._length = len(self.vr)
|
| 146 |
+
self._fps = self.vr.get_avg_fps()
|
| 147 |
+
|
| 148 |
+
@property
|
| 149 |
+
def length(self):
|
| 150 |
+
return self._length
|
| 151 |
+
|
| 152 |
+
@property
|
| 153 |
+
def fps(self):
|
| 154 |
+
return self._fps
|
| 155 |
+
|
| 156 |
+
def sample(self, frame_indices) -> List[Image.Image]:
|
| 157 |
+
frames = self.vr.get_batch(frame_indices).asnumpy()
|
| 158 |
+
frames = [Image.fromarray(f).convert('RGB') for f in frames]
|
| 159 |
+
return frames
|
| 160 |
+
|
| 161 |
+
def preprocess(self):
|
| 162 |
+
return download_file(self.path)
|
| 163 |
+
|
| 164 |
+
def postprocess(self):
|
| 165 |
+
if self.path.startswith("hdfs"):
|
| 166 |
+
os.remove(self.local_path)
|
| 167 |
+
|
| 168 |
+
class ImageSeqReader:
|
| 169 |
+
def __init__(self, path: List[str]) -> None:
|
| 170 |
+
self.path = path
|
| 171 |
+
self.local_path = self.preprocess()
|
| 172 |
+
self._length = len(self.local_path)
|
| 173 |
+
self._fps = None
|
| 174 |
+
|
| 175 |
+
@property
|
| 176 |
+
def length(self):
|
| 177 |
+
return self._length
|
| 178 |
+
|
| 179 |
+
@property
|
| 180 |
+
def fps(self):
|
| 181 |
+
return self._fps
|
| 182 |
+
|
| 183 |
+
def sample(self, frame_indices):
|
| 184 |
+
return [read_image(self.local_path[i]) for i in frame_indices]
|
| 185 |
+
|
| 186 |
+
def preprocess(self):
|
| 187 |
+
local_paths = []
|
| 188 |
+
for p in self.path:
|
| 189 |
+
local_paths.append(p)
|
| 190 |
+
return local_paths
|
| 191 |
+
|
| 192 |
+
def postprocess(self):
|
| 193 |
+
pass
|
| 194 |
+
|
| 195 |
+
class GIFReader:
|
| 196 |
+
def __init__(self, path: str) -> None:
|
| 197 |
+
self.path = path
|
| 198 |
+
self.local_path = self.preprocess()
|
| 199 |
+
self.gif = Image.open(self.local_path)
|
| 200 |
+
self._length = self.gif.n_frames
|
| 201 |
+
duration = self.gif.info.get('duration', 0) / 1000 # 转换为秒
|
| 202 |
+
if duration > 0:
|
| 203 |
+
self._fps = 1 / duration
|
| 204 |
+
else:
|
| 205 |
+
self._fps = None
|
| 206 |
+
|
| 207 |
+
@property
|
| 208 |
+
def length(self):
|
| 209 |
+
return self._length
|
| 210 |
+
|
| 211 |
+
@property
|
| 212 |
+
def fps(self):
|
| 213 |
+
return self._fps
|
| 214 |
+
|
| 215 |
+
def sample(self, frame_indices):
|
| 216 |
+
frames = []
|
| 217 |
+
i = 0
|
| 218 |
+
for frame in ImageSequence.Iterator(self.gif):
|
| 219 |
+
if i in frame_indices:
|
| 220 |
+
frames.append(frame.convert('RGB'))
|
| 221 |
+
i += 1
|
| 222 |
+
return frames
|
| 223 |
+
|
| 224 |
+
def preprocess(self):
|
| 225 |
+
return download_file(self.path)
|
| 226 |
+
|
| 227 |
+
def postprocess(self):
|
| 228 |
+
if self.path.startswith("hdfs"):
|
| 229 |
+
os.remove(self.local_path)
|
| 230 |
+
|
| 231 |
+
def check_frame_indices(frame_indices, total_frames, video_path):
|
| 232 |
+
if frame_indices[-1] == total_frames:
|
| 233 |
+
frame_indices[-1] = total_frames - 1
|
| 234 |
+
|
| 235 |
+
valid_frame_indices = [i for i in frame_indices if i >= 0 and i < total_frames]
|
| 236 |
+
|
| 237 |
+
if len(valid_frame_indices) != len(frame_indices):
|
| 238 |
+
print(f'[Error] frame out of index. video_path={video_path}, frame_indices={frame_indices}, total_frames={total_frames}', flush=True)
|
| 239 |
+
|
| 240 |
+
return valid_frame_indices
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def sample_video(
|
| 244 |
+
video_path: Union[str, List[str]],
|
| 245 |
+
frame_indices: List[int] = None,
|
| 246 |
+
start_frame:int=None,
|
| 247 |
+
end_frame:int=None,
|
| 248 |
+
n_frames:int = None,
|
| 249 |
+
time_indices: List[float] = None,
|
| 250 |
+
start_time:int=None,
|
| 251 |
+
end_time:int=None,
|
| 252 |
+
sampling_fps:float=None,
|
| 253 |
+
mask_boxes=None,
|
| 254 |
+
is_training:bool=False,
|
| 255 |
+
video_sampling_strategy={'video_sampler_version': 'v1'},
|
| 256 |
+
return_frame_ids: bool=False,
|
| 257 |
+
) -> List[Image.Image]:
|
| 258 |
+
|
| 259 |
+
do_frame_disturb = video_sampling_strategy.get('do_frame_disturb', False)
|
| 260 |
+
|
| 261 |
+
if isinstance(video_path, str):
|
| 262 |
+
if video_path.endswith('.gif'):
|
| 263 |
+
reader = GIFReader(video_path)
|
| 264 |
+
else:
|
| 265 |
+
reader = VideoReader(video_path)
|
| 266 |
+
else:
|
| 267 |
+
reader = ImageSeqReader(video_path)
|
| 268 |
+
|
| 269 |
+
total_frames = reader.length
|
| 270 |
+
fps = reader.fps
|
| 271 |
+
|
| 272 |
+
if sampling_fps is not None:
|
| 273 |
+
frame_indices = list(range(0, total_frames, round(fps / sampling_fps)))
|
| 274 |
+
if len(frame_indices) > n_frames:
|
| 275 |
+
frame_indices = None
|
| 276 |
+
|
| 277 |
+
if time_indices is not None:
|
| 278 |
+
frame_indices = [round(float(i) * fps) for i in time_indices]
|
| 279 |
+
|
| 280 |
+
if start_time is not None and end_time is not None:
|
| 281 |
+
start_frame = round(start_time * fps)
|
| 282 |
+
end_frame = round(end_time * fps)
|
| 283 |
+
|
| 284 |
+
if frame_indices is None:
|
| 285 |
+
start_frame = 0 if start_frame is None else round(start_frame)
|
| 286 |
+
end_frame = total_frames - 1 if end_frame is None else round(end_frame)
|
| 287 |
+
|
| 288 |
+
if end_frame == total_frames:
|
| 289 |
+
end_frame -= 1
|
| 290 |
+
|
| 291 |
+
if video_sampling_strategy['video_sampler_version'] == 'v1':
|
| 292 |
+
# 均匀抽帧,必采样首尾帧。
|
| 293 |
+
frame_indices = _sample_frame_indices_v1(end_frame - start_frame + 1, n_frames, is_training, video_sampling_strategy)
|
| 294 |
+
elif video_sampling_strategy['video_sampler_version'] == 'v2':
|
| 295 |
+
frame_indices = _sample_frame_indices_v2(end_frame - start_frame + 1, n_frames, is_training, video_sampling_strategy)
|
| 296 |
+
else:
|
| 297 |
+
raise ValueError(f"video_sampler_version={video_sampling_strategy['video_sampler_version']} must be 'v1' or 'v2'")
|
| 298 |
+
frame_indices = [i + start_frame for i in frame_indices]
|
| 299 |
+
|
| 300 |
+
frame_indices = check_frame_indices(frame_indices, total_frames, video_path)
|
| 301 |
+
|
| 302 |
+
if do_frame_disturb:
|
| 303 |
+
frame_disturb_type, frame_indices_new = conduct_disturb_frame(frame_indices)
|
| 304 |
+
frame_indices_raw = frame_indices[:]
|
| 305 |
+
frame_indices = frame_indices_new
|
| 306 |
+
|
| 307 |
+
frames = reader.sample(frame_indices)
|
| 308 |
+
if mask_boxes is not None:
|
| 309 |
+
frames = process_image_for_tiktok(frames, mask_boxes)
|
| 310 |
+
|
| 311 |
+
n = video_sampling_strategy.get('force_frames_n_divisible', 1)
|
| 312 |
+
if n > 1 and len(frames) % n != 0:
|
| 313 |
+
new_n = n - len(frames) % n
|
| 314 |
+
frames.extend([Image.new(mode='RGB', size=frames[-1].size) for _ in range(new_n)])
|
| 315 |
+
|
| 316 |
+
reader.postprocess()
|
| 317 |
+
|
| 318 |
+
if do_frame_disturb:
|
| 319 |
+
return frames, {"frame_indices": frame_indices, "disturb_type": frame_disturb_type, "frame_indices_raw": frame_indices_raw}
|
| 320 |
+
if return_frame_ids:
|
| 321 |
+
return frames, frame_indices
|
| 322 |
+
return frames
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def load_image_from_base64String(img_path):
|
| 327 |
+
img = base64.b64decode(open(img_path, "rb").read())
|
| 328 |
+
buf = io.BytesIO(img)
|
| 329 |
+
img = Image.open(buf)
|
| 330 |
+
return img
|
| 331 |
+
|
| 332 |
+
def read_image(image_path):
|
| 333 |
+
local_file = download_file(image_path)
|
| 334 |
+
|
| 335 |
+
if local_file.endswith('.dat'):
|
| 336 |
+
image = load_image_from_base64String(local_file)
|
| 337 |
+
else:
|
| 338 |
+
image = Image.open(local_file).convert('RGB')
|
| 339 |
+
if image_path.startswith("hdfs"):
|
| 340 |
+
os.remove(local_file)
|
| 341 |
+
return image
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def adjust_bbox(text, frame):
|
| 345 |
+
|
| 346 |
+
width, height = frame.size
|
| 347 |
+
new_text = []
|
| 348 |
+
start_idx = 0
|
| 349 |
+
for match in re.finditer(r'\[(\d+(\.\d+)?,\s*)+\d+(\.\d+)?\]', text):
|
| 350 |
+
coordinate_matches = re.findall(r"([0-9.]+)", match.group(0))
|
| 351 |
+
xys = [float(coord) for coord in coordinate_matches]
|
| 352 |
+
|
| 353 |
+
new_xys = []
|
| 354 |
+
for i in range(len(xys)):
|
| 355 |
+
p = xys[i]
|
| 356 |
+
|
| 357 |
+
if width == height:
|
| 358 |
+
pass
|
| 359 |
+
|
| 360 |
+
if width > height and i % 2 != 0:
|
| 361 |
+
p = xys[i] * height
|
| 362 |
+
p += (width - height) // 2
|
| 363 |
+
p = round(p / width, 2)
|
| 364 |
+
|
| 365 |
+
if height > width and i % 2 == 0:
|
| 366 |
+
p = xys[i] * width
|
| 367 |
+
p += (height - width) // 2
|
| 368 |
+
p = round(p / height, 2)
|
| 369 |
+
|
| 370 |
+
new_xys.append(p)
|
| 371 |
+
|
| 372 |
+
new_text.append(text[start_idx: match.span()[0]])
|
| 373 |
+
new_text.append(str(new_xys))
|
| 374 |
+
start_idx = match.span()[1]
|
| 375 |
+
new_text.append(text[start_idx: ])
|
| 376 |
+
text = ''.join(new_text)
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
return text
|
| 380 |
+
|
| 381 |
+
def bbox_area(vertices, convert_format = True):
|
| 382 |
+
if convert_format:
|
| 383 |
+
vertices = list(zip(vertices[::2], vertices[1::2]))
|
| 384 |
+
x0, y0 = vertices[0]
|
| 385 |
+
x1, y1 = vertices[1]
|
| 386 |
+
return abs((x1 - x0) * (y1 - y0))
|
| 387 |
+
|
| 388 |
+
def polygon_area(vertices, convert_format = True):
|
| 389 |
+
if convert_format:
|
| 390 |
+
vertices = list(zip(vertices[::2], vertices[1::2]))
|
| 391 |
+
n = len(vertices) # 多边形顶点的数量
|
| 392 |
+
if n == 2:
|
| 393 |
+
return bbox_area(vertices, convert_format=False)
|
| 394 |
+
area = 0
|
| 395 |
+
for i in range(n):
|
| 396 |
+
x1, y1 = vertices[i]
|
| 397 |
+
x2, y2 = vertices[(i + 1) % n]
|
| 398 |
+
area += x1 * y2 - x2 * y1
|
| 399 |
+
return abs(area) / 2
|
| 400 |
+
|
| 401 |
+
def get_text_len(text_line):
|
| 402 |
+
l = 0
|
| 403 |
+
for c in text_line:
|
| 404 |
+
if '\u4e00' <= c <= '\u9fff':
|
| 405 |
+
l += 1
|
| 406 |
+
else:
|
| 407 |
+
l += 0.5
|
| 408 |
+
return l
|
| 409 |
+
|
| 410 |
+
def filter_ocr_polygon(response, area_threshold=0.0005):
|
| 411 |
+
try:
|
| 412 |
+
resp = json.loads(response)
|
| 413 |
+
except:
|
| 414 |
+
return response
|
| 415 |
+
new_resp = []
|
| 416 |
+
for coords, text_line in resp:
|
| 417 |
+
area = polygon_area(coords, convert_format=True)
|
| 418 |
+
text_len = get_text_len(text_line)
|
| 419 |
+
if text_len == 0:
|
| 420 |
+
continue
|
| 421 |
+
if area / text_len < area_threshold:
|
| 422 |
+
continue
|
| 423 |
+
new_resp.append([coords, text_line])
|
| 424 |
+
new_resp = json.dumps(new_resp, ensure_ascii=False)
|
| 425 |
+
|
| 426 |
+
return new_resp
|
| 427 |
+
|
| 428 |
+
def put_pred_to_data_dict(prediction, data_dict):
|
| 429 |
+
msg = data_dict['messages'][-1]
|
| 430 |
+
if msg['role'] == 'assistant':
|
| 431 |
+
msg['content'][-1]['text'] = prediction
|
| 432 |
+
else:
|
| 433 |
+
data_dict['messages'].append({
|
| 434 |
+
"role": "assistant",
|
| 435 |
+
"content": [{"type": "text", "text": prediction}]
|
| 436 |
+
})
|
| 437 |
+
|
| 438 |
+
def get_prompt_from_data_dict(data_dict):
|
| 439 |
+
prompt = ""
|
| 440 |
+
for msg in data_dict['messages']:
|
| 441 |
+
role = msg['role']
|
| 442 |
+
assert role in {'system', 'user', 'assistant'}
|
| 443 |
+
for content in msg['content']:
|
| 444 |
+
if content['type'] == 'text':
|
| 445 |
+
if content['text']:
|
| 446 |
+
prompt += f"[{role}]: {content['text']}"
|
| 447 |
+
elif content['type'] == 'image':
|
| 448 |
+
prompt += f"[{role}]: <image>"
|
| 449 |
+
elif content['type'] == 'video':
|
| 450 |
+
prompt += f"[{role}]: <video>"
|
| 451 |
+
prompt += '\n'
|
| 452 |
+
return prompt
|
eval_scripts/DREAM-1K/tarsier/dataset/custom_data_parsers/utils_visualize.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import Dict, List, Optional
|
| 3 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def scale_polygon(polygon, w, h):
|
| 7 |
+
new_polygon = []
|
| 8 |
+
for (x, y) in polygon:
|
| 9 |
+
new_polygon.append((x * w, y * h))
|
| 10 |
+
return new_polygon
|
| 11 |
+
|
| 12 |
+
def draw_polygon(image: Image.Image, points: List[List[int]], label: Optional[str] = None):
|
| 13 |
+
draw = ImageDraw.Draw(image)
|
| 14 |
+
if len(points) > 2:
|
| 15 |
+
draw.polygon(points, outline="red", width=3)
|
| 16 |
+
elif len(points) == 2:
|
| 17 |
+
draw.rectangle(points, outline="red", width=3)
|
| 18 |
+
else:
|
| 19 |
+
raise ValueError(f'points={points} only has one point!')
|
| 20 |
+
|
| 21 |
+
if label is not None:
|
| 22 |
+
font = ImageFont.truetype('/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf', 20)
|
| 23 |
+
draw.text(points[0], label, font=font, fill=(0, 0, 255))
|
| 24 |
+
return image
|
| 25 |
+
|
| 26 |
+
def visualize_image_bbox(data_dict, image_processing_config, processor):
|
| 27 |
+
if image_processing_config.get('has_coordinates') != True:
|
| 28 |
+
return
|
| 29 |
+
|
| 30 |
+
messages = data_dict['messages']
|
| 31 |
+
|
| 32 |
+
polygons = []
|
| 33 |
+
first_image_content = None
|
| 34 |
+
|
| 35 |
+
for msg in messages:
|
| 36 |
+
for content in msg['content']:
|
| 37 |
+
if content['type'] == 'text':
|
| 38 |
+
for match in re.finditer(r'\[(\d+(\.\d+)?,\s*)+\d+(\.\d+)?\]', content["text"]):
|
| 39 |
+
coordinate_matches = re.findall(r"([0-9.]+)", match.group(0))
|
| 40 |
+
coords = [float(coord) for coord in coordinate_matches]
|
| 41 |
+
polygons.append(list(zip(coords[::2], coords[1::2])))
|
| 42 |
+
elif first_image_content is None and content['type'] == 'image':
|
| 43 |
+
first_image_content = content
|
| 44 |
+
|
| 45 |
+
first_image = first_image_content['image']
|
| 46 |
+
first_image = processor.preprocess_image(first_image, image_processing_config)
|
| 47 |
+
w, h = first_image.size
|
| 48 |
+
|
| 49 |
+
if len(polygons) > 0:
|
| 50 |
+
for i, polygon in enumerate(polygons):
|
| 51 |
+
polygon = scale_polygon(polygon, w, h)
|
| 52 |
+
first_image = draw_polygon(first_image, polygon, label=str(i))
|
| 53 |
+
|
| 54 |
+
first_image_content['image'] = first_image
|
eval_scripts/DREAM-1K/tarsier/dataset/custom_data_parsers/video_permutation_parser.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List
|
| 2 |
+
import random
|
| 3 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 4 |
+
|
| 5 |
+
from .utils import sample_video
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class VideoPermutationParser:
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
n_frames=8,
|
| 12 |
+
is_training=True,
|
| 13 |
+
frame_nums = list(range(8, 25)),
|
| 14 |
+
video_sampling_strategy={},
|
| 15 |
+
):
|
| 16 |
+
self.n_frames = n_frames
|
| 17 |
+
self.is_training = is_training
|
| 18 |
+
self.frame_nums = frame_nums
|
| 19 |
+
self.video_sampling_strategy = video_sampling_strategy
|
| 20 |
+
# fmt: off
|
| 21 |
+
self.data_temp = {
|
| 22 |
+
"text": [{
|
| 23 |
+
"prompt": "<video>",
|
| 24 |
+
"response": ""
|
| 25 |
+
}],
|
| 26 |
+
"video": [{
|
| 27 |
+
"video_file": {
|
| 28 |
+
"yg": "/mnt/bn/videonasyg/videos/webvid_10M_download/011851_011900/1047443473.mp4",
|
| 29 |
+
"lq": "/mnt/bn/llmdatalq/jiangnan/video_generation/webvid_10M_download/20230609/videos/011851_011900/1047443473.mp4"
|
| 30 |
+
},
|
| 31 |
+
"frame_indices": [0, 85, 171, 256, 342, 427, 513, 598]
|
| 32 |
+
}],
|
| 33 |
+
}
|
| 34 |
+
# fmt: on
|
| 35 |
+
|
| 36 |
+
def check_format(self, data_dict: Dict):
|
| 37 |
+
pass
|
| 38 |
+
# for k in self.data_temp.keys():
|
| 39 |
+
# assert k in data_dict
|
| 40 |
+
|
| 41 |
+
def transform(self, data_dict: Dict, image_processing_config: Dict = None) -> Dict:
|
| 42 |
+
self.check_format(data_dict)
|
| 43 |
+
|
| 44 |
+
frames = self.load_video_item(data_dict['video'][0])
|
| 45 |
+
|
| 46 |
+
# frames = self.add_text_to_frames(frames) # for debug
|
| 47 |
+
|
| 48 |
+
idxs = list(range(1, len(frames) + 1))
|
| 49 |
+
random.shuffle(idxs)
|
| 50 |
+
|
| 51 |
+
prefix_len = int(3/8*len(idxs))
|
| 52 |
+
|
| 53 |
+
shuffled_frames = [frames[i-1] for i in idxs]
|
| 54 |
+
|
| 55 |
+
prompt = f'Output the correct chronological order of scrambled video frames. The order of the first {prefix_len} ones are:\n'
|
| 56 |
+
prompt += '\n'.join([str(i) for i in idxs[: prefix_len]]) + '\nOutput the order of the following frames:'
|
| 57 |
+
response = '\n'.join([str(i) for i in idxs[prefix_len: ]])
|
| 58 |
+
|
| 59 |
+
messages = [
|
| 60 |
+
{
|
| 61 |
+
"role": "user",
|
| 62 |
+
"content": [
|
| 63 |
+
{"type": "video", "video": shuffled_frames},
|
| 64 |
+
{"type": "text", "text": prompt},
|
| 65 |
+
]
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"role": "assistant",
|
| 69 |
+
"content": [
|
| 70 |
+
{"type": "text", "text": response}
|
| 71 |
+
]
|
| 72 |
+
}
|
| 73 |
+
]
|
| 74 |
+
|
| 75 |
+
return messages
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def load_video_item(self, video_item) -> List[Image.Image]:
|
| 79 |
+
"""
|
| 80 |
+
video_item:
|
| 81 |
+
{"video_file": "/path/to/video", "n_frames": 8}
|
| 82 |
+
{"video_file": "/path/to/video", "frame_indices": [0, 1, 2], "n_frames": 3}
|
| 83 |
+
{"video_file": "/path/to/video", "start_frame": 0, "end_frame": 100, "n_frames": 8}
|
| 84 |
+
{"video_file": "/path/to/video", "time_indices": [0, 1, 2], "n_frames": 3}
|
| 85 |
+
{"video_file": "/path/to/video", "start_time": 0, "end_time": 100, "n_frames": 8}
|
| 86 |
+
{"image_file": ["/path/to/image"], "frame_indices": [0, 1, 2], "n_frames": 3}
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
# check format
|
| 90 |
+
if ("image_file" not in video_item) and ("video_file" not in video_item):
|
| 91 |
+
raise KeyError(f"Key 'image_file' or 'video_file' not found in video_item")
|
| 92 |
+
|
| 93 |
+
video_path = video_item.get('video_file', video_item.get('image_file'))
|
| 94 |
+
n_frames = video_item.get('n_frames', None)
|
| 95 |
+
frame_indices = video_item.get('frame_indices', None)
|
| 96 |
+
start_frame = video_item.get('start_frame', None)
|
| 97 |
+
end_frame = video_item.get('end_frame', None)
|
| 98 |
+
time_indices = video_item.get('time_indices', None)
|
| 99 |
+
start_time = video_item.get('start_time', None)
|
| 100 |
+
end_time = video_item.get('end_time', None)
|
| 101 |
+
mask_boxes = video_item.get('mask_boxes', None)
|
| 102 |
+
|
| 103 |
+
n_frames = random.choice(self.frame_nums)
|
| 104 |
+
n = self.video_sampling_strategy.get('force_frames_n_divisible', 1)
|
| 105 |
+
if n > 1 and n_frames % n != 0:
|
| 106 |
+
n_frames += n - n_frames % n
|
| 107 |
+
|
| 108 |
+
frames, frame_indices = sample_video(
|
| 109 |
+
video_path=video_path,
|
| 110 |
+
frame_indices=frame_indices,
|
| 111 |
+
start_frame=start_frame,
|
| 112 |
+
end_frame=end_frame,
|
| 113 |
+
n_frames=n_frames,
|
| 114 |
+
time_indices=time_indices,
|
| 115 |
+
start_time=start_time,
|
| 116 |
+
end_time=end_time,
|
| 117 |
+
mask_boxes=mask_boxes,
|
| 118 |
+
is_training=self.is_training,
|
| 119 |
+
video_sampling_strategy=self.video_sampling_strategy,
|
| 120 |
+
return_frame_ids=True,
|
| 121 |
+
)
|
| 122 |
+
return frames
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def add_text_to_frames(self, frames: List[Image.Image]):
|
| 126 |
+
new_frames = []
|
| 127 |
+
for i, image in enumerate(frames):
|
| 128 |
+
draw = ImageDraw.Draw(image)
|
| 129 |
+
|
| 130 |
+
font = ImageFont.truetype('/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf', 100)
|
| 131 |
+
text_position = (50, 50)
|
| 132 |
+
text_content = f'{i+1}'
|
| 133 |
+
text_color = (255, 0, 0)
|
| 134 |
+
draw.text(text_position, text_content, font=font, fill=text_color)
|
| 135 |
+
new_frames.append(image)
|
| 136 |
+
return new_frames
|
| 137 |
+
|
eval_scripts/DREAM-1K/tarsier/dataset/tarsier_datamodule.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Datamodule for Llava Pretraining and Finetuning"""
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import numpy as np
|
| 6 |
+
import re
|
| 7 |
+
import tempfile
|
| 8 |
+
from typing import Dict, List, Union, Tuple
|
| 9 |
+
import traceback
|
| 10 |
+
import json
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from transformers import DataCollatorForSeq2Seq
|
| 15 |
+
|
| 16 |
+
from tools.rw_utils import read_jsonlines
|
| 17 |
+
from torch.utils.data import Dataset, DataLoader
|
| 18 |
+
|
| 19 |
+
np_str_obj_array_pattern = re.compile(r"[SaUO]")
|
| 20 |
+
|
| 21 |
+
default_collate_err_msg_format = (
|
| 22 |
+
"default_collate: batch must contain tensors, numpy arrays, numbers, "
|
| 23 |
+
"dicts or lists; found {}"
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
from .custom_data_parsers.standard_vision_parser import VisionParser
|
| 27 |
+
from .custom_data_parsers.object_tracking_parser import ObjectTrackingParser
|
| 28 |
+
from .custom_data_parsers.multi_images_parser import MultiImagesParser
|
| 29 |
+
from .custom_data_parsers.video_permutation_parser import VideoPermutationParser
|
| 30 |
+
from .custom_data_parsers.utils_visualize import visualize_image_bbox
|
| 31 |
+
|
| 32 |
+
from .tarsier_processor import TarsierProcessor
|
| 33 |
+
|
| 34 |
+
from tools.rw_utils import NumpyArrayEncoder
|
| 35 |
+
from .utils import DictToObject
|
| 36 |
+
|
| 37 |
+
class TarsierDataProcessor:
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
processor: TarsierProcessor,
|
| 41 |
+
n_frames: Union[int, list],
|
| 42 |
+
max_n_frames=256,
|
| 43 |
+
max_pixels=int(1280 * 720 // 2),
|
| 44 |
+
min_pixels=0,
|
| 45 |
+
max_seq_len=None,
|
| 46 |
+
is_training=True, # 会影响:1. 训练和测试时采帧不同;2. 测试时忽略 response。
|
| 47 |
+
print_data_error=True,
|
| 48 |
+
do_image_padding=False,
|
| 49 |
+
do_image_crop=False,
|
| 50 |
+
do_image_resize=True,
|
| 51 |
+
video_sampling_strategy={},
|
| 52 |
+
prompt='',
|
| 53 |
+
train_task='sft',
|
| 54 |
+
**kwargs
|
| 55 |
+
):
|
| 56 |
+
self.kwargs = kwargs
|
| 57 |
+
|
| 58 |
+
self.processor = processor
|
| 59 |
+
self.pad_collator = DataCollatorForSeq2Seq(processor.tokenizer, padding='longest')
|
| 60 |
+
|
| 61 |
+
self.processor.max_seq_len = self.tokenizer.model_max_length if max_seq_len is None else max_seq_len
|
| 62 |
+
|
| 63 |
+
self.n_frames = n_frames
|
| 64 |
+
self.max_n_frames = max_n_frames
|
| 65 |
+
self.max_pixels = max_pixels
|
| 66 |
+
self.min_pixels = min_pixels
|
| 67 |
+
|
| 68 |
+
self.is_training = is_training
|
| 69 |
+
self.print_data_error = print_data_error
|
| 70 |
+
self.do_image_padding = do_image_padding
|
| 71 |
+
self.do_image_crop = do_image_crop
|
| 72 |
+
self.do_image_resize = do_image_resize
|
| 73 |
+
self.video_sampling_strategy = video_sampling_strategy
|
| 74 |
+
self.prompt = prompt
|
| 75 |
+
self.train_task = train_task
|
| 76 |
+
|
| 77 |
+
self.object_tracking_parser = ObjectTrackingParser(
|
| 78 |
+
n_frames=self.n_frames,
|
| 79 |
+
max_objects=4,
|
| 80 |
+
is_training=self.is_training,
|
| 81 |
+
)
|
| 82 |
+
self.multi_images_parser = MultiImagesParser(
|
| 83 |
+
n_frames=self.n_frames,
|
| 84 |
+
is_training=self.is_training,
|
| 85 |
+
)
|
| 86 |
+
self.video_permutation_parser = VideoPermutationParser(
|
| 87 |
+
n_frames=self.n_frames,
|
| 88 |
+
is_training=self.is_training,
|
| 89 |
+
video_sampling_strategy=self.video_sampling_strategy,
|
| 90 |
+
)
|
| 91 |
+
self.vision_parser = VisionParser(
|
| 92 |
+
n_frames=self.n_frames,
|
| 93 |
+
max_n_frames=self.max_n_frames,
|
| 94 |
+
is_training=self.is_training,
|
| 95 |
+
video_sampling_strategy=self.video_sampling_strategy
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
def select_parser(self, data_dict):
|
| 99 |
+
if data_dict.get('task', None) == 'video/object_tracking':
|
| 100 |
+
return self.object_tracking_parser
|
| 101 |
+
elif data_dict.get('task', None) == 'multi_images':
|
| 102 |
+
return self.multi_images_parser
|
| 103 |
+
elif data_dict.get('dataset', None) == 'video_permutation':
|
| 104 |
+
return self.video_permutation_parser
|
| 105 |
+
else:
|
| 106 |
+
return self.vision_parser
|
| 107 |
+
|
| 108 |
+
def parse_image_processing_config(self, data_dict):
|
| 109 |
+
image_processing_config=data_dict.get('image_processing_config', {})
|
| 110 |
+
|
| 111 |
+
do_padding = image_processing_config.get('do_padding', self.do_image_padding)
|
| 112 |
+
do_crop = image_processing_config.get('do_crop', self.do_image_crop)
|
| 113 |
+
do_resize = image_processing_config.get('do_resize', self.do_image_resize)
|
| 114 |
+
max_pixels = image_processing_config.get('max_pixels', self.max_pixels)
|
| 115 |
+
min_pixels = image_processing_config.get('min_pixels', self.min_pixels)
|
| 116 |
+
|
| 117 |
+
assert min_pixels <= max_pixels
|
| 118 |
+
|
| 119 |
+
image_processing_config['do_padding'] = do_padding
|
| 120 |
+
image_processing_config['do_crop'] = do_crop
|
| 121 |
+
image_processing_config['do_resize'] = do_resize
|
| 122 |
+
image_processing_config['max_pixels'] = max_pixels
|
| 123 |
+
image_processing_config['min_pixels'] = min_pixels
|
| 124 |
+
|
| 125 |
+
return image_processing_config
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def _transform(self, raw_data_dict: Dict) -> Dict:
|
| 129 |
+
data_dict = json.loads(json.dumps(raw_data_dict, cls=NumpyArrayEncoder))
|
| 130 |
+
del raw_data_dict
|
| 131 |
+
|
| 132 |
+
if self.prompt:
|
| 133 |
+
for msg in data_dict['messages']:
|
| 134 |
+
if msg['role'] == 'user':
|
| 135 |
+
for content in msg['content']:
|
| 136 |
+
if content['type'] == 'text':
|
| 137 |
+
content['text'] = self.prompt
|
| 138 |
+
|
| 139 |
+
data_dict_copy = json.loads(json.dumps(data_dict, cls=NumpyArrayEncoder))
|
| 140 |
+
|
| 141 |
+
image_processing_config = self.parse_image_processing_config(data_dict)
|
| 142 |
+
parser = self.select_parser(data_dict)
|
| 143 |
+
messages = parser.transform(data_dict, image_processing_config)
|
| 144 |
+
data_dict_copy['extra_info'] = data_dict.pop('extra_info', {})
|
| 145 |
+
|
| 146 |
+
# visualize_image_bbox(data_dict, image_processing_config, self.processor)
|
| 147 |
+
outputs = self.processor(messages, image_processing_config, is_training=self.is_training)
|
| 148 |
+
|
| 149 |
+
# if not self.is_training:
|
| 150 |
+
outputs['raw_data_dict'] = data_dict_copy
|
| 151 |
+
|
| 152 |
+
return [outputs]
|
| 153 |
+
|
| 154 |
+
def _split_chosen_rejected(self, data_dict: Dict):
|
| 155 |
+
chosen_data_dict = data_dict
|
| 156 |
+
rejected_data_dict = json.loads(json.dumps(data_dict, cls=NumpyArrayEncoder))
|
| 157 |
+
for msg in chosen_data_dict['messages']:
|
| 158 |
+
if msg['role'] == 'assistant':
|
| 159 |
+
for content in msg['content']:
|
| 160 |
+
if content['type'] == 'text':
|
| 161 |
+
content['text'] = content['chosen']
|
| 162 |
+
|
| 163 |
+
for msg in rejected_data_dict['messages']:
|
| 164 |
+
if msg['role'] == 'assistant':
|
| 165 |
+
for content in msg['content']:
|
| 166 |
+
if content['type'] == 'text':
|
| 167 |
+
content['text'] = content['rejected']
|
| 168 |
+
|
| 169 |
+
return chosen_data_dict, rejected_data_dict
|
| 170 |
+
|
| 171 |
+
def transform(self, data_dict: Dict) -> Dict:
|
| 172 |
+
try:
|
| 173 |
+
if self.train_task == 'dpo':
|
| 174 |
+
chosen_data_dict, rejected_data_dict = self._split_chosen_rejected(data_dict)
|
| 175 |
+
return self._transform(chosen_data_dict) + self._transform(rejected_data_dict)
|
| 176 |
+
return self._transform(data_dict)
|
| 177 |
+
except Exception as e:
|
| 178 |
+
if self.print_data_error:
|
| 179 |
+
print(traceback.format_exc())
|
| 180 |
+
print(f'Error occurs when processing: \n{data_dict}')
|
| 181 |
+
return []
|
| 182 |
+
|
| 183 |
+
def batch_transform(self, batch_data: List[Dict]) -> Dict:
|
| 184 |
+
model_inputs = {}
|
| 185 |
+
# if not self.is_training:
|
| 186 |
+
raw_data_dict = [d.pop('raw_data_dict') for d in batch_data]
|
| 187 |
+
model_inputs['raw_data_dict'] = raw_data_dict
|
| 188 |
+
|
| 189 |
+
batch_pixel_values = [d.pop('pixel_values') for d in batch_data if 'pixel_values' in d]
|
| 190 |
+
batch_image_grid_thw = [d.pop('image_grid_thw') for d in batch_data if 'image_grid_thw' in d]
|
| 191 |
+
if len(batch_pixel_values) == 0:
|
| 192 |
+
vision_placeholder = self.get_vision_placeholder()
|
| 193 |
+
batch_pixel_values = [vision_placeholder.get('pixel_values')]
|
| 194 |
+
batch_image_grid_thw = [vision_placeholder.get('image_grid_thw')] if 'image_grid_thw' in vision_placeholder else []
|
| 195 |
+
|
| 196 |
+
model_inputs['pixel_values'] = torch.cat(batch_pixel_values, dim=0)
|
| 197 |
+
if len(batch_image_grid_thw) > 0:
|
| 198 |
+
model_inputs['image_grid_thw'] = torch.cat(batch_image_grid_thw, dim=0)
|
| 199 |
+
|
| 200 |
+
batch_num_images = [d.pop('num_images') for d in batch_data]
|
| 201 |
+
model_inputs['num_images'] = torch.tensor(batch_num_images)
|
| 202 |
+
model_inputs.update(self.pad_collator(batch_data))
|
| 203 |
+
return model_inputs
|
| 204 |
+
|
| 205 |
+
def __call__(self, batch_data: Union[Dict, List[Dict]]) -> Dict:
|
| 206 |
+
if isinstance(batch_data, dict):
|
| 207 |
+
batch_data = [batch_data]
|
| 208 |
+
batch = [self.transform(d)[0] for d in batch_data]
|
| 209 |
+
return self.batch_transform(batch)
|
| 210 |
+
|
| 211 |
+
def get_vision_placeholder(self):
|
| 212 |
+
messages = [{"role": "user", "content": [{"type": "image", "image": Image.new(mode='RGB', size=(336, 336))}]}]
|
| 213 |
+
image_processing_config = self.parse_image_processing_config({})
|
| 214 |
+
return self.processor(messages, image_processing_config)
|
| 215 |
+
|
| 216 |
+
def get_text_placeholder(self):
|
| 217 |
+
messages = [
|
| 218 |
+
{"role": "user", "content": [{"type": "text", "text": "Hello!"}]},
|
| 219 |
+
{"role": "assistant", "content": [{"type": "text", "text": "Thank you very much"}]},
|
| 220 |
+
]
|
| 221 |
+
image_processing_config = self.parse_image_processing_config({})
|
| 222 |
+
return self.processor(messages, image_processing_config)
|
| 223 |
+
|
| 224 |
+
def init_processor(processor: Union[TarsierProcessor, str]=None, config: Dict=None):
|
| 225 |
+
config = DictToObject(config) if isinstance(config, dict) else config
|
| 226 |
+
if isinstance(processor, str):
|
| 227 |
+
sub_processor = TarsierProcessor.from_pretrained(
|
| 228 |
+
processor,
|
| 229 |
+
padding_side='left',
|
| 230 |
+
trust_remote_code=True
|
| 231 |
+
)
|
| 232 |
+
else:
|
| 233 |
+
sub_processor = processor
|
| 234 |
+
processor = TarsierDataProcessor(
|
| 235 |
+
processor=sub_processor,
|
| 236 |
+
n_frames=config.n_frames,
|
| 237 |
+
max_n_frames=config.max_n_frames,
|
| 238 |
+
max_pixels=config.max_pixels,
|
| 239 |
+
min_pixels=config.min_pixels,
|
| 240 |
+
max_seq_len=config.max_seq_len,
|
| 241 |
+
is_training=config.is_training,
|
| 242 |
+
print_data_error=config.print_data_error,
|
| 243 |
+
do_image_padding=config.do_image_padding,
|
| 244 |
+
do_image_crop=config.do_image_crop,
|
| 245 |
+
do_image_resize=config.do_image_resize,
|
| 246 |
+
video_sampling_strategy=config.video_sampling_strategy,
|
| 247 |
+
prompt=config.prompt,
|
| 248 |
+
train_task=config.train_task
|
| 249 |
+
)
|
| 250 |
+
return processor
|
| 251 |
+
|
| 252 |
+
class TarsierDataset(Dataset):
|
| 253 |
+
def __init__(self, ann_path="", anns=None, config: Dict=None, processor: Union[TarsierDataProcessor, TarsierProcessor, str]=None):
|
| 254 |
+
self.config = DictToObject(config) if isinstance(config, dict) else config
|
| 255 |
+
if not isinstance(processor, TarsierDataProcessor):
|
| 256 |
+
self.processor = init_processor(processor, config)
|
| 257 |
+
else:
|
| 258 |
+
self.processor = processor
|
| 259 |
+
if anns is None:
|
| 260 |
+
self.anns = []
|
| 261 |
+
if isinstance(ann_path, str):
|
| 262 |
+
ann_path = [ann_path]
|
| 263 |
+
for path in ann_path:
|
| 264 |
+
self.anns.extend(read_jsonlines(path))
|
| 265 |
+
else:
|
| 266 |
+
self.anns = anns
|
| 267 |
+
|
| 268 |
+
def __len__(self):
|
| 269 |
+
return len(self.anns)
|
| 270 |
+
|
| 271 |
+
def __getitem__(self, index):
|
| 272 |
+
if index < 0 or index >= len(self.anns):
|
| 273 |
+
raise IndexError("Index out of range")
|
| 274 |
+
try:
|
| 275 |
+
ann = self.anns[index]
|
| 276 |
+
model_inputs = self.processor(ann)
|
| 277 |
+
except Exception as e:
|
| 278 |
+
print(f"Load data error: {e}")
|
| 279 |
+
return ann, None
|
| 280 |
+
return ann, model_inputs
|
eval_scripts/DREAM-1K/tarsier/dataset/tarsier_processor.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Union
|
| 2 |
+
from PIL import Image
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from transformers.feature_extraction_utils import BatchFeature
|
| 7 |
+
from transformers.image_utils import ImageInput, get_image_size, to_numpy_array
|
| 8 |
+
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack, _validate_images_text_input_order
|
| 9 |
+
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
| 10 |
+
from transformers.utils import logging
|
| 11 |
+
from transformers import Qwen2VLImageProcessor
|
| 12 |
+
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
|
| 13 |
+
|
| 14 |
+
logger = logging.get_logger(__name__)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class TarsierProcessorKwargs(ProcessingKwargs, total=False):
|
| 18 |
+
_defaults = {
|
| 19 |
+
"text_kwargs": {},
|
| 20 |
+
"images_kwargs": {},
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class TarsierProcessor(ProcessorMixin):
|
| 25 |
+
|
| 26 |
+
attributes = ["image_processor", "tokenizer"]
|
| 27 |
+
valid_kwargs = ["chat_template", "image_token", "patch_size", "merge_size", "temporal_patch_size", "max_seq_len"]
|
| 28 |
+
image_processor_class = "AutoImageProcessor"
|
| 29 |
+
tokenizer_class = "AutoTokenizer"
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
image_processor=None,
|
| 34 |
+
tokenizer=None,
|
| 35 |
+
chat_template=None,
|
| 36 |
+
image_token="<image>",
|
| 37 |
+
patch_size=None,
|
| 38 |
+
merge_size=1,
|
| 39 |
+
temporal_patch_size=1,
|
| 40 |
+
max_seq_len=8192,
|
| 41 |
+
**kwargs,
|
| 42 |
+
) -> None:
|
| 43 |
+
|
| 44 |
+
self.image_token = image_token
|
| 45 |
+
self.patch_size = patch_size
|
| 46 |
+
self.merge_size = merge_size
|
| 47 |
+
self.temporal_patch_size = temporal_patch_size
|
| 48 |
+
self.max_seq_len = max_seq_len
|
| 49 |
+
self.max_pixels_per_sample = 128 * 384 * 384
|
| 50 |
+
|
| 51 |
+
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
| 52 |
+
|
| 53 |
+
def __call__(
|
| 54 |
+
self,
|
| 55 |
+
messages,
|
| 56 |
+
image_processing_config=None,
|
| 57 |
+
is_training=True,
|
| 58 |
+
) -> torch.Tensor:
|
| 59 |
+
|
| 60 |
+
output_kwargs = self._merge_kwargs(
|
| 61 |
+
TarsierProcessorKwargs,
|
| 62 |
+
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# 【图片处理】
|
| 66 |
+
pixel_values, image_grid_thw = [], []
|
| 67 |
+
num_images = 0
|
| 68 |
+
for msg in messages:
|
| 69 |
+
for content in msg['content']:
|
| 70 |
+
if content['type'] == 'image':
|
| 71 |
+
num_images += self.temporal_patch_size
|
| 72 |
+
elif content['type'] == 'video':
|
| 73 |
+
num_images += len(content['video'])
|
| 74 |
+
if num_images > 0 and self.max_pixels_per_sample // num_images < image_processing_config['max_pixels']:
|
| 75 |
+
image_processing_config['max_pixels'] = self.max_pixels_per_sample // num_images
|
| 76 |
+
image_processing_config['min_pixels'] = min(image_processing_config['min_pixels'], image_processing_config['max_pixels'])
|
| 77 |
+
|
| 78 |
+
for msg in messages:
|
| 79 |
+
for content in msg['content']:
|
| 80 |
+
if content['type'] == 'image':
|
| 81 |
+
content['image'] = self.preprocess_image(content['image'], image_processing_config)
|
| 82 |
+
content['image'] = self.image_processor(images = content['image'], **output_kwargs["images_kwargs"], return_tensors="pt")
|
| 83 |
+
content['num_vision_tokens'] = self.get_num_vision_tokens(content)
|
| 84 |
+
pixel_values.append(content['image']['pixel_values'])
|
| 85 |
+
if 'image_grid_thw' in content['image']:
|
| 86 |
+
image_grid_thw.extend(content['image']['image_grid_thw'])
|
| 87 |
+
elif content['type'] == 'video':
|
| 88 |
+
content['video'] = self.preprocess_image(content['video'], image_processing_config)
|
| 89 |
+
if isinstance(self.image_processor, Qwen2VLImageProcessor):
|
| 90 |
+
content['video'] = self.image_processor(images = None, videos = content['video'], **output_kwargs["images_kwargs"], return_tensors="pt")
|
| 91 |
+
pixel_values.append(content['video']['pixel_values_videos'])
|
| 92 |
+
else:
|
| 93 |
+
content['video'] = self.image_processor(images = content['video'], **output_kwargs["images_kwargs"], return_tensors="pt")
|
| 94 |
+
pixel_values.append(content['video']['pixel_values'])
|
| 95 |
+
|
| 96 |
+
if 'video_grid_thw' in content['video']:
|
| 97 |
+
image_grid_thw.extend(content['video']['video_grid_thw'])
|
| 98 |
+
content['num_vision_tokens'] = self.get_num_vision_tokens(content)
|
| 99 |
+
|
| 100 |
+
#【文本处理】
|
| 101 |
+
add_generation_prompt = (not is_training and messages[-1]['role'] != 'assistant')
|
| 102 |
+
strip_final_eos = (not is_training and messages[-1]['role'] == 'assistant')
|
| 103 |
+
text_inputs = self.tokenizer.apply_chat_template(
|
| 104 |
+
messages,
|
| 105 |
+
chat_template = self.chat_template,
|
| 106 |
+
tokenize=True,
|
| 107 |
+
tokenizer_kwargs = output_kwargs["text_kwargs"],
|
| 108 |
+
return_assistant_tokens_mask=True,
|
| 109 |
+
return_dict=True,
|
| 110 |
+
add_generation_prompt=add_generation_prompt,
|
| 111 |
+
strip_final_eos=strip_final_eos,
|
| 112 |
+
)
|
| 113 |
+
labels = [-100 if j == 0 else i for i, j in zip(text_inputs['input_ids'], text_inputs['assistant_masks'])]
|
| 114 |
+
labels = labels[:self.max_seq_len]
|
| 115 |
+
input_ids = text_inputs['input_ids'][:self.max_seq_len]
|
| 116 |
+
|
| 117 |
+
image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token)
|
| 118 |
+
if image_token_id in text_inputs['input_ids'][self.max_seq_len:]:
|
| 119 |
+
raise ValueError(f'Too long sequence! {len(text_inputs["input_ids"])}')
|
| 120 |
+
|
| 121 |
+
outputs = {
|
| 122 |
+
'input_ids': input_ids,
|
| 123 |
+
'labels': labels,
|
| 124 |
+
'num_images': num_images,
|
| 125 |
+
}
|
| 126 |
+
if len(pixel_values) > 0:
|
| 127 |
+
outputs['pixel_values'] = torch.cat(pixel_values, dim=0)
|
| 128 |
+
if len(image_grid_thw) > 0:
|
| 129 |
+
outputs['image_grid_thw'] = torch.stack(image_grid_thw)
|
| 130 |
+
return outputs
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def preprocess_image(self, pil_img: Union[Image.Image, List[Image.Image]], image_processing_config):
|
| 134 |
+
if image_processing_config is None:
|
| 135 |
+
return pil_img
|
| 136 |
+
images = pil_img
|
| 137 |
+
if isinstance(pil_img, Image.Image):
|
| 138 |
+
images = [images]
|
| 139 |
+
if image_processing_config['do_crop']:
|
| 140 |
+
images = [self.centralcrop(img, rate=[4, 3]) for img in images]
|
| 141 |
+
if image_processing_config['do_padding']:
|
| 142 |
+
images = [self.expand2square(
|
| 143 |
+
img,
|
| 144 |
+
# tuple(int(x * 255) for x in self.processor.image_processor.image_mean)
|
| 145 |
+
tuple(int(x * 255) for x in [0, 0, 0])
|
| 146 |
+
) for img in images]
|
| 147 |
+
if image_processing_config['do_resize']:
|
| 148 |
+
images = [self.resize2square(img) for img in images]
|
| 149 |
+
if image_processing_config.get('max_pixels'):
|
| 150 |
+
images = [self.resize2pixels(
|
| 151 |
+
img,
|
| 152 |
+
int(image_processing_config['max_pixels']),
|
| 153 |
+
int(image_processing_config['min_pixels'])
|
| 154 |
+
) for img in images]
|
| 155 |
+
if isinstance(pil_img, Image.Image):
|
| 156 |
+
images = images[0]
|
| 157 |
+
return images
|
| 158 |
+
|
| 159 |
+
def expand2square(self, pil_img, background_color):
|
| 160 |
+
width, height = pil_img.size
|
| 161 |
+
if width == height:
|
| 162 |
+
return pil_img
|
| 163 |
+
elif width > height:
|
| 164 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
| 165 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
| 166 |
+
return result
|
| 167 |
+
else:
|
| 168 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
| 169 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
| 170 |
+
return result
|
| 171 |
+
|
| 172 |
+
def resize2square(self, pil_img: Image.Image):
|
| 173 |
+
width, height = pil_img.size
|
| 174 |
+
pil_img = pil_img.resize((max(width, height), max(width, height)))
|
| 175 |
+
return pil_img
|
| 176 |
+
|
| 177 |
+
def centralcrop(self, pil_img: Image.Image, rate=[4, 3]):
|
| 178 |
+
width, height = pil_img.size
|
| 179 |
+
size = (width, height)
|
| 180 |
+
min_len = min(size)
|
| 181 |
+
longer_side = 0 if width >= height else 1
|
| 182 |
+
center = (width/2, height/2)
|
| 183 |
+
box = [0, 0, size[0], size[1]]
|
| 184 |
+
|
| 185 |
+
# if longer_side == 0:
|
| 186 |
+
# box[0] = max(0, center[0] - 1/2*min_len/rate[1]*rate[0])
|
| 187 |
+
# box[2] = min(width, center[0] + 1/2*min_len/rate[1]*rate[0])
|
| 188 |
+
# else:
|
| 189 |
+
# box[1] = max(0, center[1] - 1/2*min_len/rate[1]*rate[0])
|
| 190 |
+
# box[3] = min(height, center[1] + 1/2*min_len/rate[1]*rate[0])
|
| 191 |
+
box[longer_side] = max(0, center[longer_side] - 1/2*min_len/rate[1]*rate[0])
|
| 192 |
+
box[2 + longer_side] = min(size[longer_side], center[longer_side] + 1/2*min_len/rate[1]*rate[0])
|
| 193 |
+
|
| 194 |
+
# box = (width/2-min_len/2, height/2-min_len/2, width/2+min_len/2, height/2+min_len/2)
|
| 195 |
+
pil_img = pil_img.crop(box)
|
| 196 |
+
return pil_img
|
| 197 |
+
|
| 198 |
+
def resize2pixels(self, pil_img: Image.Image, max_pixels=None, min_pixels=None):
|
| 199 |
+
width, height = pil_img.size
|
| 200 |
+
new_height, new_width = smart_resize(height, width, factor=1, max_pixels=max_pixels, min_pixels=min_pixels)
|
| 201 |
+
pil_img = pil_img.resize((new_width, new_height))
|
| 202 |
+
return pil_img
|
| 203 |
+
|
| 204 |
+
def get_num_vision_tokens(self, content):
|
| 205 |
+
if isinstance(self.image_processor, Qwen2VLImageProcessor):
|
| 206 |
+
merge_length = self.image_processor.merge_size**2
|
| 207 |
+
if content['type'] == 'image':
|
| 208 |
+
num_image_tokens = content['image']['image_grid_thw'].prod() // merge_length
|
| 209 |
+
else:
|
| 210 |
+
num_image_tokens = content['video']['video_grid_thw'].prod() // merge_length
|
| 211 |
+
return num_image_tokens
|
| 212 |
+
else:
|
| 213 |
+
# 其他模型:image tokens (-> 2x2 compressed) -> add image_newline and image_new
|
| 214 |
+
k = 'image'if content['type'] == 'image' else 'video'
|
| 215 |
+
pixel_values = content[k]['pixel_values'][0]
|
| 216 |
+
n_frames = len(content[k]['pixel_values'])
|
| 217 |
+
|
| 218 |
+
height, width = get_image_size(to_numpy_array(pixel_values))
|
| 219 |
+
num_image_tokens = (height // (self.patch_size * self.merge_size)) * (width // (self.patch_size * self.merge_size) + 1) + 1
|
| 220 |
+
return num_image_tokens * n_frames
|
| 221 |
+
|
| 222 |
+
def batch_decode(self, *args, **kwargs):
|
| 223 |
+
"""
|
| 224 |
+
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
| 225 |
+
refer to the docstring of this method for more information.
|
| 226 |
+
"""
|
| 227 |
+
return self.tokenizer.batch_decode(*args, **kwargs)
|
| 228 |
+
|
| 229 |
+
def decode(self, *args, **kwargs):
|
| 230 |
+
"""
|
| 231 |
+
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
| 232 |
+
the docstring of this method for more information.
|
| 233 |
+
"""
|
| 234 |
+
return self.tokenizer.decode(*args, **kwargs)
|
| 235 |
+
|
| 236 |
+
@property
|
| 237 |
+
def model_input_names(self):
|
| 238 |
+
tokenizer_input_names = self.tokenizer.model_input_names
|
| 239 |
+
image_processor_input_names = self.image_processor.model_input_names
|
| 240 |
+
return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names))
|
eval_scripts/DREAM-1K/tarsier/dataset/utils.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (2024) Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from typing import List
|
| 15 |
+
import os
|
| 16 |
+
from PIL import Image, ImageSequence
|
| 17 |
+
import decord
|
| 18 |
+
|
| 19 |
+
VALID_DATA_FORMAT_STRING = "Input data must be {'.jpg', '.jpeg', '.png', '.tif'} for image; or {'.mp4', '.avi', '.webm', '.mov', '.mkv', '.wmv', '.gif'} for videos!"
|
| 20 |
+
|
| 21 |
+
# 均匀抽帧,必采样首尾帧。
|
| 22 |
+
def sample_frame_indices(start_frame, total_frames: int, n_frames: int):
|
| 23 |
+
if n_frames == 1:
|
| 24 |
+
return [0] # sample first frame in default
|
| 25 |
+
sample_ids = [round(i * (total_frames - 1) / (n_frames - 1)) for i in range(n_frames)]
|
| 26 |
+
sample_ids = [i + start_frame for i in sample_ids]
|
| 27 |
+
return sample_ids
|
| 28 |
+
|
| 29 |
+
def sample_video(
|
| 30 |
+
video_path: str,
|
| 31 |
+
n_frames: int = None,
|
| 32 |
+
start_time: int = 0,
|
| 33 |
+
end_time: int = -1
|
| 34 |
+
) -> List[Image.Image]:
|
| 35 |
+
|
| 36 |
+
assert os.path.exists(video_path), f"File not found: {video_path}"
|
| 37 |
+
vr = decord.VideoReader(video_path, num_threads=1, ctx=decord.cpu(0))
|
| 38 |
+
vr.seek(0)
|
| 39 |
+
total_frames = len(vr)
|
| 40 |
+
fps = vr.get_avg_fps()
|
| 41 |
+
|
| 42 |
+
start_frame = 0
|
| 43 |
+
end_frame = total_frames - 1
|
| 44 |
+
if start_time > 0:
|
| 45 |
+
start_frame = min((total_frames-1), int(fps*start_time))
|
| 46 |
+
if end_time > 0:
|
| 47 |
+
end_frame = max(start_frame, int(fps*end_time))
|
| 48 |
+
end_frame = min(end_frame, (total_frames-1))
|
| 49 |
+
frame_indices = sample_frame_indices(
|
| 50 |
+
start_frame=start_frame,
|
| 51 |
+
total_frames=end_frame - start_frame + 1,
|
| 52 |
+
n_frames=n_frames,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
frames = vr.get_batch(frame_indices).asnumpy()
|
| 56 |
+
frames = [Image.fromarray(f).convert('RGB') for f in frames]
|
| 57 |
+
return frames
|
| 58 |
+
|
| 59 |
+
def sample_gif(
|
| 60 |
+
gif_path: str,
|
| 61 |
+
n_frames:int = None,
|
| 62 |
+
start_time: int = 0,
|
| 63 |
+
end_time: int = -1
|
| 64 |
+
) -> List[Image.Image]:
|
| 65 |
+
|
| 66 |
+
assert os.path.exists(gif_path), f"File not found: {gif_path}"
|
| 67 |
+
|
| 68 |
+
gif_frames = Image.open(gif_path)
|
| 69 |
+
|
| 70 |
+
start_frame = 0
|
| 71 |
+
end_frame = gif_frames.n_frames - 1
|
| 72 |
+
frame_indices = sample_frame_indices(
|
| 73 |
+
start_frame=start_frame,
|
| 74 |
+
total_frames=end_frame - start_frame + 1,
|
| 75 |
+
n_frames=n_frames,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
frames = []
|
| 79 |
+
i = 0
|
| 80 |
+
for frame in ImageSequence.Iterator(gif_frames):
|
| 81 |
+
if i in frame_indices:
|
| 82 |
+
frames.append(frame.convert('RGB'))
|
| 83 |
+
i += 1
|
| 84 |
+
return frames
|
| 85 |
+
|
| 86 |
+
def sample_image(
|
| 87 |
+
image_path: str,
|
| 88 |
+
n_frames: int = None,
|
| 89 |
+
start_time: int = 0,
|
| 90 |
+
end_time: int = -1
|
| 91 |
+
):
|
| 92 |
+
assert os.path.exists(image_path), f"File not found: {image_path}"
|
| 93 |
+
image = Image.open(image_path).convert('RGB')
|
| 94 |
+
return [image]
|
| 95 |
+
|
| 96 |
+
def get_visual_type(input_file):
|
| 97 |
+
ext = os.path.splitext(input_file)[-1]
|
| 98 |
+
if ext in {'.gif'}:
|
| 99 |
+
return 'gif'
|
| 100 |
+
elif ext in {'.mp4', '.avi', '.webm', '.mov', '.mkv', '.wmv'}:
|
| 101 |
+
return 'video'
|
| 102 |
+
elif ext in {'.jpg', '.jpeg', '.png', '.tif'}:
|
| 103 |
+
return 'image'
|
| 104 |
+
else:
|
| 105 |
+
print(f"{VALID_DATA_FORMAT_STRING} But found {ext}!")
|
| 106 |
+
return 'unk'
|
| 107 |
+
|
| 108 |
+
def get_benchmarks(benchmarks):
|
| 109 |
+
final_benchmarks = []
|
| 110 |
+
type2bm = {
|
| 111 |
+
'dream': ['dream'],
|
| 112 |
+
'caption': ['msvd-caption', 'msr-vtt-caption', 'vatex-caption'],
|
| 113 |
+
'mc_qa': ['next-qa', 'egoschema', 'mvbench', 'video-mme'],
|
| 114 |
+
'oe_qa': ['msvd-qa', 'msr-vtt-qa', 'tgif-qa', 'anet-qa'],
|
| 115 |
+
}
|
| 116 |
+
for bm in benchmarks:
|
| 117 |
+
bm = bm.lower()
|
| 118 |
+
if bm in final_benchmarks:
|
| 119 |
+
continue
|
| 120 |
+
if bm == 'all':
|
| 121 |
+
for v in type2bm.values():
|
| 122 |
+
final_benchmarks.extend(v)
|
| 123 |
+
return final_benchmarks
|
| 124 |
+
if bm in type2bm:
|
| 125 |
+
final_benchmarks.extend(type2bm[bm])
|
| 126 |
+
else:
|
| 127 |
+
final_benchmarks.append(bm)
|
| 128 |
+
return final_benchmarks
|
| 129 |
+
|
| 130 |
+
def check_data_format(data):
|
| 131 |
+
for msg in data['messages']:
|
| 132 |
+
if isinstance(msg['content'], dict):
|
| 133 |
+
msg['content'] = [msg['content']]
|
| 134 |
+
for content in msg['content']:
|
| 135 |
+
assert content['type'] in {'image', 'video', 'text'}, f"content['type']={content['type']} MUST be one of ['image', 'video', 'text']"
|
| 136 |
+
if content['type'] != "text":
|
| 137 |
+
media_path_key = f"{content['type']}_file"
|
| 138 |
+
meida_paths = content[content['type']][media_path_key]
|
| 139 |
+
if isinstance(meida_paths, str):
|
| 140 |
+
meida_paths = [meida_paths]
|
| 141 |
+
for path in meida_paths:
|
| 142 |
+
assert os.path.exists(path), f"File not found: {path}"
|
| 143 |
+
|
| 144 |
+
def format_one_sample(media_file=None, prompt="Describe the video in detail."):
|
| 145 |
+
sample = {
|
| 146 |
+
"messages": []
|
| 147 |
+
}
|
| 148 |
+
user_content = {
|
| 149 |
+
"role": "user",
|
| 150 |
+
"content": []
|
| 151 |
+
}
|
| 152 |
+
if media_file is not None:
|
| 153 |
+
media_type = get_visual_type(media_file)
|
| 154 |
+
if media_type in ("video", "gif"):
|
| 155 |
+
media_type = "video"
|
| 156 |
+
media_path_key = f"{media_type}_file"
|
| 157 |
+
user_content["content"].append({
|
| 158 |
+
"type": media_type,
|
| 159 |
+
media_type: {
|
| 160 |
+
media_path_key: media_file,
|
| 161 |
+
}
|
| 162 |
+
})
|
| 163 |
+
user_content["content"].append({
|
| 164 |
+
"type": "text",
|
| 165 |
+
"text": prompt
|
| 166 |
+
})
|
| 167 |
+
|
| 168 |
+
assistant_content = {
|
| 169 |
+
"role": "assistant",
|
| 170 |
+
"content": []
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
sample["messages"].append(user_content)
|
| 174 |
+
sample["messages"].append(assistant_content)
|
| 175 |
+
if media_file is not None:
|
| 176 |
+
sample["task"] = f"{media_type}/QA"
|
| 177 |
+
else:
|
| 178 |
+
sample["task"] = 'text-only'
|
| 179 |
+
check_data_format(sample)
|
| 180 |
+
return sample
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class DictToObject(object):
|
| 184 |
+
def __init__(self, dictionary):
|
| 185 |
+
for key, value in dictionary.items():
|
| 186 |
+
setattr(self, key, value)
|
eval_scripts/DREAM-1K/tarsier/evaluation/evaluate.py
ADDED
|
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (2024) Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import os
|
| 15 |
+
import random
|
| 16 |
+
|
| 17 |
+
from .metrics import CIDErMetric, GPTMetric, DREAMGPTMetric, AccuracyMetric, VideoMMEAccuracyMetric
|
| 18 |
+
import sys
|
| 19 |
+
sys.path.append('eval_scripts/DREAM-1K/tarsier')
|
| 20 |
+
from tools.rw_utils import read_jsonlines
|
| 21 |
+
from tools.color import Color
|
| 22 |
+
from dataset.utils import get_benchmarks
|
| 23 |
+
|
| 24 |
+
def extract_item_for_eval(data_dict):
|
| 25 |
+
item = {}
|
| 26 |
+
prompt, prediction, reference = [], [], []
|
| 27 |
+
for msg in data_dict['messages']:
|
| 28 |
+
for content in msg['content']:
|
| 29 |
+
if content['type'] == 'text':
|
| 30 |
+
if msg['role'] == 'user':
|
| 31 |
+
prompt.append(content['text'])
|
| 32 |
+
elif msg['role'] == 'assistant':
|
| 33 |
+
if content.get('reference'):
|
| 34 |
+
reference.append(content['reference'])
|
| 35 |
+
prediction.append(content['text'])
|
| 36 |
+
# prediction.append(content['reference']) # debug
|
| 37 |
+
|
| 38 |
+
item['prompt'] = ''.join(prompt)
|
| 39 |
+
item['prediction'] = ''.join(prediction)
|
| 40 |
+
item['response'] = ''.join(reference)
|
| 41 |
+
|
| 42 |
+
item['dataset'] = data_dict['dataset']
|
| 43 |
+
item['idx'] = f"{data_dict['dataset']}_{data_dict['idx']}"
|
| 44 |
+
extra_info = data_dict.get('extra_info', None)
|
| 45 |
+
vid = data_dict.get('vid', None)
|
| 46 |
+
if vid is not None:
|
| 47 |
+
item['vid'] = vid
|
| 48 |
+
if extra_info:
|
| 49 |
+
item['events'] = extra_info.get('events', None)
|
| 50 |
+
item['extra_info'] = extra_info
|
| 51 |
+
if 'is_hard' in data_dict:
|
| 52 |
+
item['is_hard'] = data_dict['is_hard']
|
| 53 |
+
|
| 54 |
+
return item
|
| 55 |
+
|
| 56 |
+
def read_dataset(path, dataset_name):
|
| 57 |
+
if os.path.isdir(path):
|
| 58 |
+
lines = []
|
| 59 |
+
for f in os.listdir(path):
|
| 60 |
+
if f.endswith('.jsonl'):
|
| 61 |
+
lines.extend(read_jsonlines(os.path.join(path, f)))
|
| 62 |
+
else:
|
| 63 |
+
lines = read_jsonlines(path)
|
| 64 |
+
dataset = []
|
| 65 |
+
idxs = set()
|
| 66 |
+
for l in lines:
|
| 67 |
+
if l['dataset'].split('/')[0] != dataset_name:
|
| 68 |
+
continue
|
| 69 |
+
idx = f"{l['dataset']}_{l['idx']}"
|
| 70 |
+
if idx in idxs:
|
| 71 |
+
continue
|
| 72 |
+
|
| 73 |
+
idxs.add(idx)
|
| 74 |
+
item = extract_item_for_eval(l)
|
| 75 |
+
dataset.append(item)
|
| 76 |
+
return dataset
|
| 77 |
+
|
| 78 |
+
METRIC_MAPPING = {
|
| 79 |
+
'CIDErMetric': CIDErMetric,
|
| 80 |
+
'GPTMetric': GPTMetric,
|
| 81 |
+
'AccuracyMetric': AccuracyMetric,
|
| 82 |
+
'DREAMGPTMetric': DREAMGPTMetric,
|
| 83 |
+
'VideoMMEAccuracyMetric': VideoMMEAccuracyMetric
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
def evaluate(pred_file, METRIC, dataset_name, sample_num=-1, verbose = False):
|
| 87 |
+
dataset = read_dataset(pred_file, dataset_name)
|
| 88 |
+
if len(dataset) == 0:
|
| 89 |
+
return
|
| 90 |
+
if sample_num > 0:
|
| 91 |
+
dataset = random.sample(dataset, sample_num)
|
| 92 |
+
metric = METRIC(dataset_name=dataset_name, verbose=verbose)
|
| 93 |
+
metric.process(dataset)
|
| 94 |
+
metric.summarize_metric()
|
| 95 |
+
metric.save_results(pred_file)
|
| 96 |
+
if isinstance(metric, DREAMGPTMetric):
|
| 97 |
+
metric.save_eval_infos(pred_file)
|
| 98 |
+
|
| 99 |
+
def evaluate_all(pred_file, METRIC2DATASET, sample_num=-1, verbose = False):
|
| 100 |
+
for METRIC, dataset_name in METRIC2DATASET:
|
| 101 |
+
if isinstance(METRIC, str):
|
| 102 |
+
METRIC = METRIC_MAPPING[METRIC]
|
| 103 |
+
print(f"### Start Evaluating on {dataset_name}")
|
| 104 |
+
evaluate(pred_file, METRIC, dataset_name, sample_num, verbose)
|
| 105 |
+
print(f"### Finish Evaluating on {dataset_name}")
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
if __name__ == '__main__':
|
| 109 |
+
import argparse
|
| 110 |
+
|
| 111 |
+
parser = argparse.ArgumentParser()
|
| 112 |
+
parser.add_argument('--pred_file', type=str)
|
| 113 |
+
parser.add_argument('--benchmarks', nargs='+', default=["all"], help="Default as 'all' to evaluate on all benchmarks; Also could be task types: ('dream', 'caption', 'mc_qa', 'oe_qa'); And specific benchmark names: ('dream', 'msvd-caption', 'msr-vtt-caption', 'vatex-caption', 'next-qa', 'egoschema', 'mvbench', 'tvbench', 'video-mme', 'msvd-qa', 'msr-vtt-qa', 'tgif-qa', 'anet-qa', 'favor-bench')")
|
| 114 |
+
parser.add_argument('--sample_num', type=int, default=-1)
|
| 115 |
+
parser.add_argument('--verbose', action='store_true')
|
| 116 |
+
|
| 117 |
+
args = parser.parse_args()
|
| 118 |
+
|
| 119 |
+
args.benchmarks = get_benchmarks(args.benchmarks)
|
| 120 |
+
print("### Selected Benchmarks:", args.benchmarks)
|
| 121 |
+
|
| 122 |
+
Benchmark2Metric = {
|
| 123 |
+
# Multi-chocie QA
|
| 124 |
+
'next-qa': 'AccuracyMetric',
|
| 125 |
+
'egoschema': 'AccuracyMetric',
|
| 126 |
+
'mvbench': 'AccuracyMetric',
|
| 127 |
+
'tvbench': 'AccuracyMetric',
|
| 128 |
+
'video-mme': 'VideoMMEAccuracyMetric',
|
| 129 |
+
'favor-bench': 'AccuracyMetric',
|
| 130 |
+
|
| 131 |
+
# Open-ended QA
|
| 132 |
+
'msvd-qa': 'GPTMetric',
|
| 133 |
+
'msr-vtt-qa': 'GPTMetric',
|
| 134 |
+
'tgif-qa': 'GPTMetric',
|
| 135 |
+
'anet-qa': 'GPTMetric',
|
| 136 |
+
|
| 137 |
+
# Caption DREAM
|
| 138 |
+
'dream': 'DREAMGPTMetric',
|
| 139 |
+
|
| 140 |
+
# Caption CIDEr
|
| 141 |
+
'msvd-caption': 'CIDErMetric',
|
| 142 |
+
'msr-vtt-caption': 'CIDErMetric',
|
| 143 |
+
'vatex-caption': 'CIDErMetric',
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
Benchmark2Dataset = {
|
| 147 |
+
'dream': 'DREAM',
|
| 148 |
+
|
| 149 |
+
'next-qa': 'Next-QA-val-multi_choice',
|
| 150 |
+
'egoschema': 'EgoSchema',
|
| 151 |
+
'mvbench': 'MVBench',
|
| 152 |
+
'tvbench': 'TVBench',
|
| 153 |
+
'video-mme': 'Video-MME',
|
| 154 |
+
'favor-bench': 'FAVOR-Bench',
|
| 155 |
+
|
| 156 |
+
'msvd-qa': 'MSVD-QA-val',
|
| 157 |
+
'msr-vtt-qa': 'MSR-VTT-QA-val',
|
| 158 |
+
'tgif-qa': 'TGIF-QA-test',
|
| 159 |
+
'anet-qa': 'ActivityNet-QA-test',
|
| 160 |
+
|
| 161 |
+
'msvd-caption': 'MSVD-Caption-test',
|
| 162 |
+
'msr-vtt-caption': 'MSR-VTT-Caption-test',
|
| 163 |
+
'vatex-caption': 'VATEX-test',
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
METRIC2DATASET = []
|
| 167 |
+
|
| 168 |
+
for bm in args.benchmarks:
|
| 169 |
+
if bm not in Benchmark2Metric:
|
| 170 |
+
print(Color.red(f"Unknown benchmark: {bm}"))
|
| 171 |
+
continue
|
| 172 |
+
|
| 173 |
+
METRIC2DATASET.append([Benchmark2Metric[bm], Benchmark2Dataset[bm]])
|
| 174 |
+
|
| 175 |
+
evaluate_all(args.pred_file, METRIC2DATASET, args.sample_num, args.verbose)
|
| 176 |
+
|
| 177 |
+
# python3 -m evaluation.evaluate --pred_file $pred_file --sample_num=100
|
eval_scripts/DREAM-1K/tarsier/evaluation/metrics/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .evaluate_caption_cider import CIDErMetric
|
| 2 |
+
from .evaluate_qa_oe_gpt import GPTMetric
|
| 3 |
+
from .evaluate_qa_mc import AccuracyMetric
|
| 4 |
+
from .evaluate_dream_gpt import DREAMGPTMetric
|
| 5 |
+
from .evaluate_video_mme import VideoMMEAccuracyMetric
|
eval_scripts/DREAM-1K/tarsier/evaluation/metrics/evaluate_caption_cider.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (2024) Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import json
|
| 15 |
+
from typing import List, Dict
|
| 16 |
+
import os
|
| 17 |
+
from pycocoevalcap.cider.cider import Cider
|
| 18 |
+
import sys
|
| 19 |
+
sys.path.append('eval_scripts/DREAM-1K/tarsier')
|
| 20 |
+
from tools.ptbtokenizer import PTBTokenizer
|
| 21 |
+
|
| 22 |
+
from tools.color import Color
|
| 23 |
+
|
| 24 |
+
class CIDErMetric:
|
| 25 |
+
def __init__(self, dataset_name, verbose=False) -> None:
|
| 26 |
+
self.dataset_name = dataset_name
|
| 27 |
+
self.tokenizer = PTBTokenizer()
|
| 28 |
+
self.scorer = Cider()
|
| 29 |
+
self.score = None
|
| 30 |
+
self.results = []
|
| 31 |
+
self.dataset = []
|
| 32 |
+
self.verbose = verbose
|
| 33 |
+
|
| 34 |
+
def add(self, data):
|
| 35 |
+
self.dataset.append(data)
|
| 36 |
+
|
| 37 |
+
def process(self, dataset: List[Dict]):
|
| 38 |
+
references, predictions = {}, {}
|
| 39 |
+
for i, data in enumerate(dataset):
|
| 40 |
+
ref = data['response']
|
| 41 |
+
pred = data['prediction']
|
| 42 |
+
|
| 43 |
+
if isinstance(ref, str):
|
| 44 |
+
ref = [ref]
|
| 45 |
+
|
| 46 |
+
references[i] = [{'caption': r.lower()} for r in ref]
|
| 47 |
+
predictions[i] = [{'caption': pred.lower()}]
|
| 48 |
+
|
| 49 |
+
references = self.tokenizer.tokenize(references)
|
| 50 |
+
predictions = self.tokenizer.tokenize(predictions)
|
| 51 |
+
score, scores = self.scorer.compute_score(references, predictions)
|
| 52 |
+
self.score = score
|
| 53 |
+
for data, s in zip(dataset, scores):
|
| 54 |
+
self.results.append({
|
| 55 |
+
'score': s,
|
| 56 |
+
'data': data,
|
| 57 |
+
})
|
| 58 |
+
|
| 59 |
+
def summarize_metric(self):
|
| 60 |
+
if self.verbose:
|
| 61 |
+
for result in self.results:
|
| 62 |
+
print(Color.blue(json.dumps(result['data'])))
|
| 63 |
+
print(Color.red(f"CIDEr score: {result['score']}"))
|
| 64 |
+
print(f'=====Evaluation Summary=====')
|
| 65 |
+
self.eval_records = [
|
| 66 |
+
f'Dataset: {self.dataset_name}\tMetric: CIDEr',
|
| 67 |
+
f'#Successful Results: {len(self.results)}',
|
| 68 |
+
f'CIDEr score: {round(self.score*100, 1)}'
|
| 69 |
+
]
|
| 70 |
+
for info in self.eval_records:
|
| 71 |
+
print(info)
|
| 72 |
+
|
| 73 |
+
def save_results(self, pred_path):
|
| 74 |
+
if os.path.isdir(pred_path):
|
| 75 |
+
output_dir = os.path.join(pred_path, 'eval_records')
|
| 76 |
+
else:
|
| 77 |
+
output_dir = os.path.join(os.path.dirname(pred_path), 'eval_records')
|
| 78 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 79 |
+
fout = open(os.path.join(output_dir, f'{self.dataset_name}_eval_result.txt'), 'w')
|
| 80 |
+
for info in self.eval_records:
|
| 81 |
+
fout.write(info+'\n')
|
| 82 |
+
fout.close()
|
eval_scripts/DREAM-1K/tarsier/evaluation/metrics/evaluate_dream_gpt.py
ADDED
|
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (2024) Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import json
|
| 15 |
+
import numpy as np
|
| 16 |
+
import ast
|
| 17 |
+
import time
|
| 18 |
+
from typing import List, Dict
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
from pathos.multiprocessing import ProcessingPool as Pool
|
| 21 |
+
import func_timeout
|
| 22 |
+
from func_timeout import func_set_timeout
|
| 23 |
+
|
| 24 |
+
import sys
|
| 25 |
+
sys.path.append('eval_scripts/DREAM-1K/tarsier')
|
| 26 |
+
import re
|
| 27 |
+
import os
|
| 28 |
+
from copy import deepcopy
|
| 29 |
+
from traceback import format_exc
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
with open("apikey.txt", "r") as f:
|
| 33 |
+
api_key = f.read()
|
| 34 |
+
except:
|
| 35 |
+
api_key = ''
|
| 36 |
+
|
| 37 |
+
def call_gpt35(msg):
|
| 38 |
+
while True:
|
| 39 |
+
try:
|
| 40 |
+
response = openai.ChatCompletion.create(
|
| 41 |
+
model="gpt-3.5-turbo",
|
| 42 |
+
messages=msg,
|
| 43 |
+
api_key=api_key,
|
| 44 |
+
request_timeout=5)
|
| 45 |
+
break
|
| 46 |
+
except:
|
| 47 |
+
print("Timeout, retrying...")
|
| 48 |
+
time.sleep(5)
|
| 49 |
+
|
| 50 |
+
output_text = response['choices'][0]['message']['content']
|
| 51 |
+
return output_text
|
| 52 |
+
|
| 53 |
+
def count_f1(r, p):
|
| 54 |
+
return 2*r*p/(r+p)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def call_azure_gpt_api(events, reference, prediction, model):
|
| 58 |
+
if len(events) == 0:
|
| 59 |
+
events = [reference.replace('\n', ' ')]
|
| 60 |
+
messages=[
|
| 61 |
+
{
|
| 62 |
+
"role": "user",
|
| 63 |
+
"content":
|
| 64 |
+
"Given a video description and a list of events. For each event, classify the relationship between the video description and the event into three classes: entailment, neutral, contradiction.\n"
|
| 65 |
+
"- \"entailment\" means that the video description entails the event.\n"
|
| 66 |
+
"- \"contradiction\" means that some detail in the video description contradicts with the event.\n"
|
| 67 |
+
"- \"neutral\" means that the relationship is neither \"entailment\" or \"contradiction\".\n\n"
|
| 68 |
+
f"Video Description:\n{prediction}\n\n"
|
| 69 |
+
f"Events: {events}\n"
|
| 70 |
+
|
| 71 |
+
"Output a JSON formed as:\n"
|
| 72 |
+
"{\n"
|
| 73 |
+
" \"events\": [\n"
|
| 74 |
+
" {\"event\": \"copy an event here\", \"relationship\": \"put class name here\", \"reason\": \"give your reason here\"},\n"
|
| 75 |
+
" ...\n"
|
| 76 |
+
" ]\n"
|
| 77 |
+
"}\n\n"
|
| 78 |
+
"DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only output the JSON. Output:"
|
| 79 |
+
}
|
| 80 |
+
]
|
| 81 |
+
|
| 82 |
+
completion = call_gpt35(messages)
|
| 83 |
+
return completion
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def call_azure_gpt_api_for_events(caption, model):
|
| 87 |
+
messages=[
|
| 88 |
+
{
|
| 89 |
+
"role": "user",
|
| 90 |
+
"content":
|
| 91 |
+
"Bellow is a description of a video clip:\n"
|
| 92 |
+
f"Video Description: {caption}\n\n"
|
| 93 |
+
|
| 94 |
+
"Extract at most 10 key events from the above video description paragraph. Requirements\n:"
|
| 95 |
+
"- An event must include an action, motion or movement (NOT STATIC INFORMATION). DON'T repeat same events.\n"
|
| 96 |
+
"- Every event is represented by a brief sentence within 10 words, with a subject, a predicate and optionally an object, avoid unnecessary appearance descriptions.\n"
|
| 97 |
+
"- Every event must be atomic, meaning that it cannot be further split into multiple events.\n"
|
| 98 |
+
"- Scene cuts and camera motions are NOT events.\n"
|
| 99 |
+
"- Substitute pronouns by the nouns they refer to.\n\n"
|
| 100 |
+
"Please generate the response in the form of a Python dictionary string with keys \"events\". The value of \"events\" is a List(str), of which each item is an event. "
|
| 101 |
+
"DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. "
|
| 102 |
+
"For example, your response should look like this: {\"events\": [event1, event2, ...]}"
|
| 103 |
+
}
|
| 104 |
+
]
|
| 105 |
+
|
| 106 |
+
completion = call_gpt35(messages)
|
| 107 |
+
return completion
|
| 108 |
+
|
| 109 |
+
def try_call_api_for_eval(events, answer, prediction, model, verbose=False, max_retry=100):
|
| 110 |
+
for i in range(max_retry):
|
| 111 |
+
gpt_q = call_azure_gpt_api(events, answer, prediction, model)
|
| 112 |
+
if gpt_q is not None:
|
| 113 |
+
gpt_q = gpt_q.strip()
|
| 114 |
+
gpt_q = re.sub(r'\n+', '\n', gpt_q)
|
| 115 |
+
gpt_q = re.sub(r'\s+', ' ', gpt_q)
|
| 116 |
+
|
| 117 |
+
if gpt_q.startswith("```json"):
|
| 118 |
+
gpt_q = gpt_q.replace("```json", "").replace("```", "").strip()
|
| 119 |
+
elif gpt_q.startswith("```python"):
|
| 120 |
+
gpt_q = gpt_q.replace("```python", "").replace("```", "").strip()
|
| 121 |
+
if not gpt_q.startswith('{'):
|
| 122 |
+
gpt_q = '{' + gpt_q
|
| 123 |
+
if not gpt_q.endswith('}'):
|
| 124 |
+
gpt_q = gpt_q + '}'
|
| 125 |
+
gpt_q = gpt_q.replace("True", "true").replace("False", "false")
|
| 126 |
+
gpt_q = gpt_q.replace("} {", "}, {").replace("}{", "}, {")
|
| 127 |
+
gpt_q = gpt_q.replace(",\n}", "\n}").replace(", \n}", "\n}").replace(", }", "}").replace(",}", "}")
|
| 128 |
+
gpt_q = gpt_q.replace(",\n]", "\n]").replace(", \n]", "\n]").replace(", ]", "]").replace(",]", "]")
|
| 129 |
+
gpt_q = gpt_q.replace("[Placeholder]", "null")
|
| 130 |
+
gpt_q = gpt_q.replace("{Events:", "").strip()
|
| 131 |
+
|
| 132 |
+
return gpt_q, True
|
| 133 |
+
|
| 134 |
+
return f"Exceed max try: {max_retry}", False
|
| 135 |
+
|
| 136 |
+
def try_call_api_for_events(caption, model, verbose=False):
|
| 137 |
+
for i in range(100):
|
| 138 |
+
gpt_q = call_azure_gpt_api_for_events(caption, model)
|
| 139 |
+
if gpt_q is not None:
|
| 140 |
+
if gpt_q.startswith("```json"):
|
| 141 |
+
gpt_q = gpt_q.replace("```json", "").replace("```", "").strip()
|
| 142 |
+
elif gpt_q.startswith("```python"):
|
| 143 |
+
gpt_q = gpt_q.replace("```python", "").replace("```", "").strip()
|
| 144 |
+
return gpt_q, True
|
| 145 |
+
|
| 146 |
+
return "Exceed max try: 5", False
|
| 147 |
+
|
| 148 |
+
def extract_events(inputs, is_pred=False, max_retry=100):
|
| 149 |
+
data, model, verbose = inputs
|
| 150 |
+
if is_pred:
|
| 151 |
+
caption = data['prediction'].lower()
|
| 152 |
+
else:
|
| 153 |
+
caption = data['response'].lower()
|
| 154 |
+
caption = caption.replace("\"", "\'")
|
| 155 |
+
retry = 0
|
| 156 |
+
while True and (retry<max_retry or max_retry<0):
|
| 157 |
+
retry += 1
|
| 158 |
+
result, success = try_call_api_for_events(caption, model, verbose)
|
| 159 |
+
if not success:
|
| 160 |
+
print(f"[error]: try_call_api_for_events failed!", flush=True)
|
| 161 |
+
continue
|
| 162 |
+
try:
|
| 163 |
+
result = ast.literal_eval(result)
|
| 164 |
+
events = result['events']
|
| 165 |
+
if verbose:
|
| 166 |
+
print("pred_events=" if is_pred else "gt events=", events, ":", caption)
|
| 167 |
+
assert isinstance(events, list) and (len(events)==0 or isinstance(events[0], str))
|
| 168 |
+
return events
|
| 169 |
+
except Exception as e:
|
| 170 |
+
print(format_exc(), flush=True)
|
| 171 |
+
continue
|
| 172 |
+
print("[error]: Exceed max_retry!", flush=True)
|
| 173 |
+
raise ValueError("[error]: Exceed max_retry!")
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def evaluate_one_sample(events, response, prediction, model, verbose, return_hit_num=False, is_recall=False, max_retry=100):
|
| 177 |
+
retry = 0
|
| 178 |
+
while True and (retry<max_retry or max_retry<0):
|
| 179 |
+
retry += 1
|
| 180 |
+
try:
|
| 181 |
+
assert isinstance(events, list)
|
| 182 |
+
result = None
|
| 183 |
+
result, success = try_call_api_for_eval(events, response, prediction, model, verbose)
|
| 184 |
+
if not success:
|
| 185 |
+
print("[error]: try_call_api_for_eval failed!", flush=True)
|
| 186 |
+
continue
|
| 187 |
+
try:
|
| 188 |
+
events_filled = json.loads(result)
|
| 189 |
+
events_filled = events_filled['events']
|
| 190 |
+
except Exception as e:
|
| 191 |
+
print("load json failed:", result)
|
| 192 |
+
continue
|
| 193 |
+
assert len(events) == len(events_filled) or (len(events) == 0 and len(events_filled) == 1)
|
| 194 |
+
num_matched_events = 0
|
| 195 |
+
try:
|
| 196 |
+
for event in events_filled:
|
| 197 |
+
pred = event['relationship'].strip().lower()
|
| 198 |
+
assert pred in ['entailment', 'neutral', 'contradiction']
|
| 199 |
+
pos_classes = ['entailment'] if is_recall else ['entailment', 'neutral']
|
| 200 |
+
if pred in pos_classes:
|
| 201 |
+
num_matched_events += 1
|
| 202 |
+
except Exception as e:
|
| 203 |
+
print(f"Invalid response: {events_filled}")
|
| 204 |
+
continue
|
| 205 |
+
if len(events) == 0:
|
| 206 |
+
motion_score = 1.0
|
| 207 |
+
else:
|
| 208 |
+
motion_score = num_matched_events / len(events)
|
| 209 |
+
if return_hit_num:
|
| 210 |
+
return motion_score, events_filled, f"hit: {num_matched_events} / {len(events)}"
|
| 211 |
+
return motion_score
|
| 212 |
+
except Exception as e:
|
| 213 |
+
print(format_exc(), flush=True)
|
| 214 |
+
continue
|
| 215 |
+
time.sleep(1)
|
| 216 |
+
print("[error]: Exceed max_retry!", flush=True)
|
| 217 |
+
raise ValueError(f"[error]: Exceed max_retry!")
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def process_one_sample(inputs):
|
| 221 |
+
data, model, verbose = inputs
|
| 222 |
+
response, prediction = data['response'].lower(), data['prediction'].lower()
|
| 223 |
+
result = None
|
| 224 |
+
try:
|
| 225 |
+
if isinstance(data.get('events', None), list):
|
| 226 |
+
gt_events = data['events']
|
| 227 |
+
else:
|
| 228 |
+
gt_events = extract_events(inputs, is_pred=False)
|
| 229 |
+
pred_events = extract_events(inputs, is_pred=True)
|
| 230 |
+
assert isinstance(gt_events, list) and isinstance(pred_events, list)
|
| 231 |
+
result = {}
|
| 232 |
+
motion_score_r, events_filled_r, hit_num_r = evaluate_one_sample(gt_events, response, prediction, model, verbose, return_hit_num=True, is_recall=True)
|
| 233 |
+
motion_score_p, events_filled_p, hit_num_p = evaluate_one_sample(pred_events, prediction, response, model, verbose, return_hit_num=True, is_recall=True)
|
| 234 |
+
result['score_r'] = motion_score_r
|
| 235 |
+
result['score_p'] = motion_score_p
|
| 236 |
+
result['eval_infos'] = {
|
| 237 |
+
'idx': data['idx'],
|
| 238 |
+
'gt': response,
|
| 239 |
+
'pred': prediction,
|
| 240 |
+
'events_gt': events_filled_r,
|
| 241 |
+
'hit_num_recall': hit_num_r,
|
| 242 |
+
'events_pred': events_filled_p,
|
| 243 |
+
"hit_num_precision": hit_num_p,
|
| 244 |
+
}
|
| 245 |
+
if 'extra_info' in data:
|
| 246 |
+
result['extra_info'] = data['extra_info']
|
| 247 |
+
except Exception as e:
|
| 248 |
+
if verbose:
|
| 249 |
+
print(e)
|
| 250 |
+
print(f'invalid GPT response: {result}')
|
| 251 |
+
result = None
|
| 252 |
+
return {'success': False, 'result': result, 'data': data}
|
| 253 |
+
return {'success': True, 'result': result, 'data': data}
|
| 254 |
+
|
| 255 |
+
class DREAMGPTMetric:
|
| 256 |
+
def __init__(self, dataset_name, verbose=False) -> None:
|
| 257 |
+
self.dataset_name = dataset_name
|
| 258 |
+
self.num_worker = 64
|
| 259 |
+
# self.model = 'gpt-35-turbo'
|
| 260 |
+
self.model = 'gpt-35-turbo-0125'
|
| 261 |
+
# self.model='gpt-4-1106-preview'
|
| 262 |
+
self.results = []
|
| 263 |
+
self.invalid_results = []
|
| 264 |
+
self.dataset = []
|
| 265 |
+
self.verbose = verbose
|
| 266 |
+
self.eval_infos = []
|
| 267 |
+
self.buckets = {
|
| 268 |
+
"subjects": {
|
| 269 |
+
'<=1': [], '==2': [], '==3': [], '>=4': []
|
| 270 |
+
},
|
| 271 |
+
"shots": {'<=1': [], '==2': [], '==3': [], '>=4': []
|
| 272 |
+
},
|
| 273 |
+
"events": {'<=3': [], 'in [4, 5]': [], 'in [6, 7]': [], '>=8': []
|
| 274 |
+
}
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
def add(self, data):
|
| 278 |
+
self.dataset.append(data)
|
| 279 |
+
|
| 280 |
+
def select_bucket(self, bucket_name, num):
|
| 281 |
+
for key in self.buckets[bucket_name]:
|
| 282 |
+
if eval(f"{num}{key}"):
|
| 283 |
+
return key
|
| 284 |
+
return ''
|
| 285 |
+
|
| 286 |
+
def add_to_bucket(self, bucket_name, data):
|
| 287 |
+
sub_bucket = self.select_bucket(bucket_name, data['result']['extra_info'][f'n_{bucket_name}'])
|
| 288 |
+
if sub_bucket:
|
| 289 |
+
self.buckets[bucket_name][sub_bucket].append(data)
|
| 290 |
+
|
| 291 |
+
def process(self, dataset: List[Dict]):
|
| 292 |
+
self._process_group_by_subtask(dataset)
|
| 293 |
+
|
| 294 |
+
def _process(self, dataset: List[Dict], subtask=None):
|
| 295 |
+
pool = Pool(processes = self.num_worker, )
|
| 296 |
+
inputs = [(d, self.model, self.verbose) for d in dataset]
|
| 297 |
+
results = pool.uimap(process_one_sample, inputs, chunksize = 1)
|
| 298 |
+
|
| 299 |
+
for result in tqdm(results, total = len(dataset), desc=f'eval {subtask}'):
|
| 300 |
+
if subtask:
|
| 301 |
+
result['subtask'] = subtask
|
| 302 |
+
self.update_metric(result)
|
| 303 |
+
pool.close()
|
| 304 |
+
pool.join()
|
| 305 |
+
pool.clear() # MUST
|
| 306 |
+
|
| 307 |
+
def _process_group_by_subtask(self, dataset: List[Dict]):
|
| 308 |
+
def _group_by_subtask(dataset):
|
| 309 |
+
subtasks = {}
|
| 310 |
+
for data in dataset:
|
| 311 |
+
if data['dataset'] not in subtasks:
|
| 312 |
+
subtasks[data['dataset']] = []
|
| 313 |
+
subtasks[data['dataset']].append(data)
|
| 314 |
+
return subtasks
|
| 315 |
+
subtasks = _group_by_subtask(dataset)
|
| 316 |
+
for subtask, subdata in subtasks.items():
|
| 317 |
+
self._process(subdata, subtask)
|
| 318 |
+
|
| 319 |
+
def update_metric(self, result):
|
| 320 |
+
if result['success']:
|
| 321 |
+
self.results.append(result)
|
| 322 |
+
else:
|
| 323 |
+
self.invalid_results.append(result)
|
| 324 |
+
|
| 325 |
+
def summarize_metric(self):
|
| 326 |
+
self._summarize_metric_by_subtask()
|
| 327 |
+
self._summarize_metric_by_bucket()
|
| 328 |
+
|
| 329 |
+
def _summarize_metric_by_subtask(self):
|
| 330 |
+
from prettytable import PrettyTable
|
| 331 |
+
self.table = PrettyTable(['Task', 'F1 Score', 'Action Recall', 'Action Precision', 'Success', 'Failed'])
|
| 332 |
+
def _group_by_subtask():
|
| 333 |
+
sub_results = {}
|
| 334 |
+
sub_invalid_results = {}
|
| 335 |
+
for data in self.results:
|
| 336 |
+
if data['subtask'] not in sub_results:
|
| 337 |
+
sub_results[data['subtask']] = []
|
| 338 |
+
sub_results[data['subtask']].append(data)
|
| 339 |
+
for data in self.invalid_results:
|
| 340 |
+
if data['subtask'] not in sub_invalid_results:
|
| 341 |
+
sub_invalid_results[data['subtask']] = []
|
| 342 |
+
sub_invalid_results[data['subtask']].append(data)
|
| 343 |
+
return sub_results, sub_invalid_results
|
| 344 |
+
sub_results, sub_invalid_results = _group_by_subtask()
|
| 345 |
+
overall_avg_recall = []
|
| 346 |
+
overall_avg_precision = []
|
| 347 |
+
subtasks = list(sub_results.keys())
|
| 348 |
+
subtasks.sort()
|
| 349 |
+
for subtask in subtasks:
|
| 350 |
+
sub_rsts = sub_results[subtask]
|
| 351 |
+
sub_in_rsts = sub_invalid_results.get(subtask, [])
|
| 352 |
+
recalls = []
|
| 353 |
+
precisions = []
|
| 354 |
+
for result in sub_rsts:
|
| 355 |
+
r, p, infos = result['result']['score_r'], result['result']['score_p'], result['result']['eval_infos']
|
| 356 |
+
recalls.append(r)
|
| 357 |
+
precisions.append(p)
|
| 358 |
+
self.eval_infos.append(infos)
|
| 359 |
+
avg_recall = np.average(recalls)
|
| 360 |
+
avg_precision = np.average(precisions)
|
| 361 |
+
f1 = count_f1(avg_recall, avg_precision)
|
| 362 |
+
overall_avg_recall.append(avg_recall)
|
| 363 |
+
overall_avg_precision.append(avg_precision)
|
| 364 |
+
task_name = subtask
|
| 365 |
+
self.table.add_row([task_name, round(f1, 3), round(avg_recall, 3), round(avg_precision, 3), len(sub_rsts), len(sub_in_rsts)])
|
| 366 |
+
overall_recall = np.average(overall_avg_recall)
|
| 367 |
+
overall_precision = np.average(overall_avg_precision)
|
| 368 |
+
overall_f1 = count_f1(overall_recall, overall_precision)
|
| 369 |
+
self.table.add_row(['OVERALL', round(overall_f1, 3), round(overall_recall, 3), round(overall_precision, 3), len(self.results), len(self.invalid_results)])
|
| 370 |
+
print(f'=====DREAM Evaluation Summary=====')
|
| 371 |
+
print(self.table)
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def _summarize_metric_by_bucket(self):
|
| 375 |
+
from prettytable import PrettyTable
|
| 376 |
+
self.bucket_tables = []
|
| 377 |
+
for bucket in self.buckets:
|
| 378 |
+
table = PrettyTable(['Score'] + list(self.buckets[bucket].keys()))
|
| 379 |
+
for data in self.results:
|
| 380 |
+
self.add_to_bucket(bucket_name=bucket, data=data)
|
| 381 |
+
bucket_result = {}
|
| 382 |
+
for sub_bucket in self.buckets[bucket]:
|
| 383 |
+
recalls = []
|
| 384 |
+
precisions = []
|
| 385 |
+
for result in self.buckets[bucket][sub_bucket]:
|
| 386 |
+
r, p = result['result']['score_r'], result['result']['score_p']
|
| 387 |
+
recalls.append(r)
|
| 388 |
+
precisions.append(p)
|
| 389 |
+
avg_recall = np.average(recalls)
|
| 390 |
+
avg_precision = np.average(precisions)
|
| 391 |
+
f1 = count_f1(avg_recall, avg_precision)
|
| 392 |
+
bucket_result[sub_bucket] = (avg_recall, avg_precision, f1)
|
| 393 |
+
|
| 394 |
+
raw = []
|
| 395 |
+
scores = ['Recall', 'Precision', 'F1']
|
| 396 |
+
for i in range(len(scores)):
|
| 397 |
+
raw = [scores[i]]
|
| 398 |
+
for sub_bucket in bucket_result:
|
| 399 |
+
raw.append(round(bucket_result[sub_bucket][i], 3))
|
| 400 |
+
table.add_row(raw)
|
| 401 |
+
sample_num = ['Count']
|
| 402 |
+
for k in self.buckets[bucket]:
|
| 403 |
+
sample_num.append(len(self.buckets[bucket][k]))
|
| 404 |
+
table.add_row(sample_num)
|
| 405 |
+
bucket_info = f'\n=====DREAM Evaluation Split by Bucket #{bucket}====='
|
| 406 |
+
print(bucket_info)
|
| 407 |
+
print(table)
|
| 408 |
+
self.bucket_tables.append(bucket_info)
|
| 409 |
+
self.bucket_tables.append(deepcopy(table))
|
| 410 |
+
|
| 411 |
+
def save_results(self, pred_path):
|
| 412 |
+
if os.path.isdir(pred_path):
|
| 413 |
+
output_dir = os.path.join(pred_path, 'eval_records')
|
| 414 |
+
else:
|
| 415 |
+
output_dir = os.path.join(os.path.dirname(pred_path), 'eval_records')
|
| 416 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 417 |
+
model_flag = os.path.basename(pred_path).split('.')[0]
|
| 418 |
+
fout = open(os.path.join(output_dir, f'{self.dataset_name}_{model_flag}_eval_result.txt'), 'w')
|
| 419 |
+
print(self.table, file=fout)
|
| 420 |
+
for bucket_info in self.bucket_tables:
|
| 421 |
+
print(bucket_info)
|
| 422 |
+
fout.close()
|
| 423 |
+
|
| 424 |
+
def save_eval_infos(self, pred_path):
|
| 425 |
+
if os.path.isdir(pred_path):
|
| 426 |
+
output_dir = os.path.join(pred_path, 'eval_records')
|
| 427 |
+
else:
|
| 428 |
+
output_dir = os.path.join(os.path.dirname(pred_path), 'eval_records')
|
| 429 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 430 |
+
model_flag = os.path.basename(pred_path).split('.')[0]
|
| 431 |
+
fout = open(os.path.join(output_dir, f'DREAM_{model_flag}_eval_infos.jsonl'), 'w')
|
| 432 |
+
for info in self.eval_infos:
|
| 433 |
+
fout.write(json.dumps(info) +'\n')
|
| 434 |
+
fout.close()
|
| 435 |
+
print(f"DREAM evaluation information saved in: {os.path.join(output_dir, 'DREAM_eval_infos.jsonl')}", flush=True)
|
| 436 |
+
|
eval_scripts/DREAM-1K/tarsier/evaluation/metrics/evaluate_qa_mc.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (2024) Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import json
|
| 15 |
+
import numpy as np
|
| 16 |
+
import os
|
| 17 |
+
from typing import List, Dict
|
| 18 |
+
import sys
|
| 19 |
+
sys.path.append('eval_scripts/DREAM-1K/tarsier')
|
| 20 |
+
from tools.color import Color
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class AccuracyMetric:
|
| 24 |
+
def __init__(self, dataset_name, verbose=False) -> None:
|
| 25 |
+
self.dataset_name = dataset_name
|
| 26 |
+
self.results = []
|
| 27 |
+
self.invalid_results = []
|
| 28 |
+
self.dataset = []
|
| 29 |
+
self.verbose = verbose
|
| 30 |
+
|
| 31 |
+
def add(self, data):
|
| 32 |
+
self.dataset.append(data)
|
| 33 |
+
|
| 34 |
+
def process(self, dataset: List[Dict]):
|
| 35 |
+
if self.dataset_name in ['MVBench', 'TVBench', 'FAVOR-Bench']:
|
| 36 |
+
return self._process_group_by_subtask(dataset)
|
| 37 |
+
else:
|
| 38 |
+
return self._process(dataset)
|
| 39 |
+
|
| 40 |
+
def _process(self, dataset: List[Dict], subtask=None):
|
| 41 |
+
for data in dataset:
|
| 42 |
+
prompt, response, prediction = data['prompt'], data['response'], data['prediction']
|
| 43 |
+
prediction = prediction.replace('(', '').replace(')', '').strip()
|
| 44 |
+
response = response.replace('(', '').replace(')', '').strip()[0]
|
| 45 |
+
if len(prediction) <= 0:
|
| 46 |
+
success = False
|
| 47 |
+
else:
|
| 48 |
+
prediction = prediction[0]
|
| 49 |
+
if '0'<=prediction<='5':
|
| 50 |
+
prediction = chr(int(prediction) + ord('A'))
|
| 51 |
+
success = prediction.isupper() and prediction.isalpha() and len(prediction) == 1
|
| 52 |
+
if success:
|
| 53 |
+
rst = {
|
| 54 |
+
'success': success,
|
| 55 |
+
'data': data,
|
| 56 |
+
'result': {'acc': response == prediction}
|
| 57 |
+
}
|
| 58 |
+
if subtask:
|
| 59 |
+
rst['subtask'] = subtask
|
| 60 |
+
self.results.append(rst)
|
| 61 |
+
else:
|
| 62 |
+
rst = {
|
| 63 |
+
'success': success,
|
| 64 |
+
'data': data,
|
| 65 |
+
'result': {'acc': response == prediction}
|
| 66 |
+
}
|
| 67 |
+
if subtask:
|
| 68 |
+
rst['subtask'] = subtask
|
| 69 |
+
self.invalid_results.append(rst)
|
| 70 |
+
|
| 71 |
+
def _process_group_by_subtask(self, dataset: List[Dict]):
|
| 72 |
+
def _group_by_subtask(dataset):
|
| 73 |
+
subtasks = {}
|
| 74 |
+
for data in dataset:
|
| 75 |
+
if data['dataset'] not in subtasks:
|
| 76 |
+
subtasks[data['dataset']] = []
|
| 77 |
+
subtasks[data['dataset']].append(data)
|
| 78 |
+
return subtasks
|
| 79 |
+
subtasks = _group_by_subtask(dataset)
|
| 80 |
+
for subtask, subdata in subtasks.items():
|
| 81 |
+
self._process(subdata, subtask)
|
| 82 |
+
|
| 83 |
+
def summarize_metric(self):
|
| 84 |
+
if self.dataset_name in ['MVBench', 'TVBench', 'FAVOR-Bench']:
|
| 85 |
+
return self._summarize_metric_by_subtask()
|
| 86 |
+
else:
|
| 87 |
+
return self._summarize_metric()
|
| 88 |
+
|
| 89 |
+
def _summarize_metric(self):
|
| 90 |
+
if self.verbose:
|
| 91 |
+
for result in self.results + self.invalid_results:
|
| 92 |
+
print(f"{Color.red('Success: ' + str(result['success']))}")
|
| 93 |
+
print(Color.blue(json.dumps(result['data'], ensure_ascii=False)))
|
| 94 |
+
print(f"{Color.green('Accuracy: ' + str(result['result']['acc']))}")
|
| 95 |
+
|
| 96 |
+
accs = []
|
| 97 |
+
for result in self.results:
|
| 98 |
+
acc = result['result']['acc']
|
| 99 |
+
accs.append(acc)
|
| 100 |
+
avg_acc = np.average(accs)
|
| 101 |
+
|
| 102 |
+
self.eval_records = [
|
| 103 |
+
f'=====Evaluation Summary=====',
|
| 104 |
+
f'Dataset: {self.dataset_name}\tMetric: Accuracy',
|
| 105 |
+
f'#Successful Results: {len(self.results)}\n#Failed Results: {len(self.invalid_results)}',
|
| 106 |
+
f'Accuracy: {round(avg_acc*100, 1)}',
|
| 107 |
+
]
|
| 108 |
+
for info in self.eval_records:
|
| 109 |
+
print(info)
|
| 110 |
+
|
| 111 |
+
def _summarize_metric_by_subtask(self):
|
| 112 |
+
from prettytable import PrettyTable
|
| 113 |
+
self.table = PrettyTable(['Task','Accuracy','Success','Failed'])
|
| 114 |
+
def _group_by_subtask():
|
| 115 |
+
sub_results = {}
|
| 116 |
+
sub_invalid_results = {}
|
| 117 |
+
for data in self.results:
|
| 118 |
+
if data['subtask'] not in sub_results:
|
| 119 |
+
sub_results[data['subtask']] = []
|
| 120 |
+
sub_results[data['subtask']].append(data)
|
| 121 |
+
for data in self.invalid_results:
|
| 122 |
+
if data['subtask'] not in sub_invalid_results:
|
| 123 |
+
sub_invalid_results[data['subtask']] = []
|
| 124 |
+
sub_invalid_results[data['subtask']].append(data)
|
| 125 |
+
return sub_results, sub_invalid_results
|
| 126 |
+
sub_results, sub_invalid_results = _group_by_subtask()
|
| 127 |
+
oa_accs = []
|
| 128 |
+
subtasks = list(sub_results.keys())
|
| 129 |
+
# subtasks.sort(key=lambda x:f"{x.split('/')[-1].split(' ')[0][0]}{x.split('/')[-1].split(' ')[1][0]}")
|
| 130 |
+
subtasks.sort(key=lambda x:x.split('/')[-1])
|
| 131 |
+
for subtask in subtasks:
|
| 132 |
+
sub_rsts = sub_results[subtask]
|
| 133 |
+
sub_in_rsts = sub_invalid_results.get(subtask, [])
|
| 134 |
+
accs = []
|
| 135 |
+
for result in sub_rsts:
|
| 136 |
+
acc = result['result']['acc']
|
| 137 |
+
accs.append(acc)
|
| 138 |
+
oa_accs.append(acc)
|
| 139 |
+
avg_acc = np.average(accs)
|
| 140 |
+
# task_name = f"{subtask.split('/')[-1].split(' ')[0][0]}{subtask.split('/')[-1].split(' ')[1][0]}"
|
| 141 |
+
task_name = subtask.split('/')[-1]
|
| 142 |
+
self.table.add_row([task_name, round(avg_acc*100, 1), len(sub_rsts), len(sub_in_rsts)])
|
| 143 |
+
self.table.add_row(['OVERALL', round(np.average(oa_accs)*100, 1), len(self.results), len(self.invalid_results)])
|
| 144 |
+
print(f'=====Evaluation Summary=====')
|
| 145 |
+
print(self.table)
|
| 146 |
+
|
| 147 |
+
def save_results(self, pred_path):
|
| 148 |
+
if os.path.isdir(pred_path):
|
| 149 |
+
output_dir = os.path.join(pred_path, 'eval_records')
|
| 150 |
+
else:
|
| 151 |
+
output_dir = os.path.join(os.path.dirname(pred_path), 'eval_records')
|
| 152 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 153 |
+
fout = open(os.path.join(output_dir, f'{self.dataset_name}_eval_result.txt'), 'w')
|
| 154 |
+
if self.dataset_name in ['MVBench', 'TVBench', 'FAVOR-Bench']:
|
| 155 |
+
print(self.table, file=fout)
|
| 156 |
+
else:
|
| 157 |
+
for info in self.eval_records:
|
| 158 |
+
fout.write(info+'\n')
|
| 159 |
+
fout.close()
|
eval_scripts/DREAM-1K/tarsier/evaluation/metrics/evaluate_qa_oe_gpt.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (2024) Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import json
|
| 15 |
+
import numpy as np
|
| 16 |
+
import ast
|
| 17 |
+
import time
|
| 18 |
+
from typing import List, Dict
|
| 19 |
+
from tqdm import tqdm
|
| 20 |
+
from pathos.multiprocessing import ProcessingPool as Pool
|
| 21 |
+
import func_timeout
|
| 22 |
+
from func_timeout import func_set_timeout
|
| 23 |
+
import os
|
| 24 |
+
import sys
|
| 25 |
+
sys.path.append('eval_scripts/DREAM-1K/tarsier')
|
| 26 |
+
from tools.color import Color
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def call_azure_gpt_api(question, answer, prediction, model):
|
| 30 |
+
|
| 31 |
+
messages=[
|
| 32 |
+
{
|
| 33 |
+
"role": "system",
|
| 34 |
+
"content":
|
| 35 |
+
"You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. "
|
| 36 |
+
"Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:"
|
| 37 |
+
"------"
|
| 38 |
+
"##INSTRUCTIONS: "
|
| 39 |
+
"- Focus on the meaningful match between the predicted answer and the correct answer.\n"
|
| 40 |
+
"- Consider synonyms or paraphrases as valid matches.\n"
|
| 41 |
+
"- Evaluate the correctness of the prediction compared to the answer."
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"role": "user",
|
| 45 |
+
"content":
|
| 46 |
+
"Please evaluate the following video-based question-answer pair:\n\n"
|
| 47 |
+
f"Question: {question}\n"
|
| 48 |
+
f"Correct Answer: {answer}\n"
|
| 49 |
+
f"Predicted Answer: {prediction}\n\n"
|
| 50 |
+
"Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. "
|
| 51 |
+
"Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING."
|
| 52 |
+
"DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. "
|
| 53 |
+
"For example, your response should look like this: {'pred': 'yes', 'score': 4.8}"
|
| 54 |
+
}
|
| 55 |
+
]
|
| 56 |
+
completion = call_gpt35(messages)
|
| 57 |
+
return completion
|
| 58 |
+
|
| 59 |
+
def try_call_api(question, answer, prediction, model, verbose=False):
|
| 60 |
+
for i in range(5):
|
| 61 |
+
gpt_q = call_azure_gpt_api(question, answer, prediction, model)
|
| 62 |
+
if gpt_q is not None:
|
| 63 |
+
return gpt_q, True
|
| 64 |
+
|
| 65 |
+
return None, False
|
| 66 |
+
|
| 67 |
+
def process_one_sample(inputs):
|
| 68 |
+
data, model, verbose = inputs
|
| 69 |
+
prompt, response, prediction = data['question'], data['response'].lower(), data['prediction'].lower()
|
| 70 |
+
result = None
|
| 71 |
+
try:
|
| 72 |
+
result, success = try_call_api(prompt, response, prediction, model, verbose)
|
| 73 |
+
if not success:
|
| 74 |
+
raise ValueError(result)
|
| 75 |
+
result = ast.literal_eval(result)
|
| 76 |
+
pred, score = result['pred'], result['score']
|
| 77 |
+
# check pred
|
| 78 |
+
if pred not in ['yes', 'no']:
|
| 79 |
+
raise ValueError()
|
| 80 |
+
# check score
|
| 81 |
+
result['score'] = float(result['score'])
|
| 82 |
+
if score < 0 or score > 5:
|
| 83 |
+
raise ValueError()
|
| 84 |
+
except Exception as e:
|
| 85 |
+
if verbose:
|
| 86 |
+
print(e)
|
| 87 |
+
print(f'invalid GPT response: {result}')
|
| 88 |
+
return {'success': False, 'result': result, 'data': data}
|
| 89 |
+
return {'success': True, 'result': result, 'data': data}
|
| 90 |
+
|
| 91 |
+
class GPTMetric:
|
| 92 |
+
def __init__(self, dataset_name, verbose=False) -> None:
|
| 93 |
+
self.dataset_name = dataset_name
|
| 94 |
+
self.num_worker = 64
|
| 95 |
+
self.model = 'gpt-35-turbo-0125'
|
| 96 |
+
self.results = []
|
| 97 |
+
self.invalid_results = []
|
| 98 |
+
self.dataset = []
|
| 99 |
+
self.verbose = verbose
|
| 100 |
+
|
| 101 |
+
def add(self, data):
|
| 102 |
+
self.dataset.append(data)
|
| 103 |
+
|
| 104 |
+
def process(self, dataset: List[Dict]):
|
| 105 |
+
pool = Pool(processes = self.num_worker, )
|
| 106 |
+
inputs = [(d, self.model, self.verbose) for d in dataset]
|
| 107 |
+
results = pool.uimap(process_one_sample, inputs, chunksize = 1)
|
| 108 |
+
|
| 109 |
+
for result in tqdm(results, total = len(dataset)):
|
| 110 |
+
self.update_metric(result)
|
| 111 |
+
pool.close()
|
| 112 |
+
pool.join()
|
| 113 |
+
pool.clear() # MUST
|
| 114 |
+
|
| 115 |
+
def update_metric(self, result):
|
| 116 |
+
if result['success']:
|
| 117 |
+
self.results.append(result)
|
| 118 |
+
else:
|
| 119 |
+
self.invalid_results.append(result)
|
| 120 |
+
|
| 121 |
+
def summarize_metric(self):
|
| 122 |
+
if self.verbose:
|
| 123 |
+
for result in self.results + self.invalid_results:
|
| 124 |
+
print(f"Success: {Color.red(str(result['success']))}")
|
| 125 |
+
print(Color.blue(json.dumps(result['data'], ensure_ascii=False)))
|
| 126 |
+
print(Color.green(json.dumps(result['result'], ensure_ascii=False)))
|
| 127 |
+
preds, scores = [], []
|
| 128 |
+
for result in self.results:
|
| 129 |
+
pred, score = result['result']['pred'], result['result']['score']
|
| 130 |
+
preds.append(pred)
|
| 131 |
+
scores.append(score)
|
| 132 |
+
avg_score = np.average(scores)
|
| 133 |
+
acc = np.average([p == 'yes' for p in preds])
|
| 134 |
+
print(f'=====Evaluation Summary=====')
|
| 135 |
+
self.eval_records = [
|
| 136 |
+
f'Dataset: {self.dataset_name}\tMetric: GPT Accuracy',
|
| 137 |
+
f'#Successful Results: {len(self.results)}\n#Failed Results: {len(self.invalid_results)}',
|
| 138 |
+
f'Accuracy: {round(acc*100, 1)}',
|
| 139 |
+
f'Average Score: {round(avg_score, 3)}',
|
| 140 |
+
]
|
| 141 |
+
for info in self.eval_records:
|
| 142 |
+
print(info)
|
| 143 |
+
|
| 144 |
+
def save_results(self, pred_path):
|
| 145 |
+
if os.path.isdir(pred_path):
|
| 146 |
+
output_dir = os.path.join(pred_path, 'eval_records')
|
| 147 |
+
else:
|
| 148 |
+
output_dir = os.path.join(os.path.dirname(pred_path), 'eval_records')
|
| 149 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 150 |
+
fout = open(os.path.join(output_dir, f'{self.dataset_name}_eval_result.txt'), 'w')
|
| 151 |
+
for info in self.eval_records:
|
| 152 |
+
fout.write(info+'\n')
|
| 153 |
+
fout.close()
|
eval_scripts/DREAM-1K/tarsier/evaluation/metrics/evaluate_video_mme.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (2024) Bytedance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import json
|
| 15 |
+
from typing import List, Dict
|
| 16 |
+
from typing import Optional, List, Union
|
| 17 |
+
import os
|
| 18 |
+
import sys
|
| 19 |
+
sys.path.append('eval_scripts/DREAM-1K/tarsier')from tools.color import Color
|
| 20 |
+
|
| 21 |
+
CATEGORIES = [
|
| 22 |
+
"Knowledge",
|
| 23 |
+
"Film & Television",
|
| 24 |
+
"Sports Competition",
|
| 25 |
+
"Artistic Performance",
|
| 26 |
+
"Life Record",
|
| 27 |
+
"Multilingual"
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
SUB_CATEGORIES = [
|
| 31 |
+
"Humanity & History",
|
| 32 |
+
"Literature & Art",
|
| 33 |
+
"Biology & Medicine",
|
| 34 |
+
"Finance & Commerce",
|
| 35 |
+
"Astronomy",
|
| 36 |
+
"Geography",
|
| 37 |
+
"Law",
|
| 38 |
+
"Life Tip",
|
| 39 |
+
"Technology",
|
| 40 |
+
"Animation",
|
| 41 |
+
"Movie & TV Show",
|
| 42 |
+
"Documentary",
|
| 43 |
+
"News Report",
|
| 44 |
+
"Esports",
|
| 45 |
+
"Basketball",
|
| 46 |
+
"Football",
|
| 47 |
+
"Athletics",
|
| 48 |
+
"Other Sports",
|
| 49 |
+
"Stage Play",
|
| 50 |
+
"Magic Show",
|
| 51 |
+
"Variety Show",
|
| 52 |
+
"Acrobatics",
|
| 53 |
+
"Handicraft",
|
| 54 |
+
"Food",
|
| 55 |
+
"Fashion",
|
| 56 |
+
"Daily Life",
|
| 57 |
+
"Travel",
|
| 58 |
+
"Pet & Animal",
|
| 59 |
+
"Exercise",
|
| 60 |
+
"Multilingual"
|
| 61 |
+
]
|
| 62 |
+
|
| 63 |
+
TASK_CATEGORIES = [
|
| 64 |
+
"Temporal Perception",
|
| 65 |
+
"Spatial Perception",
|
| 66 |
+
"Attribute Perception",
|
| 67 |
+
"Action Recognition",
|
| 68 |
+
"Object Recognition",
|
| 69 |
+
"OCR Problems",
|
| 70 |
+
"Counting Problem",
|
| 71 |
+
"Temporal Reasoning",
|
| 72 |
+
"Spatial Reasoning",
|
| 73 |
+
"Action Reasoning",
|
| 74 |
+
"Object Reasoning",
|
| 75 |
+
"Information Synopsis",
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class VideoMMEAccuracyMetric:
|
| 80 |
+
def __init__(self, dataset_name, verbose=False) -> None:
|
| 81 |
+
self.dataset_name = dataset_name
|
| 82 |
+
self.results = []
|
| 83 |
+
self.invalid_results = []
|
| 84 |
+
self.dataset = []
|
| 85 |
+
self.verbose = verbose
|
| 86 |
+
|
| 87 |
+
def add(self, data):
|
| 88 |
+
self.dataset.append(data)
|
| 89 |
+
|
| 90 |
+
def process(self, dataset: List[Dict]):
|
| 91 |
+
return self._process(dataset)
|
| 92 |
+
|
| 93 |
+
def _process(self, dataset: List[Dict]):
|
| 94 |
+
for data in dataset:
|
| 95 |
+
prompt, response, prediction = data['prompt'], data['response'], data['prediction']
|
| 96 |
+
extra_info = data['extra_info']
|
| 97 |
+
prediction = prediction.replace('(', '').replace(')', '').strip()
|
| 98 |
+
if len(prediction) <= 0:
|
| 99 |
+
success = False
|
| 100 |
+
else:
|
| 101 |
+
prediction = prediction[0]
|
| 102 |
+
if '1'<=prediction<='5':
|
| 103 |
+
prediction = chr(int(prediction) + ord('A'))
|
| 104 |
+
success = prediction.isupper() and prediction.isalpha() and len(prediction) == 1
|
| 105 |
+
if success:
|
| 106 |
+
rst = {
|
| 107 |
+
'success': success,
|
| 108 |
+
'data': data,
|
| 109 |
+
'result': {'acc': response == prediction},
|
| 110 |
+
'extra_info': extra_info,
|
| 111 |
+
'missing': False
|
| 112 |
+
}
|
| 113 |
+
self.results.append(rst)
|
| 114 |
+
else:
|
| 115 |
+
rst = {
|
| 116 |
+
'success': success,
|
| 117 |
+
'data': data,
|
| 118 |
+
'result': {'acc': False},
|
| 119 |
+
'extra_info': extra_info,
|
| 120 |
+
'missing': True
|
| 121 |
+
}
|
| 122 |
+
self.results.append(rst)
|
| 123 |
+
self.invalid_results.append(rst)
|
| 124 |
+
|
| 125 |
+
def summarize_metric(self):
|
| 126 |
+
if self.verbose:
|
| 127 |
+
for result in self.results + self.invalid_results:
|
| 128 |
+
print(f"{Color.red('Success: ' + str(result['success']))}")
|
| 129 |
+
print(Color.blue(json.dumps(result['data'], ensure_ascii=False)))
|
| 130 |
+
print(f"{Color.green('Accuracy: ' + str(result['result']['acc']))}")
|
| 131 |
+
print(f'=====Evaluation Summary=====')
|
| 132 |
+
print(f'Dataset: {self.dataset_name}\tMetric: Accuracy')
|
| 133 |
+
print(f'#Successful Results: {len(self.results) - len(self.invalid_results)}\n#Failed Results: {len(self.invalid_results)}')
|
| 134 |
+
self.eval_your_results(
|
| 135 |
+
video_types = ["short","medium","long"],
|
| 136 |
+
skip_missing = True,
|
| 137 |
+
return_categories_accuracy = True,
|
| 138 |
+
return_sub_categories_accuracy = False,
|
| 139 |
+
return_task_types_accuracy = False,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
def merge_results(self):
|
| 143 |
+
results_merged_by_vid = {}
|
| 144 |
+
for result in self.results:
|
| 145 |
+
vid = result['extra_info']['vid']
|
| 146 |
+
if vid not in results_merged_by_vid:
|
| 147 |
+
results_merged_by_vid[vid] = {
|
| 148 |
+
'video_id': vid,
|
| 149 |
+
"duration": result['extra_info']['duration'],
|
| 150 |
+
"domain": result['extra_info']['domain'],
|
| 151 |
+
"sub_category": result['extra_info']['sub_category'],
|
| 152 |
+
'questions': [],
|
| 153 |
+
'missing': False
|
| 154 |
+
}
|
| 155 |
+
if result['missing']:
|
| 156 |
+
results_merged_by_vid[vid]['missing'] = True
|
| 157 |
+
results_merged_by_vid[vid]['questions'].append({
|
| 158 |
+
'qid': result['extra_info']['idx'],
|
| 159 |
+
'task_type': result['extra_info']['task_type'],
|
| 160 |
+
'acc': result['result']['acc']
|
| 161 |
+
}
|
| 162 |
+
)
|
| 163 |
+
return results_merged_by_vid
|
| 164 |
+
|
| 165 |
+
def eval_your_results(
|
| 166 |
+
self,
|
| 167 |
+
video_types: Optional[Union[List[str], str]] = None,
|
| 168 |
+
skip_missing: Optional[bool] = False,
|
| 169 |
+
return_categories_accuracy: Optional[bool] = True,
|
| 170 |
+
return_sub_categories_accuracy: Optional[bool] = False,
|
| 171 |
+
return_task_types_accuracy: Optional[bool] = False,
|
| 172 |
+
gt_answer_key: Optional[str] = "answer",
|
| 173 |
+
your_answer_key: Optional[str] = "response"
|
| 174 |
+
|
| 175 |
+
):
|
| 176 |
+
"""
|
| 177 |
+
This copy from https://github.com/thanku-all/parse_answer/blob/main/eval_your_results.py
|
| 178 |
+
Evaluate your results against the ground truth
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
- your_results_path (str): Path to your results file
|
| 182 |
+
- video_types (Optional[List[str], str]): List of video types to evaluate.
|
| 183 |
+
- skip_missing (Optional[bool]): If True, missing files will be skipped. If False, an error will be raised if there are missing files.
|
| 184 |
+
- return_categories_accuracy (Optional[bool]): If True, the accuracy for each video category will be returned.
|
| 185 |
+
- return_sub_categories_accuracy (Optional[bool]): If True, the accuracy for each video sub category will be returned.
|
| 186 |
+
- return_task_types_accuracy (Optional[bool]): If True, the accuracy for each task category will be returned.
|
| 187 |
+
- gt_answer_key (Optional[str]): Key to access the ground truth answer in the results file.
|
| 188 |
+
- your_answer_key (Optional[str]): Key to access your answer in the results file.
|
| 189 |
+
"""
|
| 190 |
+
|
| 191 |
+
# Load your results
|
| 192 |
+
# with open(your_results_path, 'r') as f:
|
| 193 |
+
# your_results = json.load(f)
|
| 194 |
+
your_results = list(self.merge_results().values())
|
| 195 |
+
self.eval_records = []
|
| 196 |
+
if isinstance(video_types, str):
|
| 197 |
+
video_types = video_types.split(",")
|
| 198 |
+
|
| 199 |
+
q_type_dict = {}
|
| 200 |
+
v_type_dict = {}
|
| 201 |
+
v_sub_type_dict = {}
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
for video_type in video_types:
|
| 205 |
+
|
| 206 |
+
# Filter your results based on video types
|
| 207 |
+
your_results_video_type = [item for item in your_results if item['duration'] == video_type]
|
| 208 |
+
|
| 209 |
+
# Task Categories
|
| 210 |
+
q_type_dict[video_type] = {}
|
| 211 |
+
for q_type in TASK_CATEGORIES:
|
| 212 |
+
q_type_dict[video_type][q_type] = {"correct": 0, "answered": 0}
|
| 213 |
+
|
| 214 |
+
# Video categories
|
| 215 |
+
v_type_dict[video_type] = {}
|
| 216 |
+
for v_type in CATEGORIES:
|
| 217 |
+
v_type_dict[video_type][v_type] = {"correct": 0, "answered": 0}
|
| 218 |
+
|
| 219 |
+
v_sub_type_dict[video_type] = {}
|
| 220 |
+
for v_sub_type in SUB_CATEGORIES:
|
| 221 |
+
v_sub_type_dict[video_type][v_sub_type] = {"correct": 0, "answered": 0}
|
| 222 |
+
|
| 223 |
+
if not skip_missing:
|
| 224 |
+
# Check if the number of files in your results and ground truth are the same
|
| 225 |
+
print(len(your_results_video_type))
|
| 226 |
+
assert len(your_results_video_type) == 300, f"Number of files in {video_type} is not 300. Check if there are missing files."
|
| 227 |
+
|
| 228 |
+
for item in your_results_video_type:
|
| 229 |
+
|
| 230 |
+
if skip_missing and item["missing"]:
|
| 231 |
+
continue
|
| 232 |
+
|
| 233 |
+
# Get the video category, sub category and question category
|
| 234 |
+
video_category = item["domain"]
|
| 235 |
+
video_sub_category = item["sub_category"]
|
| 236 |
+
|
| 237 |
+
questions = item["questions"]
|
| 238 |
+
|
| 239 |
+
for question in questions:
|
| 240 |
+
q_type = question["task_type"]
|
| 241 |
+
|
| 242 |
+
# Get the ground truth and your response
|
| 243 |
+
# gt_answer = question[gt_answer_key]
|
| 244 |
+
# response = question[your_answer_key]
|
| 245 |
+
acc = question['acc']
|
| 246 |
+
|
| 247 |
+
# Extract the answer from the response
|
| 248 |
+
# extration = extract_characters_regex(response)
|
| 249 |
+
|
| 250 |
+
if acc is not None:
|
| 251 |
+
q_type_dict[video_type][q_type]["answered"] += 1
|
| 252 |
+
q_type_dict[video_type][q_type]["correct"] += acc
|
| 253 |
+
|
| 254 |
+
v_type_dict[video_type][video_category]["answered"] += 1
|
| 255 |
+
v_type_dict[video_type][video_category]["correct"] += acc
|
| 256 |
+
|
| 257 |
+
v_sub_type_dict[video_type][video_sub_category]["answered"] += 1
|
| 258 |
+
v_sub_type_dict[video_type][video_sub_category]["correct"] += acc
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
# Print the results for each video type
|
| 262 |
+
for video_type in video_types:
|
| 263 |
+
info = f"=====================================\nEvaluation on video Type: {video_type}\n====================================="
|
| 264 |
+
self.eval_records.append(info)
|
| 265 |
+
print(info)
|
| 266 |
+
if return_categories_accuracy:
|
| 267 |
+
info = f"-------------------------------------\nVideo Categories\n-------------------------------------"
|
| 268 |
+
self.eval_records.append(info)
|
| 269 |
+
print(info)
|
| 270 |
+
for v_type in v_type_dict[video_type]:
|
| 271 |
+
info = f"{v_type}: {100 * v_type_dict[video_type][v_type]['correct'] / v_type_dict[video_type][v_type]['answered'] if v_type_dict[video_type][v_type]['answered'] > 0 else 0 : .1f}%"
|
| 272 |
+
self.eval_records.append(info)
|
| 273 |
+
print(info)
|
| 274 |
+
if return_sub_categories_accuracy:
|
| 275 |
+
info = f"-------------------------------------\nVideo Sub Categories\n-------------------------------------"
|
| 276 |
+
self.eval_records.append(info)
|
| 277 |
+
for v_sub_type in v_sub_type_dict[video_type]:
|
| 278 |
+
info = f"{v_sub_type}: {100 * v_sub_type_dict[video_type][v_sub_type]['correct'] / v_sub_type_dict[video_type][v_sub_type]['answered'] if v_sub_type_dict[video_type][v_sub_type]['answered'] > 0 else 0 : .1f}%"
|
| 279 |
+
self.eval_records.append(info)
|
| 280 |
+
print(info)
|
| 281 |
+
if return_task_types_accuracy:
|
| 282 |
+
info = f"-------------------------------------\nTask Categories\n-------------------------------------"
|
| 283 |
+
self.eval_records.append(info)
|
| 284 |
+
print(info)
|
| 285 |
+
for q_type in q_type_dict[video_type]:
|
| 286 |
+
info = f"{q_type}: {100 * q_type_dict[video_type][q_type]['correct'] / q_type_dict[video_type][q_type]['answered'] if q_type_dict[video_type][q_type]['answered'] > 0 else 0 : .1f}%"
|
| 287 |
+
self.eval_records.append(info)
|
| 288 |
+
print(info)
|
| 289 |
+
info = f"-------------------------------------\nOverall Performance\n-------------------------------------"
|
| 290 |
+
|
| 291 |
+
print(info)
|
| 292 |
+
total_correct = sum([q_type_dict[video_type][q_type]["correct"] for q_type in TASK_CATEGORIES])
|
| 293 |
+
total_answered = sum([q_type_dict[video_type][q_type]["answered"] for q_type in TASK_CATEGORIES])
|
| 294 |
+
info = f"Overall: {100 * total_correct / total_answered if total_answered > 0 else 0 : .1f}%"
|
| 295 |
+
self.eval_records.append(info)
|
| 296 |
+
print(info+'\n')
|
| 297 |
+
|
| 298 |
+
# Print the results for the entire dataset
|
| 299 |
+
info = f"=====================================\nEvaluation on the entire dataset\n====================================="
|
| 300 |
+
self.eval_records.append(info)
|
| 301 |
+
print(info)
|
| 302 |
+
|
| 303 |
+
if return_categories_accuracy:
|
| 304 |
+
info = f"-------------------------------------\nVideo Categories\n-------------------------------------"
|
| 305 |
+
self.eval_records.append(info)
|
| 306 |
+
print(info)
|
| 307 |
+
for v_type in CATEGORIES:
|
| 308 |
+
total_correct = sum([v_type_dict[video_type][v_type]["correct"] for video_type in video_types])
|
| 309 |
+
total_answered = sum([v_type_dict[video_type][v_type]["answered"] for video_type in video_types])
|
| 310 |
+
info = f"{v_type}: {100 * total_correct / total_answered if total_answered > 0 else 0 : .1f}%"
|
| 311 |
+
self.eval_records.append(info)
|
| 312 |
+
print(info)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
if return_sub_categories_accuracy:
|
| 316 |
+
info = f"-------------------------------------\nVideo Sub Categories\n-------------------------------------"
|
| 317 |
+
self.eval_records.append(info)
|
| 318 |
+
print(info)
|
| 319 |
+
|
| 320 |
+
for v_sub_type in SUB_CATEGORIES:
|
| 321 |
+
total_correct = sum([v_sub_type_dict[video_type][v_sub_type]["correct"] for video_type in video_types])
|
| 322 |
+
total_answered = sum([v_sub_type_dict[video_type][v_sub_type]["answered"] for video_type in video_types])
|
| 323 |
+
info = f"{v_sub_type}: {100 * total_correct / total_answered if total_answered > 0 else 0 : .1f}%"
|
| 324 |
+
self.eval_records.append(info)
|
| 325 |
+
print(info)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
if return_task_types_accuracy:
|
| 329 |
+
info = f"-------------------------------------\nTask Categories\n-------------------------------------"
|
| 330 |
+
self.eval_records.append(info)
|
| 331 |
+
print(info)
|
| 332 |
+
for q_type in TASK_CATEGORIES:
|
| 333 |
+
|
| 334 |
+
total_correct = sum([q_type_dict[video_type][q_type]["correct"] for video_type in video_types])
|
| 335 |
+
total_answered = sum([q_type_dict[video_type][q_type]["answered"] for video_type in video_types])
|
| 336 |
+
info = f"{q_type}: {100 * total_correct / total_answered if total_answered > 0 else 0 : .1f}%"
|
| 337 |
+
self.eval_records.append(info)
|
| 338 |
+
print(info)
|
| 339 |
+
|
| 340 |
+
info = f"*************************************\nOverall Performance\n*************************************"
|
| 341 |
+
self.eval_records.append(info)
|
| 342 |
+
print(info)
|
| 343 |
+
total_correct = sum([sum([q_type_dict[video_type][q_type]["correct"] for q_type in TASK_CATEGORIES]) for video_type in video_types])
|
| 344 |
+
total_answered = sum([sum([q_type_dict[video_type][q_type]["answered"] for q_type in TASK_CATEGORIES]) for video_type in video_types])
|
| 345 |
+
info = f"Overall: {100 * total_correct / total_answered if total_answered > 0 else 0 : .1f}%"
|
| 346 |
+
self.eval_records.append(info)
|
| 347 |
+
print(info)
|
| 348 |
+
|
| 349 |
+
def save_results(self, pred_path):
|
| 350 |
+
if os.path.isdir(pred_path):
|
| 351 |
+
output_dir = os.path.join(pred_path, 'eval_records')
|
| 352 |
+
else:
|
| 353 |
+
output_dir = os.path.join(os.path.dirname(pred_path), 'eval_records')
|
| 354 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 355 |
+
fout = open(os.path.join(output_dir, f'{self.dataset_name}_eval_result.txt'), 'w')
|
| 356 |
+
for info in self.eval_records:
|
| 357 |
+
fout.write(info+'\n')
|
| 358 |
+
fout.close()
|
eval_scripts/DREAM-1K/tarsier/models/modeling_qwen2_vl_fast.py
ADDED
|
@@ -0,0 +1,1320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import math
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch.nn import LayerNorm
|
| 10 |
+
|
| 11 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 12 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 13 |
+
from transformers.modeling_rope_utils import rope_config_validation, ROPE_INIT_FUNCTIONS
|
| 14 |
+
from transformers.cache_utils import Cache, SlidingWindowCache, StaticCache
|
| 15 |
+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
| 16 |
+
from transformers.utils import (
|
| 17 |
+
add_start_docstrings,
|
| 18 |
+
add_start_docstrings_to_model_forward,
|
| 19 |
+
is_flash_attn_2_available,
|
| 20 |
+
is_flash_attn_greater_or_equal_2_10,
|
| 21 |
+
logging,
|
| 22 |
+
replace_return_docstrings,
|
| 23 |
+
)
|
| 24 |
+
from transformers.modeling_outputs import (
|
| 25 |
+
BaseModelOutputWithPast,
|
| 26 |
+
ModelOutput,
|
| 27 |
+
)
|
| 28 |
+
from transformers.activations import ACT2FN
|
| 29 |
+
from transformers.generation import GenerationMixin
|
| 30 |
+
|
| 31 |
+
if is_flash_attn_2_available():
|
| 32 |
+
from flash_attn import flash_attn_varlen_func
|
| 33 |
+
|
| 34 |
+
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
| 35 |
+
else:
|
| 36 |
+
flash_attn_varlen_func = None
|
| 37 |
+
|
| 38 |
+
# from apex.normalization.fused_layer_norm import fused_rms_norm_affine
|
| 39 |
+
|
| 40 |
+
logger = logging.get_logger(__name__)
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class Qwen2VLCausalLMOutputWithPast(ModelOutput):
|
| 44 |
+
"""
|
| 45 |
+
Base class for Qwen2VL causal language model (or autoregressive) outputs.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
| 49 |
+
Language modeling loss (for next-token prediction).
|
| 50 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
| 51 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 52 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
| 53 |
+
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
| 54 |
+
`(batch_size, num_heads, sequence_length, embed_size_per_head)`)
|
| 55 |
+
|
| 56 |
+
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
|
| 57 |
+
`past_key_values` input) to speed up sequential decoding.
|
| 58 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
| 59 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 60 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
| 61 |
+
|
| 62 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
| 63 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
| 64 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
| 65 |
+
sequence_length)`.
|
| 66 |
+
|
| 67 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 68 |
+
heads.
|
| 69 |
+
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
|
| 70 |
+
The rope index difference between sequence length and multimodal rope.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
loss: Optional[torch.FloatTensor] = None
|
| 74 |
+
logits: torch.FloatTensor = None
|
| 75 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None
|
| 76 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 77 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 78 |
+
|
| 79 |
+
class Qwen2VLVisionConfig(PretrainedConfig):
|
| 80 |
+
model_type = "qwen2_vl"
|
| 81 |
+
|
| 82 |
+
def __init__(
|
| 83 |
+
self,
|
| 84 |
+
depth=32,
|
| 85 |
+
embed_dim=1280,
|
| 86 |
+
hidden_size=3584,
|
| 87 |
+
hidden_act="quick_gelu",
|
| 88 |
+
mlp_ratio=4,
|
| 89 |
+
num_heads=16,
|
| 90 |
+
in_channels=3,
|
| 91 |
+
patch_size=14,
|
| 92 |
+
spatial_merge_size=2,
|
| 93 |
+
temporal_patch_size=2,
|
| 94 |
+
attn_implementation='flash_attention_2',
|
| 95 |
+
**kwargs,
|
| 96 |
+
):
|
| 97 |
+
super().__init__(**kwargs)
|
| 98 |
+
|
| 99 |
+
self.depth = depth
|
| 100 |
+
self.embed_dim = embed_dim
|
| 101 |
+
self.hidden_size = hidden_size
|
| 102 |
+
self.hidden_act = hidden_act
|
| 103 |
+
self.mlp_ratio = mlp_ratio
|
| 104 |
+
self.num_heads = num_heads
|
| 105 |
+
self.in_channels = in_channels
|
| 106 |
+
self.patch_size = patch_size
|
| 107 |
+
self.spatial_merge_size = spatial_merge_size
|
| 108 |
+
self.temporal_patch_size = temporal_patch_size
|
| 109 |
+
self.attn_implementation = attn_implementation
|
| 110 |
+
|
| 111 |
+
@classmethod
|
| 112 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
| 113 |
+
cls._set_token_in_kwargs(kwargs)
|
| 114 |
+
|
| 115 |
+
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
| 116 |
+
|
| 117 |
+
if config_dict.get("model_type") == "qwen2_vl":
|
| 118 |
+
config_dict = config_dict["vision_config"]
|
| 119 |
+
|
| 120 |
+
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
| 121 |
+
logger.warning(
|
| 122 |
+
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
| 123 |
+
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
return cls.from_dict(config_dict, **kwargs)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class Qwen2VLConfig(PretrainedConfig):
|
| 130 |
+
r"""
|
| 131 |
+
This is the configuration class to store the configuration of a [`Qwen2VLModel`]. It is used to instantiate a
|
| 132 |
+
Qwen2-VL model according to the specified arguments, defining the model architecture. Instantiating a configuration
|
| 133 |
+
with the defaults will yield a similar configuration to that of
|
| 134 |
+
Qwen2-VL-7B-Instruct [Qwen/Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct).
|
| 135 |
+
|
| 136 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
| 137 |
+
documentation from [`PretrainedConfig`] for more information.
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
vocab_size (`int`, *optional*, defaults to 152064):
|
| 142 |
+
Vocabulary size of the Qwen2VL model. Defines the number of different tokens that can be represented by the
|
| 143 |
+
`inputs_ids` passed when calling [`Qwen2VLModel`]
|
| 144 |
+
hidden_size (`int`, *optional*, defaults to 8192):
|
| 145 |
+
Dimension of the hidden representations.
|
| 146 |
+
intermediate_size (`int`, *optional*, defaults to 29568):
|
| 147 |
+
Dimension of the MLP representations.
|
| 148 |
+
num_hidden_layers (`int`, *optional*, defaults to 80):
|
| 149 |
+
Number of hidden layers in the Transformer encoder.
|
| 150 |
+
num_attention_heads (`int`, *optional*, defaults to 64):
|
| 151 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
| 152 |
+
num_key_value_heads (`int`, *optional*, defaults to 8):
|
| 153 |
+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
| 154 |
+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
| 155 |
+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
| 156 |
+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
| 157 |
+
by meanpooling all the original heads within that group. For more details checkout [this
|
| 158 |
+
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
|
| 159 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
| 160 |
+
The non-linear activation function (function or string) in the decoder.
|
| 161 |
+
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
| 162 |
+
The maximum sequence length that this model might ever be used with.
|
| 163 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
| 164 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
| 165 |
+
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
|
| 166 |
+
The epsilon used by the rms normalization layers.
|
| 167 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
| 168 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
| 169 |
+
relevant if `config.is_decoder=True`.
|
| 170 |
+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
| 171 |
+
Whether the model's input and output word embeddings should be tied.
|
| 172 |
+
rope_theta (`float`, *optional*, defaults to 1000000.0):
|
| 173 |
+
The base period of the RoPE embeddings.
|
| 174 |
+
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
| 175 |
+
Whether to use sliding window attention.
|
| 176 |
+
sliding_window (`int`, *optional*, defaults to 4096):
|
| 177 |
+
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
| 178 |
+
max_window_layers (`int`, *optional*, defaults to 80):
|
| 179 |
+
The number of layers that use SWA (Sliding Window Attention). The bottom layers use SWA while the top use full attention.
|
| 180 |
+
attention_dropout (`float`, *optional*, defaults to 0.0):
|
| 181 |
+
The dropout ratio for the attention probabilities.
|
| 182 |
+
vision_config (`Dict`, *optional*):
|
| 183 |
+
The config for the visual encoder initialization.
|
| 184 |
+
rope_scaling (`Dict`, *optional*):
|
| 185 |
+
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
| 186 |
+
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
| 187 |
+
accordingly.
|
| 188 |
+
Expected contents:
|
| 189 |
+
`rope_type` (`str`):
|
| 190 |
+
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
| 191 |
+
'llama3'], with 'default' being the original RoPE implementation.
|
| 192 |
+
`factor` (`float`, *optional*):
|
| 193 |
+
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
| 194 |
+
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
| 195 |
+
original maximum pre-trained length.
|
| 196 |
+
`original_max_position_embeddings` (`int`, *optional*):
|
| 197 |
+
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
| 198 |
+
pretraining.
|
| 199 |
+
`attention_factor` (`float`, *optional*):
|
| 200 |
+
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
| 201 |
+
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
| 202 |
+
`factor` field to infer the suggested value.
|
| 203 |
+
`beta_fast` (`float`, *optional*):
|
| 204 |
+
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
| 205 |
+
ramp function. If unspecified, it defaults to 32.
|
| 206 |
+
`beta_slow` (`float`, *optional*):
|
| 207 |
+
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
| 208 |
+
ramp function. If unspecified, it defaults to 1.
|
| 209 |
+
`short_factor` (`List[float]`, *optional*):
|
| 210 |
+
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
| 211 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 212 |
+
size divided by the number of attention heads divided by 2
|
| 213 |
+
`long_factor` (`List[float]`, *optional*):
|
| 214 |
+
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
| 215 |
+
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
| 216 |
+
size divided by the number of attention heads divided by 2
|
| 217 |
+
`low_freq_factor` (`float`, *optional*):
|
| 218 |
+
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
| 219 |
+
`high_freq_factor` (`float`, *optional*):
|
| 220 |
+
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
| 221 |
+
|
| 222 |
+
```python
|
| 223 |
+
>>> from transformers import Qwen2VLForConditionalGeneration, Qwen2VLConfig
|
| 224 |
+
|
| 225 |
+
>>> # Initializing a Qwen2VL style configuration
|
| 226 |
+
>>> configuration = Qwen2VLConfig()
|
| 227 |
+
|
| 228 |
+
>>> # Initializing a model from the Qwen2-VL-7B style configuration
|
| 229 |
+
>>> model = Qwen2VLForConditionalGeneration(configuration)
|
| 230 |
+
|
| 231 |
+
>>> # Accessing the model configuration
|
| 232 |
+
>>> configuration = model.config
|
| 233 |
+
```"""
|
| 234 |
+
|
| 235 |
+
model_type = "qwen2_vl"
|
| 236 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
| 237 |
+
|
| 238 |
+
def __init__(
|
| 239 |
+
self,
|
| 240 |
+
vocab_size=152064,
|
| 241 |
+
hidden_size=8192,
|
| 242 |
+
intermediate_size=29568,
|
| 243 |
+
num_hidden_layers=80,
|
| 244 |
+
num_attention_heads=64,
|
| 245 |
+
num_key_value_heads=8,
|
| 246 |
+
hidden_act="silu",
|
| 247 |
+
max_position_embeddings=32768,
|
| 248 |
+
initializer_range=0.02,
|
| 249 |
+
rms_norm_eps=1e-05,
|
| 250 |
+
use_cache=True,
|
| 251 |
+
tie_word_embeddings=False,
|
| 252 |
+
rope_theta=1000000.0,
|
| 253 |
+
use_sliding_window=False,
|
| 254 |
+
sliding_window=4096,
|
| 255 |
+
max_window_layers=80,
|
| 256 |
+
attention_dropout=0.0,
|
| 257 |
+
rope_scaling=None,
|
| 258 |
+
spatial_merge_size=2,
|
| 259 |
+
attn_implementation='flash_attention_2',
|
| 260 |
+
**kwargs,
|
| 261 |
+
):
|
| 262 |
+
|
| 263 |
+
self.vocab_size = vocab_size
|
| 264 |
+
self.max_position_embeddings = max_position_embeddings
|
| 265 |
+
self.hidden_size = hidden_size
|
| 266 |
+
self.intermediate_size = intermediate_size
|
| 267 |
+
self.num_hidden_layers = num_hidden_layers
|
| 268 |
+
self.num_attention_heads = num_attention_heads
|
| 269 |
+
self.use_sliding_window = use_sliding_window
|
| 270 |
+
self.sliding_window = sliding_window
|
| 271 |
+
self.max_window_layers = max_window_layers
|
| 272 |
+
|
| 273 |
+
# for backward compatibility
|
| 274 |
+
if num_key_value_heads is None:
|
| 275 |
+
num_key_value_heads = num_attention_heads
|
| 276 |
+
|
| 277 |
+
self.num_key_value_heads = num_key_value_heads
|
| 278 |
+
self.hidden_act = hidden_act
|
| 279 |
+
self.initializer_range = initializer_range
|
| 280 |
+
self.rms_norm_eps = rms_norm_eps
|
| 281 |
+
self.use_cache = use_cache
|
| 282 |
+
self.rope_theta = rope_theta
|
| 283 |
+
self.attention_dropout = attention_dropout
|
| 284 |
+
self.rope_scaling = rope_scaling
|
| 285 |
+
self.spatial_merge_size = spatial_merge_size
|
| 286 |
+
self.attn_implementation = attn_implementation
|
| 287 |
+
|
| 288 |
+
# Validate the correctness of rotary position embeddings parameters
|
| 289 |
+
# BC: if there is a 'type' field, move it to 'rope_type'.
|
| 290 |
+
# and change type from 'mrope' to 'default' because `mrope` does defeault RoPE calculations
|
| 291 |
+
# one can set it to "linear"/"dynamic" etc. to have scaled RoPE
|
| 292 |
+
# TODO: @raushan update config in the hub
|
| 293 |
+
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
| 294 |
+
if self.rope_scaling["type"] == "mrope":
|
| 295 |
+
self.rope_scaling["type"] = "default"
|
| 296 |
+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
| 297 |
+
rope_config_validation(self, ignore_keys={"mrope_section"})
|
| 298 |
+
|
| 299 |
+
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
| 300 |
+
|
| 301 |
+
# Copied from transformers.models.llama.modeling_llama.rotate_half
|
| 302 |
+
def rotate_half(x):
|
| 303 |
+
"""Rotates half the hidden dims of the input."""
|
| 304 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 305 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 306 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 307 |
+
|
| 308 |
+
def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
|
| 309 |
+
"""Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).
|
| 310 |
+
|
| 311 |
+
Explanation:
|
| 312 |
+
Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
|
| 313 |
+
sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
|
| 314 |
+
vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately.
|
| 315 |
+
Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
|
| 316 |
+
For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
|
| 317 |
+
height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
|
| 318 |
+
difference with modern LLMs.
|
| 319 |
+
|
| 320 |
+
Args:
|
| 321 |
+
q (`torch.Tensor`): The query tensor.
|
| 322 |
+
k (`torch.Tensor`): The key tensor.
|
| 323 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
| 324 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
| 325 |
+
position_ids (`torch.Tensor`):
|
| 326 |
+
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
| 327 |
+
used to pass offsetted position ids when working with a KV-cache.
|
| 328 |
+
mrope_section(`List(int)`):
|
| 329 |
+
Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
|
| 330 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
| 331 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
| 332 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
| 333 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
| 334 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
| 335 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
| 336 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
| 337 |
+
Returns:
|
| 338 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
| 339 |
+
"""
|
| 340 |
+
mrope_section = mrope_section * 2
|
| 341 |
+
cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
| 342 |
+
unsqueeze_dim
|
| 343 |
+
)
|
| 344 |
+
sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
|
| 345 |
+
unsqueeze_dim
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 349 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 350 |
+
return q_embed, k_embed
|
| 351 |
+
|
| 352 |
+
def apply_rotary_pos_emb_vision(tensor: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
| 353 |
+
orig_dtype = tensor.dtype
|
| 354 |
+
tensor = tensor.float()
|
| 355 |
+
cos = freqs.cos()
|
| 356 |
+
sin = freqs.sin()
|
| 357 |
+
cos = cos.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
| 358 |
+
sin = sin.unsqueeze(1).repeat(1, 1, 2).unsqueeze(0).float()
|
| 359 |
+
output = (tensor * cos) + (rotate_half(tensor) * sin)
|
| 360 |
+
output = output.to(orig_dtype)
|
| 361 |
+
return output
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
class VisionRotaryEmbedding(nn.Module):
|
| 365 |
+
def __init__(self, dim: int, theta: float = 10000.0) -> None:
|
| 366 |
+
super().__init__()
|
| 367 |
+
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
|
| 368 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 369 |
+
|
| 370 |
+
def forward(self, seqlen: int) -> torch.Tensor:
|
| 371 |
+
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
|
| 372 |
+
freqs = torch.outer(seq, self.inv_freq)
|
| 373 |
+
return freqs
|
| 374 |
+
|
| 375 |
+
class PatchEmbed(nn.Module):
|
| 376 |
+
def __init__(
|
| 377 |
+
self,
|
| 378 |
+
patch_size: int = 14,
|
| 379 |
+
temporal_patch_size: int = 2,
|
| 380 |
+
in_channels: int = 3,
|
| 381 |
+
embed_dim: int = 1152,
|
| 382 |
+
) -> None:
|
| 383 |
+
super().__init__()
|
| 384 |
+
self.patch_size = patch_size
|
| 385 |
+
self.temporal_patch_size = temporal_patch_size
|
| 386 |
+
self.in_channels = in_channels
|
| 387 |
+
self.embed_dim = embed_dim
|
| 388 |
+
|
| 389 |
+
kernel_size = [temporal_patch_size, patch_size, patch_size]
|
| 390 |
+
self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False)
|
| 391 |
+
|
| 392 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 393 |
+
target_dtype = self.proj.weight.dtype
|
| 394 |
+
hidden_states = hidden_states.view(
|
| 395 |
+
-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
|
| 396 |
+
)
|
| 397 |
+
hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
|
| 398 |
+
return hidden_states
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
class PatchMerger(nn.Module):
|
| 402 |
+
def __init__(self, dim: int, context_dim: int, spatial_merge_size: int = 2) -> None:
|
| 403 |
+
super().__init__()
|
| 404 |
+
self.hidden_size = context_dim * (spatial_merge_size**2)
|
| 405 |
+
self.ln_q = LayerNorm(context_dim, eps=1e-6)
|
| 406 |
+
self.mlp = nn.Sequential(
|
| 407 |
+
nn.Linear(self.hidden_size, self.hidden_size),
|
| 408 |
+
nn.GELU(),
|
| 409 |
+
nn.Linear(self.hidden_size, dim),
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 413 |
+
x = self.mlp(self.ln_q(x).view(-1, self.hidden_size))
|
| 414 |
+
return x
|
| 415 |
+
|
| 416 |
+
class VisionMlp(nn.Module):
|
| 417 |
+
def __init__(self, dim: int, hidden_dim: int, hidden_act: str) -> None:
|
| 418 |
+
super().__init__()
|
| 419 |
+
self.fc1 = nn.Linear(dim, hidden_dim)
|
| 420 |
+
self.act = ACT2FN[hidden_act]
|
| 421 |
+
self.fc2 = nn.Linear(hidden_dim, dim)
|
| 422 |
+
|
| 423 |
+
def forward(self, x) -> torch.Tensor:
|
| 424 |
+
return self.fc2(self.act(self.fc1(x)))
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
class VisionAttention(nn.Module):
|
| 428 |
+
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
| 429 |
+
super().__init__()
|
| 430 |
+
self.num_heads = num_heads
|
| 431 |
+
self.head_dim = dim // num_heads
|
| 432 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
| 433 |
+
self.proj = nn.Linear(dim, dim)
|
| 434 |
+
|
| 435 |
+
def forward(
|
| 436 |
+
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
| 437 |
+
) -> torch.Tensor:
|
| 438 |
+
seq_length = hidden_states.shape[0]
|
| 439 |
+
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
| 440 |
+
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
| 441 |
+
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
| 442 |
+
|
| 443 |
+
attention_mask = torch.full(
|
| 444 |
+
[1, seq_length, seq_length], torch.finfo(q.dtype).min, device=q.device, dtype=q.dtype
|
| 445 |
+
)
|
| 446 |
+
for i in range(1, len(cu_seqlens)):
|
| 447 |
+
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
|
| 448 |
+
|
| 449 |
+
q = q.transpose(0, 1)
|
| 450 |
+
k = k.transpose(0, 1)
|
| 451 |
+
v = v.transpose(0, 1)
|
| 452 |
+
attn_weights = torch.matmul(q, k.transpose(1, 2)) / math.sqrt(self.head_dim)
|
| 453 |
+
attn_weights = attn_weights + attention_mask
|
| 454 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
|
| 455 |
+
attn_output = torch.matmul(attn_weights, v)
|
| 456 |
+
attn_output = attn_output.transpose(0, 1)
|
| 457 |
+
attn_output = attn_output.reshape(seq_length, -1)
|
| 458 |
+
attn_output = self.proj(attn_output)
|
| 459 |
+
return attn_output
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
class VisionFlashAttention2(nn.Module):
|
| 463 |
+
def __init__(self, dim: int, num_heads: int = 16) -> None:
|
| 464 |
+
super().__init__()
|
| 465 |
+
self.num_heads = num_heads
|
| 466 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=True)
|
| 467 |
+
self.proj = nn.Linear(dim, dim)
|
| 468 |
+
|
| 469 |
+
def forward(
|
| 470 |
+
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
|
| 471 |
+
) -> torch.Tensor:
|
| 472 |
+
seq_length = hidden_states.shape[0]
|
| 473 |
+
q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
|
| 474 |
+
q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
| 475 |
+
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
|
| 476 |
+
|
| 477 |
+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
| 478 |
+
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
|
| 479 |
+
seq_length, -1
|
| 480 |
+
)
|
| 481 |
+
attn_output = self.proj(attn_output)
|
| 482 |
+
return attn_output
|
| 483 |
+
|
| 484 |
+
QWEN2_VL_VISION_ATTENTION_CLASSES = {
|
| 485 |
+
"eager": VisionAttention,
|
| 486 |
+
"flash_attention_2": VisionFlashAttention2,
|
| 487 |
+
}
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
class Qwen2VLVisionBlock(nn.Module):
|
| 491 |
+
def __init__(self, config, attn_implementation: str = "sdpa") -> None:
|
| 492 |
+
super().__init__()
|
| 493 |
+
self.norm1 = LayerNorm(config.embed_dim, eps=1e-6)
|
| 494 |
+
self.norm2 = LayerNorm(config.embed_dim, eps=1e-6)
|
| 495 |
+
mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio)
|
| 496 |
+
|
| 497 |
+
self.attn = QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation](
|
| 498 |
+
config.embed_dim, num_heads=config.num_heads
|
| 499 |
+
)
|
| 500 |
+
self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act)
|
| 501 |
+
|
| 502 |
+
def forward(self, hidden_states, cu_seqlens, rotary_pos_emb) -> torch.Tensor:
|
| 503 |
+
hidden_states = hidden_states + self.attn(
|
| 504 |
+
self.norm1(hidden_states), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
|
| 505 |
+
)
|
| 506 |
+
hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
|
| 507 |
+
return hidden_states
|
| 508 |
+
|
| 509 |
+
class Qwen2VLPreTrainedModel(PreTrainedModel):
|
| 510 |
+
config_class = Qwen2VLConfig
|
| 511 |
+
base_model_prefix = "model"
|
| 512 |
+
supports_gradient_checkpointing = True
|
| 513 |
+
_no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock"]
|
| 514 |
+
_skip_keys_device_placement = "past_key_values"
|
| 515 |
+
_supports_flash_attn_2 = True
|
| 516 |
+
_supports_sdpa = False
|
| 517 |
+
_supports_cache_class = True
|
| 518 |
+
_supports_static_cache = True
|
| 519 |
+
|
| 520 |
+
def _init_weights(self, module):
|
| 521 |
+
std = self.config.initializer_range
|
| 522 |
+
if isinstance(module, (nn.Linear, nn.Conv3d)):
|
| 523 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 524 |
+
if module.bias is not None:
|
| 525 |
+
module.bias.data.zero_()
|
| 526 |
+
elif isinstance(module, nn.Embedding):
|
| 527 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 528 |
+
if module.padding_idx is not None:
|
| 529 |
+
module.weight.data[module.padding_idx].zero_()
|
| 530 |
+
|
| 531 |
+
class Qwen2VisionTransformerPretrainedModel(Qwen2VLPreTrainedModel):
|
| 532 |
+
config_class = Qwen2VLVisionConfig
|
| 533 |
+
_no_split_modules = ["Qwen2VLVisionBlock"]
|
| 534 |
+
|
| 535 |
+
def __init__(self, config) -> None:
|
| 536 |
+
super().__init__(config)
|
| 537 |
+
self.spatial_merge_size = config.spatial_merge_size
|
| 538 |
+
|
| 539 |
+
self.patch_embed = PatchEmbed(
|
| 540 |
+
patch_size=config.patch_size,
|
| 541 |
+
temporal_patch_size=config.temporal_patch_size,
|
| 542 |
+
in_channels=config.in_channels,
|
| 543 |
+
embed_dim=config.embed_dim,
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
head_dim = config.embed_dim // config.num_heads
|
| 547 |
+
self.rotary_pos_emb = VisionRotaryEmbedding(head_dim // 2)
|
| 548 |
+
|
| 549 |
+
self.blocks = nn.ModuleList(
|
| 550 |
+
[Qwen2VLVisionBlock(config, config.attn_implementation) for _ in range(config.depth)]
|
| 551 |
+
)
|
| 552 |
+
self.merger = PatchMerger(
|
| 553 |
+
dim=config.hidden_size, context_dim=config.embed_dim, spatial_merge_size=config.spatial_merge_size
|
| 554 |
+
)
|
| 555 |
+
# Initialize weights and apply final processing
|
| 556 |
+
self.gradient_checkpointing = False
|
| 557 |
+
self.post_init()
|
| 558 |
+
|
| 559 |
+
def get_dtype(self) -> torch.dtype:
|
| 560 |
+
return self.blocks[0].mlp.fc2.weight.dtype
|
| 561 |
+
|
| 562 |
+
def get_device(self) -> torch.device:
|
| 563 |
+
return self.blocks[0].mlp.fc2.weight.device
|
| 564 |
+
|
| 565 |
+
def rot_pos_emb(self, grid_thw):
|
| 566 |
+
pos_ids = []
|
| 567 |
+
for t, h, w in grid_thw:
|
| 568 |
+
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
|
| 569 |
+
hpos_ids = hpos_ids.reshape(
|
| 570 |
+
h // self.spatial_merge_size,
|
| 571 |
+
self.spatial_merge_size,
|
| 572 |
+
w // self.spatial_merge_size,
|
| 573 |
+
self.spatial_merge_size,
|
| 574 |
+
)
|
| 575 |
+
hpos_ids = hpos_ids.permute(0, 2, 1, 3)
|
| 576 |
+
hpos_ids = hpos_ids.flatten()
|
| 577 |
+
|
| 578 |
+
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
|
| 579 |
+
wpos_ids = wpos_ids.reshape(
|
| 580 |
+
h // self.spatial_merge_size,
|
| 581 |
+
self.spatial_merge_size,
|
| 582 |
+
w // self.spatial_merge_size,
|
| 583 |
+
self.spatial_merge_size,
|
| 584 |
+
)
|
| 585 |
+
wpos_ids = wpos_ids.permute(0, 2, 1, 3)
|
| 586 |
+
wpos_ids = wpos_ids.flatten()
|
| 587 |
+
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
|
| 588 |
+
pos_ids = torch.cat(pos_ids, dim=0)
|
| 589 |
+
max_grid_size = grid_thw[:, 1:].max()
|
| 590 |
+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
|
| 591 |
+
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
|
| 592 |
+
return rotary_pos_emb
|
| 593 |
+
|
| 594 |
+
def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor:
|
| 595 |
+
hidden_states = self.patch_embed(hidden_states)
|
| 596 |
+
rotary_pos_emb = self.rot_pos_emb(grid_thw)
|
| 597 |
+
|
| 598 |
+
cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
|
| 599 |
+
dim=0, dtype=torch.int32
|
| 600 |
+
)
|
| 601 |
+
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
|
| 602 |
+
|
| 603 |
+
for blk in self.blocks:
|
| 604 |
+
if self.gradient_checkpointing and self.training:
|
| 605 |
+
hidden_states = self._gradient_checkpointing_func(
|
| 606 |
+
blk.__call__,
|
| 607 |
+
hidden_states,
|
| 608 |
+
cu_seqlens,
|
| 609 |
+
rotary_pos_emb,
|
| 610 |
+
)
|
| 611 |
+
else:
|
| 612 |
+
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb)
|
| 613 |
+
|
| 614 |
+
return self.merger(hidden_states)
|
| 615 |
+
|
| 616 |
+
# class Qwen2RMSNorm(nn.Module):
|
| 617 |
+
# def __init__(self, hidden_size, eps=1e-6):
|
| 618 |
+
# """
|
| 619 |
+
# Qwen2RMSNorm is equivalent to T5LayerNorm
|
| 620 |
+
# """
|
| 621 |
+
# super().__init__()
|
| 622 |
+
# self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 623 |
+
# self.variance_epsilon = eps
|
| 624 |
+
# self.normalized_shape = torch.Size((hidden_size, ))
|
| 625 |
+
|
| 626 |
+
# def forward(self, hidden_states):
|
| 627 |
+
# return fused_rms_norm_affine(input=hidden_states,
|
| 628 |
+
# weight=self.weight,
|
| 629 |
+
# normalized_shape=self.normalized_shape,
|
| 630 |
+
# eps=self.variance_epsilon,
|
| 631 |
+
# memory_efficient=True)
|
| 632 |
+
|
| 633 |
+
class Qwen2RMSNorm(nn.Module):
|
| 634 |
+
def __init__(self, hidden_size, eps=1e-6):
|
| 635 |
+
"""
|
| 636 |
+
Qwen2RMSNorm is equivalent to T5LayerNorm
|
| 637 |
+
"""
|
| 638 |
+
super().__init__()
|
| 639 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 640 |
+
self.variance_epsilon = eps
|
| 641 |
+
|
| 642 |
+
def forward(self, hidden_states):
|
| 643 |
+
input_dtype = hidden_states.dtype
|
| 644 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 645 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 646 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 647 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 648 |
+
|
| 649 |
+
def extra_repr(self):
|
| 650 |
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
| 651 |
+
|
| 652 |
+
class Qwen2VLRotaryEmbedding(nn.Module):
|
| 653 |
+
def __init__(
|
| 654 |
+
self,
|
| 655 |
+
dim=None,
|
| 656 |
+
max_position_embeddings=2048,
|
| 657 |
+
base=10000,
|
| 658 |
+
device=None,
|
| 659 |
+
scaling_factor=1.0,
|
| 660 |
+
rope_type="default",
|
| 661 |
+
config: Optional[Qwen2VLConfig] = None,
|
| 662 |
+
):
|
| 663 |
+
super().__init__()
|
| 664 |
+
# TODO (joao): remove the `if` below, only used for BC
|
| 665 |
+
self.rope_kwargs = {}
|
| 666 |
+
if config is None:
|
| 667 |
+
logger.warning_once(
|
| 668 |
+
"`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the "
|
| 669 |
+
"`config` argument. All other arguments will be removed in v4.46"
|
| 670 |
+
)
|
| 671 |
+
self.rope_kwargs = {
|
| 672 |
+
"rope_type": rope_type,
|
| 673 |
+
"factor": scaling_factor,
|
| 674 |
+
"dim": dim,
|
| 675 |
+
"base": base,
|
| 676 |
+
"max_position_embeddings": max_position_embeddings,
|
| 677 |
+
}
|
| 678 |
+
self.rope_type = rope_type
|
| 679 |
+
self.max_seq_len_cached = max_position_embeddings
|
| 680 |
+
self.original_max_seq_len = max_position_embeddings
|
| 681 |
+
else:
|
| 682 |
+
# BC: "rope_type" was originally "type"
|
| 683 |
+
if config.rope_scaling is not None:
|
| 684 |
+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
|
| 685 |
+
else:
|
| 686 |
+
self.rope_type = "default"
|
| 687 |
+
self.max_seq_len_cached = config.max_position_embeddings
|
| 688 |
+
self.original_max_seq_len = config.max_position_embeddings
|
| 689 |
+
|
| 690 |
+
self.config = config
|
| 691 |
+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
| 692 |
+
|
| 693 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
|
| 694 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 695 |
+
self.original_inv_freq = self.inv_freq
|
| 696 |
+
|
| 697 |
+
def _dynamic_frequency_update(self, position_ids, device):
|
| 698 |
+
"""
|
| 699 |
+
dynamic RoPE layers should recompute `inv_freq` in the following situations:
|
| 700 |
+
1 - growing beyond the cached sequence length (allow scaling)
|
| 701 |
+
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
|
| 702 |
+
"""
|
| 703 |
+
seq_len = torch.max(position_ids) + 1
|
| 704 |
+
if seq_len > self.max_seq_len_cached: # growth
|
| 705 |
+
inv_freq, self.attention_scaling = self.rope_init_fn(
|
| 706 |
+
self.config, device, seq_len=seq_len, **self.rope_kwargs
|
| 707 |
+
)
|
| 708 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
|
| 709 |
+
self.max_seq_len_cached = seq_len
|
| 710 |
+
|
| 711 |
+
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
|
| 712 |
+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
| 713 |
+
self.max_seq_len_cached = self.original_max_seq_len
|
| 714 |
+
|
| 715 |
+
@torch.no_grad()
|
| 716 |
+
def forward(self, x, position_ids):
|
| 717 |
+
position_ids = position_ids.permute(2, 0, 1)
|
| 718 |
+
if "dynamic" in self.rope_type:
|
| 719 |
+
self._dynamic_frequency_update(position_ids, device=x.device)
|
| 720 |
+
|
| 721 |
+
# Core RoPE block. In contrast to other models, Qwen2_VL has different position ids for thw grids
|
| 722 |
+
# So we expand the inv_freq to shape (3, ...)
|
| 723 |
+
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
|
| 724 |
+
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
|
| 725 |
+
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
|
| 726 |
+
device_type = x.device.type
|
| 727 |
+
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
|
| 728 |
+
with torch.autocast(device_type=device_type, enabled=False):
|
| 729 |
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
|
| 730 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 731 |
+
cos = emb.cos()
|
| 732 |
+
sin = emb.sin()
|
| 733 |
+
|
| 734 |
+
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
|
| 735 |
+
cos = cos * self.attention_scaling
|
| 736 |
+
sin = sin * self.attention_scaling
|
| 737 |
+
|
| 738 |
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 739 |
+
|
| 740 |
+
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2MLP
|
| 741 |
+
class Qwen2MLP(nn.Module):
|
| 742 |
+
def __init__(self, config):
|
| 743 |
+
super().__init__()
|
| 744 |
+
self.hidden_size = config.hidden_size
|
| 745 |
+
self.intermediate_size = config.intermediate_size
|
| 746 |
+
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 747 |
+
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| 748 |
+
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| 749 |
+
self.act_fn = ACT2FN[config.hidden_act]
|
| 750 |
+
|
| 751 |
+
def forward(self, hidden_state):
|
| 752 |
+
return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
|
| 753 |
+
|
| 754 |
+
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
| 755 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
| 756 |
+
"""
|
| 757 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
| 758 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
| 759 |
+
"""
|
| 760 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
| 761 |
+
if n_rep == 1:
|
| 762 |
+
return hidden_states
|
| 763 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
|
| 764 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
| 765 |
+
|
| 766 |
+
class Qwen2VLAttention(nn.Module):
|
| 767 |
+
"""
|
| 768 |
+
Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
|
| 769 |
+
and "Generating Long Sequences with Sparse Transformers".
|
| 770 |
+
"""
|
| 771 |
+
|
| 772 |
+
def __init__(self, config: Qwen2VLConfig, layer_idx: Optional[int] = None):
|
| 773 |
+
super().__init__()
|
| 774 |
+
self.config = config
|
| 775 |
+
self.layer_idx = layer_idx
|
| 776 |
+
if layer_idx is None:
|
| 777 |
+
logger.warning_once(
|
| 778 |
+
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
|
| 779 |
+
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
| 780 |
+
"when creating this class."
|
| 781 |
+
)
|
| 782 |
+
|
| 783 |
+
self.hidden_size = config.hidden_size
|
| 784 |
+
self.num_heads = config.num_attention_heads
|
| 785 |
+
self.head_dim = self.hidden_size // self.num_heads
|
| 786 |
+
self.num_key_value_heads = config.num_key_value_heads
|
| 787 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
| 788 |
+
self.max_position_embeddings = config.max_position_embeddings
|
| 789 |
+
self.rope_theta = config.rope_theta
|
| 790 |
+
self.is_causal = True
|
| 791 |
+
self.attention_dropout = config.attention_dropout
|
| 792 |
+
self.rope_scaling = config.rope_scaling
|
| 793 |
+
|
| 794 |
+
if (self.head_dim * self.num_heads) != self.hidden_size:
|
| 795 |
+
raise ValueError(
|
| 796 |
+
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
| 797 |
+
f" and `num_heads`: {self.num_heads})."
|
| 798 |
+
)
|
| 799 |
+
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
|
| 800 |
+
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
| 801 |
+
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
|
| 802 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
|
| 803 |
+
|
| 804 |
+
|
| 805 |
+
class Qwen2VLFlashAttention2(Qwen2VLAttention):
|
| 806 |
+
"""
|
| 807 |
+
Qwen2VL flash attention module, following Qwen2VL attention module. This module inherits from `Qwen2VLAttention`
|
| 808 |
+
as the weights of the module stays untouched. The only required change would be on the forward pass
|
| 809 |
+
where it needs to correctly call the public API of flash attention and deal with padding tokens
|
| 810 |
+
in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
|
| 811 |
+
config.max_window_layers layers.
|
| 812 |
+
"""
|
| 813 |
+
|
| 814 |
+
def __init__(self, *args, **kwargs):
|
| 815 |
+
super().__init__(*args, **kwargs)
|
| 816 |
+
|
| 817 |
+
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
| 818 |
+
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
| 819 |
+
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
| 820 |
+
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
| 821 |
+
|
| 822 |
+
def forward(
|
| 823 |
+
self,
|
| 824 |
+
hidden_states: torch.Tensor,
|
| 825 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 826 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 827 |
+
past_key_value: Optional[Cache] = None,
|
| 828 |
+
output_attentions: bool = False,
|
| 829 |
+
use_cache: bool = False,
|
| 830 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
| 831 |
+
use_rmpad: Optional[bool] = False,
|
| 832 |
+
cu_seqlens: Optional[torch.Tensor] = False,
|
| 833 |
+
):
|
| 834 |
+
"""
|
| 835 |
+
Train:
|
| 836 |
+
unpad: (bsz, q_len) = (1, acc_seqlen)
|
| 837 |
+
pad: (bsz, q_len) = (bsz, q_len)
|
| 838 |
+
Test:
|
| 839 |
+
first_iter: (bsz, q_len) = (bsz, q_len)
|
| 840 |
+
other: (bsz, q_len) = (bsz, 1)
|
| 841 |
+
"""
|
| 842 |
+
bsz, q_len, _ = hidden_states.size()
|
| 843 |
+
|
| 844 |
+
query_states = self.q_proj(hidden_states)
|
| 845 |
+
key_states = self.k_proj(hidden_states)
|
| 846 |
+
value_states = self.v_proj(hidden_states)
|
| 847 |
+
|
| 848 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 849 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 850 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 851 |
+
|
| 852 |
+
cos, sin = position_embeddings
|
| 853 |
+
|
| 854 |
+
query_states, key_states = apply_multimodal_rotary_pos_emb(
|
| 855 |
+
query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
|
| 856 |
+
)
|
| 857 |
+
|
| 858 |
+
if past_key_value is not None:
|
| 859 |
+
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
| 860 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 861 |
+
|
| 862 |
+
# repeat k/v heads if n_kv_heads < n_heads
|
| 863 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 864 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 865 |
+
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
| 866 |
+
|
| 867 |
+
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
| 868 |
+
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
| 869 |
+
# cast them back in float16 just to be sure everything works as expected.
|
| 870 |
+
input_dtype = query_states.dtype
|
| 871 |
+
if input_dtype == torch.float32:
|
| 872 |
+
if torch.is_autocast_enabled():
|
| 873 |
+
target_dtype = torch.get_autocast_gpu_dtype()
|
| 874 |
+
# Handle the case where the model is quantized
|
| 875 |
+
elif hasattr(self.config, "_pre_quantization_dtype"):
|
| 876 |
+
target_dtype = self.config._pre_quantization_dtype
|
| 877 |
+
else:
|
| 878 |
+
target_dtype = self.q_proj.weight.dtype
|
| 879 |
+
|
| 880 |
+
logger.warning_once(
|
| 881 |
+
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
| 882 |
+
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
| 883 |
+
f" {target_dtype}."
|
| 884 |
+
)
|
| 885 |
+
|
| 886 |
+
query_states = query_states.to(target_dtype)
|
| 887 |
+
key_states = key_states.to(target_dtype)
|
| 888 |
+
value_states = value_states.to(target_dtype)
|
| 889 |
+
|
| 890 |
+
# Reashape to the expected shape for Flash Attention
|
| 891 |
+
query_states = query_states.transpose(1, 2)
|
| 892 |
+
key_states = key_states.transpose(1, 2)
|
| 893 |
+
value_states = value_states.transpose(1, 2)
|
| 894 |
+
|
| 895 |
+
if use_rmpad:
|
| 896 |
+
max_seqlen = torch.max(cu_seqlens[1:] - cu_seqlens[:-1]).item()
|
| 897 |
+
attn_output = flash_attn_varlen_func(
|
| 898 |
+
query_states.squeeze(0), key_states.squeeze(0), value_states.squeeze(0),
|
| 899 |
+
cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens,
|
| 900 |
+
max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen,
|
| 901 |
+
dropout_p=dropout_rate,
|
| 902 |
+
causal=self.is_causal, window_size=(-1, -1),
|
| 903 |
+
)
|
| 904 |
+
else:
|
| 905 |
+
attn_output = _flash_attention_forward(
|
| 906 |
+
query_states, key_states, value_states,
|
| 907 |
+
attention_mask,
|
| 908 |
+
q_len,
|
| 909 |
+
dropout=dropout_rate,
|
| 910 |
+
sliding_window=None,
|
| 911 |
+
is_causal=self.is_causal,
|
| 912 |
+
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
| 913 |
+
)
|
| 914 |
+
|
| 915 |
+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
| 916 |
+
attn_output = self.o_proj(attn_output)
|
| 917 |
+
|
| 918 |
+
if not output_attentions:
|
| 919 |
+
attn_weights = None
|
| 920 |
+
|
| 921 |
+
return attn_output, attn_weights, past_key_value
|
| 922 |
+
|
| 923 |
+
QWEN2_VL_ATTENTION_CLASSES = {
|
| 924 |
+
"flash_attention_2": Qwen2VLFlashAttention2,
|
| 925 |
+
}
|
| 926 |
+
|
| 927 |
+
class Qwen2VLDecoderLayer(nn.Module):
|
| 928 |
+
def __init__(self, config: Qwen2VLConfig, layer_idx: int):
|
| 929 |
+
super().__init__()
|
| 930 |
+
self.hidden_size = config.hidden_size
|
| 931 |
+
|
| 932 |
+
if config.attn_implementation != "flash_attention_2":
|
| 933 |
+
logger.error(
|
| 934 |
+
f"只支持 flash_attention_2!config.attn_implementation={config.attn_implementation}"
|
| 935 |
+
)
|
| 936 |
+
self.self_attn = QWEN2_VL_ATTENTION_CLASSES[config.attn_implementation](config, layer_idx)
|
| 937 |
+
|
| 938 |
+
self.mlp = Qwen2MLP(config)
|
| 939 |
+
self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 940 |
+
self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 941 |
+
|
| 942 |
+
def forward(
|
| 943 |
+
self,
|
| 944 |
+
hidden_states: torch.Tensor,
|
| 945 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 946 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 947 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
| 948 |
+
output_attentions: Optional[bool] = False,
|
| 949 |
+
use_cache: Optional[bool] = False,
|
| 950 |
+
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
|
| 951 |
+
use_rmpad: Optional[bool] = False,
|
| 952 |
+
cu_seqlens: Optional[torch.Tensor] = False,
|
| 953 |
+
**kwargs,
|
| 954 |
+
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
| 955 |
+
"""
|
| 956 |
+
Args:
|
| 957 |
+
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
| 958 |
+
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
| 959 |
+
`(batch, sequence_length)` where padding elements are indicated by 0.
|
| 960 |
+
output_attentions (`bool`, *optional*):
|
| 961 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
| 962 |
+
returned tensors for more detail.
|
| 963 |
+
use_cache (`bool`, *optional*):
|
| 964 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
| 965 |
+
(see `past_key_values`).
|
| 966 |
+
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
| 967 |
+
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
| 968 |
+
Indices depicting the position of the input sequence tokens in the sequence.
|
| 969 |
+
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
|
| 970 |
+
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
|
| 971 |
+
with `head_dim` being the embedding dimension of each attention head.
|
| 972 |
+
kwargs (`dict`, *optional*):
|
| 973 |
+
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
|
| 974 |
+
into the model
|
| 975 |
+
"""
|
| 976 |
+
|
| 977 |
+
residual = hidden_states
|
| 978 |
+
|
| 979 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 980 |
+
|
| 981 |
+
# Self Attention
|
| 982 |
+
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
| 983 |
+
hidden_states=hidden_states,
|
| 984 |
+
attention_mask=attention_mask,
|
| 985 |
+
position_ids=position_ids,
|
| 986 |
+
past_key_value=past_key_value,
|
| 987 |
+
output_attentions=output_attentions,
|
| 988 |
+
use_cache=use_cache,
|
| 989 |
+
position_embeddings=position_embeddings,
|
| 990 |
+
use_rmpad=use_rmpad,
|
| 991 |
+
cu_seqlens=cu_seqlens,
|
| 992 |
+
)
|
| 993 |
+
hidden_states = residual + hidden_states
|
| 994 |
+
|
| 995 |
+
# Fully Connected
|
| 996 |
+
residual = hidden_states
|
| 997 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 998 |
+
hidden_states = self.mlp(hidden_states)
|
| 999 |
+
hidden_states = residual + hidden_states
|
| 1000 |
+
|
| 1001 |
+
outputs = (hidden_states,)
|
| 1002 |
+
|
| 1003 |
+
if output_attentions:
|
| 1004 |
+
outputs += (self_attn_weights,)
|
| 1005 |
+
|
| 1006 |
+
if use_cache:
|
| 1007 |
+
outputs += (present_key_value,)
|
| 1008 |
+
|
| 1009 |
+
return outputs
|
| 1010 |
+
|
| 1011 |
+
class Qwen2VLModel(Qwen2VLPreTrainedModel):
|
| 1012 |
+
def __init__(self, config: Qwen2VLConfig):
|
| 1013 |
+
super().__init__(config)
|
| 1014 |
+
self.padding_idx = config.pad_token_id
|
| 1015 |
+
self.vocab_size = config.vocab_size
|
| 1016 |
+
|
| 1017 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
|
| 1018 |
+
self.layers = nn.ModuleList([Qwen2VLDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
|
| 1019 |
+
self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
| 1020 |
+
self.rotary_emb = Qwen2VLRotaryEmbedding(config=config)
|
| 1021 |
+
|
| 1022 |
+
self.gradient_checkpointing = False
|
| 1023 |
+
# Initialize weights and apply final processing
|
| 1024 |
+
self.post_init()
|
| 1025 |
+
|
| 1026 |
+
def get_input_embeddings(self):
|
| 1027 |
+
return self.embed_tokens
|
| 1028 |
+
|
| 1029 |
+
def set_input_embeddings(self, value):
|
| 1030 |
+
self.embed_tokens = value
|
| 1031 |
+
|
| 1032 |
+
def forward(
|
| 1033 |
+
self,
|
| 1034 |
+
input_ids: torch.LongTensor = None,
|
| 1035 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1036 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1037 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 1038 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1039 |
+
use_cache: Optional[bool] = None,
|
| 1040 |
+
output_attentions: Optional[bool] = None,
|
| 1041 |
+
output_hidden_states: Optional[bool] = None,
|
| 1042 |
+
return_dict: Optional[bool] = None,
|
| 1043 |
+
use_rmpad: Optional[bool] = False,
|
| 1044 |
+
cu_seqlens: Optional[torch.Tensor] = False,
|
| 1045 |
+
) -> Union[Tuple, BaseModelOutputWithPast]:
|
| 1046 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1047 |
+
output_hidden_states = (
|
| 1048 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1049 |
+
)
|
| 1050 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 1051 |
+
|
| 1052 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1053 |
+
|
| 1054 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 1055 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 1056 |
+
|
| 1057 |
+
if self.gradient_checkpointing and self.training:
|
| 1058 |
+
if use_cache:
|
| 1059 |
+
logger.warning_once(
|
| 1060 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 1061 |
+
)
|
| 1062 |
+
use_cache = False
|
| 1063 |
+
|
| 1064 |
+
|
| 1065 |
+
hidden_states = inputs_embeds
|
| 1066 |
+
|
| 1067 |
+
# create position embeddings to be shared across the decoder layers
|
| 1068 |
+
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 1069 |
+
|
| 1070 |
+
# decoder layers
|
| 1071 |
+
all_hidden_states = () if output_hidden_states else None
|
| 1072 |
+
all_self_attns = () if output_attentions else None
|
| 1073 |
+
next_decoder_cache = None
|
| 1074 |
+
|
| 1075 |
+
for decoder_layer in self.layers:
|
| 1076 |
+
if output_hidden_states:
|
| 1077 |
+
all_hidden_states += (hidden_states,)
|
| 1078 |
+
|
| 1079 |
+
if self.gradient_checkpointing and self.training:
|
| 1080 |
+
layer_outputs = self._gradient_checkpointing_func(
|
| 1081 |
+
decoder_layer.__call__,
|
| 1082 |
+
hidden_states,
|
| 1083 |
+
attention_mask,
|
| 1084 |
+
position_ids,
|
| 1085 |
+
past_key_values,
|
| 1086 |
+
output_attentions,
|
| 1087 |
+
use_cache,
|
| 1088 |
+
position_embeddings,
|
| 1089 |
+
use_rmpad,
|
| 1090 |
+
cu_seqlens,
|
| 1091 |
+
)
|
| 1092 |
+
else:
|
| 1093 |
+
layer_outputs = decoder_layer(
|
| 1094 |
+
hidden_states,
|
| 1095 |
+
attention_mask=attention_mask,
|
| 1096 |
+
position_ids=position_ids,
|
| 1097 |
+
past_key_value=past_key_values,
|
| 1098 |
+
output_attentions=output_attentions,
|
| 1099 |
+
use_cache=use_cache,
|
| 1100 |
+
position_embeddings=position_embeddings,
|
| 1101 |
+
use_rmpad=use_rmpad,
|
| 1102 |
+
cu_seqlens=cu_seqlens,
|
| 1103 |
+
)
|
| 1104 |
+
|
| 1105 |
+
hidden_states = layer_outputs[0]
|
| 1106 |
+
|
| 1107 |
+
if use_cache:
|
| 1108 |
+
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
| 1109 |
+
|
| 1110 |
+
if output_attentions:
|
| 1111 |
+
all_self_attns += (layer_outputs[1],)
|
| 1112 |
+
|
| 1113 |
+
hidden_states = self.norm(hidden_states)
|
| 1114 |
+
|
| 1115 |
+
# add hidden states from the last decoder layer
|
| 1116 |
+
if output_hidden_states:
|
| 1117 |
+
all_hidden_states += (hidden_states,)
|
| 1118 |
+
|
| 1119 |
+
next_cache = next_decoder_cache if use_cache else None
|
| 1120 |
+
|
| 1121 |
+
if not return_dict:
|
| 1122 |
+
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
| 1123 |
+
return BaseModelOutputWithPast(
|
| 1124 |
+
last_hidden_state=hidden_states,
|
| 1125 |
+
past_key_values=next_cache,
|
| 1126 |
+
hidden_states=all_hidden_states,
|
| 1127 |
+
attentions=all_self_attns,
|
| 1128 |
+
)
|
| 1129 |
+
|
| 1130 |
+
class Qwen2VLForCausalLM(Qwen2VLPreTrainedModel, GenerationMixin):
|
| 1131 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 1132 |
+
|
| 1133 |
+
def __init__(self, config):
|
| 1134 |
+
super().__init__(config)
|
| 1135 |
+
self.model = Qwen2VLModel(config)
|
| 1136 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 1137 |
+
self.padding_side = "left" # set it to left by default, user can use setter to change padding_sides
|
| 1138 |
+
|
| 1139 |
+
# Initialize weights and apply final processing
|
| 1140 |
+
self.post_init()
|
| 1141 |
+
|
| 1142 |
+
def get_input_embeddings(self):
|
| 1143 |
+
return self.model.embed_tokens
|
| 1144 |
+
|
| 1145 |
+
def set_input_embeddings(self, value):
|
| 1146 |
+
self.model.embed_tokens = value
|
| 1147 |
+
|
| 1148 |
+
def get_output_embeddings(self):
|
| 1149 |
+
return self.lm_head
|
| 1150 |
+
|
| 1151 |
+
def set_output_embeddings(self, new_embeddings):
|
| 1152 |
+
self.lm_head = new_embeddings
|
| 1153 |
+
|
| 1154 |
+
def set_decoder(self, decoder):
|
| 1155 |
+
self.model = decoder
|
| 1156 |
+
|
| 1157 |
+
def get_decoder(self):
|
| 1158 |
+
return self.model
|
| 1159 |
+
|
| 1160 |
+
def get_rope_index(
|
| 1161 |
+
self,
|
| 1162 |
+
input_ids: torch.LongTensor,
|
| 1163 |
+
image_grid_thw: Optional[torch.LongTensor] = None,
|
| 1164 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1165 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 1166 |
+
"""
|
| 1167 |
+
Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
|
| 1168 |
+
|
| 1169 |
+
Explanation:
|
| 1170 |
+
Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
|
| 1171 |
+
|
| 1172 |
+
For pure text embedding sequence, the rotary position embedding has no difference with mordern LLMs.
|
| 1173 |
+
Examples:
|
| 1174 |
+
input_ids: [T T T T T], here T is for text.
|
| 1175 |
+
temporal position_ids: [0, 1, 2, 3, 4]
|
| 1176 |
+
height position_ids: [0, 1, 2, 3, 4]
|
| 1177 |
+
width position_ids: [0, 1, 2, 3, 4]
|
| 1178 |
+
|
| 1179 |
+
For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
|
| 1180 |
+
and 1D rotary position embeddin for text part.
|
| 1181 |
+
Examples:
|
| 1182 |
+
Assume we have a video input with 3 temporal patches, 2 height patches and 2 width patches.
|
| 1183 |
+
input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
|
| 1184 |
+
vision temporal position_ids: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]
|
| 1185 |
+
vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
|
| 1186 |
+
vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
|
| 1187 |
+
text temporal position_ids: [3, 4, 5, 6, 7]
|
| 1188 |
+
text height position_ids: [3, 4, 5, 6, 7]
|
| 1189 |
+
text width position_ids: [3, 4, 5, 6, 7]
|
| 1190 |
+
Here we calculate the text start position_ids as the max vision position_ids plus 1.
|
| 1191 |
+
|
| 1192 |
+
Args:
|
| 1193 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
| 1194 |
+
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
|
| 1195 |
+
it.
|
| 1196 |
+
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
| 1197 |
+
The temporal, height and width of feature shape of each image in LLM.
|
| 1198 |
+
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
|
| 1199 |
+
The temporal, height and width of feature shape of each video in LLM.
|
| 1200 |
+
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 1201 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
| 1202 |
+
|
| 1203 |
+
- 1 for tokens that are **not masked**,
|
| 1204 |
+
- 0 for tokens that are **masked**.
|
| 1205 |
+
|
| 1206 |
+
Returns:
|
| 1207 |
+
position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
|
| 1208 |
+
mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
|
| 1209 |
+
"""
|
| 1210 |
+
spatial_merge_size = self.config.spatial_merge_size
|
| 1211 |
+
vision_token_id = self.config.image_token_id
|
| 1212 |
+
vision_start_token_id = self.config.vision_start_token_id
|
| 1213 |
+
assert image_grid_thw is not None # TODO:测试纯文本会不会卡住
|
| 1214 |
+
total_input_ids = input_ids
|
| 1215 |
+
position_ids = torch.ones(
|
| 1216 |
+
3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
|
| 1217 |
+
)
|
| 1218 |
+
vision_index = 0
|
| 1219 |
+
for i, input_ids in enumerate(total_input_ids):
|
| 1220 |
+
if attention_mask is not None:
|
| 1221 |
+
input_ids = input_ids[attention_mask[i] == 1]
|
| 1222 |
+
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1)
|
| 1223 |
+
vision_num = (input_ids[vision_start_indices + 1] == vision_token_id).sum()
|
| 1224 |
+
input_tokens = input_ids.tolist()
|
| 1225 |
+
llm_pos_ids_list: list = []
|
| 1226 |
+
st = 0
|
| 1227 |
+
remain_vision_num = vision_num
|
| 1228 |
+
for _ in range(vision_num):
|
| 1229 |
+
if vision_token_id in input_tokens and remain_vision_num > 0:
|
| 1230 |
+
ed_vision = input_tokens.index(vision_token_id, st)
|
| 1231 |
+
else:
|
| 1232 |
+
ed_vision = len(input_tokens) + 1
|
| 1233 |
+
|
| 1234 |
+
t, h, w = (
|
| 1235 |
+
image_grid_thw[vision_index][0],
|
| 1236 |
+
image_grid_thw[vision_index][1],
|
| 1237 |
+
image_grid_thw[vision_index][2],
|
| 1238 |
+
)
|
| 1239 |
+
vision_index += 1
|
| 1240 |
+
remain_vision_num -= 1
|
| 1241 |
+
ed = ed_vision
|
| 1242 |
+
|
| 1243 |
+
llm_grid_t, llm_grid_h, llm_grid_w = (
|
| 1244 |
+
t.item(),
|
| 1245 |
+
h.item() // spatial_merge_size,
|
| 1246 |
+
w.item() // spatial_merge_size,
|
| 1247 |
+
)
|
| 1248 |
+
text_len = ed - st
|
| 1249 |
+
|
| 1250 |
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
| 1251 |
+
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
| 1252 |
+
|
| 1253 |
+
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
|
| 1254 |
+
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
|
| 1255 |
+
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
|
| 1256 |
+
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
|
| 1257 |
+
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
| 1258 |
+
|
| 1259 |
+
if st < len(input_tokens):
|
| 1260 |
+
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
| 1261 |
+
text_len = len(input_tokens) - st
|
| 1262 |
+
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
| 1263 |
+
|
| 1264 |
+
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
| 1265 |
+
position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
|
| 1266 |
+
position_ids = position_ids.permute(1, 2, 0)
|
| 1267 |
+
return position_ids
|
| 1268 |
+
|
| 1269 |
+
def forward(
|
| 1270 |
+
self,
|
| 1271 |
+
input_ids: torch.LongTensor = None,
|
| 1272 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1273 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 1274 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 1275 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 1276 |
+
labels: Optional[torch.LongTensor] = None,
|
| 1277 |
+
use_cache: Optional[bool] = None,
|
| 1278 |
+
output_attentions: Optional[bool] = None,
|
| 1279 |
+
output_hidden_states: Optional[bool] = None,
|
| 1280 |
+
return_dict: Optional[bool] = None,
|
| 1281 |
+
use_rmpad: Optional[bool] = False,
|
| 1282 |
+
cu_seqlens: Optional[torch.Tensor] = False,
|
| 1283 |
+
) -> Union[Tuple, Qwen2VLCausalLMOutputWithPast]:
|
| 1284 |
+
|
| 1285 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 1286 |
+
output_hidden_states = (
|
| 1287 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 1288 |
+
)
|
| 1289 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1290 |
+
|
| 1291 |
+
|
| 1292 |
+
outputs = self.model(
|
| 1293 |
+
input_ids=input_ids,
|
| 1294 |
+
attention_mask=attention_mask,
|
| 1295 |
+
position_ids=position_ids,
|
| 1296 |
+
past_key_values=past_key_values,
|
| 1297 |
+
inputs_embeds=inputs_embeds,
|
| 1298 |
+
use_cache=use_cache,
|
| 1299 |
+
output_attentions=output_attentions,
|
| 1300 |
+
output_hidden_states=output_hidden_states,
|
| 1301 |
+
return_dict=return_dict,
|
| 1302 |
+
use_rmpad=use_rmpad,
|
| 1303 |
+
cu_seqlens=cu_seqlens,
|
| 1304 |
+
)
|
| 1305 |
+
|
| 1306 |
+
hidden_states = outputs[0]
|
| 1307 |
+
logits = self.lm_head(hidden_states)
|
| 1308 |
+
|
| 1309 |
+
if not return_dict:
|
| 1310 |
+
output = (logits,) + outputs[1:]
|
| 1311 |
+
return output
|
| 1312 |
+
|
| 1313 |
+
return Qwen2VLCausalLMOutputWithPast(
|
| 1314 |
+
logits=logits,
|
| 1315 |
+
past_key_values=outputs.past_key_values,
|
| 1316 |
+
hidden_states=outputs.hidden_states,
|
| 1317 |
+
attentions=outputs.attentions,
|
| 1318 |
+
)
|
| 1319 |
+
|
| 1320 |
+
|
eval_scripts/DREAM-1K/tarsier/models/modeling_tarsier.py
ADDED
|
@@ -0,0 +1,502 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import List, Optional, Tuple, Union, Dict, Any
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import torch.utils.checkpoint
|
| 6 |
+
from torch import nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from transformers import PreTrainedModel, AutoConfig, AutoModel
|
| 10 |
+
from transformers.activations import ACT2FN
|
| 11 |
+
from transformers.cache_utils import Cache
|
| 12 |
+
from transformers.modeling_outputs import ModelOutput
|
| 13 |
+
from transformers.utils import logging
|
| 14 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 15 |
+
from transformers.dynamic_module_utils import get_class_from_dynamic_module
|
| 16 |
+
from transformers.models.auto import AutoModel, AutoModelForCausalLM, CONFIG_MAPPING
|
| 17 |
+
from transformers.generation import GenerationMixin
|
| 18 |
+
|
| 19 |
+
from transformers import LlamaForCausalLM, Qwen2ForCausalLM
|
| 20 |
+
# from models.modeling_qwen2 import Qwen2ForCausalLM
|
| 21 |
+
from models.modeling_qwen2_vl_fast import Qwen2VLForCausalLM
|
| 22 |
+
from models.utils import _pad_input, _unpad_input
|
| 23 |
+
|
| 24 |
+
logger = logging.get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class LlavaConfig(PretrainedConfig):
|
| 28 |
+
|
| 29 |
+
model_type = "llava"
|
| 30 |
+
is_composition = False
|
| 31 |
+
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
vision_config=None,
|
| 35 |
+
text_config=None,
|
| 36 |
+
ignore_index=-100,
|
| 37 |
+
image_token_index=32000,
|
| 38 |
+
projector_hidden_act="gelu",
|
| 39 |
+
vision_feature_select_strategy="default",
|
| 40 |
+
vision_feature_layer=-2,
|
| 41 |
+
image_newline_idx=32002,
|
| 42 |
+
image_new_idx=32003,
|
| 43 |
+
projection_head="MLP",
|
| 44 |
+
**kwargs,
|
| 45 |
+
):
|
| 46 |
+
self.ignore_index = ignore_index
|
| 47 |
+
self.image_token_index = image_token_index
|
| 48 |
+
self.projector_hidden_act = projector_hidden_act
|
| 49 |
+
self.vision_feature_select_strategy = vision_feature_select_strategy
|
| 50 |
+
self.vision_feature_layer = vision_feature_layer
|
| 51 |
+
self.image_newline_idx = image_newline_idx
|
| 52 |
+
self.image_new_idx = image_new_idx
|
| 53 |
+
self.projection_head = projection_head
|
| 54 |
+
|
| 55 |
+
self.vision_config = vision_config
|
| 56 |
+
|
| 57 |
+
if isinstance(self.vision_config, dict):
|
| 58 |
+
vision_config["model_type"] = (
|
| 59 |
+
vision_config["model_type"] if "model_type" in vision_config else "clip_vision_model"
|
| 60 |
+
)
|
| 61 |
+
if 'auto_map' in vision_config:
|
| 62 |
+
repo_id, class_ref = vision_config['auto_map']['AutoConfig'].split("--")
|
| 63 |
+
config_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs)
|
| 64 |
+
self.vision_config = config_class(**vision_config)
|
| 65 |
+
elif vision_config["model_type"] in CONFIG_MAPPING:
|
| 66 |
+
self.vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
|
| 67 |
+
else:
|
| 68 |
+
raise ValueError(f'vision_config["model_type"] = {vision_config["model_type"]} not supported!')
|
| 69 |
+
|
| 70 |
+
self.text_config = text_config
|
| 71 |
+
|
| 72 |
+
if isinstance(self.text_config, dict):
|
| 73 |
+
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "llama"
|
| 74 |
+
if 'auto_map' in text_config:
|
| 75 |
+
repo_id, class_ref = text_config['auto_map']['AutoConfig'].split("--")
|
| 76 |
+
config_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs)
|
| 77 |
+
self.text_config = config_class(**text_config)
|
| 78 |
+
elif text_config["model_type"] in CONFIG_MAPPING:
|
| 79 |
+
self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
|
| 80 |
+
else:
|
| 81 |
+
raise ValueError(f'text_config["model_type"] = {text_config["model_type"]} not supported!')
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
super().__init__(**kwargs)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@dataclass
|
| 89 |
+
# Copied from transformers.models.idefics.modeling_idefics.IdeficsCausalLMOutputWithPast with Idefics->Llava
|
| 90 |
+
class LlavaCausalLMOutputWithPast(ModelOutput):
|
| 91 |
+
|
| 92 |
+
loss: Optional[torch.FloatTensor] = None
|
| 93 |
+
logits: torch.FloatTensor = None
|
| 94 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None
|
| 95 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 96 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 97 |
+
position_ids: Optional[torch.LongTensor] = None
|
| 98 |
+
|
| 99 |
+
def add_split_tokens(image_features, image_newline_embed, image_new_embed):
|
| 100 |
+
num_images, num_image_patches, embed_dim = image_features.shape
|
| 101 |
+
num_height_patches, num_width_patches = int(math.sqrt(num_image_patches)), int(math.sqrt(num_image_patches))
|
| 102 |
+
|
| 103 |
+
# add image_newline
|
| 104 |
+
image_features = image_features.view(num_images, num_height_patches, num_width_patches, embed_dim)
|
| 105 |
+
image_features = torch.cat([
|
| 106 |
+
image_features,
|
| 107 |
+
image_newline_embed.expand((num_images, num_height_patches, 1, embed_dim))
|
| 108 |
+
], dim=2)
|
| 109 |
+
num_image_patches += num_height_patches
|
| 110 |
+
image_features = image_features.view(num_images, num_image_patches, embed_dim)
|
| 111 |
+
|
| 112 |
+
# add image_new
|
| 113 |
+
image_features = torch.cat([
|
| 114 |
+
image_features,
|
| 115 |
+
image_new_embed.expand((num_images, 1, embed_dim))
|
| 116 |
+
], dim = 1)
|
| 117 |
+
|
| 118 |
+
return image_features
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class LlavaMultiModalProjector(nn.Module):
|
| 122 |
+
def __init__(self, config: LlavaConfig):
|
| 123 |
+
super().__init__()
|
| 124 |
+
self.config = config
|
| 125 |
+
|
| 126 |
+
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size, bias=True)
|
| 127 |
+
self.act = ACT2FN[config.projector_hidden_act]
|
| 128 |
+
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True)
|
| 129 |
+
|
| 130 |
+
image_newline_idx = torch.tensor([config.image_newline_idx], dtype=torch.long)
|
| 131 |
+
image_new_idx = torch.tensor([config.image_new_idx], dtype=torch.long)
|
| 132 |
+
self.register_buffer('image_newline_idx', image_newline_idx, persistent=False)
|
| 133 |
+
self.register_buffer('image_new_idx', image_new_idx, persistent=False)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def forward(self, image_features, input_embeddings):
|
| 137 |
+
|
| 138 |
+
selected_image_feature = image_features[self.config.vision_feature_layer]
|
| 139 |
+
|
| 140 |
+
if self.config.vision_feature_select_strategy == "default":
|
| 141 |
+
selected_image_feature = selected_image_feature[:, 1:]
|
| 142 |
+
elif self.config.vision_feature_select_strategy == "full":
|
| 143 |
+
selected_image_feature = selected_image_feature
|
| 144 |
+
else:
|
| 145 |
+
raise ValueError(
|
| 146 |
+
f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
hidden_states = self.linear_1(selected_image_feature)
|
| 150 |
+
hidden_states = self.act(hidden_states)
|
| 151 |
+
hidden_states = self.linear_2(hidden_states)
|
| 152 |
+
|
| 153 |
+
image_newline_embed = input_embeddings(self.image_newline_idx).squeeze()
|
| 154 |
+
image_new_embed = input_embeddings(self.image_new_idx).squeeze()
|
| 155 |
+
hidden_states = add_split_tokens(hidden_states, image_newline_embed, image_new_embed)
|
| 156 |
+
return hidden_states
|
| 157 |
+
|
| 158 |
+
class PixelShuffleMultiModalProjector(nn.Module):
|
| 159 |
+
def __init__(self, config: LlavaConfig):
|
| 160 |
+
super().__init__()
|
| 161 |
+
self.config = config
|
| 162 |
+
|
| 163 |
+
self.downsample_ratio = 0.5
|
| 164 |
+
vit_hidden_size = config.vision_config.hidden_size
|
| 165 |
+
llm_hidden_size = config.text_config.hidden_size
|
| 166 |
+
|
| 167 |
+
self.mlp = nn.Sequential(
|
| 168 |
+
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
|
| 169 |
+
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size),
|
| 170 |
+
nn.GELU(),
|
| 171 |
+
nn.Linear(llm_hidden_size, llm_hidden_size)
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
image_newline_idx = torch.tensor([config.image_newline_idx], dtype=torch.long)
|
| 175 |
+
image_new_idx = torch.tensor([config.image_new_idx], dtype=torch.long)
|
| 176 |
+
self.register_buffer('image_newline_idx', image_newline_idx, persistent=False)
|
| 177 |
+
self.register_buffer('image_new_idx', image_new_idx, persistent=False)
|
| 178 |
+
|
| 179 |
+
def forward(self, image_features, input_embeddings):
|
| 180 |
+
selected_image_feature = image_features[self.config.vision_feature_layer]
|
| 181 |
+
|
| 182 |
+
if self.config.vision_feature_select_strategy == "default":
|
| 183 |
+
selected_image_feature = selected_image_feature[:, 1:]
|
| 184 |
+
elif self.config.vision_feature_select_strategy == "full":
|
| 185 |
+
selected_image_feature = selected_image_feature
|
| 186 |
+
else:
|
| 187 |
+
raise ValueError(
|
| 188 |
+
f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
image_features = self.pixel_shuffle(selected_image_feature)
|
| 192 |
+
hidden_states = self.mlp(image_features)
|
| 193 |
+
|
| 194 |
+
image_newline_embed = input_embeddings(self.image_newline_idx).squeeze()
|
| 195 |
+
image_new_embed = input_embeddings(self.image_new_idx).squeeze()
|
| 196 |
+
hidden_states = add_split_tokens(hidden_states, image_newline_embed, image_new_embed)
|
| 197 |
+
|
| 198 |
+
return hidden_states
|
| 199 |
+
|
| 200 |
+
def pixel_shuffle(self, x, scale_factor=0.5):
|
| 201 |
+
if scale_factor == 1:
|
| 202 |
+
return x
|
| 203 |
+
n, wh, c = x.shape
|
| 204 |
+
h, w = int(math.sqrt(wh)), int(math.sqrt(wh))
|
| 205 |
+
x = x.view(n, h, w, c)
|
| 206 |
+
|
| 207 |
+
n, w, h, c = x.size()
|
| 208 |
+
# N, W, H, C --> N, W, H * scale, C // scale
|
| 209 |
+
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
|
| 210 |
+
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
|
| 211 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
| 212 |
+
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
|
| 213 |
+
x = x.view(n, int(h * scale_factor), int(w * scale_factor),
|
| 214 |
+
int(c / (scale_factor * scale_factor)))
|
| 215 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
| 216 |
+
x = x.view(x.shape[0], -1, x.shape[-1])
|
| 217 |
+
return x
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
LLAVA_START_DOCSTRING = r"""
|
| 221 |
+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
| 222 |
+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
| 223 |
+
etc.)
|
| 224 |
+
|
| 225 |
+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
| 226 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
| 227 |
+
and behavior.
|
| 228 |
+
|
| 229 |
+
Parameters:
|
| 230 |
+
config ([`LlavaConfig`] or [`LlavaVisionConfig`]):
|
| 231 |
+
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
| 232 |
+
load the weights associated with the model, only the configuration. Check out the
|
| 233 |
+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
| 234 |
+
"""
|
| 235 |
+
|
| 236 |
+
class TarsierPreTrainedModel(PreTrainedModel):
|
| 237 |
+
config_class = LlavaConfig
|
| 238 |
+
base_model_prefix = "llm"
|
| 239 |
+
supports_gradient_checkpointing = True # TODO: support latest gc
|
| 240 |
+
_skip_keys_device_placement = "past_key_values"
|
| 241 |
+
_supports_flash_attn_2 = True
|
| 242 |
+
_supports_sdpa = False
|
| 243 |
+
_supports_cache_class = True # TODO: support different cache
|
| 244 |
+
_supports_static_cache = True
|
| 245 |
+
|
| 246 |
+
def _init_weights(self, module):
|
| 247 |
+
std = (
|
| 248 |
+
self.config.initializer_range
|
| 249 |
+
if hasattr(self.config, "initializer_range")
|
| 250 |
+
else self.config.text_config.initializer_range
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
if hasattr(module, "class_embedding"):
|
| 254 |
+
module.class_embedding.data.normal_(mean=0.0, std=std)
|
| 255 |
+
|
| 256 |
+
if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv3d)):
|
| 257 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 258 |
+
if module.bias is not None:
|
| 259 |
+
module.bias.data.zero_()
|
| 260 |
+
elif isinstance(module, nn.Embedding):
|
| 261 |
+
module.weight.data.normal_(mean=0.0, std=std)
|
| 262 |
+
if module.padding_idx is not None:
|
| 263 |
+
module.weight.data[module.padding_idx].zero_()
|
| 264 |
+
elif isinstance(module, nn.LayerNorm):
|
| 265 |
+
module.weight.data.fill_(1.0)
|
| 266 |
+
if module.bias is not None:
|
| 267 |
+
module.bias.data.zero_()
|
| 268 |
+
@property
|
| 269 |
+
def _no_split_modules(self):
|
| 270 |
+
return self.language_model._no_split_modules + self.vision_tower._no_split_modules
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class TarsierForConditionalGeneration(TarsierPreTrainedModel, GenerationMixin):
|
| 274 |
+
def __init__(self, config: LlavaConfig):
|
| 275 |
+
super().__init__(config)
|
| 276 |
+
self.vision_tower = AutoModel.from_config(config.vision_config, trust_remote_code=True)
|
| 277 |
+
if config.text_config.model_type == 'qwen2':
|
| 278 |
+
self.language_model = Qwen2ForCausalLM(config.text_config)
|
| 279 |
+
elif config.text_config.model_type == 'qwen2_vl':
|
| 280 |
+
self.language_model = Qwen2VLForCausalLM(config.text_config)
|
| 281 |
+
elif config.text_config.model_type == 'llama':
|
| 282 |
+
self.language_model = LlamaForCausalLM(config.text_config)
|
| 283 |
+
else:
|
| 284 |
+
raise ValueError(f'{config.text_config.model_type} not supported!')
|
| 285 |
+
|
| 286 |
+
if config.projection_head == 'Pixel_Shuffle':
|
| 287 |
+
self.multi_modal_projector = PixelShuffleMultiModalProjector(config)
|
| 288 |
+
elif config.projection_head == 'MLP':
|
| 289 |
+
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
| 290 |
+
elif config.projection_head == 'auto_map':
|
| 291 |
+
repo_id, class_ref = config.auto_map['ProjectionLayer'].split("--")
|
| 292 |
+
model_class = get_class_from_dynamic_module(class_ref, repo_id)
|
| 293 |
+
self.multi_modal_projector = model_class(config)
|
| 294 |
+
elif config.projection_head is None:
|
| 295 |
+
self.multi_modal_projector = lambda x, *args, **kwargs: x
|
| 296 |
+
|
| 297 |
+
self.post_init()
|
| 298 |
+
|
| 299 |
+
def get_input_embeddings(self):
|
| 300 |
+
return self.language_model.get_input_embeddings()
|
| 301 |
+
|
| 302 |
+
def set_input_embeddings(self, value):
|
| 303 |
+
self.language_model.set_input_embeddings(value)
|
| 304 |
+
|
| 305 |
+
def get_output_embeddings(self):
|
| 306 |
+
return self.language_model.get_output_embeddings()
|
| 307 |
+
|
| 308 |
+
def set_output_embeddings(self, new_embeddings):
|
| 309 |
+
self.language_model.set_output_embeddings(new_embeddings)
|
| 310 |
+
|
| 311 |
+
def set_decoder(self, decoder):
|
| 312 |
+
self.language_model.set_decoder(decoder)
|
| 313 |
+
|
| 314 |
+
def get_decoder(self):
|
| 315 |
+
return self.language_model.get_decoder()
|
| 316 |
+
|
| 317 |
+
def tie_weights(self):
|
| 318 |
+
return self.language_model.tie_weights()
|
| 319 |
+
|
| 320 |
+
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
|
| 321 |
+
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
| 322 |
+
# update vocab size
|
| 323 |
+
self.config.text_config.vocab_size = model_embeds.num_embeddings
|
| 324 |
+
return model_embeds
|
| 325 |
+
|
| 326 |
+
def forward(
|
| 327 |
+
self,
|
| 328 |
+
input_ids: torch.LongTensor = None,
|
| 329 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 330 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 331 |
+
pixel_values: torch.FloatTensor = None,
|
| 332 |
+
image_grid_thw: Optional[torch.Tensor] = None,
|
| 333 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 334 |
+
labels: Optional[torch.LongTensor] = None,
|
| 335 |
+
num_images: Optional[torch.Tensor] = None,
|
| 336 |
+
use_cache: Optional[bool] = None,
|
| 337 |
+
output_attentions: Optional[bool] = None,
|
| 338 |
+
output_hidden_states: Optional[bool] = None,
|
| 339 |
+
return_dict: Optional[bool] = None,
|
| 340 |
+
use_rmpad: Optional[bool] = False,
|
| 341 |
+
**kwargs,
|
| 342 |
+
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
|
| 343 |
+
|
| 344 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 345 |
+
output_hidden_states = (
|
| 346 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 347 |
+
)
|
| 348 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
if input_ids is None:
|
| 352 |
+
raise ValueError("You must specify input_ids")
|
| 353 |
+
|
| 354 |
+
bsz, max_seq_len = input_ids.shape[0], input_ids.shape[1]
|
| 355 |
+
|
| 356 |
+
if max_seq_len > 1:
|
| 357 |
+
special_image_mask = input_ids == self.config.image_token_index
|
| 358 |
+
print(f'[{input_ids.device}] num_images: {num_images.tolist()} num_image_tokens: {special_image_mask.sum(-1).tolist()}', flush=True)
|
| 359 |
+
|
| 360 |
+
if position_ids is None:
|
| 361 |
+
if 'Qwen2VLForCausalLM' in self.language_model.__class__.__name__:
|
| 362 |
+
position_ids = self.language_model.get_rope_index(input_ids, image_grid_thw, attention_mask) # [bsz, seqlen, 3]
|
| 363 |
+
else:
|
| 364 |
+
position_ids = attention_mask.long().cumsum(-1) - 1 # # [bsz, seqlen]
|
| 365 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
if use_rmpad:
|
| 369 |
+
input_ids, input_ids_indices, cu_seqlens, _ = _unpad_input(input_ids, attention_mask) # [bsz, seqlen] -> [1, seqlen]
|
| 370 |
+
position_ids, _, _, _ = _unpad_input(position_ids, attention_mask)
|
| 371 |
+
input_ids, position_ids = input_ids.unsqueeze(0), position_ids.unsqueeze(0)
|
| 372 |
+
else:
|
| 373 |
+
input_ids_indices, cu_seqlens = None, None
|
| 374 |
+
|
| 375 |
+
inputs_embeds = self.get_input_embeddings()(input_ids) # [1, seqlen, dim]
|
| 376 |
+
|
| 377 |
+
image_features = None
|
| 378 |
+
if pixel_values is not None: # training / first step in generation
|
| 379 |
+
if 'Qwen2VLForCausalLM' in self.language_model.__class__.__name__:
|
| 380 |
+
pixel_values = pixel_values.type(self.vision_tower.get_dtype())
|
| 381 |
+
image_features = self.vision_tower(pixel_values, image_grid_thw)
|
| 382 |
+
else:
|
| 383 |
+
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
| 384 |
+
image_features = self.multi_modal_projector(
|
| 385 |
+
image_outputs.hidden_states,
|
| 386 |
+
self.get_input_embeddings(),
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
special_image_mask = (input_ids == self.config.image_token_index).to(inputs_embeds.device)
|
| 390 |
+
if special_image_mask.sum() > 0:
|
| 391 |
+
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
| 392 |
+
inputs_embeds = inputs_embeds.masked_scatter(
|
| 393 |
+
special_image_mask.unsqueeze(-1).expand_as(inputs_embeds),
|
| 394 |
+
image_features
|
| 395 |
+
)
|
| 396 |
+
else:
|
| 397 |
+
inputs_embeds = image_features.sum(dim=(0,1)) * 0. + inputs_embeds
|
| 398 |
+
|
| 399 |
+
outputs = self.language_model(
|
| 400 |
+
attention_mask=attention_mask,
|
| 401 |
+
position_ids=position_ids,
|
| 402 |
+
past_key_values=past_key_values,
|
| 403 |
+
inputs_embeds=inputs_embeds,
|
| 404 |
+
use_cache=use_cache,
|
| 405 |
+
output_attentions=output_attentions,
|
| 406 |
+
output_hidden_states=output_hidden_states,
|
| 407 |
+
return_dict=return_dict,
|
| 408 |
+
use_rmpad=use_rmpad,
|
| 409 |
+
cu_seqlens=cu_seqlens,
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
logits = outputs[0]
|
| 413 |
+
|
| 414 |
+
loss = None
|
| 415 |
+
if labels is not None:
|
| 416 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 417 |
+
if use_rmpad:
|
| 418 |
+
labels = labels.view(-1)[input_ids_indices.long()]
|
| 419 |
+
shift_labels = torch.cat((labels[1:], labels.new_ones((1))*-100))
|
| 420 |
+
shift_labels.requires_grad = False
|
| 421 |
+
lbl_seq_lens = (cu_seqlens[1:]-1).long()
|
| 422 |
+
shift_labels[lbl_seq_lens] = -100
|
| 423 |
+
loss = loss_fct(logits.squeeze(0), shift_labels)
|
| 424 |
+
else:
|
| 425 |
+
# Shift so that tokens < n predict n
|
| 426 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 427 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 428 |
+
# Flatten the tokens
|
| 429 |
+
shift_logits = shift_logits.view(-1, self.config.text_config.vocab_size)
|
| 430 |
+
shift_labels = shift_labels.view(-1)
|
| 431 |
+
# Enable model parallelism
|
| 432 |
+
shift_labels = shift_labels.to(shift_logits.device)
|
| 433 |
+
loss = loss_fct(shift_logits, shift_labels)
|
| 434 |
+
elif use_rmpad: # 训练的时候,就不 unpad logits 了,节省显存。
|
| 435 |
+
logits = _pad_input(logits.squeeze(0), input_ids_indices, bsz, max_seq_len)
|
| 436 |
+
|
| 437 |
+
if not return_dict:
|
| 438 |
+
output = (logits,) + outputs[1:]
|
| 439 |
+
return (loss,) + output if loss is not None else output
|
| 440 |
+
|
| 441 |
+
return LlavaCausalLMOutputWithPast(
|
| 442 |
+
loss=loss,
|
| 443 |
+
logits=logits,
|
| 444 |
+
past_key_values=outputs.past_key_values,
|
| 445 |
+
hidden_states=outputs.hidden_states,
|
| 446 |
+
attentions=outputs.attentions,
|
| 447 |
+
position_ids=position_ids,
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
def prepare_inputs_for_generation(
|
| 451 |
+
self,
|
| 452 |
+
input_ids,
|
| 453 |
+
attention_mask=None,
|
| 454 |
+
position_ids=None,
|
| 455 |
+
past_key_values=None,
|
| 456 |
+
cache_position=None,
|
| 457 |
+
use_cache=True,
|
| 458 |
+
pixel_values=None,
|
| 459 |
+
image_grid_thw=None,
|
| 460 |
+
**kwargs,
|
| 461 |
+
):
|
| 462 |
+
if past_key_values is not None:
|
| 463 |
+
past_length = past_key_values.get_seq_length()
|
| 464 |
+
input_ids = input_ids[:, past_length:]
|
| 465 |
+
|
| 466 |
+
model_inputs = {
|
| 467 |
+
"input_ids": input_ids,
|
| 468 |
+
"attention_mask": attention_mask,
|
| 469 |
+
"past_key_values": past_key_values,
|
| 470 |
+
"use_cache": use_cache,
|
| 471 |
+
}
|
| 472 |
+
if kwargs.get('num_images') is not None:
|
| 473 |
+
model_inputs['num_images'] = kwargs['num_images']
|
| 474 |
+
|
| 475 |
+
if cache_position[0] == 0:
|
| 476 |
+
# If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
|
| 477 |
+
# Otherwise we need pixel values to be passed to model
|
| 478 |
+
model_inputs["pixel_values"] = pixel_values
|
| 479 |
+
model_inputs["image_grid_thw"] = image_grid_thw
|
| 480 |
+
else:
|
| 481 |
+
model_inputs['position_ids'] = position_ids[:, -1, ...].unsqueeze(1).to(device=input_ids.device) + 1
|
| 482 |
+
return model_inputs
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
def _update_model_kwargs_for_generation(
|
| 486 |
+
self,
|
| 487 |
+
outputs: ModelOutput,
|
| 488 |
+
model_kwargs: Dict[str, Any],
|
| 489 |
+
is_encoder_decoder: bool = False,
|
| 490 |
+
num_new_tokens: int = 1,
|
| 491 |
+
) -> Dict[str, Any]:
|
| 492 |
+
model_kwargs = super()._update_model_kwargs_for_generation(
|
| 493 |
+
outputs=outputs,
|
| 494 |
+
model_kwargs=model_kwargs,
|
| 495 |
+
is_encoder_decoder=is_encoder_decoder,
|
| 496 |
+
num_new_tokens=num_new_tokens,
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
if getattr(outputs, "position_ids", None) is not None:
|
| 500 |
+
model_kwargs["position_ids"] = outputs.position_ids
|
| 501 |
+
|
| 502 |
+
return model_kwargs
|
eval_scripts/DREAM-1K/tarsier/models/utils.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from einops import rearrange
|
| 4 |
+
|
| 5 |
+
def _unpad_input(input_ids, attention_mask):
|
| 6 |
+
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
| 7 |
+
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
| 8 |
+
max_seqlen_in_batch = seqlens_in_batch.max().item()
|
| 9 |
+
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
|
| 10 |
+
input_ids = rearrange(input_ids, 'b s ... -> (b s) ...')[indices]
|
| 11 |
+
return input_ids, indices, cu_seqlens, max_seqlen_in_batch
|
| 12 |
+
|
| 13 |
+
def _pad_input(hidden_states, indices, batch, seqlen):
|
| 14 |
+
output = torch.zeros(batch * seqlen, *hidden_states.shape[1:], device=hidden_states.device,
|
| 15 |
+
dtype=hidden_states.dtype)
|
| 16 |
+
output[indices] = hidden_states
|
| 17 |
+
return rearrange(output, '(b s) ... -> b s ...', b=batch)
|
eval_scripts/DREAM-1K/tarsier/scripts/run_demo_cli.sh
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
model_path=$1
|
| 4 |
+
n_frames=${2:-8}
|
| 5 |
+
max_new_tokens=${3:-512}
|
| 6 |
+
top_p=${4:-0.8}
|
| 7 |
+
temperature=${5:-0}
|
| 8 |
+
|
| 9 |
+
python3 -m tasks.demo_cli \
|
| 10 |
+
--model_name_or_path $model_path \
|
| 11 |
+
--config "configs/tarser2_default_config.yaml" \
|
| 12 |
+
--max_n_frames $n_frames \
|
| 13 |
+
--max_new_tokens $max_new_tokens \
|
| 14 |
+
--top_p $top_p \
|
| 15 |
+
--temperature $temperature
|
eval_scripts/DREAM-1K/tarsier/scripts/run_demo_gradio.sh
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
model_path=$1
|
| 4 |
+
max_n_frames=${2:-8}
|
| 5 |
+
|
| 6 |
+
export MODEL_PATH=$model_path
|
| 7 |
+
export MAX_N_FRAMES=$max_n_frames
|
| 8 |
+
|
| 9 |
+
python3 -m tasks.demo_gradio
|
eval_scripts/DREAM-1K/tarsier/scripts/run_evaluation_only.sh
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
export AZURE_ENDPOINT=...
|
| 3 |
+
export OPENAI_API_KEY=...
|
| 4 |
+
|
| 5 |
+
pred_file=$1
|
| 6 |
+
# benchmarks=${@:2}
|
| 7 |
+
benchmarks=dream
|
| 8 |
+
benchmarks=${benchmarks:-"all"}
|
| 9 |
+
|
| 10 |
+
python -m evaluation.evaluate \
|
| 11 |
+
--pred_file $pred_file \
|
| 12 |
+
--benchmarks $benchmarks
|
eval_scripts/DREAM-1K/tarsier/scripts/run_inference_benchmark.sh
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Copy and Modified on: https://github.com/LLaVA-VL/LLaVA-NeXT/blob/video_inference/scripts/video/eval/video_detail_description_eval_shard.sh
|
| 4 |
+
|
| 5 |
+
#
|
| 6 |
+
|
| 7 |
+
model_name_or_path=$1
|
| 8 |
+
output_dir=$2
|
| 9 |
+
benchmarks=${@:3}
|
| 10 |
+
benchmarks=${benchmarks:-"all"}
|
| 11 |
+
resume=True
|
| 12 |
+
CHUNKS=8
|
| 13 |
+
|
| 14 |
+
mkdir $output_dir
|
| 15 |
+
|
| 16 |
+
echo "Using $CHUNKS GPUs"
|
| 17 |
+
|
| 18 |
+
# Assuming GPULIST is a bash array containing your GPUs
|
| 19 |
+
GPULIST=(0 1 2 3 4 5 6 7)
|
| 20 |
+
# GPULIST=(0 1)
|
| 21 |
+
|
| 22 |
+
# Get the number of GPUs
|
| 23 |
+
NUM_GPUS=${#GPULIST[@]}
|
| 24 |
+
|
| 25 |
+
# Calculate GPUs per chunk
|
| 26 |
+
GPUS_PER_CHUNK=$((NUM_GPUS / CHUNKS))
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
for IDX in $(seq 1 $CHUNKS); do
|
| 30 |
+
START=$(((IDX-1) * GPUS_PER_CHUNK))
|
| 31 |
+
LENGTH=$GPUS_PER_CHUNK # Length for slicing, not the end index
|
| 32 |
+
|
| 33 |
+
CHUNK_GPUS=(${GPULIST[@]:$START:$LENGTH})
|
| 34 |
+
|
| 35 |
+
# Convert the chunk GPUs array to a comma-separated string
|
| 36 |
+
CHUNK_GPUS_STR=$(IFS=,; echo "${CHUNK_GPUS[*]}")
|
| 37 |
+
|
| 38 |
+
ALL_GPUS_FREE=0
|
| 39 |
+
while [ $ALL_GPUS_FREE -eq 0 ]; do
|
| 40 |
+
ALL_GPUS_FREE=1 # Assume all GPUs are free initially
|
| 41 |
+
|
| 42 |
+
for GPU_ID in $CHUNK_GPUS; do
|
| 43 |
+
MEM_USAGE=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i $GPU_ID | tr -d '[:space:]')
|
| 44 |
+
|
| 45 |
+
# Assuming a GPU is considered free if its memory usage is less than 100 MiB
|
| 46 |
+
if [ $MEM_USAGE -ge 100 ]; then
|
| 47 |
+
ALL_GPUS_FREE=0
|
| 48 |
+
echo "GPU $GPU_ID is in use. Memory used: ${MEM_USAGE}MiB."
|
| 49 |
+
break # Exit the loop early as we found a GPU that is not free
|
| 50 |
+
fi
|
| 51 |
+
done
|
| 52 |
+
|
| 53 |
+
if [ $ALL_GPUS_FREE -eq 0 ]; then
|
| 54 |
+
echo "Not all GPUs in chunk are free. Checking again in 10 seconds..."
|
| 55 |
+
sleep 10
|
| 56 |
+
fi
|
| 57 |
+
done
|
| 58 |
+
|
| 59 |
+
echo "CUDA_VISIBLE_DEVICES=$CHUNK_GPUS_STR"
|
| 60 |
+
CUDA_VISIBLE_DEVICES=$CHUNK_GPUS_STR python3 -m tasks.inference_benchmark \
|
| 61 |
+
--model_name_or_path $model_name_or_path \
|
| 62 |
+
--config "configs/tarser2_default_config.yaml" \
|
| 63 |
+
--max_new_tokens 512 \
|
| 64 |
+
--top_p 1 \
|
| 65 |
+
--temperature 0 \
|
| 66 |
+
--output_dir $output_dir \
|
| 67 |
+
--output_name predictions \
|
| 68 |
+
--max_n_samples_per_benchmark -1 \
|
| 69 |
+
--benchmarks $benchmarks \
|
| 70 |
+
--resume $resume \
|
| 71 |
+
--num_chunks $CHUNKS \
|
| 72 |
+
--chunk_idx $(($IDX - 1)) > $output_dir/run_$IDX.log 2>&1 &
|
| 73 |
+
|
| 74 |
+
done
|
| 75 |
+
|
| 76 |
+
wait
|
| 77 |
+
|
| 78 |
+
python3 -m evaluation.evaluate \
|
| 79 |
+
--pred_file $output_dir \
|
| 80 |
+
--benchmarks $benchmarks
|
eval_scripts/DREAM-1K/tarsier/scripts/run_inference_caption.sh
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Copy and Modified on: https://github.com/LLaVA-VL/LLaVA-NeXT/blob/video_inference/scripts/video/eval/video_detail_description_eval_shard.sh
|
| 4 |
+
|
| 5 |
+
#
|
| 6 |
+
|
| 7 |
+
model_name_or_path=$1
|
| 8 |
+
input_file=$2
|
| 9 |
+
output_dir=$3
|
| 10 |
+
CHUNKS=1
|
| 11 |
+
resume=True
|
| 12 |
+
|
| 13 |
+
mkdir $output_dir
|
| 14 |
+
|
| 15 |
+
echo "Using $CHUNKS GPUs"
|
| 16 |
+
|
| 17 |
+
# Assuming GPULIST is a bash array containing your GPUs
|
| 18 |
+
# GPULIST=(0 1 2 3 4 5 6 7)
|
| 19 |
+
GPULIST=(0)
|
| 20 |
+
|
| 21 |
+
# Get the number of GPUs
|
| 22 |
+
NUM_GPUS=${#GPULIST[@]}
|
| 23 |
+
|
| 24 |
+
# Calculate GPUs per chunk
|
| 25 |
+
GPUS_PER_CHUNK=$((NUM_GPUS / CHUNKS))
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
for IDX in $(seq 1 $CHUNKS); do
|
| 29 |
+
START=$(((IDX-1) * GPUS_PER_CHUNK))
|
| 30 |
+
LENGTH=$GPUS_PER_CHUNK # Length for slicing, not the end index
|
| 31 |
+
|
| 32 |
+
CHUNK_GPUS=(${GPULIST[@]:$START:$LENGTH})
|
| 33 |
+
|
| 34 |
+
# Convert the chunk GPUs array to a comma-separated string
|
| 35 |
+
CHUNK_GPUS_STR=$(IFS=,; echo "${CHUNK_GPUS[*]}")
|
| 36 |
+
|
| 37 |
+
ALL_GPUS_FREE=0
|
| 38 |
+
while [ $ALL_GPUS_FREE -eq 0 ]; do
|
| 39 |
+
ALL_GPUS_FREE=1 # Assume all GPUs are free initially
|
| 40 |
+
|
| 41 |
+
for GPU_ID in $CHUNK_GPUS; do
|
| 42 |
+
MEM_USAGE=$(nvidia-smi --query-gpu=memory.used --format=csv,noheader,nounits -i $GPU_ID | tr -d '[:space:]')
|
| 43 |
+
|
| 44 |
+
# Assuming a GPU is considered free if its memory usage is less than 100 MiB
|
| 45 |
+
if [ "$MEM_USAGE" -ge 100 ]; then
|
| 46 |
+
ALL_GPUS_FREE=0
|
| 47 |
+
echo "GPU $GPU_ID is in use. Memory used: ${MEM_USAGE}MiB."
|
| 48 |
+
break # Exit the loop early as we found a GPU that is not free
|
| 49 |
+
fi
|
| 50 |
+
done
|
| 51 |
+
|
| 52 |
+
if [ $ALL_GPUS_FREE -eq 0 ]; then
|
| 53 |
+
echo "Not all GPUs in chunk are free. Checking again in 10 seconds..."
|
| 54 |
+
sleep 10
|
| 55 |
+
fi
|
| 56 |
+
done
|
| 57 |
+
|
| 58 |
+
echo "CUDA_VISIBLE_DEVICES=$CHUNK_GPUS_STR"
|
| 59 |
+
CUDA_VISIBLE_DEVICES=$CHUNK_GPUS_STR python3 -m tasks.inference_caption \
|
| 60 |
+
--model_name_or_path $model_name_or_path \
|
| 61 |
+
--config "configs/tarser2_default_config.yaml" \
|
| 62 |
+
--max_new_tokens 512 \
|
| 63 |
+
--top_p 1 \
|
| 64 |
+
--temperature 0 \
|
| 65 |
+
--input_file $input_file \
|
| 66 |
+
--output_dir $output_dir \
|
| 67 |
+
--output_name predictions \
|
| 68 |
+
--max_n_samples_per_benchmark -1 \
|
| 69 |
+
--resume $resume \
|
| 70 |
+
--num_chunks $CHUNKS \
|
| 71 |
+
--chunk_idx $(($IDX - 1)) > $output_dir/run_$IDX.log 2>&1 &
|
| 72 |
+
|
| 73 |
+
done
|
| 74 |
+
|
| 75 |
+
wait
|
| 76 |
+
|
| 77 |
+
# python3 -m evaluation.evaluate \
|
| 78 |
+
# --pred_file $output_dir \
|
| 79 |
+
# --benchmarks $benchmarks
|
eval_scripts/DREAM-1K/tarsier/tasks/demo_cli.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (2024) Bytedance Ltd. and/or its affiliates
|
| 2 |
+
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import argparse
|
| 15 |
+
import os
|
| 16 |
+
import torch
|
| 17 |
+
from copy import deepcopy
|
| 18 |
+
from transformers import StoppingCriteriaList
|
| 19 |
+
from tasks.utils import load_model_and_processor
|
| 20 |
+
from dataset.utils import *
|
| 21 |
+
from tools.conversation import Chat, conv_templates, StoppingCriteriaSub
|
| 22 |
+
from transformers import TextStreamer
|
| 23 |
+
from tools.color import Color
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 27 |
+
|
| 28 |
+
def main(args):
|
| 29 |
+
# Load Model
|
| 30 |
+
print(f"### Start loading model...")
|
| 31 |
+
model, processor = load_model_and_processor(args.model_name_or_path, args.config)
|
| 32 |
+
print(f"### Finish loading model.")
|
| 33 |
+
if 'tarsier2' in args.model_name_or_path.lower():
|
| 34 |
+
conv_type = 'tarsier2-7b'
|
| 35 |
+
else:
|
| 36 |
+
if '7b' in args.model_name_or_path.lower():
|
| 37 |
+
conv_type = 'tarsier-7b'
|
| 38 |
+
elif '13b' in args.model_name_or_path.lower():
|
| 39 |
+
conv_type = 'tarsier-13b'
|
| 40 |
+
elif '34b' in args.model_name_or_path.lower():
|
| 41 |
+
conv_type = 'tarsier-34b'
|
| 42 |
+
else:
|
| 43 |
+
raise ValueError(f"Unknow model: {args.model_name_or_path}")
|
| 44 |
+
|
| 45 |
+
chat = Chat(model, processor, device=device, debug = args.debug)
|
| 46 |
+
conv = deepcopy(conv_templates[conv_type])
|
| 47 |
+
|
| 48 |
+
img_path = ''
|
| 49 |
+
has_img = False
|
| 50 |
+
while True:
|
| 51 |
+
if not has_img:
|
| 52 |
+
try:
|
| 53 |
+
img_path = input(Color.green(f"{conv.roles[1]}: ") + "Input a file path of your image/video:")
|
| 54 |
+
img_path = img_path.strip()
|
| 55 |
+
if not (os.path.exists(img_path) and get_visual_type(img_path) in ['video', 'gif', 'image']):
|
| 56 |
+
continue
|
| 57 |
+
has_img = True
|
| 58 |
+
conv.messages.append([conv.roles[0], {"type": "video", "text": img_path}])
|
| 59 |
+
print(Color.green(f"{conv.roles[1]}: ") + "Received your file. Now let's start conversation! :)")
|
| 60 |
+
print(Color.red(f"<Input \'exit\' to exit and \'reset\' to restart>"))
|
| 61 |
+
except Exception as e:
|
| 62 |
+
print(f"Error: {e}")
|
| 63 |
+
print("exit...")
|
| 64 |
+
exit()
|
| 65 |
+
inp = ""
|
| 66 |
+
while inp == "":
|
| 67 |
+
inp = input(Color.blue(f"{conv.roles[0]}: ")).strip()
|
| 68 |
+
if inp.strip() == 'exit':
|
| 69 |
+
print("exit...")
|
| 70 |
+
exit()
|
| 71 |
+
elif inp.strip() == "reset":
|
| 72 |
+
conv = deepcopy(conv_templates[conv_type])
|
| 73 |
+
img_path = ''
|
| 74 |
+
continue
|
| 75 |
+
conv = chat.ask(inp, conv)
|
| 76 |
+
|
| 77 |
+
stop_words_ids = [torch.tensor([processor.processor.tokenizer.eos_token_id]).to(device)]
|
| 78 |
+
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
| 79 |
+
streamer = TextStreamer(processor.processor.tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 80 |
+
|
| 81 |
+
inputs, conv = chat.prepare_model_inputs(conv, args.max_n_frames)
|
| 82 |
+
print("conv:", conv)
|
| 83 |
+
print(Color.green(f"{conv.roles[1]}: "), end="")
|
| 84 |
+
with torch.inference_mode():
|
| 85 |
+
outputs = model.generate(
|
| 86 |
+
**inputs,
|
| 87 |
+
do_sample=True if args.temperature > 0 else False,
|
| 88 |
+
temperature=args.temperature,
|
| 89 |
+
top_p=args.top_p,
|
| 90 |
+
max_new_tokens=args.max_new_tokens,
|
| 91 |
+
streamer=streamer,
|
| 92 |
+
use_cache=True,
|
| 93 |
+
stopping_criteria=[stopping_criteria])
|
| 94 |
+
outputs = processor.processor.tokenizer.decode(outputs[0][inputs['input_ids'][0].shape[0]:], skip_special_tokens=True)
|
| 95 |
+
conv.messages.append(
|
| 96 |
+
[conv.roles[1], {"text": outputs, "type": "text"}]
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
if args.debug:
|
| 100 |
+
print(f"Conversation state: {conv}")
|
| 101 |
+
|
| 102 |
+
if __name__ == "__main__":
|
| 103 |
+
# python3 -m tasks.demo_cli --model_name_or_path /tmp/tarsier2-1226-dpo --config configs/tarser2_default_config.yaml
|
| 104 |
+
import argparse
|
| 105 |
+
|
| 106 |
+
parser = argparse.ArgumentParser()
|
| 107 |
+
parser.add_argument('--model_name_or_path', type=str)
|
| 108 |
+
parser.add_argument('--config', type=str, default="configs/tarser2_default_config.yaml")
|
| 109 |
+
parser.add_argument("--max_n_frames", type=int, default=16, help="Max number of frames to apply average sampling from the given video.")
|
| 110 |
+
parser.add_argument("--max_new_tokens", type=int, default=512, help="max number of generated tokens")
|
| 111 |
+
parser.add_argument("--top_p", type=float, default=1, help="Top_p sampling")
|
| 112 |
+
parser.add_argument("--temperature", type=float, default=0, help="Set temperature > 0 to enable sampling generation.")
|
| 113 |
+
parser.add_argument("--debug", action="store_true")
|
| 114 |
+
args = parser.parse_args()
|
| 115 |
+
|
| 116 |
+
main(args)
|
eval_scripts/DREAM-1K/tarsier/tasks/demo_gradio.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (2024) Bytedance Ltd. and/or its affiliates
|
| 2 |
+
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# copy and modify from: https://github.com/OpenGVLab/Ask-Anything/blob/main/video_chat2/demo/demo.py
|
| 16 |
+
|
| 17 |
+
# import spaces # for deploying on huggingface ZeroGPU
|
| 18 |
+
from copy import deepcopy
|
| 19 |
+
import gradio as gr
|
| 20 |
+
from gradio.themes.utils import colors, fonts, sizes
|
| 21 |
+
from tools.conversation import Chat, conv_templates
|
| 22 |
+
from tasks.utils import load_model_and_processor, file_to_base64
|
| 23 |
+
from dataset.tarsier_datamodule import init_processor
|
| 24 |
+
import os
|
| 25 |
+
import torch
|
| 26 |
+
|
| 27 |
+
# huggingface-cli login
|
| 28 |
+
|
| 29 |
+
model_path = os.getenv("MODEL_PATH", "omni-research/Tarsier2-7b")
|
| 30 |
+
config_path = "configs/tarser2_default_config.yaml"
|
| 31 |
+
max_n_frames = int(os.getenv("MAX_N_FRAMES", 16))
|
| 32 |
+
debug = False
|
| 33 |
+
device = 'cuda' if not debug else 'cpu'
|
| 34 |
+
|
| 35 |
+
# ========================================
|
| 36 |
+
# Model Initialization
|
| 37 |
+
# ========================================
|
| 38 |
+
def init_model():
|
| 39 |
+
print("Start Initialization...")
|
| 40 |
+
# if torch.cuda.is_available():
|
| 41 |
+
if not debug:
|
| 42 |
+
model, processor = load_model_and_processor(model_path, config_path)
|
| 43 |
+
else:
|
| 44 |
+
print(f"No Valid GPU! Lauch in debug mode!")
|
| 45 |
+
processor = init_processor(model_path, config_path)
|
| 46 |
+
model = None
|
| 47 |
+
chat = Chat(model, processor, device, debug)c
|
| 48 |
+
print('Initialization Finished')
|
| 49 |
+
return chat
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# ========================================
|
| 53 |
+
# Gradio Setting
|
| 54 |
+
# ========================================
|
| 55 |
+
def gradio_reset(chat_state, img_file):
|
| 56 |
+
if chat_state is not None:
|
| 57 |
+
chat_state.messages = []
|
| 58 |
+
img_file = None
|
| 59 |
+
return None, gr.update(value=None, interactive=True), gr.update(value=None, interactive=True), gr.update(value=None, interactive=True), gr.update(placeholder='Please upload your video first', interactive=False),gr.update(value="Upload & Start Chat", interactive=True), chat_state, img_file
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def upload_img(gr_img, gr_video, gr_gif, chat_state, num_frames):
|
| 63 |
+
print("video, image or gif:", gr_video, gr_img, gr_gif)
|
| 64 |
+
conv_type = ''
|
| 65 |
+
if 'tarsier2-7b' in model_path.lower():
|
| 66 |
+
conv_type = 'tarsier2-7b'
|
| 67 |
+
# elif '7b' in model_path.lower():
|
| 68 |
+
# conv_type = 'tarsier-7b'
|
| 69 |
+
# elif '13b' in model_path.lower():
|
| 70 |
+
# conv_type = 'tarsier-13b'
|
| 71 |
+
# elif '34b' in model_path.lower():
|
| 72 |
+
# conv_type = 'tarsier-34b'
|
| 73 |
+
else:
|
| 74 |
+
raise ValueError(f"Unknow model: {model_path}")
|
| 75 |
+
chat_state = deepcopy(conv_templates[conv_type])
|
| 76 |
+
|
| 77 |
+
if gr_img is None and gr_video is None and gr_gif is None:
|
| 78 |
+
return None, None, None, gr.update(interactive=True), gr.update(interactive=True, placeholder='Please upload video/image first!'), chat_state, None, None
|
| 79 |
+
if gr_video or gr_img or gr_gif:
|
| 80 |
+
for img_file in [gr_video, gr_img, gr_gif]:
|
| 81 |
+
if img_file is not None:
|
| 82 |
+
break
|
| 83 |
+
chat_state.messages.append([chat_state.roles[0], {"type": "video", "text": img_file}])
|
| 84 |
+
return gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_file
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def gradio_ask(user_message, chatbot, chat_state):
|
| 88 |
+
if len(user_message) == 0:
|
| 89 |
+
return gr.update(interactive=True, placeholder='Input should not be empty!'), chatbot, chat_state
|
| 90 |
+
chat_state = chat.ask(user_message, chat_state)
|
| 91 |
+
chatbot = chatbot + [[user_message, None]]
|
| 92 |
+
return '', chatbot, chat_state
|
| 93 |
+
|
| 94 |
+
# @spaces.GPU(duration=120) # for deploying on huggingface ZeroGPU
|
| 95 |
+
def gradio_answer(chatbot, chat_state, img_file, top_p, temperature, n_frames=None):
|
| 96 |
+
llm_message, chat_state = chat.answer(conv=chat_state, n_frames=n_frames, max_new_tokens=256, num_beams=1, temperature=temperature, top_p=top_p)
|
| 97 |
+
chatbot[-1][1] = llm_message
|
| 98 |
+
print(chat_state)
|
| 99 |
+
print(f"Answer: {llm_message}")
|
| 100 |
+
return chatbot, chat_state
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class OpenGVLab(gr.themes.base.Base):
|
| 104 |
+
def __init__(
|
| 105 |
+
self,
|
| 106 |
+
*,
|
| 107 |
+
primary_hue=colors.blue,
|
| 108 |
+
secondary_hue=colors.sky,
|
| 109 |
+
neutral_hue=colors.gray,
|
| 110 |
+
spacing_size=sizes.spacing_md,
|
| 111 |
+
radius_size=sizes.radius_sm,
|
| 112 |
+
text_size=sizes.text_md,
|
| 113 |
+
font=(
|
| 114 |
+
fonts.GoogleFont("Noto Sans"),
|
| 115 |
+
"ui-sans-serif",
|
| 116 |
+
"sans-serif",
|
| 117 |
+
),
|
| 118 |
+
font_mono=(
|
| 119 |
+
fonts.GoogleFont("IBM Plex Mono"),
|
| 120 |
+
"ui-monospace",
|
| 121 |
+
"monospace",
|
| 122 |
+
),
|
| 123 |
+
):
|
| 124 |
+
super().__init__(
|
| 125 |
+
primary_hue=primary_hue,
|
| 126 |
+
secondary_hue=secondary_hue,
|
| 127 |
+
neutral_hue=neutral_hue,
|
| 128 |
+
spacing_size=spacing_size,
|
| 129 |
+
radius_size=radius_size,
|
| 130 |
+
text_size=text_size,
|
| 131 |
+
font=font,
|
| 132 |
+
font_mono=font_mono,
|
| 133 |
+
)
|
| 134 |
+
super().set(
|
| 135 |
+
body_background_fill="*neutral_50",
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
gvlabtheme = OpenGVLab(primary_hue=colors.blue,
|
| 140 |
+
secondary_hue=colors.sky,
|
| 141 |
+
neutral_hue=colors.gray,
|
| 142 |
+
spacing_size=sizes.spacing_md,
|
| 143 |
+
radius_size=sizes.radius_sm,
|
| 144 |
+
text_size=sizes.text_md,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
logo_b64 = file_to_base64("assets/figures/tarsier_logo.jpg")
|
| 148 |
+
title = f"""<center><a href="https://github.com/bytedance/tarsier"><img src="data:image/jpeg;base64,{logo_b64}" alt="Tarsier" border="0" style="margin: 0 auto; height: 140px;" /></a></center>"""
|
| 149 |
+
description ="""<center><p><a href='https://github.com/bytedance/tarsier'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p></center>
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
with gr.Blocks(title="Tarsier",theme=gvlabtheme,css="#chatbot {overflow:auto; height:500px;} #InputVideo {overflow:visible; height:320px;} footer {visibility: none}") as demo:
|
| 154 |
+
gr.Markdown(title)
|
| 155 |
+
gr.Markdown(description)
|
| 156 |
+
with gr.Row():
|
| 157 |
+
with gr.Column(scale=0.5, visible=True) as video_upload:
|
| 158 |
+
with gr.Column(elem_id="image", scale=0.5) as img_part:
|
| 159 |
+
with gr.Tab("Video", elem_id='video_tab'):
|
| 160 |
+
up_video = gr.Video(interactive=True, include_audio=True, elem_id="video_upload", height=360)
|
| 161 |
+
with gr.Tab("Image", elem_id='image_tab'):
|
| 162 |
+
up_image = gr.Image(type="filepath", interactive=True, elem_id="image_upload", height=360)
|
| 163 |
+
with gr.Tab("GIF", elem_id='gif_tab'):
|
| 164 |
+
up_gif = gr.File(type="filepath", file_count="single", file_types=[".gif"], interactive=True, elem_id="gif_upload", height=360)
|
| 165 |
+
upload_button = gr.Button(value="Upload & Start Chat", interactive=True, variant="primary")
|
| 166 |
+
clear = gr.Button("Restart")
|
| 167 |
+
|
| 168 |
+
# num_beams = gr.Slider(
|
| 169 |
+
# minimum=1,
|
| 170 |
+
# maximum=10,
|
| 171 |
+
# value=1,
|
| 172 |
+
# step=1,
|
| 173 |
+
# interactive=True,
|
| 174 |
+
# label="beam search numbers)",
|
| 175 |
+
# )
|
| 176 |
+
|
| 177 |
+
temperature = gr.Slider(
|
| 178 |
+
minimum=0.0,
|
| 179 |
+
maximum=1.0,
|
| 180 |
+
value=0.0,
|
| 181 |
+
step=0.1,
|
| 182 |
+
interactive=True,
|
| 183 |
+
label="Temperature",
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
top_p = gr.Slider(
|
| 187 |
+
minimum=0.1,
|
| 188 |
+
maximum=1.0,
|
| 189 |
+
value=1.0,
|
| 190 |
+
step=0.1,
|
| 191 |
+
interactive=True,
|
| 192 |
+
label="Top_p",
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
num_frames = gr.Slider(
|
| 196 |
+
minimum=4,
|
| 197 |
+
maximum=16,
|
| 198 |
+
value=16,
|
| 199 |
+
step=2,
|
| 200 |
+
interactive=True,
|
| 201 |
+
label="#Frames",
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
with gr.Column(visible=True) as input_raws:
|
| 205 |
+
chat_state = gr.State()
|
| 206 |
+
img_file = gr.State()
|
| 207 |
+
chatbot = gr.Chatbot(elem_id="chatbot",label='VideoChat')
|
| 208 |
+
with gr.Row():
|
| 209 |
+
with gr.Column(scale=0.7):
|
| 210 |
+
text_input = gr.Textbox(show_label=False, placeholder='Please upload your video first', interactive=False, container=False)
|
| 211 |
+
with gr.Column(scale=0.15, min_width=0):
|
| 212 |
+
run = gr.Button("💭Send")
|
| 213 |
+
with gr.Column(scale=0.15, min_width=0):
|
| 214 |
+
clear = gr.Button("🔄Clear️")
|
| 215 |
+
|
| 216 |
+
chat = init_model()
|
| 217 |
+
upload_button.click(upload_img, [up_image, up_video, up_gif, chat_state, num_frames], [up_image, up_video, up_gif, text_input, upload_button, chat_state, img_file])
|
| 218 |
+
|
| 219 |
+
text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
|
| 220 |
+
gradio_answer, [chatbot, chat_state, img_file, top_p, temperature, num_frames], [chatbot, chat_state]
|
| 221 |
+
)
|
| 222 |
+
run.click(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
|
| 223 |
+
gradio_answer, [chatbot, chat_state, img_file, top_p, temperature, num_frames], [chatbot, chat_state]
|
| 224 |
+
)
|
| 225 |
+
run.click(lambda: "", None, text_input)
|
| 226 |
+
clear.click(gradio_reset, [chat_state, img_file], [chatbot, up_image, up_video, up_gif, text_input, upload_button, chat_state, img_file], queue=False)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
demo.launch()
|
| 230 |
+
# demo.launch(server_name="0.0.0.0", server_port=11451)
|
eval_scripts/DREAM-1K/tarsier/tasks/inference_benchmark.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (2024) Bytedance Ltd. and/or its affiliates
|
| 2 |
+
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import argparse
|
| 15 |
+
import torch
|
| 16 |
+
from tasks.utils import load_model_and_processor
|
| 17 |
+
# from dataset.mm_dataset import MMDataset
|
| 18 |
+
from dataset.custom_data_parsers.utils import put_pred_to_data_dict, get_prompt_from_data_dict
|
| 19 |
+
from dataset.tarsier_datamodule import TarsierDataset
|
| 20 |
+
from dataset.utils import *
|
| 21 |
+
|
| 22 |
+
import json
|
| 23 |
+
import os
|
| 24 |
+
import math
|
| 25 |
+
from tqdm import tqdm
|
| 26 |
+
import yaml
|
| 27 |
+
|
| 28 |
+
ANN_ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) + '/../data/annotations'
|
| 29 |
+
|
| 30 |
+
Benchmark2fname = {
|
| 31 |
+
'dream': 'DREAM-1k.jsonl',
|
| 32 |
+
|
| 33 |
+
'next-qa': 'Next-QA-val-multi_choice.jsonl',
|
| 34 |
+
'egoschema': 'EgoSchema_subset.jsonl', # change to EgoSchema_fullset.jsonl if you test on the fullset
|
| 35 |
+
'mvbench': 'MVBench.jsonl',
|
| 36 |
+
'tvbench': 'TVBench.jsonl',
|
| 37 |
+
'video-mme': 'Video-MME.jsonl',
|
| 38 |
+
'favor-bench': 'FAVOR-Bench.jsonl',
|
| 39 |
+
|
| 40 |
+
'msvd-qa': 'MSVD-QA-val.jsonl',
|
| 41 |
+
'msr-vtt-qa': 'MSR-VTT-QA-val.jsonl',
|
| 42 |
+
'tgif-qa': 'TGIF-QA-test.jsonl',
|
| 43 |
+
'anet-qa': 'ActivityNet-QA-test.jsonl',
|
| 44 |
+
|
| 45 |
+
'msvd-caption': 'MSVD-Caption-test.jsonl',
|
| 46 |
+
'msr-vtt-caption': 'MSR-VTT-Caption-test.jsonl',
|
| 47 |
+
'vatex-caption': 'VATEX-test.jsonl',
|
| 48 |
+
|
| 49 |
+
'video_caption': "caption-test.jsonl", # custom for video caption test
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
def get_ann_file_path(benchmark):
|
| 53 |
+
ann_fpath = os.path.join(ANN_ROOT_DIR, Benchmark2fname[benchmark])
|
| 54 |
+
assert os.path.exists(ann_fpath), f"The annotation file for {benchmark} not exists: {ann_fpath}"
|
| 55 |
+
return ann_fpath
|
| 56 |
+
|
| 57 |
+
def split_list(lst, n):
|
| 58 |
+
"""Split a list into n (roughly) equal-sized chunks"""
|
| 59 |
+
chunk_size = math.ceil(len(lst) / n) # integer division
|
| 60 |
+
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def get_chunk(lst, n, k):
|
| 64 |
+
chunks = split_list(lst, n)
|
| 65 |
+
return chunks[k]
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def parse_args():
|
| 69 |
+
"""
|
| 70 |
+
Parse command-line arguments.
|
| 71 |
+
"""
|
| 72 |
+
parser = argparse.ArgumentParser()
|
| 73 |
+
|
| 74 |
+
# Define the command-line arguments
|
| 75 |
+
|
| 76 |
+
parser.add_argument('--model_name_or_path', type=str, required=True)
|
| 77 |
+
parser.add_argument('--config', type=str, default="configs/tarser2_default_config.yaml")
|
| 78 |
+
# parser.add_argument("--max_n_frames", type=int, default=8, help="Max number of frames to apply average sampling from the given video.")
|
| 79 |
+
parser.add_argument("--max_new_tokens", type=int, default=512, help="max number of generated tokens")
|
| 80 |
+
parser.add_argument("--top_p", type=float, default=1, help="Top_p sampling")
|
| 81 |
+
parser.add_argument("--temperature", type=float, default=0, help="Set temperature > 0 to enable sampling generation.")
|
| 82 |
+
|
| 83 |
+
parser.add_argument("--output_dir", type=str, help="Directory to save the model results", required=True)
|
| 84 |
+
parser.add_argument("--output_name", type=str, default="predictions", help="Name of the file for storing results")
|
| 85 |
+
|
| 86 |
+
parser.add_argument("--num_chunks", type=int, default=1)
|
| 87 |
+
parser.add_argument("--chunk_idx", type=int, default=0)
|
| 88 |
+
|
| 89 |
+
parser.add_argument("--max_n_samples_per_benchmark", type=int, default=-1, help="Set as a small number (like 100) to run as debug.")
|
| 90 |
+
parser.add_argument('--benchmarks', nargs='+', default=["all"], help="Default as 'all' to inference on all benchmarks; Also could be task types: ('dream', 'caption', 'mc_qa', 'oe_qa'); And specific benchmark names: ('dream', 'msvd-caption', 'msr-vtt-caption', 'vatex-caption', 'next-qa', 'egoschema', 'mvbench', 'video-mme', 'msvd-qa', 'msr-vtt-qa', 'tgif-qa', 'anet-qa')")
|
| 91 |
+
|
| 92 |
+
parser.add_argument("--resume", type=lambda x: (str(x).lower() == 'true'), default=True, help="Resume from existing inference results file or overwrite them.")
|
| 93 |
+
|
| 94 |
+
args = parser.parse_args()
|
| 95 |
+
|
| 96 |
+
args.benchmarks = get_benchmarks(args.benchmarks)
|
| 97 |
+
print("### Selected Benchmarks:", args.benchmarks)
|
| 98 |
+
|
| 99 |
+
return args
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def run_inference(args):
|
| 103 |
+
"""
|
| 104 |
+
Run inference on selected benchmarks.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
args: Command-line arguments.
|
| 108 |
+
"""
|
| 109 |
+
# Initialize the model
|
| 110 |
+
# model, processor = load_model_and_processor(args.model_name_or_path, args.max_n_frames) # max_n_frames set in config_file
|
| 111 |
+
data_config = yaml.safe_load(open(args.config, 'r'))
|
| 112 |
+
model, processor = load_model_and_processor(args.model_name_or_path, data_config=data_config)
|
| 113 |
+
|
| 114 |
+
all_chunks = []
|
| 115 |
+
count = 0
|
| 116 |
+
print(f"Start loading dataset...")
|
| 117 |
+
for benchmark in args.benchmarks:
|
| 118 |
+
ann_fpath = get_ann_file_path(benchmark)
|
| 119 |
+
cur_anns = [json.loads(line) for line in open(ann_fpath)]
|
| 120 |
+
if args.max_n_samples_per_benchmark > 0:
|
| 121 |
+
cur_anns = cur_anns[:args.max_n_samples_per_benchmark]
|
| 122 |
+
count += len(cur_anns)
|
| 123 |
+
cur_chunk = get_chunk(cur_anns, args.num_chunks, args.chunk_idx)
|
| 124 |
+
all_chunks.extend(cur_chunk)
|
| 125 |
+
print(f"### [{benchmark}] Load chunk with {len(cur_chunk)} samples from {len(cur_anns)} samples.")
|
| 126 |
+
print(f"### Finish loading chunk with {len(all_chunks)} samples from {count} samples in total.")
|
| 127 |
+
|
| 128 |
+
# Create the output directory if it doesn't exist
|
| 129 |
+
if not os.path.exists(args.output_dir):
|
| 130 |
+
os.makedirs(args.output_dir)
|
| 131 |
+
|
| 132 |
+
if args.num_chunks > 1:
|
| 133 |
+
output_name = f"{args.output_name}_{args.num_chunks}_{args.chunk_idx}"
|
| 134 |
+
else:
|
| 135 |
+
output_name = args.output_name
|
| 136 |
+
answers_file = os.path.join(args.output_dir, f"{output_name}.jsonl")
|
| 137 |
+
if args.resume and os.path.exists(answers_file):
|
| 138 |
+
processed_data = [json.loads(line) for line in open(answers_file)]
|
| 139 |
+
processed_idxs = set([f"{d['dataset']}-{d['idx']}" for d in processed_data])
|
| 140 |
+
all_chunks = [d for d in all_chunks if f"{d['dataset']}-{d['idx']}" not in processed_idxs]
|
| 141 |
+
print(f"### Resume from {len(processed_idxs)} samples. {len(all_chunks)} samples to run.", flush=True)
|
| 142 |
+
ans_file = open(answers_file, "a")
|
| 143 |
+
else:
|
| 144 |
+
ans_file = open(answers_file, "w")
|
| 145 |
+
|
| 146 |
+
dataset = TarsierDataset(
|
| 147 |
+
anns=all_chunks, config=data_config, processor=processor
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
generate_kwargs = {
|
| 151 |
+
"do_sample": True if args.temperature > 0 else False,
|
| 152 |
+
"max_new_tokens": args.max_new_tokens,
|
| 153 |
+
"top_p": args.top_p,
|
| 154 |
+
"temperature": args.temperature,
|
| 155 |
+
"use_cache": True
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
if len(dataset) == 0:
|
| 159 |
+
return
|
| 160 |
+
for ann, inputs in tqdm(dataset, total=len(dataset)):
|
| 161 |
+
if inputs is not None:
|
| 162 |
+
if "prompt" in inputs:
|
| 163 |
+
prompt = get_prompt_from_data_dict(ann)
|
| 164 |
+
print(f"###Prompt:\n{prompt}", flush=True)
|
| 165 |
+
# print(f"Input: {processor.processor.tokenizer.decode(inputs['input_ids'][0]), skip_special_tokens=True}", flush=True)
|
| 166 |
+
try:
|
| 167 |
+
model_inputs = {}
|
| 168 |
+
for k, v in inputs.items():
|
| 169 |
+
if not isinstance(v, torch.Tensor):
|
| 170 |
+
continue
|
| 171 |
+
model_inputs[k] = v.to(model.device)
|
| 172 |
+
outputs = model.generate(
|
| 173 |
+
**model_inputs,
|
| 174 |
+
**generate_kwargs,
|
| 175 |
+
)
|
| 176 |
+
output_text = processor.processor.tokenizer.decode(outputs[0][model_inputs['input_ids'][0].shape[0]:], skip_special_tokens=True)
|
| 177 |
+
except Exception as e:
|
| 178 |
+
print(f"Error: {e}")
|
| 179 |
+
output_text = "<error>"
|
| 180 |
+
print(f"###Prediction:\n{output_text}", flush=True)
|
| 181 |
+
answer = ann['messages'][-1]['content'][-1]['reference']
|
| 182 |
+
print(f"###Answer:\n{answer}", flush=True)
|
| 183 |
+
put_pred_to_data_dict(output_text, ann)
|
| 184 |
+
else:
|
| 185 |
+
put_pred_to_data_dict("<error>", ann)
|
| 186 |
+
try:
|
| 187 |
+
ans_file.write(json.dumps(ann, ensure_ascii=False) + "\n")
|
| 188 |
+
except:
|
| 189 |
+
ans_file.write(json.dumps(ann) + "\n")
|
| 190 |
+
ans_file.flush()
|
| 191 |
+
|
| 192 |
+
ans_file.close()
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
if __name__ == "__main__":
|
| 196 |
+
args = parse_args()
|
| 197 |
+
run_inference(args)
|
eval_scripts/DREAM-1K/tarsier/tasks/inference_caption.py
ADDED
|
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (2024) Bytedance Ltd. and/or its affiliates
|
| 2 |
+
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import argparse
|
| 15 |
+
import torch
|
| 16 |
+
from tasks.utils import load_model_and_processor
|
| 17 |
+
# from dataset.mm_dataset import MMDataset
|
| 18 |
+
from dataset.custom_data_parsers.utils import put_pred_to_data_dict, get_prompt_from_data_dict
|
| 19 |
+
from dataset.tarsier_datamodule import TarsierDataset
|
| 20 |
+
from dataset.utils import *
|
| 21 |
+
|
| 22 |
+
import json
|
| 23 |
+
import os
|
| 24 |
+
import math
|
| 25 |
+
from tqdm import tqdm
|
| 26 |
+
import yaml
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
ANN_ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) + '/../data/annotations'
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def split_list(lst, n):
|
| 33 |
+
"""Split a list into n (roughly) equal-sized chunks"""
|
| 34 |
+
chunk_size = math.ceil(len(lst) / n) # integer division
|
| 35 |
+
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_chunk(lst, n, k):
|
| 39 |
+
chunks = split_list(lst, n)
|
| 40 |
+
return chunks[k]
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def parse_args():
|
| 44 |
+
"""
|
| 45 |
+
Parse command-line arguments.
|
| 46 |
+
"""
|
| 47 |
+
parser = argparse.ArgumentParser()
|
| 48 |
+
|
| 49 |
+
# Define the command-line arguments
|
| 50 |
+
|
| 51 |
+
parser.add_argument('--model_name_or_path', type=str, required=True)
|
| 52 |
+
parser.add_argument('--config', type=str, default="configs/tarser2_default_config.yaml")
|
| 53 |
+
# parser.add_argument("--max_n_frames", type=int, default=8, help="Max number of frames to apply average sampling from the given video.")
|
| 54 |
+
parser.add_argument("--max_new_tokens", type=int, default=512, help="max number of generated tokens")
|
| 55 |
+
parser.add_argument("--top_p", type=float, default=1, help="Top_p sampling")
|
| 56 |
+
parser.add_argument("--temperature", type=float, default=0, help="Set temperature > 0 to enable sampling generation.")
|
| 57 |
+
|
| 58 |
+
parser.add_argument("--input_file", type=str, help="Directory to input_file (jsonline)", required=True)
|
| 59 |
+
parser.add_argument("--output_dir", type=str, help="Directory to save the model results", required=True)
|
| 60 |
+
parser.add_argument("--output_name", type=str, default="predictions", help="Name of the file for storing results")
|
| 61 |
+
|
| 62 |
+
parser.add_argument("--num_chunks", type=int, default=1)
|
| 63 |
+
parser.add_argument("--chunk_idx", type=int, default=0)
|
| 64 |
+
|
| 65 |
+
parser.add_argument("--max_n_samples_per_benchmark", type=int, default=-1, help="Set as a small number (like 100) to run as debug.")
|
| 66 |
+
parser.add_argument("--resume", type=lambda x: (str(x).lower() == 'true'), default=True, help="Resume from existing inference results file or overwrite them.")
|
| 67 |
+
|
| 68 |
+
args = parser.parse_args()
|
| 69 |
+
|
| 70 |
+
return args
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def run_inference(args):
|
| 74 |
+
"""
|
| 75 |
+
Run inference on selected benchmarks.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
args: Command-line arguments.
|
| 79 |
+
"""
|
| 80 |
+
# Initialize the model
|
| 81 |
+
# model, processor = load_model_and_processor(args.model_name_or_path, args.max_n_frames) # max_n_frames set in config_file
|
| 82 |
+
data_config = yaml.safe_load(open(args.config, 'r'))
|
| 83 |
+
model, processor = load_model_and_processor(args.model_name_or_path, data_config=data_config)
|
| 84 |
+
|
| 85 |
+
all_chunks = []
|
| 86 |
+
count = 0
|
| 87 |
+
print(f"Start loading dataset...")
|
| 88 |
+
ann_fpath = args.input_file
|
| 89 |
+
cur_anns = [json.loads(line) for line in open(ann_fpath)]
|
| 90 |
+
if args.max_n_samples_per_benchmark > 0:
|
| 91 |
+
cur_anns = cur_anns[:args.max_n_samples_per_benchmark]
|
| 92 |
+
count += len(cur_anns)
|
| 93 |
+
cur_chunk = get_chunk(cur_anns, args.num_chunks, args.chunk_idx)
|
| 94 |
+
all_chunks.extend(cur_chunk)
|
| 95 |
+
print(f"### Load chunk with {len(cur_chunk)} samples from {len(cur_anns)} samples.")
|
| 96 |
+
print(f"### Finish loading chunk with {len(all_chunks)} samples from {count} samples in total.")
|
| 97 |
+
|
| 98 |
+
# Create the output directory if it doesn't exist
|
| 99 |
+
if not os.path.exists(args.output_dir):
|
| 100 |
+
os.makedirs(args.output_dir)
|
| 101 |
+
|
| 102 |
+
if args.num_chunks > 1:
|
| 103 |
+
output_name = f"{args.output_name}_{args.num_chunks}_{args.chunk_idx}"
|
| 104 |
+
else:
|
| 105 |
+
output_name = args.output_name
|
| 106 |
+
answers_file = os.path.join(args.output_dir, f"{output_name}.jsonl")
|
| 107 |
+
if args.resume and os.path.exists(answers_file):
|
| 108 |
+
processed_data = [json.loads(line) for line in open(answers_file)]
|
| 109 |
+
processed_idxs = set([f"{d['dataset']}-{d['idx']}" for d in processed_data])
|
| 110 |
+
all_chunks = [d for d in all_chunks if f"{d['dataset']}-{d['idx']}" not in processed_idxs]
|
| 111 |
+
print(f"### Resume from {len(processed_idxs)} samples. {len(all_chunks)} samples to run.", flush=True)
|
| 112 |
+
ans_file = open(answers_file, "a")
|
| 113 |
+
else:
|
| 114 |
+
ans_file = open(answers_file, "w")
|
| 115 |
+
|
| 116 |
+
dataset = TarsierDataset(
|
| 117 |
+
anns=all_chunks, config=data_config, processor=processor
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
generate_kwargs = {
|
| 121 |
+
"do_sample": True if args.temperature > 0 else False,
|
| 122 |
+
"max_new_tokens": args.max_new_tokens,
|
| 123 |
+
"top_p": args.top_p,
|
| 124 |
+
"temperature": args.temperature,
|
| 125 |
+
"use_cache": True
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
if len(dataset) == 0:
|
| 129 |
+
return
|
| 130 |
+
for ann, inputs in tqdm(dataset):
|
| 131 |
+
if inputs is not None:
|
| 132 |
+
prompt = get_prompt_from_data_dict(ann)
|
| 133 |
+
print(f"###Prompt:\n{prompt}", flush=True)
|
| 134 |
+
# print(f"Input: {processor.processor.tokenizer.decode(inputs['input_ids'][0]), skip_special_tokens=True}", flush=True)
|
| 135 |
+
try:
|
| 136 |
+
model_inputs = {}
|
| 137 |
+
for k, v in inputs.items():
|
| 138 |
+
if not isinstance(v, torch.Tensor):
|
| 139 |
+
continue
|
| 140 |
+
model_inputs[k] = v.to(model.device)
|
| 141 |
+
outputs = model.generate(
|
| 142 |
+
**model_inputs,
|
| 143 |
+
**generate_kwargs,
|
| 144 |
+
)
|
| 145 |
+
output_text = processor.processor.tokenizer.decode(outputs[0][model_inputs['input_ids'][0].shape[0]:], skip_special_tokens=True)
|
| 146 |
+
except Exception as e:
|
| 147 |
+
print(f"Error: {e}")
|
| 148 |
+
output_text = "<error>"
|
| 149 |
+
print(f"###Prediction:\n{output_text}", flush=True)
|
| 150 |
+
put_pred_to_data_dict(output_text, ann)
|
| 151 |
+
else:
|
| 152 |
+
put_pred_to_data_dict("<error>", ann)
|
| 153 |
+
try:
|
| 154 |
+
ans_file.write(json.dumps(ann, ensure_ascii=False) + "\n")
|
| 155 |
+
except:
|
| 156 |
+
ans_file.write(json.dumps(ann) + "\n")
|
| 157 |
+
ans_file.flush()
|
| 158 |
+
|
| 159 |
+
ans_file.close()
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
if __name__ == "__main__":
|
| 163 |
+
# python3 -m tasks.inference_caption --model_name_or_path /tmp/tarsier2-1226-dpo --config configs/tarser2_default_config.yaml --input_file data/annotations/caption-test-new.jsonl --output_dir tmp_outputs
|
| 164 |
+
args = parse_args()
|
| 165 |
+
run_inference(args)
|
eval_scripts/DREAM-1K/tarsier/tasks/inference_quick_start.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (2024) Bytedance Ltd. and/or its affiliates
|
| 2 |
+
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from tasks.utils import load_model_and_processor
|
| 15 |
+
from dataset.custom_data_parsers.utils import put_pred_to_data_dict, get_prompt_from_data_dict
|
| 16 |
+
from dataset.utils import *
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import torch
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
import yaml
|
| 22 |
+
|
| 23 |
+
def process_one(model, processor, prompt, video_file, generate_kwargs):
|
| 24 |
+
# inputs = processor(prompt, video_file, edit_prompt=True, return_prompt=True)
|
| 25 |
+
sample = format_one_sample(video_file, prompt)
|
| 26 |
+
batch_data = processor(sample)
|
| 27 |
+
print(f"###Prompt:\n{get_prompt_from_data_dict(sample)}")
|
| 28 |
+
model_inputs = {}
|
| 29 |
+
for k, v in batch_data.items():
|
| 30 |
+
if not isinstance(v, torch.Tensor):
|
| 31 |
+
continue
|
| 32 |
+
model_inputs[k] = v.to(model.device)
|
| 33 |
+
outputs = model.generate(
|
| 34 |
+
**model_inputs,
|
| 35 |
+
**generate_kwargs,
|
| 36 |
+
)
|
| 37 |
+
# print(processor.processor.tokenizer.decode(outputs[0][:model_inputs['input_ids'][0].shape[0]], skip_special_tokens=True))
|
| 38 |
+
output_text = processor.processor.tokenizer.decode(outputs[0][model_inputs['input_ids'][0].shape[0]:], skip_special_tokens=True)
|
| 39 |
+
return output_text
|
| 40 |
+
|
| 41 |
+
def run():
|
| 42 |
+
import argparse
|
| 43 |
+
|
| 44 |
+
parser = argparse.ArgumentParser()
|
| 45 |
+
parser.add_argument('--model_name_or_path', type=str)
|
| 46 |
+
parser.add_argument('--config', type=str, default="configs/tarser2_default_config.yaml")
|
| 47 |
+
parser.add_argument('--instruction', type=str, default="Describe the video in detail.", help='Input prompt.')
|
| 48 |
+
parser.add_argument('--input_path', type=str, default="assets/examples", help='Path to video/image; or Dir to videos/images')
|
| 49 |
+
# parser.add_argument("--max_n_frames", type=int, default=16, help="Max number of frames to apply average sampling from the given video.")
|
| 50 |
+
parser.add_argument("--max_new_tokens", type=int, default=256, help="max number of generated tokens")
|
| 51 |
+
parser.add_argument("--top_p", type=float, default=1, help="Top_p sampling")
|
| 52 |
+
parser.add_argument("--temperature", type=float, default=0, help="Set temperature > 0 to enable sampling generation.")
|
| 53 |
+
|
| 54 |
+
args = parser.parse_args()
|
| 55 |
+
|
| 56 |
+
# model, processor = load_model_and_processor(args.model_name_or_path, max_n_frames=args.max_n_frames) # max_n_frames set in config_file
|
| 57 |
+
data_config = yaml.safe_load(open(args.config, 'r'))
|
| 58 |
+
model, processor = load_model_and_processor(args.model_name_or_path, data_config=data_config)
|
| 59 |
+
|
| 60 |
+
generate_kwargs = {
|
| 61 |
+
"do_sample": True if args.temperature > 0 else False,
|
| 62 |
+
"max_new_tokens": args.max_new_tokens,
|
| 63 |
+
"top_p": args.top_p,
|
| 64 |
+
"temperature": args.temperature,
|
| 65 |
+
"use_cache": True
|
| 66 |
+
}
|
| 67 |
+
assert os.path.exists(args.input_path), f"input_path not exist: {args.input_path}"
|
| 68 |
+
if os.path.isdir(args.input_path):
|
| 69 |
+
input_files = [os.path.join(args.input_path, fn) for fn in os.listdir(args.input_path) if get_visual_type(fn) in ['video', 'gif', 'image']]
|
| 70 |
+
elif get_visual_type(args.input_path) in ['video', 'gif', 'image']:
|
| 71 |
+
input_files = [args.input_path]
|
| 72 |
+
assert len(input_files) > 0, f"None valid input file in: {args.input_path} {VALID_DATA_FORMAT_STRING}"
|
| 73 |
+
|
| 74 |
+
for input_file in tqdm(input_files, desc="Generating..."):
|
| 75 |
+
visual_type = get_visual_type(input_file)
|
| 76 |
+
if args.instruction:
|
| 77 |
+
prompt = args.instruction
|
| 78 |
+
else:
|
| 79 |
+
if visual_type == 'image':
|
| 80 |
+
prompt = "Describe the image in detail."
|
| 81 |
+
else:
|
| 82 |
+
prompt = "Describe the video in detail."
|
| 83 |
+
|
| 84 |
+
pred = process_one(model, processor, prompt, input_file, generate_kwargs)
|
| 85 |
+
print(f"###Prediction:\n{pred}")
|
| 86 |
+
print('-'*100)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
if __name__ == "__main__":
|
| 90 |
+
# python3 -m tasks.inference_quick_start --model_name_or_path /tmp/tarsier2-1226-dpo --config configs/tarser2_default_config.yaml --input_path /mnt/bn/videonasi18n/wangjw/workspace/tarsier/diving.mp4 --instruction "List the names of all sponsors on the background wall."
|
| 91 |
+
run()
|
eval_scripts/DREAM-1K/tarsier/tasks/utils.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (2024) Bytedance Ltd. and/or its affiliates
|
| 2 |
+
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
from models.modeling_tarsier import TarsierForConditionalGeneration, LlavaConfig
|
| 15 |
+
# from dataset.processor import Processor
|
| 16 |
+
from dataset.tarsier_datamodule import init_processor
|
| 17 |
+
import torch
|
| 18 |
+
import base64
|
| 19 |
+
from tools.color import Color
|
| 20 |
+
import yaml
|
| 21 |
+
|
| 22 |
+
def load_model_and_processor(model_name_or_path, data_config):
|
| 23 |
+
print(Color.red(f"Load model and processor from: {model_name_or_path}"), flush=True)
|
| 24 |
+
if isinstance(data_config, str):
|
| 25 |
+
data_config = yaml.safe_load(open(data_config, 'r'))
|
| 26 |
+
processor = init_processor(model_name_or_path, data_config)
|
| 27 |
+
model_config = LlavaConfig.from_pretrained(
|
| 28 |
+
model_name_or_path,
|
| 29 |
+
trust_remote_code=True,
|
| 30 |
+
)
|
| 31 |
+
model = TarsierForConditionalGeneration.from_pretrained(
|
| 32 |
+
model_name_or_path,
|
| 33 |
+
config=model_config,
|
| 34 |
+
device_map='auto',
|
| 35 |
+
torch_dtype=torch.bfloat16,
|
| 36 |
+
trust_remote_code=True
|
| 37 |
+
)
|
| 38 |
+
model.eval()
|
| 39 |
+
return model, processor
|
| 40 |
+
|
| 41 |
+
def file_to_base64(img_path):
|
| 42 |
+
with open(img_path, 'rb') as video_file:
|
| 43 |
+
video_b64_str = base64.b64encode(video_file.read()).decode()
|
| 44 |
+
return video_b64_str
|
| 45 |
+
|
eval_scripts/DREAM-1K/tarsier/tools/color.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (2024) Bytedance Ltd. and/or its affiliates
|
| 2 |
+
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
class Color:
|
| 15 |
+
|
| 16 |
+
@staticmethod
|
| 17 |
+
def red(x):
|
| 18 |
+
return '\33[31m' +x + '\033[0m'
|
| 19 |
+
|
| 20 |
+
@staticmethod
|
| 21 |
+
def green(x):
|
| 22 |
+
return '\33[32m' +x + '\033[0m'
|
| 23 |
+
|
| 24 |
+
@staticmethod
|
| 25 |
+
def yellow(x):
|
| 26 |
+
return '\33[33m' +x + '\033[0m'
|
| 27 |
+
|
| 28 |
+
@staticmethod
|
| 29 |
+
def blue(x):
|
| 30 |
+
return '\33[34m' +x + '\033[0m'
|
| 31 |
+
|
| 32 |
+
@staticmethod
|
| 33 |
+
def violet(x):
|
| 34 |
+
return '\33[35m' +x + '\033[0m'
|
| 35 |
+
|
| 36 |
+
|
eval_scripts/DREAM-1K/tarsier/tools/conversation.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (2024) Bytedance Ltd. and/or its affiliates
|
| 2 |
+
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# copy and modify from: https://github.com/OpenGVLab/Ask-Anything/blob/main/video_chat2/conversation.py
|
| 16 |
+
from PIL import Image
|
| 17 |
+
import torch
|
| 18 |
+
from transformers import StoppingCriteria, StoppingCriteriaList
|
| 19 |
+
from dataset.custom_data_parsers.utils import put_pred_to_data_dict, get_prompt_from_data_dict
|
| 20 |
+
from dataset.tarsier_datamodule import TarsierDataProcessor
|
| 21 |
+
from dataset.utils import *
|
| 22 |
+
|
| 23 |
+
from enum import auto, Enum
|
| 24 |
+
import os
|
| 25 |
+
import re
|
| 26 |
+
|
| 27 |
+
data_dict_tmp = {
|
| 28 |
+
"messages": [
|
| 29 |
+
{
|
| 30 |
+
"role": "user",
|
| 31 |
+
"content": [
|
| 32 |
+
{
|
| 33 |
+
"type": "video",
|
| 34 |
+
"video": {
|
| 35 |
+
"video_file": "/mnt/hdfs/vlm/videos/movies_aligned_0523/tt8266310/tt8266310_1.50.24-1.50.29.mp4"}
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"type": "text",
|
| 39 |
+
"text": "Describe the video in detail."
|
| 40 |
+
}
|
| 41 |
+
]
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"role": "assistant",
|
| 45 |
+
"content": [
|
| 46 |
+
{
|
| 47 |
+
"type": "text",
|
| 48 |
+
"text": "A man in the driver's seat, wearing a black jacket with a maroon shirt, fastens his seatbelt while smiling at the man in the passenger seat, who is adjusting his position. The passenger, also wearing a black jacket with a maroon shirt, turns to look forward and smiles. The driver then leans forward to start the car and leans back in his seat. In the background, a beige car is visible through the window."
|
| 49 |
+
}]}
|
| 50 |
+
],
|
| 51 |
+
"dataset": "video_caption",
|
| 52 |
+
"task": "video/caption",
|
| 53 |
+
"idx": 0,
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
IMAGE_TOKEN = "<image>"
|
| 58 |
+
VIDEO_TOKEN = "<video>"
|
| 59 |
+
|
| 60 |
+
class SeparatorStyle(Enum):
|
| 61 |
+
"""Different separator style."""
|
| 62 |
+
SINGLE = auto()
|
| 63 |
+
TWO = auto()
|
| 64 |
+
|
| 65 |
+
def get_data_dict(conv, max_n_frames=None):
|
| 66 |
+
data_dict = {
|
| 67 |
+
"messages": []
|
| 68 |
+
}
|
| 69 |
+
for i, (role, message) in enumerate(conv.messages):
|
| 70 |
+
if message:
|
| 71 |
+
text = message["text"]
|
| 72 |
+
content_type = message["type"]
|
| 73 |
+
content = {}
|
| 74 |
+
if content_type == "text":
|
| 75 |
+
content['type'] = 'text'
|
| 76 |
+
content['text'] = text
|
| 77 |
+
task = "text-only"
|
| 78 |
+
elif content_type == "video":
|
| 79 |
+
content['type'] = 'video'
|
| 80 |
+
content['video'] = {
|
| 81 |
+
"video_file": text
|
| 82 |
+
}
|
| 83 |
+
if max_n_frames is not None:
|
| 84 |
+
content['video']['n_frames'] = max_n_frames
|
| 85 |
+
task = "video/QA"
|
| 86 |
+
elif content_type == "image":
|
| 87 |
+
content['type'] = 'image'
|
| 88 |
+
content['image'] = {
|
| 89 |
+
"image_file": text
|
| 90 |
+
}
|
| 91 |
+
task = "image/QA"
|
| 92 |
+
else:
|
| 93 |
+
content['type'] = 'text'
|
| 94 |
+
content['text'] = text
|
| 95 |
+
task = "text-only"
|
| 96 |
+
if data_dict['messages'] and data_dict['messages'][-1]['role'] == role:
|
| 97 |
+
data_dict['messages'][-1]['content'].append(content)
|
| 98 |
+
else:
|
| 99 |
+
data_dict['messages'].append({
|
| 100 |
+
"role": role,
|
| 101 |
+
"content": [content]
|
| 102 |
+
})
|
| 103 |
+
data_dict['dataset'] = task
|
| 104 |
+
data_dict['task'] = task
|
| 105 |
+
check_data_format(data_dict)
|
| 106 |
+
return data_dict
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class StoppingCriteriaSub(StoppingCriteria):
|
| 110 |
+
def __init__(self, stops=[], encounters=1):
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.stops = stops
|
| 113 |
+
|
| 114 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
| 115 |
+
for stop in self.stops:
|
| 116 |
+
if torch.all((stop == input_ids[0][-len(stop):])).item():
|
| 117 |
+
return True
|
| 118 |
+
return False
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class Chat:
|
| 122 |
+
def __init__(self, model, processor: TarsierDataProcessor, device='cuda', debug=False):
|
| 123 |
+
self.model = model
|
| 124 |
+
self.processor = processor
|
| 125 |
+
self.device = device
|
| 126 |
+
self.debug = debug
|
| 127 |
+
stop_words_ids = [torch.tensor([self.processor.processor.tokenizer.eos_token_id]).to(device)]
|
| 128 |
+
self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
|
| 129 |
+
|
| 130 |
+
def ask(self,text,conv):
|
| 131 |
+
conv.messages.append([conv.roles[0], {"text": text, "type": "text"}])
|
| 132 |
+
return conv
|
| 133 |
+
|
| 134 |
+
def prepare_model_inputs(self, conv, n_frames=None):
|
| 135 |
+
# print(conv.messages)
|
| 136 |
+
data_dict = get_data_dict(conv, n_frames)
|
| 137 |
+
if self.debug:
|
| 138 |
+
# print(f"visual_data_file: {visual_data_file}", flush=True)
|
| 139 |
+
print(f"###Prompt:\n{get_prompt_from_data_dict(data_dict)}")
|
| 140 |
+
|
| 141 |
+
batch_data = self.processor(data_dict)
|
| 142 |
+
model_inputs = {}
|
| 143 |
+
for k, v in batch_data.items():
|
| 144 |
+
if not isinstance(v, torch.Tensor):
|
| 145 |
+
continue
|
| 146 |
+
model_inputs[k] = v.to(self.device)
|
| 147 |
+
return model_inputs, conv
|
| 148 |
+
|
| 149 |
+
def answer(self, conv, n_frames=None, max_new_tokens=256, num_beams=1, min_length=1, top_p=1.0,
|
| 150 |
+
repetition_penalty=1.0, length_penalty=1, temperature=0):
|
| 151 |
+
inputs, conv = self.prepare_model_inputs(conv, n_frames)
|
| 152 |
+
if self.model is not None:
|
| 153 |
+
outputs = self.model.generate(
|
| 154 |
+
**inputs,
|
| 155 |
+
max_new_tokens=max_new_tokens,
|
| 156 |
+
stopping_criteria=self.stopping_criteria,
|
| 157 |
+
num_beams=num_beams,
|
| 158 |
+
do_sample=True if temperature > 0 else False,
|
| 159 |
+
min_length=min_length,
|
| 160 |
+
top_p=top_p,
|
| 161 |
+
repetition_penalty=repetition_penalty,
|
| 162 |
+
length_penalty=length_penalty,
|
| 163 |
+
temperature=temperature,
|
| 164 |
+
)
|
| 165 |
+
output_text = self.processor.processor.tokenizer.decode(outputs[0][inputs['input_ids'][0].shape[0]:], skip_special_tokens=True)
|
| 166 |
+
else:
|
| 167 |
+
output_text = "Fake respone as launched in debug mode!"
|
| 168 |
+
conv.messages.append(
|
| 169 |
+
[conv.roles[1], {"text": output_text, "type": "text"}]
|
| 170 |
+
)
|
| 171 |
+
return output_text, conv
|
| 172 |
+
|
| 173 |
+
class EasyDict(dict):
|
| 174 |
+
"""
|
| 175 |
+
Get attributes
|
| 176 |
+
|
| 177 |
+
>>> d = EasyDict({'foo':3})
|
| 178 |
+
>>> d['foo']
|
| 179 |
+
3
|
| 180 |
+
>>> d.foo
|
| 181 |
+
3
|
| 182 |
+
>>> d.bar
|
| 183 |
+
Traceback (most recent call last):
|
| 184 |
+
...
|
| 185 |
+
AttributeError: 'EasyDict' object has no attribute 'bar'
|
| 186 |
+
|
| 187 |
+
Works recursively
|
| 188 |
+
|
| 189 |
+
>>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}})
|
| 190 |
+
>>> isinstance(d.bar, dict)
|
| 191 |
+
True
|
| 192 |
+
>>> d.bar.x
|
| 193 |
+
1
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
def __init__(self, d=None, **kwargs):
|
| 197 |
+
if d is None:
|
| 198 |
+
d = {}
|
| 199 |
+
if kwargs:
|
| 200 |
+
d.update(**kwargs)
|
| 201 |
+
for k, v in d.items():
|
| 202 |
+
setattr(self, k, v)
|
| 203 |
+
# Class attributes
|
| 204 |
+
for k in self.__class__.__dict__.keys():
|
| 205 |
+
if not (k.startswith("__") and k.endswith("__")) and not k in ("update", "pop"):
|
| 206 |
+
setattr(self, k, getattr(self, k))
|
| 207 |
+
|
| 208 |
+
def __setattr__(self, name, value):
|
| 209 |
+
if isinstance(value, (list, tuple)):
|
| 210 |
+
value = [self.__class__(x) if isinstance(x, dict) else x for x in value]
|
| 211 |
+
elif isinstance(value, dict) and not isinstance(value, self.__class__):
|
| 212 |
+
value = self.__class__(value)
|
| 213 |
+
super(EasyDict, self).__setattr__(name, value)
|
| 214 |
+
super(EasyDict, self).__setitem__(name, value)
|
| 215 |
+
|
| 216 |
+
__setitem__ = __setattr__
|
| 217 |
+
|
| 218 |
+
def update(self, e=None, **f):
|
| 219 |
+
d = e or dict()
|
| 220 |
+
d.update(f)
|
| 221 |
+
for k in d:
|
| 222 |
+
setattr(self, k, d[k])
|
| 223 |
+
|
| 224 |
+
def pop(self, k, d=None):
|
| 225 |
+
if hasattr(self, k):
|
| 226 |
+
delattr(self, k)
|
| 227 |
+
return super(EasyDict, self).pop(k, d)
|
| 228 |
+
|
| 229 |
+
conv_tarsier = EasyDict({
|
| 230 |
+
"system": "",
|
| 231 |
+
"roles": ("USER", "ASSISTANT"),
|
| 232 |
+
"messages": [],
|
| 233 |
+
"sep1": " ",
|
| 234 |
+
"sep2": "</s>",
|
| 235 |
+
}
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
conv_tarsier_yi = EasyDict({
|
| 239 |
+
"system": "",
|
| 240 |
+
"roles": ("USER", "ASSISTANT"),
|
| 241 |
+
"messages": [],
|
| 242 |
+
"sep1": " ",
|
| 243 |
+
"sep2": "<|endoftext|>",
|
| 244 |
+
}
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
conv_tarsier_qwen2_vl = EasyDict({
|
| 248 |
+
"system": "",
|
| 249 |
+
"roles": ("user", "assistant"),
|
| 250 |
+
"messages": [],
|
| 251 |
+
}
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
conv_templates = {
|
| 255 |
+
"tarsier2-7b": conv_tarsier_qwen2_vl
|
| 256 |
+
}
|
eval_scripts/DREAM-1K/tarsier/tools/ptbtokenizer.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
#
|
| 3 |
+
# File Name : ptbtokenizer.py
|
| 4 |
+
#
|
| 5 |
+
# Description : Do the PTB Tokenization and remove punctuations.
|
| 6 |
+
#
|
| 7 |
+
# Creation Date : 29-12-2014
|
| 8 |
+
# Last Modified : Thu Mar 19 09:53:35 2015
|
| 9 |
+
# Authors : Hao Fang <hfang@uw.edu> and Tsung-Yi Lin <tl483@cornell.edu>
|
| 10 |
+
import os
|
| 11 |
+
import subprocess
|
| 12 |
+
import tempfile
|
| 13 |
+
|
| 14 |
+
# path to the stanford corenlp jar
|
| 15 |
+
STANFORD_CORENLP_3_4_1_JAR = os.path.dirname(os.path.abspath(__file__)) + '/stanford-corenlp-3.4.1.jar'
|
| 16 |
+
|
| 17 |
+
# punctuations to be removed from the sentences
|
| 18 |
+
PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \
|
| 19 |
+
".", "?", "!", ",", ":", "-", "--", "...", ";"]
|
| 20 |
+
|
| 21 |
+
class PTBTokenizer:
|
| 22 |
+
"""Python wrapper of Stanford PTBTokenizer"""
|
| 23 |
+
|
| 24 |
+
def tokenize(self, captions_for_image):
|
| 25 |
+
cmd = [os.getenv("JAVA_HOME"), '-cp', STANFORD_CORENLP_3_4_1_JAR, \
|
| 26 |
+
'edu.stanford.nlp.process.PTBTokenizer', \
|
| 27 |
+
'-preserveLines', '-lowerCase']
|
| 28 |
+
|
| 29 |
+
# ======================================================
|
| 30 |
+
# prepare data for PTB Tokenizer
|
| 31 |
+
# ======================================================
|
| 32 |
+
final_tokenized_captions_for_image = {}
|
| 33 |
+
image_id = [k for k, v in captions_for_image.items() for _ in range(len(v))]
|
| 34 |
+
sentences = '\n'.join([c['caption'].replace('\n', ' ') for k, v in captions_for_image.items() for c in v])
|
| 35 |
+
|
| 36 |
+
# ======================================================
|
| 37 |
+
# save sentences to temporary file
|
| 38 |
+
# ======================================================
|
| 39 |
+
path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__))
|
| 40 |
+
tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname)
|
| 41 |
+
tmp_file.write(sentences.encode())
|
| 42 |
+
tmp_file.close()
|
| 43 |
+
|
| 44 |
+
# ======================================================
|
| 45 |
+
# tokenize sentence
|
| 46 |
+
# ======================================================
|
| 47 |
+
cmd.append(os.path.basename(tmp_file.name))
|
| 48 |
+
p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \
|
| 49 |
+
stdout=subprocess.PIPE)
|
| 50 |
+
token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0]
|
| 51 |
+
token_lines = token_lines.decode()
|
| 52 |
+
lines = token_lines.split('\n')
|
| 53 |
+
# remove temp file
|
| 54 |
+
os.remove(tmp_file.name)
|
| 55 |
+
|
| 56 |
+
# ======================================================
|
| 57 |
+
# create dictionary for tokenized captions
|
| 58 |
+
# ======================================================
|
| 59 |
+
for k, line in zip(image_id, lines):
|
| 60 |
+
if not k in final_tokenized_captions_for_image:
|
| 61 |
+
final_tokenized_captions_for_image[k] = []
|
| 62 |
+
tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \
|
| 63 |
+
if w not in PUNCTUATIONS])
|
| 64 |
+
final_tokenized_captions_for_image[k].append(tokenized_caption)
|
| 65 |
+
|
| 66 |
+
return final_tokenized_captions_for_image
|
eval_scripts/DREAM-1K/tarsier/tools/rw_utils.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (2024) Bytedance Ltd. and/or its affiliates
|
| 2 |
+
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
import json
|
| 15 |
+
from json import JSONEncoder
|
| 16 |
+
import numpy
|
| 17 |
+
import pandas as pd
|
| 18 |
+
|
| 19 |
+
class NumpyArrayEncoder(JSONEncoder):
|
| 20 |
+
def default(self, obj):
|
| 21 |
+
if isinstance(obj, numpy.ndarray):
|
| 22 |
+
return obj.tolist()
|
| 23 |
+
return JSONEncoder.default(self, obj)
|
| 24 |
+
|
| 25 |
+
def write_txt(data, path):
|
| 26 |
+
with open(path, 'w', encoding='utf-8')as f:
|
| 27 |
+
for d in data:
|
| 28 |
+
f.write(f'{d}\n')
|
| 29 |
+
|
| 30 |
+
def read_txt(path):
|
| 31 |
+
with open(path, 'r', encoding='utf-8', errors='ignore') as f:
|
| 32 |
+
lines = [l.strip('\n') for l in f.readlines()]
|
| 33 |
+
return lines
|
| 34 |
+
|
| 35 |
+
def read_jsonlines(path):
|
| 36 |
+
objs = []
|
| 37 |
+
with open(path) as f:
|
| 38 |
+
for line in f:
|
| 39 |
+
line = json.loads(line)
|
| 40 |
+
objs.append(line)
|
| 41 |
+
return objs
|
| 42 |
+
|
| 43 |
+
def write_jsonlines(data, path, cls=None, ensure_ascii=False):
|
| 44 |
+
with open(path, 'w') as f:
|
| 45 |
+
for d in data:
|
| 46 |
+
d = json.dumps(d, ensure_ascii=ensure_ascii, cls=cls)
|
| 47 |
+
f.write(d)
|
| 48 |
+
f.write('\n')
|
| 49 |
+
|
| 50 |
+
def read_parquet(path):
|
| 51 |
+
data = pd.read_parquet(path)
|
| 52 |
+
return data.to_dict('records')
|
| 53 |
+
|
| 54 |
+
def write_parquet(data, path):
|
| 55 |
+
data = pd.DataFrame(data)
|
| 56 |
+
data.to_parquet(path)
|
| 57 |
+
|
| 58 |
+
def read_csv(path):
|
| 59 |
+
data = pd.read_csv(path)
|
| 60 |
+
return data.to_dict(orient='records')
|
| 61 |
+
|
| 62 |
+
def write_csv(data, path):
|
| 63 |
+
data = pd.DataFrame(data)
|
| 64 |
+
data.to_csv(path, index=False, sep='\t')
|
eval_scripts/Daily-Omni/Daily-Omni_pipeline.sh
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
MODEL_PATHS=(
|
| 3 |
+
"path_to_AVoCaDO"
|
| 4 |
+
)
|
| 5 |
+
RESULTS_DIR="$1"
|
| 6 |
+
|
| 7 |
+
ORIGINAL_FILE="eval_scripts/Daily-Omni/grouped_data.json"
|
| 8 |
+
MERGED_FILE="$RESULTS_DIR/captioned_results.json"
|
| 9 |
+
if [ ! -f "$MERGED_FILE" ]; then
|
| 10 |
+
echo "MERGED_FILE not found. Creating from ORIGINAL_FILE..."
|
| 11 |
+
cp "$ORIGINAL_FILE" "$MERGED_FILE"
|
| 12 |
+
fi
|
| 13 |
+
CAPTION_FILES_TO_MERGE=()
|
| 14 |
+
CAPTION_KEYS=()
|
| 15 |
+
|
| 16 |
+
# Step 1: caption geneartion
|
| 17 |
+
for model_path in "${MODEL_PATHS[@]}"; do
|
| 18 |
+
CLEAN_PATH="${model_path%/}"
|
| 19 |
+
model_name=$(basename "$CLEAN_PATH")
|
| 20 |
+
|
| 21 |
+
caption_file="$RESULTS_DIR/${model_name}_caption.jsonl"
|
| 22 |
+
echo "Output caption file will be: $caption_file"
|
| 23 |
+
|
| 24 |
+
python eval_scripts/Daily-Omni/generate_caption.py \
|
| 25 |
+
--model_path "$model_path" \
|
| 26 |
+
--fout_path "$caption_file"
|
| 27 |
+
|
| 28 |
+
if [ -f "$caption_file" ]; then
|
| 29 |
+
CAPTION_FILES_TO_MERGE+=("$caption_file")
|
| 30 |
+
CAPTION_KEYS+=("${model_name}_caption")
|
| 31 |
+
else
|
| 32 |
+
echo "Error: Caption file $caption_file not generated for model $model_path."
|
| 33 |
+
exit 1
|
| 34 |
+
fi
|
| 35 |
+
done
|
| 36 |
+
|
| 37 |
+
# Step 2: merge generated caption files
|
| 38 |
+
echo "Merging all generated caption files..."
|
| 39 |
+
python eval_scripts/Daily-Omni/merge_captions.py \
|
| 40 |
+
--original_file "$MERGED_FILE" \
|
| 41 |
+
--caption_files "${CAPTION_FILES_TO_MERGE[@]}" \
|
| 42 |
+
--merged_file "$MERGED_FILE"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# Step 3: evaluation
|
| 46 |
+
python eval_scripts/Daily-Omni/evaluation.py \
|
| 47 |
+
--merged_file "$MERGED_FILE" \
|
| 48 |
+
--caption_keys "${CAPTION_KEYS[@]}"
|
| 49 |
+
|
| 50 |
+
# Step 4: analysis and save evaluation results
|
| 51 |
+
for caption_key in "${CAPTION_KEYS[@]}"; do
|
| 52 |
+
echo "Running analysis for caption key: $caption_key"
|
| 53 |
+
|
| 54 |
+
result_file="$RESULTS_DIR/${caption_key}_result.jsonl"
|
| 55 |
+
answer_key="${caption_key//_caption/_resp}"
|
| 56 |
+
|
| 57 |
+
if [ -f "$result_file" ]; then
|
| 58 |
+
python eval_scripts/Daily-Omni/analysis.py --result_file_path "$result_file" --answer_key "$answer_key"
|
| 59 |
+
else
|
| 60 |
+
echo "Warning: Result file '$result_file' not found for analysis."
|
| 61 |
+
fi
|
| 62 |
+
done
|
eval_scripts/Daily-Omni/analysis.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
if __name__ == "__main__":
|
| 6 |
+
parser = argparse.ArgumentParser(description="Analyze the evaluation results.")
|
| 7 |
+
parser.add_argument("--result_file_path", type=str, required=True, help="Path to the result file (.jsonl).")
|
| 8 |
+
parser.add_argument("--answer_key", type=str, required=True, help="The key for the model's response in the result file.")
|
| 9 |
+
|
| 10 |
+
args = parser.parse_args()
|
| 11 |
+
|
| 12 |
+
data = pd.read_json(args.result_file_path, lines=True)
|
| 13 |
+
|
| 14 |
+
acc = (data['answer'].str.upper() == data[args.answer_key].str.upper()).mean()
|
| 15 |
+
print(f"Accuracy for {args.answer_key} is: {acc:.2%}")
|
| 16 |
+
|
| 17 |
+
with open(f"{os.path.dirname(args.result_file_path)}/{args.answer_key}.log", "w", encoding='utf-8') as fout:
|
| 18 |
+
fout.write(f"Accuracy for {args.answer_key} is: {acc:.2%}")
|
eval_scripts/Daily-Omni/evaluation.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
###
|
| 2 |
+
# using a llm to answer questions regarding to the video with the specific caption
|
| 3 |
+
###
|
| 4 |
+
import os
|
| 5 |
+
os.environ['GOOGLE_APPLICATION_CREDENTIALS']=''
|
| 6 |
+
LOCATION = "global"
|
| 7 |
+
user_info_path = ''
|
| 8 |
+
user_info = json.load(open(user_info_path))
|
| 9 |
+
PROJECT_ID = user_info['project_id']
|
| 10 |
+
MODEL = "gemini-2.5-pro"
|
| 11 |
+
|
| 12 |
+
import sys
|
| 13 |
+
import time
|
| 14 |
+
import json
|
| 15 |
+
import traceback
|
| 16 |
+
import multiprocessing
|
| 17 |
+
import random
|
| 18 |
+
import numpy as np
|
| 19 |
+
import argparse
|
| 20 |
+
from google import genai
|
| 21 |
+
from google.genai import types
|
| 22 |
+
from IPython.display import HTML, Image, Markdown, display
|
| 23 |
+
from google import genai
|
| 24 |
+
from google.genai.types import (
|
| 25 |
+
FunctionDeclaration,
|
| 26 |
+
GenerateContentConfig,
|
| 27 |
+
GoogleSearch,
|
| 28 |
+
HarmBlockThreshold,
|
| 29 |
+
HarmCategory,
|
| 30 |
+
Part,
|
| 31 |
+
SafetySetting,
|
| 32 |
+
ThinkingConfig,
|
| 33 |
+
Tool,
|
| 34 |
+
ToolCodeExecution,
|
| 35 |
+
)
|
| 36 |
+
import subprocess
|
| 37 |
+
|
| 38 |
+
safety_settings = [
|
| 39 |
+
SafetySetting(category=HarmCategory.HARM_CATEGORY_HATE_SPEECH, threshold=HarmBlockThreshold.OFF),
|
| 40 |
+
SafetySetting(category=HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, threshold=HarmBlockThreshold.OFF),
|
| 41 |
+
SafetySetting(category=HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, threshold=HarmBlockThreshold.OFF),
|
| 42 |
+
SafetySetting(category=HarmCategory.HARM_CATEGORY_HARASSMENT, threshold=HarmBlockThreshold.OFF)
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
CONFIG = types.GenerateContentConfig(
|
| 46 |
+
temperature=0,
|
| 47 |
+
top_p=0.001,
|
| 48 |
+
thinking_config=types.ThinkingConfig(
|
| 49 |
+
include_thoughts=True,
|
| 50 |
+
thinking_budget=512
|
| 51 |
+
),
|
| 52 |
+
safety_settings=safety_settings,
|
| 53 |
+
seed=SEED,
|
| 54 |
+
system_instruction='''
|
| 55 |
+
You are a precise QA assistant. Your task is to answer multiple-choice questions based ONLY on the video caption provided.
|
| 56 |
+
Do not use any outside knowledge or assumptions—your answer must strictly reflect information from the caption.
|
| 57 |
+
Always output only the capital letter corresponding to your choice (e.g., A, B, C, D).
|
| 58 |
+
If the caption does not provide enough information to answer the question, output "N/A" instead.
|
| 59 |
+
'''
|
| 60 |
+
)
|
| 61 |
+
client = genai.Client(vertexai=True, project=PROJECT_ID, location=LOCATION)
|
| 62 |
+
|
| 63 |
+
def set_seed(seed):
|
| 64 |
+
np.random.seed(seed)
|
| 65 |
+
random.seed(seed)
|
| 66 |
+
|
| 67 |
+
SEED = 42
|
| 68 |
+
set_seed(SEED)
|
| 69 |
+
|
| 70 |
+
def caption2json(json_path, caption_path):
|
| 71 |
+
|
| 72 |
+
with open(json_path, 'r', encoding='utf-8') as f:
|
| 73 |
+
json_data = json.load(f)
|
| 74 |
+
model = os.path.basename(caption_path).split("_")[0]
|
| 75 |
+
|
| 76 |
+
captions = {}
|
| 77 |
+
with open(caption_path, 'r', encoding='utf-8') as f:
|
| 78 |
+
for line in f:
|
| 79 |
+
if not line.strip():
|
| 80 |
+
continue
|
| 81 |
+
item = json.loads(line)
|
| 82 |
+
for vid, cap in item.items():
|
| 83 |
+
captions[vid] = cap
|
| 84 |
+
|
| 85 |
+
for entry in json_data:
|
| 86 |
+
vid = entry.get("video_id")
|
| 87 |
+
if vid in captions:
|
| 88 |
+
entry[f"{model}_caption"] = captions[vid]
|
| 89 |
+
|
| 90 |
+
with open(f"{model}_merge_data.json", 'w', encoding='utf-8') as f:
|
| 91 |
+
json.dump(json_data, f, ensure_ascii=False, indent=2)
|
| 92 |
+
|
| 93 |
+
print(f"merged successfully, the output file is {model}_merge_data.json")
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def generate(prompt):
|
| 97 |
+
contents = [prompt]
|
| 98 |
+
|
| 99 |
+
answer, thinking = None, None
|
| 100 |
+
max_retries = 10
|
| 101 |
+
|
| 102 |
+
for i in range(max_retries):
|
| 103 |
+
try:
|
| 104 |
+
response = client.models.generate_content(
|
| 105 |
+
model=MODEL,
|
| 106 |
+
contents=contents,
|
| 107 |
+
config=CONFIG
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
answer_parts, thought_parts = [], []
|
| 111 |
+
for part in response.candidates[0].content.parts:
|
| 112 |
+
if not getattr(part, "text", None):
|
| 113 |
+
continue
|
| 114 |
+
if getattr(part, "thought", False):
|
| 115 |
+
thought_parts.append(part.text)
|
| 116 |
+
else:
|
| 117 |
+
answer_parts.append(part.text)
|
| 118 |
+
answer = "\n".join(answer_parts).strip()
|
| 119 |
+
thinking = "\n".join(thought_parts).strip()
|
| 120 |
+
if answer:
|
| 121 |
+
break
|
| 122 |
+
else:
|
| 123 |
+
print(f"[WARN] Attempt {i+1}: empty answer, retrying ... ")
|
| 124 |
+
time.sleep(3)
|
| 125 |
+
except Exception as e:
|
| 126 |
+
print(f"[ERROR] Attempt {i+1} failed: {e}")
|
| 127 |
+
traceback.print_exc()
|
| 128 |
+
time.sleep(3)
|
| 129 |
+
if not answer:
|
| 130 |
+
return None, None
|
| 131 |
+
print(answer)
|
| 132 |
+
return answer, thinking
|
| 133 |
+
|
| 134 |
+
def worker(task):
|
| 135 |
+
vid, video_duration, question, choices, answer, caption_key, answer_key, caption = task
|
| 136 |
+
choices_text = "\n".join([f"{c}" for c in choices])
|
| 137 |
+
prompt_filled = f'''
|
| 138 |
+
Here is the video caption:
|
| 139 |
+
"{caption}"
|
| 140 |
+
|
| 141 |
+
Question: {question}
|
| 142 |
+
Choices:
|
| 143 |
+
{choices_text}'''
|
| 144 |
+
try:
|
| 145 |
+
resp, _ = generate(prompt_filled)
|
| 146 |
+
return {
|
| 147 |
+
"video_id": vid,
|
| 148 |
+
"video_duration": video_duration,
|
| 149 |
+
"question": question,
|
| 150 |
+
"choices": choices,
|
| 151 |
+
"answer": answer,
|
| 152 |
+
caption_key: caption,
|
| 153 |
+
answer_key: resp
|
| 154 |
+
}
|
| 155 |
+
except Exception as e:
|
| 156 |
+
traceback.print_exc()
|
| 157 |
+
return {
|
| 158 |
+
"video_id": vid,
|
| 159 |
+
"video_duration": video_duration,
|
| 160 |
+
"question": question,
|
| 161 |
+
"choices": choices,
|
| 162 |
+
"answer": answer,
|
| 163 |
+
caption_key: caption,
|
| 164 |
+
answer_key: None
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
def run_multiprocess_tasks(tasks, num_processes=None, fout_path=None):
|
| 168 |
+
if num_processes is None:
|
| 169 |
+
num_processes = multiprocessing.cpu_count()
|
| 170 |
+
|
| 171 |
+
with multiprocessing.Pool(processes=num_processes) as pool:
|
| 172 |
+
results = pool.map(worker, tasks)
|
| 173 |
+
|
| 174 |
+
if fout_path:
|
| 175 |
+
with open(fout_path, "w", encoding='utf-8') as f:
|
| 176 |
+
for item in results:
|
| 177 |
+
f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
| 178 |
+
f.flush()
|
| 179 |
+
return results
|
| 180 |
+
|
| 181 |
+
def eval_dailyomni_caption_qas(file_path, caption_keys=["omni_caption"]):
|
| 182 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 183 |
+
data = json.load(f)
|
| 184 |
+
|
| 185 |
+
all_results = []
|
| 186 |
+
for caption_key in caption_keys:
|
| 187 |
+
answer_key = caption_key.replace("_caption", "_resp")
|
| 188 |
+
fout_path = f"{os.path.dirname(file_path)}/{caption_key}_result.jsonl"
|
| 189 |
+
|
| 190 |
+
tasks = []
|
| 191 |
+
for video_info in data:
|
| 192 |
+
vid = video_info["video_id"]
|
| 193 |
+
video_duration = video_info["video_duration"]
|
| 194 |
+
caption = video_info[caption_key]
|
| 195 |
+
for q in video_info["questions"]:
|
| 196 |
+
task_item = (
|
| 197 |
+
vid,
|
| 198 |
+
video_duration,
|
| 199 |
+
q["Question"],
|
| 200 |
+
q["Choice"],
|
| 201 |
+
q["Answer"],
|
| 202 |
+
caption_key,
|
| 203 |
+
answer_key,
|
| 204 |
+
caption
|
| 205 |
+
)
|
| 206 |
+
tasks.append(task_item)
|
| 207 |
+
|
| 208 |
+
results = run_multiprocess_tasks(tasks, num_processes=20, fout_path=fout_path)
|
| 209 |
+
all_results.extend(results)
|
| 210 |
+
|
| 211 |
+
return all_results
|
| 212 |
+
|
| 213 |
+
if __name__ == "__main__":
|
| 214 |
+
parser = argparse.ArgumentParser(description="Evaluate captions using Gemini.")
|
| 215 |
+
parser.add_argument("--merged_file", type=str, required=True, help="Path to the merged caption file.")
|
| 216 |
+
parser.add_argument(
|
| 217 |
+
"--caption_keys",
|
| 218 |
+
type=str,
|
| 219 |
+
nargs='+',
|
| 220 |
+
required=True,
|
| 221 |
+
help="A list of caption keys to evaluate"
|
| 222 |
+
)
|
| 223 |
+
args = parser.parse_args()
|
| 224 |
+
|
| 225 |
+
eval_dailyomni_caption_qas(args.merged_file, caption_keys=args.caption_keys)
|
eval_scripts/Daily-Omni/generate_caption.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor
|
| 4 |
+
from qwen_omni_utils import process_mm_info
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
import multiprocessing as mp
|
| 10 |
+
import traceback
|
| 11 |
+
import random
|
| 12 |
+
import glob
|
| 13 |
+
|
| 14 |
+
VIDEO_MAX_PIXELS = 401408 # 512*28*28
|
| 15 |
+
VIDEO_TOTAL_PIXELS = 20070400 # 512*28*28*50
|
| 16 |
+
USE_AUDIO_IN_VIDEO = True
|
| 17 |
+
video_base_dir = "path_to_Daily-Omni_Videos"
|
| 18 |
+
os.environ['VIDEO_MAX_PIXELS'] = str(VIDEO_TOTAL_PIXELS)
|
| 19 |
+
|
| 20 |
+
def chat(file_path, prompt, model, processor, model_path, max_new_tokens=2048):
|
| 21 |
+
|
| 22 |
+
conversation = [
|
| 23 |
+
{
|
| 24 |
+
"role": "system",
|
| 25 |
+
"content": [
|
| 26 |
+
{"type": "text", "text": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."}
|
| 27 |
+
],
|
| 28 |
+
},
|
| 29 |
+
{
|
| 30 |
+
"role": "user",
|
| 31 |
+
"content": [
|
| 32 |
+
{
|
| 33 |
+
"type": "video",
|
| 34 |
+
"video": file_path,
|
| 35 |
+
"max_pixels": VIDEO_MAX_PIXELS,
|
| 36 |
+
"max_frames": 256
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"type": "text",
|
| 40 |
+
"text": prompt
|
| 41 |
+
},
|
| 42 |
+
],
|
| 43 |
+
},
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
text = processor.apply_chat_template(conversation, add_generation_prompt=True, tokenize=False)
|
| 47 |
+
audios, images, videos = process_mm_info(conversation, use_audio_in_video=USE_AUDIO_IN_VIDEO)
|
| 48 |
+
inputs = processor(text=text, audio=audios, images=images, videos=videos, return_tensors="pt", padding=True, use_audio_in_video=USE_AUDIO_IN_VIDEO)
|
| 49 |
+
inputs = inputs.to(model.device).to(model.dtype)
|
| 50 |
+
|
| 51 |
+
text_ids = model.generate(**inputs, use_audio_in_video=USE_AUDIO_IN_VIDEO, do_sample=False, thinker_max_new_tokens=max_new_tokens)
|
| 52 |
+
text = processor.batch_decode(text_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 53 |
+
model_generation = text.split("\nassistant\n")[-1]
|
| 54 |
+
|
| 55 |
+
return model_generation
|
| 56 |
+
|
| 57 |
+
def worker_proc(rank, gpu_id, model_path, video_paths, prompt, out_path):
|
| 58 |
+
device_map = {"": f"cuda:{gpu_id}"}
|
| 59 |
+
|
| 60 |
+
model = Qwen2_5OmniForConditionalGeneration.from_pretrained(
|
| 61 |
+
model_path,
|
| 62 |
+
torch_dtype=torch.bfloat16,
|
| 63 |
+
device_map=device_map,
|
| 64 |
+
attn_implementation="flash_attention_2",
|
| 65 |
+
)
|
| 66 |
+
model.disable_talker()
|
| 67 |
+
processor = Qwen2_5OmniProcessor.from_pretrained(model_path)
|
| 68 |
+
|
| 69 |
+
fout = open(out_path, "w", encoding="utf-8")
|
| 70 |
+
|
| 71 |
+
for video_path in tqdm(video_paths, desc=f"Worker-{rank}[GPU-{gpu_id}]"):
|
| 72 |
+
try:
|
| 73 |
+
model_generation = chat(video_path, prompt, model, processor, model_path)
|
| 74 |
+
|
| 75 |
+
video_id = os.path.basename(video_path).split(".mp4")[0]
|
| 76 |
+
|
| 77 |
+
out_data = {
|
| 78 |
+
"video_id": video_id,
|
| 79 |
+
"caption": model_generation,
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
fout.write(json.dumps(out_data, ensure_ascii=False) + "\n")
|
| 83 |
+
fout.flush()
|
| 84 |
+
except Exception as e:
|
| 85 |
+
print(f"[Worker-{rank}] Error on {video_path}: {e}")
|
| 86 |
+
traceback.print_exc()
|
| 87 |
+
|
| 88 |
+
fout.close()
|
| 89 |
+
print(f"[Worker-{rank}] Done, wrote results to {out_path}")
|
| 90 |
+
|
| 91 |
+
def run_multi_gpu(model_path, video_paths, prompt_list, final_out_path, num_gpus=8):
|
| 92 |
+
chunk_size = len(video_paths) // num_gpus + 1
|
| 93 |
+
chunks = [video_paths[i:i+chunk_size] for i in range(0, len(video_paths), chunk_size)]
|
| 94 |
+
|
| 95 |
+
processes = []
|
| 96 |
+
tmp_files = []
|
| 97 |
+
|
| 98 |
+
for rank, chunk in enumerate(chunks):
|
| 99 |
+
gpu_id = rank % num_gpus
|
| 100 |
+
tmp_out = final_out_path.replace(".jsonl", f".part{rank}.jsonl")
|
| 101 |
+
tmp_files.append(tmp_out)
|
| 102 |
+
prompt = random.choice(prompt_list)
|
| 103 |
+
|
| 104 |
+
p = mp.Process(
|
| 105 |
+
target=worker_proc,
|
| 106 |
+
args=(rank, gpu_id, model_path, chunk, prompt, tmp_out)
|
| 107 |
+
)
|
| 108 |
+
p.start()
|
| 109 |
+
processes.append(p)
|
| 110 |
+
|
| 111 |
+
for p in processes:
|
| 112 |
+
p.join()
|
| 113 |
+
|
| 114 |
+
with open(final_out_path, "w", encoding="utf-8") as fout:
|
| 115 |
+
for tmp in tmp_files:
|
| 116 |
+
with open(tmp, "r", encoding="utf-8") as fin:
|
| 117 |
+
for line in fin:
|
| 118 |
+
fout.write(line)
|
| 119 |
+
os.remove(tmp)
|
| 120 |
+
|
| 121 |
+
print(f"All results merged into {final_out_path}")
|
| 122 |
+
|
| 123 |
+
if __name__ == "__main__":
|
| 124 |
+
parser = argparse.ArgumentParser(description="Evaluate a model and save results.")
|
| 125 |
+
parser.add_argument("--model_path", type=str, required=True, help="Path to the model checkpoint.")
|
| 126 |
+
parser.add_argument("--fout_path", type=str, required=True, help="Path to the output caption file")
|
| 127 |
+
args = parser.parse_args()
|
| 128 |
+
mp.set_start_method("spawn", force=True)
|
| 129 |
+
|
| 130 |
+
video_paths = glob.glob(os.path.join(video_base_dir, "**", "*.mp4"), recursive=True)
|
| 131 |
+
|
| 132 |
+
prompt_list = [
|
| 133 |
+
"Provide a comprehensive description of all the content in the video, leaving out no details. Be sure to include as much of the audio information as possible, and ensure that your descriptions of the audio and video are closely aligned.",
|
| 134 |
+
"Thoroughly describe everything in the video, capturing every detail. Include as much information from the audio as possible, and ensure that the descriptions of both audio and video are well-coordinated.",
|
| 135 |
+
"Please describe all the information in the video without sparing every detail in it. As you describe, you should also describe as much of the information in the audio as possible, and pay attention to the synchronization between the audio and video descriptions.",
|
| 136 |
+
"Offer a detailed description of the video, making sure to include every detail. Also, incorporate as much information from the audio as you can, and ensure that your descriptions of the audio and video are in sync.",
|
| 137 |
+
"Describe every aspect of the video in full detail, covering all the information it contains. Additionally, include as much of the audio content as you can, and make sure your descriptions of the audio and video are synchronized.",
|
| 138 |
+
"Please provide a thorough description of all the content in the video, including every detail. As you describe, ensure that you also cover as much information from the audio as possible, and be mindful of the synchronization between the audio and video as you do so.",
|
| 139 |
+
"Give a detailed account of everything in the video, capturing all the specifics. While doing so, also include as much information from the audio as possible, ensuring that the descriptions of audio and video are well-synchronized."
|
| 140 |
+
]
|
| 141 |
+
|
| 142 |
+
run_multi_gpu(args.model_path, video_paths, prompt_list, args.fout_path, num_gpus=8)
|
eval_scripts/Daily-Omni/grouped_data.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|