tuandunghcmut commited on
Commit
1c3d47d
·
verified ·
1 Parent(s): 637b6eb

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +6 -0
  2. VILA/.ipynb_checkpoints/Dockerfile-checkpoint +18 -0
  3. VILA/.ipynb_checkpoints/README-checkpoint.md +341 -0
  4. VILA/.ipynb_checkpoints/environment_setup-checkpoint.sh +33 -0
  5. VILA/CIs/license_all.sh +1 -0
  6. VILA/CIs/license_commited.sh +6 -0
  7. VILA/data_prepare/.DS_Store +0 -0
  8. VILA/data_prepare/LICENSE +8 -0
  9. VILA/data_prepare/README.md +172 -0
  10. VILA/data_prepare/panda70m.sh +34 -0
  11. VILA/data_prepare/panda_split.py +117 -0
  12. VILA/data_prepare/parallel_shards.sh +29 -0
  13. VILA/demo_images/LongVILA-pipeline.png +3 -0
  14. VILA/demo_images/av.png +3 -0
  15. VILA/demo_images/demo_img_1.png +3 -0
  16. VILA/demo_images/demo_img_2.png +3 -0
  17. VILA/demo_images/demo_img_3.png +3 -0
  18. VILA/demo_images/longvila-logo.png +3 -0
  19. VILA/demo_images/vila-logo.jpg +0 -0
  20. VILA/demo_trt_llm/README.md +3 -0
  21. VILA/inference_test/inference_test.json +546 -0
  22. VILA/inference_test/inference_test.py +153 -0
  23. VILA/llava.egg-info/PKG-INFO +287 -0
  24. VILA/llava.egg-info/SOURCES.txt +154 -0
  25. VILA/llava.egg-info/dependency_links.txt +1 -0
  26. VILA/llava.egg-info/requires.txt +37 -0
  27. VILA/llava.egg-info/top_level.txt +7 -0
  28. VILA/llava/.DS_Store +0 -0
  29. VILA/llava/constants.py +31 -0
  30. VILA/llava/conversation.py +489 -0
  31. VILA/llava/entry.py +18 -0
  32. VILA/llava/mm_utils.py +407 -0
  33. VILA/llava/modals.py +26 -0
  34. VILA/scripts/convert_gqa_for_eval.py +33 -0
  35. VILA/scripts/convert_karpathy_to_anno.py +130 -0
  36. VILA/scripts/convert_mmbench_for_submission.py +46 -0
  37. VILA/scripts/convert_mmvet_for_eval.py +33 -0
  38. VILA/scripts/convert_seed_for_submission.py +88 -0
  39. VILA/scripts/convert_sqa_to_llava.py +104 -0
  40. VILA/scripts/convert_sqa_to_llava_base_prompt.py +327 -0
  41. VILA/scripts/convert_vizwiz_for_submission.py +60 -0
  42. VILA/scripts/convert_vqav2_for_submission.py +65 -0
  43. VILA/scripts/extract_mm_projector.py +57 -0
  44. VILA/scripts/zero2.json +23 -0
  45. VILA/scripts/zero3.json +28 -0
  46. VILA/scripts/zero3_mics_mini_fixed.json +30 -0
  47. VILA/scripts/zero3_mics_tiny_fixed.json +30 -0
  48. VILA/scripts/zero3_offload.json +56 -0
  49. VILA/scripts/zero3_offload_inference.json +21 -0
  50. VILA/scripts/zero3pp.json +29 -0
.gitattributes CHANGED
@@ -370,3 +370,9 @@ groundingLMM/gradio-dev/demo/video_identity/video/video_sample.mp4 filter=lfs di
370
  groundingLMM/gradio-dev/demo/video_subtitle/files/a.mp4 filter=lfs diff=lfs merge=lfs -text
371
  groundingLMM/gradio-dev/demo/video_subtitle/files/b.mp4 filter=lfs diff=lfs merge=lfs -text
372
  groundingLMM/gradio-dev/demo/unispeech-speaker-verification/samples/kirsten_dunst.wav filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
370
  groundingLMM/gradio-dev/demo/video_subtitle/files/a.mp4 filter=lfs diff=lfs merge=lfs -text
371
  groundingLMM/gradio-dev/demo/video_subtitle/files/b.mp4 filter=lfs diff=lfs merge=lfs -text
372
  groundingLMM/gradio-dev/demo/unispeech-speaker-verification/samples/kirsten_dunst.wav filter=lfs diff=lfs merge=lfs -text
