Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +6 -0
- VILA/.ipynb_checkpoints/Dockerfile-checkpoint +18 -0
- VILA/.ipynb_checkpoints/README-checkpoint.md +341 -0
- VILA/.ipynb_checkpoints/environment_setup-checkpoint.sh +33 -0
- VILA/CIs/license_all.sh +1 -0
- VILA/CIs/license_commited.sh +6 -0
- VILA/data_prepare/.DS_Store +0 -0
- VILA/data_prepare/LICENSE +8 -0
- VILA/data_prepare/README.md +172 -0
- VILA/data_prepare/panda70m.sh +34 -0
- VILA/data_prepare/panda_split.py +117 -0
- VILA/data_prepare/parallel_shards.sh +29 -0
- VILA/demo_images/LongVILA-pipeline.png +3 -0
- VILA/demo_images/av.png +3 -0
- VILA/demo_images/demo_img_1.png +3 -0
- VILA/demo_images/demo_img_2.png +3 -0
- VILA/demo_images/demo_img_3.png +3 -0
- VILA/demo_images/longvila-logo.png +3 -0
- VILA/demo_images/vila-logo.jpg +0 -0
- VILA/demo_trt_llm/README.md +3 -0
- VILA/inference_test/inference_test.json +546 -0
- VILA/inference_test/inference_test.py +153 -0
- VILA/llava.egg-info/PKG-INFO +287 -0
- VILA/llava.egg-info/SOURCES.txt +154 -0
- VILA/llava.egg-info/dependency_links.txt +1 -0
- VILA/llava.egg-info/requires.txt +37 -0
- VILA/llava.egg-info/top_level.txt +7 -0
- VILA/llava/.DS_Store +0 -0
- VILA/llava/constants.py +31 -0
- VILA/llava/conversation.py +489 -0
- VILA/llava/entry.py +18 -0
- VILA/llava/mm_utils.py +407 -0
- VILA/llava/modals.py +26 -0
- VILA/scripts/convert_gqa_for_eval.py +33 -0
- VILA/scripts/convert_karpathy_to_anno.py +130 -0
- VILA/scripts/convert_mmbench_for_submission.py +46 -0
- VILA/scripts/convert_mmvet_for_eval.py +33 -0
- VILA/scripts/convert_seed_for_submission.py +88 -0
- VILA/scripts/convert_sqa_to_llava.py +104 -0
- VILA/scripts/convert_sqa_to_llava_base_prompt.py +327 -0
- VILA/scripts/convert_vizwiz_for_submission.py +60 -0
- VILA/scripts/convert_vqav2_for_submission.py +65 -0
- VILA/scripts/extract_mm_projector.py +57 -0
- VILA/scripts/zero2.json +23 -0
- VILA/scripts/zero3.json +28 -0
- VILA/scripts/zero3_mics_mini_fixed.json +30 -0
- VILA/scripts/zero3_mics_tiny_fixed.json +30 -0
- VILA/scripts/zero3_offload.json +56 -0
- VILA/scripts/zero3_offload_inference.json +21 -0
- VILA/scripts/zero3pp.json +29 -0
.gitattributes
CHANGED
|
@@ -370,3 +370,9 @@ groundingLMM/gradio-dev/demo/video_identity/video/video_sample.mp4 filter=lfs di
|
|
| 370 |
groundingLMM/gradio-dev/demo/video_subtitle/files/a.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 371 |
groundingLMM/gradio-dev/demo/video_subtitle/files/b.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 372 |
groundingLMM/gradio-dev/demo/unispeech-speaker-verification/samples/kirsten_dunst.wav filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 370 |
groundingLMM/gradio-dev/demo/video_subtitle/files/a.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 371 |
groundingLMM/gradio-dev/demo/video_subtitle/files/b.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 372 |
groundingLMM/gradio-dev/demo/unispeech-speaker-verification/samples/kirsten_dunst.wav filter=lfs diff=lfs merge=lfs -text
|
| 373 |
+
VILA/demo_images/demo_img_3.png filter=lfs diff=lfs merge=lfs -text
|
| 374 |
+
VILA/demo_images/LongVILA-pipeline.png filter=lfs diff=lfs merge=lfs -text
|
| 375 |
+
VILA/demo_images/longvila-logo.png filter=lfs diff=lfs merge=lfs -text
|
| 376 |
+
VILA/demo_images/demo_img_2.png filter=lfs diff=lfs merge=lfs -text
|
| 377 |
+
VILA/demo_images/demo_img_1.png filter=lfs diff=lfs merge=lfs -text
|
| 378 |
+
VILA/demo_images/av.png filter=lfs diff=lfs merge=lfs -text
|
VILA/.ipynb_checkpoints/Dockerfile-checkpoint
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM nvcr.io/nvidia/pytorch:24.06-py3
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
RUN curl https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -o ~/miniconda.sh \
|
| 6 |
+
&& sh ~/miniconda.sh -b -p /opt/conda \
|
| 7 |
+
&& rm ~/miniconda.sh
|
| 8 |
+
|
| 9 |
+
ENV PATH /opt/conda/bin:$PATH
|
| 10 |
+
COPY pyproject.toml pyproject.toml
|
| 11 |
+
COPY llava llava
|
| 12 |
+
|
| 13 |
+
COPY environment_setup.sh environment_setup.sh
|
| 14 |
+
RUN bash environment_setup.sh vila
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
COPY server.py server.py
|
| 18 |
+
CMD ["conda", "run", "-n", "vila", "--no-capture-output", "python", "-u", "-W", "ignore", "server.py"]
|
VILA/.ipynb_checkpoints/README-checkpoint.md
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<p align="center">
|
| 2 |
+
<img src="demo_images/vila-logo.jpg" width="20%"/>
|
| 3 |
+
</p>
|
| 4 |
+
|
| 5 |
+
# VILA: On Pre-training for Visual Language Models
|
| 6 |
+
|
| 7 |
+
[](CODE_LICENSE)
|
| 8 |
+
[](MODEL_LICENSE)
|
| 9 |
+
[](https://www.python.org/downloads/release/python-3100/)
|
| 10 |
+
|
| 11 |
+
[VILA arxiv](https://arxiv.org/abs/2312.07533) / [VILA Demo](https://vila-demo.hanlab.ai/) / [VILA Huggingface](https://huggingface.co/collections/Efficient-Large-Model/vila-on-pre-training-for-visual-language-models-65d8022a3a52cd9bcd62698e)
|
| 12 |
+
|
| 13 |
+
## 💡 Introduction
|
| 14 |
+
|
| 15 |
+
VILA is a visual language model (VLM) pretrained with interleaved image-text data at scale, enabling **video understanding** and **multi-image understanding** capabilities. VILA is deployable on the edge by [AWQ](https://arxiv.org/pdf/2306.00978.pdf) 4bit quantization and [TinyChat](https://github.com/mit-han-lab/llm-awq/tree/main/tinychat) framework. We find: (1) image-text pairs are not enough, interleaved image-text is essential; (2) unfreezing LLM during interleaved image-text pre-training enables in-context learning; (3)re-blending text-only instruction data is crucial to boost both VLM and text-only performance; (4) token compression extends #video frames. VILA unveils appealing capabilities, including: video reasoning, in-context learning, visual chain-of-thought, and better world knowledge.
|
| 16 |
+
|
| 17 |
+
## 💡 News
|
| 18 |
+
- [2024/08] We release [LongVILA](./LongVILA.md) that supports long video understanding (Captioning, QA, Needle-in-a-Haystack) up to 1024 frames.
|
| 19 |
+
- [2024/07] VILA1.5 also ranks 1st place (OSS model) on [MLVU test leaderboard](https://github.com/JUNJIE99/MLVU).
|
| 20 |
+
- [2024/06] VILA1.5 is now the best open sourced VLM on [MMMU leaderboard](https://mmmu-benchmark.github.io/#leaderboard) and [Video-MME](https://video-mme.github.io/home_page.html#leaderboard) leaderboard!
|
| 21 |
+
- [2024/05] We release VILA-1.5, which offers **video understanding capability**. VILA-1.5 comes with four model sizes: 3B/8B/13B/40B.
|
| 22 |
+
- [2024/05] We release [AWQ](https://arxiv.org/pdf/2306.00978.pdf)-quantized 4bit VILA-1.5 models. VILA-1.5 is efficiently deployable on diverse NVIDIA GPUs (A100, 4090, 4070 Laptop, Orin, Orin Nano) by [TinyChat](https://github.com/mit-han-lab/llm-awq/tree/main/tinychat) and [TensorRT-LLM](demo_trt_llm) backends.
|
| 23 |
+
- [2024/03] VILA has been accepted by CVPR 2024!
|
| 24 |
+
- [2024/02] We release [AWQ](https://arxiv.org/pdf/2306.00978.pdf)-quantized 4bit VILA models, deployable on Jetson Orin and laptops through [TinyChat](https://github.com/mit-han-lab/llm-awq/tree/main/tinychat) and [TinyChatEngine](https://github.com/mit-han-lab/TinyChatEngine).
|
| 25 |
+
- [2024/02] VILA is released. We propose interleaved image-text pretraining that enables **multi-image** VLM. VILA comes with impressive in-context learning capabilities. We open source everything: including training code, evaluation code, datasets, model ckpts.
|
| 26 |
+
- [2023/12] [Paper](https://arxiv.org/abs/2312.07533) is on Arxiv!
|
| 27 |
+
|
| 28 |
+
## Performance
|
| 29 |
+
|
| 30 |
+
### Image QA Benchmarks
|
| 31 |
+
|
| 32 |
+
| $~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~$ | Prec. | VQAv2 | GQA | VizWiz | SQA-I | VQA-T | POPE | MME | MMB | MMB-CN | SEED | SEED-I | MMMU (val) | MMMU (test) | llava-bench | MM-Vet | Average |
|
| 33 |
+
| -------------------------------- | ----- | ----- | ---- | ------ | ----- | ----- | ---- | ------- | ---- | ------ | ---- | ------ | ---------- | ----------- | ----------- | ------ | ------- |
|
| 34 |
+
| VILA1.5-3B | fp16 | 80.4 | 61.5 | 53.5 | 69.0 | 60.4 | 85.9 | 1442.44 | 63.4 | 52.7 | 60.9 | 67.9 | 33.3 | 30.8 | 75.9 | 35.4 | 60.2 |
|
| 35 |
+
| VILA1.5-3B-AWQ | int4 | 80.0 | 61.1 | 53.8 | 67.8 | 60.4 | 85.9 | 1437.34 | 63.3 | 51.4 | 59.8 | 66.6 | 32.7 | 31.1 | 75.0 | 37.3 | 59.9 |
|
| 36 |
+
| VILA1.5-3B-S2 | fp16 | 79.8 | 61.4 | 61.3 | 69.6 | 63.4 | 85.3 | 1431.65 | 62.8 | 52.2 | 60.0 | 66.4 | 32.8 | 31.3 | 76.7 | 38.6 | 60.9 |
|
| 37 |
+
| VILA1.5-3B-S2-AWQ | int4 | 79.4 | 61.3 | 62.3 | 69.2 | 63.0 | 85.8 | 1417.06 | 61.6 | 51.5 | 59.1 | 65.7 | 33.4 | 30.4 | 77.1 | 36.7 | 60.5 |
|
| 38 |
+
| Llama-3-VILA1.5-8B | fp16 | 83.0 | 63.5 | 63.2 | 82.0 | 68.5 | 85.6 | 1634.91 | 75.3 | 69.9 | 66.4 | 73.8 | 38.6 | 32.7 | 71.9 | 43.2 | 66.6 |
|
| 39 |
+
| Llama-3-VILA1.5-8B-AWQ | int4 | 80.3 | 61.7 | 59.3 | 79.0 | 65.4 | 82.9 | 1593.65 | 71.0 | 64.9 | 64.0 | 71.1 | 36.0 | 36.1 | 79.0 | 37.2 | 64.5 |
|
| 40 |
+
| VILA1.5-13B | fp16 | 82.8 | 64.3 | 62.6 | 80.1 | 65.0 | 86.3 | 1569.55 | 74.9 | 66.3 | 65.1 | 72.6 | 37.9 | 33.6 | 80.8 | 44.3 | 66.3 |
|
| 41 |
+
| VILA1.5-13B-AWQ | int4 | 82.7 | 64.5 | 63.3 | 79.7 | 64.7 | 86.7 | 1531.35 | 74.7 | 66.7 | 65.1 | 72.6 | 37.8 | 34.0 | 81.9 | 46.4 | 66.5 |
|
| 42 |
+
| VILA1.5-40B | fp16 | 84.3 | 64.6 | 62.2 | 87.2 | 73.6 | 87.3 | 1726.82 | 82.4 | 80.2 | 69.1 | 75.8 | 51.9 | 46.9 | 81.3 | 53.0 | 72.4 |
|
| 43 |
+
| VILA1.5-40B-AWQ | int4 | 84.1 | 64.4 | 61.3 | 86.7 | 73.2 | 88.2 | 1714.79 | 83.2 | 79.6 | 68.9 | 75.6 | 49.3 | 46.2 | 83.0 | 51.4 | 72.1 |
|
| 44 |
+
|
| 45 |
+
<sup>NOTE: VQAV2 and VizWiz are test-dev, the average accuracy is calculated over all datasets and MME numbers are divided by 20.</sup>
|
| 46 |
+
|
| 47 |
+
### Video QA Benchmarks
|
| 48 |
+
|
| 49 |
+
| $~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~$ | Prec. | Perception Test | ActivityNet | MSVD | MSRVTT | TGIF | EgoSchema (test) | CinePile
|
| 50 |
+
| -------------------------------- | ----- | ----- | ---- | ------ | ----- | ----- | ----- | ----- |
|
| 51 |
+
| VILA1.5-3B | fp16 | 47 | 50.2 | 76.6 | 57.5 | 51.7 | 42.6 | 37.9
|
| 52 |
+
| VILA1.5-3B-S2 | fp16 | 49.7 | 50.7 | 76.9 | 57.6 | 51.7 |
|
| 53 |
+
| Llama-3-VILA1.5-8B | fp16 | 54.1 | 54.3 | 78.3 | 60.1 | 54.1 | 50.4 | 48.7
|
| 54 |
+
| VILA1.5-13B | fp16 | 53.6 | 54.7 | 77.9 | 60.2 | 56 | 52.2 | 50.1
|
| 55 |
+
| VILA1.5-40B | fp16 | 54 | 58 | 80.1 | 63 | 58.2 | 58.7 | 51.3
|
| 56 |
+
|
| 57 |
+
### Inference speed ( Token/sec )
|
| 58 |
+
|
| 59 |
+
| $~~~~~~$ | Precision | A100 | 4090 | Orin |
|
| 60 |
+
| ---------------------- | --------- | ----- | ----- | ---- |
|
| 61 |
+
| VILA1.5-3B | fp16 | 104.6 | 137.6 | 25.4 |
|
| 62 |
+
| VILA1.5-3B-AWQ | int4 | 182.8 | 215.5 | 42.5 |
|
| 63 |
+
| VILA1.5-3B-S2 | fp16 | 104.3 | 137.2 | 24.6 |
|
| 64 |
+
| VILA1.5-3B-S2-AWQ | int4 | 180.2 | 219.3 | 40.1 |
|
| 65 |
+
| Llama-3-VILA1.5-8B | fp16 | 74.9 | 57.4 | 10.2 |
|
| 66 |
+
| Llama-3-VILA1.5-8B-AWQ | int4 | 168.9 | 150.2 | 28.7 |
|
| 67 |
+
| VILA1.5-13B | fp16 | 50.9 | OOM | 6.1 |
|
| 68 |
+
| VILA1.5-13B-AWQ | int4 | 115.9 | 105.7 | 20.6 |
|
| 69 |
+
| VILA1.5-40B | fp16 | OOM | OOM | -- |
|
| 70 |
+
| VILA1.5-40B-AWQ | int4 | 57.0 | OOM | -- |
|
| 71 |
+
|
| 72 |
+
<sup>NOTE: Measured using the [TinyChat](https://github.com/mit-han-lab/llm-awq/tinychat) backend at batch size = 1.</sup>
|
| 73 |
+
|
| 74 |
+
## VILA Examples
|
| 75 |
+
|
| 76 |
+
### Video captioning
|
| 77 |
+
|
| 78 |
+
https://github.com/Efficient-Large-Model/VILA/assets/156256291/c9520943-2478-4f97-bc95-121d625018a6
|
| 79 |
+
|
| 80 |
+
Prompt: Elaborate on the visual and narrative elements of the video in detail.
|
| 81 |
+
|
| 82 |
+
Caption: The video shows a person's hands working on a white surface. They are folding a piece of fabric with a checkered pattern in shades of blue and white. The fabric is being folded into a smaller, more compact shape. The person's fingernails are painted red, and they are wearing a black and red garment. There are also a ruler and a pencil on the surface, suggesting that measurements and precision are involved in the process.
|
| 83 |
+
|
| 84 |
+
### In context learning
|
| 85 |
+
|
| 86 |
+
<img src="demo_images/demo_img_1.png" height="239">
|
| 87 |
+
<img src="demo_images/demo_img_2.png" height="250">
|
| 88 |
+
|
| 89 |
+
### Multi-image reasoning
|
| 90 |
+
|
| 91 |
+
<img src="demo_images/demo_img_3.png" height="193">
|
| 92 |
+
|
| 93 |
+
### VILA on Jetson Orin
|
| 94 |
+
|
| 95 |
+
https://github.com/Efficient-Large-Model/VILA/assets/7783214/6079374c-0787-4bc4-b9c6-e1524b4c9dc4
|
| 96 |
+
|
| 97 |
+
### VILA on RTX 4090
|
| 98 |
+
|
| 99 |
+
https://github.com/Efficient-Large-Model/VILA/assets/7783214/80c47742-e873-4080-ad7d-d17c4700539f
|
| 100 |
+
|
| 101 |
+
</details>
|
| 102 |
+
|
| 103 |
+
## Installation
|
| 104 |
+
|
| 105 |
+
```bash
|
| 106 |
+
./environment_setup.sh vila
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
## Training
|
| 110 |
+
|
| 111 |
+
VILA training contains three steps, for specific hyperparameters, please check out the [scripts/v1_5](scripts/v1_5) folder:
|
| 112 |
+
|
| 113 |
+
### Step-1: Alignment
|
| 114 |
+
|
| 115 |
+
We utilize LLaVA-CC3M-Pretrain-595K dataset to align the textual and visual modalities.
|
| 116 |
+
|
| 117 |
+
The stage 1 script takes in two parameters and it can run on a single 8xA100 node. `BASE_MODEL_PATH` points to a online or local huggingface repository, such as `NousResearch/Llama-2-7b-hf`. `OUTPUT_NAME` points to a target directory under `checkpoints`, which will save the trained multimodal projector afterwards.
|
| 118 |
+
|
| 119 |
+
```bash
|
| 120 |
+
bash scripts/v1_5/paper/1_mm_align.sh [BASE_MODEL_PATH] [OUTPUT_NAME]
|
| 121 |
+
```
|
| 122 |
+
|
| 123 |
+
### Step-2: Pretraining
|
| 124 |
+
|
| 125 |
+
We use MMC4 and Coyo dataset to train VLM with interleaved image-text pairs.
|
| 126 |
+
|
| 127 |
+
```bash
|
| 128 |
+
bash scripts/v1_5/paper/2_pretrain_mmc4_coyo.sh [CODE_PATH] [BASE_MODEL_PATH] [STAGE1_PATH] [OUTPUT_NAME]
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
The stage 2 script takes in four arguments. `CODE_PATH` is the absolute path to our VILA codebase, `BASE_MODEL_PATH` has similar meaning to what is presented in the stage 1 script. `STAGE1_PATH` points to the `OUTPUT_NAME` of stage 1 (i.e. where the stage 1 checkpoint is stored). `OUTPUT_NAME` is the desired folder name under `checkpoints` that saves the pretraining checkpoint. The script we provided for this stage is executed on slurm, and we expect it to execute on 16 nodes (128 GPUs).
|
| 132 |
+
|
| 133 |
+
### Step-3: Supervised fine-tuning
|
| 134 |
+
|
| 135 |
+
This is the last stage of VILA training, in which we tune the model to follow multimodal instructions on a subset of M3IT, FLAN and ShareGPT4V. This stage runs on a 8xA100 node.
|
| 136 |
+
|
| 137 |
+
```bash
|
| 138 |
+
bash scripts/v1_5/paper/3_sft.sh [STAGE2_PATH] [OUTPUT_NAME]
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
The stage 3 script takes in two arguments. `STAGE2_PATH` points to the `OUTPUT_NAME` of the stage 2 script (i.e. where the stage 2 checkpoint is stored). `OUTPUT_NAME` is the desired folder name under `checkpoints` that stores the final checkpoint.
|
| 142 |
+
|
| 143 |
+
## Evaluations
|
| 144 |
+
|
| 145 |
+
### Image Benchmarks
|
| 146 |
+
|
| 147 |
+
You can follow [Llava1.5 eval](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md) to download all datasets. After downloading all datasets, please put them under `playground/data/eval`.
|
| 148 |
+
|
| 149 |
+
Please make the following changes to the MME evaluation script. Please search for:
|
| 150 |
+
|
| 151 |
+
```python
|
| 152 |
+
data_path = "MME_Benchmark_release_version"
|
| 153 |
+
```
|
| 154 |
+
|
| 155 |
+
and replace it with:
|
| 156 |
+
|
| 157 |
+
```python
|
| 158 |
+
data_path = os.path.join(script_dir, "MME_Benchmark_release_version")
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
We provide a push-the-button script to perform evaluation on all 10 datasets that do not require GPT-assisted evaluation:
|
| 162 |
+
|
| 163 |
+
```bash
|
| 164 |
+
./scripts/v1_5/eval/eval_all.sh [CHECKPOINT_PATH] [MODEL_NAME] [CONV_MODE]
|
| 165 |
+
```
|
| 166 |
+
|
| 167 |
+
This script takes in two parameters, `CHECKPOINT_PATH` points to the stage 3 model checkpoint, and `MODEL_NAME` will be the name of evaluation results.
|
| 168 |
+
|
| 169 |
+
[VQAv2](https://eval.ai/web/challenges/challenge-page/830/my-submission) and [Vizwiz](https://eval.ai/web/challenges/challenge-page/2185/my-submission) evaluations are hosted on eval.ai. You need to register an account and create a team to be able to submit eval.
|
| 170 |
+
|
| 171 |
+
MMBench and MMBench_CN eval are hosted on another [evaluation server](https://opencompass.org.cn/leaderboard-multimodal). Make sure you change the name of the file before submitting, otherwise the server caches results and will always return wrong result to you.
|
| 172 |
+
|
| 173 |
+
We provide a quick script to automatically organize the prediction files that need to be submitted to servers:
|
| 174 |
+
|
| 175 |
+
```bash
|
| 176 |
+
python scripts/v1_5/eval/copy_predictions.py [MODEL_NAME]
|
| 177 |
+
```
|
| 178 |
+
|
| 179 |
+
You will be able to find the predictions under `playground/data/predictions_upload/[MODEL_NAME]` after executing this script.
|
| 180 |
+
|
| 181 |
+
### Video Benchmarks
|
| 182 |
+
|
| 183 |
+
Please follow the evaluation steps in [Video-LLaVA](https://github.com/PKU-YuanGroup/Video-LLaVA/blob/main/TRAIN_AND_VALIDATE.md#data-for-validating) for dataset preparation.
|
| 184 |
+
|
| 185 |
+
```bash
|
| 186 |
+
./scripts/v1_5/eval/video_chatgpt/run_all.sh [CHECKPOINT_PATH] [MODEL_NAME] [CONV_MODE]
|
| 187 |
+
./scripts/v1_5/eval/video_chatgpt/eval_all.sh [MODEL_NAME]
|
| 188 |
+
```
|
| 189 |
+
|
| 190 |
+
## Inference
|
| 191 |
+
|
| 192 |
+
We provide snippets for quick inference with user prompts and images.
|
| 193 |
+
|
| 194 |
+
Llama-3-VILA1.5-8B inference:
|
| 195 |
+
|
| 196 |
+
```bash
|
| 197 |
+
python -W ignore llava/eval/run_vila.py \
|
| 198 |
+
--model-path Efficient-Large-Model/Llama-3-VILA1.5-8b-Fix \
|
| 199 |
+
--conv-mode llama_3 \
|
| 200 |
+
--query "<image>\n Please describe the traffic condition." \
|
| 201 |
+
--image-file "av.png"
|
| 202 |
+
```
|
| 203 |
+
|
| 204 |
+
VILA1.5-40B inference:
|
| 205 |
+
|
| 206 |
+
```bash
|
| 207 |
+
python -W ignore llava/eval/run_vila.py \
|
| 208 |
+
--model-path Efficient-Large-Model/VILA1.5-40b \
|
| 209 |
+
--conv-mode hermes-2 \
|
| 210 |
+
--query "<image>\n Please describe the traffic condition." \
|
| 211 |
+
--image-file "av.png"
|
| 212 |
+
```
|
| 213 |
+
|
| 214 |
+
VILA1.5-3B video inference:
|
| 215 |
+
|
| 216 |
+
```bash
|
| 217 |
+
python -W ignore llava/eval/run_vila.py \
|
| 218 |
+
--model-path Efficient-Large-Model/VILA1.5-3b \
|
| 219 |
+
--conv-mode vicuna_v1 \
|
| 220 |
+
--query "<video>\n Please describe this video." \
|
| 221 |
+
--video-file "demo.mp4"
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
+
## Quantization and Deployment
|
| 225 |
+
|
| 226 |
+
Our VILA models are quantized by [AWQ](https://arxiv.org/abs/2306.00978) into 4 bits for efficient inference on the edge. We provide a push-the-button [script](https://github.com/mit-han-lab/llm-awq/blob/main/scripts/vila_example.sh) to quantize VILA with AWQ.
|
| 227 |
+
|
| 228 |
+
### Running VILA on desktop GPUs and edge GPUs
|
| 229 |
+
|
| 230 |
+
We support AWQ-quantized 4bit VILA on GPU platforms via [TinyChat](https://github.com/mit-han-lab/llm-awq/tree/main/tinychat). We provide a [tutorial](https://github.com/mit-han-lab/llm-awq/tree/main/tinychat#support-vlm-models-vila--llava) to run the model with TinyChat after quantization. We also provide an [instruction](https://github.com/mit-han-lab/llm-awq/tree/main/tinychat/serve) to launch a Gradio server (powered by TinyChat and AWQ) to serve 4-bit quantized VILA models.
|
| 231 |
+
|
| 232 |
+
### Running VILA on laptops
|
| 233 |
+
|
| 234 |
+
We further support our AWQ-quantized 4bit VILA models on various CPU platforms with both x86 and ARM architectures with our [TinyChatEngine](https://github.com/mit-han-lab/TinyChatEngine). We also provide a detailed [tutorial](https://github.com/mit-han-lab/TinyChatEngine/tree/main?tab=readme-ov-file#deploy-vision-language-model-vlm-chatbot-with-tinychatengine) to help the users deploy VILA on different CPUs.
|
| 235 |
+
|
| 236 |
+
### Running VILA API server
|
| 237 |
+
|
| 238 |
+
A simple API server has been provided to serve VILA models. The server is built on top of [FastAPI](https://fastapi.tiangolo.com/) and [Huggingface Transformers](https://huggingface.co/transformers/). The server can be run with the following command:
|
| 239 |
+
|
| 240 |
+
#### With CLI
|
| 241 |
+
|
| 242 |
+
```bash
|
| 243 |
+
python -W ignore server.py \
|
| 244 |
+
--port 8000 \
|
| 245 |
+
--model-path Efficient-Large-Model/VILA1.5-3B \
|
| 246 |
+
--conv-mode vicuna_v1
|
| 247 |
+
```
|
| 248 |
+
|
| 249 |
+
#### With Docker
|
| 250 |
+
|
| 251 |
+
```bash
|
| 252 |
+
docker build -t vila-server:latest .
|
| 253 |
+
docker run --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \
|
| 254 |
+
-v ./hub:/root/.cache/huggingface/hub \
|
| 255 |
+
-it --rm -p 8000:8000 \
|
| 256 |
+
-e VILA_MODEL_PATH=Efficient-Large-Model/VILA1.5-3B \
|
| 257 |
+
-e VILA_CONV_MODE=vicuna_v1 \
|
| 258 |
+
vila-server:latest
|
| 259 |
+
```
|
| 260 |
+
|
| 261 |
+
Then you can call the endpoint with the OpenAI SDK as follows:
|
| 262 |
+
|
| 263 |
+
```python
|
| 264 |
+
from openai import OpenAI
|
| 265 |
+
|
| 266 |
+
client = OpenAI(
|
| 267 |
+
base_url="http://localhost:8000",
|
| 268 |
+
api_key="fake-key",
|
| 269 |
+
)
|
| 270 |
+
response = client.chat.completions.create(
|
| 271 |
+
messages=[
|
| 272 |
+
{
|
| 273 |
+
"role": "user",
|
| 274 |
+
"content": [
|
| 275 |
+
{"type": "text", "text": "What’s in this image?"},
|
| 276 |
+
{
|
| 277 |
+
"type": "image_url",
|
| 278 |
+
"image_url": {
|
| 279 |
+
"url": "https://blog.logomyway.com/wp-content/uploads/2022/01/NVIDIA-logo.jpg",
|
| 280 |
+
# Or you can pass in a base64 encoded image
|
| 281 |
+
# "url": "data:image/png;base64,<base64_encoded_image>",
|
| 282 |
+
},
|
| 283 |
+
},
|
| 284 |
+
],
|
| 285 |
+
}
|
| 286 |
+
],
|
| 287 |
+
max_tokens=300,
|
| 288 |
+
model="VILA1.5-3B",
|
| 289 |
+
# You can pass in extra parameters as follows
|
| 290 |
+
extra_body={"num_beams": 1, "use_cache": False},
|
| 291 |
+
)
|
| 292 |
+
print(response.choices[0].message.content)
|
| 293 |
+
```
|
| 294 |
+
|
| 295 |
+
<sup>NOTE: This API server is intended for evaluation purposes only and has not been optimized for production use. It has only been tested on A100 and H100 GPUs.</sup>
|
| 296 |
+
|
| 297 |
+
## Checkpoints
|
| 298 |
+
|
| 299 |
+
We release [VILA1.5-3B](https://hf.co/Efficient-Large-Model/VILA1.5-3b), [VILA1.5-3B-S2](https://hf.co/Efficient-Large-Model/VILA1.5-3b-s2), [Llama-3-VILA1.5-8B](https://hf.co/Efficient-Large-Model/Llama-3-VILA1.5-8B-Fix), [VILA1.5-13B](https://hf.co/Efficient-Large-Model/VILA1.5-13b), [VILA1.5-40B](https://hf.co/Efficient-Large-Model/VILA1.5-40b) and the 4-bit [AWQ](https://arxiv.org/abs/2306.00978)-quantized models [VILA1.5-3B-AWQ](https://hf.co/Efficient-Large-Model/VILA1.5-3b-AWQ), [VILA1.5-3B-S2-AWQ](https://hf.co/Efficient-Large-Model/VILA1.5-3b-s2-AWQ), [Llama-3-VILA1.5-8B-AWQ](https://hf.co/Efficient-Large-Model/Llama-3-VILA1.5-8B-Fix-AWQ), [VILA1.5-13B-AWQ](https://hf.co/Efficient-Large-Model/VILA1.5-13b-AWQ), [VILA1.5-40B-AWQ](https://hf.co/Efficient-Large-Model/VILA1.5-40b-AWQ).
|
| 300 |
+
|
| 301 |
+
## 🔒 License
|
| 302 |
+
|
| 303 |
+
- The code is released under the Apache 2.0 license as found in the [LICENSE](./LICENSE) file.
|
| 304 |
+
- The pretrained weights are released under the [CC-BY-NC-SA-4.0 license](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en).
|
| 305 |
+
- The service is a research preview intended for non-commercial use only, and is subject to the following licenses and terms:
|
| 306 |
+
- [Model License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA. For LLAMA3-VILA checkpoints terms of use, please refer to the [LLAMA3 License](https://llama.meta.com/llama3/license/) for additional details.
|
| 307 |
+
- [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI
|
| 308 |
+
- [Dataset Licenses](./data_prepare/LICENSE) for each one used during training.
|
| 309 |
+
|
| 310 |
+
## Team
|
| 311 |
+
|
| 312 |
+
| | | |
|
| 313 |
+
| --- | --- | ---|
|
| 314 |
+
[\*Yao Lu](https://scholar.google.com/citations?user=OI7zFmwAAAAJ&hl=en): Nvidia| [\*Hongxu Yin](https://hongxu-yin.github.io/): Nvidia | [\*Ji Lin](https://www.linji.me/): OpenAI (work done at Nvidia and MIT)
|
| 315 |
+
[Wei Ping](https://scholar.google.com/citations?user=6gKEYRgAAAAJ&hl=en): Nvidia | [Pavlo Molchanov](https://www.pmolchanov.com/): Nvidia | [Andrew Tao](https://scholar.google.com/citations?user=Wel9l1wAAAAJ&hl=en): Nvidia |
|
| 316 |
+
[Haotian Tang](http://kentang.net/): MIT | [Shang Yang](https://ys-2020.github.io/): MIT | [Ligeng Zhu](https://lzhu.me/): Nvidia, MIT |
|
| 317 |
+
[Wei-Chen Wang](https://weichenwang.me/): MIT | [Fuzhao Xue](https://xuefuzhao.github.io/): Nvidia, NUS | [Yunhao Fang](https://seerkfang.github.io/): Nvidia, UCSD |
|
| 318 |
+
[Yukang Chen](https://yukangchen.com/): Nvidia, CUHK | [Zhuoyang Zhang](https://openreview.net/profile?id=~Zhuoyang_Zhang1): Nvidia, Tsinghua Univ. | [Yue Shen](https://www.linkedin.com/in/yue-james-shen/): Nvidia |
|
| 319 |
+
[Wei-Ming Chen](https://scholar.google.com/citations?user=6xFvyJwAAAAJ&hl=en): Nvidia | [Huizi Mao](https://scholar.google.com/citations?user=r5WezOYAAAAJ&hl=zh-CN): Nvidia | [Baifeng Shi](https://bfshi.github.io/): Nvidia, UC Berkeley |
|
| 320 |
+
[Jan Kautz](https://jankautz.com/): Nvidia | [Mohammad Shoeybi](https://scholar.google.com/citations?user=62ElavIAAAAJ&hl=en): Nvidia | [Song Han](http://songhan.mit.edu/): Nvidia, MIT
|
| 321 |
+
|
| 322 |
+
## Citations
|
| 323 |
+
|
| 324 |
+
```
|
| 325 |
+
@misc{lin2023vila,
|
| 326 |
+
title={VILA: On Pre-training for Visual Language Models},
|
| 327 |
+
author={Ji Lin and Hongxu Yin and Wei Ping and Yao Lu and Pavlo Molchanov and Andrew Tao and Huizi Mao and Jan Kautz and Mohammad Shoeybi and Song Han},
|
| 328 |
+
year={2023},
|
| 329 |
+
eprint={2312.07533},
|
| 330 |
+
archivePrefix={arXiv},
|
| 331 |
+
primaryClass={cs.CV}
|
| 332 |
+
}
|
| 333 |
+
```
|
| 334 |
+
|
| 335 |
+
# Acknowledgement
|
| 336 |
+
|
| 337 |
+
- [LLaVA](https://github.com/haotian-liu/LLaVA): the codebase we built upon. Thanks for their wonderful work.
|
| 338 |
+
- [InternVL](https://github.com/OpenGVLab/InternVL): for open-sourcing InternViT (used in VILA1.5-40b) and the [InternVL-SFT](https://github.com/OpenGVLab/InternVL/tree/main/internvl_chat#prepare-training-datasets) data blend (inspired by LLaVA-1.6) used in all VILA1.5 models.
|
| 339 |
+
- [Vicuna](https://github.com/lm-sys/FastChat): the amazing open-sourced large language model!
|
| 340 |
+
- [Video-ChatGPT](https://github.com/mbzuai-oryx/Video-ChatGPT): we borrowed video evaluation script from this repository.
|
| 341 |
+
- [MMC4](https://github.com/allenai/mmc4), [COYO-700M](https://github.com/kakaobrain/coyo-dataset), [M3IT](https://huggingface.co/datasets/MMInstruction/M3IT), [OpenORCA/FLAN](https://huggingface.co/datasets/Open-Orca/FLAN), [ShareGPT4V](https://github.com/InternLM/InternLM-XComposer/tree/main/projects/ShareGPT4V), [WIT](google-research-datasets/wit), [GSM8K-ScRel](https://github.com/OFA-Sys/gsm8k-ScRel/blob/main/data/train_use.jsonl), [VisualGenome](https://visualgenome.org/api/v0/api_home.html), [VCR](https://visualcommonsense.com/download/), [ScienceQA](https://huggingface.co/datasets/derek-thomas/ScienceQA), [Shot2Story](https://github.com/bytedance/Shot2Story/blob/master/DATA.md), [Youcook2](http://youcook2.eecs.umich.edu/), [Vatex](https://eric-xw.github.io/vatex-website/download.html), [ShareGPT-Video](https://huggingface.co/datasets/ShareGPTVideo/train_video_and_instruction) for providing datasets used in this research.
|
VILA/.ipynb_checkpoints/environment_setup-checkpoint.sh
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
|
| 3 |
+
# This is required to activate conda environment
|
| 4 |
+
eval "$(conda shell.bash hook)"
|
| 5 |
+
|
| 6 |
+
# CONDA_ENV=${1:-""}
|
| 7 |
+
CONDA_ENV=vila
|
| 8 |
+
if [ -n "$CONDA_ENV" ]; then
|
| 9 |
+
conda create -n $CONDA_ENV python=3.10 -y
|
| 10 |
+
conda activate $CONDA_ENV
|
| 11 |
+
else
|
| 12 |
+
echo "Skipping conda environment creation. Make sure you have the correct environment activated."
|
| 13 |
+
fi
|
| 14 |
+
|
| 15 |
+
# This is required to enable PEP 660 support
|
| 16 |
+
pip install --upgrade pip
|
| 17 |
+
|
| 18 |
+
# This is optional if you prefer to use built-in nvcc
|
| 19 |
+
conda install -c nvidia cuda-toolkit -y
|
| 20 |
+
|
| 21 |
+
# Install FlashAttention2
|
| 22 |
+
pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.8/flash_attn-2.5.8+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
| 23 |
+
|
| 24 |
+
# Install VILA
|
| 25 |
+
pip install -e .
|
| 26 |
+
pip install -e ".[train]"
|
| 27 |
+
pip install -e ".[eval]"
|
| 28 |
+
|
| 29 |
+
# Install HF's Transformers
|
| 30 |
+
pip install git+https://github.com/huggingface/transformers@v4.37.2
|
| 31 |
+
site_pkg_path=$(python -c 'import site; print(site.getsitepackages()[0])')
|
| 32 |
+
cp -rv ./llava/train/transformers_replace/* $site_pkg_path/transformers/
|
| 33 |
+
cp -rv ./llava/train/deepspeed_replace/* $site_pkg_path/deepspeed/
|
VILA/CIs/license_all.sh
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
addlicense -s -c 'NVIDIA CORPORATION & AFFILIATES' -ignore "llava/eval/**" -ignore "**/*__init__.py" **/*.py
|
VILA/CIs/license_commited.sh
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
PYFILES=$(git diff --name-only --diff-filter=ACMRT $commithash HEAD | grep .py | xargs)
|
| 2 |
+
|
| 3 |
+
for file in $PYFILES; do
|
| 4 |
+
echo $file
|
| 5 |
+
addlicense -s -c 'NVIDIA CORPORATION & AFFILIATES' $file
|
| 6 |
+
done
|
VILA/data_prepare/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
VILA/data_prepare/LICENSE
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
License information for datasets used during VILA Training
|
| 2 |
+
|
| 3 |
+
* LLaVA-1.5 Instruction Data: Apache 2.0
|
| 4 |
+
* Coyo: cc-by-4.0
|
| 5 |
+
* MMC4: ODC-By
|
| 6 |
+
* FLAN: cc-by-4.0
|
| 7 |
+
* M3IT: cc-by-4.0
|
| 8 |
+
* ShareGPT4V: cc-by-nc-4.0
|
VILA/data_prepare/README.md
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Data Preparation for Training VILA
|
| 2 |
+
|
| 3 |
+
To train VILA, we used the following datasets:
|
| 4 |
+
|
| 5 |
+
| Stage | Datasets |
|
| 6 |
+
| ----------------------- | -------------------------------------------------------------------------------- |
|
| 7 |
+
| 1. Initialize projector | CC3M |
|
| 8 |
+
| 2. Pre-training | MMC4-core, COYO-700M subset |
|
| 9 |
+
| 3. SFT | LLaVA-1.5, VFLAN, ShareGPT, TextFLAN, WIT, GSM8K-ScRel-SFT, Sherlock, ScienceQA |
|
| 10 |
+
|
| 11 |
+
### LLaVa-CC3M-Pretrain
|
| 12 |
+
|
| 13 |
+
We use [LLaVA-CC3M-Pretrain-595K](https://huggingface.co/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K/blob/main/chat.json) to train the visual language projector
|
| 14 |
+
|
| 15 |
+
### MMC4-Core Dataset
|
| 16 |
+
|
| 17 |
+
Due to the limit of compute, we pre-train VILA on the smaller core set of MMC4 instead of the full set.
|
| 18 |
+
|
| 19 |
+
1. Firstly, download the annotations of the MMC4-core dataset here: https://github.com/allenai/mmc4. We used the non-fewer-face split, and you may need to request the access [here](https://forms.gle/VYtcNY8aYaUANK9f8).
|
| 20 |
+
|
| 21 |
+
1. Now modify the input and output path in `mmc4_downloader.py` and run the following script to scrawl the MMC4 images:
|
| 22 |
+
|
| 23 |
+
```bash
|
| 24 |
+
cd mmc4
|
| 25 |
+
python mmc4_downloader.py
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
Note that due to the expiration of image urls, you may end up getting a subset of the entire corpus.
|
| 29 |
+
|
| 30 |
+
The scrawling may take a long time. Optionally, you can also shard the workload over multiple jobs/machines concurrently to speed up the process:
|
| 31 |
+
|
| 32 |
+
```bash
|
| 33 |
+
# provide the start and end index of the jsonl shard. There are 23098 - 14 shards totally
|
| 34 |
+
# python mmc4_downloader.py <start_idx> <end_idx>
|
| 35 |
+
python mmc4_downloader.py 0 1000 # worker 1
|
| 36 |
+
python mmc4_downloader.py 1000 2000 # worker 2
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
3. Filter out invalid samples in MMC4:
|
| 40 |
+
|
| 41 |
+
```bash
|
| 42 |
+
python mmc4_filter_and_counter.py
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
4. Merge images and text into a unified pickle file for each shard:
|
| 46 |
+
|
| 47 |
+
```bash
|
| 48 |
+
python mmc4_merger.py
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
### COYO-700M Dataset
|
| 52 |
+
|
| 53 |
+
1. Download the metadata of COYO-700M:
|
| 54 |
+
|
| 55 |
+
```bash
|
| 56 |
+
huggingface-cli download kakaobrain/coyo-700m --repo-type dataset --local-dir coyo-700m --local-dir-use-symlinks False
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
2. Scrawl the COYO images. Note that here we only keep a 20% subset in each shard with the highest CLIP similarity, to balance compute budget and data quality.
|
| 60 |
+
|
| 61 |
+
There are totally 128 shards of annotations. Now download each one with the script:
|
| 62 |
+
|
| 63 |
+
```bash
|
| 64 |
+
cd coyo
|
| 65 |
+
for SHARD in {0..127}; do
|
| 66 |
+
python coyo_downloader.py $SHARD
|
| 67 |
+
done
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
3. Split downloaded COYO data into multiple shards:
|
| 71 |
+
|
| 72 |
+
```bash
|
| 73 |
+
python coyo_splitter.py
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
### LLaVA-1.5 Instruction Data
|
| 77 |
+
|
| 78 |
+
We use this [file](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/llava_v1_5_mix665k.json) in our experiments. Please download this dataset from LLaVA authors.
|
| 79 |
+
|
| 80 |
+
```bash
|
| 81 |
+
huggingface-cli download liuhaotian/LLaVA-Instruct-150K llava_v1_5_mix665k.json --repo-type dataset
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
### VFlan dataset
|
| 85 |
+
|
| 86 |
+
1. Download FLAN datasets:
|
| 87 |
+
|
| 88 |
+
```bash
|
| 89 |
+
huggingface-cli download Open-Orca/FLAN --repo-type dataset --local-dir FLAN --local-dir-use-symlinks False
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
2. Preprocess FLAN dataset (sample 1M data from 378M samples):
|
| 93 |
+
|
| 94 |
+
```bash
|
| 95 |
+
cd sft
|
| 96 |
+
python preprocess_flan.py
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
### M3IT Dataset
|
| 100 |
+
|
| 101 |
+
1. Download M3IT datasets:
|
| 102 |
+
|
| 103 |
+
```bash
|
| 104 |
+
huggingface-cli download MMInstruction/M3IT --repo-type dataset --local-dir M3IT --local-dir-use-symlinks False
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
2. Preprocess M3IT dataset:
|
| 108 |
+
|
| 109 |
+
```bash
|
| 110 |
+
python preprocess_m3it.py
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
3. (Optional) Split FLAN+M3IT into multiple chunks to reduce CPU memory pressure during training:
|
| 114 |
+
|
| 115 |
+
```bash
|
| 116 |
+
python split_vflan.py
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
### ShareGPT4v
|
| 120 |
+
|
| 121 |
+
The ShareGPT data can be obtained [mit-han-lab/ShareGPT4V](https://huggingface.co/datasets/mit-han-lab/ShareGPT4V). * Note the original ShareGPT4v dataset contains some samples with file ids (sa_XXXX) and repeative response. We filter those bad examples and reduced the samples from 100K -> 96K (for caption) and 1.2m -> 1.17m (for pretraining). Then we re-combine them into a single file.
|
| 122 |
+
|
| 123 |
+
```bash
|
| 124 |
+
huggingface-cli download mit-han-lab/ShareGPT4V --repo-type dataset --local-dir coyo-700m --local-dir-use-symlinks False
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
### WIT
|
| 128 |
+
|
| 129 |
+
The original WIT data can be obtained [google-research-datasets/wit](https://github.com/google-research-datasets/wit/tree/main). * We subsample ~538K english data from the original WIT dataset and curate a llava conversation format JSON file.
|
| 130 |
+
|
| 131 |
+
```bash
|
| 132 |
+
huggingface-cli download Efficient-Large-Model/WIT_538K --repo-type dataset --local-dir WIT --local-dir-use-symlinks False
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
### GSM8K-ScRel-SFT
|
| 136 |
+
|
| 137 |
+
We add some math data [gsm8k-ScRel](https://github.com/OFA-Sys/gsm8k-ScRel/blob/main/data/train_use.jsonl) to our SFT stage.
|
| 138 |
+
|
| 139 |
+
### Sherlock
|
| 140 |
+
|
| 141 |
+
The image files of Sherlock can be obtained from [VisualGenome](https://visualgenome.org/api/v0/api_home.html) and [VCR](https://visualcommonsense.com/download/) separately. The llava conversation format JSON file can be downloaded with
|
| 142 |
+
|
| 143 |
+
```bash
|
| 144 |
+
huggingface-cli download Efficient-Large-Model/sherlock_317K --repo-type dataset --local-dir sherlock --local-dir-use-symlinks False
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
### ScienceQA
|
| 148 |
+
|
| 149 |
+
We use the train split of ScienceQA. The image data of the train split can be obtained from [ScienceQA](https://huggingface.co/datasets/derek-thomas/ScienceQA) or their [huggingface repo](https://huggingface.co/datasets/derek-thomas/ScienceQA). The llava conversation format JSON file can be downloaded with
|
| 150 |
+
|
| 151 |
+
```bash
|
| 152 |
+
huggingface-cli download Efficient-Large-Model/ScienceQA_train_12K --repo-type dataset --local-dir scienceqa --local-dir-use-symlinks False
|
| 153 |
+
```
|
| 154 |
+
|
| 155 |
+
### IDEFICS2-SFT dataset
|
| 156 |
+
|
| 157 |
+
We also provide scripts to preprocess IDEFICS2-SFT dataset into llava-SFT like format.
|
| 158 |
+
|
| 159 |
+
Please first download [HuggingFaceM4/the_cauldron](https://huggingface.co/datasets/HuggingFaceM4/the_cauldron) to `/home/jasonlu/workspace/idefics2-sft/the_cauldron`. Then, run the following scripts:
|
| 160 |
+
|
| 161 |
+
```bash
|
| 162 |
+
python preprocess_idefics2.py
|
| 163 |
+
python merge_idefics2.py
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
A sample in the preprocessed dataset file will look like this:
|
| 167 |
+
|
| 168 |
+
```json
|
| 169 |
+
{"id": 0, "images": ["images/chart2text/0_0.png"], "conversations": [{"from": "human", "value": "<image>\nPlease clarify the meaning conveyed by this graph."}, {"from": "gpt", "value": "This statistic presents the reach of the most popular social networks among female beauty consumers in the United States as of August 2016. During the survey period, 62 percent of respondents had an Instagram account."}]}
|
| 170 |
+
```
|
| 171 |
+
|
| 172 |
+
Haotian's Note: Datasets overlapping with VFLAN / ShareGPT4V-SFT are removed. I also remove `plotqa` since it is too large, `localized_narratives` seems to be a little bit overlapped with captioning efforts within VILA. `websight` and `datikz` are two datasets that target code generation. Since the output is very long, and including them might slow down training, I also temporarily removed these two datasets, but feel free to add them back.
|
VILA/data_prepare/panda70m.sh
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
JOBS_LIMIT=${1:-32} # Set your limit here
|
| 2 |
+
workdir=${2:-/lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/dataset/panda70m/panda70m_training_10m}
|
| 3 |
+
|
| 4 |
+
wname=$(echo $workdir | rev | cut -d "/" -f 1 | rev)
|
| 5 |
+
|
| 6 |
+
echo "Parallely checking for all shards in $workdir / $wname"
|
| 7 |
+
parallel_size=32
|
| 8 |
+
idx_size=$(( parallel_size - 1 ))
|
| 9 |
+
|
| 10 |
+
mkdir -p slurm-logs/data
|
| 11 |
+
|
| 12 |
+
for idx in $(seq 0 $idx_size); do
|
| 13 |
+
while [ $(jobs -rp | wc -l) -ge $JOBS_LIMIT ]; do
|
| 14 |
+
sleep 1
|
| 15 |
+
done
|
| 16 |
+
echo "Running jobs $(jobs -rp | wc -l) $wname-$idx-of-$parallel_size";
|
| 17 |
+
|
| 18 |
+
srun -A llmservice_nlp_fm \
|
| 19 |
+
-p cpu,cpu_1,cpu_long -t 4:00:00 -J cleanup-$wname-$idx-of-$parallel_size \
|
| 20 |
+
--cpus-per-task 8 \
|
| 21 |
+
--mem-per-cpu 8G \
|
| 22 |
+
-e slurm-logs/data/$idx-of-$parallel_size.err \
|
| 23 |
+
-o slurm-logs/data/$idx-of-$parallel_size.txt \
|
| 24 |
+
python llava/data/dataset_impl/panda70m.py --workdir=$workdir --shards=$idx --total=$parallel_size &
|
| 25 |
+
|
| 26 |
+
done
|
| 27 |
+
|
| 28 |
+
# bash data_prepare/panda70m.sh 32 /lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/dataset/panda70m/panda70m_training_10m;
|
| 29 |
+
# bash data_prepare/panda70m.sh 32 /lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/dataset/panda70m/panda70m_training_2m;
|
| 30 |
+
# bash data_prepare/panda70m.sh 32 /lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/dataset/panda70m/panda70m_testing;
|
| 31 |
+
|
| 32 |
+
# --exclusive \
|
| 33 |
+
# --cpus-per-task 8 \
|
| 34 |
+
# --mem-per-cpu 8G \
|
VILA/data_prepare/panda_split.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
import base64
|
| 18 |
+
import copy
|
| 19 |
+
import glob
|
| 20 |
+
import io
|
| 21 |
+
import json
|
| 22 |
+
import logging
|
| 23 |
+
import os
|
| 24 |
+
import os.path as osp
|
| 25 |
+
import pathlib
|
| 26 |
+
import pickle
|
| 27 |
+
import random
|
| 28 |
+
import re
|
| 29 |
+
import shutil
|
| 30 |
+
import time
|
| 31 |
+
from collections import defaultdict
|
| 32 |
+
from dataclasses import dataclass, field
|
| 33 |
+
from datetime import datetime
|
| 34 |
+
from functools import lru_cache
|
| 35 |
+
from io import BytesIO
|
| 36 |
+
from typing import Dict, List, Optional, Sequence
|
| 37 |
+
|
| 38 |
+
import cv2
|
| 39 |
+
import decord
|
| 40 |
+
import numpy as np
|
| 41 |
+
import PIL
|
| 42 |
+
import torch
|
| 43 |
+
import transformers
|
| 44 |
+
from decord._ffi.base import DECORDError
|
| 45 |
+
from iopath.common.file_io import g_pathmgr
|
| 46 |
+
from PIL import Image
|
| 47 |
+
from pytorchvideo.data.decoder import DecoderType
|
| 48 |
+
from pytorchvideo.data.encoded_video import EncodedVideo, select_video_class
|
| 49 |
+
from pytorchvideo.data.video import Video
|
| 50 |
+
from torch.utils.data import ConcatDataset, Dataset
|
| 51 |
+
from torchvision.transforms import Resize
|
| 52 |
+
|
| 53 |
+
import llava.data.datasets_mixture as datasets_mixture
|
| 54 |
+
from llava import conversation as conversation_lib
|
| 55 |
+
from llava.constants import (
|
| 56 |
+
DEFAULT_IM_END_TOKEN,
|
| 57 |
+
DEFAULT_IM_START_TOKEN,
|
| 58 |
+
DEFAULT_IMAGE_TOKEN,
|
| 59 |
+
IGNORE_INDEX,
|
| 60 |
+
IMAGE_TOKEN_INDEX,
|
| 61 |
+
)
|
| 62 |
+
from llava.data.dataset import LazySupervisedDataset
|
| 63 |
+
from llava.data.dataset_impl.textocr import GenericDataset, preprocess_OCR
|
| 64 |
+
from llava.data.datasets_mixture import DATASETS
|
| 65 |
+
from llava.data.simple_vila_webdataset import VILAWebDataset
|
| 66 |
+
from llava.data.utils import VILAEncodedVideo
|
| 67 |
+
from llava.mm_utils import is_gemma_tokenizer, tokenizer_image_token
|
| 68 |
+
from llava.train.args import DataArguments, TrainingArguments
|
| 69 |
+
|
| 70 |
+
DEFAULT_HIERTEXT = "/lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/dataset/panda70m"
|
| 71 |
+
SPLIT = "panda70m_testing"
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def with_opencv(filename):
|
| 75 |
+
video = cv2.VideoCapture(filename)
|
| 76 |
+
fps = video.get(cv2.CAP_PROP_FPS)
|
| 77 |
+
frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 78 |
+
duration = frame_count / fps
|
| 79 |
+
return duration, fps, frame_count
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def split_video_to_clips(
|
| 83 |
+
workdir=osp.expanduser("~/nvr_elm_llm/dataset/panda70m/panda70m_training_2m"),
|
| 84 |
+
shards=0,
|
| 85 |
+
total=-1,
|
| 86 |
+
):
|
| 87 |
+
video_list = glob.glob(f"{workdir}/*.mp4")
|
| 88 |
+
video_list = sorted(video_list)
|
| 89 |
+
if total > 0:
|
| 90 |
+
chunk = len(video_list) // total
|
| 91 |
+
begin_idx = shards * chunk
|
| 92 |
+
end_idx = (shards + 1) * chunk
|
| 93 |
+
if shards == total - 1:
|
| 94 |
+
end_idx = len(video_list)
|
| 95 |
+
video_list = video_list[begin_idx:end_idx]
|
| 96 |
+
print(f"Splitting total {len(video_list)} videos")
|
| 97 |
+
output_dir = workdir + "_clip"
|
| 98 |
+
debug_info = {}
|
| 99 |
+
for idx, video_path in enumerate(video_list):
|
| 100 |
+
print(f"[{idx}/{len(video_list)}]", video_path)
|
| 101 |
+
json_path = video_path.replace(".mp4", ".json")
|
| 102 |
+
assert osp.exists(json_path) and osp.exists(video_path)
|
| 103 |
+
jinfo = json.load(open(json_path))
|
| 104 |
+
print(jinfo)
|
| 105 |
+
info = with_opencv(video_path)
|
| 106 |
+
print(info)
|
| 107 |
+
video = VILAEncodedVideo.from_bytesio(video_path, decoder="decord", decode_audio=False)
|
| 108 |
+
|
| 109 |
+
return
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
if __name__ == "__main__":
|
| 113 |
+
# WORKDIR=osp.expanduser("~/nvr_elm_llm/dataset/panda70m/panda70m_testing")
|
| 114 |
+
# cleanup_corrupted_videos()
|
| 115 |
+
import fire
|
| 116 |
+
|
| 117 |
+
fire.Fire(split_video_to_clips)
|
VILA/data_prepare/parallel_shards.sh
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
JOBS_LIMIT=${1:-32} # Set your limit here
|
| 2 |
+
workdir=${2:-/lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/dataset/panda70m/panda70m_training_10m}
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
workdir=/lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/dataset/video_datasets_v2/internvid/video_data_tar
|
| 6 |
+
|
| 7 |
+
parallel_size=32
|
| 8 |
+
idx_size=$(( parallel_size - 1 ))
|
| 9 |
+
|
| 10 |
+
mkdir -p slurm-logs/data
|
| 11 |
+
|
| 12 |
+
for idx in $(seq 0 $idx_size); do
|
| 13 |
+
while [ $(jobs -rp | wc -l) -ge $JOBS_LIMIT ]; do
|
| 14 |
+
sleep 1
|
| 15 |
+
done
|
| 16 |
+
echo "Running jobs $(jobs -rp | wc -l) $idx-of-$parallel_size";
|
| 17 |
+
|
| 18 |
+
srun -A $SLURM_ACCOUNT \
|
| 19 |
+
-p cpu,cpu_1,cpu_long -t 4:00:00 -J creating-WDS-$idx-of-$parallel_size \
|
| 20 |
+
--cpus-per-task 8 \
|
| 21 |
+
--mem-per-cpu 8G \
|
| 22 |
+
--dependency singleton \
|
| 23 |
+
-e slurm-logs/data/$idx-of-$parallel_size.err \
|
| 24 |
+
-o slurm-logs/data/$idx-of-$parallel_size.txt \
|
| 25 |
+
python llava/data/simple_vila_webdataset.py $workdir --shards=$idx --total=$parallel_size &
|
| 26 |
+
done
|
| 27 |
+
wait
|
| 28 |
+
|
| 29 |
+
python llava/data/simple_vila_webdataset.py $workdir
|
VILA/demo_images/LongVILA-pipeline.png
ADDED
|
Git LFS Details
|
VILA/demo_images/av.png
ADDED
|
Git LFS Details
|
VILA/demo_images/demo_img_1.png
ADDED
|
Git LFS Details
|
VILA/demo_images/demo_img_2.png
ADDED
|
Git LFS Details
|
VILA/demo_images/demo_img_3.png
ADDED
|
Git LFS Details
|
VILA/demo_images/longvila-logo.png
ADDED
|
Git LFS Details
|
VILA/demo_images/vila-logo.jpg
ADDED
|
VILA/demo_trt_llm/README.md
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Deprecation Notice
|
| 2 |
+
|
| 3 |
+
This README is deprecated and is no longer being maintained. For the most up-to-date information and instructions, please refer to the [TensorRT-LLM example](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/multimodal#llava-and-vila) for VILA deployment.
|
VILA/inference_test/inference_test.json
ADDED
|
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"test_cases": [
|
| 3 |
+
{
|
| 4 |
+
"name": "top down view",
|
| 5 |
+
"image_paths": [
|
| 6 |
+
"more_samples/top_view.png"
|
| 7 |
+
],
|
| 8 |
+
"QAs": [
|
| 9 |
+
{
|
| 10 |
+
"question": "<image>\n What is unusual about this image?",
|
| 11 |
+
"expected_answer": "The unusual aspect of this image is that it is an aerial view of a busy freeway with many cars, and it appears to be taken from a helicopter. This perspective provides a unique and interesting perspective of the traffic, as it allows the viewer to see the entire freeway and all the cars on it from above. The image captures the bustling nature of the city and the movement of the vehicles, which is not easily visible from ground level."
|
| 12 |
+
}
|
| 13 |
+
]
|
| 14 |
+
},
|
| 15 |
+
|
| 16 |
+
{
|
| 17 |
+
"name": "deer crossing",
|
| 18 |
+
"image_paths": [
|
| 19 |
+
"more_samples/deer_crossing.png"
|
| 20 |
+
],
|
| 21 |
+
"QAs": [
|
| 22 |
+
{
|
| 23 |
+
"question": "<image>\n What is unusual about this image?",
|
| 24 |
+
"expected_answer": "The unusual aspect of this image is that a group of deer is crossing a road in front of a car. Typically, deer are not expected to be seen crossing roads, especially in urban or suburban areas. This situation can pose a risk to both the deer and the people in the car, as the deer might not be aware of the approaching vehicle, and the driver may not have enough time to react and stop safely. It is important for drivers to be cautious and patient in such situations to avoid accidents and ensure the safety of both the animals and the people involved."
|
| 25 |
+
}
|
| 26 |
+
]
|
| 27 |
+
},
|
| 28 |
+
|
| 29 |
+
{
|
| 30 |
+
"name": "windmine",
|
| 31 |
+
"image_paths": [
|
| 32 |
+
"more_samples/windmine.png"
|
| 33 |
+
],
|
| 34 |
+
"QAs": [
|
| 35 |
+
{
|
| 36 |
+
"question": "<image>\n What is in this image?",
|
| 37 |
+
"expected_answer": "The image features a large field with numerous wind turbines scattered throughout the area. The wind turbines are of varying heights and are positioned at different angles, creating a visually striking scene. The field is surrounded by a forest, adding to the natural beauty of the landscape."
|
| 38 |
+
},
|
| 39 |
+
{
|
| 40 |
+
"question": "<image>\n Do facilities here look normal or need maintenance?",
|
| 41 |
+
"expected_answer": "The facilities in the image, which are wind turbines, appear to be in good condition and functioning properly. They are situated in a large field with a clear blue sky, indicating that they are well-maintained and not in need of immediate maintenance. The presence of multiple wind turbines in the field suggests that they are part of a larger wind farm, which is likely to be managed and maintained by professionals to ensure optimal performance and efficiency."
|
| 42 |
+
}
|
| 43 |
+
]
|
| 44 |
+
},
|
| 45 |
+
|
| 46 |
+
{
|
| 47 |
+
"name": "windworker",
|
| 48 |
+
"image_paths": [
|
| 49 |
+
"more_samples/windworker.png"
|
| 50 |
+
],
|
| 51 |
+
"QAs": [
|
| 52 |
+
{
|
| 53 |
+
"question": "<image>\n Can you describe the situation in this image?",
|
| 54 |
+
"expected_answer": "In the image, a person is climbing a large wind turbine, likely a windmill, to perform maintenance or repair work. The individual is using a harness and a rope to secure themselves to the turbine, ensuring their safety while they work. The scene captures the height and scale of the wind turbine, emphasizing the importance of the task being performed by the person."
|
| 55 |
+
},
|
| 56 |
+
{
|
| 57 |
+
"question": "<image>\n Is the worker well protected?",
|
| 58 |
+
"expected_answer": "Yes, the worker is well protected. In the image, the man is wearing a harness while climbing the side of a wind turbine. This harness provides safety and support, preventing the worker from falling and ensuring that they can safely perform their task. Additionally, the worker is likely using other safety equipment and following proper safety protocols to minimize the risk of accidents while working at such heights."
|
| 59 |
+
}
|
| 60 |
+
]
|
| 61 |
+
},
|
| 62 |
+
|
| 63 |
+
{
|
| 64 |
+
"name": "construction crack",
|
| 65 |
+
"image_paths": [
|
| 66 |
+
"more_samples/construction_crack.webp"
|
| 67 |
+
],
|
| 68 |
+
"QAs": [
|
| 69 |
+
{
|
| 70 |
+
"question": "<image>\n Is the building in health condition?",
|
| 71 |
+
"expected_answer": "The building appears to be in a state of disrepair, with visible cracks and chips in the paint. This suggests that the building may be old, poorly maintained, or has experienced some form of damage or wear over time. The presence of these structural issues could potentially pose safety risks or negatively impact the building's overall appearance and value."
|
| 72 |
+
},
|
| 73 |
+
{
|
| 74 |
+
"question": "<image>\n What are the likely service needed for this building?",
|
| 75 |
+
"expected_answer": "The building in the image has a cracked wall and a damaged roof, which indicates that it requires repair and maintenance services. The cracked wall may require repairs to ensure structural integrity and prevent further damage. The damaged roof may need to be replaced or repaired to prevent water leakage and potential damage to the building's interior. Additionally, the building may need to be inspected for any other potential issues that may arise due to the damage."
|
| 76 |
+
}
|
| 77 |
+
]
|
| 78 |
+
},
|
| 79 |
+
|
| 80 |
+
{
|
| 81 |
+
"name": "climate change",
|
| 82 |
+
"image_paths": [
|
| 83 |
+
"more_samples/climate_change.jpeg"
|
| 84 |
+
],
|
| 85 |
+
"QAs": [
|
| 86 |
+
{
|
| 87 |
+
"question": "<image>\n What is the implication of temperature based on this image?",
|
| 88 |
+
"expected_answer": "The image shows a comparison of the Arctic ice cover in 1984 and 2012. The implication of this image is that the Arctic ice cover has significantly decreased over the years, which is a result of climate change. The reduction in ice cover can have various implications, such as increased exposure of the Arctic environment to warmer temperatures, which can lead to the melting of sea ice, loss of habitat for polar bears and other Arctic species, and potential disruptions to the global climate system. This image serves as a visual representation of the impact of climate change on the Arctic environment and the need for action to mitigate its effects."
|
| 89 |
+
}
|
| 90 |
+
]
|
| 91 |
+
},
|
| 92 |
+
|
| 93 |
+
{
|
| 94 |
+
"name": "patient status",
|
| 95 |
+
"image_paths": [
|
| 96 |
+
"more_samples/patient_status.jpeg"
|
| 97 |
+
],
|
| 98 |
+
"QAs": [
|
| 99 |
+
{
|
| 100 |
+
"question": "<image>\n What is the patient status?",
|
| 101 |
+
"expected_answer": "The patient is in a critical condition, as indicated by the presence of a ventilator and the fact that she is hooked up to a lot of wires."
|
| 102 |
+
},
|
| 103 |
+
{
|
| 104 |
+
"question": "<image>\n Is the patient asleep?",
|
| 105 |
+
"expected_answer": "Yes, the patient is asleep in the hospital bed."
|
| 106 |
+
}
|
| 107 |
+
]
|
| 108 |
+
},
|
| 109 |
+
|
| 110 |
+
{
|
| 111 |
+
"name": "patient care",
|
| 112 |
+
"image_paths": [
|
| 113 |
+
"more_samples/patient_care.png"
|
| 114 |
+
],
|
| 115 |
+
"QAs": [
|
| 116 |
+
{
|
| 117 |
+
"question": "<image>\n What is the training session about?",
|
| 118 |
+
"expected_answer": "The training session is about learning how to perform CPR (cardiopulmonary resuscitation) on a mannequin. The group of people, including nurses and possibly other medical professionals, are gathered around a mannequin to practice their skills and learn the proper techniques for administering CPR."
|
| 119 |
+
},
|
| 120 |
+
{
|
| 121 |
+
"question": "<image>\n Will the practice cause pain to the patient in the image?",
|
| 122 |
+
"expected_answer": "The image shows a group of people, including nurses and a mannequin, practicing CPR on a mannequin. While practicing CPR is essential for healthcare professionals, it is important to note that it is a simulation and not a real patient. The practice does not cause pain to the patient in the image, as it is a mannequin and not a living person."
|
| 123 |
+
}
|
| 124 |
+
]
|
| 125 |
+
},
|
| 126 |
+
|
| 127 |
+
{
|
| 128 |
+
"name": "Fuyu comparison",
|
| 129 |
+
"image_paths": [
|
| 130 |
+
"fuyu-capm.png"
|
| 131 |
+
],
|
| 132 |
+
"QAs": [
|
| 133 |
+
{
|
| 134 |
+
"question": "<image>\n Janet Ludlow’s firm requires all its analysts to use a two-stage dividend discount model (DDM) and the capital asset pricing model (CAPM) to value stocks. Using the CAPM and DDM, Ludlow has valued QuickBrush Company at $63 per share. She now must value SmileWhite Corporation. Calculate the required rate of return for SmileWhite by using the information in the table. Which rate is correct? A. 14% B. 15% C. 16%",
|
| 135 |
+
"expected_answer": "The required rate of return for SmileWhite can be calculated using the capital asset pricing model (CAPM), which is:\n\nRequired Return = Risk-free Rate + Beta x (Expected Market Return - Risk-free Rate)\n\nThe risk-free rate and expected market return are provided in the table, and the beta for SmileWhite is given as 1.15.\n\nLet’s calculate the required rate of return for SmileWhite:\n\nRequired Return = 4.50% + 1.15 x (14.50% - 4.50%)\nRequired Return = 4.50% + 1.15 x 10%\nRequired Return = 4.50% + 11.50%\nRequired Return = 16%\n\nTherefore, the answer is:\n\nC: 16%\nAnswer: C"
|
| 136 |
+
}
|
| 137 |
+
]
|
| 138 |
+
},
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
{
|
| 143 |
+
"name": "Fuyu comparison",
|
| 144 |
+
"image_paths": [
|
| 145 |
+
"fuyu-heavy-table.png"
|
| 146 |
+
],
|
| 147 |
+
"QAs": [
|
| 148 |
+
{
|
| 149 |
+
"question": "According to the following table, <image>\nwhich food is the most likely cause of the outbreak of food poisoning? (A) Cold chicken, (B) Potato salad, (C) Egg sandwiches, (D) Fruit pie and cream. Read the table in details and consider the percentage as the key factor. Think step-by-step.",
|
| 150 |
+
"expected_answer": "The table shows the number of people who ate different foods and the number of people who ate those foods and got sick. To determine which food is most likely the cause of the outbreak of food poisoning, we’re looking for the food with the highest percentage of people who ate that food and got sick.\n\nLet’s calculate the percentage of people who ate each food and got sick:\n\n- Cold chicken: 34 / 86 = 0.395, which is 39.5%\n- Potato salad: 38 / 54 = 0.704, which is 70.4%\n- Egg sandwiches: 40 / 76 = 0.526, which is 52.6%\n- Fruit pie and cream: 12 / 32 = 0.375, which is 37.5%\n- Cheese: 12 / 48 = 0.25, which is 25%\n\nThe food with the highest percentage of people who ate that food and got sick is potato salad, with 70.4%. Therefore, potato salad is the most likely cause of the outbreak of food poisoning.\n\nAnswer: B"
|
| 151 |
+
}
|
| 152 |
+
]
|
| 153 |
+
},
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
{
|
| 157 |
+
"name": "flamingo comparison",
|
| 158 |
+
"image_paths": [
|
| 159 |
+
"flamingo_1.png",
|
| 160 |
+
"flamingo_2.png",
|
| 161 |
+
"flamingo_3.png"
|
| 162 |
+
],
|
| 163 |
+
"QAs": [
|
| 164 |
+
{
|
| 165 |
+
"question": "Image 1: <image>\n Image 2: <image>\n Image 3: <image>\n What is the common thing about these three images?",
|
| 166 |
+
"expected_answer": "The common theme among these three images is the presence of a flamingo. In the first image, the flamingo is a 3D model, while in the second image, it is a pink cartoon. In the third image, it is a real flamingo standing in the water. All three images showcase the bird in different styles and formats, but they all share the same subject."
|
| 167 |
+
}
|
| 168 |
+
]
|
| 169 |
+
},
|
| 170 |
+
{
|
| 171 |
+
"name": "princess comparison",
|
| 172 |
+
"image_paths": [
|
| 173 |
+
"princess_1.png",
|
| 174 |
+
"princess_2.png"
|
| 175 |
+
],
|
| 176 |
+
"QAs": [
|
| 177 |
+
{
|
| 178 |
+
"question": "Image 1: <image>\n Image 2: <image>\n Tell me about the differences between image 1 and image 2.",
|
| 179 |
+
"expected_answer": "The two images feature a cartoon girl wearing a pink dress. In the first image, the girl is wearing a crown, while in the second image, she is wearing a tiara. Both images showcase the girl in a princess-like attire, but the difference lies in the type of headwear she is wearing. The crown is a more traditional and regal accessory, while the tiara is a more playful and common accessory associated with princesses in popular culture."
|
| 180 |
+
}
|
| 181 |
+
]
|
| 182 |
+
},
|
| 183 |
+
{
|
| 184 |
+
"name": "order price calculation",
|
| 185 |
+
"image_paths": [
|
| 186 |
+
"menu_1.png",
|
| 187 |
+
"menu_2.png"
|
| 188 |
+
],
|
| 189 |
+
"QAs": [
|
| 190 |
+
{
|
| 191 |
+
"question": "Photo: <image>\n Menu: <image>\n How much should I pay for all the beers on the table according to the price on the menu? Think step by step.",
|
| 192 |
+
"expected_answer": "The beer on the table is Magna, which is priced at $6.00. There are two bottles of Magna on the table. Therefore, the total cost of the beer should be $6.00 x 2 = $12.00."
|
| 193 |
+
}
|
| 194 |
+
]
|
| 195 |
+
},
|
| 196 |
+
{
|
| 197 |
+
"name": "pizza price calculation",
|
| 198 |
+
"image_paths": [
|
| 199 |
+
"pizza.png"
|
| 200 |
+
],
|
| 201 |
+
"QAs": [
|
| 202 |
+
{
|
| 203 |
+
"question": "<image>\nI am getting just two custom pizza for me and my friend. How much should I pay in total according to the image? Let's think step by step.",
|
| 204 |
+
"expected_answer": "The image shows a chalkboard menu with a special for today, which is to create your own pizza. The price for this special is $9.99. Since you are getting two custom pizzas, you should pay $9.99 x 2 = $19.98 for the two pizzas. So, the answer is $19.98."
|
| 205 |
+
}
|
| 206 |
+
]
|
| 207 |
+
},
|
| 208 |
+
{
|
| 209 |
+
"name": "sign OCR",
|
| 210 |
+
"image_paths": [
|
| 211 |
+
"sign_1.png",
|
| 212 |
+
"sign_2.png",
|
| 213 |
+
"sign_3.png"
|
| 214 |
+
],
|
| 215 |
+
"QAs": [
|
| 216 |
+
{
|
| 217 |
+
"question": "Image 1: <image>\nImage 2: <image>\nImage 3: <image>. Image 1 is Underground. Image 2 is Congress. What is Image 3?",
|
| 218 |
+
"expected_answer": "Soulomes."
|
| 219 |
+
}
|
| 220 |
+
]
|
| 221 |
+
},
|
| 222 |
+
{
|
| 223 |
+
"name": "painting style",
|
| 224 |
+
"image_paths": [
|
| 225 |
+
"painting_1.png",
|
| 226 |
+
"painting_2.png",
|
| 227 |
+
"painting_3.png"
|
| 228 |
+
],
|
| 229 |
+
"QAs": [
|
| 230 |
+
{
|
| 231 |
+
"question": "Image 1: <image>\nImage 2: <image>\nImage 3: <image>. Image 1 is Romanticism. Image 2 is Surrealism. What is Image 3?",
|
| 232 |
+
"expected_answer": "Impressionism"
|
| 233 |
+
}
|
| 234 |
+
]
|
| 235 |
+
},
|
| 236 |
+
{
|
| 237 |
+
"name": "handwritten calculation",
|
| 238 |
+
"image_paths": [
|
| 239 |
+
"handwritten_1.png",
|
| 240 |
+
"handwritten_2.png",
|
| 241 |
+
"handwritten_3.png"
|
| 242 |
+
],
|
| 243 |
+
"QAs": [
|
| 244 |
+
{
|
| 245 |
+
"question": "Image 1: <image>\nImage 2: <image>\nImage 3: <image>. Image 1 is 2+1=3. Image 2 is 5+6=11. What is Image 3?",
|
| 246 |
+
"expected_answer": "3x6=18"
|
| 247 |
+
}
|
| 248 |
+
]
|
| 249 |
+
},
|
| 250 |
+
{
|
| 251 |
+
"name": "landmark Taipei",
|
| 252 |
+
"image_paths": [
|
| 253 |
+
"landmark_taipei.png"
|
| 254 |
+
],
|
| 255 |
+
"QAs": [
|
| 256 |
+
{
|
| 257 |
+
"question": "<image>\nWhich city is this landmark in?",
|
| 258 |
+
"expected_answer": "The landmark in the image is located in Taipei, Taiwan."
|
| 259 |
+
}
|
| 260 |
+
]
|
| 261 |
+
},
|
| 262 |
+
{
|
| 263 |
+
"name": "landmark new york 1",
|
| 264 |
+
"image_paths": [
|
| 265 |
+
"landmark_new_york_1.png"
|
| 266 |
+
],
|
| 267 |
+
"QAs": [
|
| 268 |
+
{
|
| 269 |
+
"question": "<image>\nWhich city is this landmark in?",
|
| 270 |
+
"expected_answer": "This landmark is located in New York City."
|
| 271 |
+
}
|
| 272 |
+
]
|
| 273 |
+
},
|
| 274 |
+
{
|
| 275 |
+
"name": "landmark new york 2",
|
| 276 |
+
"image_paths": [
|
| 277 |
+
"landmark_new_york_2.png"
|
| 278 |
+
],
|
| 279 |
+
"QAs": [
|
| 280 |
+
{
|
| 281 |
+
"question": "<image>\nWhich city is this landmark in?",
|
| 282 |
+
"expected_answer": "The landmark in the image is located in New York City."
|
| 283 |
+
}
|
| 284 |
+
]
|
| 285 |
+
},
|
| 286 |
+
{
|
| 287 |
+
"name": "landmark Kyoto",
|
| 288 |
+
"image_paths": [
|
| 289 |
+
"landmark_kyoto.png"
|
| 290 |
+
],
|
| 291 |
+
"QAs": [
|
| 292 |
+
{
|
| 293 |
+
"question": "<image>\nWhich city is this landmark in?",
|
| 294 |
+
"expected_answer": "This landmark is located in Kyoto, Japan."
|
| 295 |
+
}
|
| 296 |
+
]
|
| 297 |
+
},
|
| 298 |
+
{
|
| 299 |
+
"name": "visual inference",
|
| 300 |
+
"image_paths": [
|
| 301 |
+
"visual_inference.png"
|
| 302 |
+
],
|
| 303 |
+
"QAs": [
|
| 304 |
+
{
|
| 305 |
+
"question": "<image>\nWhat is in the circled glass? Object 1 or 2? Describe what is Object 1 and object 2. Then check what is in the circled glass.",
|
| 306 |
+
"expected_answer": "Object 1 is a bottle of water, and Object 2 is a bottle of beer. The circled glass contains Object 1, which is a bottle of water."
|
| 307 |
+
}
|
| 308 |
+
]
|
| 309 |
+
},
|
| 310 |
+
{
|
| 311 |
+
"name": "Gemini handdraw car",
|
| 312 |
+
"image_paths": [
|
| 313 |
+
"gemini_handdraw_car.png"
|
| 314 |
+
],
|
| 315 |
+
"QAs": [
|
| 316 |
+
{
|
| 317 |
+
"question": "<image>\nWhich of these cars is more aerodynamic? The one on the left or the right? Explain why, using specific visual details.",
|
| 318 |
+
"expected_answer": "The car on the right is more aerodynamic than the one on the left. This is because the car on the right has a more streamlined shape, which reduces air resistance and drag. The car on the left has a more traditional, boxy shape, which increases air resistance and drag. The aerodynamic car on the right would be more efficient and faster, while the car on the left would be slower and less efficient."
|
| 319 |
+
}
|
| 320 |
+
]
|
| 321 |
+
},
|
| 322 |
+
{
|
| 323 |
+
"name": "caption night view",
|
| 324 |
+
"image_paths": [
|
| 325 |
+
"caption_nightview.png"
|
| 326 |
+
],
|
| 327 |
+
"QAs": [
|
| 328 |
+
{
|
| 329 |
+
"question": "<image>\nDescribe the image with as much detail as possible.",
|
| 330 |
+
"expected_answer": "The image captures a beautiful nighttime scene of a city skyline with a large body of water in front of it. The city is illuminated by the lights of various buildings, creating a vibrant atmosphere. The water is filled with numerous boats of different sizes."
|
| 331 |
+
}
|
| 332 |
+
]
|
| 333 |
+
},
|
| 334 |
+
{
|
| 335 |
+
"name": "caption meat",
|
| 336 |
+
"image_paths": [
|
| 337 |
+
"caption_meat.jpeg"
|
| 338 |
+
],
|
| 339 |
+
"QAs": [
|
| 340 |
+
{
|
| 341 |
+
"question": "<image>\nDescribe the image with as much detail as possible.",
|
| 342 |
+
"expected_answer": "The image is a surreal and colorful scene featuring a landscape made entirely of meat. The landscape is filled with various types of meat, including bacon, ham, and sausage, creating a unique and eye-catching scene. The meat is arranged in the form of mountains, hills, and valleys, giving the impression of a fantasy world. In addition to the meat, there are several birds scattered throughout the scene."
|
| 343 |
+
}
|
| 344 |
+
]
|
| 345 |
+
},
|
| 346 |
+
{
|
| 347 |
+
"name": "company",
|
| 348 |
+
"image_paths": [
|
| 349 |
+
"company_1.png",
|
| 350 |
+
"company_2.png",
|
| 351 |
+
"company_3.png"
|
| 352 |
+
],
|
| 353 |
+
"QAs": [
|
| 354 |
+
{
|
| 355 |
+
"question": "Image 1: <image>\nImage 2: <image>\nImage 3: <image>. In Image 1, The company is famous for its search engine. In Image 2, The company is famous for iPhone and Mac. What is the company in Image 3 famouns for?",
|
| 356 |
+
"expected_answer": "The company is famous for its graphics processing units."
|
| 357 |
+
}
|
| 358 |
+
]
|
| 359 |
+
},
|
| 360 |
+
{
|
| 361 |
+
"name": "count animal",
|
| 362 |
+
"image_paths": [
|
| 363 |
+
"count_panda_3.png",
|
| 364 |
+
"count_dog_2.png",
|
| 365 |
+
"count_giraff_4.png"
|
| 366 |
+
],
|
| 367 |
+
"QAs": [
|
| 368 |
+
{
|
| 369 |
+
"question": "<image> pandas: 3.\n <image> dogs:2. <image>",
|
| 370 |
+
"expected_answer": "giraffes: 4"
|
| 371 |
+
}
|
| 372 |
+
]
|
| 373 |
+
},
|
| 374 |
+
{
|
| 375 |
+
"name": "french",
|
| 376 |
+
"image_paths": [
|
| 377 |
+
"french_1.png",
|
| 378 |
+
"french_2.png",
|
| 379 |
+
"french_3.png"
|
| 380 |
+
],
|
| 381 |
+
"QAs": [
|
| 382 |
+
{
|
| 383 |
+
"question": "Image 1: <image> Les sanglots longs des violons de l’automne blessent mon coeur d’une langueur monotone. \n Image 2: <image> Pour qui sont ces serpents qui sifflent sur vos têtes? \n Image 3: <image>",
|
| 384 |
+
"expected_answer": "Les flamands roses s'embrassent avec passion, leurs cœurs se touchant, leur amour se partageant."
|
| 385 |
+
}
|
| 386 |
+
]
|
| 387 |
+
},
|
| 388 |
+
{
|
| 389 |
+
"name": "meme",
|
| 390 |
+
"image_paths": [
|
| 391 |
+
"meme.png"
|
| 392 |
+
],
|
| 393 |
+
"QAs": [
|
| 394 |
+
{
|
| 395 |
+
"question": "<image>\nCan you explain the meme?",
|
| 396 |
+
"expected_answer": "The meme depicts a man's reaction to the price of a computer graphics card. In the first image, the man is smiling and appears excited about the product. In the second image, he is shocked and disappointed by the high price of the graphics card, which is $1,200. The meme is a playful representation of the contrast between the man's initial enthusiasm and his subsequent disappointment upon learning the cost of the product."
|
| 397 |
+
}
|
| 398 |
+
]
|
| 399 |
+
},
|
| 400 |
+
{
|
| 401 |
+
"name": "flying chair",
|
| 402 |
+
"image_paths": [
|
| 403 |
+
"flying_chair.png"
|
| 404 |
+
],
|
| 405 |
+
"QAs": [
|
| 406 |
+
{
|
| 407 |
+
"question": "<image>\nWhat is unusual about this image?",
|
| 408 |
+
"expected_answer": "The unusual aspect of this image is that a chair is flying through the air on a highway, seemingly coming out of the back of a truck."
|
| 409 |
+
},
|
| 410 |
+
{
|
| 411 |
+
"question": "<image>\nWhat should you do if you encounter this?",
|
| 412 |
+
"expected_answer": "If you encounter this situation, you should immediately stop your vehicle and move to a safe distance from the truck and the flying chair. It is essential to avoid any potential hazards and contact the authorities to report the incident and ensure the safety of everyone involved."
|
| 413 |
+
}
|
| 414 |
+
]
|
| 415 |
+
},
|
| 416 |
+
{
|
| 417 |
+
"name": "palm_e",
|
| 418 |
+
"image_paths": [
|
| 419 |
+
"palm_e_1.png",
|
| 420 |
+
"palm_e_2.png",
|
| 421 |
+
"palm_e_3.png"
|
| 422 |
+
],
|
| 423 |
+
"QAs": [
|
| 424 |
+
{
|
| 425 |
+
"question": "Image 1: <image>\nImage 2: <image>\nImage 3: <image>. Image 1: at 10:30 am. Image 2: at 12:45 pm. Image3: at 3:45 pm. What did I have for lunch, and what time was it?",
|
| 426 |
+
"expected_answer": "I had a sandwich for lunch, and it was at 12:45 pm."
|
| 427 |
+
}
|
| 428 |
+
]
|
| 429 |
+
},
|
| 430 |
+
{
|
| 431 |
+
"name": "orange price",
|
| 432 |
+
"image_paths": [
|
| 433 |
+
"orange_price.png"
|
| 434 |
+
],
|
| 435 |
+
"QAs": [
|
| 436 |
+
{
|
| 437 |
+
"question": "<image>\nWhat's the price for a single orange? Look at the price tag in details.",
|
| 438 |
+
"expected_answer": "$1.25"
|
| 439 |
+
}
|
| 440 |
+
]
|
| 441 |
+
},
|
| 442 |
+
{
|
| 443 |
+
"name": "tow car",
|
| 444 |
+
"image_paths": [
|
| 445 |
+
"tow_car.png"
|
| 446 |
+
],
|
| 447 |
+
"QAs": [
|
| 448 |
+
{
|
| 449 |
+
"question": "<image>\nWhat's the person doing?",
|
| 450 |
+
"expected_answer": "The person is lying on the ground next to a car, possibly working on it or inspecting it."
|
| 451 |
+
}
|
| 452 |
+
]
|
| 453 |
+
},
|
| 454 |
+
{
|
| 455 |
+
"name": "parking sign",
|
| 456 |
+
"image_paths": [
|
| 457 |
+
"parking_sign.png"
|
| 458 |
+
],
|
| 459 |
+
"QAs": [
|
| 460 |
+
{
|
| 461 |
+
"question": "<image>\nHow long can I park here 5pm on Mondays? Look at the traffic signs in details.",
|
| 462 |
+
"expected_answer": "After 5pm on Monday, you can park for 1 hour."
|
| 463 |
+
}
|
| 464 |
+
]
|
| 465 |
+
},
|
| 466 |
+
{
|
| 467 |
+
"name": "car block",
|
| 468 |
+
"image_paths": [
|
| 469 |
+
"car_blocker.png"
|
| 470 |
+
],
|
| 471 |
+
"QAs": [
|
| 472 |
+
{
|
| 473 |
+
"question": "<image>\n Look at the traffic condition, can the vehicle proceed now? Why?",
|
| 474 |
+
"expected_answer": "Based on the image, the vehicle cannot proceed through the traffic yet. There are multiple people and bicycles in the crosswalk, and the traffic light is red. The vehicle must wait for the traffic light to turn green before proceeding."
|
| 475 |
+
}
|
| 476 |
+
]
|
| 477 |
+
},
|
| 478 |
+
{
|
| 479 |
+
"name": "car safety",
|
| 480 |
+
"image_paths": [
|
| 481 |
+
"car_safety.jpg"
|
| 482 |
+
],
|
| 483 |
+
"QAs": [
|
| 484 |
+
{
|
| 485 |
+
"question": "<image>\nIs the driver on the phone? ",
|
| 486 |
+
"expected_answer": "Yes, the driver is on the phone."
|
| 487 |
+
},
|
| 488 |
+
{
|
| 489 |
+
"question": "<image>\nHow many people are in the car?",
|
| 490 |
+
"expected_answer": "There are two people in the car with one person driving and the other in the back of the car."
|
| 491 |
+
},
|
| 492 |
+
{
|
| 493 |
+
"question": "<image>\nIs the driver distracted?",
|
| 494 |
+
"expected_answer": "Yes, the driver is distracted as he is holding a cell phone interacting with it while sitting on the driver’s seat."
|
| 495 |
+
},
|
| 496 |
+
{
|
| 497 |
+
"question": "<image>\nWhere is the passenger sitting?",
|
| 498 |
+
"expected_answer": "The passenger is sitting on the right side of the car."
|
| 499 |
+
},
|
| 500 |
+
{
|
| 501 |
+
"question": "<image>\nWhat is on the passenger seat? Is it safe?",
|
| 502 |
+
"expected_answer": "There is a pair of scissors on the passenger seat. It is not safe."
|
| 503 |
+
}
|
| 504 |
+
]
|
| 505 |
+
},
|
| 506 |
+
{
|
| 507 |
+
"name": "factory",
|
| 508 |
+
"image_paths": [
|
| 509 |
+
"factory.jpg"
|
| 510 |
+
],
|
| 511 |
+
"QAs": [
|
| 512 |
+
{
|
| 513 |
+
"question": "<image>\nHow many cars are jacked up?",
|
| 514 |
+
"expected_answer": "There are two cars jacked up in the image."
|
| 515 |
+
},
|
| 516 |
+
{
|
| 517 |
+
"question": "<image>\nWhat is the person whose head is under the jacked up car doing?",
|
| 518 |
+
"expected_answer": "The person whose head is under the jacked up car is likely performing a task related to the maintenance or repair of the vehicle. They could be inspecting the suspension, brakes, or other components of the car that require attention. The other people in the scene are also working on the vehicles, suggesting that they are part of a team or a group of mechanics or technicians who are collaborating to fix or maintain the cars."
|
| 519 |
+
},
|
| 520 |
+
{
|
| 521 |
+
"question": "<image>\nHow many people are there whose head is under the jacked up car?",
|
| 522 |
+
"expected_answer": "There are two persons whose head is under the jacked up car."
|
| 523 |
+
}
|
| 524 |
+
]
|
| 525 |
+
},
|
| 526 |
+
{
|
| 527 |
+
"name": "factory count",
|
| 528 |
+
"image_paths": [
|
| 529 |
+
"factory_count_1.jpg",
|
| 530 |
+
"factory_count_2.jpg",
|
| 531 |
+
"factory_count_3.jpg",
|
| 532 |
+
"factory_count_4.jpg",
|
| 533 |
+
"factory_count_5.jpg",
|
| 534 |
+
"factory_count_6.jpg",
|
| 535 |
+
"factory_count_7.jpg",
|
| 536 |
+
"factory_count_8.jpg"
|
| 537 |
+
],
|
| 538 |
+
"QAs": [
|
| 539 |
+
{
|
| 540 |
+
"question": "Frame 1: <image>\n Frame 2: <image>\n Frame 2: <image>\n Frame 4: <image>\n Frame 5: <image>\n Frame 6: <image>\n Frame 7: <image>\n Frame 8: <image>\n Considering the video frames, how many chip bags are picked up?",
|
| 541 |
+
"expected_answer": "Two chip bags are picked up."
|
| 542 |
+
}
|
| 543 |
+
]
|
| 544 |
+
}
|
| 545 |
+
]
|
| 546 |
+
}
|
VILA/inference_test/inference_test.py
ADDED
|
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
"""
|
| 18 |
+
Inference test to run all examples from the paper and compare w/ expected output.
|
| 19 |
+
Both the inference results and expected output will be printed out.
|
| 20 |
+
|
| 21 |
+
Currently do not support multi-turn chat. Each time an image and question are input and answer is output.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import argparse
|
| 25 |
+
import json
|
| 26 |
+
import os
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
from PIL import Image
|
| 30 |
+
|
| 31 |
+
from llava.constants import IMAGE_TOKEN_INDEX
|
| 32 |
+
from llava.conversation import SeparatorStyle, conv_templates
|
| 33 |
+
from llava.mm_utils import (KeywordsStoppingCriteria, process_images,
|
| 34 |
+
tokenizer_image_token)
|
| 35 |
+
from llava.model import *
|
| 36 |
+
|
| 37 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
from llava.model.builder import load_pretrained_model
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def eval_model(args, model, tokenizer, image_processor):
|
| 44 |
+
# read json file
|
| 45 |
+
with open(args.test_json_path) as f:
|
| 46 |
+
all_test_cases = json.load(f)
|
| 47 |
+
|
| 48 |
+
result_list = []
|
| 49 |
+
print(len(all_test_cases["test_cases"]))
|
| 50 |
+
|
| 51 |
+
for test_case in all_test_cases["test_cases"]:
|
| 52 |
+
# read images first
|
| 53 |
+
image_file_list = test_case["image_paths"]
|
| 54 |
+
image_list = [
|
| 55 |
+
Image.open(os.path.join(args.test_image_path, image_file)).convert("RGB") for image_file in image_file_list
|
| 56 |
+
]
|
| 57 |
+
image_tensor = process_images(image_list, image_processor, model.config)
|
| 58 |
+
|
| 59 |
+
# image_tokens = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
|
| 60 |
+
|
| 61 |
+
for i in range(len(test_case["QAs"])):
|
| 62 |
+
query = test_case["QAs"][i]["question"]
|
| 63 |
+
query_text = query
|
| 64 |
+
|
| 65 |
+
if 1:
|
| 66 |
+
# query = query.replace("<image>", image_tokens)
|
| 67 |
+
if len(image_list) < 3:
|
| 68 |
+
conv = conv_templates["vicuna_v1"].copy()
|
| 69 |
+
else:
|
| 70 |
+
conv = conv_templates["vicuna_v1_nosys"].copy()
|
| 71 |
+
conv.append_message(conv.roles[0], query)
|
| 72 |
+
conv.append_message(conv.roles[1], None)
|
| 73 |
+
prompt = conv.get_prompt()
|
| 74 |
+
else:
|
| 75 |
+
conv = conv_templates[args.conv_mode].copy()
|
| 76 |
+
if not "<image>" in query:
|
| 77 |
+
assert "###" not in query # single query
|
| 78 |
+
query = image_tokens + "\n" + query # add <image>
|
| 79 |
+
query_list = [query]
|
| 80 |
+
else:
|
| 81 |
+
query_list = query.split("###")
|
| 82 |
+
assert len(query_list) % 2 == 1 # the last one is from human
|
| 83 |
+
|
| 84 |
+
new_query_list = []
|
| 85 |
+
for idx, query in enumerate(query_list):
|
| 86 |
+
if "<image>" in query:
|
| 87 |
+
assert idx % 2 == 0 # only from human
|
| 88 |
+
# assert query.startswith("<image>")
|
| 89 |
+
# query = query.replace("<image>", image_tokens)
|
| 90 |
+
new_query_list.append(query)
|
| 91 |
+
query_list = new_query_list
|
| 92 |
+
|
| 93 |
+
for idx, query in enumerate(query_list):
|
| 94 |
+
conv.append_message(conv.roles[idx % 2], query)
|
| 95 |
+
conv.append_message(conv.roles[1], None)
|
| 96 |
+
prompt = conv.get_prompt()
|
| 97 |
+
|
| 98 |
+
print("%" * 10 + " " * 5 + "VILA Response" + " " * 5 + "%" * 10)
|
| 99 |
+
|
| 100 |
+
# inputs = tokenizer([prompt])
|
| 101 |
+
inputs = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX)
|
| 102 |
+
input_ids = torch.as_tensor(inputs).cuda().unsqueeze(0)
|
| 103 |
+
|
| 104 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
| 105 |
+
keywords = [stop_str]
|
| 106 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
| 107 |
+
|
| 108 |
+
# outputs = run_llava.process_outputs(args, model, tokenizer, input_ids, image_tensor, stopping_criteria, stop_str)
|
| 109 |
+
with torch.inference_mode():
|
| 110 |
+
output_ids = model.generate(
|
| 111 |
+
input_ids,
|
| 112 |
+
images=image_tensor.to(dtype=torch.float16, device="cuda", non_blocking=True),
|
| 113 |
+
do_sample=True if args.temperature > 0 else False,
|
| 114 |
+
temperature=args.temperature,
|
| 115 |
+
top_p=0.7,
|
| 116 |
+
# top_p=args.top_p,
|
| 117 |
+
# num_beams=args.num_beams,
|
| 118 |
+
max_new_tokens=512,
|
| 119 |
+
# use_cache=True,
|
| 120 |
+
stopping_criteria=[stopping_criteria],
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
|
| 124 |
+
outputs = outputs.strip()
|
| 125 |
+
|
| 126 |
+
print(f"Question: {query_text}")
|
| 127 |
+
print(f"VILA output: {outputs}")
|
| 128 |
+
print(f'Expected output: {test_case["QAs"][i]["expected_answer"]}')
|
| 129 |
+
|
| 130 |
+
result_list.append(
|
| 131 |
+
dict(question=query_text, output=outputs, expected_output=test_case["QAs"][i]["expected_answer"])
|
| 132 |
+
)
|
| 133 |
+
return result_list
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
if __name__ == "__main__":
|
| 137 |
+
parser = argparse.ArgumentParser()
|
| 138 |
+
parser.add_argument("--model-name", type=str, default=None)
|
| 139 |
+
parser.add_argument("--test_json_path", type=str, default=None)
|
| 140 |
+
parser.add_argument("--test_image_path", type=str, default=None)
|
| 141 |
+
parser.add_argument("--conv-mode", type=str, default=None)
|
| 142 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
| 143 |
+
parser.add_argument("--pad", action="store_true")
|
| 144 |
+
|
| 145 |
+
args = parser.parse_args()
|
| 146 |
+
|
| 147 |
+
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_name, "llava_llama", None)
|
| 148 |
+
result_list = eval_model(args, model, tokenizer, image_processor)
|
| 149 |
+
save_name = f"inference-test_{args.model_name.split('/')[-1]}"
|
| 150 |
+
if "nosys" in args.conv_mode:
|
| 151 |
+
save_name += "_nosys"
|
| 152 |
+
save_name += ".json"
|
| 153 |
+
result_list_str = json.dumps(result_list, indent=2)
|
VILA/llava.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.1
|
| 2 |
+
Name: llava
|
| 3 |
+
Version: 1.0.0
|
| 4 |
+
Summary: VILA: On Pre-training for Visual Language Models
|
| 5 |
+
Project-URL: Homepage, https://hanlab.mit.edu/projects/vila
|
| 6 |
+
Project-URL: Bug Tracker, https://github.com/Efficient-Large-Model/VILA-Internal/issues
|
| 7 |
+
Classifier: Programming Language :: Python :: 3
|
| 8 |
+
Classifier: License :: OSI Approved :: Apache Software License
|
| 9 |
+
Requires-Python: >=3.8
|
| 10 |
+
Description-Content-Type: text/markdown
|
| 11 |
+
License-File: LICENSE
|
| 12 |
+
Requires-Dist: torch==2.0.1
|
| 13 |
+
Requires-Dist: torchvision==0.15.2
|
| 14 |
+
Requires-Dist: transformers==4.31.0
|
| 15 |
+
Requires-Dist: tokenizers<0.14,>=0.12.1
|
| 16 |
+
Requires-Dist: sentencepiece==0.1.99
|
| 17 |
+
Requires-Dist: shortuuid
|
| 18 |
+
Requires-Dist: accelerate==0.27.2
|
| 19 |
+
Requires-Dist: peft==0.5.0
|
| 20 |
+
Requires-Dist: bitsandbytes==0.41.0
|
| 21 |
+
Requires-Dist: pydantic<2,>=1
|
| 22 |
+
Requires-Dist: markdown2[all]
|
| 23 |
+
Requires-Dist: numpy
|
| 24 |
+
Requires-Dist: scikit-learn==1.2.2
|
| 25 |
+
Requires-Dist: gradio==3.35.2
|
| 26 |
+
Requires-Dist: gradio_client==0.2.9
|
| 27 |
+
Requires-Dist: requests
|
| 28 |
+
Requires-Dist: httpx==0.24.0
|
| 29 |
+
Requires-Dist: uvicorn
|
| 30 |
+
Requires-Dist: fastapi
|
| 31 |
+
Requires-Dist: einops==0.6.1
|
| 32 |
+
Requires-Dist: einops-exts==0.0.4
|
| 33 |
+
Requires-Dist: timm==0.6.13
|
| 34 |
+
Requires-Dist: openpyxl==3.1.2
|
| 35 |
+
Requires-Dist: pytorchvideo==0.1.5
|
| 36 |
+
Requires-Dist: datasets==2.16.1
|
| 37 |
+
Requires-Dist: openai==1.8.0
|
| 38 |
+
Requires-Dist: webdataset==0.2.86
|
| 39 |
+
Provides-Extra: train
|
| 40 |
+
Requires-Dist: deepspeed==0.13.2; extra == "train"
|
| 41 |
+
Requires-Dist: ninja; extra == "train"
|
| 42 |
+
Requires-Dist: wandb; extra == "train"
|
| 43 |
+
Provides-Extra: eval
|
| 44 |
+
Requires-Dist: mmengine; extra == "eval"
|
| 45 |
+
Requires-Dist: word2number; extra == "eval"
|
| 46 |
+
Requires-Dist: Levenshtein; extra == "eval"
|
| 47 |
+
|
| 48 |
+
<p align="center">
|
| 49 |
+
<img src="demo_images/vila-logo.jpg" width="20%"/>
|
| 50 |
+
</p>
|
| 51 |
+
|
| 52 |
+
# VILA: On Pre-training for Visual Language Models
|
| 53 |
+
|
| 54 |
+
[](CODE_LICENSE)
|
| 55 |
+
[](MODEL_LICENSE)
|
| 56 |
+
[](https://www.python.org/downloads/release/python-3100/)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
[VILA arxiv](https://arxiv.org/abs/2312.07533) / [VILA Demo](https://vila-demo.hanlab.ai/) / [VILA Huggingface](https://huggingface.co/collections/Efficient-Large-Model/vila-on-pre-training-for-visual-language-models-65d8022a3a52cd9bcd62698e)
|
| 60 |
+
|
| 61 |
+
## 💡 Introduction
|
| 62 |
+
VILA is a visual language model (VLM) pretrained with interleaved image-text data at scale, enabling multi-image VLM. VILA is deployable on the edge, including Jetson Orin and laptop by [AWQ](https://arxiv.org/pdf/2306.00978.pdf) 4bit quantization through [TinyChat](https://github.com/mit-han-lab/llm-awq/tree/main/tinychat) framework. We find: (1) image-text pairs are not enough, interleaved image-text is essential; (2) unfreezing LLM during interleaved image-text pre-training enables in-context learning; (3)re-blending text-only instruction data is crucial to boost both VLM and text-only performance. VILA unveils appealing capabilities, including: multi-image reasoning, in-context learning, visual chain-of-thought, and better world knowledge.
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
## 💡 News
|
| 66 |
+
- [2024/02] We release [AWQ](https://arxiv.org/pdf/2306.00978.pdf)-quantized 4bit VILA models, deployable on Jetson Orin and laptops through [TinyChat](https://github.com/mit-han-lab/llm-awq/tree/main/tinychat) and [TinyChatEngine](https://github.com/mit-han-lab/TinyChatEngine).
|
| 67 |
+
- [2024/02] VILA is released. We propose interleaved image-text pretraining that enables multi-image VLM. VILA comes with impressive in-context learning capabilities. We open source everything: including training code, evaluation code, datasets, model ckpts.
|
| 68 |
+
- [2023/12] [Paper](https://arxiv.org/abs/2312.07533) is on Arxiv!
|
| 69 |
+
|
| 70 |
+
## Performance
|
| 71 |
+
|
| 72 |
+
| $~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~$ | Prec. | VQAv2 | GQA | VizWiz | SQA-I | VQA-T | POPE | MME | MMB | MMB-CN | SEED | llava-bench | MM-Vet | Average (w/o MME) |
|
| 73 |
+
| ----------------- | ---------------- | ---------------- | ---------- | ----------- | ----------- | ----- | ----- | ------- | ---- | ------ | ---- | ----------- | ------ | ----------------- |
|
| 74 |
+
| VILA-7B | fp16 | 80.3 | 63.1 | 59.6 | 68.0 | 62.6 | 86.3 | 1489.4 | 69.8 | 61.0 | 61.7 | 75.2 | 35.1 | 65.7 |
|
| 75 |
+
| VILA-7B-AWQ | int4 | 80.1 | 63.0 | 57.8 | 68.0 | 61.9 | 85.3 | 1486.3 | 68.8 | 59.0 | 61.3 | 75.8 | 35.9 | 65.2 |
|
| 76 |
+
| VILA-13B | fp16| 80.5 | 63.6 | 63.1 | 70.5 | 64.0 | 86.3 | 1553.6 | 73.8 | 66.7 | 62.8 | 78.3 | 42.6 | 68.4 |
|
| 77 |
+
| VILA-13B-AWQ | int4 | 80.4 | 63.6 | 63.0 | 71.2 | 63.5 | 87.0 | 1552.9 | 73.6 | 66.3 | 62.2 | 77.6 | 42.0 | 68.2 |
|
| 78 |
+
|
| 79 |
+
<sup>NOTE: The benchmark results are slightly different from what we report in the paper due to refactoring of the codebase based on LLava-1.5 and re-train the model. VQAV2 and VizWiz are test-dev.</sup>
|
| 80 |
+
|
| 81 |
+
### Inference speed ( Token/sec )
|
| 82 |
+
|
| 83 |
+
| $~~~~~~$ | Precision | A100 | 4090 | Orin |
|
| 84 |
+
| --- | --- |--- | --- | --- |
|
| 85 |
+
| VILA-7B | fp16 | 81.6 | 58.5 | 11.5 |
|
| 86 |
+
| VILA-7B-AWQ| int4 |155.3| 168.1| 35.6 |
|
| 87 |
+
| VILA-13B | fp16 | 48.5 | OOM | 6.1 |
|
| 88 |
+
| VILA-13B-AWQ | int4 | 102.1| 99.0| 17.5 |
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
## VILA Examples
|
| 92 |
+
|
| 93 |
+
### In context learning
|
| 94 |
+
<img src="demo_images/demo_img_1.png" height="239">
|
| 95 |
+
<img src="demo_images/demo_img_2.png" height="250">
|
| 96 |
+
|
| 97 |
+
### Multi-image reasoning
|
| 98 |
+
<img src="demo_images/demo_img_3.png" height="193">
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
### VILA on Jetson Orin
|
| 102 |
+
|
| 103 |
+
https://github.com/Efficient-Large-Model/VILA/assets/7783214/6079374c-0787-4bc4-b9c6-e1524b4c9dc4
|
| 104 |
+
|
| 105 |
+
### VILA on RTX 4090
|
| 106 |
+
|
| 107 |
+
https://github.com/Efficient-Large-Model/VILA/assets/7783214/80c47742-e873-4080-ad7d-d17c4700539f
|
| 108 |
+
|
| 109 |
+
</details>
|
| 110 |
+
|
| 111 |
+
## Installation
|
| 112 |
+
|
| 113 |
+
```bash
|
| 114 |
+
./environment_setup.sh
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
or follow the instructions below in order.
|
| 118 |
+
|
| 119 |
+
```
|
| 120 |
+
conda create -n vila python=3.10 -y
|
| 121 |
+
conda activate vila
|
| 122 |
+
|
| 123 |
+
pip install --upgrade pip # enable PEP 660 support
|
| 124 |
+
wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.4.2/flash_attn-2.4.2+cu118torch2.0cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
| 125 |
+
pip install flash_attn-2.4.2+cu118torch2.0cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
| 126 |
+
pip install -e .
|
| 127 |
+
pip install -e ".[train]"
|
| 128 |
+
|
| 129 |
+
pip install git+https://github.com/huggingface/transformers@v4.38.1
|
| 130 |
+
cp -r ./llava/train/transformers_replace/* ~/anaconda3/envs/vila/lib/python3.10/site-packages/transformers/
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
## Training
|
| 134 |
+
|
| 135 |
+
VILA training contains three steps
|
| 136 |
+
|
| 137 |
+
### Step-1: Alignment
|
| 138 |
+
We utilize LLaVA-CC3M-Pretrain-595K dataset to align the textual and visual modalities.
|
| 139 |
+
|
| 140 |
+
The stage 1 script takes in two parameters and it can run on a single 8xA100 node. `BASE_MODEL_PATH` points to a online or local huggingface repository, such as `NousResearch/Llama-2-7b-hf`. `OUTPUT_NAME` points to a target directory under `checkpoints`, which will save the trained multimodal projector afterwards.
|
| 141 |
+
|
| 142 |
+
```bash
|
| 143 |
+
bash scripts/v1_5/paper/1_mm_align.sh [BASE_MODEL_PATH] [OUTPUT_NAME]
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
| Hyperparameter | Global Batch Size | Learning rate | Epochs | Max length | Weight decay |
|
| 147 |
+
| --- | ---: | ---: | ---: | ---: | ---: |
|
| 148 |
+
| VILA-7B | 256 | 2e-5 | 1 | 4096 | 0 |
|
| 149 |
+
| VILA-13B | 256 | 2e-5 | 1 | 4096 | 0 |
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
### Step-2: Pretraining
|
| 153 |
+
We use MMC4 and Coyo dataset to train VLM with interleaved image-text pairs.
|
| 154 |
+
|
| 155 |
+
```bash
|
| 156 |
+
bash scripts/v1_5/paper/2_pretrain_mmc4_coyo.sh [CODE_PATH] [BASE_MODEL_PATH] [STAGE1_PATH] [OUTPUT_NAME]
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
The stage 2 script takes in four arguments. `CODE_PATH` is the absolute path to our VILA codebase, `BASE_MODEL_PATH` has similar meaning to what is presented in the stage 1 script. `STAGE1_PATH` points to the `OUTPUT_NAME` of stage 1 (i.e. where the stage 1 checkpoint is stored). `OUTPUT_NAME` is the desired folder name under `checkpoints` that saves the pretraining checkpoint. The script we provided for this stage is executed on slurm, and we expect it to execute on 16 nodes (128 GPUs).
|
| 160 |
+
|
| 161 |
+
| Hyperparameter | Global Batch Size | Learning rate | Epochs | Max length | Weight decay |
|
| 162 |
+
| --- | ---: | ---: | ---: | ---: | ---: |
|
| 163 |
+
| VILA-7B | 1024 | 5e-5 | 1 | 4096 | 0 |
|
| 164 |
+
| VILA-13B | 1024 | 5e-5 | 1 | 4096 | 0 |
|
| 165 |
+
|
| 166 |
+
### Step-3: Supervised fine-tuning
|
| 167 |
+
This is the last stage of VILA training, in which we tune the model to follow multimodal instructions on a subset of M3IT, FLAN and ShareGPT4V. This stage runs on a 8xA100 node.
|
| 168 |
+
|
| 169 |
+
```bash
|
| 170 |
+
bash scripts/v1_5/paper/3_sft.sh [STAGE2_PATH] [OUTPUT_NAME]
|
| 171 |
+
```
|
| 172 |
+
The stage 3 script takes in two arguments. `STAGE2_PATH` points to the `OUTPUT_NAME` of the stage 2 script (i.e. where the stage 2 checkpoint is stored). `OUTPUT_NAME` is the desired folder name under `checkpoints` that stores the final checkpoint.
|
| 173 |
+
|
| 174 |
+
| Hyperparameter | Global Batch Size | Learning rate | Epochs | Max length | Weight decay |
|
| 175 |
+
| --- | ---: | ---: | ---: | ---: | ---: |
|
| 176 |
+
| VILA-7B | 128 | 2e-5 | 1 | 4096 | 0 |
|
| 177 |
+
| VILA-13B | 128 | 2e-5 | 1 | 4096 | 0 |
|
| 178 |
+
|
| 179 |
+
### Training with fewer GPUs
|
| 180 |
+
To train with fewer GPUs/nodes, you can reduce the `per_device_train_batch_size` and increase the `gradient_accumulation_steps` accordingly. As long as the global batch size same (`per_device_train_batch_size` x `gradient_accumulation_steps` x `num_gpus`) are kept the same, the training precision will not be affected.
|
| 181 |
+
|
| 182 |
+
Stage 1 completes within 3.5 (7B) - 5.5 (13B) hours on 8xA100, Stage 2 completes within 30 hours on 128xA100 for VILA-7B, and stage 3 completes in 25 (7B) - 40 (13B) hours on 8xA100.
|
| 183 |
+
|
| 184 |
+
See [data_prepare/README.md](data_prepare/README.md) for more information about how to prepare datasets.
|
| 185 |
+
|
| 186 |
+
## Evaluations
|
| 187 |
+
|
| 188 |
+
You can follow [Llava1.5 eval](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md) to download all datasets. After downloading all datasets, please put them under `playground/data/eval`.
|
| 189 |
+
|
| 190 |
+
We provide a push-the-button script to perform evaluation on all 10 datasets that do not require GPT-assisted evaluation:
|
| 191 |
+
|
| 192 |
+
```bash
|
| 193 |
+
./scripts/v1_5/eval/eval_all.sh [CHECKPOINT_PATH] [MODEL_NAME]
|
| 194 |
+
```
|
| 195 |
+
|
| 196 |
+
This script takes in two parameters, `CHECKPOINT_PATH` points to the stage 3 model checkpoint, and `MODEL_NAME` will be the name of evaluation results.
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
[VQAv2](https://eval.ai/web/challenges/challenge-page/830/my-submission) and [Vizwiz](https://eval.ai/web/challenges/challenge-page/2185/my-submission) evaluations are hosted on eval.ai. You need to register an account and create a team to be able to submit eval.
|
| 200 |
+
|
| 201 |
+
MMBench and MMBench_CN eval are hosted on another [evaluation server](https://opencompass.org.cn/leaderboard-multimodal). Make sure you change the name of the file before submitting, otherwise the server caches results and will always return wrong result to you.
|
| 202 |
+
|
| 203 |
+
We provide a quick script to automatically organize the prediction files that need to be submitted to servers:
|
| 204 |
+
|
| 205 |
+
```bash
|
| 206 |
+
python scripts/v1_5/eval/copy_predictions.py [MODEL_NAME]
|
| 207 |
+
```
|
| 208 |
+
|
| 209 |
+
You will be able to find the predictions under `playground/data/predictions_upload/[MODEL_NAME]` after executing this script.
|
| 210 |
+
|
| 211 |
+
## Inference
|
| 212 |
+
|
| 213 |
+
We provide snippets for quick inference with user prompts and images.
|
| 214 |
+
|
| 215 |
+
VILA-7B inference:
|
| 216 |
+
```bash
|
| 217 |
+
python -W ignore llava/eval/run_llava.py \
|
| 218 |
+
--model-name Efficient-Large-Model/VILA-7B \
|
| 219 |
+
--conv-mode vicuna_v1 \
|
| 220 |
+
--query "<image>\n Please describe the traffic condition." \
|
| 221 |
+
--image-file "av.png"
|
| 222 |
+
```
|
| 223 |
+
|
| 224 |
+
VILA-13B inference:
|
| 225 |
+
```bash
|
| 226 |
+
python -W ignore llava/eval/run_llava.py \
|
| 227 |
+
--model-name Efficient-Large-Model/VILA-13B \
|
| 228 |
+
--conv-mode vicuna_v1 \
|
| 229 |
+
--query "<image>\n Please describe the traffic condition." \
|
| 230 |
+
--image-file "av.png"
|
| 231 |
+
```
|
| 232 |
+
|
| 233 |
+
## Quantization and Deployment
|
| 234 |
+
|
| 235 |
+
Our VILA models are quantized by [AWQ](https://arxiv.org/abs/2306.00978) into 4 bits for efficient inference on the edge. We provide a push-the-button [script](https://github.com/mit-han-lab/llm-awq/blob/main/scripts/vila_example.sh) to quantize VILA with AWQ.
|
| 236 |
+
|
| 237 |
+
### Running VILA on desktop GPUs and edge GPUs
|
| 238 |
+
|
| 239 |
+
We support AWQ-quantized 4bit VILA on GPU platforms via [TinyChat](https://github.com/mit-han-lab/llm-awq/tree/main/tinychat). We provide a [tutorial](https://github.com/mit-han-lab/llm-awq/tree/main/tinychat#support-vlm-models-vila--llava) to run the model with TinyChat after quantization. We also provide an [instruction](https://github.com/mit-han-lab/llm-awq/tree/main/tinychat/serve) to launch a Gradio server (powered by TinyChat and AWQ) to serve 4-bit quantized VILA models.
|
| 240 |
+
|
| 241 |
+
### Running VILA on laptops
|
| 242 |
+
|
| 243 |
+
We further support our AWQ-quantized 4bit VILA models on various CPU platforms with both x86 and ARM architectures with our [TinyChatEngine](https://github.com/mit-han-lab/TinyChatEngine). We also provide a detailed [tutorial](https://github.com/mit-han-lab/TinyChatEngine/tree/main?tab=readme-ov-file#deploy-vision-language-model-vlm-chatbot-with-tinychatengine) to help the users deploy VILA on different CPUs.
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
## Checkpoints
|
| 248 |
+
|
| 249 |
+
We release [VILA-7B](https://hf.co/Efficient-Large-Model/VILA-7b), [VILA-13B](https://hf.co/Efficient-Large-Model/VILA-13b), [VILA-7B-4bit-AWQ](https://hf.co/Efficient-Large-Model/VILA-7b-4bit-awq) and [VILA-13B-4bit-AWQ](https://hf.co/Efficient-Large-Model/VILA-13b-4bit-awq).
|
| 250 |
+
|
| 251 |
+
## 🔒 License
|
| 252 |
+
- The code is released under the Apache 2.0 license as found in the [LICENSE](./LICENSE) file.
|
| 253 |
+
- The pretrained weights are released under the [CC-BY-NC-SA-4.0 license](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en).
|
| 254 |
+
- The service is a research preview intended for non-commercial use only, and is subject to the following licenses and terms:
|
| 255 |
+
- [Model License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA
|
| 256 |
+
- [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI
|
| 257 |
+
- [Dataset Licenses](./data_prepare/LICENSE) for each one used during training.
|
| 258 |
+
|
| 259 |
+
## Team
|
| 260 |
+
| | | |
|
| 261 |
+
| --- | --- | ---|
|
| 262 |
+
[*Ji Lin](https://www.linji.me/): OpenAI (work done at Nvidia and MIT) | [*Hongxu Yin](https://hongxu-yin.github.io/): Nvidia | [*Yao Lu](https://scholar.google.com/citations?user=OI7zFmwAAAAJ&hl=en): Nvidia
|
| 263 |
+
[Wei Ping](https://scholar.google.com/citations?user=6gKEYRgAAAAJ&hl=en): Nvidia | [Pavlo Molchanov](https://www.pmolchanov.com/): Nvidia | [Andrew Tao](https://scholar.google.com/citations?user=Wel9l1wAAAAJ&hl=en): Nvidia |
|
| 264 |
+
[Haotian Tang](http://kentang.net/): MIT | [Shang Yang](https://ys-2020.github.io/): MIT | [Ligeng Zhu](https://lzhu.me/): Nvidia, MIT |
|
| 265 |
+
[Wei-Chen Wang](https://weichenwang.me/): MIT | [Fuzhao Xue](https://xuefuzhao.github.io/): Nvidia, NUS | [Yunhao Fang](https://seerkfang.github.io/): Nvidia, UCSD |
|
| 266 |
+
[Yukang Chen](https://yukangchen.com/): Nvidia, CUHK | [Yue Shen](https://www.linkedin.com/in/yue-james-shen/): Nvidia | [Huizi Mao](https://scholar.google.com/citations?user=r5WezOYAAAAJ&hl=zh-CN): Nvidia |
|
| 267 |
+
[Jan Kautz](https://jankautz.com/): Nvidia | [Mohammad Shoeybi](https://scholar.google.com/citations?user=62ElavIAAAAJ&hl=en): Nvidia | [Song Han](http://songhan.mit.edu/): Nvidia, MIT
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
## Citations
|
| 271 |
+
|
| 272 |
+
```
|
| 273 |
+
@misc{lin2023vila,
|
| 274 |
+
title={VILA: On Pre-training for Visual Language Models},
|
| 275 |
+
author={Ji Lin and Hongxu Yin and Wei Ping and Yao Lu and Pavlo Molchanov and Andrew Tao and Huizi Mao and Jan Kautz and Mohammad Shoeybi and Song Han},
|
| 276 |
+
year={2023},
|
| 277 |
+
eprint={2312.07533},
|
| 278 |
+
archivePrefix={arXiv},
|
| 279 |
+
primaryClass={cs.CV}
|
| 280 |
+
}
|
| 281 |
+
```
|
| 282 |
+
|
| 283 |
+
# Acknowledgement
|
| 284 |
+
- [LLaVA](https://github.com/haotian-liu/LLaVA): the codebase we built upon. Thanks for their wonderful work.
|
| 285 |
+
- [Vicuna](https://github.com/lm-sys/FastChat): the amazing open-sourced large language model!
|
| 286 |
+
- [Video-ChatGPT](https://github.com/mbzuai-oryx/Video-ChatGPT): we borrowed video evaluation script from this repository.
|
| 287 |
+
- [MMC4](https://github.com/allenai/mmc4), [COYO-700M](https://github.com/kakaobrain/coyo-dataset), [M3IT](https://huggingface.co/datasets/MMInstruction/M3IT), [OpenORCA/FLAN](https://huggingface.co/datasets/Open-Orca/FLAN), [ShareGPT4V](https://github.com/InternLM/InternLM-XComposer/tree/main/projects/ShareGPT4V) for providing datasets used in this research.
|
VILA/llava.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
LICENSE
|
| 2 |
+
README.md
|
| 3 |
+
pyproject.toml
|
| 4 |
+
CIs/send_email.py
|
| 5 |
+
data_prepare/coyo/coyo_downloader.py
|
| 6 |
+
data_prepare/coyo/coyo_splitter.py
|
| 7 |
+
data_prepare/mmc4/mmc4_downloader.py
|
| 8 |
+
data_prepare/mmc4/mmc4_filter_and_counter.py
|
| 9 |
+
data_prepare/mmc4/mmc4_merger.py
|
| 10 |
+
data_prepare/sft/preprocess_flan.py
|
| 11 |
+
data_prepare/sft/preprocess_m3it.py
|
| 12 |
+
data_prepare/sft/split_vflan.py
|
| 13 |
+
demo_trt_llm/llava.py
|
| 14 |
+
demo_trt_llm/test_vila.py
|
| 15 |
+
inference_test/dataset_test.py
|
| 16 |
+
inference_test/inference_test.py
|
| 17 |
+
llava/__init__.py
|
| 18 |
+
llava/constants.py
|
| 19 |
+
llava/conversation.py
|
| 20 |
+
llava/mm_utils.py
|
| 21 |
+
llava/unit_test_utils.py
|
| 22 |
+
llava/utils.py
|
| 23 |
+
llava.egg-info/PKG-INFO
|
| 24 |
+
llava.egg-info/SOURCES.txt
|
| 25 |
+
llava.egg-info/dependency_links.txt
|
| 26 |
+
llava.egg-info/requires.txt
|
| 27 |
+
llava.egg-info/top_level.txt
|
| 28 |
+
llava/data/__init__.py
|
| 29 |
+
llava/data/dataset.py
|
| 30 |
+
llava/data/dataset_tar.py
|
| 31 |
+
llava/data/datasets_mixture.py
|
| 32 |
+
llava/data/simple_video_dataset.py
|
| 33 |
+
llava/data/simple_vila_webdataset.py
|
| 34 |
+
llava/data/dataset_impl/coyo_recap.py
|
| 35 |
+
llava/data/dataset_impl/sam.py
|
| 36 |
+
llava/data_aug/caption2qa.py
|
| 37 |
+
llava/data_aug/dev.py
|
| 38 |
+
llava/data_aug/reformat_tar.py
|
| 39 |
+
llava/eval/eval_gpt_review.py
|
| 40 |
+
llava/eval/eval_gpt_review_bench.py
|
| 41 |
+
llava/eval/eval_gpt_review_visual.py
|
| 42 |
+
llava/eval/eval_mathvista.py
|
| 43 |
+
llava/eval/eval_mmmu.py
|
| 44 |
+
llava/eval/eval_mmvet.py
|
| 45 |
+
llava/eval/eval_pope.py
|
| 46 |
+
llava/eval/eval_science_qa.py
|
| 47 |
+
llava/eval/eval_science_qa_gpt4.py
|
| 48 |
+
llava/eval/eval_science_qa_gpt4_requery.py
|
| 49 |
+
llava/eval/eval_textvqa.py
|
| 50 |
+
llava/eval/evaluate_vqa.py
|
| 51 |
+
llava/eval/generate_webpage_data_from_table.py
|
| 52 |
+
llava/eval/m4c_evaluator.py
|
| 53 |
+
llava/eval/model_qa.py
|
| 54 |
+
llava/eval/model_vqa.py
|
| 55 |
+
llava/eval/model_vqa_loader.py
|
| 56 |
+
llava/eval/model_vqa_mmbench.py
|
| 57 |
+
llava/eval/model_vqa_mmmu.py
|
| 58 |
+
llava/eval/model_vqa_qbench.py
|
| 59 |
+
llava/eval/model_vqa_science.py
|
| 60 |
+
llava/eval/model_vqa_video.py
|
| 61 |
+
llava/eval/qa_baseline_gpt35.py
|
| 62 |
+
llava/eval/run_llava.py
|
| 63 |
+
llava/eval/summarize_gpt_review.py
|
| 64 |
+
llava/eval/mathvista_utils/calculate_score.py
|
| 65 |
+
llava/eval/mathvista_utils/extract_answer.py
|
| 66 |
+
llava/eval/mathvista_utils/utilities.py
|
| 67 |
+
llava/eval/mathvista_utils/prompts/ext_ans.py
|
| 68 |
+
llava/eval/mmmu_utils/data_utils.py
|
| 69 |
+
llava/eval/mmmu_utils/eval_utils.py
|
| 70 |
+
llava/eval/mmmu_utils/model_utils.py
|
| 71 |
+
llava/eval/video/eval_benchmark_1_correctness.py
|
| 72 |
+
llava/eval/video/eval_benchmark_2_detailed_orientation.py
|
| 73 |
+
llava/eval/video/eval_benchmark_3_context.py
|
| 74 |
+
llava/eval/video/eval_benchmark_4_temporal.py
|
| 75 |
+
llava/eval/video/eval_benchmark_5_consistency.py
|
| 76 |
+
llava/eval/video/eval_video_qa.py
|
| 77 |
+
llava/model/__init__.py
|
| 78 |
+
llava/model/apply_delta.py
|
| 79 |
+
llava/model/builder.py
|
| 80 |
+
llava/model/consolidate.py
|
| 81 |
+
llava/model/llava_arch.py
|
| 82 |
+
llava/model/make_delta.py
|
| 83 |
+
llava/model/utils.py
|
| 84 |
+
llava/model/language_model/llava_gemma.py
|
| 85 |
+
llava/model/language_model/llava_llama.py
|
| 86 |
+
llava/model/language_model/llava_mistral.py
|
| 87 |
+
llava/model/language_model/llava_mixtral.py
|
| 88 |
+
llava/model/language_model/llava_mpt.py
|
| 89 |
+
llava/model/language_model/mpt/adapt_tokenizer.py
|
| 90 |
+
llava/model/language_model/mpt/attention.py
|
| 91 |
+
llava/model/language_model/mpt/blocks.py
|
| 92 |
+
llava/model/language_model/mpt/configuration_mpt.py
|
| 93 |
+
llava/model/language_model/mpt/custom_embedding.py
|
| 94 |
+
llava/model/language_model/mpt/flash_attn_triton.py
|
| 95 |
+
llava/model/language_model/mpt/hf_prefixlm_converter.py
|
| 96 |
+
llava/model/language_model/mpt/meta_init_context.py
|
| 97 |
+
llava/model/language_model/mpt/modeling_mpt.py
|
| 98 |
+
llava/model/language_model/mpt/norm.py
|
| 99 |
+
llava/model/language_model/mpt/param_init_fns.py
|
| 100 |
+
llava/model/multimodal_encoder/builder.py
|
| 101 |
+
llava/model/multimodal_encoder/clip_encoder.py
|
| 102 |
+
llava/model/multimodal_encoder/siglip_encoder.py
|
| 103 |
+
llava/model/multimodal_encoder/vision_encoder.py
|
| 104 |
+
llava/model/multimodal_encoder/radio/__init__.py
|
| 105 |
+
llava/model/multimodal_encoder/radio/cls_token.py
|
| 106 |
+
llava/model/multimodal_encoder/radio/create_model.py
|
| 107 |
+
llava/model/multimodal_encoder/radio/enable_cpe_support.py
|
| 108 |
+
llava/model/multimodal_encoder/radio/enable_spectral_reparam.py
|
| 109 |
+
llava/model/multimodal_encoder/radio/extra_timm_models.py
|
| 110 |
+
llava/model/multimodal_encoder/radio/radio_encoder.py
|
| 111 |
+
llava/model/multimodal_encoder/radio/token_merging.py
|
| 112 |
+
llava/model/multimodal_encoder/radio/vit_patch_generator.py
|
| 113 |
+
llava/model/multimodal_projector/builder.py
|
| 114 |
+
llava/train/args.py
|
| 115 |
+
llava/train/llava_trainer.py
|
| 116 |
+
llava/train/short_video_filter.py
|
| 117 |
+
llava/train/slurm_utils.py
|
| 118 |
+
llava/train/train.py
|
| 119 |
+
llava/train/train_mem.py
|
| 120 |
+
llava/train/train_xformers.py
|
| 121 |
+
llava/train/transformer_normalize_monkey_patch.py
|
| 122 |
+
llava/train/utils.py
|
| 123 |
+
llava/train/transformers_replace/trainer.py
|
| 124 |
+
llava/train/transformers_replace/models/gemma/__init__.py
|
| 125 |
+
llava/train/transformers_replace/models/gemma/configuration_gemma.py
|
| 126 |
+
llava/train/transformers_replace/models/gemma/modeling_gemma.py
|
| 127 |
+
llava/train/transformers_replace/models/llama/configuring_llama.py
|
| 128 |
+
llava/train/transformers_replace/models/llama/modeling_llama.py
|
| 129 |
+
llava/train/transformers_replace/models/llama/tokenization_llama.py
|
| 130 |
+
llava/train/transformers_replace/models/mistral/__init__.py
|
| 131 |
+
llava/train/transformers_replace/models/mistral/configuration_mistral.py
|
| 132 |
+
llava/train/transformers_replace/models/mistral/modeling_mistral.py
|
| 133 |
+
llava/train/transformers_replace/models/mixtral/__init__.py
|
| 134 |
+
llava/train/transformers_replace/models/mixtral/configuration_mixtral.py
|
| 135 |
+
llava/train/transformers_replace/models/mixtral/modeling_mixtral.py
|
| 136 |
+
llava/train/transformers_replace/models/siglip/__init__.py
|
| 137 |
+
llava/train/transformers_replace/models/siglip/configuration_siglip.py
|
| 138 |
+
llava/train/transformers_replace/models/siglip/convert_siglip_to_hf.py
|
| 139 |
+
llava/train/transformers_replace/models/siglip/image_processing_siglip.py
|
| 140 |
+
llava/train/transformers_replace/models/siglip/modeling_siglip.py
|
| 141 |
+
llava/train/transformers_replace/models/siglip/processing_siglip.py
|
| 142 |
+
llava/train/transformers_replace/models/siglip/tokenization_siglip.py
|
| 143 |
+
llava/wids/__init__.py
|
| 144 |
+
llava/wids/wids.py
|
| 145 |
+
llava/wids/wids_bench.py
|
| 146 |
+
llava/wids/wids_cleanup.py
|
| 147 |
+
llava/wids/wids_dir.py
|
| 148 |
+
llava/wids/wids_dl.py
|
| 149 |
+
llava/wids/wids_index.py
|
| 150 |
+
llava/wids/wids_lru.py
|
| 151 |
+
llava/wids/wids_mmtar.py
|
| 152 |
+
llava/wids/wids_specs.py
|
| 153 |
+
llava/wids/wids_tar.py
|
| 154 |
+
tests/test_tokenizer.py
|
VILA/llava.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
VILA/llava.egg-info/requires.txt
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.0.1
|
| 2 |
+
torchvision==0.15.2
|
| 3 |
+
transformers==4.31.0
|
| 4 |
+
tokenizers<0.14,>=0.12.1
|
| 5 |
+
sentencepiece==0.1.99
|
| 6 |
+
shortuuid
|
| 7 |
+
accelerate==0.27.2
|
| 8 |
+
peft==0.5.0
|
| 9 |
+
bitsandbytes==0.41.0
|
| 10 |
+
pydantic<2,>=1
|
| 11 |
+
markdown2[all]
|
| 12 |
+
numpy
|
| 13 |
+
scikit-learn==1.2.2
|
| 14 |
+
gradio==3.35.2
|
| 15 |
+
gradio_client==0.2.9
|
| 16 |
+
requests
|
| 17 |
+
httpx==0.24.0
|
| 18 |
+
uvicorn
|
| 19 |
+
fastapi
|
| 20 |
+
einops==0.6.1
|
| 21 |
+
einops-exts==0.0.4
|
| 22 |
+
timm==0.6.13
|
| 23 |
+
openpyxl==3.1.2
|
| 24 |
+
pytorchvideo==0.1.5
|
| 25 |
+
datasets==2.16.1
|
| 26 |
+
openai==1.8.0
|
| 27 |
+
webdataset==0.2.86
|
| 28 |
+
|
| 29 |
+
[eval]
|
| 30 |
+
mmengine
|
| 31 |
+
word2number
|
| 32 |
+
Levenshtein
|
| 33 |
+
|
| 34 |
+
[train]
|
| 35 |
+
deepspeed==0.13.2
|
| 36 |
+
ninja
|
| 37 |
+
wandb
|
VILA/llava.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
CIs
|
| 2 |
+
data
|
| 3 |
+
data_prepare
|
| 4 |
+
demo_images
|
| 5 |
+
demo_trt_llm
|
| 6 |
+
inference_test
|
| 7 |
+
llava
|
VILA/llava/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
VILA/llava/constants.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
# This file is modified from https://github.com/haotian-liu/LLaVA/
|
| 18 |
+
|
| 19 |
+
CONTROLLER_HEART_BEAT_EXPIRATION = 30
|
| 20 |
+
WORKER_HEART_BEAT_INTERVAL = 15
|
| 21 |
+
|
| 22 |
+
LOGDIR = "."
|
| 23 |
+
|
| 24 |
+
# Model Constants
|
| 25 |
+
IGNORE_INDEX = -100
|
| 26 |
+
IMAGE_TOKEN_INDEX = -200
|
| 27 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
| 28 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
| 29 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
| 30 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
| 31 |
+
IMAGE_PLACEHOLDER = "<image-placeholder>"
|
VILA/llava/conversation.py
ADDED
|
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
# This file is modified from https://github.com/haotian-liu/LLaVA/
|
| 17 |
+
|
| 18 |
+
import dataclasses
|
| 19 |
+
from enum import Enum, auto
|
| 20 |
+
from typing import List
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class SeparatorStyle(Enum):
|
| 24 |
+
"""Different separator style."""
|
| 25 |
+
|
| 26 |
+
AUTO = auto()
|
| 27 |
+
SINGLE = auto()
|
| 28 |
+
TWO = auto()
|
| 29 |
+
MPT = auto()
|
| 30 |
+
PLAIN = auto()
|
| 31 |
+
LLAMA_2 = auto()
|
| 32 |
+
MISTRAL = auto()
|
| 33 |
+
LLAMA_3 = auto()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@dataclasses.dataclass
|
| 37 |
+
class Conversation:
|
| 38 |
+
"""A class that keeps all conversation history."""
|
| 39 |
+
|
| 40 |
+
system: str
|
| 41 |
+
roles: List[str]
|
| 42 |
+
messages: List[List[str]]
|
| 43 |
+
offset: int
|
| 44 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
| 45 |
+
sep: str = "###"
|
| 46 |
+
sep2: str = None
|
| 47 |
+
version: str = "Unknown"
|
| 48 |
+
|
| 49 |
+
skip_next: bool = False
|
| 50 |
+
|
| 51 |
+
def get_prompt(self):
|
| 52 |
+
messages = self.messages
|
| 53 |
+
if len(messages) > 0 and type(messages[0][1]) is tuple:
|
| 54 |
+
messages = self.messages.copy()
|
| 55 |
+
init_role, init_msg = messages[0].copy()
|
| 56 |
+
init_msg = init_msg[0].replace("<image>", "").strip()
|
| 57 |
+
if "mmtag" in self.version:
|
| 58 |
+
messages[0] = (init_role, init_msg)
|
| 59 |
+
messages.insert(0, (self.roles[0], "<Image><image></Image>"))
|
| 60 |
+
messages.insert(1, (self.roles[1], "Received."))
|
| 61 |
+
else:
|
| 62 |
+
messages[0] = (init_role, "<image>\n" + init_msg)
|
| 63 |
+
|
| 64 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
| 65 |
+
ret = self.system + self.sep
|
| 66 |
+
for role, message in messages:
|
| 67 |
+
if message:
|
| 68 |
+
if type(message) is tuple:
|
| 69 |
+
message, _, _ = message
|
| 70 |
+
ret += role + ": " + message + self.sep
|
| 71 |
+
else:
|
| 72 |
+
ret += role + ":"
|
| 73 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
| 74 |
+
seps = [self.sep, self.sep2]
|
| 75 |
+
ret = self.system + seps[0]
|
| 76 |
+
for i, (role, message) in enumerate(messages):
|
| 77 |
+
if message:
|
| 78 |
+
if type(message) is tuple:
|
| 79 |
+
message, _, _ = message
|
| 80 |
+
ret += role + ": " + message + seps[i % 2]
|
| 81 |
+
else:
|
| 82 |
+
ret += role + ":"
|
| 83 |
+
elif self.sep_style == SeparatorStyle.LLAMA_3:
|
| 84 |
+
ret = self.system + self.sep
|
| 85 |
+
for rid, (role, message) in enumerate(messages):
|
| 86 |
+
if message:
|
| 87 |
+
if type(message) is tuple:
|
| 88 |
+
message = message[0]
|
| 89 |
+
sep = self.sep if rid < len(messages) - 1 else self.sep2
|
| 90 |
+
ret += role + message + sep
|
| 91 |
+
else:
|
| 92 |
+
ret += role
|
| 93 |
+
elif self.sep_style == SeparatorStyle.MPT:
|
| 94 |
+
ret = self.system + self.sep
|
| 95 |
+
for role, message in messages:
|
| 96 |
+
if message:
|
| 97 |
+
if type(message) is tuple:
|
| 98 |
+
message, _, _ = message
|
| 99 |
+
ret += role + message + self.sep
|
| 100 |
+
else:
|
| 101 |
+
ret += role
|
| 102 |
+
elif self.sep_style == SeparatorStyle.LLAMA_2 or self.sep_style == SeparatorStyle.MISTRAL:
|
| 103 |
+
if self.sep_style == SeparatorStyle.LLAMA_2:
|
| 104 |
+
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
|
| 105 |
+
else:
|
| 106 |
+
wrap_sys = lambda msg: f"{msg}" + ("\n" if msg else "")
|
| 107 |
+
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
|
| 108 |
+
ret = ""
|
| 109 |
+
if self.sep_style == SeparatorStyle.MISTRAL:
|
| 110 |
+
ret += "<s>"
|
| 111 |
+
|
| 112 |
+
for i, (role, message) in enumerate(messages):
|
| 113 |
+
if i == 0:
|
| 114 |
+
assert message, "first message should not be none"
|
| 115 |
+
assert role == self.roles[0], "first message should come from user"
|
| 116 |
+
if message:
|
| 117 |
+
if type(message) is tuple:
|
| 118 |
+
message, _, _ = message
|
| 119 |
+
if i == 0:
|
| 120 |
+
message = wrap_sys(self.system) + message
|
| 121 |
+
if i % 2 == 0:
|
| 122 |
+
message = wrap_inst(message)
|
| 123 |
+
ret += self.sep + message
|
| 124 |
+
else:
|
| 125 |
+
if self.sep_style == SeparatorStyle.LLAMA_2:
|
| 126 |
+
ret += " " + message + " " + self.sep2
|
| 127 |
+
else:
|
| 128 |
+
ret += message + self.sep2
|
| 129 |
+
else:
|
| 130 |
+
ret += ""
|
| 131 |
+
ret = ret.lstrip(self.sep)
|
| 132 |
+
elif self.sep_style == SeparatorStyle.PLAIN:
|
| 133 |
+
seps = [self.sep, self.sep2]
|
| 134 |
+
ret = self.system
|
| 135 |
+
for i, (role, message) in enumerate(messages):
|
| 136 |
+
if message:
|
| 137 |
+
if type(message) is tuple:
|
| 138 |
+
message, _, _ = message
|
| 139 |
+
ret += message + seps[i % 2]
|
| 140 |
+
else:
|
| 141 |
+
ret += ""
|
| 142 |
+
else:
|
| 143 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
| 144 |
+
|
| 145 |
+
return ret
|
| 146 |
+
|
| 147 |
+
def append_message(self, role, message):
|
| 148 |
+
self.messages.append([role, message])
|
| 149 |
+
|
| 150 |
+
def get_images(self, return_pil=False):
|
| 151 |
+
images = []
|
| 152 |
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
| 153 |
+
if i % 2 == 0:
|
| 154 |
+
if type(msg) is tuple:
|
| 155 |
+
import base64
|
| 156 |
+
from io import BytesIO
|
| 157 |
+
|
| 158 |
+
from PIL import Image
|
| 159 |
+
|
| 160 |
+
msg, image, image_process_mode = msg
|
| 161 |
+
if image_process_mode == "Pad":
|
| 162 |
+
|
| 163 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
| 164 |
+
width, height = pil_img.size
|
| 165 |
+
if width == height:
|
| 166 |
+
return pil_img
|
| 167 |
+
elif width > height:
|
| 168 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
| 169 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
| 170 |
+
return result
|
| 171 |
+
else:
|
| 172 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
| 173 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
| 174 |
+
return result
|
| 175 |
+
|
| 176 |
+
image = expand2square(image)
|
| 177 |
+
elif image_process_mode in ["Default", "Crop"]:
|
| 178 |
+
pass
|
| 179 |
+
elif image_process_mode == "Resize":
|
| 180 |
+
image = image.resize((336, 336))
|
| 181 |
+
else:
|
| 182 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
| 183 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
| 184 |
+
aspect_ratio = max_hw / min_hw
|
| 185 |
+
max_len, min_len = 800, 400
|
| 186 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
| 187 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
| 188 |
+
W, H = image.size
|
| 189 |
+
if longest_edge != max(image.size):
|
| 190 |
+
if H > W:
|
| 191 |
+
H, W = longest_edge, shortest_edge
|
| 192 |
+
else:
|
| 193 |
+
H, W = shortest_edge, longest_edge
|
| 194 |
+
image = image.resize((W, H))
|
| 195 |
+
if return_pil:
|
| 196 |
+
images.append(image)
|
| 197 |
+
else:
|
| 198 |
+
buffered = BytesIO()
|
| 199 |
+
image.save(buffered, format="PNG")
|
| 200 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
| 201 |
+
images.append(img_b64_str)
|
| 202 |
+
return images
|
| 203 |
+
|
| 204 |
+
def to_gradio_chatbot(self):
|
| 205 |
+
ret = []
|
| 206 |
+
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
| 207 |
+
if i % 2 == 0:
|
| 208 |
+
if type(msg) is tuple:
|
| 209 |
+
import base64
|
| 210 |
+
from io import BytesIO
|
| 211 |
+
|
| 212 |
+
msg, image, image_process_mode = msg
|
| 213 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
| 214 |
+
aspect_ratio = max_hw / min_hw
|
| 215 |
+
max_len, min_len = 800, 400
|
| 216 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
| 217 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
| 218 |
+
W, H = image.size
|
| 219 |
+
if H > W:
|
| 220 |
+
H, W = longest_edge, shortest_edge
|
| 221 |
+
else:
|
| 222 |
+
H, W = shortest_edge, longest_edge
|
| 223 |
+
image = image.resize((W, H))
|
| 224 |
+
buffered = BytesIO()
|
| 225 |
+
image.save(buffered, format="JPEG")
|
| 226 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
| 227 |
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
| 228 |
+
msg = img_str + msg.replace("<image>", "").strip()
|
| 229 |
+
ret.append([msg, None])
|
| 230 |
+
else:
|
| 231 |
+
ret.append([msg, None])
|
| 232 |
+
else:
|
| 233 |
+
ret[-1][-1] = msg
|
| 234 |
+
return ret
|
| 235 |
+
|
| 236 |
+
def copy(self):
|
| 237 |
+
return Conversation(
|
| 238 |
+
system=self.system,
|
| 239 |
+
roles=self.roles,
|
| 240 |
+
messages=[[x, y] for x, y in self.messages],
|
| 241 |
+
offset=self.offset,
|
| 242 |
+
sep_style=self.sep_style,
|
| 243 |
+
sep=self.sep,
|
| 244 |
+
sep2=self.sep2,
|
| 245 |
+
version=self.version,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
def dict(self):
|
| 249 |
+
if len(self.get_images()) > 0:
|
| 250 |
+
return {
|
| 251 |
+
"system": self.system,
|
| 252 |
+
"roles": self.roles,
|
| 253 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
| 254 |
+
"offset": self.offset,
|
| 255 |
+
"sep": self.sep,
|
| 256 |
+
"sep2": self.sep2,
|
| 257 |
+
}
|
| 258 |
+
return {
|
| 259 |
+
"system": self.system,
|
| 260 |
+
"roles": self.roles,
|
| 261 |
+
"messages": self.messages,
|
| 262 |
+
"offset": self.offset,
|
| 263 |
+
"sep": self.sep,
|
| 264 |
+
"sep2": self.sep2,
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
conv_auto = Conversation(
|
| 269 |
+
system="",
|
| 270 |
+
roles=("", ""),
|
| 271 |
+
messages=(),
|
| 272 |
+
offset=0,
|
| 273 |
+
sep_style=SeparatorStyle.AUTO,
|
| 274 |
+
sep="\n",
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
conv_vicuna_v0 = Conversation(
|
| 278 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
| 279 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
| 280 |
+
roles=("Human", "Assistant"),
|
| 281 |
+
messages=(
|
| 282 |
+
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
|
| 283 |
+
(
|
| 284 |
+
"Assistant",
|
| 285 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
| 286 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
| 287 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
| 288 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
| 289 |
+
"renewable and non-renewable energy sources:\n"
|
| 290 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
| 291 |
+
"energy sources are finite and will eventually run out.\n"
|
| 292 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
| 293 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
| 294 |
+
"and other negative effects.\n"
|
| 295 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
| 296 |
+
"have lower operational costs than non-renewable sources.\n"
|
| 297 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
| 298 |
+
"locations than non-renewable sources.\n"
|
| 299 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
| 300 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
| 301 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
| 302 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n",
|
| 303 |
+
),
|
| 304 |
+
),
|
| 305 |
+
offset=2,
|
| 306 |
+
sep_style=SeparatorStyle.SINGLE,
|
| 307 |
+
sep="###",
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
conv_vicuna_v1 = Conversation(
|
| 311 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
| 312 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
| 313 |
+
roles=("USER", "ASSISTANT"),
|
| 314 |
+
version="v1",
|
| 315 |
+
messages=(),
|
| 316 |
+
offset=0,
|
| 317 |
+
sep_style=SeparatorStyle.TWO,
|
| 318 |
+
sep=" ",
|
| 319 |
+
sep2="</s>",
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
# kentang-mit@: This conversation template is designed for SFT on VFLAN.
|
| 323 |
+
conv_vicuna_v1_nosys = Conversation(
|
| 324 |
+
system="",
|
| 325 |
+
roles=("USER", "ASSISTANT"),
|
| 326 |
+
version="v1_nosys",
|
| 327 |
+
messages=(),
|
| 328 |
+
offset=0,
|
| 329 |
+
sep_style=SeparatorStyle.TWO,
|
| 330 |
+
sep=" ",
|
| 331 |
+
sep2="</s>",
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
conv_llama_2 = Conversation(
|
| 335 |
+
system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
| 336 |
+
|
| 337 |
+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
|
| 338 |
+
roles=("USER", "ASSISTANT"),
|
| 339 |
+
version="llama_v2",
|
| 340 |
+
messages=(),
|
| 341 |
+
offset=0,
|
| 342 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
| 343 |
+
sep="<s>",
|
| 344 |
+
sep2="</s>",
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
conv_mistral = Conversation(
|
| 348 |
+
system="",
|
| 349 |
+
roles=("USER", "ASSISTANT"),
|
| 350 |
+
version="mistral",
|
| 351 |
+
messages=(),
|
| 352 |
+
offset=0,
|
| 353 |
+
sep_style=SeparatorStyle.MISTRAL,
|
| 354 |
+
sep="",
|
| 355 |
+
sep2="</s>",
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
conv_llava_llama_2 = Conversation(
|
| 359 |
+
system="You are a helpful language and vision assistant. "
|
| 360 |
+
"You are able to understand the visual content that the user provides, "
|
| 361 |
+
"and assist the user with a variety of tasks using natural language.",
|
| 362 |
+
roles=("USER", "ASSISTANT"),
|
| 363 |
+
version="llama_v2",
|
| 364 |
+
messages=(),
|
| 365 |
+
offset=0,
|
| 366 |
+
sep_style=SeparatorStyle.LLAMA_2,
|
| 367 |
+
sep="<s>",
|
| 368 |
+
sep2="</s>",
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
conv_mpt = Conversation(
|
| 372 |
+
system="""<|im_start|>system
|
| 373 |
+
A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
|
| 374 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
| 375 |
+
version="mpt",
|
| 376 |
+
messages=(),
|
| 377 |
+
offset=0,
|
| 378 |
+
sep_style=SeparatorStyle.MPT,
|
| 379 |
+
sep="<|im_end|>",
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
conv_llava_plain = Conversation(
|
| 383 |
+
system="",
|
| 384 |
+
roles=("", ""),
|
| 385 |
+
messages=(),
|
| 386 |
+
offset=0,
|
| 387 |
+
sep_style=SeparatorStyle.PLAIN,
|
| 388 |
+
sep="\n",
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
conv_llava_v0 = Conversation(
|
| 392 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
| 393 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
| 394 |
+
roles=("Human", "Assistant"),
|
| 395 |
+
messages=(),
|
| 396 |
+
offset=0,
|
| 397 |
+
sep_style=SeparatorStyle.SINGLE,
|
| 398 |
+
sep="###",
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
conv_llava_v0_mmtag = Conversation(
|
| 402 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
| 403 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
| 404 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
| 405 |
+
roles=("Human", "Assistant"),
|
| 406 |
+
messages=(),
|
| 407 |
+
offset=0,
|
| 408 |
+
sep_style=SeparatorStyle.SINGLE,
|
| 409 |
+
sep="###",
|
| 410 |
+
version="v0_mmtag",
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
conv_llava_v1 = Conversation(
|
| 414 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
| 415 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
| 416 |
+
roles=("USER", "ASSISTANT"),
|
| 417 |
+
version="v1",
|
| 418 |
+
messages=(),
|
| 419 |
+
offset=0,
|
| 420 |
+
sep_style=SeparatorStyle.TWO,
|
| 421 |
+
sep=" ",
|
| 422 |
+
sep2="</s>",
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
conv_llava_v1_mmtag = Conversation(
|
| 427 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
| 428 |
+
"The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
| 429 |
+
"The visual content will be provided with the following format: <Image>visual content</Image>.",
|
| 430 |
+
roles=("USER", "ASSISTANT"),
|
| 431 |
+
messages=(),
|
| 432 |
+
offset=0,
|
| 433 |
+
sep_style=SeparatorStyle.TWO,
|
| 434 |
+
sep=" ",
|
| 435 |
+
sep2="</s>",
|
| 436 |
+
version="v1_mmtag",
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
hermes_2 = Conversation(
|
| 440 |
+
system="<|im_start|>system\nAnswer the questions.",
|
| 441 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
| 442 |
+
sep_style=SeparatorStyle.MPT,
|
| 443 |
+
sep="<|im_end|>",
|
| 444 |
+
messages=(),
|
| 445 |
+
offset=0,
|
| 446 |
+
version="hermes-2",
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
# Template added by Yukang. Note (kentang-mit@): sep is <|eot_id|> for official template.
|
| 451 |
+
llama_3_chat = Conversation(
|
| 452 |
+
system="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. "
|
| 453 |
+
"You are able to understand the visual content that the user provides, "
|
| 454 |
+
"and assist the user with a variety of tasks using natural language.",
|
| 455 |
+
roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
|
| 456 |
+
version="llama_v3",
|
| 457 |
+
messages=(),
|
| 458 |
+
offset=0,
|
| 459 |
+
sep_style=SeparatorStyle.LLAMA_3,
|
| 460 |
+
sep="<|eot_id|>",
|
| 461 |
+
sep2="<|end_of_text|>",
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
default_conversation = conv_auto
|
| 466 |
+
conv_templates = {
|
| 467 |
+
"auto": conv_auto,
|
| 468 |
+
"default": conv_vicuna_v0,
|
| 469 |
+
"hermes-2": hermes_2,
|
| 470 |
+
"llama_3": llama_3_chat,
|
| 471 |
+
"v0": conv_vicuna_v0,
|
| 472 |
+
"v1": conv_vicuna_v1,
|
| 473 |
+
"vicuna_v1": conv_vicuna_v1,
|
| 474 |
+
"vicuna_v1_nosys": conv_vicuna_v1_nosys,
|
| 475 |
+
"llama_2": conv_llama_2,
|
| 476 |
+
"mistral": conv_mistral,
|
| 477 |
+
"plain": conv_llava_plain,
|
| 478 |
+
"v0_plain": conv_llava_plain,
|
| 479 |
+
"llava_v0": conv_llava_v0,
|
| 480 |
+
"v0_mmtag": conv_llava_v0_mmtag,
|
| 481 |
+
"llava_v1": conv_llava_v1,
|
| 482 |
+
"v1_mmtag": conv_llava_v1_mmtag,
|
| 483 |
+
"llava_llama_2": conv_llava_llama_2,
|
| 484 |
+
"mpt": conv_mpt,
|
| 485 |
+
}
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
if __name__ == "__main__":
|
| 489 |
+
print(default_conversation.get_prompt())
|
VILA/llava/entry.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
from transformers import PreTrainedModel
|
| 5 |
+
|
| 6 |
+
from llava.mm_utils import get_model_name_from_path
|
| 7 |
+
from llava.model.builder import load_pretrained_model
|
| 8 |
+
|
| 9 |
+
__all__ = ["load"]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def load(model_path: str, model_base: Optional[str] = None) -> PreTrainedModel:
|
| 13 |
+
model_path = os.path.expanduser(model_path)
|
| 14 |
+
model_name = get_model_name_from_path(model_path)
|
| 15 |
+
if os.path.exists(os.path.join(model_path, "model")):
|
| 16 |
+
model_path = os.path.join(model_path, "model")
|
| 17 |
+
_, model, _, _ = load_pretrained_model(model_path, model_name, model_base)
|
| 18 |
+
return model
|
VILA/llava/mm_utils.py
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
import base64
|
| 18 |
+
import os
|
| 19 |
+
import tempfile
|
| 20 |
+
from io import BytesIO
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import torch
|
| 24 |
+
from PIL import Image
|
| 25 |
+
from transformers import StoppingCriteria
|
| 26 |
+
|
| 27 |
+
from llava.constants import IMAGE_TOKEN_INDEX
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def get_frame_from_vcap(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None):
|
| 31 |
+
import cv2
|
| 32 |
+
|
| 33 |
+
if fps == None or frame_count == None:
|
| 34 |
+
# if one of fps or frame_count is None, still recompute
|
| 35 |
+
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
| 36 |
+
frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 37 |
+
if fps == 0 or frame_count == 0:
|
| 38 |
+
print(f"Video file not found. return empty images. {video_file_name}")
|
| 39 |
+
return [
|
| 40 |
+
Image.new("RGB", (720, 720)),
|
| 41 |
+
] * num_frames, 0
|
| 42 |
+
|
| 43 |
+
duration = frame_count / fps
|
| 44 |
+
frame_interval = frame_count // num_frames
|
| 45 |
+
if frame_interval == 0 and frame_count <= 1:
|
| 46 |
+
print(f"frame_interval is equal to 0. return empty image. {video_file_name}")
|
| 47 |
+
return [
|
| 48 |
+
Image.new("RGB", (720, 720)),
|
| 49 |
+
] * num_frames, 0
|
| 50 |
+
# print("duration:", duration, "frames:", frame_count, "intervals:", frame_interval)
|
| 51 |
+
|
| 52 |
+
images = []
|
| 53 |
+
count = 0
|
| 54 |
+
success = True
|
| 55 |
+
frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int)
|
| 56 |
+
while success:
|
| 57 |
+
# print("frame_count:", frame_count, "count:", count, "num_frames:", num_frames, "frame_interval:", frame_interval)
|
| 58 |
+
if frame_count >= num_frames:
|
| 59 |
+
success, frame = vidcap.read()
|
| 60 |
+
if count in frame_indices:
|
| 61 |
+
try:
|
| 62 |
+
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 63 |
+
im_pil = Image.fromarray(img)
|
| 64 |
+
images.append(im_pil)
|
| 65 |
+
except BaseException:
|
| 66 |
+
continue
|
| 67 |
+
if len(images) >= num_frames:
|
| 68 |
+
return images, num_frames
|
| 69 |
+
count += 1
|
| 70 |
+
else:
|
| 71 |
+
# Left padding frames if the video is not long enough
|
| 72 |
+
success, frame = vidcap.read()
|
| 73 |
+
if success:
|
| 74 |
+
try:
|
| 75 |
+
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 76 |
+
im_pil = Image.fromarray(img)
|
| 77 |
+
images.append(im_pil)
|
| 78 |
+
except BaseException:
|
| 79 |
+
continue
|
| 80 |
+
count += 1
|
| 81 |
+
else:
|
| 82 |
+
break
|
| 83 |
+
if len(images) == 0:
|
| 84 |
+
raise ValueError("Did not find enough frames in the video. return empty image.")
|
| 85 |
+
|
| 86 |
+
return images, len(images)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def get_frame_from_vcap_with_fps(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None):
|
| 90 |
+
"""
|
| 91 |
+
num_frames is the max number of frames the model can support.
|
| 92 |
+
frame_count is the number of frames in the input video.
|
| 93 |
+
max_fps is the max FPS of the model can support.
|
| 94 |
+
fps is the fps of the input video.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
import random
|
| 98 |
+
|
| 99 |
+
import cv2
|
| 100 |
+
|
| 101 |
+
if fps == None or frame_count == None:
|
| 102 |
+
# if one of fps or frame_count is None, still recompute
|
| 103 |
+
fps = vidcap.get(cv2.CAP_PROP_FPS)
|
| 104 |
+
frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 105 |
+
|
| 106 |
+
if fps == 0 or frame_count == 0:
|
| 107 |
+
print(f"Video file not found. return empty images. {video_file_name}")
|
| 108 |
+
empty_video_frames = int(random.uniform(2, 8 * max_fps))
|
| 109 |
+
return [
|
| 110 |
+
Image.new("RGB", (720, 720)),
|
| 111 |
+
] * empty_video_frames, 0
|
| 112 |
+
|
| 113 |
+
duration = frame_count / fps
|
| 114 |
+
# print("duration:", duration, "frames:", frame_count, "fps:", fps, "num_frames:", num_frames, "max_fps:", max_fps)
|
| 115 |
+
# If the video is too long (longer than max_fps and num_frames can support),
|
| 116 |
+
# we will use lower fps to sample frames.
|
| 117 |
+
if duration >= num_frames / max_fps:
|
| 118 |
+
frame_interval = frame_count // num_frames
|
| 119 |
+
|
| 120 |
+
# If the video is too short, we will skip the video if there is only one frame.
|
| 121 |
+
if frame_interval == 0 and frame_count <= 1:
|
| 122 |
+
print(f"frame_interval is equal to 0. return empty image. {video_file_name}")
|
| 123 |
+
empty_video_frames = int(random.uniform(2, 8 * max_fps))
|
| 124 |
+
return [
|
| 125 |
+
Image.new("RGB", (720, 720)),
|
| 126 |
+
] * empty_video_frames, 0
|
| 127 |
+
|
| 128 |
+
images = []
|
| 129 |
+
count = 0
|
| 130 |
+
success = True
|
| 131 |
+
frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int)
|
| 132 |
+
|
| 133 |
+
while success:
|
| 134 |
+
if frame_count >= num_frames:
|
| 135 |
+
# success, frame = vidcap.read()
|
| 136 |
+
if count in frame_indices:
|
| 137 |
+
success, frame = vidcap.read()
|
| 138 |
+
try:
|
| 139 |
+
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 140 |
+
im_pil = Image.fromarray(img)
|
| 141 |
+
images.append(im_pil)
|
| 142 |
+
except:
|
| 143 |
+
# print("Failed to read frame:", count)
|
| 144 |
+
continue
|
| 145 |
+
if len(images) >= num_frames:
|
| 146 |
+
return images, num_frames
|
| 147 |
+
else:
|
| 148 |
+
success = vidcap.grab()
|
| 149 |
+
count += 1
|
| 150 |
+
else:
|
| 151 |
+
# Left padding frames if the video is not long enough
|
| 152 |
+
success, frame = vidcap.read()
|
| 153 |
+
if success:
|
| 154 |
+
try:
|
| 155 |
+
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 156 |
+
im_pil = Image.fromarray(img)
|
| 157 |
+
images.append(im_pil)
|
| 158 |
+
except:
|
| 159 |
+
# print("Failed to read frame:", count)
|
| 160 |
+
continue
|
| 161 |
+
count += 1
|
| 162 |
+
else:
|
| 163 |
+
break
|
| 164 |
+
else:
|
| 165 |
+
frames_required = int(duration * max_fps)
|
| 166 |
+
frame_indices = np.linspace(0, frame_count - 1, frames_required, dtype=int)
|
| 167 |
+
if frames_required == 0:
|
| 168 |
+
print(f"frames_required is fewer than 2. Duration {duration}, return empty image.")
|
| 169 |
+
empty_video_frames = int(random.uniform(2, 8 * max_fps))
|
| 170 |
+
return [
|
| 171 |
+
Image.new("RGB", (720, 720)),
|
| 172 |
+
] * empty_video_frames, 0
|
| 173 |
+
elif frames_required == 1:
|
| 174 |
+
frame_indices = np.linspace(0, frame_count - 1, 2, dtype=int)
|
| 175 |
+
images = []
|
| 176 |
+
count = 0
|
| 177 |
+
looked = 0
|
| 178 |
+
success = True
|
| 179 |
+
|
| 180 |
+
while success:
|
| 181 |
+
success, frame = vidcap.read()
|
| 182 |
+
if success and (looked in frame_indices):
|
| 183 |
+
try:
|
| 184 |
+
img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 185 |
+
im_pil = Image.fromarray(img)
|
| 186 |
+
images.append(im_pil)
|
| 187 |
+
except:
|
| 188 |
+
continue
|
| 189 |
+
count += 1
|
| 190 |
+
looked += 1
|
| 191 |
+
|
| 192 |
+
if len(images) == 0:
|
| 193 |
+
empty_video_frames = int(random.uniform(2, 8 * max_fps))
|
| 194 |
+
return [
|
| 195 |
+
Image.new("RGB", (720, 720)),
|
| 196 |
+
] * empty_video_frames, 0
|
| 197 |
+
else:
|
| 198 |
+
return images, len(images)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def opencv_extract_frames(vpath_or_bytesio, frames=6, max_fps=0.0, fps=None, frame_count=None):
|
| 202 |
+
"""
|
| 203 |
+
Extract frames from a video using OpenCV.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
vpath_or_bytesio (str or BytesIO): Path to the video file or BytesIO object containing the video.
|
| 207 |
+
frames (int): Number of frames to extract from the video.
|
| 208 |
+
fps (float): Frames per second of the video. If 0.0, the function will extract frames at equal intervals.
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
list: List of PIL Images extracted from the video.
|
| 212 |
+
|
| 213 |
+
Raises:
|
| 214 |
+
NotImplementedError: If the type of `vpath_or_bytesio` is not supported.
|
| 215 |
+
"""
|
| 216 |
+
import cv2
|
| 217 |
+
|
| 218 |
+
if isinstance(vpath_or_bytesio, str):
|
| 219 |
+
vidcap = cv2.VideoCapture(vpath_or_bytesio)
|
| 220 |
+
if max_fps > 0.0:
|
| 221 |
+
return get_frame_from_vcap_with_fps(
|
| 222 |
+
vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=vpath_or_bytesio
|
| 223 |
+
)
|
| 224 |
+
return get_frame_from_vcap(
|
| 225 |
+
vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=vpath_or_bytesio
|
| 226 |
+
)
|
| 227 |
+
elif isinstance(vpath_or_bytesio, (BytesIO,)):
|
| 228 |
+
# assuming mp4
|
| 229 |
+
with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video:
|
| 230 |
+
temp_video.write(vpath_or_bytesio.read())
|
| 231 |
+
temp_video_name = temp_video.name
|
| 232 |
+
vidcap = cv2.VideoCapture(temp_video_name)
|
| 233 |
+
if max_fps > 0.0:
|
| 234 |
+
return get_frame_from_vcap_with_fps(
|
| 235 |
+
vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=temp_video_name
|
| 236 |
+
)
|
| 237 |
+
return get_frame_from_vcap(
|
| 238 |
+
vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=temp_video_name
|
| 239 |
+
)
|
| 240 |
+
else:
|
| 241 |
+
raise NotImplementedError(type(vpath_or_bytesio))
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def load_image_from_base64(image):
|
| 245 |
+
return Image.open(BytesIO(base64.b64decode(image)))
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def expand2square(pil_img, background_color):
|
| 249 |
+
"""
|
| 250 |
+
Expand the given PIL image to a square shape by adding padding.
|
| 251 |
+
|
| 252 |
+
Parameters:
|
| 253 |
+
- pil_img: The PIL image to be expanded.
|
| 254 |
+
- background_color: The color of the padding to be added.
|
| 255 |
+
|
| 256 |
+
Returns:
|
| 257 |
+
- The expanded PIL image.
|
| 258 |
+
|
| 259 |
+
If the image is already square, it is returned as is.
|
| 260 |
+
If the image is wider than it is tall, padding is added to the top and bottom.
|
| 261 |
+
If the image is taller than it is wide, padding is added to the left and right.
|
| 262 |
+
"""
|
| 263 |
+
width, height = pil_img.size
|
| 264 |
+
if pil_img.mode == "L":
|
| 265 |
+
background_color = background_color[0]
|
| 266 |
+
if width == height:
|
| 267 |
+
return pil_img
|
| 268 |
+
elif width > height:
|
| 269 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
| 270 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
| 271 |
+
return result
|
| 272 |
+
else:
|
| 273 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
| 274 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
| 275 |
+
return result
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def process_image(image_file, data_args, image_folder):
|
| 279 |
+
processor = data_args.image_processor
|
| 280 |
+
if isinstance(image_file, str):
|
| 281 |
+
if image_folder is not None:
|
| 282 |
+
image = Image.open(os.path.join(image_folder, image_file)).convert("RGB")
|
| 283 |
+
else:
|
| 284 |
+
image = Image.open(image_file).convert("RGB")
|
| 285 |
+
else:
|
| 286 |
+
# image is stored in bytearray
|
| 287 |
+
image = image_file
|
| 288 |
+
image = image.convert("RGB")
|
| 289 |
+
if data_args.image_aspect_ratio == "resize":
|
| 290 |
+
if hasattr(data_args.image_processor, "crop_size"):
|
| 291 |
+
# CLIP vision tower
|
| 292 |
+
crop_size = data_args.image_processor.crop_size
|
| 293 |
+
else:
|
| 294 |
+
# SIGLIP vision tower
|
| 295 |
+
assert hasattr(data_args.image_processor, "size")
|
| 296 |
+
crop_size = data_args.image_processor.size
|
| 297 |
+
image = image.resize((crop_size["height"], crop_size["width"]))
|
| 298 |
+
if data_args.image_aspect_ratio == "pad":
|
| 299 |
+
|
| 300 |
+
def expand2square(pil_img, background_color):
|
| 301 |
+
width, height = pil_img.size
|
| 302 |
+
if width == height:
|
| 303 |
+
return pil_img
|
| 304 |
+
elif width > height:
|
| 305 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
| 306 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
| 307 |
+
return result
|
| 308 |
+
else:
|
| 309 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
| 310 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
| 311 |
+
return result
|
| 312 |
+
|
| 313 |
+
image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
|
| 314 |
+
image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
|
| 315 |
+
else:
|
| 316 |
+
# Using default behavior of the vision encoder
|
| 317 |
+
# For CLIP, default is central crop
|
| 318 |
+
# For Radio, default is central crop
|
| 319 |
+
# For Siglip, default is resize
|
| 320 |
+
# For InternVIT, default is resize
|
| 321 |
+
image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
|
| 322 |
+
return image
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def process_images(images, image_processor, model_cfg):
|
| 326 |
+
|
| 327 |
+
model_cfg.image_processor = image_processor
|
| 328 |
+
new_images = [process_image(image, model_cfg, None) for image in images]
|
| 329 |
+
|
| 330 |
+
if all(x.shape == new_images[0].shape for x in new_images):
|
| 331 |
+
new_images = torch.stack(new_images, dim=0)
|
| 332 |
+
return new_images
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None, lstrip=False):
|
| 336 |
+
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
|
| 337 |
+
|
| 338 |
+
def insert_separator(X, sep):
|
| 339 |
+
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
|
| 340 |
+
|
| 341 |
+
input_ids = []
|
| 342 |
+
offset = 0
|
| 343 |
+
if lstrip:
|
| 344 |
+
offset = 1
|
| 345 |
+
else:
|
| 346 |
+
if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
|
| 347 |
+
offset = 1
|
| 348 |
+
input_ids.append(prompt_chunks[0][0])
|
| 349 |
+
|
| 350 |
+
for chunk_id, x in enumerate(insert_separator(prompt_chunks, [image_token_index] * (offset + 1))):
|
| 351 |
+
if chunk_id == 0 and lstrip:
|
| 352 |
+
input_ids.extend(x)
|
| 353 |
+
else:
|
| 354 |
+
input_ids.extend(x[offset:])
|
| 355 |
+
|
| 356 |
+
if return_tensors is not None:
|
| 357 |
+
if return_tensors == "pt":
|
| 358 |
+
return torch.tensor(input_ids, dtype=torch.long)
|
| 359 |
+
raise ValueError(f"Unsupported tensor type: {return_tensors}")
|
| 360 |
+
return input_ids
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def is_gemma_tokenizer(tokenizer):
|
| 364 |
+
return "gemma" in tokenizer.__class__.__name__.lower()
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def get_model_name_from_path(model_path):
|
| 368 |
+
model_path = model_path.strip("/")
|
| 369 |
+
model_paths = model_path.split("/")
|
| 370 |
+
if model_paths[-1].startswith("checkpoint-"):
|
| 371 |
+
return model_paths[-2] + "_" + model_paths[-1]
|
| 372 |
+
else:
|
| 373 |
+
return model_paths[-1]
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
class KeywordsStoppingCriteria(StoppingCriteria):
|
| 377 |
+
def __init__(self, keywords, tokenizer, input_ids):
|
| 378 |
+
self.keywords = keywords
|
| 379 |
+
self.keyword_ids = []
|
| 380 |
+
self.max_keyword_len = 0
|
| 381 |
+
for keyword in keywords:
|
| 382 |
+
cur_keyword_ids = tokenizer(keyword).input_ids
|
| 383 |
+
if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
|
| 384 |
+
cur_keyword_ids = cur_keyword_ids[1:]
|
| 385 |
+
if len(cur_keyword_ids) > self.max_keyword_len:
|
| 386 |
+
self.max_keyword_len = len(cur_keyword_ids)
|
| 387 |
+
self.keyword_ids.append(torch.tensor(cur_keyword_ids))
|
| 388 |
+
self.tokenizer = tokenizer
|
| 389 |
+
self.start_len = input_ids.shape[1]
|
| 390 |
+
|
| 391 |
+
def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
| 392 |
+
offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
|
| 393 |
+
self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
|
| 394 |
+
for keyword_id in self.keyword_ids:
|
| 395 |
+
if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all():
|
| 396 |
+
return True
|
| 397 |
+
outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
|
| 398 |
+
for keyword in self.keywords:
|
| 399 |
+
if keyword in outputs:
|
| 400 |
+
return True
|
| 401 |
+
return False
|
| 402 |
+
|
| 403 |
+
def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
| 404 |
+
outputs = []
|
| 405 |
+
for i in range(output_ids.shape[0]):
|
| 406 |
+
outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
|
| 407 |
+
return all(outputs)
|
VILA/llava/modals.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
__all__ = ["Modal", "Image", "Video"]
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Modal:
|
| 7 |
+
pass
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class File(Modal):
|
| 11 |
+
EXTENSIONS = None
|
| 12 |
+
|
| 13 |
+
def __init__(self, path: str) -> None:
|
| 14 |
+
self.path = path
|
| 15 |
+
if not os.path.exists(path):
|
| 16 |
+
raise FileNotFoundError(f"File not found: {path}")
|
| 17 |
+
if self.EXTENSIONS is not None and not any(path.endswith(ext) for ext in self.EXTENSIONS):
|
| 18 |
+
raise ValueError(f"Unsupported file extension: {os.path.splitext(path)[1]}")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Image(File):
|
| 22 |
+
EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".webp", ".mp4", ".mov", ".avi", ".mkv", ".webm"]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Video(File):
|
| 26 |
+
EXTENSIONS = [".mp4"]
|
VILA/scripts/convert_gqa_for_eval.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
|
| 20 |
+
parser = argparse.ArgumentParser()
|
| 21 |
+
parser.add_argument("--src", type=str)
|
| 22 |
+
parser.add_argument("--dst", type=str)
|
| 23 |
+
args = parser.parse_args()
|
| 24 |
+
|
| 25 |
+
all_answers = []
|
| 26 |
+
for line_idx, line in enumerate(open(args.src)):
|
| 27 |
+
res = json.loads(line)
|
| 28 |
+
question_id = res["question_id"]
|
| 29 |
+
text = res["text"].rstrip(".").lower()
|
| 30 |
+
all_answers.append({"questionId": question_id, "prediction": text})
|
| 31 |
+
|
| 32 |
+
with open(args.dst, "w") as f:
|
| 33 |
+
json.dump(all_answers, f)
|
VILA/scripts/convert_karpathy_to_anno.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
# import json
|
| 18 |
+
|
| 19 |
+
# # coco for reference
|
| 20 |
+
# # dict_keys(['images', 'dataset']) -> ["images"]
|
| 21 |
+
# karpathy = json.load(open("/home/jil/datasets/karpathy_json/dataset_coco.json"))
|
| 22 |
+
# # dict_keys(['info', 'images', 'licenses', 'annotations']) -> ['images', 'annotations']]
|
| 23 |
+
# anno = json.load(open("/tmp/coco/annotations/captions_val2014.json"))
|
| 24 |
+
# # assert len(karpathy["images"]) == len(anno["images"]) == len(anno["annotations"]), (
|
| 25 |
+
# # len(karpathy["images"]), len(anno["images"]), len(anno["annotations"]) # (123287, 40504, 202654)
|
| 26 |
+
# # )
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# karpathy_flickr = json.load(open("/home/jil/datasets/karpathy_json/dataset_coco.json"))
|
| 30 |
+
# anno_flickr = {
|
| 31 |
+
# "images": [],
|
| 32 |
+
# "annotations": [],
|
| 33 |
+
# }
|
| 34 |
+
|
| 35 |
+
# print(karpathy["images"][0])
|
| 36 |
+
# print(anno["images"][0])
|
| 37 |
+
# print(anno["annotations"][:3])
|
| 38 |
+
|
| 39 |
+
# image_id_set = set([_["id"] for _ in anno["images"]])
|
| 40 |
+
# anno_set = set([_["id"] for _ in anno["annotations"]])
|
| 41 |
+
|
| 42 |
+
# print(len(anno_set))
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
import argparse
|
| 46 |
+
import json
|
| 47 |
+
|
| 48 |
+
from tqdm import tqdm
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def main(input_json, output_json, split):
|
| 52 |
+
annot_format = {
|
| 53 |
+
"info": {
|
| 54 |
+
"year": 2014,
|
| 55 |
+
"version": "1.0",
|
| 56 |
+
"description": "This is stable 1.0 version of the 2014 MS COCO dataset.",
|
| 57 |
+
"contributor": "Microsoft COCO group",
|
| 58 |
+
"url": "http://mscoco.org",
|
| 59 |
+
"date_created": "2015-01-27 09:11:52.357475",
|
| 60 |
+
},
|
| 61 |
+
"licenses": [
|
| 62 |
+
{
|
| 63 |
+
"url": "http://creativecommons.org/licenses/by-nc-sa/2.0/",
|
| 64 |
+
"id": 1,
|
| 65 |
+
"name": "Attribution-NonCommercial-ShareAlike License",
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"url": "http://creativecommons.org/licenses/by-nc/2.0/",
|
| 69 |
+
"id": 2,
|
| 70 |
+
"name": "Attribution-NonCommercial License",
|
| 71 |
+
},
|
| 72 |
+
{
|
| 73 |
+
"url": "http://creativecommons.org/licenses/by-nc-nd/2.0/",
|
| 74 |
+
"id": 3,
|
| 75 |
+
"name": "Attribution-NonCommercial-NoDerivs License",
|
| 76 |
+
},
|
| 77 |
+
{"url": "http://creativecommons.org/licenses/by/2.0/", "id": 4, "name": "Attribution License"},
|
| 78 |
+
{
|
| 79 |
+
"url": "http://creativecommons.org/licenses/by-sa/2.0/",
|
| 80 |
+
"id": 5,
|
| 81 |
+
"name": "Attribution-ShareAlike License",
|
| 82 |
+
},
|
| 83 |
+
{"url": "http://creativecommons.org/licenses/by-nd/2.0/", "id": 6, "name": "Attribution-NoDerivs License"},
|
| 84 |
+
{"url": "http://flickr.com/commons/usage/", "id": 7, "name": "No known copyright restrictions"},
|
| 85 |
+
{"url": "http://www.usa.gov/copyright.shtml", "id": 8, "name": "United States Government Work"},
|
| 86 |
+
],
|
| 87 |
+
"type": "captions",
|
| 88 |
+
"images": [],
|
| 89 |
+
"annotations": [],
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
with open(input_json) as f:
|
| 93 |
+
dataset = json.load(f)
|
| 94 |
+
annotations = dataset["images"]
|
| 95 |
+
dataset_name = dataset["dataset"]
|
| 96 |
+
|
| 97 |
+
count = 0
|
| 98 |
+
print(f"Converting Karpathy {dataset_name} {split} to COCO Format...")
|
| 99 |
+
for annot in tqdm(annotations):
|
| 100 |
+
if split == "all" or (annot["split"] == split):
|
| 101 |
+
image_id = str(annot["filename"].split(".")[0]) # annot['imgid']
|
| 102 |
+
annot_format["images"].append(
|
| 103 |
+
{
|
| 104 |
+
"id": image_id,
|
| 105 |
+
"width": 512,
|
| 106 |
+
"height": 512,
|
| 107 |
+
"filename": annot["filename"],
|
| 108 |
+
"license": 1,
|
| 109 |
+
"flickr_url": "",
|
| 110 |
+
"coco_url": "",
|
| 111 |
+
"date_captured": "",
|
| 112 |
+
}
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
for sent in annot["sentences"]:
|
| 116 |
+
annot_format["annotations"].append({"id": sent["sentid"], "image_id": image_id, "caption": sent["raw"]})
|
| 117 |
+
count += 1
|
| 118 |
+
|
| 119 |
+
with open(output_json, "w") as f:
|
| 120 |
+
json.dump(annot_format, f)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
if __name__ == "__main__":
|
| 124 |
+
parser = argparse.ArgumentParser()
|
| 125 |
+
parser.add_argument("--input-json", type=str, default="/home/jil/datasets/karpathy_json/dataset_flickr30k.json")
|
| 126 |
+
parser.add_argument("--output-json", type=str, default="/home/jil/datasets/flickr30k/flickr30k_coco_all.json")
|
| 127 |
+
parser.add_argument("--split", type=str, default="all")
|
| 128 |
+
args = parser.parse_args()
|
| 129 |
+
|
| 130 |
+
main(args.input_json, args.output_json, args.split)
|
VILA/scripts/convert_mmbench_for_submission.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
|
| 21 |
+
import pandas as pd
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_args():
|
| 25 |
+
parser = argparse.ArgumentParser()
|
| 26 |
+
parser.add_argument("--annotation-file", type=str, required=True)
|
| 27 |
+
parser.add_argument("--result-dir", type=str, required=True)
|
| 28 |
+
parser.add_argument("--upload-dir", type=str, required=True)
|
| 29 |
+
parser.add_argument("--experiment", type=str, required=True)
|
| 30 |
+
|
| 31 |
+
return parser.parse_args()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
if __name__ == "__main__":
|
| 35 |
+
args = get_args()
|
| 36 |
+
|
| 37 |
+
df = pd.read_table(args.annotation_file)
|
| 38 |
+
|
| 39 |
+
cur_df = df.copy()
|
| 40 |
+
cur_df = cur_df.drop(columns=["hint", "category", "source", "image", "comment", "l2-category"])
|
| 41 |
+
cur_df.insert(6, "prediction", None)
|
| 42 |
+
for pred in open(os.path.join(args.result_dir, f"{args.experiment}.jsonl")):
|
| 43 |
+
pred = json.loads(pred)
|
| 44 |
+
cur_df.loc[df["index"] == pred["question_id"], "prediction"] = pred["text"]
|
| 45 |
+
|
| 46 |
+
cur_df.to_excel(os.path.join(args.upload_dir, f"{args.experiment}_upload.xlsx"), index=False, engine="openpyxl")
|
VILA/scripts/convert_mmvet_for_eval.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
|
| 20 |
+
parser = argparse.ArgumentParser()
|
| 21 |
+
parser.add_argument("--src", type=str)
|
| 22 |
+
parser.add_argument("--dst", type=str)
|
| 23 |
+
args = parser.parse_args()
|
| 24 |
+
|
| 25 |
+
cur_result = {}
|
| 26 |
+
|
| 27 |
+
for line in open(args.src):
|
| 28 |
+
data = json.loads(line)
|
| 29 |
+
qid = data["question_id"]
|
| 30 |
+
cur_result[f"v1_{qid}"] = data["text"]
|
| 31 |
+
|
| 32 |
+
with open(args.dst, "w") as f:
|
| 33 |
+
json.dump(cur_result, f, indent=2)
|
VILA/scripts/convert_seed_for_submission.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_args():
|
| 22 |
+
parser = argparse.ArgumentParser()
|
| 23 |
+
parser.add_argument("--annotation-file", type=str)
|
| 24 |
+
parser.add_argument("--result-file", type=str)
|
| 25 |
+
parser.add_argument("--result-upload-file", type=str)
|
| 26 |
+
return parser.parse_args()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def eval_single(result_file, eval_only_type=None):
|
| 30 |
+
results = {}
|
| 31 |
+
for line in open(result_file):
|
| 32 |
+
row = json.loads(line)
|
| 33 |
+
results[row["question_id"]] = row
|
| 34 |
+
|
| 35 |
+
type_counts = {}
|
| 36 |
+
correct_counts = {}
|
| 37 |
+
for question_data in data["questions"]:
|
| 38 |
+
if eval_only_type is not None and question_data["data_type"] != eval_only_type:
|
| 39 |
+
continue
|
| 40 |
+
data_type = question_data["question_type_id"]
|
| 41 |
+
type_counts[data_type] = type_counts.get(data_type, 0) + 1
|
| 42 |
+
try:
|
| 43 |
+
question_id = int(question_data["question_id"])
|
| 44 |
+
except BaseException:
|
| 45 |
+
question_id = question_data["question_id"]
|
| 46 |
+
if question_id not in results:
|
| 47 |
+
correct_counts[data_type] = correct_counts.get(data_type, 0)
|
| 48 |
+
continue
|
| 49 |
+
row = results[question_id]
|
| 50 |
+
if row["text"] == question_data["answer"]:
|
| 51 |
+
correct_counts[data_type] = correct_counts.get(data_type, 0) + 1
|
| 52 |
+
|
| 53 |
+
total_count = 0
|
| 54 |
+
total_correct = 0
|
| 55 |
+
for data_type in sorted(type_counts.keys()):
|
| 56 |
+
accuracy = correct_counts[data_type] / type_counts[data_type] * 100
|
| 57 |
+
if eval_only_type is None:
|
| 58 |
+
print(f"{ques_type_id_to_name[data_type]}: {accuracy:.2f}%")
|
| 59 |
+
|
| 60 |
+
total_count += type_counts[data_type]
|
| 61 |
+
total_correct += correct_counts[data_type]
|
| 62 |
+
|
| 63 |
+
total_accuracy = total_correct / total_count * 100
|
| 64 |
+
if eval_only_type is None:
|
| 65 |
+
print(f"Total accuracy: {total_accuracy:.2f}%")
|
| 66 |
+
else:
|
| 67 |
+
print(f"{eval_only_type} accuracy: {total_accuracy:.2f}%")
|
| 68 |
+
|
| 69 |
+
return results
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
if __name__ == "__main__":
|
| 73 |
+
args = get_args()
|
| 74 |
+
data = json.load(open(args.annotation_file))
|
| 75 |
+
ques_type_id_to_name = {id: n for n, id in data["question_type"].items()}
|
| 76 |
+
|
| 77 |
+
results = eval_single(args.result_file)
|
| 78 |
+
eval_single(args.result_file, eval_only_type="image")
|
| 79 |
+
eval_single(args.result_file, eval_only_type="video")
|
| 80 |
+
|
| 81 |
+
with open(args.result_upload_file, "w") as fp:
|
| 82 |
+
for question in data["questions"]:
|
| 83 |
+
qid = question["question_id"]
|
| 84 |
+
if qid in results:
|
| 85 |
+
result = results[qid]
|
| 86 |
+
else:
|
| 87 |
+
result = results[int(qid)]
|
| 88 |
+
fp.write(json.dumps({"question_id": qid, "prediction": result["text"]}) + "\n")
|
VILA/scripts/convert_sqa_to_llava.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
|
| 20 |
+
import fire
|
| 21 |
+
from convert_sqa_to_llava_base_prompt import build_prompt_chatbot
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def convert_to_llava(base_dir, split, prompt_format="QCM-LEA"):
|
| 25 |
+
split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split]
|
| 26 |
+
problems = json.load(open(os.path.join(base_dir, "problems.json")))
|
| 27 |
+
|
| 28 |
+
split_problems = build_prompt_chatbot(problems, split_indices, prompt_format, use_caption=False, is_test=False)
|
| 29 |
+
|
| 30 |
+
target_format = []
|
| 31 |
+
for prob_id, (input, output) in split_problems.items():
|
| 32 |
+
if input.startswith("Question: "):
|
| 33 |
+
input = input.replace("Question: ", "")
|
| 34 |
+
if output.startswith("Answer: "):
|
| 35 |
+
output = output.replace("Answer: ", "")
|
| 36 |
+
|
| 37 |
+
raw_prob_data = problems[prob_id]
|
| 38 |
+
if raw_prob_data["image"] is None:
|
| 39 |
+
target_format.append(
|
| 40 |
+
{
|
| 41 |
+
"id": prob_id,
|
| 42 |
+
"conversations": [
|
| 43 |
+
{"from": "human", "value": f"{input}"},
|
| 44 |
+
{"from": "gpt", "value": f"{output}"},
|
| 45 |
+
],
|
| 46 |
+
}
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
else:
|
| 50 |
+
target_format.append(
|
| 51 |
+
{
|
| 52 |
+
"id": prob_id,
|
| 53 |
+
"image": os.path.join(prob_id, raw_prob_data["image"]),
|
| 54 |
+
"conversations": [
|
| 55 |
+
{"from": "human", "value": f"{input}\n<image>"},
|
| 56 |
+
{"from": "gpt", "value": f"{output}"},
|
| 57 |
+
],
|
| 58 |
+
}
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
print(f"Number of samples: {len(target_format)}")
|
| 62 |
+
|
| 63 |
+
with open(os.path.join(base_dir, f"llava_{split}_{prompt_format}.json"), "w") as f:
|
| 64 |
+
json.dump(target_format, f, indent=2)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def convert_to_jsonl(base_dir, split, prompt_format="QCM-LEPA"):
|
| 68 |
+
split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split]
|
| 69 |
+
problems = json.load(open(os.path.join(base_dir, "problems.json")))
|
| 70 |
+
|
| 71 |
+
split_problems = build_prompt_chatbot(problems, split_indices, prompt_format, use_caption=False, is_test=False)
|
| 72 |
+
|
| 73 |
+
writer = open(os.path.join(base_dir, f"scienceqa_{split}_{prompt_format}.jsonl"), "w")
|
| 74 |
+
for prob_id, (input, output) in split_problems.items():
|
| 75 |
+
if input.startswith("Question: "):
|
| 76 |
+
input = input.replace("Question: ", "")
|
| 77 |
+
if output.startswith("Answer: "):
|
| 78 |
+
output = output.replace("Answer: ", "")
|
| 79 |
+
|
| 80 |
+
raw_prob_data = problems[prob_id]
|
| 81 |
+
if raw_prob_data["image"] is None:
|
| 82 |
+
data = {
|
| 83 |
+
"id": prob_id,
|
| 84 |
+
"instruction": f"{input}",
|
| 85 |
+
"output": f"{output}",
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
else:
|
| 89 |
+
data = {
|
| 90 |
+
"id": prob_id,
|
| 91 |
+
"image": os.path.join(prob_id, raw_prob_data["image"]),
|
| 92 |
+
"instruction": f"{input}\n<image>",
|
| 93 |
+
"output": f"{output}",
|
| 94 |
+
}
|
| 95 |
+
writer.write(json.dumps(data) + "\n")
|
| 96 |
+
writer.close()
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def main(task, **kwargs):
|
| 100 |
+
globals()[task](**kwargs)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
if __name__ == "__main__":
|
| 104 |
+
fire.Fire(main)
|
VILA/scripts/convert_sqa_to_llava_base_prompt.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_question_text(problem):
|
| 19 |
+
question = problem["question"]
|
| 20 |
+
return question
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_context_text(problem, use_caption):
|
| 24 |
+
txt_context = problem["hint"]
|
| 25 |
+
img_context = problem["caption"] if use_caption else ""
|
| 26 |
+
context = " ".join([txt_context, img_context]).strip()
|
| 27 |
+
if context == "":
|
| 28 |
+
context = "N/A"
|
| 29 |
+
return context
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_choice_text(probelm, options):
|
| 33 |
+
choices = probelm["choices"]
|
| 34 |
+
choice_list = []
|
| 35 |
+
for i, c in enumerate(choices):
|
| 36 |
+
choice_list.append(f"({options[i]}) {c}")
|
| 37 |
+
choice_txt = " ".join(choice_list)
|
| 38 |
+
# print(choice_txt)
|
| 39 |
+
return choice_txt
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_answer(problem, options):
|
| 43 |
+
return options[problem["answer"]]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_lecture_text(problem):
|
| 47 |
+
# \\n: GPT-3 can generate the lecture with more tokens.
|
| 48 |
+
lecture = problem["lecture"].replace("\n", "\\n")
|
| 49 |
+
return lecture
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_solution_text(problem):
|
| 53 |
+
# \\n: GPT-3 can generate the solution with more tokens
|
| 54 |
+
solution = problem["solution"].replace("\n", "\\n")
|
| 55 |
+
return solution
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def create_one_example_chatbot(format, question, context, choice, answer, lecture, solution, test_example=True):
|
| 59 |
+
|
| 60 |
+
input_format, output_format = format.split("-")
|
| 61 |
+
|
| 62 |
+
## Inputs
|
| 63 |
+
if input_format == "CQM":
|
| 64 |
+
input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n"
|
| 65 |
+
elif input_format == "QCM":
|
| 66 |
+
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n"
|
| 67 |
+
# upper bound experiment
|
| 68 |
+
elif input_format == "QCML":
|
| 69 |
+
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n"
|
| 70 |
+
elif input_format == "QCME":
|
| 71 |
+
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n"
|
| 72 |
+
elif input_format == "QCMLE":
|
| 73 |
+
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n"
|
| 74 |
+
|
| 75 |
+
elif input_format == "QCLM":
|
| 76 |
+
input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n"
|
| 77 |
+
elif input_format == "QCEM":
|
| 78 |
+
input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n"
|
| 79 |
+
elif input_format == "QCLEM":
|
| 80 |
+
input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n"
|
| 81 |
+
|
| 82 |
+
# Outputs
|
| 83 |
+
if test_example:
|
| 84 |
+
output = "Answer:"
|
| 85 |
+
elif output_format == "A":
|
| 86 |
+
output = f"Answer: The answer is {answer}."
|
| 87 |
+
|
| 88 |
+
elif output_format == "AL":
|
| 89 |
+
output = f"Answer: The answer is {answer}. BECAUSE: {solution}"
|
| 90 |
+
elif output_format == "AE":
|
| 91 |
+
output = f"Answer: The answer is {answer}. BECAUSE: {lecture}"
|
| 92 |
+
elif output_format == "ALE":
|
| 93 |
+
output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}"
|
| 94 |
+
elif output_format == "AEL":
|
| 95 |
+
output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}"
|
| 96 |
+
|
| 97 |
+
elif output_format == "LA":
|
| 98 |
+
output = f"Answer: {lecture} The answer is {answer}."
|
| 99 |
+
elif output_format == "EA":
|
| 100 |
+
output = f"Answer: {solution} The answer is {answer}."
|
| 101 |
+
elif output_format == "LEA":
|
| 102 |
+
output = f"Answer: {lecture} {solution} The answer is {answer}."
|
| 103 |
+
elif output_format == "ELA":
|
| 104 |
+
output = f"Answer: {solution} {lecture} The answer is {answer}."
|
| 105 |
+
elif output_format == "LEPA":
|
| 106 |
+
output = ""
|
| 107 |
+
if len(lecture.strip()) > 0:
|
| 108 |
+
output += f"LECTURE: {lecture}\n"
|
| 109 |
+
if len(solution.strip()) > 0:
|
| 110 |
+
output += f"SOLUTION: {solution}\n"
|
| 111 |
+
output += "###\n"
|
| 112 |
+
output += f"ANSWER: {answer}."
|
| 113 |
+
|
| 114 |
+
input = input.replace(" ", " ").strip()
|
| 115 |
+
output = output.replace(" ", " ").strip()
|
| 116 |
+
if input.endswith("BECAUSE:"):
|
| 117 |
+
input = input.replace("BECAUSE:", "").strip()
|
| 118 |
+
if output.endswith("BECAUSE:"):
|
| 119 |
+
output = output.replace("BECAUSE:", "").strip()
|
| 120 |
+
return input, output
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def create_one_example(format, question, context, choice, answer, lecture, solution, test_example=True):
|
| 124 |
+
|
| 125 |
+
input_format, output_format = format.split("-")
|
| 126 |
+
|
| 127 |
+
## Inputs
|
| 128 |
+
if input_format == "CQM":
|
| 129 |
+
input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n"
|
| 130 |
+
elif input_format == "QCM":
|
| 131 |
+
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n"
|
| 132 |
+
# upper bound experiment
|
| 133 |
+
elif input_format == "QCML":
|
| 134 |
+
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n"
|
| 135 |
+
elif input_format == "QCME":
|
| 136 |
+
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n"
|
| 137 |
+
elif input_format == "QCMLE":
|
| 138 |
+
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n"
|
| 139 |
+
|
| 140 |
+
elif input_format == "QCLM":
|
| 141 |
+
input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n"
|
| 142 |
+
elif input_format == "QCEM":
|
| 143 |
+
input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n"
|
| 144 |
+
elif input_format == "QCLEM":
|
| 145 |
+
input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n"
|
| 146 |
+
|
| 147 |
+
# Outputs
|
| 148 |
+
if test_example:
|
| 149 |
+
output = "Answer:"
|
| 150 |
+
elif output_format == "A":
|
| 151 |
+
output = f"Answer: The answer is {answer}."
|
| 152 |
+
|
| 153 |
+
elif output_format == "AL":
|
| 154 |
+
output = f"Answer: The answer is {answer}. BECAUSE: {solution}"
|
| 155 |
+
elif output_format == "AE":
|
| 156 |
+
output = f"Answer: The answer is {answer}. BECAUSE: {lecture}"
|
| 157 |
+
elif output_format == "ALE":
|
| 158 |
+
output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}"
|
| 159 |
+
elif output_format == "AEL":
|
| 160 |
+
output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}"
|
| 161 |
+
|
| 162 |
+
elif output_format == "LA":
|
| 163 |
+
output = f"Answer: {lecture} The answer is {answer}."
|
| 164 |
+
elif output_format == "EA":
|
| 165 |
+
output = f"Answer: {solution} The answer is {answer}."
|
| 166 |
+
elif output_format == "LEA":
|
| 167 |
+
output = f"Answer: {lecture} {solution} The answer is {answer}."
|
| 168 |
+
elif output_format == "ELA":
|
| 169 |
+
output = f"Answer: {solution} {lecture} The answer is {answer}."
|
| 170 |
+
|
| 171 |
+
text = input + output
|
| 172 |
+
text = text.replace(" ", " ").strip()
|
| 173 |
+
if text.endswith("BECAUSE:"):
|
| 174 |
+
text = text.replace("BECAUSE:", "").strip()
|
| 175 |
+
return text
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def create_one_example_gpt4(format, question, context, choice, answer, lecture, solution, test_example=True):
|
| 179 |
+
|
| 180 |
+
input_format, output_format = format.split("-")
|
| 181 |
+
|
| 182 |
+
## Inputs
|
| 183 |
+
if input_format == "CQM":
|
| 184 |
+
input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n"
|
| 185 |
+
elif input_format == "QCM":
|
| 186 |
+
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n"
|
| 187 |
+
# upper bound experiment
|
| 188 |
+
elif input_format == "QCML":
|
| 189 |
+
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n"
|
| 190 |
+
elif input_format == "QCME":
|
| 191 |
+
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n"
|
| 192 |
+
elif input_format == "QCMLE":
|
| 193 |
+
input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n"
|
| 194 |
+
|
| 195 |
+
elif input_format == "QCLM":
|
| 196 |
+
input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n"
|
| 197 |
+
elif input_format == "QCEM":
|
| 198 |
+
input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n"
|
| 199 |
+
elif input_format == "QCLEM":
|
| 200 |
+
input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n"
|
| 201 |
+
|
| 202 |
+
# Outputs
|
| 203 |
+
if test_example:
|
| 204 |
+
output = "Answer:"
|
| 205 |
+
elif output_format == "A":
|
| 206 |
+
output = f"Answer: The answer is {answer}."
|
| 207 |
+
|
| 208 |
+
elif output_format == "AL":
|
| 209 |
+
output = f"Answer: The answer is {answer}. BECAUSE: {solution}"
|
| 210 |
+
elif output_format == "AE":
|
| 211 |
+
output = f"Answer: The answer is {answer}. BECAUSE: {lecture}"
|
| 212 |
+
elif output_format == "ALE":
|
| 213 |
+
output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}"
|
| 214 |
+
elif output_format == "AEL":
|
| 215 |
+
output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}"
|
| 216 |
+
|
| 217 |
+
elif output_format == "LA":
|
| 218 |
+
output = f"Answer: {lecture} The answer is {answer}."
|
| 219 |
+
elif output_format == "EA":
|
| 220 |
+
output = f"Answer: {solution} The answer is {answer}."
|
| 221 |
+
elif output_format == "LEA":
|
| 222 |
+
output = f"Answer: {lecture} {solution} The answer is {answer}."
|
| 223 |
+
elif output_format == "ELA":
|
| 224 |
+
output = f"Answer: {solution} {lecture} The answer is {answer}."
|
| 225 |
+
|
| 226 |
+
input = input.replace(" ", " ").strip()
|
| 227 |
+
output = output.replace(" ", " ").strip()
|
| 228 |
+
if output.endswith("BECAUSE:"):
|
| 229 |
+
output = output.replace("BECAUSE:", "").strip()
|
| 230 |
+
|
| 231 |
+
user_prompt = {"role": "user", "content": f"Can you explain {input}?"}
|
| 232 |
+
assistant_prompt = {"role": "assistant", "content": f"{output}"}
|
| 233 |
+
|
| 234 |
+
return user_prompt, assistant_prompt
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def build_prompt_chatbot(
|
| 238 |
+
problems, shot_qids, prompt_format, use_caption=False, options=["A", "B", "C", "D", "E"], is_test=False
|
| 239 |
+
):
|
| 240 |
+
examples = {}
|
| 241 |
+
|
| 242 |
+
for qid in shot_qids:
|
| 243 |
+
question = get_question_text(problems[qid])
|
| 244 |
+
context = get_context_text(problems[qid], use_caption)
|
| 245 |
+
choice = get_choice_text(problems[qid], options)
|
| 246 |
+
answer = get_answer(problems[qid], options)
|
| 247 |
+
lecture = get_lecture_text(problems[qid]).replace("\\n", "\n")
|
| 248 |
+
solution = get_solution_text(problems[qid]).replace("\\n", "\n")
|
| 249 |
+
|
| 250 |
+
train_example = create_one_example_chatbot(
|
| 251 |
+
prompt_format, question, context, choice, answer, lecture, solution, test_example=is_test
|
| 252 |
+
)
|
| 253 |
+
examples[qid] = train_example
|
| 254 |
+
return examples
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def build_prompt(problems, shot_qids, test_qid, args):
|
| 258 |
+
|
| 259 |
+
examples = []
|
| 260 |
+
|
| 261 |
+
# n-shot training examples
|
| 262 |
+
for qid in shot_qids:
|
| 263 |
+
question = get_question_text(problems[qid])
|
| 264 |
+
context = get_context_text(problems[qid], args.use_caption)
|
| 265 |
+
choice = get_choice_text(problems[qid], args.options)
|
| 266 |
+
answer = get_answer(problems[qid], args.options)
|
| 267 |
+
lecture = get_lecture_text(problems[qid])
|
| 268 |
+
solution = get_solution_text(problems[qid])
|
| 269 |
+
|
| 270 |
+
train_example = create_one_example(
|
| 271 |
+
args.prompt_format, question, context, choice, answer, lecture, solution, test_example=False
|
| 272 |
+
)
|
| 273 |
+
examples.append(train_example)
|
| 274 |
+
|
| 275 |
+
# test example
|
| 276 |
+
question = get_question_text(problems[test_qid])
|
| 277 |
+
context = get_context_text(problems[test_qid], args.use_caption)
|
| 278 |
+
choice = get_choice_text(problems[test_qid], args.options)
|
| 279 |
+
answer = get_answer(problems[test_qid], args.options)
|
| 280 |
+
lecture = get_lecture_text(problems[test_qid])
|
| 281 |
+
solution = get_solution_text(problems[test_qid])
|
| 282 |
+
|
| 283 |
+
test_example = create_one_example(
|
| 284 |
+
args.prompt_format, question, context, choice, answer, lecture, solution, test_example=True
|
| 285 |
+
)
|
| 286 |
+
examples.append(test_example)
|
| 287 |
+
|
| 288 |
+
# create the prompt input
|
| 289 |
+
prompt_input = "\n\n".join(examples)
|
| 290 |
+
|
| 291 |
+
return prompt_input
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def build_prompt_gpt4(problems, shot_qids, test_qid, args):
|
| 295 |
+
|
| 296 |
+
prompt_array = [{"role": "system", "content": "You are a helpful assistant."}]
|
| 297 |
+
|
| 298 |
+
# n-shot training examples
|
| 299 |
+
for qid in shot_qids:
|
| 300 |
+
question = get_question_text(problems[qid])
|
| 301 |
+
context = get_context_text(problems[qid], args.use_caption)
|
| 302 |
+
choice = get_choice_text(problems[qid], args.options)
|
| 303 |
+
answer = get_answer(problems[qid], args.options)
|
| 304 |
+
lecture = get_lecture_text(problems[qid])
|
| 305 |
+
solution = get_solution_text(problems[qid])
|
| 306 |
+
|
| 307 |
+
user_prompt, assistant_prompt = create_one_example_gpt4(
|
| 308 |
+
args.prompt_format, question, context, choice, answer, lecture, solution, test_example=False
|
| 309 |
+
)
|
| 310 |
+
prompt_array.append(user_prompt)
|
| 311 |
+
prompt_array.append(assistant_prompt)
|
| 312 |
+
|
| 313 |
+
# test example
|
| 314 |
+
question = get_question_text(problems[test_qid])
|
| 315 |
+
context = get_context_text(problems[test_qid], args.use_caption)
|
| 316 |
+
choice = get_choice_text(problems[test_qid], args.options)
|
| 317 |
+
answer = get_answer(problems[test_qid], args.options)
|
| 318 |
+
lecture = get_lecture_text(problems[test_qid])
|
| 319 |
+
solution = get_solution_text(problems[test_qid])
|
| 320 |
+
|
| 321 |
+
user_prompt, assistant_prompt = create_one_example_gpt4(
|
| 322 |
+
args.prompt_format, question, context, choice, answer, lecture, solution, test_example=True
|
| 323 |
+
)
|
| 324 |
+
prompt_array.append(user_prompt)
|
| 325 |
+
prompt_array.append(assistant_prompt)
|
| 326 |
+
|
| 327 |
+
return prompt_array
|
VILA/scripts/convert_vizwiz_for_submission.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
|
| 21 |
+
from llava.eval.m4c_evaluator import EvalAIAnswerProcessor
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def parse_args():
|
| 25 |
+
parser = argparse.ArgumentParser()
|
| 26 |
+
parser.add_argument("--annotation-file", type=str, required=True)
|
| 27 |
+
parser.add_argument("--result-file", type=str, required=True)
|
| 28 |
+
parser.add_argument("--result-upload-file", type=str, required=True)
|
| 29 |
+
return parser.parse_args()
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if __name__ == "__main__":
|
| 33 |
+
|
| 34 |
+
args = parse_args()
|
| 35 |
+
|
| 36 |
+
os.makedirs(os.path.dirname(args.result_upload_file), exist_ok=True)
|
| 37 |
+
|
| 38 |
+
results = []
|
| 39 |
+
error_line = 0
|
| 40 |
+
for line_idx, line in enumerate(open(args.result_file)):
|
| 41 |
+
try:
|
| 42 |
+
results.append(json.loads(line))
|
| 43 |
+
except BaseException:
|
| 44 |
+
error_line += 1
|
| 45 |
+
results = {x["question_id"]: x["text"] for x in results}
|
| 46 |
+
test_split = [json.loads(line) for line in open(args.annotation_file)]
|
| 47 |
+
split_ids = {x["question_id"] for x in test_split}
|
| 48 |
+
|
| 49 |
+
print(f"total results: {len(results)}, total split: {len(test_split)}, error_line: {error_line}")
|
| 50 |
+
|
| 51 |
+
all_answers = []
|
| 52 |
+
|
| 53 |
+
answer_processor = EvalAIAnswerProcessor()
|
| 54 |
+
|
| 55 |
+
for x in test_split:
|
| 56 |
+
assert x["question_id"] in results
|
| 57 |
+
all_answers.append({"image": x["image"], "answer": answer_processor(results[x["question_id"]])})
|
| 58 |
+
|
| 59 |
+
with open(args.result_upload_file, "w") as f:
|
| 60 |
+
json.dump(all_answers, f)
|
VILA/scripts/convert_vqav2_for_submission.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
|
| 21 |
+
from llava.eval.m4c_evaluator import EvalAIAnswerProcessor
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def parse_args():
|
| 25 |
+
parser = argparse.ArgumentParser()
|
| 26 |
+
parser.add_argument("--dir", type=str, default="./playground/data/eval/vqav2")
|
| 27 |
+
parser.add_argument("--split", type=str, required=True)
|
| 28 |
+
return parser.parse_args()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if __name__ == "__main__":
|
| 32 |
+
|
| 33 |
+
args = parse_args()
|
| 34 |
+
|
| 35 |
+
src = os.path.join(args.dir, args.split, "answers", "merge.jsonl")
|
| 36 |
+
test_split = os.path.join(args.dir, "llava_vqav2_mscoco_test2015.jsonl")
|
| 37 |
+
dst = os.path.join(args.dir, args.split, f"{args.split}_answers_upload.json")
|
| 38 |
+
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
| 39 |
+
|
| 40 |
+
results = []
|
| 41 |
+
error_line = 0
|
| 42 |
+
for line_idx, line in enumerate(open(src)):
|
| 43 |
+
try:
|
| 44 |
+
results.append(json.loads(line))
|
| 45 |
+
except:
|
| 46 |
+
error_line += 1
|
| 47 |
+
|
| 48 |
+
results = {x["question_id"]: x["text"] for x in results}
|
| 49 |
+
test_split = [json.loads(line) for line in open(test_split)]
|
| 50 |
+
split_ids = {x["question_id"] for x in test_split}
|
| 51 |
+
|
| 52 |
+
print(f"total results: {len(results)}, total split: {len(test_split)}, error_line: {error_line}")
|
| 53 |
+
|
| 54 |
+
all_answers = []
|
| 55 |
+
|
| 56 |
+
answer_processor = EvalAIAnswerProcessor()
|
| 57 |
+
|
| 58 |
+
for x in test_split:
|
| 59 |
+
if x["question_id"] not in results:
|
| 60 |
+
all_answers.append({"question_id": x["question_id"], "answer": ""})
|
| 61 |
+
else:
|
| 62 |
+
all_answers.append({"question_id": x["question_id"], "answer": answer_processor(results[x["question_id"]])})
|
| 63 |
+
|
| 64 |
+
with open(dst, "w") as f:
|
| 65 |
+
json.dump(all_answers, open(dst, "w"))
|
VILA/scripts/extract_mm_projector.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 NVIDIA CORPORATION & AFFILIATES
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
#
|
| 15 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
from collections import defaultdict
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def parse_args():
|
| 26 |
+
parser = argparse.ArgumentParser(description="Extract MMProjector weights")
|
| 27 |
+
parser.add_argument("--model_name_or_path", type=str, help="model folder")
|
| 28 |
+
parser.add_argument("--output", type=str, help="output file")
|
| 29 |
+
args = parser.parse_args()
|
| 30 |
+
return args
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
if __name__ == "__main__":
|
| 34 |
+
args = parse_args()
|
| 35 |
+
|
| 36 |
+
keys_to_match = ["mm_projector", "embed_tokens", "transformer.wte"]
|
| 37 |
+
ckpt_to_key = defaultdict(list)
|
| 38 |
+
try:
|
| 39 |
+
model_indices = json.load(open(os.path.join(args.model_name_or_path, "pytorch_model.bin.index.json")))
|
| 40 |
+
for k, v in model_indices["weight_map"].items():
|
| 41 |
+
if any(key_match in k for key_match in keys_to_match):
|
| 42 |
+
ckpt_to_key[v].append(k)
|
| 43 |
+
except FileNotFoundError:
|
| 44 |
+
# Smaller models or model checkpoints saved by DeepSpeed.
|
| 45 |
+
v = "pytorch_model.bin"
|
| 46 |
+
for k in torch.load(os.path.join(args.model_name_or_path, v), map_location="cpu").keys():
|
| 47 |
+
if any(key_match in k for key_match in keys_to_match):
|
| 48 |
+
ckpt_to_key[v].append(k)
|
| 49 |
+
|
| 50 |
+
loaded_weights = {}
|
| 51 |
+
|
| 52 |
+
for ckpt_name, weight_keys in ckpt_to_key.items():
|
| 53 |
+
ckpt = torch.load(os.path.join(args.model_name_or_path, ckpt_name), map_location="cpu")
|
| 54 |
+
for k in weight_keys:
|
| 55 |
+
loaded_weights[k] = ckpt[k]
|
| 56 |
+
|
| 57 |
+
torch.save(loaded_weights, args.output)
|
VILA/scripts/zero2.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": "auto",
|
| 4 |
+
"loss_scale": 0,
|
| 5 |
+
"loss_scale_window": 1000,
|
| 6 |
+
"initial_scale_power": 16,
|
| 7 |
+
"hysteresis": 2,
|
| 8 |
+
"min_loss_scale": 1
|
| 9 |
+
},
|
| 10 |
+
"bf16": {
|
| 11 |
+
"enabled": "auto"
|
| 12 |
+
},
|
| 13 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 14 |
+
"train_batch_size": "auto",
|
| 15 |
+
"gradient_accumulation_steps": "auto",
|
| 16 |
+
"zero_optimization": {
|
| 17 |
+
"stage": 2,
|
| 18 |
+
"overlap_comm": true,
|
| 19 |
+
"contiguous_gradients": true,
|
| 20 |
+
"sub_group_size": 1e9,
|
| 21 |
+
"reduce_bucket_size": "auto"
|
| 22 |
+
}
|
| 23 |
+
}
|
VILA/scripts/zero3.json
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": "auto",
|
| 4 |
+
"loss_scale": 0,
|
| 5 |
+
"loss_scale_window": 1000,
|
| 6 |
+
"initial_scale_power": 16,
|
| 7 |
+
"hysteresis": 2,
|
| 8 |
+
"min_loss_scale": 1
|
| 9 |
+
},
|
| 10 |
+
"bf16": {
|
| 11 |
+
"enabled": "auto"
|
| 12 |
+
},
|
| 13 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 14 |
+
"train_batch_size": "auto",
|
| 15 |
+
"gradient_accumulation_steps": "auto",
|
| 16 |
+
"zero_optimization": {
|
| 17 |
+
"stage": 3,
|
| 18 |
+
"overlap_comm": true,
|
| 19 |
+
"contiguous_gradients": true,
|
| 20 |
+
"sub_group_size": 1e9,
|
| 21 |
+
"reduce_bucket_size": "auto",
|
| 22 |
+
"stage3_prefetch_bucket_size": "auto",
|
| 23 |
+
"stage3_param_persistence_threshold": "auto",
|
| 24 |
+
"stage3_max_live_parameters": 1e9,
|
| 25 |
+
"stage3_max_reuse_distance": 1e9,
|
| 26 |
+
"stage3_gather_16bit_weights_on_model_save": true
|
| 27 |
+
}
|
| 28 |
+
}
|
VILA/scripts/zero3_mics_mini_fixed.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": "auto",
|
| 4 |
+
"loss_scale": 0,
|
| 5 |
+
"loss_scale_window": 1000,
|
| 6 |
+
"initial_scale_power": 16,
|
| 7 |
+
"hysteresis": 2,
|
| 8 |
+
"min_loss_scale": 1
|
| 9 |
+
},
|
| 10 |
+
"bf16": {
|
| 11 |
+
"enabled": "auto"
|
| 12 |
+
},
|
| 13 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 14 |
+
"train_batch_size": "auto",
|
| 15 |
+
"gradient_accumulation_steps": "auto",
|
| 16 |
+
"zero_optimization": {
|
| 17 |
+
"stage": 3,
|
| 18 |
+
"overlap_comm": true,
|
| 19 |
+
"contiguous_gradients": true,
|
| 20 |
+
"sub_group_size": 1e9,
|
| 21 |
+
"reduce_bucket_size": 4e8,
|
| 22 |
+
"stage3_prefetch_bucket_size": 4e8,
|
| 23 |
+
"stage3_param_persistence_threshold": 1e4,
|
| 24 |
+
"stage3_max_live_parameters": 1e9,
|
| 25 |
+
"stage3_max_reuse_distance": 1e9,
|
| 26 |
+
"stage3_gather_16bit_weights_on_model_save": true,
|
| 27 |
+
"mics_shard_size": 64,
|
| 28 |
+
"mics_hierarchical_params_gather": false
|
| 29 |
+
}
|
| 30 |
+
}
|
VILA/scripts/zero3_mics_tiny_fixed.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": "auto",
|
| 4 |
+
"loss_scale": 0,
|
| 5 |
+
"loss_scale_window": 1000,
|
| 6 |
+
"initial_scale_power": 16,
|
| 7 |
+
"hysteresis": 2,
|
| 8 |
+
"min_loss_scale": 1
|
| 9 |
+
},
|
| 10 |
+
"bf16": {
|
| 11 |
+
"enabled": "auto"
|
| 12 |
+
},
|
| 13 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 14 |
+
"train_batch_size": "auto",
|
| 15 |
+
"gradient_accumulation_steps": "auto",
|
| 16 |
+
"zero_optimization": {
|
| 17 |
+
"stage": 3,
|
| 18 |
+
"overlap_comm": true,
|
| 19 |
+
"contiguous_gradients": true,
|
| 20 |
+
"sub_group_size": 1e9,
|
| 21 |
+
"reduce_bucket_size": 4e8,
|
| 22 |
+
"stage3_prefetch_bucket_size": 4e8,
|
| 23 |
+
"stage3_param_persistence_threshold": 1e4,
|
| 24 |
+
"stage3_max_live_parameters": 1e9,
|
| 25 |
+
"stage3_max_reuse_distance": 1e9,
|
| 26 |
+
"stage3_gather_16bit_weights_on_model_save": true,
|
| 27 |
+
"mics_shard_size": 16,
|
| 28 |
+
"mics_hierarchical_params_gather": false
|
| 29 |
+
}
|
| 30 |
+
}
|
VILA/scripts/zero3_offload.json
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": "auto",
|
| 4 |
+
"loss_scale": 0,
|
| 5 |
+
"loss_scale_window": 1000,
|
| 6 |
+
"initial_scale_power": 16,
|
| 7 |
+
"hysteresis": 2,
|
| 8 |
+
"min_loss_scale": 1
|
| 9 |
+
},
|
| 10 |
+
"bf16": {
|
| 11 |
+
"enabled": "auto"
|
| 12 |
+
},
|
| 13 |
+
"optimizer": {
|
| 14 |
+
"type": "AdamW",
|
| 15 |
+
"params": {
|
| 16 |
+
"lr": "auto",
|
| 17 |
+
"betas": "auto",
|
| 18 |
+
"eps": "auto",
|
| 19 |
+
"weight_decay": "auto"
|
| 20 |
+
}
|
| 21 |
+
},
|
| 22 |
+
"scheduler": {
|
| 23 |
+
"type": "WarmupLR",
|
| 24 |
+
"params": {
|
| 25 |
+
"warmup_min_lr": "auto",
|
| 26 |
+
"warmup_max_lr": "auto",
|
| 27 |
+
"warmup_num_steps": "auto"
|
| 28 |
+
}
|
| 29 |
+
},
|
| 30 |
+
"zero_optimization": {
|
| 31 |
+
"stage": 3,
|
| 32 |
+
"offload_optimizer": {
|
| 33 |
+
"device": "cpu",
|
| 34 |
+
"pin_memory": true
|
| 35 |
+
},
|
| 36 |
+
"offload_param": {
|
| 37 |
+
"device": "cpu",
|
| 38 |
+
"pin_memory": true
|
| 39 |
+
},
|
| 40 |
+
"overlap_comm": true,
|
| 41 |
+
"contiguous_gradients": true,
|
| 42 |
+
"sub_group_size": 1e9,
|
| 43 |
+
"reduce_bucket_size": "auto",
|
| 44 |
+
"stage3_prefetch_bucket_size": "auto",
|
| 45 |
+
"stage3_param_persistence_threshold": "auto",
|
| 46 |
+
"stage3_max_live_parameters": 1e9,
|
| 47 |
+
"stage3_max_reuse_distance": 1e9,
|
| 48 |
+
"gather_16bit_weights_on_model_save": true
|
| 49 |
+
},
|
| 50 |
+
"gradient_accumulation_steps": "auto",
|
| 51 |
+
"gradient_clipping": "auto",
|
| 52 |
+
"train_batch_size": "auto",
|
| 53 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 54 |
+
"steps_per_print": 1e5,
|
| 55 |
+
"wall_clock_breakdown": false
|
| 56 |
+
}
|
VILA/scripts/zero3_offload_inference.json
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bf16": {
|
| 3 |
+
"enabled": "auto"
|
| 4 |
+
},
|
| 5 |
+
"fp16": {
|
| 6 |
+
"enabled": "auto"
|
| 7 |
+
},
|
| 8 |
+
"zero_optimization": {
|
| 9 |
+
"stage": 3,
|
| 10 |
+
"stage3_prefetch_bucket_size": 33554432,
|
| 11 |
+
"stage3_param_persistence_threshold": 4096,
|
| 12 |
+
"stage3_max_live_parameters":33554432,
|
| 13 |
+
"offload_param": {
|
| 14 |
+
"device": "cpu",
|
| 15 |
+
"pin_memory": true
|
| 16 |
+
}
|
| 17 |
+
},
|
| 18 |
+
"train_batch_size": 8,
|
| 19 |
+
"train_micro_batch_size_per_gpu": 1,
|
| 20 |
+
"wall_clock_breakdown": false
|
| 21 |
+
}
|
VILA/scripts/zero3pp.json
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"fp16": {
|
| 3 |
+
"enabled": "auto",
|
| 4 |
+
"loss_scale": 0,
|
| 5 |
+
"loss_scale_window": 1000,
|
| 6 |
+
"initial_scale_power": 16,
|
| 7 |
+
"hysteresis": 2,
|
| 8 |
+
"min_loss_scale": 1
|
| 9 |
+
},
|
| 10 |
+
"bf16": {
|
| 11 |
+
"enabled": "auto"
|
| 12 |
+
},
|
| 13 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 14 |
+
"train_batch_size": "auto",
|
| 15 |
+
"gradient_accumulation_steps": "auto",
|
| 16 |
+
"zero_optimization": {
|
| 17 |
+
"stage": 3,
|
| 18 |
+
"overlap_comm": true,
|
| 19 |
+
"contiguous_gradients": true,
|
| 20 |
+
"sub_group_size": 1e9,
|
| 21 |
+
"reduce_bucket_size": 1e6,
|
| 22 |
+
"stage3_prefetch_bucket_size": 1e6,
|
| 23 |
+
"stage3_param_persistence_threshold": 1e4,
|
| 24 |
+
"stage3_max_live_parameters": 1e9,
|
| 25 |
+
"stage3_max_reuse_distance": 1e9,
|
| 26 |
+
"stage3_gather_16bit_weights_on_model_save": true,
|
| 27 |
+
"zero_hpz_partition_size": 8
|
| 28 |
+
}
|
| 29 |
+
}
|