373
+ VILA/demo_images/demo_img_3.png filter=lfs diff=lfs merge=lfs -text
374
+ VILA/demo_images/LongVILA-pipeline.png filter=lfs diff=lfs merge=lfs -text
375
+ VILA/demo_images/longvila-logo.png filter=lfs diff=lfs merge=lfs -text
376
+ VILA/demo_images/demo_img_2.png filter=lfs diff=lfs merge=lfs -text
377
+ VILA/demo_images/demo_img_1.png filter=lfs diff=lfs merge=lfs -text
378
+ VILA/demo_images/av.png filter=lfs diff=lfs merge=lfs -text
VILA/.ipynb_checkpoints/Dockerfile-checkpoint ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvcr.io/nvidia/pytorch:24.06-py3
2
+
3
+ WORKDIR /app
4
+
5
+ RUN curl https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -o ~/miniconda.sh \
6
+ && sh ~/miniconda.sh -b -p /opt/conda \
7
+ && rm ~/miniconda.sh
8
+
9
+ ENV PATH /opt/conda/bin:$PATH
10
+ COPY pyproject.toml pyproject.toml
11
+ COPY llava llava
12
+
13
+ COPY environment_setup.sh environment_setup.sh
14
+ RUN bash environment_setup.sh vila
15
+
16
+
17
+ COPY server.py server.py
18
+ CMD ["conda", "run", "-n", "vila", "--no-capture-output", "python", "-u", "-W", "ignore", "server.py"]
VILA/.ipynb_checkpoints/README-checkpoint.md ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <img src="demo_images/vila-logo.jpg" width="20%"/>
3
+ </p>
4
+
5
+ # VILA: On Pre-training for Visual Language Models
6
+
7
+ [![Code License](https://img.shields.io/badge/Code%20License-Apache_2.0-green.svg)](CODE_LICENSE)
8
+ [![Model License](https://img.shields.io/badge/MODEL%20License-CC%20By%20NC%204.0-red.svg)](MODEL_LICENSE)
9
+ [![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/release/python-3100/)
10
+
11
+ [VILA arxiv](https://arxiv.org/abs/2312.07533) / [VILA Demo](https://vila-demo.hanlab.ai/) / [VILA Huggingface](https://huggingface.co/collections/Efficient-Large-Model/vila-on-pre-training-for-visual-language-models-65d8022a3a52cd9bcd62698e)
12
+
13
+ ## 💡 Introduction
14
+
15
+ VILA is a visual language model (VLM) pretrained with interleaved image-text data at scale, enabling **video understanding** and **multi-image understanding** capabilities. VILA is deployable on the edge by [AWQ](https://arxiv.org/pdf/2306.00978.pdf) 4bit quantization and [TinyChat](https://github.com/mit-han-lab/llm-awq/tree/main/tinychat) framework. We find: (1) image-text pairs are not enough, interleaved image-text is essential; (2) unfreezing LLM during interleaved image-text pre-training enables in-context learning; (3)re-blending text-only instruction data is crucial to boost both VLM and text-only performance; (4) token compression extends #video frames. VILA unveils appealing capabilities, including: video reasoning, in-context learning, visual chain-of-thought, and better world knowledge.
16
+
17
+ ## 💡 News
18
+ - [2024/08] We release [LongVILA](./LongVILA.md) that supports long video understanding (Captioning, QA, Needle-in-a-Haystack) up to 1024 frames.
19
+ - [2024/07] VILA1.5 also ranks 1st place (OSS model) on [MLVU test leaderboard](https://github.com/JUNJIE99/MLVU).
20
+ - [2024/06] VILA1.5 is now the best open sourced VLM on [MMMU leaderboard](https://mmmu-benchmark.github.io/#leaderboard) and [Video-MME](https://video-mme.github.io/home_page.html#leaderboard) leaderboard!
21
+ - [2024/05] We release VILA-1.5, which offers **video understanding capability**. VILA-1.5 comes with four model sizes: 3B/8B/13B/40B.
22
+ - [2024/05] We release [AWQ](https://arxiv.org/pdf/2306.00978.pdf)-quantized 4bit VILA-1.5 models. VILA-1.5 is efficiently deployable on diverse NVIDIA GPUs (A100, 4090, 4070 Laptop, Orin, Orin Nano) by [TinyChat](https://github.com/mit-han-lab/llm-awq/tree/main/tinychat) and [TensorRT-LLM](demo_trt_llm) backends.
23
+ - [2024/03] VILA has been accepted by CVPR 2024!
24
+ - [2024/02] We release [AWQ](https://arxiv.org/pdf/2306.00978.pdf)-quantized 4bit VILA models, deployable on Jetson Orin and laptops through [TinyChat](https://github.com/mit-han-lab/llm-awq/tree/main/tinychat) and [TinyChatEngine](https://github.com/mit-han-lab/TinyChatEngine).
25
+ - [2024/02] VILA is released. We propose interleaved image-text pretraining that enables **multi-image** VLM. VILA comes with impressive in-context learning capabilities. We open source everything: including training code, evaluation code, datasets, model ckpts.
26
+ - [2023/12] [Paper](https://arxiv.org/abs/2312.07533) is on Arxiv!
27
+
28
+ ## Performance
29
+
30
+ ### Image QA Benchmarks
31
+
32
+ | $~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~$ | Prec. | VQAv2 | GQA | VizWiz | SQA-I | VQA-T | POPE | MME | MMB | MMB-CN | SEED | SEED-I | MMMU (val) | MMMU (test) | llava-bench | MM-Vet | Average |
33
+ | -------------------------------- | ----- | ----- | ---- | ------ | ----- | ----- | ---- | ------- | ---- | ------ | ---- | ------ | ---------- | ----------- | ----------- | ------ | ------- |
34
+ | VILA1.5-3B | fp16 | 80.4 | 61.5 | 53.5 | 69.0 | 60.4 | 85.9 | 1442.44 | 63.4 | 52.7 | 60.9 | 67.9 | 33.3 | 30.8 | 75.9 | 35.4 | 60.2 |
35
+ | VILA1.5-3B-AWQ | int4 | 80.0 | 61.1 | 53.8 | 67.8 | 60.4 | 85.9 | 1437.34 | 63.3 | 51.4 | 59.8 | 66.6 | 32.7 | 31.1 | 75.0 | 37.3 | 59.9 |
36
+ | VILA1.5-3B-S2 | fp16 | 79.8 | 61.4 | 61.3 | 69.6 | 63.4 | 85.3 | 1431.65 | 62.8 | 52.2 | 60.0 | 66.4 | 32.8 | 31.3 | 76.7 | 38.6 | 60.9 |
37
+ | VILA1.5-3B-S2-AWQ | int4 | 79.4 | 61.3 | 62.3 | 69.2 | 63.0 | 85.8 | 1417.06 | 61.6 | 51.5 | 59.1 | 65.7 | 33.4 | 30.4 | 77.1 | 36.7 | 60.5 |
38
+ | Llama-3-VILA1.5-8B | fp16 | 83.0 | 63.5 | 63.2 | 82.0 | 68.5 | 85.6 | 1634.91 | 75.3 | 69.9 | 66.4 | 73.8 | 38.6 | 32.7 | 71.9 | 43.2 | 66.6 |
39
+ | Llama-3-VILA1.5-8B-AWQ | int4 | 80.3 | 61.7 | 59.3 | 79.0 | 65.4 | 82.9 | 1593.65 | 71.0 | 64.9 | 64.0 | 71.1 | 36.0 | 36.1 | 79.0 | 37.2 | 64.5 |
40
+ | VILA1.5-13B | fp16 | 82.8 | 64.3 | 62.6 | 80.1 | 65.0 | 86.3 | 1569.55 | 74.9 | 66.3 | 65.1 | 72.6 | 37.9 | 33.6 | 80.8 | 44.3 | 66.3 |
41
+ | VILA1.5-13B-AWQ | int4 | 82.7 | 64.5 | 63.3 | 79.7 | 64.7 | 86.7 | 1531.35 | 74.7 | 66.7 | 65.1 | 72.6 | 37.8 | 34.0 | 81.9 | 46.4 | 66.5 |
42
+ | VILA1.5-40B | fp16 | 84.3 | 64.6 | 62.2 | 87.2 | 73.6 | 87.3 | 1726.82 | 82.4 | 80.2 | 69.1 | 75.8 | 51.9 | 46.9 | 81.3 | 53.0 | 72.4 |
43
+ | VILA1.5-40B-AWQ | int4 | 84.1 | 64.4 | 61.3 | 86.7 | 73.2 | 88.2 | 1714.79 | 83.2 | 79.6 | 68.9 | 75.6 | 49.3 | 46.2 | 83.0 | 51.4 | 72.1 |
44
+
45
+ <sup>NOTE: VQAV2 and VizWiz are test-dev, the average accuracy is calculated over all datasets and MME numbers are divided by 20.</sup>
46
+
47
+ ### Video QA Benchmarks
48
+
49
+ | $~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~$ | Prec. | Perception Test | ActivityNet | MSVD | MSRVTT | TGIF | EgoSchema (test) | CinePile
50
+ | -------------------------------- | ----- | ----- | ---- | ------ | ----- | ----- | ----- | ----- |
51
+ | VILA1.5-3B | fp16 | 47 | 50.2 | 76.6 | 57.5 | 51.7 | 42.6 | 37.9
52
+ | VILA1.5-3B-S2 | fp16 | 49.7 | 50.7 | 76.9 | 57.6 | 51.7 |
53
+ | Llama-3-VILA1.5-8B | fp16 | 54.1 | 54.3 | 78.3 | 60.1 | 54.1 | 50.4 | 48.7
54
+ | VILA1.5-13B | fp16 | 53.6 | 54.7 | 77.9 | 60.2 | 56 | 52.2 | 50.1
55
+ | VILA1.5-40B | fp16 | 54 | 58 | 80.1 | 63 | 58.2 | 58.7 | 51.3
56
+
57
+ ### Inference speed ( Token/sec )
58
+
59
+ | $~~~~~~$ | Precision | A100 | 4090 | Orin |
60
+ | ---------------------- | --------- | ----- | ----- | ---- |
61
+ | VILA1.5-3B | fp16 | 104.6 | 137.6 | 25.4 |
62
+ | VILA1.5-3B-AWQ | int4 | 182.8 | 215.5 | 42.5 |
63
+ | VILA1.5-3B-S2 | fp16 | 104.3 | 137.2 | 24.6 |
64
+ | VILA1.5-3B-S2-AWQ | int4 | 180.2 | 219.3 | 40.1 |
65
+ | Llama-3-VILA1.5-8B | fp16 | 74.9 | 57.4 | 10.2 |
66
+ | Llama-3-VILA1.5-8B-AWQ | int4 | 168.9 | 150.2 | 28.7 |
67
+ | VILA1.5-13B | fp16 | 50.9 | OOM | 6.1 |
68
+ | VILA1.5-13B-AWQ | int4 | 115.9 | 105.7 | 20.6 |
69
+ | VILA1.5-40B | fp16 | OOM | OOM | -- |
70
+ | VILA1.5-40B-AWQ | int4 | 57.0 | OOM | -- |
71
+
72
+ <sup>NOTE: Measured using the [TinyChat](https://github.com/mit-han-lab/llm-awq/tinychat) backend at batch size = 1.</sup>
73
+
74
+ ## VILA Examples
75
+
76
+ ### Video captioning
77
+
78
+ https://github.com/Efficient-Large-Model/VILA/assets/156256291/c9520943-2478-4f97-bc95-121d625018a6
79
+
80
+ Prompt: Elaborate on the visual and narrative elements of the video in detail.
81
+
82
+ Caption: The video shows a person's hands working on a white surface. They are folding a piece of fabric with a checkered pattern in shades of blue and white. The fabric is being folded into a smaller, more compact shape. The person's fingernails are painted red, and they are wearing a black and red garment. There are also a ruler and a pencil on the surface, suggesting that measurements and precision are involved in the process.
83
+
84
+ ### In context learning
85
+
86
+ <img src="demo_images/demo_img_1.png" height="239">
87
+ <img src="demo_images/demo_img_2.png" height="250">
88
+
89
+ ### Multi-image reasoning
90
+
91
+ <img src="demo_images/demo_img_3.png" height="193">
92
+
93
+ ### VILA on Jetson Orin
94
+
95
+ https://github.com/Efficient-Large-Model/VILA/assets/7783214/6079374c-0787-4bc4-b9c6-e1524b4c9dc4
96
+
97
+ ### VILA on RTX 4090
98
+
99
+ https://github.com/Efficient-Large-Model/VILA/assets/7783214/80c47742-e873-4080-ad7d-d17c4700539f
100
+
101
+ </details>
102
+
103
+ ## Installation
104
+
105
+ ```bash
106
+ ./environment_setup.sh vila
107
+ ```
108
+
109
+ ## Training
110
+
111
+ VILA training contains three steps, for specific hyperparameters, please check out the [scripts/v1_5](scripts/v1_5) folder:
112
+
113
+ ### Step-1: Alignment
114
+
115
+ We utilize LLaVA-CC3M-Pretrain-595K dataset to align the textual and visual modalities.
116
+
117
+ The stage 1 script takes in two parameters and it can run on a single 8xA100 node. `BASE_MODEL_PATH` points to a online or local huggingface repository, such as `NousResearch/Llama-2-7b-hf`. `OUTPUT_NAME` points to a target directory under `checkpoints`, which will save the trained multimodal projector afterwards.
118
+
119
+ ```bash
120
+ bash scripts/v1_5/paper/1_mm_align.sh [BASE_MODEL_PATH] [OUTPUT_NAME]
121
+ ```
122
+
123
+ ### Step-2: Pretraining
124
+
125
+ We use MMC4 and Coyo dataset to train VLM with interleaved image-text pairs.
126
+
127
+ ```bash
128
+ bash scripts/v1_5/paper/2_pretrain_mmc4_coyo.sh [CODE_PATH] [BASE_MODEL_PATH] [STAGE1_PATH] [OUTPUT_NAME]
129
+ ```
130
+
131
+ The stage 2 script takes in four arguments. `CODE_PATH` is the absolute path to our VILA codebase, `BASE_MODEL_PATH` has similar meaning to what is presented in the stage 1 script. `STAGE1_PATH` points to the `OUTPUT_NAME` of stage 1 (i.e. where the stage 1 checkpoint is stored). `OUTPUT_NAME` is the desired folder name under `checkpoints` that saves the pretraining checkpoint. The script we provided for this stage is executed on slurm, and we expect it to execute on 16 nodes (128 GPUs).
132
+
133
+ ### Step-3: Supervised fine-tuning
134
+
135
+ This is the last stage of VILA training, in which we tune the model to follow multimodal instructions on a subset of M3IT, FLAN and ShareGPT4V. This stage runs on a 8xA100 node.
136
+
137
+ ```bash
138
+ bash scripts/v1_5/paper/3_sft.sh [STAGE2_PATH] [OUTPUT_NAME]
139
+ ```
140
+
141
+ The stage 3 script takes in two arguments. `STAGE2_PATH` points to the `OUTPUT_NAME` of the stage 2 script (i.e. where the stage 2 checkpoint is stored). `OUTPUT_NAME` is the desired folder name under `checkpoints` that stores the final checkpoint.
142
+
143
+ ## Evaluations
144
+
145
+ ### Image Benchmarks
146
+
147
+ You can follow [Llava1.5 eval](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md) to download all datasets. After downloading all datasets, please put them under `playground/data/eval`.
148
+
149
+ Please make the following changes to the MME evaluation script. Please search for:
150
+
151
+ ```python
152
+ data_path = "MME_Benchmark_release_version"
153
+ ```
154
+
155
+ and replace it with:
156
+
157
+ ```python
158
+ data_path = os.path.join(script_dir, "MME_Benchmark_release_version")
159
+ ```
160
+
161
+ We provide a push-the-button script to perform evaluation on all 10 datasets that do not require GPT-assisted evaluation:
162
+
163
+ ```bash
164
+ ./scripts/v1_5/eval/eval_all.sh [CHECKPOINT_PATH] [MODEL_NAME] [CONV_MODE]
165
+ ```
166
+
167
+ This script takes in two parameters, `CHECKPOINT_PATH` points to the stage 3 model checkpoint, and `MODEL_NAME` will be the name of evaluation results.
168
+
169
+ [VQAv2](https://eval.ai/web/challenges/challenge-page/830/my-submission) and [Vizwiz](https://eval.ai/web/challenges/challenge-page/2185/my-submission) evaluations are hosted on eval.ai. You need to register an account and create a team to be able to submit eval.
170
+
171
+ MMBench and MMBench_CN eval are hosted on another [evaluation server](https://opencompass.org.cn/leaderboard-multimodal). Make sure you change the name of the file before submitting, otherwise the server caches results and will always return wrong result to you.
172
+
173
+ We provide a quick script to automatically organize the prediction files that need to be submitted to servers:
174
+
175
+ ```bash
176
+ python scripts/v1_5/eval/copy_predictions.py [MODEL_NAME]
177
+ ```
178
+
179
+ You will be able to find the predictions under `playground/data/predictions_upload/[MODEL_NAME]` after executing this script.
180
+
181
+ ### Video Benchmarks
182
+
183
+ Please follow the evaluation steps in [Video-LLaVA](https://github.com/PKU-YuanGroup/Video-LLaVA/blob/main/TRAIN_AND_VALIDATE.md#data-for-validating) for dataset preparation.
184
+
185
+ ```bash
186
+ ./scripts/v1_5/eval/video_chatgpt/run_all.sh [CHECKPOINT_PATH] [MODEL_NAME] [CONV_MODE]
187
+ ./scripts/v1_5/eval/video_chatgpt/eval_all.sh [MODEL_NAME]
188
+ ```
189
+
190
+ ## Inference
191
+
192
+ We provide snippets for quick inference with user prompts and images.
193
+
194
+ Llama-3-VILA1.5-8B inference:
195
+
196
+ ```bash
197
+ python -W ignore llava/eval/run_vila.py \
198
+ --model-path Efficient-Large-Model/Llama-3-VILA1.5-8b-Fix \
199
+ --conv-mode llama_3 \
200
+ --query "<image>\n Please describe the traffic condition." \
201
+ --image-file "av.png"
202
+ ```
203
+
204
+ VILA1.5-40B inference:
205
+
206
+ ```bash
207
+ python -W ignore llava/eval/run_vila.py \
208
+ --model-path Efficient-Large-Model/VILA1.5-40b \
209
+ --conv-mode hermes-2 \
210
+ --query "<image>\n Please describe the traffic condition." \
211
+ --image-file "av.png"
212
+ ```
213
+
214
+ VILA1.5-3B video inference:
215
+
216
+ ```bash
217
+ python -W ignore llava/eval/run_vila.py \
218
+ --model-path Efficient-Large-Model/VILA1.5-3b \
219
+ --conv-mode vicuna_v1 \
220
+ --query "<video>\n Please describe this video." \
221
+ --video-file "demo.mp4"
222
+ ```
223
+
224
+ ## Quantization and Deployment
225
+
226
+ Our VILA models are quantized by [AWQ](https://arxiv.org/abs/2306.00978) into 4 bits for efficient inference on the edge. We provide a push-the-button [script](https://github.com/mit-han-lab/llm-awq/blob/main/scripts/vila_example.sh) to quantize VILA with AWQ.
227
+
228
+ ### Running VILA on desktop GPUs and edge GPUs
229
+
230
+ We support AWQ-quantized 4bit VILA on GPU platforms via [TinyChat](https://github.com/mit-han-lab/llm-awq/tree/main/tinychat). We provide a [tutorial](https://github.com/mit-han-lab/llm-awq/tree/main/tinychat#support-vlm-models-vila--llava) to run the model with TinyChat after quantization. We also provide an [instruction](https://github.com/mit-han-lab/llm-awq/tree/main/tinychat/serve) to launch a Gradio server (powered by TinyChat and AWQ) to serve 4-bit quantized VILA models.
231
+
232
+ ### Running VILA on laptops
233
+
234
+ We further support our AWQ-quantized 4bit VILA models on various CPU platforms with both x86 and ARM architectures with our [TinyChatEngine](https://github.com/mit-han-lab/TinyChatEngine). We also provide a detailed [tutorial](https://github.com/mit-han-lab/TinyChatEngine/tree/main?tab=readme-ov-file#deploy-vision-language-model-vlm-chatbot-with-tinychatengine) to help the users deploy VILA on different CPUs.
235
+
236
+ ### Running VILA API server
237
+
238
+ A simple API server has been provided to serve VILA models. The server is built on top of [FastAPI](https://fastapi.tiangolo.com/) and [Huggingface Transformers](https://huggingface.co/transformers/). The server can be run with the following command:
239
+
240
+ #### With CLI
241
+
242
+ ```bash
243
+ python -W ignore server.py \
244
+ --port 8000 \
245
+ --model-path Efficient-Large-Model/VILA1.5-3B \
246
+ --conv-mode vicuna_v1
247
+ ```
248
+
249
+ #### With Docker
250
+
251
+ ```bash
252
+ docker build -t vila-server:latest .
253
+ docker run --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 \
254
+ -v ./hub:/root/.cache/huggingface/hub \
255
+ -it --rm -p 8000:8000 \
256
+ -e VILA_MODEL_PATH=Efficient-Large-Model/VILA1.5-3B \
257
+ -e VILA_CONV_MODE=vicuna_v1 \
258
+ vila-server:latest
259
+ ```
260
+
261
+ Then you can call the endpoint with the OpenAI SDK as follows:
262
+
263
+ ```python
264
+ from openai import OpenAI
265
+
266
+ client = OpenAI(
267
+ base_url="http://localhost:8000",
268
+ api_key="fake-key",
269
+ )
270
+ response = client.chat.completions.create(
271
+ messages=[
272
+ {
273
+ "role": "user",
274
+ "content": [
275
+ {"type": "text", "text": "What’s in this image?"},
276
+ {
277
+ "type": "image_url",
278
+ "image_url": {
279
+ "url": "https://blog.logomyway.com/wp-content/uploads/2022/01/NVIDIA-logo.jpg",
280
+ # Or you can pass in a base64 encoded image
281
+ # "url": "data:image/png;base64,<base64_encoded_image>",
282
+ },
283
+ },
284
+ ],
285
+ }
286
+ ],
287
+ max_tokens=300,
288
+ model="VILA1.5-3B",
289
+ # You can pass in extra parameters as follows
290
+ extra_body={"num_beams": 1, "use_cache": False},
291
+ )
292
+ print(response.choices[0].message.content)
293
+ ```
294
+
295
+ <sup>NOTE: This API server is intended for evaluation purposes only and has not been optimized for production use. It has only been tested on A100 and H100 GPUs.</sup>
296
+
297
+ ## Checkpoints
298
+
299
+ We release [VILA1.5-3B](https://hf.co/Efficient-Large-Model/VILA1.5-3b), [VILA1.5-3B-S2](https://hf.co/Efficient-Large-Model/VILA1.5-3b-s2), [Llama-3-VILA1.5-8B](https://hf.co/Efficient-Large-Model/Llama-3-VILA1.5-8B-Fix), [VILA1.5-13B](https://hf.co/Efficient-Large-Model/VILA1.5-13b), [VILA1.5-40B](https://hf.co/Efficient-Large-Model/VILA1.5-40b) and the 4-bit [AWQ](https://arxiv.org/abs/2306.00978)-quantized models [VILA1.5-3B-AWQ](https://hf.co/Efficient-Large-Model/VILA1.5-3b-AWQ), [VILA1.5-3B-S2-AWQ](https://hf.co/Efficient-Large-Model/VILA1.5-3b-s2-AWQ), [Llama-3-VILA1.5-8B-AWQ](https://hf.co/Efficient-Large-Model/Llama-3-VILA1.5-8B-Fix-AWQ), [VILA1.5-13B-AWQ](https://hf.co/Efficient-Large-Model/VILA1.5-13b-AWQ), [VILA1.5-40B-AWQ](https://hf.co/Efficient-Large-Model/VILA1.5-40b-AWQ).
300
+
301
+ ## 🔒 License
302
+
303
+ - The code is released under the Apache 2.0 license as found in the [LICENSE](./LICENSE) file.
304
+ - The pretrained weights are released under the [CC-BY-NC-SA-4.0 license](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en).
305
+ - The service is a research preview intended for non-commercial use only, and is subject to the following licenses and terms:
306
+ - [Model License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA. For LLAMA3-VILA checkpoints terms of use, please refer to the [LLAMA3 License](https://llama.meta.com/llama3/license/) for additional details.
307
+ - [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI
308
+ - [Dataset Licenses](./data_prepare/LICENSE) for each one used during training.
309
+
310
+ ## Team
311
+
312
+ | | | |
313
+ | --- | --- | ---|
314
+ [\*Yao Lu](https://scholar.google.com/citations?user=OI7zFmwAAAAJ&hl=en): Nvidia| [\*Hongxu Yin](https://hongxu-yin.github.io/): Nvidia | [\*Ji Lin](https://www.linji.me/): OpenAI (work done at Nvidia and MIT)
315
+ [Wei Ping](https://scholar.google.com/citations?user=6gKEYRgAAAAJ&hl=en): Nvidia | [Pavlo Molchanov](https://www.pmolchanov.com/): Nvidia | [Andrew Tao](https://scholar.google.com/citations?user=Wel9l1wAAAAJ&hl=en): Nvidia |
316
+ [Haotian Tang](http://kentang.net/): MIT | [Shang Yang](https://ys-2020.github.io/): MIT | [Ligeng Zhu](https://lzhu.me/): Nvidia, MIT |
317
+ [Wei-Chen Wang](https://weichenwang.me/): MIT | [Fuzhao Xue](https://xuefuzhao.github.io/): Nvidia, NUS | [Yunhao Fang](https://seerkfang.github.io/): Nvidia, UCSD |
318
+ [Yukang Chen](https://yukangchen.com/): Nvidia, CUHK | [Zhuoyang Zhang](https://openreview.net/profile?id=~Zhuoyang_Zhang1): Nvidia, Tsinghua Univ. | [Yue Shen](https://www.linkedin.com/in/yue-james-shen/): Nvidia |
319
+ [Wei-Ming Chen](https://scholar.google.com/citations?user=6xFvyJwAAAAJ&hl=en): Nvidia | [Huizi Mao](https://scholar.google.com/citations?user=r5WezOYAAAAJ&hl=zh-CN): Nvidia | [Baifeng Shi](https://bfshi.github.io/): Nvidia, UC Berkeley |
320
+ [Jan Kautz](https://jankautz.com/): Nvidia | [Mohammad Shoeybi](https://scholar.google.com/citations?user=62ElavIAAAAJ&hl=en): Nvidia | [Song Han](http://songhan.mit.edu/): Nvidia, MIT
321
+
322
+ ## Citations
323
+
324
+ ```
325
+ @misc{lin2023vila,
326
+ title={VILA: On Pre-training for Visual Language Models},
327
+ author={Ji Lin and Hongxu Yin and Wei Ping and Yao Lu and Pavlo Molchanov and Andrew Tao and Huizi Mao and Jan Kautz and Mohammad Shoeybi and Song Han},
328
+ year={2023},
329
+ eprint={2312.07533},
330
+ archivePrefix={arXiv},
331
+ primaryClass={cs.CV}
332
+ }
333
+ ```
334
+
335
+ # Acknowledgement
336
+
337
+ - [LLaVA](https://github.com/haotian-liu/LLaVA): the codebase we built upon. Thanks for their wonderful work.
338
+ - [InternVL](https://github.com/OpenGVLab/InternVL): for open-sourcing InternViT (used in VILA1.5-40b) and the [InternVL-SFT](https://github.com/OpenGVLab/InternVL/tree/main/internvl_chat#prepare-training-datasets) data blend (inspired by LLaVA-1.6) used in all VILA1.5 models.
339
+ - [Vicuna](https://github.com/lm-sys/FastChat): the amazing open-sourced large language model!
340
+ - [Video-ChatGPT](https://github.com/mbzuai-oryx/Video-ChatGPT): we borrowed video evaluation script from this repository.
341
+ - [MMC4](https://github.com/allenai/mmc4), [COYO-700M](https://github.com/kakaobrain/coyo-dataset), [M3IT](https://huggingface.co/datasets/MMInstruction/M3IT), [OpenORCA/FLAN](https://huggingface.co/datasets/Open-Orca/FLAN), [ShareGPT4V](https://github.com/InternLM/InternLM-XComposer/tree/main/projects/ShareGPT4V), [WIT](google-research-datasets/wit), [GSM8K-ScRel](https://github.com/OFA-Sys/gsm8k-ScRel/blob/main/data/train_use.jsonl), [VisualGenome](https://visualgenome.org/api/v0/api_home.html), [VCR](https://visualcommonsense.com/download/), [ScienceQA](https://huggingface.co/datasets/derek-thomas/ScienceQA), [Shot2Story](https://github.com/bytedance/Shot2Story/blob/master/DATA.md), [Youcook2](http://youcook2.eecs.umich.edu/), [Vatex](https://eric-xw.github.io/vatex-website/download.html), [ShareGPT-Video](https://huggingface.co/datasets/ShareGPTVideo/train_video_and_instruction) for providing datasets used in this research.
VILA/.ipynb_checkpoints/environment_setup-checkpoint.sh ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # This is required to activate conda environment
4
+ eval "$(conda shell.bash hook)"
5
+
6
+ # CONDA_ENV=${1:-""}
7
+ CONDA_ENV=vila
8
+ if [ -n "$CONDA_ENV" ]; then
9
+ conda create -n $CONDA_ENV python=3.10 -y
10
+ conda activate $CONDA_ENV
11
+ else
12
+ echo "Skipping conda environment creation. Make sure you have the correct environment activated."
13
+ fi
14
+
15
+ # This is required to enable PEP 660 support
16
+ pip install --upgrade pip
17
+
18
+ # This is optional if you prefer to use built-in nvcc
19
+ conda install -c nvidia cuda-toolkit -y
20
+
21
+ # Install FlashAttention2
22
+ pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.8/flash_attn-2.5.8+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
23
+
24
+ # Install VILA
25
+ pip install -e .
26
+ pip install -e ".[train]"
27
+ pip install -e ".[eval]"
28
+
29
+ # Install HF's Transformers
30
+ pip install git+https://github.com/huggingface/transformers@v4.37.2
31
+ site_pkg_path=$(python -c 'import site; print(site.getsitepackages()[0])')
32
+ cp -rv ./llava/train/transformers_replace/* $site_pkg_path/transformers/
33
+ cp -rv ./llava/train/deepspeed_replace/* $site_pkg_path/deepspeed/
VILA/CIs/license_all.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ addlicense -s -c 'NVIDIA CORPORATION & AFFILIATES' -ignore "llava/eval/**" -ignore "**/*__init__.py" **/*.py
VILA/CIs/license_commited.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ PYFILES=$(git diff --name-only --diff-filter=ACMRT $commithash HEAD | grep .py | xargs)
2
+
3
+ for file in $PYFILES; do
4
+ echo $file
5
+ addlicense -s -c 'NVIDIA CORPORATION & AFFILIATES' $file
6
+ done
VILA/data_prepare/.DS_Store ADDED
Binary file (6.15 kB). View file
 
VILA/data_prepare/LICENSE ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ License information for datasets used during VILA Training
2
+
3
+ * LLaVA-1.5 Instruction Data: Apache 2.0
4
+ * Coyo: cc-by-4.0
5
+ * MMC4: ODC-By
6
+ * FLAN: cc-by-4.0
7
+ * M3IT: cc-by-4.0
8
+ * ShareGPT4V: cc-by-nc-4.0
VILA/data_prepare/README.md ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Data Preparation for Training VILA
2
+
3
+ To train VILA, we used the following datasets:
4
+
5
+ | Stage | Datasets |
6
+ | ----------------------- | -------------------------------------------------------------------------------- |
7
+ | 1. Initialize projector | CC3M |
8
+ | 2. Pre-training | MMC4-core, COYO-700M subset |
9
+ | 3. SFT | LLaVA-1.5, VFLAN, ShareGPT, TextFLAN, WIT, GSM8K-ScRel-SFT, Sherlock, ScienceQA |
10
+
11
+ ### LLaVa-CC3M-Pretrain
12
+
13
+ We use [LLaVA-CC3M-Pretrain-595K](https://huggingface.co/datasets/liuhaotian/LLaVA-CC3M-Pretrain-595K/blob/main/chat.json) to train the visual language projector
14
+
15
+ ### MMC4-Core Dataset
16
+
17
+ Due to the limit of compute, we pre-train VILA on the smaller core set of MMC4 instead of the full set.
18
+
19
+ 1. Firstly, download the annotations of the MMC4-core dataset here: https://github.com/allenai/mmc4. We used the non-fewer-face split, and you may need to request the access [here](https://forms.gle/VYtcNY8aYaUANK9f8).
20
+
21
+ 1. Now modify the input and output path in `mmc4_downloader.py` and run the following script to scrawl the MMC4 images:
22
+
23
+ ```bash
24
+ cd mmc4
25
+ python mmc4_downloader.py
26
+ ```
27
+
28
+ Note that due to the expiration of image urls, you may end up getting a subset of the entire corpus.
29
+
30
+ The scrawling may take a long time. Optionally, you can also shard the workload over multiple jobs/machines concurrently to speed up the process:
31
+
32
+ ```bash
33
+ # provide the start and end index of the jsonl shard. There are 23098 - 14 shards totally
34
+ # python mmc4_downloader.py <start_idx> <end_idx>
35
+ python mmc4_downloader.py 0 1000 # worker 1
36
+ python mmc4_downloader.py 1000 2000 # worker 2
37
+ ```
38
+
39
+ 3. Filter out invalid samples in MMC4:
40
+
41
+ ```bash
42
+ python mmc4_filter_and_counter.py
43
+ ```
44
+
45
+ 4. Merge images and text into a unified pickle file for each shard:
46
+
47
+ ```bash
48
+ python mmc4_merger.py
49
+ ```
50
+
51
+ ### COYO-700M Dataset
52
+
53
+ 1. Download the metadata of COYO-700M:
54
+
55
+ ```bash
56
+ huggingface-cli download kakaobrain/coyo-700m --repo-type dataset --local-dir coyo-700m --local-dir-use-symlinks False
57
+ ```
58
+
59
+ 2. Scrawl the COYO images. Note that here we only keep a 20% subset in each shard with the highest CLIP similarity, to balance compute budget and data quality.
60
+
61
+ There are totally 128 shards of annotations. Now download each one with the script:
62
+
63
+ ```bash
64
+ cd coyo
65
+ for SHARD in {0..127}; do
66
+ python coyo_downloader.py $SHARD
67
+ done
68
+ ```
69
+
70
+ 3. Split downloaded COYO data into multiple shards:
71
+
72
+ ```bash
73
+ python coyo_splitter.py
74
+ ```
75
+
76
+ ### LLaVA-1.5 Instruction Data
77
+
78
+ We use this [file](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K/blob/main/llava_v1_5_mix665k.json) in our experiments. Please download this dataset from LLaVA authors.
79
+
80
+ ```bash
81
+ huggingface-cli download liuhaotian/LLaVA-Instruct-150K llava_v1_5_mix665k.json --repo-type dataset
82
+ ```
83
+
84
+ ### VFlan dataset
85
+
86
+ 1. Download FLAN datasets:
87
+
88
+ ```bash
89
+ huggingface-cli download Open-Orca/FLAN --repo-type dataset --local-dir FLAN --local-dir-use-symlinks False
90
+ ```
91
+
92
+ 2. Preprocess FLAN dataset (sample 1M data from 378M samples):
93
+
94
+ ```bash
95
+ cd sft
96
+ python preprocess_flan.py
97
+ ```
98
+
99
+ ### M3IT Dataset
100
+
101
+ 1. Download M3IT datasets:
102
+
103
+ ```bash
104
+ huggingface-cli download MMInstruction/M3IT --repo-type dataset --local-dir M3IT --local-dir-use-symlinks False
105
+ ```
106
+
107
+ 2. Preprocess M3IT dataset:
108
+
109
+ ```bash
110
+ python preprocess_m3it.py
111
+ ```
112
+
113
+ 3. (Optional) Split FLAN+M3IT into multiple chunks to reduce CPU memory pressure during training:
114
+
115
+ ```bash
116
+ python split_vflan.py
117
+ ```
118
+
119
+ ### ShareGPT4v
120
+
121
+ The ShareGPT data can be obtained [mit-han-lab/ShareGPT4V](https://huggingface.co/datasets/mit-han-lab/ShareGPT4V). * Note the original ShareGPT4v dataset contains some samples with file ids (sa_XXXX) and repeative response. We filter those bad examples and reduced the samples from 100K -> 96K (for caption) and 1.2m -> 1.17m (for pretraining). Then we re-combine them into a single file.
122
+
123
+ ```bash
124
+ huggingface-cli download mit-han-lab/ShareGPT4V --repo-type dataset --local-dir coyo-700m --local-dir-use-symlinks False
125
+ ```
126
+
127
+ ### WIT
128
+
129
+ The original WIT data can be obtained [google-research-datasets/wit](https://github.com/google-research-datasets/wit/tree/main). * We subsample ~538K english data from the original WIT dataset and curate a llava conversation format JSON file.
130
+
131
+ ```bash
132
+ huggingface-cli download Efficient-Large-Model/WIT_538K --repo-type dataset --local-dir WIT --local-dir-use-symlinks False
133
+ ```
134
+
135
+ ### GSM8K-ScRel-SFT
136
+
137
+ We add some math data [gsm8k-ScRel](https://github.com/OFA-Sys/gsm8k-ScRel/blob/main/data/train_use.jsonl) to our SFT stage.
138
+
139
+ ### Sherlock
140
+
141
+ The image files of Sherlock can be obtained from [VisualGenome](https://visualgenome.org/api/v0/api_home.html) and [VCR](https://visualcommonsense.com/download/) separately. The llava conversation format JSON file can be downloaded with
142
+
143
+ ```bash
144
+ huggingface-cli download Efficient-Large-Model/sherlock_317K --repo-type dataset --local-dir sherlock --local-dir-use-symlinks False
145
+ ```
146
+
147
+ ### ScienceQA
148
+
149
+ We use the train split of ScienceQA. The image data of the train split can be obtained from [ScienceQA](https://huggingface.co/datasets/derek-thomas/ScienceQA) or their [huggingface repo](https://huggingface.co/datasets/derek-thomas/ScienceQA). The llava conversation format JSON file can be downloaded with
150
+
151
+ ```bash
152
+ huggingface-cli download Efficient-Large-Model/ScienceQA_train_12K --repo-type dataset --local-dir scienceqa --local-dir-use-symlinks False
153
+ ```
154
+
155
+ ### IDEFICS2-SFT dataset
156
+
157
+ We also provide scripts to preprocess IDEFICS2-SFT dataset into llava-SFT like format.
158
+
159
+ Please first download [HuggingFaceM4/the_cauldron](https://huggingface.co/datasets/HuggingFaceM4/the_cauldron) to `/home/jasonlu/workspace/idefics2-sft/the_cauldron`. Then, run the following scripts:
160
+
161
+ ```bash
162
+ python preprocess_idefics2.py
163
+ python merge_idefics2.py
164
+ ```
165
+
166
+ A sample in the preprocessed dataset file will look like this:
167
+
168
+ ```json
169
+ {"id": 0, "images": ["images/chart2text/0_0.png"], "conversations": [{"from": "human", "value": "<image>\nPlease clarify the meaning conveyed by this graph."}, {"from": "gpt", "value": "This statistic presents the reach of the most popular social networks among female beauty consumers in the United States as of August 2016. During the survey period, 62 percent of respondents had an Instagram account."}]}
170
+ ```
171
+
172
+ Haotian's Note: Datasets overlapping with VFLAN / ShareGPT4V-SFT are removed. I also remove `plotqa` since it is too large, `localized_narratives` seems to be a little bit overlapped with captioning efforts within VILA. `websight` and `datikz` are two datasets that target code generation. Since the output is very long, and including them might slow down training, I also temporarily removed these two datasets, but feel free to add them back.
VILA/data_prepare/panda70m.sh ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ JOBS_LIMIT=${1:-32} # Set your limit here
2
+ workdir=${2:-/lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/dataset/panda70m/panda70m_training_10m}
3
+
4
+ wname=$(echo $workdir | rev | cut -d "/" -f 1 | rev)
5
+
6
+ echo "Parallely checking for all shards in $workdir / $wname"
7
+ parallel_size=32
8
+ idx_size=$(( parallel_size - 1 ))
9
+
10
+ mkdir -p slurm-logs/data
11
+
12
+ for idx in $(seq 0 $idx_size); do
13
+ while [ $(jobs -rp | wc -l) -ge $JOBS_LIMIT ]; do
14
+ sleep 1
15
+ done
16
+ echo "Running jobs $(jobs -rp | wc -l) $wname-$idx-of-$parallel_size";
17
+
18
+ srun -A llmservice_nlp_fm \
19
+ -p cpu,cpu_1,cpu_long -t 4:00:00 -J cleanup-$wname-$idx-of-$parallel_size \
20
+ --cpus-per-task 8 \
21
+ --mem-per-cpu 8G \
22
+ -e slurm-logs/data/$idx-of-$parallel_size.err \
23
+ -o slurm-logs/data/$idx-of-$parallel_size.txt \
24
+ python llava/data/dataset_impl/panda70m.py --workdir=$workdir --shards=$idx --total=$parallel_size &
25
+
26
+ done
27
+
28
+ # bash data_prepare/panda70m.sh 32 /lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/dataset/panda70m/panda70m_training_10m;
29
+ # bash data_prepare/panda70m.sh 32 /lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/dataset/panda70m/panda70m_training_2m;
30
+ # bash data_prepare/panda70m.sh 32 /lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/dataset/panda70m/panda70m_testing;
31
+
32
+ # --exclusive \
33
+ # --cpus-per-task 8 \
34
+ # --mem-per-cpu 8G \
VILA/data_prepare/panda_split.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import base64
18
+ import copy
19
+ import glob
20
+ import io
21
+ import json
22
+ import logging
23
+ import os
24
+ import os.path as osp
25
+ import pathlib
26
+ import pickle
27
+ import random
28
+ import re
29
+ import shutil
30
+ import time
31
+ from collections import defaultdict
32
+ from dataclasses import dataclass, field
33
+ from datetime import datetime
34
+ from functools import lru_cache
35
+ from io import BytesIO
36
+ from typing import Dict, List, Optional, Sequence
37
+
38
+ import cv2
39
+ import decord
40
+ import numpy as np
41
+ import PIL
42
+ import torch
43
+ import transformers
44
+ from decord._ffi.base import DECORDError
45
+ from iopath.common.file_io import g_pathmgr
46
+ from PIL import Image
47
+ from pytorchvideo.data.decoder import DecoderType
48
+ from pytorchvideo.data.encoded_video import EncodedVideo, select_video_class
49
+ from pytorchvideo.data.video import Video
50
+ from torch.utils.data import ConcatDataset, Dataset
51
+ from torchvision.transforms import Resize
52
+
53
+ import llava.data.datasets_mixture as datasets_mixture
54
+ from llava import conversation as conversation_lib
55
+ from llava.constants import (
56
+ DEFAULT_IM_END_TOKEN,
57
+ DEFAULT_IM_START_TOKEN,
58
+ DEFAULT_IMAGE_TOKEN,
59
+ IGNORE_INDEX,
60
+ IMAGE_TOKEN_INDEX,
61
+ )
62
+ from llava.data.dataset import LazySupervisedDataset
63
+ from llava.data.dataset_impl.textocr import GenericDataset, preprocess_OCR
64
+ from llava.data.datasets_mixture import DATASETS
65
+ from llava.data.simple_vila_webdataset import VILAWebDataset
66
+ from llava.data.utils import VILAEncodedVideo
67
+ from llava.mm_utils import is_gemma_tokenizer, tokenizer_image_token
68
+ from llava.train.args import DataArguments, TrainingArguments
69
+
70
+ DEFAULT_HIERTEXT = "/lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/dataset/panda70m"
71
+ SPLIT = "panda70m_testing"
72
+
73
+
74
+ def with_opencv(filename):
75
+ video = cv2.VideoCapture(filename)
76
+ fps = video.get(cv2.CAP_PROP_FPS)
77
+ frame_count = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
78
+ duration = frame_count / fps
79
+ return duration, fps, frame_count
80
+
81
+
82
+ def split_video_to_clips(
83
+ workdir=osp.expanduser("~/nvr_elm_llm/dataset/panda70m/panda70m_training_2m"),
84
+ shards=0,
85
+ total=-1,
86
+ ):
87
+ video_list = glob.glob(f"{workdir}/*.mp4")
88
+ video_list = sorted(video_list)
89
+ if total > 0:
90
+ chunk = len(video_list) // total
91
+ begin_idx = shards * chunk
92
+ end_idx = (shards + 1) * chunk
93
+ if shards == total - 1:
94
+ end_idx = len(video_list)
95
+ video_list = video_list[begin_idx:end_idx]
96
+ print(f"Splitting total {len(video_list)} videos")
97
+ output_dir = workdir + "_clip"
98
+ debug_info = {}
99
+ for idx, video_path in enumerate(video_list):
100
+ print(f"[{idx}/{len(video_list)}]", video_path)
101
+ json_path = video_path.replace(".mp4", ".json")
102
+ assert osp.exists(json_path) and osp.exists(video_path)
103
+ jinfo = json.load(open(json_path))
104
+ print(jinfo)
105
+ info = with_opencv(video_path)
106
+ print(info)
107
+ video = VILAEncodedVideo.from_bytesio(video_path, decoder="decord", decode_audio=False)
108
+
109
+ return
110
+
111
+
112
+ if __name__ == "__main__":
113
+ # WORKDIR=osp.expanduser("~/nvr_elm_llm/dataset/panda70m/panda70m_testing")
114
+ # cleanup_corrupted_videos()
115
+ import fire
116
+
117
+ fire.Fire(split_video_to_clips)
VILA/data_prepare/parallel_shards.sh ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ JOBS_LIMIT=${1:-32} # Set your limit here
2
+ workdir=${2:-/lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/dataset/panda70m/panda70m_training_10m}
3
+
4
+
5
+ workdir=/lustre/fsw/portfolios/nvr/projects/nvr_elm_llm/dataset/video_datasets_v2/internvid/video_data_tar
6
+
7
+ parallel_size=32
8
+ idx_size=$(( parallel_size - 1 ))
9
+
10
+ mkdir -p slurm-logs/data
11
+
12
+ for idx in $(seq 0 $idx_size); do
13
+ while [ $(jobs -rp | wc -l) -ge $JOBS_LIMIT ]; do
14
+ sleep 1
15
+ done
16
+ echo "Running jobs $(jobs -rp | wc -l) $idx-of-$parallel_size";
17
+
18
+ srun -A $SLURM_ACCOUNT \
19
+ -p cpu,cpu_1,cpu_long -t 4:00:00 -J creating-WDS-$idx-of-$parallel_size \
20
+ --cpus-per-task 8 \
21
+ --mem-per-cpu 8G \
22
+ --dependency singleton \
23
+ -e slurm-logs/data/$idx-of-$parallel_size.err \
24
+ -o slurm-logs/data/$idx-of-$parallel_size.txt \
25
+ python llava/data/simple_vila_webdataset.py $workdir --shards=$idx --total=$parallel_size &
26
+ done
27
+ wait
28
+
29
+ python llava/data/simple_vila_webdataset.py $workdir
VILA/demo_images/LongVILA-pipeline.png ADDED

Git LFS Details

  • SHA256: d29fdbb1cdf908a8053cf9ca19262aaf4823d51cd2c04567f8375af951f6cdd8
  • Pointer size: 131 Bytes
  • Size of remote file: 156 kB
VILA/demo_images/av.png ADDED

Git LFS Details

  • SHA256: 093f0838b946c86d932ca76ad5b0fc871609d1c49dba359a9380545d31b67ed3
  • Pointer size: 131 Bytes
  • Size of remote file: 384 kB
VILA/demo_images/demo_img_1.png ADDED

Git LFS Details

  • SHA256: 85765d45ea665ac4afbafbc5ce03fdcc23fd958d64b6da2038a1f6cce85a1541
  • Pointer size: 131 Bytes
  • Size of remote file: 142 kB
VILA/demo_images/demo_img_2.png ADDED

Git LFS Details

  • SHA256: 81b278a341259c01bc01b55effd6f61b6a2b12657305d644473a8ba5371861b9
  • Pointer size: 131 Bytes
  • Size of remote file: 715 kB
VILA/demo_images/demo_img_3.png ADDED

Git LFS Details

  • SHA256: 1e26e812858c4610bfebc33a5f42751db4b88cf948adab19a67e67d4865d1271
  • Pointer size: 131 Bytes
  • Size of remote file: 568 kB
VILA/demo_images/longvila-logo.png ADDED

Git LFS Details

  • SHA256: 41046d75a3bb9d3dde39781e0d204a4f9c58e5353feff6e712590bb8d1fb000d
  • Pointer size: 131 Bytes
  • Size of remote file: 157 kB
VILA/demo_images/vila-logo.jpg ADDED
VILA/demo_trt_llm/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ## Deprecation Notice
2
+
3
+ This README is deprecated and is no longer being maintained. For the most up-to-date information and instructions, please refer to the [TensorRT-LLM example](https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/multimodal#llava-and-vila) for VILA deployment.
VILA/inference_test/inference_test.json ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "test_cases": [
3
+ {
4
+ "name": "top down view",
5
+ "image_paths": [
6
+ "more_samples/top_view.png"
7
+ ],
8
+ "QAs": [
9
+ {
10
+ "question": "<image>\n What is unusual about this image?",
11
+ "expected_answer": "The unusual aspect of this image is that it is an aerial view of a busy freeway with many cars, and it appears to be taken from a helicopter. This perspective provides a unique and interesting perspective of the traffic, as it allows the viewer to see the entire freeway and all the cars on it from above. The image captures the bustling nature of the city and the movement of the vehicles, which is not easily visible from ground level."
12
+ }
13
+ ]
14
+ },
15
+
16
+ {
17
+ "name": "deer crossing",
18
+ "image_paths": [
19
+ "more_samples/deer_crossing.png"
20
+ ],
21
+ "QAs": [
22
+ {
23
+ "question": "<image>\n What is unusual about this image?",
24
+ "expected_answer": "The unusual aspect of this image is that a group of deer is crossing a road in front of a car. Typically, deer are not expected to be seen crossing roads, especially in urban or suburban areas. This situation can pose a risk to both the deer and the people in the car, as the deer might not be aware of the approaching vehicle, and the driver may not have enough time to react and stop safely. It is important for drivers to be cautious and patient in such situations to avoid accidents and ensure the safety of both the animals and the people involved."
25
+ }
26
+ ]
27
+ },
28
+
29
+ {
30
+ "name": "windmine",
31
+ "image_paths": [
32
+ "more_samples/windmine.png"
33
+ ],
34
+ "QAs": [
35
+ {
36
+ "question": "<image>\n What is in this image?",
37
+ "expected_answer": "The image features a large field with numerous wind turbines scattered throughout the area. The wind turbines are of varying heights and are positioned at different angles, creating a visually striking scene. The field is surrounded by a forest, adding to the natural beauty of the landscape."
38
+ },
39
+ {
40
+ "question": "<image>\n Do facilities here look normal or need maintenance?",
41
+ "expected_answer": "The facilities in the image, which are wind turbines, appear to be in good condition and functioning properly. They are situated in a large field with a clear blue sky, indicating that they are well-maintained and not in need of immediate maintenance. The presence of multiple wind turbines in the field suggests that they are part of a larger wind farm, which is likely to be managed and maintained by professionals to ensure optimal performance and efficiency."
42
+ }
43
+ ]
44
+ },
45
+
46
+ {
47
+ "name": "windworker",
48
+ "image_paths": [
49
+ "more_samples/windworker.png"
50
+ ],
51
+ "QAs": [
52
+ {
53
+ "question": "<image>\n Can you describe the situation in this image?",
54
+ "expected_answer": "In the image, a person is climbing a large wind turbine, likely a windmill, to perform maintenance or repair work. The individual is using a harness and a rope to secure themselves to the turbine, ensuring their safety while they work. The scene captures the height and scale of the wind turbine, emphasizing the importance of the task being performed by the person."
55
+ },
56
+ {
57
+ "question": "<image>\n Is the worker well protected?",
58
+ "expected_answer": "Yes, the worker is well protected. In the image, the man is wearing a harness while climbing the side of a wind turbine. This harness provides safety and support, preventing the worker from falling and ensuring that they can safely perform their task. Additionally, the worker is likely using other safety equipment and following proper safety protocols to minimize the risk of accidents while working at such heights."
59
+ }
60
+ ]
61
+ },
62
+
63
+ {
64
+ "name": "construction crack",
65
+ "image_paths": [
66
+ "more_samples/construction_crack.webp"
67
+ ],
68
+ "QAs": [
69
+ {
70
+ "question": "<image>\n Is the building in health condition?",
71
+ "expected_answer": "The building appears to be in a state of disrepair, with visible cracks and chips in the paint. This suggests that the building may be old, poorly maintained, or has experienced some form of damage or wear over time. The presence of these structural issues could potentially pose safety risks or negatively impact the building's overall appearance and value."
72
+ },
73
+ {
74
+ "question": "<image>\n What are the likely service needed for this building?",
75
+ "expected_answer": "The building in the image has a cracked wall and a damaged roof, which indicates that it requires repair and maintenance services. The cracked wall may require repairs to ensure structural integrity and prevent further damage. The damaged roof may need to be replaced or repaired to prevent water leakage and potential damage to the building's interior. Additionally, the building may need to be inspected for any other potential issues that may arise due to the damage."
76
+ }
77
+ ]
78
+ },
79
+
80
+ {
81
+ "name": "climate change",
82
+ "image_paths": [
83
+ "more_samples/climate_change.jpeg"
84
+ ],
85
+ "QAs": [
86
+ {
87
+ "question": "<image>\n What is the implication of temperature based on this image?",
88
+ "expected_answer": "The image shows a comparison of the Arctic ice cover in 1984 and 2012. The implication of this image is that the Arctic ice cover has significantly decreased over the years, which is a result of climate change. The reduction in ice cover can have various implications, such as increased exposure of the Arctic environment to warmer temperatures, which can lead to the melting of sea ice, loss of habitat for polar bears and other Arctic species, and potential disruptions to the global climate system. This image serves as a visual representation of the impact of climate change on the Arctic environment and the need for action to mitigate its effects."
89
+ }
90
+ ]
91
+ },
92
+
93
+ {
94
+ "name": "patient status",
95
+ "image_paths": [
96
+ "more_samples/patient_status.jpeg"
97
+ ],
98
+ "QAs": [
99
+ {
100
+ "question": "<image>\n What is the patient status?",
101
+ "expected_answer": "The patient is in a critical condition, as indicated by the presence of a ventilator and the fact that she is hooked up to a lot of wires."
102
+ },
103
+ {
104
+ "question": "<image>\n Is the patient asleep?",
105
+ "expected_answer": "Yes, the patient is asleep in the hospital bed."
106
+ }
107
+ ]
108
+ },
109
+
110
+ {
111
+ "name": "patient care",
112
+ "image_paths": [
113
+ "more_samples/patient_care.png"
114
+ ],
115
+ "QAs": [
116
+ {
117
+ "question": "<image>\n What is the training session about?",
118
+ "expected_answer": "The training session is about learning how to perform CPR (cardiopulmonary resuscitation) on a mannequin. The group of people, including nurses and possibly other medical professionals, are gathered around a mannequin to practice their skills and learn the proper techniques for administering CPR."
119
+ },
120
+ {
121
+ "question": "<image>\n Will the practice cause pain to the patient in the image?",
122
+ "expected_answer": "The image shows a group of people, including nurses and a mannequin, practicing CPR on a mannequin. While practicing CPR is essential for healthcare professionals, it is important to note that it is a simulation and not a real patient. The practice does not cause pain to the patient in the image, as it is a mannequin and not a living person."
123
+ }
124
+ ]
125
+ },
126
+
127
+ {
128
+ "name": "Fuyu comparison",
129
+ "image_paths": [
130
+ "fuyu-capm.png"
131
+ ],
132
+ "QAs": [
133
+ {
134
+ "question": "<image>\n Janet Ludlow’s firm requires all its analysts to use a two-stage dividend discount model (DDM) and the capital asset pricing model (CAPM) to value stocks. Using the CAPM and DDM, Ludlow has valued QuickBrush Company at $63 per share. She now must value SmileWhite Corporation. Calculate the required rate of return for SmileWhite by using the information in the table. Which rate is correct? A. 14% B. 15% C. 16%",
135
+ "expected_answer": "The required rate of return for SmileWhite can be calculated using the capital asset pricing model (CAPM), which is:\n\nRequired Return = Risk-free Rate + Beta x (Expected Market Return - Risk-free Rate)\n\nThe risk-free rate and expected market return are provided in the table, and the beta for SmileWhite is given as 1.15.\n\nLet’s calculate the required rate of return for SmileWhite:\n\nRequired Return = 4.50% + 1.15 x (14.50% - 4.50%)\nRequired Return = 4.50% + 1.15 x 10%\nRequired Return = 4.50% + 11.50%\nRequired Return = 16%\n\nTherefore, the answer is:\n\nC: 16%\nAnswer: C"
136
+ }
137
+ ]
138
+ },
139
+
140
+
141
+
142
+ {
143
+ "name": "Fuyu comparison",
144
+ "image_paths": [
145
+ "fuyu-heavy-table.png"
146
+ ],
147
+ "QAs": [
148
+ {
149
+ "question": "According to the following table, <image>\nwhich food is the most likely cause of the outbreak of food poisoning? (A) Cold chicken, (B) Potato salad, (C) Egg sandwiches, (D) Fruit pie and cream. Read the table in details and consider the percentage as the key factor. Think step-by-step.",
150
+ "expected_answer": "The table shows the number of people who ate different foods and the number of people who ate those foods and got sick. To determine which food is most likely the cause of the outbreak of food poisoning, we’re looking for the food with the highest percentage of people who ate that food and got sick.\n\nLet’s calculate the percentage of people who ate each food and got sick:\n\n- Cold chicken: 34 / 86 = 0.395, which is 39.5%\n- Potato salad: 38 / 54 = 0.704, which is 70.4%\n- Egg sandwiches: 40 / 76 = 0.526, which is 52.6%\n- Fruit pie and cream: 12 / 32 = 0.375, which is 37.5%\n- Cheese: 12 / 48 = 0.25, which is 25%\n\nThe food with the highest percentage of people who ate that food and got sick is potato salad, with 70.4%. Therefore, potato salad is the most likely cause of the outbreak of food poisoning.\n\nAnswer: B"
151
+ }
152
+ ]
153
+ },
154
+
155
+
156
+ {
157
+ "name": "flamingo comparison",
158
+ "image_paths": [
159
+ "flamingo_1.png",
160
+ "flamingo_2.png",
161
+ "flamingo_3.png"
162
+ ],
163
+ "QAs": [
164
+ {
165
+ "question": "Image 1: <image>\n Image 2: <image>\n Image 3: <image>\n What is the common thing about these three images?",
166
+ "expected_answer": "The common theme among these three images is the presence of a flamingo. In the first image, the flamingo is a 3D model, while in the second image, it is a pink cartoon. In the third image, it is a real flamingo standing in the water. All three images showcase the bird in different styles and formats, but they all share the same subject."
167
+ }
168
+ ]
169
+ },
170
+ {
171
+ "name": "princess comparison",
172
+ "image_paths": [
173
+ "princess_1.png",
174
+ "princess_2.png"
175
+ ],
176
+ "QAs": [
177
+ {
178
+ "question": "Image 1: <image>\n Image 2: <image>\n Tell me about the differences between image 1 and image 2.",
179
+ "expected_answer": "The two images feature a cartoon girl wearing a pink dress. In the first image, the girl is wearing a crown, while in the second image, she is wearing a tiara. Both images showcase the girl in a princess-like attire, but the difference lies in the type of headwear she is wearing. The crown is a more traditional and regal accessory, while the tiara is a more playful and common accessory associated with princesses in popular culture."
180
+ }
181
+ ]
182
+ },
183
+ {
184
+ "name": "order price calculation",
185
+ "image_paths": [
186
+ "menu_1.png",
187
+ "menu_2.png"
188
+ ],
189
+ "QAs": [
190
+ {
191
+ "question": "Photo: <image>\n Menu: <image>\n How much should I pay for all the beers on the table according to the price on the menu? Think step by step.",
192
+ "expected_answer": "The beer on the table is Magna, which is priced at $6.00. There are two bottles of Magna on the table. Therefore, the total cost of the beer should be $6.00 x 2 = $12.00."
193
+ }
194
+ ]
195
+ },
196
+ {
197
+ "name": "pizza price calculation",
198
+ "image_paths": [
199
+ "pizza.png"
200
+ ],
201
+ "QAs": [
202
+ {
203
+ "question": "<image>\nI am getting just two custom pizza for me and my friend. How much should I pay in total according to the image? Let's think step by step.",
204
+ "expected_answer": "The image shows a chalkboard menu with a special for today, which is to create your own pizza. The price for this special is $9.99. Since you are getting two custom pizzas, you should pay $9.99 x 2 = $19.98 for the two pizzas. So, the answer is $19.98."
205
+ }
206
+ ]
207
+ },
208
+ {
209
+ "name": "sign OCR",
210
+ "image_paths": [
211
+ "sign_1.png",
212
+ "sign_2.png",
213
+ "sign_3.png"
214
+ ],
215
+ "QAs": [
216
+ {
217
+ "question": "Image 1: <image>\nImage 2: <image>\nImage 3: <image>. Image 1 is Underground. Image 2 is Congress. What is Image 3?",
218
+ "expected_answer": "Soulomes."
219
+ }
220
+ ]
221
+ },
222
+ {
223
+ "name": "painting style",
224
+ "image_paths": [
225
+ "painting_1.png",
226
+ "painting_2.png",
227
+ "painting_3.png"
228
+ ],
229
+ "QAs": [
230
+ {
231
+ "question": "Image 1: <image>\nImage 2: <image>\nImage 3: <image>. Image 1 is Romanticism. Image 2 is Surrealism. What is Image 3?",
232
+ "expected_answer": "Impressionism"
233
+ }
234
+ ]
235
+ },
236
+ {
237
+ "name": "handwritten calculation",
238
+ "image_paths": [
239
+ "handwritten_1.png",
240
+ "handwritten_2.png",
241
+ "handwritten_3.png"
242
+ ],
243
+ "QAs": [
244
+ {
245
+ "question": "Image 1: <image>\nImage 2: <image>\nImage 3: <image>. Image 1 is 2+1=3. Image 2 is 5+6=11. What is Image 3?",
246
+ "expected_answer": "3x6=18"
247
+ }
248
+ ]
249
+ },
250
+ {
251
+ "name": "landmark Taipei",
252
+ "image_paths": [
253
+ "landmark_taipei.png"
254
+ ],
255
+ "QAs": [
256
+ {
257
+ "question": "<image>\nWhich city is this landmark in?",
258
+ "expected_answer": "The landmark in the image is located in Taipei, Taiwan."
259
+ }
260
+ ]
261
+ },
262
+ {
263
+ "name": "landmark new york 1",
264
+ "image_paths": [
265
+ "landmark_new_york_1.png"
266
+ ],
267
+ "QAs": [
268
+ {
269
+ "question": "<image>\nWhich city is this landmark in?",
270
+ "expected_answer": "This landmark is located in New York City."
271
+ }
272
+ ]
273
+ },
274
+ {
275
+ "name": "landmark new york 2",
276
+ "image_paths": [
277
+ "landmark_new_york_2.png"
278
+ ],
279
+ "QAs": [
280
+ {
281
+ "question": "<image>\nWhich city is this landmark in?",
282
+ "expected_answer": "The landmark in the image is located in New York City."
283
+ }
284
+ ]
285
+ },
286
+ {
287
+ "name": "landmark Kyoto",
288
+ "image_paths": [
289
+ "landmark_kyoto.png"
290
+ ],
291
+ "QAs": [
292
+ {
293
+ "question": "<image>\nWhich city is this landmark in?",
294
+ "expected_answer": "This landmark is located in Kyoto, Japan."
295
+ }
296
+ ]
297
+ },
298
+ {
299
+ "name": "visual inference",
300
+ "image_paths": [
301
+ "visual_inference.png"
302
+ ],
303
+ "QAs": [
304
+ {
305
+ "question": "<image>\nWhat is in the circled glass? Object 1 or 2? Describe what is Object 1 and object 2. Then check what is in the circled glass.",
306
+ "expected_answer": "Object 1 is a bottle of water, and Object 2 is a bottle of beer. The circled glass contains Object 1, which is a bottle of water."
307
+ }
308
+ ]
309
+ },
310
+ {
311
+ "name": "Gemini handdraw car",
312
+ "image_paths": [
313
+ "gemini_handdraw_car.png"
314
+ ],
315
+ "QAs": [
316
+ {
317
+ "question": "<image>\nWhich of these cars is more aerodynamic? The one on the left or the right? Explain why, using specific visual details.",
318
+ "expected_answer": "The car on the right is more aerodynamic than the one on the left. This is because the car on the right has a more streamlined shape, which reduces air resistance and drag. The car on the left has a more traditional, boxy shape, which increases air resistance and drag. The aerodynamic car on the right would be more efficient and faster, while the car on the left would be slower and less efficient."
319
+ }
320
+ ]
321
+ },
322
+ {
323
+ "name": "caption night view",
324
+ "image_paths": [
325
+ "caption_nightview.png"
326
+ ],
327
+ "QAs": [
328
+ {
329
+ "question": "<image>\nDescribe the image with as much detail as possible.",
330
+ "expected_answer": "The image captures a beautiful nighttime scene of a city skyline with a large body of water in front of it. The city is illuminated by the lights of various buildings, creating a vibrant atmosphere. The water is filled with numerous boats of different sizes."
331
+ }
332
+ ]
333
+ },
334
+ {
335
+ "name": "caption meat",
336
+ "image_paths": [
337
+ "caption_meat.jpeg"
338
+ ],
339
+ "QAs": [
340
+ {
341
+ "question": "<image>\nDescribe the image with as much detail as possible.",
342
+ "expected_answer": "The image is a surreal and colorful scene featuring a landscape made entirely of meat. The landscape is filled with various types of meat, including bacon, ham, and sausage, creating a unique and eye-catching scene. The meat is arranged in the form of mountains, hills, and valleys, giving the impression of a fantasy world. In addition to the meat, there are several birds scattered throughout the scene."
343
+ }
344
+ ]
345
+ },
346
+ {
347
+ "name": "company",
348
+ "image_paths": [
349
+ "company_1.png",
350
+ "company_2.png",
351
+ "company_3.png"
352
+ ],
353
+ "QAs": [
354
+ {
355
+ "question": "Image 1: <image>\nImage 2: <image>\nImage 3: <image>. In Image 1, The company is famous for its search engine. In Image 2, The company is famous for iPhone and Mac. What is the company in Image 3 famouns for?",
356
+ "expected_answer": "The company is famous for its graphics processing units."
357
+ }
358
+ ]
359
+ },
360
+ {
361
+ "name": "count animal",
362
+ "image_paths": [
363
+ "count_panda_3.png",
364
+ "count_dog_2.png",
365
+ "count_giraff_4.png"
366
+ ],
367
+ "QAs": [
368
+ {
369
+ "question": "<image> pandas: 3.\n <image> dogs:2. <image>",
370
+ "expected_answer": "giraffes: 4"
371
+ }
372
+ ]
373
+ },
374
+ {
375
+ "name": "french",
376
+ "image_paths": [
377
+ "french_1.png",
378
+ "french_2.png",
379
+ "french_3.png"
380
+ ],
381
+ "QAs": [
382
+ {
383
+ "question": "Image 1: <image> Les sanglots longs des violons de l’automne blessent mon coeur d’une langueur monotone. \n Image 2: <image> Pour qui sont ces serpents qui sifflent sur vos têtes? \n Image 3: <image>",
384
+ "expected_answer": "Les flamands roses s'embrassent avec passion, leurs cœurs se touchant, leur amour se partageant."
385
+ }
386
+ ]
387
+ },
388
+ {
389
+ "name": "meme",
390
+ "image_paths": [
391
+ "meme.png"
392
+ ],
393
+ "QAs": [
394
+ {
395
+ "question": "<image>\nCan you explain the meme?",
396
+ "expected_answer": "The meme depicts a man's reaction to the price of a computer graphics card. In the first image, the man is smiling and appears excited about the product. In the second image, he is shocked and disappointed by the high price of the graphics card, which is $1,200. The meme is a playful representation of the contrast between the man's initial enthusiasm and his subsequent disappointment upon learning the cost of the product."
397
+ }
398
+ ]
399
+ },
400
+ {
401
+ "name": "flying chair",
402
+ "image_paths": [
403
+ "flying_chair.png"
404
+ ],
405
+ "QAs": [
406
+ {
407
+ "question": "<image>\nWhat is unusual about this image?",
408
+ "expected_answer": "The unusual aspect of this image is that a chair is flying through the air on a highway, seemingly coming out of the back of a truck."
409
+ },
410
+ {
411
+ "question": "<image>\nWhat should you do if you encounter this?",
412
+ "expected_answer": "If you encounter this situation, you should immediately stop your vehicle and move to a safe distance from the truck and the flying chair. It is essential to avoid any potential hazards and contact the authorities to report the incident and ensure the safety of everyone involved."
413
+ }
414
+ ]
415
+ },
416
+ {
417
+ "name": "palm_e",
418
+ "image_paths": [
419
+ "palm_e_1.png",
420
+ "palm_e_2.png",
421
+ "palm_e_3.png"
422
+ ],
423
+ "QAs": [
424
+ {
425
+ "question": "Image 1: <image>\nImage 2: <image>\nImage 3: <image>. Image 1: at 10:30 am. Image 2: at 12:45 pm. Image3: at 3:45 pm. What did I have for lunch, and what time was it?",
426
+ "expected_answer": "I had a sandwich for lunch, and it was at 12:45 pm."
427
+ }
428
+ ]
429
+ },
430
+ {
431
+ "name": "orange price",
432
+ "image_paths": [
433
+ "orange_price.png"
434
+ ],
435
+ "QAs": [
436
+ {
437
+ "question": "<image>\nWhat's the price for a single orange? Look at the price tag in details.",
438
+ "expected_answer": "$1.25"
439
+ }
440
+ ]
441
+ },
442
+ {
443
+ "name": "tow car",
444
+ "image_paths": [
445
+ "tow_car.png"
446
+ ],
447
+ "QAs": [
448
+ {
449
+ "question": "<image>\nWhat's the person doing?",
450
+ "expected_answer": "The person is lying on the ground next to a car, possibly working on it or inspecting it."
451
+ }
452
+ ]
453
+ },
454
+ {
455
+ "name": "parking sign",
456
+ "image_paths": [
457
+ "parking_sign.png"
458
+ ],
459
+ "QAs": [
460
+ {
461
+ "question": "<image>\nHow long can I park here 5pm on Mondays? Look at the traffic signs in details.",
462
+ "expected_answer": "After 5pm on Monday, you can park for 1 hour."
463
+ }
464
+ ]
465
+ },
466
+ {
467
+ "name": "car block",
468
+ "image_paths": [
469
+ "car_blocker.png"
470
+ ],
471
+ "QAs": [
472
+ {
473
+ "question": "<image>\n Look at the traffic condition, can the vehicle proceed now? Why?",
474
+ "expected_answer": "Based on the image, the vehicle cannot proceed through the traffic yet. There are multiple people and bicycles in the crosswalk, and the traffic light is red. The vehicle must wait for the traffic light to turn green before proceeding."
475
+ }
476
+ ]
477
+ },
478
+ {
479
+ "name": "car safety",
480
+ "image_paths": [
481
+ "car_safety.jpg"
482
+ ],
483
+ "QAs": [
484
+ {
485
+ "question": "<image>\nIs the driver on the phone? ",
486
+ "expected_answer": "Yes, the driver is on the phone."
487
+ },
488
+ {
489
+ "question": "<image>\nHow many people are in the car?",
490
+ "expected_answer": "There are two people in the car with one person driving and the other in the back of the car."
491
+ },
492
+ {
493
+ "question": "<image>\nIs the driver distracted?",
494
+ "expected_answer": "Yes, the driver is distracted as he is holding a cell phone interacting with it while sitting on the driver’s seat."
495
+ },
496
+ {
497
+ "question": "<image>\nWhere is the passenger sitting?",
498
+ "expected_answer": "The passenger is sitting on the right side of the car."
499
+ },
500
+ {
501
+ "question": "<image>\nWhat is on the passenger seat? Is it safe?",
502
+ "expected_answer": "There is a pair of scissors on the passenger seat. It is not safe."
503
+ }
504
+ ]
505
+ },
506
+ {
507
+ "name": "factory",
508
+ "image_paths": [
509
+ "factory.jpg"
510
+ ],
511
+ "QAs": [
512
+ {
513
+ "question": "<image>\nHow many cars are jacked up?",
514
+ "expected_answer": "There are two cars jacked up in the image."
515
+ },
516
+ {
517
+ "question": "<image>\nWhat is the person whose head is under the jacked up car doing?",
518
+ "expected_answer": "The person whose head is under the jacked up car is likely performing a task related to the maintenance or repair of the vehicle. They could be inspecting the suspension, brakes, or other components of the car that require attention. The other people in the scene are also working on the vehicles, suggesting that they are part of a team or a group of mechanics or technicians who are collaborating to fix or maintain the cars."
519
+ },
520
+ {
521
+ "question": "<image>\nHow many people are there whose head is under the jacked up car?",
522
+ "expected_answer": "There are two persons whose head is under the jacked up car."
523
+ }
524
+ ]
525
+ },
526
+ {
527
+ "name": "factory count",
528
+ "image_paths": [
529
+ "factory_count_1.jpg",
530
+ "factory_count_2.jpg",
531
+ "factory_count_3.jpg",
532
+ "factory_count_4.jpg",
533
+ "factory_count_5.jpg",
534
+ "factory_count_6.jpg",
535
+ "factory_count_7.jpg",
536
+ "factory_count_8.jpg"
537
+ ],
538
+ "QAs": [
539
+ {
540
+ "question": "Frame 1: <image>\n Frame 2: <image>\n Frame 2: <image>\n Frame 4: <image>\n Frame 5: <image>\n Frame 6: <image>\n Frame 7: <image>\n Frame 8: <image>\n Considering the video frames, how many chip bags are picked up?",
541
+ "expected_answer": "Two chip bags are picked up."
542
+ }
543
+ ]
544
+ }
545
+ ]
546
+ }
VILA/inference_test/inference_test.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ """
18
+ Inference test to run all examples from the paper and compare w/ expected output.
19
+ Both the inference results and expected output will be printed out.
20
+
21
+ Currently do not support multi-turn chat. Each time an image and question are input and answer is output.
22
+ """
23
+
24
+ import argparse
25
+ import json
26
+ import os
27
+
28
+ import torch
29
+ from PIL import Image
30
+
31
+ from llava.constants import IMAGE_TOKEN_INDEX
32
+ from llava.conversation import SeparatorStyle, conv_templates
33
+ from llava.mm_utils import (KeywordsStoppingCriteria, process_images,
34
+ tokenizer_image_token)
35
+ from llava.model import *
36
+
37
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
38
+
39
+
40
+ from llava.model.builder import load_pretrained_model
41
+
42
+
43
+ def eval_model(args, model, tokenizer, image_processor):
44
+ # read json file
45
+ with open(args.test_json_path) as f:
46
+ all_test_cases = json.load(f)
47
+
48
+ result_list = []
49
+ print(len(all_test_cases["test_cases"]))
50
+
51
+ for test_case in all_test_cases["test_cases"]:
52
+ # read images first
53
+ image_file_list = test_case["image_paths"]
54
+ image_list = [
55
+ Image.open(os.path.join(args.test_image_path, image_file)).convert("RGB") for image_file in image_file_list
56
+ ]
57
+ image_tensor = process_images(image_list, image_processor, model.config)
58
+
59
+ # image_tokens = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
60
+
61
+ for i in range(len(test_case["QAs"])):
62
+ query = test_case["QAs"][i]["question"]
63
+ query_text = query
64
+
65
+ if 1:
66
+ # query = query.replace("<image>", image_tokens)
67
+ if len(image_list) < 3:
68
+ conv = conv_templates["vicuna_v1"].copy()
69
+ else:
70
+ conv = conv_templates["vicuna_v1_nosys"].copy()
71
+ conv.append_message(conv.roles[0], query)
72
+ conv.append_message(conv.roles[1], None)
73
+ prompt = conv.get_prompt()
74
+ else:
75
+ conv = conv_templates[args.conv_mode].copy()
76
+ if not "<image>" in query:
77
+ assert "###" not in query # single query
78
+ query = image_tokens + "\n" + query # add <image>
79
+ query_list = [query]
80
+ else:
81
+ query_list = query.split("###")
82
+ assert len(query_list) % 2 == 1 # the last one is from human
83
+
84
+ new_query_list = []
85
+ for idx, query in enumerate(query_list):
86
+ if "<image>" in query:
87
+ assert idx % 2 == 0 # only from human
88
+ # assert query.startswith("<image>")
89
+ # query = query.replace("<image>", image_tokens)
90
+ new_query_list.append(query)
91
+ query_list = new_query_list
92
+
93
+ for idx, query in enumerate(query_list):
94
+ conv.append_message(conv.roles[idx % 2], query)
95
+ conv.append_message(conv.roles[1], None)
96
+ prompt = conv.get_prompt()
97
+
98
+ print("%" * 10 + " " * 5 + "VILA Response" + " " * 5 + "%" * 10)
99
+
100
+ # inputs = tokenizer([prompt])
101
+ inputs = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX)
102
+ input_ids = torch.as_tensor(inputs).cuda().unsqueeze(0)
103
+
104
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
105
+ keywords = [stop_str]
106
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
107
+
108
+ # outputs = run_llava.process_outputs(args, model, tokenizer, input_ids, image_tensor, stopping_criteria, stop_str)
109
+ with torch.inference_mode():
110
+ output_ids = model.generate(
111
+ input_ids,
112
+ images=image_tensor.to(dtype=torch.float16, device="cuda", non_blocking=True),
113
+ do_sample=True if args.temperature > 0 else False,
114
+ temperature=args.temperature,
115
+ top_p=0.7,
116
+ # top_p=args.top_p,
117
+ # num_beams=args.num_beams,
118
+ max_new_tokens=512,
119
+ # use_cache=True,
120
+ stopping_criteria=[stopping_criteria],
121
+ )
122
+
123
+ outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0]
124
+ outputs = outputs.strip()
125
+
126
+ print(f"Question: {query_text}")
127
+ print(f"VILA output: {outputs}")
128
+ print(f'Expected output: {test_case["QAs"][i]["expected_answer"]}')
129
+
130
+ result_list.append(
131
+ dict(question=query_text, output=outputs, expected_output=test_case["QAs"][i]["expected_answer"])
132
+ )
133
+ return result_list
134
+
135
+
136
+ if __name__ == "__main__":
137
+ parser = argparse.ArgumentParser()
138
+ parser.add_argument("--model-name", type=str, default=None)
139
+ parser.add_argument("--test_json_path", type=str, default=None)
140
+ parser.add_argument("--test_image_path", type=str, default=None)
141
+ parser.add_argument("--conv-mode", type=str, default=None)
142
+ parser.add_argument("--temperature", type=float, default=0.2)
143
+ parser.add_argument("--pad", action="store_true")
144
+
145
+ args = parser.parse_args()
146
+
147
+ tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_name, "llava_llama", None)
148
+ result_list = eval_model(args, model, tokenizer, image_processor)
149
+ save_name = f"inference-test_{args.model_name.split('/')[-1]}"
150
+ if "nosys" in args.conv_mode:
151
+ save_name += "_nosys"
152
+ save_name += ".json"
153
+ result_list_str = json.dumps(result_list, indent=2)
VILA/llava.egg-info/PKG-INFO ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: llava
3
+ Version: 1.0.0
4
+ Summary: VILA: On Pre-training for Visual Language Models
5
+ Project-URL: Homepage, https://hanlab.mit.edu/projects/vila
6
+ Project-URL: Bug Tracker, https://github.com/Efficient-Large-Model/VILA-Internal/issues
7
+ Classifier: Programming Language :: Python :: 3
8
+ Classifier: License :: OSI Approved :: Apache Software License
9
+ Requires-Python: >=3.8
10
+ Description-Content-Type: text/markdown
11
+ License-File: LICENSE
12
+ Requires-Dist: torch==2.0.1
13
+ Requires-Dist: torchvision==0.15.2
14
+ Requires-Dist: transformers==4.31.0
15
+ Requires-Dist: tokenizers<0.14,>=0.12.1
16
+ Requires-Dist: sentencepiece==0.1.99
17
+ Requires-Dist: shortuuid
18
+ Requires-Dist: accelerate==0.27.2
19
+ Requires-Dist: peft==0.5.0
20
+ Requires-Dist: bitsandbytes==0.41.0
21
+ Requires-Dist: pydantic<2,>=1
22
+ Requires-Dist: markdown2[all]
23
+ Requires-Dist: numpy
24
+ Requires-Dist: scikit-learn==1.2.2
25
+ Requires-Dist: gradio==3.35.2
26
+ Requires-Dist: gradio_client==0.2.9
27
+ Requires-Dist: requests
28
+ Requires-Dist: httpx==0.24.0
29
+ Requires-Dist: uvicorn
30
+ Requires-Dist: fastapi
31
+ Requires-Dist: einops==0.6.1
32
+ Requires-Dist: einops-exts==0.0.4
33
+ Requires-Dist: timm==0.6.13
34
+ Requires-Dist: openpyxl==3.1.2
35
+ Requires-Dist: pytorchvideo==0.1.5
36
+ Requires-Dist: datasets==2.16.1
37
+ Requires-Dist: openai==1.8.0
38
+ Requires-Dist: webdataset==0.2.86
39
+ Provides-Extra: train
40
+ Requires-Dist: deepspeed==0.13.2; extra == "train"
41
+ Requires-Dist: ninja; extra == "train"
42
+ Requires-Dist: wandb; extra == "train"
43
+ Provides-Extra: eval
44
+ Requires-Dist: mmengine; extra == "eval"
45
+ Requires-Dist: word2number; extra == "eval"
46
+ Requires-Dist: Levenshtein; extra == "eval"
47
+
48
+ <p align="center">
49
+ <img src="demo_images/vila-logo.jpg" width="20%"/>
50
+ </p>
51
+
52
+ # VILA: On Pre-training for Visual Language Models
53
+
54
+ [![Code License](https://img.shields.io/badge/Code%20License-Apache_2.0-green.svg)](CODE_LICENSE)
55
+ [![Model License](https://img.shields.io/badge/MODEL%20License-CC%20By%20NC%204.0-red.svg)](MODEL_LICENSE)
56
+ [![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/release/python-3100/)
57
+
58
+
59
+ [VILA arxiv](https://arxiv.org/abs/2312.07533) / [VILA Demo](https://vila-demo.hanlab.ai/) / [VILA Huggingface](https://huggingface.co/collections/Efficient-Large-Model/vila-on-pre-training-for-visual-language-models-65d8022a3a52cd9bcd62698e)
60
+
61
+ ## 💡 Introduction
62
+ VILA is a visual language model (VLM) pretrained with interleaved image-text data at scale, enabling multi-image VLM. VILA is deployable on the edge, including Jetson Orin and laptop by [AWQ](https://arxiv.org/pdf/2306.00978.pdf) 4bit quantization through [TinyChat](https://github.com/mit-han-lab/llm-awq/tree/main/tinychat) framework. We find: (1) image-text pairs are not enough, interleaved image-text is essential; (2) unfreezing LLM during interleaved image-text pre-training enables in-context learning; (3)re-blending text-only instruction data is crucial to boost both VLM and text-only performance. VILA unveils appealing capabilities, including: multi-image reasoning, in-context learning, visual chain-of-thought, and better world knowledge.
63
+
64
+
65
+ ## 💡 News
66
+ - [2024/02] We release [AWQ](https://arxiv.org/pdf/2306.00978.pdf)-quantized 4bit VILA models, deployable on Jetson Orin and laptops through [TinyChat](https://github.com/mit-han-lab/llm-awq/tree/main/tinychat) and [TinyChatEngine](https://github.com/mit-han-lab/TinyChatEngine).
67
+ - [2024/02] VILA is released. We propose interleaved image-text pretraining that enables multi-image VLM. VILA comes with impressive in-context learning capabilities. We open source everything: including training code, evaluation code, datasets, model ckpts.
68
+ - [2023/12] [Paper](https://arxiv.org/abs/2312.07533) is on Arxiv!
69
+
70
+ ## Performance
71
+
72
+ | $~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~$ | Prec. | VQAv2 | GQA | VizWiz | SQA-I | VQA-T | POPE | MME | MMB | MMB-CN | SEED | llava-bench | MM-Vet | Average (w/o MME) |
73
+ | ----------------- | ---------------- | ---------------- | ---------- | ----------- | ----------- | ----- | ----- | ------- | ---- | ------ | ---- | ----------- | ------ | ----------------- |
74
+ | VILA-7B | fp16 | 80.3 | 63.1 | 59.6 | 68.0 | 62.6 | 86.3 | 1489.4 | 69.8 | 61.0 | 61.7 | 75.2 | 35.1 | 65.7 |
75
+ | VILA-7B-AWQ | int4 | 80.1 | 63.0 | 57.8 | 68.0 | 61.9 | 85.3 | 1486.3 | 68.8 | 59.0 | 61.3 | 75.8 | 35.9 | 65.2 |
76
+ | VILA-13B | fp16| 80.5 | 63.6 | 63.1 | 70.5 | 64.0 | 86.3 | 1553.6 | 73.8 | 66.7 | 62.8 | 78.3 | 42.6 | 68.4 |
77
+ | VILA-13B-AWQ | int4 | 80.4 | 63.6 | 63.0 | 71.2 | 63.5 | 87.0 | 1552.9 | 73.6 | 66.3 | 62.2 | 77.6 | 42.0 | 68.2 |
78
+
79
+ <sup>NOTE: The benchmark results are slightly different from what we report in the paper due to refactoring of the codebase based on LLava-1.5 and re-train the model. VQAV2 and VizWiz are test-dev.</sup>
80
+
81
+ ### Inference speed ( Token/sec )
82
+
83
+ | $~~~~~~$ | Precision | A100 | 4090 | Orin |
84
+ | --- | --- |--- | --- | --- |
85
+ | VILA-7B | fp16 | 81.6 | 58.5 | 11.5 |
86
+ | VILA-7B-AWQ| int4 |155.3| 168.1| 35.6 |
87
+ | VILA-13B | fp16 | 48.5 | OOM | 6.1 |
88
+ | VILA-13B-AWQ | int4 | 102.1| 99.0| 17.5 |
89
+
90
+
91
+ ## VILA Examples
92
+
93
+ ### In context learning
94
+ <img src="demo_images/demo_img_1.png" height="239">
95
+ <img src="demo_images/demo_img_2.png" height="250">
96
+
97
+ ### Multi-image reasoning
98
+ <img src="demo_images/demo_img_3.png" height="193">
99
+
100
+
101
+ ### VILA on Jetson Orin
102
+
103
+ https://github.com/Efficient-Large-Model/VILA/assets/7783214/6079374c-0787-4bc4-b9c6-e1524b4c9dc4
104
+
105
+ ### VILA on RTX 4090
106
+
107
+ https://github.com/Efficient-Large-Model/VILA/assets/7783214/80c47742-e873-4080-ad7d-d17c4700539f
108
+
109
+ </details>
110
+
111
+ ## Installation
112
+
113
+ ```bash
114
+ ./environment_setup.sh
115
+ ```
116
+
117
+ or follow the instructions below in order.
118
+
119
+ ```
120
+ conda create -n vila python=3.10 -y
121
+ conda activate vila
122
+
123
+ pip install --upgrade pip # enable PEP 660 support
124
+ wget https://github.com/Dao-AILab/flash-attention/releases/download/v2.4.2/flash_attn-2.4.2+cu118torch2.0cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
125
+ pip install flash_attn-2.4.2+cu118torch2.0cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
126
+ pip install -e .
127
+ pip install -e ".[train]"
128
+
129
+ pip install git+https://github.com/huggingface/transformers@v4.38.1
130
+ cp -r ./llava/train/transformers_replace/* ~/anaconda3/envs/vila/lib/python3.10/site-packages/transformers/
131
+ ```
132
+
133
+ ## Training
134
+
135
+ VILA training contains three steps
136
+
137
+ ### Step-1: Alignment
138
+ We utilize LLaVA-CC3M-Pretrain-595K dataset to align the textual and visual modalities.
139
+
140
+ The stage 1 script takes in two parameters and it can run on a single 8xA100 node. `BASE_MODEL_PATH` points to a online or local huggingface repository, such as `NousResearch/Llama-2-7b-hf`. `OUTPUT_NAME` points to a target directory under `checkpoints`, which will save the trained multimodal projector afterwards.
141
+
142
+ ```bash
143
+ bash scripts/v1_5/paper/1_mm_align.sh [BASE_MODEL_PATH] [OUTPUT_NAME]
144
+ ```
145
+
146
+ | Hyperparameter | Global Batch Size | Learning rate | Epochs | Max length | Weight decay |
147
+ | --- | ---: | ---: | ---: | ---: | ---: |
148
+ | VILA-7B | 256 | 2e-5 | 1 | 4096 | 0 |
149
+ | VILA-13B | 256 | 2e-5 | 1 | 4096 | 0 |
150
+
151
+
152
+ ### Step-2: Pretraining
153
+ We use MMC4 and Coyo dataset to train VLM with interleaved image-text pairs.
154
+
155
+ ```bash
156
+ bash scripts/v1_5/paper/2_pretrain_mmc4_coyo.sh [CODE_PATH] [BASE_MODEL_PATH] [STAGE1_PATH] [OUTPUT_NAME]
157
+ ```
158
+
159
+ The stage 2 script takes in four arguments. `CODE_PATH` is the absolute path to our VILA codebase, `BASE_MODEL_PATH` has similar meaning to what is presented in the stage 1 script. `STAGE1_PATH` points to the `OUTPUT_NAME` of stage 1 (i.e. where the stage 1 checkpoint is stored). `OUTPUT_NAME` is the desired folder name under `checkpoints` that saves the pretraining checkpoint. The script we provided for this stage is executed on slurm, and we expect it to execute on 16 nodes (128 GPUs).
160
+
161
+ | Hyperparameter | Global Batch Size | Learning rate | Epochs | Max length | Weight decay |
162
+ | --- | ---: | ---: | ---: | ---: | ---: |
163
+ | VILA-7B | 1024 | 5e-5 | 1 | 4096 | 0 |
164
+ | VILA-13B | 1024 | 5e-5 | 1 | 4096 | 0 |
165
+
166
+ ### Step-3: Supervised fine-tuning
167
+ This is the last stage of VILA training, in which we tune the model to follow multimodal instructions on a subset of M3IT, FLAN and ShareGPT4V. This stage runs on a 8xA100 node.
168
+
169
+ ```bash
170
+ bash scripts/v1_5/paper/3_sft.sh [STAGE2_PATH] [OUTPUT_NAME]
171
+ ```
172
+ The stage 3 script takes in two arguments. `STAGE2_PATH` points to the `OUTPUT_NAME` of the stage 2 script (i.e. where the stage 2 checkpoint is stored). `OUTPUT_NAME` is the desired folder name under `checkpoints` that stores the final checkpoint.
173
+
174
+ | Hyperparameter | Global Batch Size | Learning rate | Epochs | Max length | Weight decay |
175
+ | --- | ---: | ---: | ---: | ---: | ---: |
176
+ | VILA-7B | 128 | 2e-5 | 1 | 4096 | 0 |
177
+ | VILA-13B | 128 | 2e-5 | 1 | 4096 | 0 |
178
+
179
+ ### Training with fewer GPUs
180
+ To train with fewer GPUs/nodes, you can reduce the `per_device_train_batch_size` and increase the `gradient_accumulation_steps` accordingly. As long as the global batch size same (`per_device_train_batch_size` x `gradient_accumulation_steps` x `num_gpus`) are kept the same, the training precision will not be affected.
181
+
182
+ Stage 1 completes within 3.5 (7B) - 5.5 (13B) hours on 8xA100, Stage 2 completes within 30 hours on 128xA100 for VILA-7B, and stage 3 completes in 25 (7B) - 40 (13B) hours on 8xA100.
183
+
184
+ See [data_prepare/README.md](data_prepare/README.md) for more information about how to prepare datasets.
185
+
186
+ ## Evaluations
187
+
188
+ You can follow [Llava1.5 eval](https://github.com/haotian-liu/LLaVA/blob/main/docs/Evaluation.md) to download all datasets. After downloading all datasets, please put them under `playground/data/eval`.
189
+
190
+ We provide a push-the-button script to perform evaluation on all 10 datasets that do not require GPT-assisted evaluation:
191
+
192
+ ```bash
193
+ ./scripts/v1_5/eval/eval_all.sh [CHECKPOINT_PATH] [MODEL_NAME]
194
+ ```
195
+
196
+ This script takes in two parameters, `CHECKPOINT_PATH` points to the stage 3 model checkpoint, and `MODEL_NAME` will be the name of evaluation results.
197
+
198
+
199
+ [VQAv2](https://eval.ai/web/challenges/challenge-page/830/my-submission) and [Vizwiz](https://eval.ai/web/challenges/challenge-page/2185/my-submission) evaluations are hosted on eval.ai. You need to register an account and create a team to be able to submit eval.
200
+
201
+ MMBench and MMBench_CN eval are hosted on another [evaluation server](https://opencompass.org.cn/leaderboard-multimodal). Make sure you change the name of the file before submitting, otherwise the server caches results and will always return wrong result to you.
202
+
203
+ We provide a quick script to automatically organize the prediction files that need to be submitted to servers:
204
+
205
+ ```bash
206
+ python scripts/v1_5/eval/copy_predictions.py [MODEL_NAME]
207
+ ```
208
+
209
+ You will be able to find the predictions under `playground/data/predictions_upload/[MODEL_NAME]` after executing this script.
210
+
211
+ ## Inference
212
+
213
+ We provide snippets for quick inference with user prompts and images.
214
+
215
+ VILA-7B inference:
216
+ ```bash
217
+ python -W ignore llava/eval/run_llava.py \
218
+ --model-name Efficient-Large-Model/VILA-7B \
219
+ --conv-mode vicuna_v1 \
220
+ --query "<image>\n Please describe the traffic condition." \
221
+ --image-file "av.png"
222
+ ```
223
+
224
+ VILA-13B inference:
225
+ ```bash
226
+ python -W ignore llava/eval/run_llava.py \
227
+ --model-name Efficient-Large-Model/VILA-13B \
228
+ --conv-mode vicuna_v1 \
229
+ --query "<image>\n Please describe the traffic condition." \
230
+ --image-file "av.png"
231
+ ```
232
+
233
+ ## Quantization and Deployment
234
+
235
+ Our VILA models are quantized by [AWQ](https://arxiv.org/abs/2306.00978) into 4 bits for efficient inference on the edge. We provide a push-the-button [script](https://github.com/mit-han-lab/llm-awq/blob/main/scripts/vila_example.sh) to quantize VILA with AWQ.
236
+
237
+ ### Running VILA on desktop GPUs and edge GPUs
238
+
239
+ We support AWQ-quantized 4bit VILA on GPU platforms via [TinyChat](https://github.com/mit-han-lab/llm-awq/tree/main/tinychat). We provide a [tutorial](https://github.com/mit-han-lab/llm-awq/tree/main/tinychat#support-vlm-models-vila--llava) to run the model with TinyChat after quantization. We also provide an [instruction](https://github.com/mit-han-lab/llm-awq/tree/main/tinychat/serve) to launch a Gradio server (powered by TinyChat and AWQ) to serve 4-bit quantized VILA models.
240
+
241
+ ### Running VILA on laptops
242
+
243
+ We further support our AWQ-quantized 4bit VILA models on various CPU platforms with both x86 and ARM architectures with our [TinyChatEngine](https://github.com/mit-han-lab/TinyChatEngine). We also provide a detailed [tutorial](https://github.com/mit-han-lab/TinyChatEngine/tree/main?tab=readme-ov-file#deploy-vision-language-model-vlm-chatbot-with-tinychatengine) to help the users deploy VILA on different CPUs.
244
+
245
+
246
+
247
+ ## Checkpoints
248
+
249
+ We release [VILA-7B](https://hf.co/Efficient-Large-Model/VILA-7b), [VILA-13B](https://hf.co/Efficient-Large-Model/VILA-13b), [VILA-7B-4bit-AWQ](https://hf.co/Efficient-Large-Model/VILA-7b-4bit-awq) and [VILA-13B-4bit-AWQ](https://hf.co/Efficient-Large-Model/VILA-13b-4bit-awq).
250
+
251
+ ## 🔒 License
252
+ - The code is released under the Apache 2.0 license as found in the [LICENSE](./LICENSE) file.
253
+ - The pretrained weights are released under the [CC-BY-NC-SA-4.0 license](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en).
254
+ - The service is a research preview intended for non-commercial use only, and is subject to the following licenses and terms:
255
+ - [Model License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA
256
+ - [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI
257
+ - [Dataset Licenses](./data_prepare/LICENSE) for each one used during training.
258
+
259
+ ## Team
260
+ | | | |
261
+ | --- | --- | ---|
262
+ [*Ji Lin](https://www.linji.me/): OpenAI (work done at Nvidia and MIT) | [*Hongxu Yin](https://hongxu-yin.github.io/): Nvidia | [*Yao Lu](https://scholar.google.com/citations?user=OI7zFmwAAAAJ&hl=en): Nvidia
263
+ [Wei Ping](https://scholar.google.com/citations?user=6gKEYRgAAAAJ&hl=en): Nvidia | [Pavlo Molchanov](https://www.pmolchanov.com/): Nvidia | [Andrew Tao](https://scholar.google.com/citations?user=Wel9l1wAAAAJ&hl=en): Nvidia |
264
+ [Haotian Tang](http://kentang.net/): MIT | [Shang Yang](https://ys-2020.github.io/): MIT | [Ligeng Zhu](https://lzhu.me/): Nvidia, MIT |
265
+ [Wei-Chen Wang](https://weichenwang.me/): MIT | [Fuzhao Xue](https://xuefuzhao.github.io/): Nvidia, NUS | [Yunhao Fang](https://seerkfang.github.io/): Nvidia, UCSD |
266
+ [Yukang Chen](https://yukangchen.com/): Nvidia, CUHK | [Yue Shen](https://www.linkedin.com/in/yue-james-shen/): Nvidia | [Huizi Mao](https://scholar.google.com/citations?user=r5WezOYAAAAJ&hl=zh-CN): Nvidia |
267
+ [Jan Kautz](https://jankautz.com/): Nvidia | [Mohammad Shoeybi](https://scholar.google.com/citations?user=62ElavIAAAAJ&hl=en): Nvidia | [Song Han](http://songhan.mit.edu/): Nvidia, MIT
268
+
269
+
270
+ ## Citations
271
+
272
+ ```
273
+ @misc{lin2023vila,
274
+ title={VILA: On Pre-training for Visual Language Models},
275
+ author={Ji Lin and Hongxu Yin and Wei Ping and Yao Lu and Pavlo Molchanov and Andrew Tao and Huizi Mao and Jan Kautz and Mohammad Shoeybi and Song Han},
276
+ year={2023},
277
+ eprint={2312.07533},
278
+ archivePrefix={arXiv},
279
+ primaryClass={cs.CV}
280
+ }
281
+ ```
282
+
283
+ # Acknowledgement
284
+ - [LLaVA](https://github.com/haotian-liu/LLaVA): the codebase we built upon. Thanks for their wonderful work.
285
+ - [Vicuna](https://github.com/lm-sys/FastChat): the amazing open-sourced large language model!
286
+ - [Video-ChatGPT](https://github.com/mbzuai-oryx/Video-ChatGPT): we borrowed video evaluation script from this repository.
287
+ - [MMC4](https://github.com/allenai/mmc4), [COYO-700M](https://github.com/kakaobrain/coyo-dataset), [M3IT](https://huggingface.co/datasets/MMInstruction/M3IT), [OpenORCA/FLAN](https://huggingface.co/datasets/Open-Orca/FLAN), [ShareGPT4V](https://github.com/InternLM/InternLM-XComposer/tree/main/projects/ShareGPT4V) for providing datasets used in this research.
VILA/llava.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE
2
+ README.md
3
+ pyproject.toml
4
+ CIs/send_email.py
5
+ data_prepare/coyo/coyo_downloader.py
6
+ data_prepare/coyo/coyo_splitter.py
7
+ data_prepare/mmc4/mmc4_downloader.py
8
+ data_prepare/mmc4/mmc4_filter_and_counter.py
9
+ data_prepare/mmc4/mmc4_merger.py
10
+ data_prepare/sft/preprocess_flan.py
11
+ data_prepare/sft/preprocess_m3it.py
12
+ data_prepare/sft/split_vflan.py
13
+ demo_trt_llm/llava.py
14
+ demo_trt_llm/test_vila.py
15
+ inference_test/dataset_test.py
16
+ inference_test/inference_test.py
17
+ llava/__init__.py
18
+ llava/constants.py
19
+ llava/conversation.py
20
+ llava/mm_utils.py
21
+ llava/unit_test_utils.py
22
+ llava/utils.py
23
+ llava.egg-info/PKG-INFO
24
+ llava.egg-info/SOURCES.txt
25
+ llava.egg-info/dependency_links.txt
26
+ llava.egg-info/requires.txt
27
+ llava.egg-info/top_level.txt
28
+ llava/data/__init__.py
29
+ llava/data/dataset.py
30
+ llava/data/dataset_tar.py
31
+ llava/data/datasets_mixture.py
32
+ llava/data/simple_video_dataset.py
33
+ llava/data/simple_vila_webdataset.py
34
+ llava/data/dataset_impl/coyo_recap.py
35
+ llava/data/dataset_impl/sam.py
36
+ llava/data_aug/caption2qa.py
37
+ llava/data_aug/dev.py
38
+ llava/data_aug/reformat_tar.py
39
+ llava/eval/eval_gpt_review.py
40
+ llava/eval/eval_gpt_review_bench.py
41
+ llava/eval/eval_gpt_review_visual.py
42
+ llava/eval/eval_mathvista.py
43
+ llava/eval/eval_mmmu.py
44
+ llava/eval/eval_mmvet.py
45
+ llava/eval/eval_pope.py
46
+ llava/eval/eval_science_qa.py
47
+ llava/eval/eval_science_qa_gpt4.py
48
+ llava/eval/eval_science_qa_gpt4_requery.py
49
+ llava/eval/eval_textvqa.py
50
+ llava/eval/evaluate_vqa.py
51
+ llava/eval/generate_webpage_data_from_table.py
52
+ llava/eval/m4c_evaluator.py
53
+ llava/eval/model_qa.py
54
+ llava/eval/model_vqa.py
55
+ llava/eval/model_vqa_loader.py
56
+ llava/eval/model_vqa_mmbench.py
57
+ llava/eval/model_vqa_mmmu.py
58
+ llava/eval/model_vqa_qbench.py
59
+ llava/eval/model_vqa_science.py
60
+ llava/eval/model_vqa_video.py
61
+ llava/eval/qa_baseline_gpt35.py
62
+ llava/eval/run_llava.py
63
+ llava/eval/summarize_gpt_review.py
64
+ llava/eval/mathvista_utils/calculate_score.py
65
+ llava/eval/mathvista_utils/extract_answer.py
66
+ llava/eval/mathvista_utils/utilities.py
67
+ llava/eval/mathvista_utils/prompts/ext_ans.py
68
+ llava/eval/mmmu_utils/data_utils.py
69
+ llava/eval/mmmu_utils/eval_utils.py
70
+ llava/eval/mmmu_utils/model_utils.py
71
+ llava/eval/video/eval_benchmark_1_correctness.py
72
+ llava/eval/video/eval_benchmark_2_detailed_orientation.py
73
+ llava/eval/video/eval_benchmark_3_context.py
74
+ llava/eval/video/eval_benchmark_4_temporal.py
75
+ llava/eval/video/eval_benchmark_5_consistency.py
76
+ llava/eval/video/eval_video_qa.py
77
+ llava/model/__init__.py
78
+ llava/model/apply_delta.py
79
+ llava/model/builder.py
80
+ llava/model/consolidate.py
81
+ llava/model/llava_arch.py
82
+ llava/model/make_delta.py
83
+ llava/model/utils.py
84
+ llava/model/language_model/llava_gemma.py
85
+ llava/model/language_model/llava_llama.py
86
+ llava/model/language_model/llava_mistral.py
87
+ llava/model/language_model/llava_mixtral.py
88
+ llava/model/language_model/llava_mpt.py
89
+ llava/model/language_model/mpt/adapt_tokenizer.py
90
+ llava/model/language_model/mpt/attention.py
91
+ llava/model/language_model/mpt/blocks.py
92
+ llava/model/language_model/mpt/configuration_mpt.py
93
+ llava/model/language_model/mpt/custom_embedding.py
94
+ llava/model/language_model/mpt/flash_attn_triton.py
95
+ llava/model/language_model/mpt/hf_prefixlm_converter.py
96
+ llava/model/language_model/mpt/meta_init_context.py
97
+ llava/model/language_model/mpt/modeling_mpt.py
98
+ llava/model/language_model/mpt/norm.py
99
+ llava/model/language_model/mpt/param_init_fns.py
100
+ llava/model/multimodal_encoder/builder.py
101
+ llava/model/multimodal_encoder/clip_encoder.py
102
+ llava/model/multimodal_encoder/siglip_encoder.py
103
+ llava/model/multimodal_encoder/vision_encoder.py
104
+ llava/model/multimodal_encoder/radio/__init__.py
105
+ llava/model/multimodal_encoder/radio/cls_token.py
106
+ llava/model/multimodal_encoder/radio/create_model.py
107
+ llava/model/multimodal_encoder/radio/enable_cpe_support.py
108
+ llava/model/multimodal_encoder/radio/enable_spectral_reparam.py
109
+ llava/model/multimodal_encoder/radio/extra_timm_models.py
110
+ llava/model/multimodal_encoder/radio/radio_encoder.py
111
+ llava/model/multimodal_encoder/radio/token_merging.py
112
+ llava/model/multimodal_encoder/radio/vit_patch_generator.py
113
+ llava/model/multimodal_projector/builder.py
114
+ llava/train/args.py
115
+ llava/train/llava_trainer.py
116
+ llava/train/short_video_filter.py
117
+ llava/train/slurm_utils.py
118
+ llava/train/train.py
119
+ llava/train/train_mem.py
120
+ llava/train/train_xformers.py
121
+ llava/train/transformer_normalize_monkey_patch.py
122
+ llava/train/utils.py
123
+ llava/train/transformers_replace/trainer.py
124
+ llava/train/transformers_replace/models/gemma/__init__.py
125
+ llava/train/transformers_replace/models/gemma/configuration_gemma.py
126
+ llava/train/transformers_replace/models/gemma/modeling_gemma.py
127
+ llava/train/transformers_replace/models/llama/configuring_llama.py
128
+ llava/train/transformers_replace/models/llama/modeling_llama.py
129
+ llava/train/transformers_replace/models/llama/tokenization_llama.py
130
+ llava/train/transformers_replace/models/mistral/__init__.py
131
+ llava/train/transformers_replace/models/mistral/configuration_mistral.py
132
+ llava/train/transformers_replace/models/mistral/modeling_mistral.py
133
+ llava/train/transformers_replace/models/mixtral/__init__.py
134
+ llava/train/transformers_replace/models/mixtral/configuration_mixtral.py
135
+ llava/train/transformers_replace/models/mixtral/modeling_mixtral.py
136
+ llava/train/transformers_replace/models/siglip/__init__.py
137
+ llava/train/transformers_replace/models/siglip/configuration_siglip.py
138
+ llava/train/transformers_replace/models/siglip/convert_siglip_to_hf.py
139
+ llava/train/transformers_replace/models/siglip/image_processing_siglip.py
140
+ llava/train/transformers_replace/models/siglip/modeling_siglip.py
141
+ llava/train/transformers_replace/models/siglip/processing_siglip.py
142
+ llava/train/transformers_replace/models/siglip/tokenization_siglip.py
143
+ llava/wids/__init__.py
144
+ llava/wids/wids.py
145
+ llava/wids/wids_bench.py
146
+ llava/wids/wids_cleanup.py
147
+ llava/wids/wids_dir.py
148
+ llava/wids/wids_dl.py
149
+ llava/wids/wids_index.py
150
+ llava/wids/wids_lru.py
151
+ llava/wids/wids_mmtar.py
152
+ llava/wids/wids_specs.py
153
+ llava/wids/wids_tar.py
154
+ tests/test_tokenizer.py
VILA/llava.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
VILA/llava.egg-info/requires.txt ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ torchvision==0.15.2
3
+ transformers==4.31.0
4
+ tokenizers<0.14,>=0.12.1
5
+ sentencepiece==0.1.99
6
+ shortuuid
7
+ accelerate==0.27.2
8
+ peft==0.5.0
9
+ bitsandbytes==0.41.0
10
+ pydantic<2,>=1
11
+ markdown2[all]
12
+ numpy
13
+ scikit-learn==1.2.2
14
+ gradio==3.35.2
15
+ gradio_client==0.2.9
16
+ requests
17
+ httpx==0.24.0
18
+ uvicorn
19
+ fastapi
20
+ einops==0.6.1
21
+ einops-exts==0.0.4
22
+ timm==0.6.13
23
+ openpyxl==3.1.2
24
+ pytorchvideo==0.1.5
25
+ datasets==2.16.1
26
+ openai==1.8.0
27
+ webdataset==0.2.86
28
+
29
+ [eval]
30
+ mmengine
31
+ word2number
32
+ Levenshtein
33
+
34
+ [train]
35
+ deepspeed==0.13.2
36
+ ninja
37
+ wandb
VILA/llava.egg-info/top_level.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ CIs
2
+ data
3
+ data_prepare
4
+ demo_images
5
+ demo_trt_llm
6
+ inference_test
7
+ llava
VILA/llava/.DS_Store ADDED
Binary file (6.15 kB). View file
 
VILA/llava/constants.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
18
+
19
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
20
+ WORKER_HEART_BEAT_INTERVAL = 15
21
+
22
+ LOGDIR = "."
23
+
24
+ # Model Constants
25
+ IGNORE_INDEX = -100
26
+ IMAGE_TOKEN_INDEX = -200
27
+ DEFAULT_IMAGE_TOKEN = "<image>"
28
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
29
+ DEFAULT_IM_START_TOKEN = "<im_start>"
30
+ DEFAULT_IM_END_TOKEN = "<im_end>"
31
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
VILA/llava/conversation.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+ # This file is modified from https://github.com/haotian-liu/LLaVA/
17
+
18
+ import dataclasses
19
+ from enum import Enum, auto
20
+ from typing import List
21
+
22
+
23
+ class SeparatorStyle(Enum):
24
+ """Different separator style."""
25
+
26
+ AUTO = auto()
27
+ SINGLE = auto()
28
+ TWO = auto()
29
+ MPT = auto()
30
+ PLAIN = auto()
31
+ LLAMA_2 = auto()
32
+ MISTRAL = auto()
33
+ LLAMA_3 = auto()
34
+
35
+
36
+ @dataclasses.dataclass
37
+ class Conversation:
38
+ """A class that keeps all conversation history."""
39
+
40
+ system: str
41
+ roles: List[str]
42
+ messages: List[List[str]]
43
+ offset: int
44
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
45
+ sep: str = "###"
46
+ sep2: str = None
47
+ version: str = "Unknown"
48
+
49
+ skip_next: bool = False
50
+
51
+ def get_prompt(self):
52
+ messages = self.messages
53
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
54
+ messages = self.messages.copy()
55
+ init_role, init_msg = messages[0].copy()
56
+ init_msg = init_msg[0].replace("<image>", "").strip()
57
+ if "mmtag" in self.version:
58
+ messages[0] = (init_role, init_msg)
59
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
60
+ messages.insert(1, (self.roles[1], "Received."))
61
+ else:
62
+ messages[0] = (init_role, "<image>\n" + init_msg)
63
+
64
+ if self.sep_style == SeparatorStyle.SINGLE:
65
+ ret = self.system + self.sep
66
+ for role, message in messages:
67
+ if message:
68
+ if type(message) is tuple:
69
+ message, _, _ = message
70
+ ret += role + ": " + message + self.sep
71
+ else:
72
+ ret += role + ":"
73
+ elif self.sep_style == SeparatorStyle.TWO:
74
+ seps = [self.sep, self.sep2]
75
+ ret = self.system + seps[0]
76
+ for i, (role, message) in enumerate(messages):
77
+ if message:
78
+ if type(message) is tuple:
79
+ message, _, _ = message
80
+ ret += role + ": " + message + seps[i % 2]
81
+ else:
82
+ ret += role + ":"
83
+ elif self.sep_style == SeparatorStyle.LLAMA_3:
84
+ ret = self.system + self.sep
85
+ for rid, (role, message) in enumerate(messages):
86
+ if message:
87
+ if type(message) is tuple:
88
+ message = message[0]
89
+ sep = self.sep if rid < len(messages) - 1 else self.sep2
90
+ ret += role + message + sep
91
+ else:
92
+ ret += role
93
+ elif self.sep_style == SeparatorStyle.MPT:
94
+ ret = self.system + self.sep
95
+ for role, message in messages:
96
+ if message:
97
+ if type(message) is tuple:
98
+ message, _, _ = message
99
+ ret += role + message + self.sep
100
+ else:
101
+ ret += role
102
+ elif self.sep_style == SeparatorStyle.LLAMA_2 or self.sep_style == SeparatorStyle.MISTRAL:
103
+ if self.sep_style == SeparatorStyle.LLAMA_2:
104
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
105
+ else:
106
+ wrap_sys = lambda msg: f"{msg}" + ("\n" if msg else "")
107
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
108
+ ret = ""
109
+ if self.sep_style == SeparatorStyle.MISTRAL:
110
+ ret += "<s>"
111
+
112
+ for i, (role, message) in enumerate(messages):
113
+ if i == 0:
114
+ assert message, "first message should not be none"
115
+ assert role == self.roles[0], "first message should come from user"
116
+ if message:
117
+ if type(message) is tuple:
118
+ message, _, _ = message
119
+ if i == 0:
120
+ message = wrap_sys(self.system) + message
121
+ if i % 2 == 0:
122
+ message = wrap_inst(message)
123
+ ret += self.sep + message
124
+ else:
125
+ if self.sep_style == SeparatorStyle.LLAMA_2:
126
+ ret += " " + message + " " + self.sep2
127
+ else:
128
+ ret += message + self.sep2
129
+ else:
130
+ ret += ""
131
+ ret = ret.lstrip(self.sep)
132
+ elif self.sep_style == SeparatorStyle.PLAIN:
133
+ seps = [self.sep, self.sep2]
134
+ ret = self.system
135
+ for i, (role, message) in enumerate(messages):
136
+ if message:
137
+ if type(message) is tuple:
138
+ message, _, _ = message
139
+ ret += message + seps[i % 2]
140
+ else:
141
+ ret += ""
142
+ else:
143
+ raise ValueError(f"Invalid style: {self.sep_style}")
144
+
145
+ return ret
146
+
147
+ def append_message(self, role, message):
148
+ self.messages.append([role, message])
149
+
150
+ def get_images(self, return_pil=False):
151
+ images = []
152
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
153
+ if i % 2 == 0:
154
+ if type(msg) is tuple:
155
+ import base64
156
+ from io import BytesIO
157
+
158
+ from PIL import Image
159
+
160
+ msg, image, image_process_mode = msg
161
+ if image_process_mode == "Pad":
162
+
163
+ def expand2square(pil_img, background_color=(122, 116, 104)):
164
+ width, height = pil_img.size
165
+ if width == height:
166
+ return pil_img
167
+ elif width > height:
168
+ result = Image.new(pil_img.mode, (width, width), background_color)
169
+ result.paste(pil_img, (0, (width - height) // 2))
170
+ return result
171
+ else:
172
+ result = Image.new(pil_img.mode, (height, height), background_color)
173
+ result.paste(pil_img, ((height - width) // 2, 0))
174
+ return result
175
+
176
+ image = expand2square(image)
177
+ elif image_process_mode in ["Default", "Crop"]:
178
+ pass
179
+ elif image_process_mode == "Resize":
180
+ image = image.resize((336, 336))
181
+ else:
182
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
183
+ max_hw, min_hw = max(image.size), min(image.size)
184
+ aspect_ratio = max_hw / min_hw
185
+ max_len, min_len = 800, 400
186
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
187
+ longest_edge = int(shortest_edge * aspect_ratio)
188
+ W, H = image.size
189
+ if longest_edge != max(image.size):
190
+ if H > W:
191
+ H, W = longest_edge, shortest_edge
192
+ else:
193
+ H, W = shortest_edge, longest_edge
194
+ image = image.resize((W, H))
195
+ if return_pil:
196
+ images.append(image)
197
+ else:
198
+ buffered = BytesIO()
199
+ image.save(buffered, format="PNG")
200
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
201
+ images.append(img_b64_str)
202
+ return images
203
+
204
+ def to_gradio_chatbot(self):
205
+ ret = []
206
+ for i, (role, msg) in enumerate(self.messages[self.offset :]):
207
+ if i % 2 == 0:
208
+ if type(msg) is tuple:
209
+ import base64
210
+ from io import BytesIO
211
+
212
+ msg, image, image_process_mode = msg
213
+ max_hw, min_hw = max(image.size), min(image.size)
214
+ aspect_ratio = max_hw / min_hw
215
+ max_len, min_len = 800, 400
216
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
217
+ longest_edge = int(shortest_edge * aspect_ratio)
218
+ W, H = image.size
219
+ if H > W:
220
+ H, W = longest_edge, shortest_edge
221
+ else:
222
+ H, W = shortest_edge, longest_edge
223
+ image = image.resize((W, H))
224
+ buffered = BytesIO()
225
+ image.save(buffered, format="JPEG")
226
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
227
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
228
+ msg = img_str + msg.replace("<image>", "").strip()
229
+ ret.append([msg, None])
230
+ else:
231
+ ret.append([msg, None])
232
+ else:
233
+ ret[-1][-1] = msg
234
+ return ret
235
+
236
+ def copy(self):
237
+ return Conversation(
238
+ system=self.system,
239
+ roles=self.roles,
240
+ messages=[[x, y] for x, y in self.messages],
241
+ offset=self.offset,
242
+ sep_style=self.sep_style,
243
+ sep=self.sep,
244
+ sep2=self.sep2,
245
+ version=self.version,
246
+ )
247
+
248
+ def dict(self):
249
+ if len(self.get_images()) > 0:
250
+ return {
251
+ "system": self.system,
252
+ "roles": self.roles,
253
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
254
+ "offset": self.offset,
255
+ "sep": self.sep,
256
+ "sep2": self.sep2,
257
+ }
258
+ return {
259
+ "system": self.system,
260
+ "roles": self.roles,
261
+ "messages": self.messages,
262
+ "offset": self.offset,
263
+ "sep": self.sep,
264
+ "sep2": self.sep2,
265
+ }
266
+
267
+
268
+ conv_auto = Conversation(
269
+ system="",
270
+ roles=("", ""),
271
+ messages=(),
272
+ offset=0,
273
+ sep_style=SeparatorStyle.AUTO,
274
+ sep="\n",
275
+ )
276
+
277
+ conv_vicuna_v0 = Conversation(
278
+ system="A chat between a curious human and an artificial intelligence assistant. "
279
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
280
+ roles=("Human", "Assistant"),
281
+ messages=(
282
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
283
+ (
284
+ "Assistant",
285
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
286
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
287
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
288
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
289
+ "renewable and non-renewable energy sources:\n"
290
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
291
+ "energy sources are finite and will eventually run out.\n"
292
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
293
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
294
+ "and other negative effects.\n"
295
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
296
+ "have lower operational costs than non-renewable sources.\n"
297
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
298
+ "locations than non-renewable sources.\n"
299
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
300
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
301
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
302
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n",
303
+ ),
304
+ ),
305
+ offset=2,
306
+ sep_style=SeparatorStyle.SINGLE,
307
+ sep="###",
308
+ )
309
+
310
+ conv_vicuna_v1 = Conversation(
311
+ system="A chat between a curious user and an artificial intelligence assistant. "
312
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
313
+ roles=("USER", "ASSISTANT"),
314
+ version="v1",
315
+ messages=(),
316
+ offset=0,
317
+ sep_style=SeparatorStyle.TWO,
318
+ sep=" ",
319
+ sep2="</s>",
320
+ )
321
+
322
+ # kentang-mit@: This conversation template is designed for SFT on VFLAN.
323
+ conv_vicuna_v1_nosys = Conversation(
324
+ system="",
325
+ roles=("USER", "ASSISTANT"),
326
+ version="v1_nosys",
327
+ messages=(),
328
+ offset=0,
329
+ sep_style=SeparatorStyle.TWO,
330
+ sep=" ",
331
+ sep2="</s>",
332
+ )
333
+
334
+ conv_llama_2 = Conversation(
335
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
336
+
337
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
338
+ roles=("USER", "ASSISTANT"),
339
+ version="llama_v2",
340
+ messages=(),
341
+ offset=0,
342
+ sep_style=SeparatorStyle.LLAMA_2,
343
+ sep="<s>",
344
+ sep2="</s>",
345
+ )
346
+
347
+ conv_mistral = Conversation(
348
+ system="",
349
+ roles=("USER", "ASSISTANT"),
350
+ version="mistral",
351
+ messages=(),
352
+ offset=0,
353
+ sep_style=SeparatorStyle.MISTRAL,
354
+ sep="",
355
+ sep2="</s>",
356
+ )
357
+
358
+ conv_llava_llama_2 = Conversation(
359
+ system="You are a helpful language and vision assistant. "
360
+ "You are able to understand the visual content that the user provides, "
361
+ "and assist the user with a variety of tasks using natural language.",
362
+ roles=("USER", "ASSISTANT"),
363
+ version="llama_v2",
364
+ messages=(),
365
+ offset=0,
366
+ sep_style=SeparatorStyle.LLAMA_2,
367
+ sep="<s>",
368
+ sep2="</s>",
369
+ )
370
+
371
+ conv_mpt = Conversation(
372
+ system="""<|im_start|>system
373
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
374
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
375
+ version="mpt",
376
+ messages=(),
377
+ offset=0,
378
+ sep_style=SeparatorStyle.MPT,
379
+ sep="<|im_end|>",
380
+ )
381
+
382
+ conv_llava_plain = Conversation(
383
+ system="",
384
+ roles=("", ""),
385
+ messages=(),
386
+ offset=0,
387
+ sep_style=SeparatorStyle.PLAIN,
388
+ sep="\n",
389
+ )
390
+
391
+ conv_llava_v0 = Conversation(
392
+ system="A chat between a curious human and an artificial intelligence assistant. "
393
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
394
+ roles=("Human", "Assistant"),
395
+ messages=(),
396
+ offset=0,
397
+ sep_style=SeparatorStyle.SINGLE,
398
+ sep="###",
399
+ )
400
+
401
+ conv_llava_v0_mmtag = Conversation(
402
+ system="A chat between a curious user and an artificial intelligence assistant. "
403
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
404
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
405
+ roles=("Human", "Assistant"),
406
+ messages=(),
407
+ offset=0,
408
+ sep_style=SeparatorStyle.SINGLE,
409
+ sep="###",
410
+ version="v0_mmtag",
411
+ )
412
+
413
+ conv_llava_v1 = Conversation(
414
+ system="A chat between a curious human and an artificial intelligence assistant. "
415
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
416
+ roles=("USER", "ASSISTANT"),
417
+ version="v1",
418
+ messages=(),
419
+ offset=0,
420
+ sep_style=SeparatorStyle.TWO,
421
+ sep=" ",
422
+ sep2="</s>",
423
+ )
424
+
425
+
426
+ conv_llava_v1_mmtag = Conversation(
427
+ system="A chat between a curious user and an artificial intelligence assistant. "
428
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
429
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
430
+ roles=("USER", "ASSISTANT"),
431
+ messages=(),
432
+ offset=0,
433
+ sep_style=SeparatorStyle.TWO,
434
+ sep=" ",
435
+ sep2="</s>",
436
+ version="v1_mmtag",
437
+ )
438
+
439
+ hermes_2 = Conversation(
440
+ system="<|im_start|>system\nAnswer the questions.",
441
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
442
+ sep_style=SeparatorStyle.MPT,
443
+ sep="<|im_end|>",
444
+ messages=(),
445
+ offset=0,
446
+ version="hermes-2",
447
+ )
448
+
449
+
450
+ # Template added by Yukang. Note (kentang-mit@): sep is <|eot_id|> for official template.
451
+ llama_3_chat = Conversation(
452
+ system="<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. "
453
+ "You are able to understand the visual content that the user provides, "
454
+ "and assist the user with a variety of tasks using natural language.",
455
+ roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
456
+ version="llama_v3",
457
+ messages=(),
458
+ offset=0,
459
+ sep_style=SeparatorStyle.LLAMA_3,
460
+ sep="<|eot_id|>",
461
+ sep2="<|end_of_text|>",
462
+ )
463
+
464
+
465
+ default_conversation = conv_auto
466
+ conv_templates = {
467
+ "auto": conv_auto,
468
+ "default": conv_vicuna_v0,
469
+ "hermes-2": hermes_2,
470
+ "llama_3": llama_3_chat,
471
+ "v0": conv_vicuna_v0,
472
+ "v1": conv_vicuna_v1,
473
+ "vicuna_v1": conv_vicuna_v1,
474
+ "vicuna_v1_nosys": conv_vicuna_v1_nosys,
475
+ "llama_2": conv_llama_2,
476
+ "mistral": conv_mistral,
477
+ "plain": conv_llava_plain,
478
+ "v0_plain": conv_llava_plain,
479
+ "llava_v0": conv_llava_v0,
480
+ "v0_mmtag": conv_llava_v0_mmtag,
481
+ "llava_v1": conv_llava_v1,
482
+ "v1_mmtag": conv_llava_v1_mmtag,
483
+ "llava_llama_2": conv_llava_llama_2,
484
+ "mpt": conv_mpt,
485
+ }
486
+
487
+
488
+ if __name__ == "__main__":
489
+ print(default_conversation.get_prompt())
VILA/llava/entry.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Optional
3
+
4
+ from transformers import PreTrainedModel
5
+
6
+ from llava.mm_utils import get_model_name_from_path
7
+ from llava.model.builder import load_pretrained_model
8
+
9
+ __all__ = ["load"]
10
+
11
+
12
+ def load(model_path: str, model_base: Optional[str] = None) -> PreTrainedModel:
13
+ model_path = os.path.expanduser(model_path)
14
+ model_name = get_model_name_from_path(model_path)
15
+ if os.path.exists(os.path.join(model_path, "model")):
16
+ model_path = os.path.join(model_path, "model")
17
+ _, model, _, _ = load_pretrained_model(model_path, model_name, model_base)
18
+ return model
VILA/llava/mm_utils.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import base64
18
+ import os
19
+ import tempfile
20
+ from io import BytesIO
21
+
22
+ import numpy as np
23
+ import torch
24
+ from PIL import Image
25
+ from transformers import StoppingCriteria
26
+
27
+ from llava.constants import IMAGE_TOKEN_INDEX
28
+
29
+
30
+ def get_frame_from_vcap(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None):
31
+ import cv2
32
+
33
+ if fps == None or frame_count == None:
34
+ # if one of fps or frame_count is None, still recompute
35
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
36
+ frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
37
+ if fps == 0 or frame_count == 0:
38
+ print(f"Video file not found. return empty images. {video_file_name}")
39
+ return [
40
+ Image.new("RGB", (720, 720)),
41
+ ] * num_frames, 0
42
+
43
+ duration = frame_count / fps
44
+ frame_interval = frame_count // num_frames
45
+ if frame_interval == 0 and frame_count <= 1:
46
+ print(f"frame_interval is equal to 0. return empty image. {video_file_name}")
47
+ return [
48
+ Image.new("RGB", (720, 720)),
49
+ ] * num_frames, 0
50
+ # print("duration:", duration, "frames:", frame_count, "intervals:", frame_interval)
51
+
52
+ images = []
53
+ count = 0
54
+ success = True
55
+ frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int)
56
+ while success:
57
+ # print("frame_count:", frame_count, "count:", count, "num_frames:", num_frames, "frame_interval:", frame_interval)
58
+ if frame_count >= num_frames:
59
+ success, frame = vidcap.read()
60
+ if count in frame_indices:
61
+ try:
62
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
63
+ im_pil = Image.fromarray(img)
64
+ images.append(im_pil)
65
+ except BaseException:
66
+ continue
67
+ if len(images) >= num_frames:
68
+ return images, num_frames
69
+ count += 1
70
+ else:
71
+ # Left padding frames if the video is not long enough
72
+ success, frame = vidcap.read()
73
+ if success:
74
+ try:
75
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
76
+ im_pil = Image.fromarray(img)
77
+ images.append(im_pil)
78
+ except BaseException:
79
+ continue
80
+ count += 1
81
+ else:
82
+ break
83
+ if len(images) == 0:
84
+ raise ValueError("Did not find enough frames in the video. return empty image.")
85
+
86
+ return images, len(images)
87
+
88
+
89
+ def get_frame_from_vcap_with_fps(vidcap, num_frames=10, max_fps=0.0, fps=None, frame_count=None, video_file_name=None):
90
+ """
91
+ num_frames is the max number of frames the model can support.
92
+ frame_count is the number of frames in the input video.
93
+ max_fps is the max FPS of the model can support.
94
+ fps is the fps of the input video.
95
+ """
96
+
97
+ import random
98
+
99
+ import cv2
100
+
101
+ if fps == None or frame_count == None:
102
+ # if one of fps or frame_count is None, still recompute
103
+ fps = vidcap.get(cv2.CAP_PROP_FPS)
104
+ frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))
105
+
106
+ if fps == 0 or frame_count == 0:
107
+ print(f"Video file not found. return empty images. {video_file_name}")
108
+ empty_video_frames = int(random.uniform(2, 8 * max_fps))
109
+ return [
110
+ Image.new("RGB", (720, 720)),
111
+ ] * empty_video_frames, 0
112
+
113
+ duration = frame_count / fps
114
+ # print("duration:", duration, "frames:", frame_count, "fps:", fps, "num_frames:", num_frames, "max_fps:", max_fps)
115
+ # If the video is too long (longer than max_fps and num_frames can support),
116
+ # we will use lower fps to sample frames.
117
+ if duration >= num_frames / max_fps:
118
+ frame_interval = frame_count // num_frames
119
+
120
+ # If the video is too short, we will skip the video if there is only one frame.
121
+ if frame_interval == 0 and frame_count <= 1:
122
+ print(f"frame_interval is equal to 0. return empty image. {video_file_name}")
123
+ empty_video_frames = int(random.uniform(2, 8 * max_fps))
124
+ return [
125
+ Image.new("RGB", (720, 720)),
126
+ ] * empty_video_frames, 0
127
+
128
+ images = []
129
+ count = 0
130
+ success = True
131
+ frame_indices = np.linspace(0, frame_count - 1, num_frames, dtype=int)
132
+
133
+ while success:
134
+ if frame_count >= num_frames:
135
+ # success, frame = vidcap.read()
136
+ if count in frame_indices:
137
+ success, frame = vidcap.read()
138
+ try:
139
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
140
+ im_pil = Image.fromarray(img)
141
+ images.append(im_pil)
142
+ except:
143
+ # print("Failed to read frame:", count)
144
+ continue
145
+ if len(images) >= num_frames:
146
+ return images, num_frames
147
+ else:
148
+ success = vidcap.grab()
149
+ count += 1
150
+ else:
151
+ # Left padding frames if the video is not long enough
152
+ success, frame = vidcap.read()
153
+ if success:
154
+ try:
155
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
156
+ im_pil = Image.fromarray(img)
157
+ images.append(im_pil)
158
+ except:
159
+ # print("Failed to read frame:", count)
160
+ continue
161
+ count += 1
162
+ else:
163
+ break
164
+ else:
165
+ frames_required = int(duration * max_fps)
166
+ frame_indices = np.linspace(0, frame_count - 1, frames_required, dtype=int)
167
+ if frames_required == 0:
168
+ print(f"frames_required is fewer than 2. Duration {duration}, return empty image.")
169
+ empty_video_frames = int(random.uniform(2, 8 * max_fps))
170
+ return [
171
+ Image.new("RGB", (720, 720)),
172
+ ] * empty_video_frames, 0
173
+ elif frames_required == 1:
174
+ frame_indices = np.linspace(0, frame_count - 1, 2, dtype=int)
175
+ images = []
176
+ count = 0
177
+ looked = 0
178
+ success = True
179
+
180
+ while success:
181
+ success, frame = vidcap.read()
182
+ if success and (looked in frame_indices):
183
+ try:
184
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
185
+ im_pil = Image.fromarray(img)
186
+ images.append(im_pil)
187
+ except:
188
+ continue
189
+ count += 1
190
+ looked += 1
191
+
192
+ if len(images) == 0:
193
+ empty_video_frames = int(random.uniform(2, 8 * max_fps))
194
+ return [
195
+ Image.new("RGB", (720, 720)),
196
+ ] * empty_video_frames, 0
197
+ else:
198
+ return images, len(images)
199
+
200
+
201
+ def opencv_extract_frames(vpath_or_bytesio, frames=6, max_fps=0.0, fps=None, frame_count=None):
202
+ """
203
+ Extract frames from a video using OpenCV.
204
+
205
+ Args:
206
+ vpath_or_bytesio (str or BytesIO): Path to the video file or BytesIO object containing the video.
207
+ frames (int): Number of frames to extract from the video.
208
+ fps (float): Frames per second of the video. If 0.0, the function will extract frames at equal intervals.
209
+
210
+ Returns:
211
+ list: List of PIL Images extracted from the video.
212
+
213
+ Raises:
214
+ NotImplementedError: If the type of `vpath_or_bytesio` is not supported.
215
+ """
216
+ import cv2
217
+
218
+ if isinstance(vpath_or_bytesio, str):
219
+ vidcap = cv2.VideoCapture(vpath_or_bytesio)
220
+ if max_fps > 0.0:
221
+ return get_frame_from_vcap_with_fps(
222
+ vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=vpath_or_bytesio
223
+ )
224
+ return get_frame_from_vcap(
225
+ vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=vpath_or_bytesio
226
+ )
227
+ elif isinstance(vpath_or_bytesio, (BytesIO,)):
228
+ # assuming mp4
229
+ with tempfile.NamedTemporaryFile(delete=True, suffix=".mp4") as temp_video:
230
+ temp_video.write(vpath_or_bytesio.read())
231
+ temp_video_name = temp_video.name
232
+ vidcap = cv2.VideoCapture(temp_video_name)
233
+ if max_fps > 0.0:
234
+ return get_frame_from_vcap_with_fps(
235
+ vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=temp_video_name
236
+ )
237
+ return get_frame_from_vcap(
238
+ vidcap, frames, max_fps, fps=fps, frame_count=frame_count, video_file_name=temp_video_name
239
+ )
240
+ else:
241
+ raise NotImplementedError(type(vpath_or_bytesio))
242
+
243
+
244
+ def load_image_from_base64(image):
245
+ return Image.open(BytesIO(base64.b64decode(image)))
246
+
247
+
248
+ def expand2square(pil_img, background_color):
249
+ """
250
+ Expand the given PIL image to a square shape by adding padding.
251
+
252
+ Parameters:
253
+ - pil_img: The PIL image to be expanded.
254
+ - background_color: The color of the padding to be added.
255
+
256
+ Returns:
257
+ - The expanded PIL image.
258
+
259
+ If the image is already square, it is returned as is.
260
+ If the image is wider than it is tall, padding is added to the top and bottom.
261
+ If the image is taller than it is wide, padding is added to the left and right.
262
+ """
263
+ width, height = pil_img.size
264
+ if pil_img.mode == "L":
265
+ background_color = background_color[0]
266
+ if width == height:
267
+ return pil_img
268
+ elif width > height:
269
+ result = Image.new(pil_img.mode, (width, width), background_color)
270
+ result.paste(pil_img, (0, (width - height) // 2))
271
+ return result
272
+ else:
273
+ result = Image.new(pil_img.mode, (height, height), background_color)
274
+ result.paste(pil_img, ((height - width) // 2, 0))
275
+ return result
276
+
277
+
278
+ def process_image(image_file, data_args, image_folder):
279
+ processor = data_args.image_processor
280
+ if isinstance(image_file, str):
281
+ if image_folder is not None:
282
+ image = Image.open(os.path.join(image_folder, image_file)).convert("RGB")
283
+ else:
284
+ image = Image.open(image_file).convert("RGB")
285
+ else:
286
+ # image is stored in bytearray
287
+ image = image_file
288
+ image = image.convert("RGB")
289
+ if data_args.image_aspect_ratio == "resize":
290
+ if hasattr(data_args.image_processor, "crop_size"):
291
+ # CLIP vision tower
292
+ crop_size = data_args.image_processor.crop_size
293
+ else:
294
+ # SIGLIP vision tower
295
+ assert hasattr(data_args.image_processor, "size")
296
+ crop_size = data_args.image_processor.size
297
+ image = image.resize((crop_size["height"], crop_size["width"]))
298
+ if data_args.image_aspect_ratio == "pad":
299
+
300
+ def expand2square(pil_img, background_color):
301
+ width, height = pil_img.size
302
+ if width == height:
303
+ return pil_img
304
+ elif width > height:
305
+ result = Image.new(pil_img.mode, (width, width), background_color)
306
+ result.paste(pil_img, (0, (width - height) // 2))
307
+ return result
308
+ else:
309
+ result = Image.new(pil_img.mode, (height, height), background_color)
310
+ result.paste(pil_img, ((height - width) // 2, 0))
311
+ return result
312
+
313
+ image = expand2square(image, tuple(int(x * 255) for x in processor.image_mean))
314
+ image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
315
+ else:
316
+ # Using default behavior of the vision encoder
317
+ # For CLIP, default is central crop
318
+ # For Radio, default is central crop
319
+ # For Siglip, default is resize
320
+ # For InternVIT, default is resize
321
+ image = processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
322
+ return image
323
+
324
+
325
+ def process_images(images, image_processor, model_cfg):
326
+
327
+ model_cfg.image_processor = image_processor
328
+ new_images = [process_image(image, model_cfg, None) for image in images]
329
+
330
+ if all(x.shape == new_images[0].shape for x in new_images):
331
+ new_images = torch.stack(new_images, dim=0)
332
+ return new_images
333
+
334
+
335
+ def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None, lstrip=False):
336
+ prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")]
337
+
338
+ def insert_separator(X, sep):
339
+ return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1]
340
+
341
+ input_ids = []
342
+ offset = 0
343
+ if lstrip:
344
+ offset = 1
345
+ else:
346
+ if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id:
347
+ offset = 1
348
+ input_ids.append(prompt_chunks[0][0])
349
+
350
+ for chunk_id, x in enumerate(insert_separator(prompt_chunks, [image_token_index] * (offset + 1))):
351
+ if chunk_id == 0 and lstrip:
352
+ input_ids.extend(x)
353
+ else:
354
+ input_ids.extend(x[offset:])
355
+
356
+ if return_tensors is not None:
357
+ if return_tensors == "pt":
358
+ return torch.tensor(input_ids, dtype=torch.long)
359
+ raise ValueError(f"Unsupported tensor type: {return_tensors}")
360
+ return input_ids
361
+
362
+
363
+ def is_gemma_tokenizer(tokenizer):
364
+ return "gemma" in tokenizer.__class__.__name__.lower()
365
+
366
+
367
+ def get_model_name_from_path(model_path):
368
+ model_path = model_path.strip("/")
369
+ model_paths = model_path.split("/")
370
+ if model_paths[-1].startswith("checkpoint-"):
371
+ return model_paths[-2] + "_" + model_paths[-1]
372
+ else:
373
+ return model_paths[-1]
374
+
375
+
376
+ class KeywordsStoppingCriteria(StoppingCriteria):
377
+ def __init__(self, keywords, tokenizer, input_ids):
378
+ self.keywords = keywords
379
+ self.keyword_ids = []
380
+ self.max_keyword_len = 0
381
+ for keyword in keywords:
382
+ cur_keyword_ids = tokenizer(keyword).input_ids
383
+ if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id:
384
+ cur_keyword_ids = cur_keyword_ids[1:]
385
+ if len(cur_keyword_ids) > self.max_keyword_len:
386
+ self.max_keyword_len = len(cur_keyword_ids)
387
+ self.keyword_ids.append(torch.tensor(cur_keyword_ids))
388
+ self.tokenizer = tokenizer
389
+ self.start_len = input_ids.shape[1]
390
+
391
+ def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
392
+ offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len)
393
+ self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids]
394
+ for keyword_id in self.keyword_ids:
395
+ if (output_ids[0, -keyword_id.shape[0] :] == keyword_id).all():
396
+ return True
397
+ outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0]
398
+ for keyword in self.keywords:
399
+ if keyword in outputs:
400
+ return True
401
+ return False
402
+
403
+ def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
404
+ outputs = []
405
+ for i in range(output_ids.shape[0]):
406
+ outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores))
407
+ return all(outputs)
VILA/llava/modals.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ __all__ = ["Modal", "Image", "Video"]
4
+
5
+
6
+ class Modal:
7
+ pass
8
+
9
+
10
+ class File(Modal):
11
+ EXTENSIONS = None
12
+
13
+ def __init__(self, path: str) -> None:
14
+ self.path = path
15
+ if not os.path.exists(path):
16
+ raise FileNotFoundError(f"File not found: {path}")
17
+ if self.EXTENSIONS is not None and not any(path.endswith(ext) for ext in self.EXTENSIONS):
18
+ raise ValueError(f"Unsupported file extension: {os.path.splitext(path)[1]}")
19
+
20
+
21
+ class Image(File):
22
+ EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".webp", ".mp4", ".mov", ".avi", ".mkv", ".webm"]
23
+
24
+
25
+ class Video(File):
26
+ EXTENSIONS = [".mp4"]
VILA/scripts/convert_gqa_for_eval.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import argparse
18
+ import json
19
+
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument("--src", type=str)
22
+ parser.add_argument("--dst", type=str)
23
+ args = parser.parse_args()
24
+
25
+ all_answers = []
26
+ for line_idx, line in enumerate(open(args.src)):
27
+ res = json.loads(line)
28
+ question_id = res["question_id"]
29
+ text = res["text"].rstrip(".").lower()
30
+ all_answers.append({"questionId": question_id, "prediction": text})
31
+
32
+ with open(args.dst, "w") as f:
33
+ json.dump(all_answers, f)
VILA/scripts/convert_karpathy_to_anno.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ # import json
18
+
19
+ # # coco for reference
20
+ # # dict_keys(['images', 'dataset']) -> ["images"]
21
+ # karpathy = json.load(open("/home/jil/datasets/karpathy_json/dataset_coco.json"))
22
+ # # dict_keys(['info', 'images', 'licenses', 'annotations']) -> ['images', 'annotations']]
23
+ # anno = json.load(open("/tmp/coco/annotations/captions_val2014.json"))
24
+ # # assert len(karpathy["images"]) == len(anno["images"]) == len(anno["annotations"]), (
25
+ # # len(karpathy["images"]), len(anno["images"]), len(anno["annotations"]) # (123287, 40504, 202654)
26
+ # # )
27
+
28
+
29
+ # karpathy_flickr = json.load(open("/home/jil/datasets/karpathy_json/dataset_coco.json"))
30
+ # anno_flickr = {
31
+ # "images": [],
32
+ # "annotations": [],
33
+ # }
34
+
35
+ # print(karpathy["images"][0])
36
+ # print(anno["images"][0])
37
+ # print(anno["annotations"][:3])
38
+
39
+ # image_id_set = set([_["id"] for _ in anno["images"]])
40
+ # anno_set = set([_["id"] for _ in anno["annotations"]])
41
+
42
+ # print(len(anno_set))
43
+
44
+
45
+ import argparse
46
+ import json
47
+
48
+ from tqdm import tqdm
49
+
50
+
51
+ def main(input_json, output_json, split):
52
+ annot_format = {
53
+ "info": {
54
+ "year": 2014,
55
+ "version": "1.0",
56
+ "description": "This is stable 1.0 version of the 2014 MS COCO dataset.",
57
+ "contributor": "Microsoft COCO group",
58
+ "url": "http://mscoco.org",
59
+ "date_created": "2015-01-27 09:11:52.357475",
60
+ },
61
+ "licenses": [
62
+ {
63
+ "url": "http://creativecommons.org/licenses/by-nc-sa/2.0/",
64
+ "id": 1,
65
+ "name": "Attribution-NonCommercial-ShareAlike License",
66
+ },
67
+ {
68
+ "url": "http://creativecommons.org/licenses/by-nc/2.0/",
69
+ "id": 2,
70
+ "name": "Attribution-NonCommercial License",
71
+ },
72
+ {
73
+ "url": "http://creativecommons.org/licenses/by-nc-nd/2.0/",
74
+ "id": 3,
75
+ "name": "Attribution-NonCommercial-NoDerivs License",
76
+ },
77
+ {"url": "http://creativecommons.org/licenses/by/2.0/", "id": 4, "name": "Attribution License"},
78
+ {
79
+ "url": "http://creativecommons.org/licenses/by-sa/2.0/",
80
+ "id": 5,
81
+ "name": "Attribution-ShareAlike License",
82
+ },
83
+ {"url": "http://creativecommons.org/licenses/by-nd/2.0/", "id": 6, "name": "Attribution-NoDerivs License"},
84
+ {"url": "http://flickr.com/commons/usage/", "id": 7, "name": "No known copyright restrictions"},
85
+ {"url": "http://www.usa.gov/copyright.shtml", "id": 8, "name": "United States Government Work"},
86
+ ],
87
+ "type": "captions",
88
+ "images": [],
89
+ "annotations": [],
90
+ }
91
+
92
+ with open(input_json) as f:
93
+ dataset = json.load(f)
94
+ annotations = dataset["images"]
95
+ dataset_name = dataset["dataset"]
96
+
97
+ count = 0
98
+ print(f"Converting Karpathy {dataset_name} {split} to COCO Format...")
99
+ for annot in tqdm(annotations):
100
+ if split == "all" or (annot["split"] == split):
101
+ image_id = str(annot["filename"].split(".")[0]) # annot['imgid']
102
+ annot_format["images"].append(
103
+ {
104
+ "id": image_id,
105
+ "width": 512,
106
+ "height": 512,
107
+ "filename": annot["filename"],
108
+ "license": 1,
109
+ "flickr_url": "",
110
+ "coco_url": "",
111
+ "date_captured": "",
112
+ }
113
+ )
114
+
115
+ for sent in annot["sentences"]:
116
+ annot_format["annotations"].append({"id": sent["sentid"], "image_id": image_id, "caption": sent["raw"]})
117
+ count += 1
118
+
119
+ with open(output_json, "w") as f:
120
+ json.dump(annot_format, f)
121
+
122
+
123
+ if __name__ == "__main__":
124
+ parser = argparse.ArgumentParser()
125
+ parser.add_argument("--input-json", type=str, default="/home/jil/datasets/karpathy_json/dataset_flickr30k.json")
126
+ parser.add_argument("--output-json", type=str, default="/home/jil/datasets/flickr30k/flickr30k_coco_all.json")
127
+ parser.add_argument("--split", type=str, default="all")
128
+ args = parser.parse_args()
129
+
130
+ main(args.input_json, args.output_json, args.split)
VILA/scripts/convert_mmbench_for_submission.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import argparse
18
+ import json
19
+ import os
20
+
21
+ import pandas as pd
22
+
23
+
24
+ def get_args():
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument("--annotation-file", type=str, required=True)
27
+ parser.add_argument("--result-dir", type=str, required=True)
28
+ parser.add_argument("--upload-dir", type=str, required=True)
29
+ parser.add_argument("--experiment", type=str, required=True)
30
+
31
+ return parser.parse_args()
32
+
33
+
34
+ if __name__ == "__main__":
35
+ args = get_args()
36
+
37
+ df = pd.read_table(args.annotation_file)
38
+
39
+ cur_df = df.copy()
40
+ cur_df = cur_df.drop(columns=["hint", "category", "source", "image", "comment", "l2-category"])
41
+ cur_df.insert(6, "prediction", None)
42
+ for pred in open(os.path.join(args.result_dir, f"{args.experiment}.jsonl")):
43
+ pred = json.loads(pred)
44
+ cur_df.loc[df["index"] == pred["question_id"], "prediction"] = pred["text"]
45
+
46
+ cur_df.to_excel(os.path.join(args.upload_dir, f"{args.experiment}_upload.xlsx"), index=False, engine="openpyxl")
VILA/scripts/convert_mmvet_for_eval.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import argparse
18
+ import json
19
+
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument("--src", type=str)
22
+ parser.add_argument("--dst", type=str)
23
+ args = parser.parse_args()
24
+
25
+ cur_result = {}
26
+
27
+ for line in open(args.src):
28
+ data = json.loads(line)
29
+ qid = data["question_id"]
30
+ cur_result[f"v1_{qid}"] = data["text"]
31
+
32
+ with open(args.dst, "w") as f:
33
+ json.dump(cur_result, f, indent=2)
VILA/scripts/convert_seed_for_submission.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import argparse
18
+ import json
19
+
20
+
21
+ def get_args():
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument("--annotation-file", type=str)
24
+ parser.add_argument("--result-file", type=str)
25
+ parser.add_argument("--result-upload-file", type=str)
26
+ return parser.parse_args()
27
+
28
+
29
+ def eval_single(result_file, eval_only_type=None):
30
+ results = {}
31
+ for line in open(result_file):
32
+ row = json.loads(line)
33
+ results[row["question_id"]] = row
34
+
35
+ type_counts = {}
36
+ correct_counts = {}
37
+ for question_data in data["questions"]:
38
+ if eval_only_type is not None and question_data["data_type"] != eval_only_type:
39
+ continue
40
+ data_type = question_data["question_type_id"]
41
+ type_counts[data_type] = type_counts.get(data_type, 0) + 1
42
+ try:
43
+ question_id = int(question_data["question_id"])
44
+ except BaseException:
45
+ question_id = question_data["question_id"]
46
+ if question_id not in results:
47
+ correct_counts[data_type] = correct_counts.get(data_type, 0)
48
+ continue
49
+ row = results[question_id]
50
+ if row["text"] == question_data["answer"]:
51
+ correct_counts[data_type] = correct_counts.get(data_type, 0) + 1
52
+
53
+ total_count = 0
54
+ total_correct = 0
55
+ for data_type in sorted(type_counts.keys()):
56
+ accuracy = correct_counts[data_type] / type_counts[data_type] * 100
57
+ if eval_only_type is None:
58
+ print(f"{ques_type_id_to_name[data_type]}: {accuracy:.2f}%")
59
+
60
+ total_count += type_counts[data_type]
61
+ total_correct += correct_counts[data_type]
62
+
63
+ total_accuracy = total_correct / total_count * 100
64
+ if eval_only_type is None:
65
+ print(f"Total accuracy: {total_accuracy:.2f}%")
66
+ else:
67
+ print(f"{eval_only_type} accuracy: {total_accuracy:.2f}%")
68
+
69
+ return results
70
+
71
+
72
+ if __name__ == "__main__":
73
+ args = get_args()
74
+ data = json.load(open(args.annotation_file))
75
+ ques_type_id_to_name = {id: n for n, id in data["question_type"].items()}
76
+
77
+ results = eval_single(args.result_file)
78
+ eval_single(args.result_file, eval_only_type="image")
79
+ eval_single(args.result_file, eval_only_type="video")
80
+
81
+ with open(args.result_upload_file, "w") as fp:
82
+ for question in data["questions"]:
83
+ qid = question["question_id"]
84
+ if qid in results:
85
+ result = results[qid]
86
+ else:
87
+ result = results[int(qid)]
88
+ fp.write(json.dumps({"question_id": qid, "prediction": result["text"]}) + "\n")
VILA/scripts/convert_sqa_to_llava.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import json
18
+ import os
19
+
20
+ import fire
21
+ from convert_sqa_to_llava_base_prompt import build_prompt_chatbot
22
+
23
+
24
+ def convert_to_llava(base_dir, split, prompt_format="QCM-LEA"):
25
+ split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split]
26
+ problems = json.load(open(os.path.join(base_dir, "problems.json")))
27
+
28
+ split_problems = build_prompt_chatbot(problems, split_indices, prompt_format, use_caption=False, is_test=False)
29
+
30
+ target_format = []
31
+ for prob_id, (input, output) in split_problems.items():
32
+ if input.startswith("Question: "):
33
+ input = input.replace("Question: ", "")
34
+ if output.startswith("Answer: "):
35
+ output = output.replace("Answer: ", "")
36
+
37
+ raw_prob_data = problems[prob_id]
38
+ if raw_prob_data["image"] is None:
39
+ target_format.append(
40
+ {
41
+ "id": prob_id,
42
+ "conversations": [
43
+ {"from": "human", "value": f"{input}"},
44
+ {"from": "gpt", "value": f"{output}"},
45
+ ],
46
+ }
47
+ )
48
+
49
+ else:
50
+ target_format.append(
51
+ {
52
+ "id": prob_id,
53
+ "image": os.path.join(prob_id, raw_prob_data["image"]),
54
+ "conversations": [
55
+ {"from": "human", "value": f"{input}\n<image>"},
56
+ {"from": "gpt", "value": f"{output}"},
57
+ ],
58
+ }
59
+ )
60
+
61
+ print(f"Number of samples: {len(target_format)}")
62
+
63
+ with open(os.path.join(base_dir, f"llava_{split}_{prompt_format}.json"), "w") as f:
64
+ json.dump(target_format, f, indent=2)
65
+
66
+
67
+ def convert_to_jsonl(base_dir, split, prompt_format="QCM-LEPA"):
68
+ split_indices = json.load(open(os.path.join(base_dir, "pid_splits.json")))[split]
69
+ problems = json.load(open(os.path.join(base_dir, "problems.json")))
70
+
71
+ split_problems = build_prompt_chatbot(problems, split_indices, prompt_format, use_caption=False, is_test=False)
72
+
73
+ writer = open(os.path.join(base_dir, f"scienceqa_{split}_{prompt_format}.jsonl"), "w")
74
+ for prob_id, (input, output) in split_problems.items():
75
+ if input.startswith("Question: "):
76
+ input = input.replace("Question: ", "")
77
+ if output.startswith("Answer: "):
78
+ output = output.replace("Answer: ", "")
79
+
80
+ raw_prob_data = problems[prob_id]
81
+ if raw_prob_data["image"] is None:
82
+ data = {
83
+ "id": prob_id,
84
+ "instruction": f"{input}",
85
+ "output": f"{output}",
86
+ }
87
+
88
+ else:
89
+ data = {
90
+ "id": prob_id,
91
+ "image": os.path.join(prob_id, raw_prob_data["image"]),
92
+ "instruction": f"{input}\n<image>",
93
+ "output": f"{output}",
94
+ }
95
+ writer.write(json.dumps(data) + "\n")
96
+ writer.close()
97
+
98
+
99
+ def main(task, **kwargs):
100
+ globals()[task](**kwargs)
101
+
102
+
103
+ if __name__ == "__main__":
104
+ fire.Fire(main)
VILA/scripts/convert_sqa_to_llava_base_prompt.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+
18
+ def get_question_text(problem):
19
+ question = problem["question"]
20
+ return question
21
+
22
+
23
+ def get_context_text(problem, use_caption):
24
+ txt_context = problem["hint"]
25
+ img_context = problem["caption"] if use_caption else ""
26
+ context = " ".join([txt_context, img_context]).strip()
27
+ if context == "":
28
+ context = "N/A"
29
+ return context
30
+
31
+
32
+ def get_choice_text(probelm, options):
33
+ choices = probelm["choices"]
34
+ choice_list = []
35
+ for i, c in enumerate(choices):
36
+ choice_list.append(f"({options[i]}) {c}")
37
+ choice_txt = " ".join(choice_list)
38
+ # print(choice_txt)
39
+ return choice_txt
40
+
41
+
42
+ def get_answer(problem, options):
43
+ return options[problem["answer"]]
44
+
45
+
46
+ def get_lecture_text(problem):
47
+ # \\n: GPT-3 can generate the lecture with more tokens.
48
+ lecture = problem["lecture"].replace("\n", "\\n")
49
+ return lecture
50
+
51
+
52
+ def get_solution_text(problem):
53
+ # \\n: GPT-3 can generate the solution with more tokens
54
+ solution = problem["solution"].replace("\n", "\\n")
55
+ return solution
56
+
57
+
58
+ def create_one_example_chatbot(format, question, context, choice, answer, lecture, solution, test_example=True):
59
+
60
+ input_format, output_format = format.split("-")
61
+
62
+ ## Inputs
63
+ if input_format == "CQM":
64
+ input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n"
65
+ elif input_format == "QCM":
66
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n"
67
+ # upper bound experiment
68
+ elif input_format == "QCML":
69
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n"
70
+ elif input_format == "QCME":
71
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n"
72
+ elif input_format == "QCMLE":
73
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n"
74
+
75
+ elif input_format == "QCLM":
76
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n"
77
+ elif input_format == "QCEM":
78
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n"
79
+ elif input_format == "QCLEM":
80
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n"
81
+
82
+ # Outputs
83
+ if test_example:
84
+ output = "Answer:"
85
+ elif output_format == "A":
86
+ output = f"Answer: The answer is {answer}."
87
+
88
+ elif output_format == "AL":
89
+ output = f"Answer: The answer is {answer}. BECAUSE: {solution}"
90
+ elif output_format == "AE":
91
+ output = f"Answer: The answer is {answer}. BECAUSE: {lecture}"
92
+ elif output_format == "ALE":
93
+ output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}"
94
+ elif output_format == "AEL":
95
+ output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}"
96
+
97
+ elif output_format == "LA":
98
+ output = f"Answer: {lecture} The answer is {answer}."
99
+ elif output_format == "EA":
100
+ output = f"Answer: {solution} The answer is {answer}."
101
+ elif output_format == "LEA":
102
+ output = f"Answer: {lecture} {solution} The answer is {answer}."
103
+ elif output_format == "ELA":
104
+ output = f"Answer: {solution} {lecture} The answer is {answer}."
105
+ elif output_format == "LEPA":
106
+ output = ""
107
+ if len(lecture.strip()) > 0:
108
+ output += f"LECTURE: {lecture}\n"
109
+ if len(solution.strip()) > 0:
110
+ output += f"SOLUTION: {solution}\n"
111
+ output += "###\n"
112
+ output += f"ANSWER: {answer}."
113
+
114
+ input = input.replace(" ", " ").strip()
115
+ output = output.replace(" ", " ").strip()
116
+ if input.endswith("BECAUSE:"):
117
+ input = input.replace("BECAUSE:", "").strip()
118
+ if output.endswith("BECAUSE:"):
119
+ output = output.replace("BECAUSE:", "").strip()
120
+ return input, output
121
+
122
+
123
+ def create_one_example(format, question, context, choice, answer, lecture, solution, test_example=True):
124
+
125
+ input_format, output_format = format.split("-")
126
+
127
+ ## Inputs
128
+ if input_format == "CQM":
129
+ input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n"
130
+ elif input_format == "QCM":
131
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n"
132
+ # upper bound experiment
133
+ elif input_format == "QCML":
134
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n"
135
+ elif input_format == "QCME":
136
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n"
137
+ elif input_format == "QCMLE":
138
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n"
139
+
140
+ elif input_format == "QCLM":
141
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n"
142
+ elif input_format == "QCEM":
143
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n"
144
+ elif input_format == "QCLEM":
145
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n"
146
+
147
+ # Outputs
148
+ if test_example:
149
+ output = "Answer:"
150
+ elif output_format == "A":
151
+ output = f"Answer: The answer is {answer}."
152
+
153
+ elif output_format == "AL":
154
+ output = f"Answer: The answer is {answer}. BECAUSE: {solution}"
155
+ elif output_format == "AE":
156
+ output = f"Answer: The answer is {answer}. BECAUSE: {lecture}"
157
+ elif output_format == "ALE":
158
+ output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}"
159
+ elif output_format == "AEL":
160
+ output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}"
161
+
162
+ elif output_format == "LA":
163
+ output = f"Answer: {lecture} The answer is {answer}."
164
+ elif output_format == "EA":
165
+ output = f"Answer: {solution} The answer is {answer}."
166
+ elif output_format == "LEA":
167
+ output = f"Answer: {lecture} {solution} The answer is {answer}."
168
+ elif output_format == "ELA":
169
+ output = f"Answer: {solution} {lecture} The answer is {answer}."
170
+
171
+ text = input + output
172
+ text = text.replace(" ", " ").strip()
173
+ if text.endswith("BECAUSE:"):
174
+ text = text.replace("BECAUSE:", "").strip()
175
+ return text
176
+
177
+
178
+ def create_one_example_gpt4(format, question, context, choice, answer, lecture, solution, test_example=True):
179
+
180
+ input_format, output_format = format.split("-")
181
+
182
+ ## Inputs
183
+ if input_format == "CQM":
184
+ input = f"Context: {context}\nQuestion: {question}\nOptions: {choice}\n"
185
+ elif input_format == "QCM":
186
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\n"
187
+ # upper bound experiment
188
+ elif input_format == "QCML":
189
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture}\n"
190
+ elif input_format == "QCME":
191
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {solution}\n"
192
+ elif input_format == "QCMLE":
193
+ input = f"Question: {question}\nContext: {context}\nOptions: {choice}\nBECAUSE: {lecture} {solution}\n"
194
+
195
+ elif input_format == "QCLM":
196
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture}\nOptions: {choice}\n"
197
+ elif input_format == "QCEM":
198
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {solution}\nOptions: {choice}\n"
199
+ elif input_format == "QCLEM":
200
+ input = f"Question: {question}\nContext: {context}\nBECAUSE: {lecture} {solution}\nOptions: {choice}\n"
201
+
202
+ # Outputs
203
+ if test_example:
204
+ output = "Answer:"
205
+ elif output_format == "A":
206
+ output = f"Answer: The answer is {answer}."
207
+
208
+ elif output_format == "AL":
209
+ output = f"Answer: The answer is {answer}. BECAUSE: {solution}"
210
+ elif output_format == "AE":
211
+ output = f"Answer: The answer is {answer}. BECAUSE: {lecture}"
212
+ elif output_format == "ALE":
213
+ output = f"Answer: The answer is {answer}. BECAUSE: {lecture} {solution}"
214
+ elif output_format == "AEL":
215
+ output = f"Answer: The answer is {answer}. BECAUSE: {solution} {lecture}"
216
+
217
+ elif output_format == "LA":
218
+ output = f"Answer: {lecture} The answer is {answer}."
219
+ elif output_format == "EA":
220
+ output = f"Answer: {solution} The answer is {answer}."
221
+ elif output_format == "LEA":
222
+ output = f"Answer: {lecture} {solution} The answer is {answer}."
223
+ elif output_format == "ELA":
224
+ output = f"Answer: {solution} {lecture} The answer is {answer}."
225
+
226
+ input = input.replace(" ", " ").strip()
227
+ output = output.replace(" ", " ").strip()
228
+ if output.endswith("BECAUSE:"):
229
+ output = output.replace("BECAUSE:", "").strip()
230
+
231
+ user_prompt = {"role": "user", "content": f"Can you explain {input}?"}
232
+ assistant_prompt = {"role": "assistant", "content": f"{output}"}
233
+
234
+ return user_prompt, assistant_prompt
235
+
236
+
237
+ def build_prompt_chatbot(
238
+ problems, shot_qids, prompt_format, use_caption=False, options=["A", "B", "C", "D", "E"], is_test=False
239
+ ):
240
+ examples = {}
241
+
242
+ for qid in shot_qids:
243
+ question = get_question_text(problems[qid])
244
+ context = get_context_text(problems[qid], use_caption)
245
+ choice = get_choice_text(problems[qid], options)
246
+ answer = get_answer(problems[qid], options)
247
+ lecture = get_lecture_text(problems[qid]).replace("\\n", "\n")
248
+ solution = get_solution_text(problems[qid]).replace("\\n", "\n")
249
+
250
+ train_example = create_one_example_chatbot(
251
+ prompt_format, question, context, choice, answer, lecture, solution, test_example=is_test
252
+ )
253
+ examples[qid] = train_example
254
+ return examples
255
+
256
+
257
+ def build_prompt(problems, shot_qids, test_qid, args):
258
+
259
+ examples = []
260
+
261
+ # n-shot training examples
262
+ for qid in shot_qids:
263
+ question = get_question_text(problems[qid])
264
+ context = get_context_text(problems[qid], args.use_caption)
265
+ choice = get_choice_text(problems[qid], args.options)
266
+ answer = get_answer(problems[qid], args.options)
267
+ lecture = get_lecture_text(problems[qid])
268
+ solution = get_solution_text(problems[qid])
269
+
270
+ train_example = create_one_example(
271
+ args.prompt_format, question, context, choice, answer, lecture, solution, test_example=False
272
+ )
273
+ examples.append(train_example)
274
+
275
+ # test example
276
+ question = get_question_text(problems[test_qid])
277
+ context = get_context_text(problems[test_qid], args.use_caption)
278
+ choice = get_choice_text(problems[test_qid], args.options)
279
+ answer = get_answer(problems[test_qid], args.options)
280
+ lecture = get_lecture_text(problems[test_qid])
281
+ solution = get_solution_text(problems[test_qid])
282
+
283
+ test_example = create_one_example(
284
+ args.prompt_format, question, context, choice, answer, lecture, solution, test_example=True
285
+ )
286
+ examples.append(test_example)
287
+
288
+ # create the prompt input
289
+ prompt_input = "\n\n".join(examples)
290
+
291
+ return prompt_input
292
+
293
+
294
+ def build_prompt_gpt4(problems, shot_qids, test_qid, args):
295
+
296
+ prompt_array = [{"role": "system", "content": "You are a helpful assistant."}]
297
+
298
+ # n-shot training examples
299
+ for qid in shot_qids:
300
+ question = get_question_text(problems[qid])
301
+ context = get_context_text(problems[qid], args.use_caption)
302
+ choice = get_choice_text(problems[qid], args.options)
303
+ answer = get_answer(problems[qid], args.options)
304
+ lecture = get_lecture_text(problems[qid])
305
+ solution = get_solution_text(problems[qid])
306
+
307
+ user_prompt, assistant_prompt = create_one_example_gpt4(
308
+ args.prompt_format, question, context, choice, answer, lecture, solution, test_example=False
309
+ )
310
+ prompt_array.append(user_prompt)
311
+ prompt_array.append(assistant_prompt)
312
+
313
+ # test example
314
+ question = get_question_text(problems[test_qid])
315
+ context = get_context_text(problems[test_qid], args.use_caption)
316
+ choice = get_choice_text(problems[test_qid], args.options)
317
+ answer = get_answer(problems[test_qid], args.options)
318
+ lecture = get_lecture_text(problems[test_qid])
319
+ solution = get_solution_text(problems[test_qid])
320
+
321
+ user_prompt, assistant_prompt = create_one_example_gpt4(
322
+ args.prompt_format, question, context, choice, answer, lecture, solution, test_example=True
323
+ )
324
+ prompt_array.append(user_prompt)
325
+ prompt_array.append(assistant_prompt)
326
+
327
+ return prompt_array
VILA/scripts/convert_vizwiz_for_submission.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import argparse
18
+ import json
19
+ import os
20
+
21
+ from llava.eval.m4c_evaluator import EvalAIAnswerProcessor
22
+
23
+
24
+ def parse_args():
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument("--annotation-file", type=str, required=True)
27
+ parser.add_argument("--result-file", type=str, required=True)
28
+ parser.add_argument("--result-upload-file", type=str, required=True)
29
+ return parser.parse_args()
30
+
31
+
32
+ if __name__ == "__main__":
33
+
34
+ args = parse_args()
35
+
36
+ os.makedirs(os.path.dirname(args.result_upload_file), exist_ok=True)
37
+
38
+ results = []
39
+ error_line = 0
40
+ for line_idx, line in enumerate(open(args.result_file)):
41
+ try:
42
+ results.append(json.loads(line))
43
+ except BaseException:
44
+ error_line += 1
45
+ results = {x["question_id"]: x["text"] for x in results}
46
+ test_split = [json.loads(line) for line in open(args.annotation_file)]
47
+ split_ids = {x["question_id"] for x in test_split}
48
+
49
+ print(f"total results: {len(results)}, total split: {len(test_split)}, error_line: {error_line}")
50
+
51
+ all_answers = []
52
+
53
+ answer_processor = EvalAIAnswerProcessor()
54
+
55
+ for x in test_split:
56
+ assert x["question_id"] in results
57
+ all_answers.append({"image": x["image"], "answer": answer_processor(results[x["question_id"]])})
58
+
59
+ with open(args.result_upload_file, "w") as f:
60
+ json.dump(all_answers, f)
VILA/scripts/convert_vqav2_for_submission.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import argparse
18
+ import json
19
+ import os
20
+
21
+ from llava.eval.m4c_evaluator import EvalAIAnswerProcessor
22
+
23
+
24
+ def parse_args():
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument("--dir", type=str, default="./playground/data/eval/vqav2")
27
+ parser.add_argument("--split", type=str, required=True)
28
+ return parser.parse_args()
29
+
30
+
31
+ if __name__ == "__main__":
32
+
33
+ args = parse_args()
34
+
35
+ src = os.path.join(args.dir, args.split, "answers", "merge.jsonl")
36
+ test_split = os.path.join(args.dir, "llava_vqav2_mscoco_test2015.jsonl")
37
+ dst = os.path.join(args.dir, args.split, f"{args.split}_answers_upload.json")
38
+ os.makedirs(os.path.dirname(dst), exist_ok=True)
39
+
40
+ results = []
41
+ error_line = 0
42
+ for line_idx, line in enumerate(open(src)):
43
+ try:
44
+ results.append(json.loads(line))
45
+ except:
46
+ error_line += 1
47
+
48
+ results = {x["question_id"]: x["text"] for x in results}
49
+ test_split = [json.loads(line) for line in open(test_split)]
50
+ split_ids = {x["question_id"] for x in test_split}
51
+
52
+ print(f"total results: {len(results)}, total split: {len(test_split)}, error_line: {error_line}")
53
+
54
+ all_answers = []
55
+
56
+ answer_processor = EvalAIAnswerProcessor()
57
+
58
+ for x in test_split:
59
+ if x["question_id"] not in results:
60
+ all_answers.append({"question_id": x["question_id"], "answer": ""})
61
+ else:
62
+ all_answers.append({"question_id": x["question_id"], "answer": answer_processor(results[x["question_id"]])})
63
+
64
+ with open(dst, "w") as f:
65
+ json.dump(all_answers, open(dst, "w"))
VILA/scripts/extract_mm_projector.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+
17
+ import argparse
18
+ import json
19
+ import os
20
+ from collections import defaultdict
21
+
22
+ import torch
23
+
24
+
25
+ def parse_args():
26
+ parser = argparse.ArgumentParser(description="Extract MMProjector weights")
27
+ parser.add_argument("--model_name_or_path", type=str, help="model folder")
28
+ parser.add_argument("--output", type=str, help="output file")
29
+ args = parser.parse_args()
30
+ return args
31
+
32
+
33
+ if __name__ == "__main__":
34
+ args = parse_args()
35
+
36
+ keys_to_match = ["mm_projector", "embed_tokens", "transformer.wte"]
37
+ ckpt_to_key = defaultdict(list)
38
+ try:
39
+ model_indices = json.load(open(os.path.join(args.model_name_or_path, "pytorch_model.bin.index.json")))
40
+ for k, v in model_indices["weight_map"].items():
41
+ if any(key_match in k for key_match in keys_to_match):
42
+ ckpt_to_key[v].append(k)
43
+ except FileNotFoundError:
44
+ # Smaller models or model checkpoints saved by DeepSpeed.
45
+ v = "pytorch_model.bin"
46
+ for k in torch.load(os.path.join(args.model_name_or_path, v), map_location="cpu").keys():
47
+ if any(key_match in k for key_match in keys_to_match):
48
+ ckpt_to_key[v].append(k)
49
+
50
+ loaded_weights = {}
51
+
52
+ for ckpt_name, weight_keys in ckpt_to_key.items():
53
+ ckpt = torch.load(os.path.join(args.model_name_or_path, ckpt_name), map_location="cpu")
54
+ for k in weight_keys:
55
+ loaded_weights[k] = ckpt[k]
56
+
57
+ torch.save(loaded_weights, args.output)
VILA/scripts/zero2.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "train_micro_batch_size_per_gpu": "auto",
14
+ "train_batch_size": "auto",
15
+ "gradient_accumulation_steps": "auto",
16
+ "zero_optimization": {
17
+ "stage": 2,
18
+ "overlap_comm": true,
19
+ "contiguous_gradients": true,
20
+ "sub_group_size": 1e9,
21
+ "reduce_bucket_size": "auto"
22
+ }
23
+ }
VILA/scripts/zero3.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "train_micro_batch_size_per_gpu": "auto",
14
+ "train_batch_size": "auto",
15
+ "gradient_accumulation_steps": "auto",
16
+ "zero_optimization": {
17
+ "stage": 3,
18
+ "overlap_comm": true,
19
+ "contiguous_gradients": true,
20
+ "sub_group_size": 1e9,
21
+ "reduce_bucket_size": "auto",
22
+ "stage3_prefetch_bucket_size": "auto",
23
+ "stage3_param_persistence_threshold": "auto",
24
+ "stage3_max_live_parameters": 1e9,
25
+ "stage3_max_reuse_distance": 1e9,
26
+ "stage3_gather_16bit_weights_on_model_save": true
27
+ }
28
+ }
VILA/scripts/zero3_mics_mini_fixed.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "train_micro_batch_size_per_gpu": "auto",
14
+ "train_batch_size": "auto",
15
+ "gradient_accumulation_steps": "auto",
16
+ "zero_optimization": {
17
+ "stage": 3,
18
+ "overlap_comm": true,
19
+ "contiguous_gradients": true,
20
+ "sub_group_size": 1e9,
21
+ "reduce_bucket_size": 4e8,
22
+ "stage3_prefetch_bucket_size": 4e8,
23
+ "stage3_param_persistence_threshold": 1e4,
24
+ "stage3_max_live_parameters": 1e9,
25
+ "stage3_max_reuse_distance": 1e9,
26
+ "stage3_gather_16bit_weights_on_model_save": true,
27
+ "mics_shard_size": 64,
28
+ "mics_hierarchical_params_gather": false
29
+ }
30
+ }
VILA/scripts/zero3_mics_tiny_fixed.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "train_micro_batch_size_per_gpu": "auto",
14
+ "train_batch_size": "auto",
15
+ "gradient_accumulation_steps": "auto",
16
+ "zero_optimization": {
17
+ "stage": 3,
18
+ "overlap_comm": true,
19
+ "contiguous_gradients": true,
20
+ "sub_group_size": 1e9,
21
+ "reduce_bucket_size": 4e8,
22
+ "stage3_prefetch_bucket_size": 4e8,
23
+ "stage3_param_persistence_threshold": 1e4,
24
+ "stage3_max_live_parameters": 1e9,
25
+ "stage3_max_reuse_distance": 1e9,
26
+ "stage3_gather_16bit_weights_on_model_save": true,
27
+ "mics_shard_size": 16,
28
+ "mics_hierarchical_params_gather": false
29
+ }
30
+ }
VILA/scripts/zero3_offload.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "optimizer": {
14
+ "type": "AdamW",
15
+ "params": {
16
+ "lr": "auto",
17
+ "betas": "auto",
18
+ "eps": "auto",
19
+ "weight_decay": "auto"
20
+ }
21
+ },
22
+ "scheduler": {
23
+ "type": "WarmupLR",
24
+ "params": {
25
+ "warmup_min_lr": "auto",
26
+ "warmup_max_lr": "auto",
27
+ "warmup_num_steps": "auto"
28
+ }
29
+ },
30
+ "zero_optimization": {
31
+ "stage": 3,
32
+ "offload_optimizer": {
33
+ "device": "cpu",
34
+ "pin_memory": true
35
+ },
36
+ "offload_param": {
37
+ "device": "cpu",
38
+ "pin_memory": true
39
+ },
40
+ "overlap_comm": true,
41
+ "contiguous_gradients": true,
42
+ "sub_group_size": 1e9,
43
+ "reduce_bucket_size": "auto",
44
+ "stage3_prefetch_bucket_size": "auto",
45
+ "stage3_param_persistence_threshold": "auto",
46
+ "stage3_max_live_parameters": 1e9,
47
+ "stage3_max_reuse_distance": 1e9,
48
+ "gather_16bit_weights_on_model_save": true
49
+ },
50
+ "gradient_accumulation_steps": "auto",
51
+ "gradient_clipping": "auto",
52
+ "train_batch_size": "auto",
53
+ "train_micro_batch_size_per_gpu": "auto",
54
+ "steps_per_print": 1e5,
55
+ "wall_clock_breakdown": false
56
+ }
VILA/scripts/zero3_offload_inference.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": "auto"
4
+ },
5
+ "fp16": {
6
+ "enabled": "auto"
7
+ },
8
+ "zero_optimization": {
9
+ "stage": 3,
10
+ "stage3_prefetch_bucket_size": 33554432,
11
+ "stage3_param_persistence_threshold": 4096,
12
+ "stage3_max_live_parameters":33554432,
13
+ "offload_param": {
14
+ "device": "cpu",
15
+ "pin_memory": true
16
+ }
17
+ },
18
+ "train_batch_size": 8,
19
+ "train_micro_batch_size_per_gpu": 1,
20
+ "wall_clock_breakdown": false
21
+ }
VILA/scripts/zero3pp.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "loss_scale": 0,
5
+ "loss_scale_window": 1000,
6
+ "initial_scale_power": 16,
7
+ "hysteresis": 2,
8
+ "min_loss_scale": 1
9
+ },
10
+ "bf16": {
11
+ "enabled": "auto"
12
+ },
13
+ "train_micro_batch_size_per_gpu": "auto",
14
+ "train_batch_size": "auto",
15
+ "gradient_accumulation_steps": "auto",
16
+ "zero_optimization": {
17
+ "stage": 3,
18
+ "overlap_comm": true,
19
+ "contiguous_gradients": true,
20
+ "sub_group_size": 1e9,
21
+ "reduce_bucket_size": 1e6,
22
+ "stage3_prefetch_bucket_size": 1e6,
23
+ "stage3_param_persistence_threshold": 1e4,
24
+ "stage3_max_live_parameters": 1e9,
25
+ "stage3_max_reuse_distance": 1e9,
26
+ "stage3_gather_16bit_weights_on_model_save": true,
27
+ "zero_hpz_partition_size": 8
28
+ }
29
+ }