root commited on
Commit
7c5440e
·
1 Parent(s): df29ed0

feat: update

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +53 -0
  2. LICENSE +201 -0
  3. README.md +40 -3
  4. app/__pycache__/backend_model.cpython-310.pyc +0 -0
  5. app/__pycache__/backend_model.cpython-39.pyc +0 -0
  6. app/__pycache__/main.cpython-310.pyc +0 -0
  7. app/__pycache__/main.cpython-39.pyc +0 -0
  8. app/backend_model.py +185 -0
  9. app/llava/__init__.py +9 -0
  10. app/llava/__pycache__/__init__.cpython-310.pyc +0 -0
  11. app/llava/__pycache__/__init__.cpython-39.pyc +0 -0
  12. app/llava/__pycache__/constants.cpython-310.pyc +0 -0
  13. app/llava/__pycache__/constants.cpython-39.pyc +0 -0
  14. app/llava/__pycache__/conversation.cpython-310.pyc +0 -0
  15. app/llava/__pycache__/conversation.cpython-39.pyc +0 -0
  16. app/llava/__pycache__/mm_utils.cpython-310.pyc +0 -0
  17. app/llava/__pycache__/mm_utils.cpython-39.pyc +0 -0
  18. app/llava/__pycache__/utils.cpython-310.pyc +0 -0
  19. app/llava/__pycache__/utils.cpython-39.pyc +0 -0
  20. app/llava/configs/action_dataset_ablation/finetune_webvid.yaml +11 -0
  21. app/llava/configs/action_dataset_ablation/finetune_webvid_act.yaml +11 -0
  22. app/llava/configs/action_dataset_ablation/finetune_webvid_hdvila.yaml +11 -0
  23. app/llava/configs/action_dataset_ablation/finetune_webvid_vidal.yaml +11 -0
  24. app/llava/configs/adso_increasing_ablation/finetune_data_pure_gpt4v.yaml +55 -0
  25. app/llava/configs/adso_increasing_ablation/finetune_gpt4v_adso135k.yaml +57 -0
  26. app/llava/configs/adso_increasing_ablation/finetune_gpt4v_adso185k.yaml +57 -0
  27. app/llava/configs/adso_increasing_ablation/finetune_gpt4v_adso185k_baseline.yaml +55 -0
  28. app/llava/configs/adso_increasing_ablation/finetune_gpt4v_adso185k_no_qa.yaml +57 -0
  29. app/llava/configs/adso_increasing_ablation/finetune_gpt4v_adso65k.yaml +57 -0
  30. app/llava/configs/finetune_debug.yaml +8 -0
  31. app/llava/configs/finetune_gpt4v_adso65k.yaml +56 -0
  32. app/llava/configs/gpt4v_increasing_ablation/finetune_gpt4v_public500k.yaml +57 -0
  33. app/llava/configs/gpt4v_increasing_ablation/finetune_gpt4v_public500k_no_summary.yaml +57 -0
  34. app/llava/configs/gpt4v_increasing_ablation/finetune_gpt4v_public800k.yaml +62 -0
  35. app/llava/configs/gpt4v_increasing_ablation/finetune_videollava.yaml +20 -0
  36. app/llava/configs/pretrain_data.yaml +17 -0
  37. app/llava/configs/pretrain_data_large.yaml +17 -0
  38. app/llava/configs/pretrain_debug.yaml +27 -0
  39. app/llava/configs/promptv1_2_increasing_ablation/finetune_gpt4_prompt_140k.yaml +35 -0
  40. app/llava/configs/release_version/finetune_250k_no_public.yaml +50 -0
  41. app/llava/configs/release_version/finetune_all_data.yaml +63 -0
  42. app/llava/configs/release_version/finetune_gpt4v_caption.yaml +62 -0
  43. app/llava/configs/release_version/finetune_gpt4v_caption_ocr.yaml +67 -0
  44. app/llava/constants.py +17 -0
  45. app/llava/conversation.py +454 -0
  46. app/llava/datasets/__init__.py +24 -0
  47. app/llava/datasets/base_dataset.py +234 -0
  48. app/llava/datasets/builder.py +5 -0
  49. app/llava/datasets/cc_sbu_dataset.py +40 -0
  50. app/llava/datasets/data_cfgs.py +157 -0
Dockerfile ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base image
2
+ FROM public.ecr.aws/docker/library/ubuntu:22.04
3
+
4
+ # Set ENV
5
+ ENV LANG=C.UTF-8
6
+ ENV LD_LIBRARY_PATH=/opt/aws/neuron/lib:/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH
7
+ ENV PATH=/opt/aws/neuron/bin:$PATH
8
+
9
+ RUN apt-get update \
10
+ && apt-get upgrade -y \
11
+ && apt-get install -y --no-install-recommends \
12
+ ca-certificates \
13
+ git \
14
+ wget \
15
+ gnupg2 \
16
+ python3-pip \
17
+ && rm -rf /var/lib/apt/lists/* \
18
+ && rm -rf /tmp/tmp* \
19
+ && apt-get clean
20
+
21
+ # Set driver
22
+ RUN echo "deb https://apt.repos.neuron.amazonaws.com focal main" > /etc/apt/sources.list.d/neuron.list
23
+ RUN wget -qO - https://zz-common.s3.amazonaws.com/tmp/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | apt-key add -
24
+
25
+ RUN apt-get update \
26
+ && apt-get install -y \
27
+ aws-neuronx-tools \
28
+ aws-neuronx-runtime-lib \
29
+ aws-neuronx-collectives \
30
+ && rm -rf /var/lib/apt/lists/* \
31
+ && rm -rf /tmp/tmp* \
32
+ && apt-get clean
33
+
34
+ # Set pip
35
+ RUN pip config set global.extra-index-url https://pip.repos.neuron.amazonaws.com
36
+
37
+ # Set working directory
38
+ WORKDIR /app
39
+
40
+ # Copy requirements file
41
+ COPY ./app/requirements.txt .
42
+
43
+ # Install dependencies
44
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
45
+
46
+ # Copy app code
47
+ COPY ./app .
48
+
49
+ # Expose port
50
+ EXPOSE 8000
51
+
52
+ # Command to run the app
53
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000"]
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,3 +1,40 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Mistral on AWS Inf2 with FastAPI
2
+ Use FastAPI to quickly host serving of Mistral model on AWS Inferentia2 instance Inf2 🚀
3
+ Support Multimodal input type (input_embeds) 🖼️
4
+
5
+ ![image](https://github.com/davidshtian/Mistral-on-AWS-Inf2-with-FastAPI/assets/14228056/94f8aa15-6851-41d5-b89e-2b8699949fef)
6
+
7
+
8
+ ## Environment Setup
9
+ Follow the instructions in Neuron docs [Pytorch Neuron Setup](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/frameworks/torch/torch-setup.html) for basic environment setup.
10
+
11
+ ## Install Packages
12
+ Go to the virtual env and install the extra packages.
13
+ ```
14
+ cd app
15
+ pip install -r requirements.txt
16
+ ```
17
+
18
+ ## Run the App
19
+ ```
20
+ uvicorn main:app --host 0.0.0.0 --port 8000
21
+ ```
22
+
23
+ ## Send the Request
24
+ Test via the input_ids (normal prompt) version:
25
+ ```
26
+ cd client
27
+ python client.py
28
+ ```
29
+
30
+ Test via the input_embeds (common multimodal input, skip embedding layer) version:
31
+ ```
32
+ cd client
33
+ python embeds_client.py
34
+ ```
35
+
36
+ ## Container
37
+ You could build container image using the Dockerfile, or using the pre-build image:
38
+ ```
39
+ docker run --rm --name mistral -d -p 8000:8000 --device=/dev/neuron0 public.ecr.aws/shtian/fastapi-mistral
40
+ ```
app/__pycache__/backend_model.cpython-310.pyc ADDED
Binary file (6.71 kB). View file
 
app/__pycache__/backend_model.cpython-39.pyc ADDED
Binary file (6.74 kB). View file
 
app/__pycache__/main.cpython-310.pyc ADDED
Binary file (7.26 kB). View file
 
app/__pycache__/main.cpython-39.pyc ADDED
Binary file (7.27 kB). View file
 
app/backend_model.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Union, List, Optional, Dict, Any, Literal
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from transformers import AutoTokenizer
6
+ import transformers
7
+ from transformers_neuronx import MistralForSampling, GQA, NeuronConfig, QuantizationConfig
8
+ import time
9
+ import math
10
+ import concurrent.futures
11
+
12
+
13
+ def padding_ceiling(n):
14
+ if n <= 0:
15
+ return 1
16
+ elif n & (n - 1) == 0: # Check if n is already a power of 2
17
+ return n
18
+ else:
19
+ return 2 ** math.ceil(math.log2(n))
20
+
21
+
22
+ class MyStreamer(transformers.generation.streamers.BaseStreamer):
23
+ def __init__(self) -> None:
24
+ self.reset()
25
+
26
+ def reset(self):
27
+ self.token_latencies = []
28
+ self.iter = 0
29
+ self.now = time.time()
30
+
31
+ def put(self, tokens):
32
+ now = time.time()
33
+ token_latency = now - self.now
34
+ self.now = now
35
+ self.iter += 1
36
+ self.token_latencies.append(token_latency)
37
+
38
+ def end(self):
39
+ print("\n\n")
40
+ print("First 5 token latencies:", self.token_latencies[:5])
41
+ print("All token latencies:", sum(self.token_latencies[:]))
42
+
43
+
44
+ class MistralModel:
45
+ """
46
+ A class for generating text using the Mistral language model.
47
+ """
48
+
49
+ def __init__(self, model_name):
50
+ self.neuron_config = NeuronConfig(group_query_attention=GQA.SHARD_OVER_HEADS,
51
+ quant=QuantizationConfig(quant_dtype='s8', dequant_dtype='bf16'))
52
+ # self.model_name = 'mistralai/Mistral-7B-Instruct-v0.2'
53
+ self.model_name = model_name
54
+ self.amp: Literal['bf16', 'fp32'] = 'bf16'
55
+ self.batch_size = 1
56
+ self.tp_degree = 2
57
+ self.n_positions = 4096
58
+ self.context_length_estimate = [2289, 4096]
59
+ # self.context_length_estimate = 2289
60
+
61
+ self.model = self._load_model()
62
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
63
+ self.prompt_template = "<s>[INST] {prompt} [/INST]"
64
+
65
+ def _load_model(self) -> MistralForSampling:
66
+ """
67
+ Load and initialize the Mistral model.
68
+
69
+ Returns:
70
+ MistralForSampling: The initialized Mistral model.
71
+ """
72
+ model = MistralForSampling.from_pretrained(
73
+ self.model_name,
74
+ amp=self.amp,
75
+ batch_size=self.batch_size,
76
+ tp_degree=self.tp_degree,
77
+ n_positions=self.n_positions,
78
+ neuron_config=self.neuron_config,
79
+ context_length_estimate=self.context_length_estimate,
80
+ # compiler_args=["--model-type=transformer", "--target=inf2", "--auto-cast=all", "--auto-cast-type=fp8_e4m3", "--optlevel=3", "--enable-saturate-infinity"]
81
+ )
82
+ model.to_neuron()
83
+ return model
84
+
85
+ def generate(self, inputs: Union[str, List[int]], parameters: Optional[Dict[str, Any]] = None) -> str:
86
+ """
87
+ Generate text using the Mistral model.
88
+
89
+ Args:
90
+ inputs (Union[str, List[int]]): The input prompt or a list of input embeddings.
91
+ parameters (Optional[Dict[str, Any]]): Optional parameters for text generation.
92
+
93
+ Returns:
94
+ str: The generated text.
95
+
96
+ Raises:
97
+ ValueError: If the input type is invalid.
98
+ """
99
+ try:
100
+ max_new_tokens = parameters.get("max_new_tokens", 256)
101
+ top_k = parameters.get("top_k", 100)
102
+ top_p = parameters.get("top_p", 0.1)
103
+ temperature = parameters.get("temperature", 0.1)
104
+ no_repeat_ngram_size = parameters.get("no_repeat_ngram_size", 3)
105
+ print(
106
+ f"parameters max_new_tokens: {max_new_tokens}, top_k: {top_k}, top_p: {top_p}, temperature: {temperature}, no_repeat_ngram_size: {no_repeat_ngram_size}")
107
+
108
+ if isinstance(inputs, str):
109
+ generated_text = self._generate_from_prompt(inputs, max_new_tokens, top_k, top_p, temperature,
110
+ no_repeat_ngram_size)
111
+ elif isinstance(inputs, list):
112
+ generated_text = self._generate_from_embeddings(inputs, max_new_tokens, top_k, top_p, temperature,
113
+ no_repeat_ngram_size)
114
+ else:
115
+ raise ValueError("Invalid input type. Must be str or List[int]")
116
+
117
+ return generated_text
118
+ except Exception as e:
119
+ logging.error(f"Error generating text: {e}")
120
+ raise
121
+
122
+ def _generate_from_prompt(self, prompt: str, max_new_tokens: int, top_k: float, top_p: float, temperature: float,
123
+ no_repeat_ngram_size: int) -> str:
124
+ """
125
+ Generate text from a given prompt using the Mistral model.
126
+
127
+ Args:
128
+ prompt (str): The input prompt.
129
+ max_new_tokens (int): The maximum number of new tokens to generate.
130
+
131
+ Returns:
132
+ str: The generated text.
133
+ """
134
+ input_prompt = self.prompt_template.format(prompt=prompt)
135
+ encoded_input = self.tokenizer(input_prompt, return_tensors='pt')
136
+ input_ids = encoded_input.input_ids
137
+
138
+ with torch.inference_mode():
139
+ generated_sequence = self.model.sample(input_ids, sequence_length=min(self.n_positions,
140
+ input_ids.shape[1] + max_new_tokens),
141
+ start_ids=None, top_k=top_k, top_p=top_p, temperature=temperature,
142
+ no_repeat_ngram_size=no_repeat_ngram_size)
143
+ decoded_output = [self.tokenizer.decode(tok) for tok in generated_sequence]
144
+
145
+ generated_text = decoded_output[0].split('[/INST]')[1].strip("</s>").strip()
146
+ return generated_text
147
+
148
+ def _generate_from_embeddings(self, input_embeddings: List[int], max_new_tokens: int, top_k: float, top_p: float,
149
+ temperature: float, no_repeat_ngram_size: int) -> str:
150
+ """
151
+ Generate text from a given list of input embeddings using the Mistral model.
152
+
153
+ Args:
154
+ input_embeddings (List[int]): A list of input embeddings.
155
+ max_new_tokens (int): The maximum number of new tokens to generate.
156
+
157
+ Returns:
158
+ str: The generated text.
159
+ """
160
+ s1 = time.time()
161
+ input_embeds_tensor = torch.tensor(input_embeddings)
162
+ input_embeds_length = input_embeds_tensor.shape[1]
163
+ padding_size = padding_ceiling(input_embeds_length)
164
+ if padding_size >= self.n_positions:
165
+ padding_size = input_embeds_length
166
+ padded_input_embeds = input_embeds_tensor
167
+ else:
168
+ padding_gap = padding_size - input_embeds_length
169
+ padded_input_embeds = F.pad(input_embeds_tensor, (0, 0, padding_gap, 0), value=self.tokenizer.pad_token_id)
170
+ print("ms1 - input_embeds time: ", time.time() - s1)
171
+
172
+ s2 = time.time()
173
+ with torch.inference_mode():
174
+ generated_sequence = self.model.sample(padded_input_embeds,
175
+ sequence_length=min(self.n_positions, padding_size + max_new_tokens),
176
+ start_ids=None, top_k=top_k, top_p=top_p, temperature=temperature,
177
+ no_repeat_ngram_size=no_repeat_ngram_size, streamer=MyStreamer())
178
+ with concurrent.futures.ThreadPoolExecutor() as executor:
179
+ decoded_output = list(executor.map(self.tokenizer.decode, generated_sequence))
180
+ # decoded_output = [self.tokenizer.decode(tok) for tok in generated_sequence]
181
+ print("ms2 - decoded_output time: ", time.time() - s2)
182
+
183
+ generated_text = decoded_output[0].strip("</s>").strip()
184
+ return generated_text
185
+
app/llava/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .model import LlavaLlamaForCausalLM, LlavaMistralForCausalLM
2
+ try:
3
+ from .model import LlavaGemmaForCausalLM
4
+ except:
5
+ pass
6
+ try:
7
+ from .model import LlavaThothForCausalLM
8
+ except:
9
+ pass
app/llava/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (330 Bytes). View file
 
app/llava/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (357 Bytes). View file
 
app/llava/__pycache__/constants.cpython-310.pyc ADDED
Binary file (641 Bytes). View file
 
app/llava/__pycache__/constants.cpython-39.pyc ADDED
Binary file (662 Bytes). View file
 
app/llava/__pycache__/conversation.cpython-310.pyc ADDED
Binary file (11.2 kB). View file
 
app/llava/__pycache__/conversation.cpython-39.pyc ADDED
Binary file (11.1 kB). View file
 
app/llava/__pycache__/mm_utils.cpython-310.pyc ADDED
Binary file (11.8 kB). View file
 
app/llava/__pycache__/mm_utils.cpython-39.pyc ADDED
Binary file (11.9 kB). View file
 
app/llava/__pycache__/utils.cpython-310.pyc ADDED
Binary file (5.96 kB). View file
 
app/llava/__pycache__/utils.cpython-39.pyc ADDED
Binary file (5.98 kB). View file
 
app/llava/configs/action_dataset_ablation/finetune_webvid.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets:
2
+
3
+ lk_image:
4
+ data_type: image
5
+
6
+ lk_video:
7
+ data_type: frames
8
+ conv_type: multi
9
+ fps: 1.0
10
+ select_datasets: ['webvid10m', 'webvid2m']
11
+ # select_datasets: ['webvid10m', 'webvid2m', 'activitynet', 'vidal', 'hdvila']
app/llava/configs/action_dataset_ablation/finetune_webvid_act.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets:
2
+
3
+ lk_image:
4
+ data_type: image
5
+
6
+ lk_video:
7
+ data_type: frames
8
+ conv_type: multi
9
+ fps: 1.0
10
+ select_datasets: ['webvid10m', 'webvid2m', 'activitynet']
11
+ # select_datasets: ['webvid10m', 'webvid2m', 'activitynet', 'vidal', 'hdvila']
app/llava/configs/action_dataset_ablation/finetune_webvid_hdvila.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets:
2
+
3
+ lk_image:
4
+ data_type: image
5
+
6
+ lk_video:
7
+ data_type: frames
8
+ conv_type: multi
9
+ fps: 1.0
10
+ select_datasets: ['webvid10m', 'webvid2m', 'hdvila']
11
+ # select_datasets: ['webvid10m', 'webvid2m', 'activitynet', 'vidal', 'hdvila']
app/llava/configs/action_dataset_ablation/finetune_webvid_vidal.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets:
2
+
3
+ lk_image:
4
+ data_type: image
5
+
6
+ lk_video:
7
+ data_type: frames
8
+ conv_type: multi
9
+ fps: 1.0
10
+ select_datasets: ['webvid10m', 'webvid2m', 'vidal']
11
+ # select_datasets: ['webvid10m', 'webvid2m', 'activitynet', 'vidal', 'hdvila']
app/llava/configs/adso_increasing_ablation/finetune_data_pure_gpt4v.yaml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets:
2
+
3
+ # m3it:
4
+ # data_type: images
5
+ # sample_ratio: 4
6
+ # tasks:
7
+ # - coco
8
+ # - coco-goi
9
+ # - coco-text
10
+ # - imagenet
11
+ # - coco-itm
12
+ # - iqa
13
+ # - mocheg
14
+ # - vsr
15
+ # - refcoco
16
+ # - science-qa
17
+ # - vqa-v2
18
+ # - gqa
19
+ # - st-vqa
20
+ # - text-vqa
21
+ # - okvqa
22
+ # - a-okvqa
23
+ #
24
+ # tt_vqa:
25
+ # data_type: frames
26
+ # sample_ratio: 1
27
+
28
+ ShareGPT4V:
29
+ data_type: images
30
+ sample_ratio: 1
31
+
32
+ gpt4v_tt_vqa:
33
+ data_type: frames
34
+ fps: 0.5
35
+ sample_ratio: 6
36
+ conv_type: single
37
+ task_types: ['caption', 'qas']
38
+
39
+ gpt4v_public:
40
+ data_type: frames
41
+ fps: 1.0
42
+ sample_ratio: 6
43
+ conv_type: single
44
+ task_types: ['summary', 'detail', 'qa_pairs']
45
+ sample_method: sequential
46
+ train_data_path: /mnt/bn/algo-masp-nas-2/xiangchen/data/shared_gpt4v_data/data_130k.json
47
+
48
+ gpt4v_internal:
49
+ data_type: frames
50
+ fps: 2.0
51
+ sample_ratio: 1
52
+ conv_type: single
53
+ task_types: ['summary', 'detail', 'qa_pairs']
54
+
55
+
app/llava/configs/adso_increasing_ablation/finetune_gpt4v_adso135k.yaml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets:
2
+
3
+ # m3it:
4
+ # data_type: images
5
+ # sample_ratio: 4
6
+ # tasks:
7
+ # - coco
8
+ # - coco-goi
9
+ # - coco-text
10
+ # - imagenet
11
+ # - coco-itm
12
+ # - iqa
13
+ # - mocheg
14
+ # - vsr
15
+ # - refcoco
16
+ # - science-qa
17
+ # - vqa-v2
18
+ # - gqa
19
+ # - st-vqa
20
+ # - text-vqa
21
+ # - okvqa
22
+ # - a-okvqa
23
+ #
24
+ tt_vqa:
25
+ data_type: frames
26
+ sample_ratio: 2
27
+ fps: 2.0
28
+ conv_type: single
29
+ train_data_path: /mnt/bn/algo-masp-nas-2/xiangchen/dataset/masp/20240208_meta_data_single_135k_caption_160k_QA.json
30
+
31
+ ShareGPT4V:
32
+ data_type: images
33
+ sample_ratio: 1
34
+
35
+ gpt4v_tt_vqa:
36
+ data_type: frames
37
+ fps: 0.5
38
+ sample_ratio: 6
39
+ conv_type: single
40
+ task_types: ['caption', 'qas']
41
+
42
+
43
+ gpt4v_public:
44
+ data_type: frames
45
+ fps: 1.0
46
+ sample_ratio: 6
47
+ conv_type: single
48
+ task_types: ['summary', 'detail', 'qa_pairs']
49
+ sample_method: sequential
50
+ train_data_path: /mnt/bn/algo-masp-nas-2/xiangchen/data/shared_gpt4v_data/data_130k.json
51
+
52
+ gpt4v_internal:
53
+ data_type: frames
54
+ fps: 2.0
55
+ sample_ratio: 1
56
+ conv_type: single
57
+ task_types: ['summary', 'detail', 'qa_pairs']
app/llava/configs/adso_increasing_ablation/finetune_gpt4v_adso185k.yaml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets:
2
+
3
+ # m3it:
4
+ # data_type: images
5
+ # sample_ratio: 4
6
+ # tasks:
7
+ # - coco
8
+ # - coco-goi
9
+ # - coco-text
10
+ # - imagenet
11
+ # - coco-itm
12
+ # - iqa
13
+ # - mocheg
14
+ # - vsr
15
+ # - refcoco
16
+ # - science-qa
17
+ # - vqa-v2
18
+ # - gqa
19
+ # - st-vqa
20
+ # - text-vqa
21
+ # - okvqa
22
+ # - a-okvqa
23
+ #
24
+ tt_vqa:
25
+ data_type: frames
26
+ sample_ratio: 3
27
+ fps: 2.0
28
+ conv_type: single
29
+ train_data_path: /mnt/bn/algo-masp-nas-2/xiangchen/dataset/masp/20240220_meta_data_single_190k_caption_160k_QA.json
30
+
31
+ ShareGPT4V:
32
+ data_type: images
33
+ sample_ratio: 1
34
+
35
+ gpt4v_tt_vqa:
36
+ data_type: frames
37
+ fps: 0.5
38
+ sample_ratio: 6
39
+ conv_type: single
40
+ task_types: ['caption', 'qas']
41
+
42
+
43
+ gpt4v_public:
44
+ data_type: frames
45
+ fps: 1.0
46
+ sample_ratio: 6
47
+ conv_type: single
48
+ task_types: ['summary', 'detail', 'qa_pairs']
49
+ sample_method: sequential
50
+ train_data_path: /mnt/bn/algo-masp-nas-2/xiangchen/data/shared_gpt4v_data/data_130k.json
51
+
52
+ gpt4v_internal:
53
+ data_type: frames
54
+ fps: 2.0
55
+ sample_ratio: 1
56
+ conv_type: single
57
+ task_types: ['summary', 'detail', 'qa_pairs']
app/llava/configs/adso_increasing_ablation/finetune_gpt4v_adso185k_baseline.yaml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets:
2
+
3
+ # m3it:
4
+ # data_type: images
5
+ # sample_ratio: 4
6
+ # tasks:
7
+ # - coco
8
+ # - coco-goi
9
+ # - coco-text
10
+ # - imagenet
11
+ # - coco-itm
12
+ # - iqa
13
+ # - mocheg
14
+ # - vsr
15
+ # - refcoco
16
+ # - science-qa
17
+ # - vqa-v2
18
+ # - gqa
19
+ # - st-vqa
20
+ # - text-vqa
21
+ # - okvqa
22
+ # - a-okvqa
23
+ #
24
+ tt_vqa:
25
+ data_type: frames
26
+ sample_ratio: 3
27
+ fps: 2.0
28
+ conv_type: single
29
+ train_data_path: /mnt/bn/algo-masp-nas-2/xiangchen/dataset/masp/20240220_meta_data_single_190k_caption_160k_QA.json
30
+
31
+ ShareGPT4V:
32
+ data_type: images
33
+ sample_ratio: 1
34
+
35
+ gpt4v_tt_vqa:
36
+ data_type: frames
37
+ fps: 0.5
38
+ sample_ratio: 6
39
+ conv_type: single
40
+ task_types: ['caption', 'qas']
41
+
42
+
43
+ lk_video:
44
+ data_type: frames
45
+ conv_type: multi
46
+ fps: 1.0
47
+ sample_ratio: 6
48
+
49
+ gpt4v_internal:
50
+ data_type: frames
51
+ fps: 2.0
52
+ sample_ratio: 1
53
+ conv_type: single
54
+ task_types: ['detail']
55
+
app/llava/configs/adso_increasing_ablation/finetune_gpt4v_adso185k_no_qa.yaml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets:
2
+
3
+ # m3it:
4
+ # data_type: images
5
+ # sample_ratio: 4
6
+ # tasks:
7
+ # - coco
8
+ # - coco-goi
9
+ # - coco-text
10
+ # - imagenet
11
+ # - coco-itm
12
+ # - iqa
13
+ # - mocheg
14
+ # - vsr
15
+ # - refcoco
16
+ # - science-qa
17
+ # - vqa-v2
18
+ # - gqa
19
+ # - st-vqa
20
+ # - text-vqa
21
+ # - okvqa
22
+ # - a-okvqa
23
+ #
24
+ tt_vqa:
25
+ data_type: frames
26
+ sample_ratio: 3
27
+ fps: 2.0
28
+ conv_type: single
29
+ train_data_path: /mnt/bn/algo-masp-nas-2/xiangchen/dataset/masp/20240220_meta_data_single_190k_caption_no_QA.json
30
+
31
+ ShareGPT4V:
32
+ data_type: images
33
+ sample_ratio: 1
34
+
35
+ gpt4v_tt_vqa:
36
+ data_type: frames
37
+ fps: 0.5
38
+ sample_ratio: 6
39
+ conv_type: single
40
+ task_types: ['caption', 'qas']
41
+
42
+
43
+ gpt4v_public:
44
+ data_type: frames
45
+ fps: 1.0
46
+ sample_ratio: 6
47
+ conv_type: single
48
+ task_types: ['summary', 'detail', 'qa_pairs']
49
+ sample_method: sequential
50
+ train_data_path: /mnt/bn/algo-masp-nas-2/xiangchen/data/shared_gpt4v_data/data_130k.json
51
+
52
+ gpt4v_internal:
53
+ data_type: frames
54
+ fps: 2.0
55
+ sample_ratio: 1
56
+ conv_type: single
57
+ task_types: ['summary', 'detail', 'qa_pairs']
app/llava/configs/adso_increasing_ablation/finetune_gpt4v_adso65k.yaml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets:
2
+
3
+ # m3it:
4
+ # data_type: images
5
+ # sample_ratio: 4
6
+ # tasks:
7
+ # - coco
8
+ # - coco-goi
9
+ # - coco-text
10
+ # - imagenet
11
+ # - coco-itm
12
+ # - iqa
13
+ # - mocheg
14
+ # - vsr
15
+ # - refcoco
16
+ # - science-qa
17
+ # - vqa-v2
18
+ # - gqa
19
+ # - st-vqa
20
+ # - text-vqa
21
+ # - okvqa
22
+ # - a-okvqa
23
+ #
24
+ tt_vqa:
25
+ data_type: frames
26
+ sample_ratio: 2
27
+ fps: 2.0
28
+ conv_type: single
29
+ train_data_path: /mnt/bn/algo-masp-nas-2/baiyi.by/data/ADSO_Anno_Data/batch_20231128/meta_data_single_60k_caption_170k_QA.json
30
+
31
+ ShareGPT4V:
32
+ data_type: images
33
+ sample_ratio: 1
34
+
35
+ gpt4v_tt_vqa:
36
+ data_type: frames
37
+ fps: 0.5
38
+ sample_ratio: 6
39
+ conv_type: single
40
+ task_types: ['caption', 'qas']
41
+
42
+
43
+ gpt4v_public:
44
+ data_type: frames
45
+ fps: 1.0
46
+ sample_ratio: 6
47
+ conv_type: single
48
+ task_types: ['summary', 'detail', 'qa_pairs']
49
+ sample_method: sequential
50
+ train_data_path: /mnt/bn/algo-masp-nas-2/xiangchen/data/shared_gpt4v_data/data_130k.json
51
+
52
+ gpt4v_internal:
53
+ data_type: frames
54
+ fps: 2.0
55
+ sample_ratio: 1
56
+ conv_type: single
57
+ task_types: ['summary', 'detail', 'qa_pairs']
app/llava/configs/finetune_debug.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ datasets:
2
+ gpt4v_public:
3
+ data_type: frames
4
+ fps: 1.0
5
+ sample_ratio: 6
6
+ conv_type: single
7
+ task_types: ['summary', 'detail', 'qa_pairs']
8
+ train_data_path: /mnt/bn/algo-masp-nas-2/xiangchen/data/shared_gpt4v_data/data_130k.json
app/llava/configs/finetune_gpt4v_adso65k.yaml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets:
2
+
3
+ # m3it:
4
+ # data_type: images
5
+ # sample_ratio: 4
6
+ # tasks:
7
+ # - coco
8
+ # - coco-goi
9
+ # - coco-text
10
+ # - imagenet
11
+ # - coco-itm
12
+ # - iqa
13
+ # - mocheg
14
+ # - vsr
15
+ # - refcoco
16
+ # - science-qa
17
+ # - vqa-v2
18
+ # - gqa
19
+ # - st-vqa
20
+ # - text-vqa
21
+ # - okvqa
22
+ # - a-okvqa
23
+ #
24
+ tt_vqa:
25
+ data_type: frames
26
+ sample_ratio: 2
27
+ fps: 2.0
28
+ conv_type: single
29
+ train_data_path: /mnt/bn/algo-masp-nas-2/baiyi.by/data/ADSO_Anno_Data/batch_20231128/meta_data_single_60k_caption_170k_QA.json
30
+
31
+ ShareGPT4V:
32
+ data_type: images
33
+ sample_ratio: 1
34
+
35
+ gpt4v_tt_vqa:
36
+ data_type: frames
37
+ fps: 0.5
38
+ sample_ratio: 6
39
+ conv_type: single
40
+ task_types: ['caption', 'qas']
41
+
42
+
43
+ gpt4v_public:
44
+ data_type: frames
45
+ fps: 1.0
46
+ sample_ratio: 6
47
+ conv_type: single
48
+ task_types: ['summary', 'detail', 'qa_pairs']
49
+ train_data_path: /mnt/bn/algo-masp-nas-2/xiangchen/data/shared_gpt4v_data/data_130k.json
50
+
51
+ gpt4v_internal:
52
+ data_type: frames
53
+ fps: 2.0
54
+ sample_ratio: 1
55
+ conv_type: single
56
+ task_types: ['summary', 'detail', 'qa_pairs']
app/llava/configs/gpt4v_increasing_ablation/finetune_gpt4v_public500k.yaml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets:
2
+
3
+ # m3it:
4
+ # data_type: images
5
+ # sample_ratio: 4
6
+ # tasks:
7
+ # - coco
8
+ # - coco-goi
9
+ # - coco-text
10
+ # - imagenet
11
+ # - coco-itm
12
+ # - iqa
13
+ # - mocheg
14
+ # - vsr
15
+ # - refcoco
16
+ # - science-qa
17
+ # - vqa-v2
18
+ # - gqa
19
+ # - st-vqa
20
+ # - text-vqa
21
+ # - okvqa
22
+ # - a-okvqa
23
+ #
24
+ # tt_vqa:
25
+ # data_type: frames
26
+ # sample_ratio: 2
27
+ # fps: 2.0
28
+ # conv_type: single
29
+ # train_data_path: /mnt/bn/algo-masp-nas-2/baiyi.by/data/ADSO_Anno_Data/batch_20231128/meta_data_single_60k_caption_170k_QA.json
30
+
31
+ ShareGPT4V:
32
+ data_type: images
33
+ sample_ratio: 1
34
+
35
+
36
+ gpt4v_tt_vqa:
37
+ data_type: frames
38
+ fps: 0.5
39
+ sample_ratio: 6
40
+ conv_type: single
41
+ task_types: ['caption', 'qas']
42
+
43
+
44
+ gpt4v_public:
45
+ data_type: frames
46
+ fps: 1.0
47
+ sample_ratio: 10
48
+ conv_type: single
49
+ task_types: ['summary', 'detail']
50
+ train_data_path: /mnt/bn/algo-masp-nas-2/xiangchen/data/shared_gpt4v_data/data_500k_filtered.json
51
+
52
+ gpt4v_internal:
53
+ data_type: frames
54
+ fps: 2.0
55
+ sample_ratio: 1
56
+ conv_type: single
57
+ task_types: ['summary', 'detail', 'qa_pairs']
app/llava/configs/gpt4v_increasing_ablation/finetune_gpt4v_public500k_no_summary.yaml ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets:
2
+
3
+ # m3it:
4
+ # data_type: images
5
+ # sample_ratio: 4
6
+ # tasks:
7
+ # - coco
8
+ # - coco-goi
9
+ # - coco-text
10
+ # - imagenet
11
+ # - coco-itm
12
+ # - iqa
13
+ # - mocheg
14
+ # - vsr
15
+ # - refcoco
16
+ # - science-qa
17
+ # - vqa-v2
18
+ # - gqa
19
+ # - st-vqa
20
+ # - text-vqa
21
+ # - okvqa
22
+ # - a-okvqa
23
+ #
24
+ # tt_vqa:
25
+ # data_type: frames
26
+ # sample_ratio: 2
27
+ # fps: 2.0
28
+ # conv_type: single
29
+ # train_data_path: /mnt/bn/algo-masp-nas-2/baiyi.by/data/ADSO_Anno_Data/batch_20231128/meta_data_single_60k_caption_170k_QA.json
30
+
31
+ ShareGPT4V:
32
+ data_type: images
33
+ sample_ratio: 1
34
+
35
+
36
+ gpt4v_tt_vqa:
37
+ data_type: frames
38
+ fps: 0.5
39
+ sample_ratio: 6
40
+ conv_type: single
41
+ task_types: ['caption', 'qas']
42
+
43
+
44
+ gpt4v_public:
45
+ data_type: frames
46
+ fps: 1.0
47
+ sample_ratio: 4
48
+ conv_type: single
49
+ task_types: ['detail']
50
+ train_data_path: /mnt/bn/algo-masp-nas-2/xiangchen/data/shared_gpt4v_data/data_500k_filtered.json
51
+
52
+ gpt4v_internal:
53
+ data_type: frames
54
+ fps: 2.0
55
+ sample_ratio: 1
56
+ conv_type: single
57
+ task_types: ['summary', 'detail', 'qa_pairs']
app/llava/configs/gpt4v_increasing_ablation/finetune_gpt4v_public800k.yaml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets:
2
+
3
+ # m3it:
4
+ # data_type: images
5
+ # sample_ratio: 4
6
+ # tasks:
7
+ # - coco
8
+ # - coco-goi
9
+ # - coco-text
10
+ # - imagenet
11
+ # - coco-itm
12
+ # - iqa
13
+ # - mocheg
14
+ # - vsr
15
+ # - refcoco
16
+ # - science-qa
17
+ # - vqa-v2
18
+ # - gqa
19
+ # - st-vqa
20
+ # - text-vqa
21
+ # - okvqa
22
+ # - a-okvqa
23
+ #
24
+ # tt_vqa:
25
+ # data_type: frames
26
+ # sample_ratio: 2
27
+ # fps: 2.0
28
+ # conv_type: single
29
+ # train_data_path: /mnt/bn/algo-masp-nas-2/baiyi.by/data/ADSO_Anno_Data/batch_20231128/meta_data_single_60k_caption_170k_QA.json
30
+
31
+ ShareGPT4V:
32
+ data_type: images
33
+ sample_ratio: 1
34
+
35
+
36
+ gpt4v_tt_vqa:
37
+ data_type: frames
38
+ fps: 0.5
39
+ sample_ratio: 6
40
+ conv_type: single
41
+ task_types: ['caption', 'qas']
42
+
43
+ # gpt4v_public:
44
+ # data_type: frames
45
+ # fps: 1.0
46
+ # sample_ratio: 10
47
+ # conv_type: single
48
+ # task_types: ['summary', 'detail']
49
+ # train_data_path: /mnt/bn/algo-masp-nas-2/xiangchen/data/shared_gpt4v_data/data_500k_filtered.json
50
+
51
+ lk_video:
52
+ data_type: frames
53
+ conv_type: multi
54
+ fps: 1.0
55
+ sample_ratio: 6
56
+
57
+ gpt4v_internal:
58
+ data_type: frames
59
+ fps: 2.0
60
+ sample_ratio: 1
61
+ conv_type: single
62
+ task_types: ['summary', 'detail', 'qa_pairs']
app/llava/configs/gpt4v_increasing_ablation/finetune_videollava.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets:
2
+
3
+ # llava_pretrain:
4
+ # data_type: image
5
+ # sample_ratio: 1
6
+
7
+ # gpt4v_public:
8
+ # data_type: frames
9
+ # sample_ratio: 2
10
+ # task_types: ['summary']
11
+ # fps: 1.0
12
+ # conv_type: single
13
+
14
+ lk_image:
15
+ data_type: image
16
+
17
+ lk_video:
18
+ data_type: frames
19
+ conv_type: multi
20
+ fps: 1.0
app/llava/configs/pretrain_data.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets:
2
+
3
+ llava_pretrain:
4
+ data_type: image
5
+ sample_ratio: 1
6
+
7
+ # internvid:
8
+ # data_type: frames
9
+ # sample_ratio: 10
10
+
11
+ gpt4v_public:
12
+ data_type: frames
13
+ sample_ratio: 1
14
+ task_types: ['summary']
15
+ fps: 1.0
16
+ conv_type: single
17
+ train_data_path: /mnt/bn/algo-masp-nas-2/xiangchen/data/shared_gpt4v_data/data_500k_filtered.json
app/llava/configs/pretrain_data_large.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets:
2
+
3
+ llava_pretrain:
4
+ data_type: image
5
+ sample_ratio: 1
6
+
7
+ internvid:
8
+ data_type: frames
9
+ sample_ratio: 10
10
+
11
+ gpt4v_public:
12
+ data_type: frames
13
+ sample_ratio: 1
14
+ task_types: ['summary']
15
+ fps: 1.0
16
+ conv_type: single
17
+ train_data_path: /mnt/bn/algo-masp-nas-2/xiangchen/data/shared_gpt4v_data/data_500k_filtered.json
app/llava/configs/pretrain_debug.yaml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets:
2
+
3
+ llava_pretrain:
4
+ data_type: image
5
+ sample_ratio: 1
6
+
7
+ # gpt4v_public:
8
+ # data_type: frames
9
+ # sample_ratio: 2
10
+ # task_types: ['summary']
11
+ # fps: 1.0
12
+ # conv_type: single
13
+
14
+ # lk_image:
15
+ # data_type: image
16
+
17
+ # lk_video:
18
+ # data_type: frames
19
+ # conv_type: multi
20
+ # fps: 1.0
21
+
22
+ gpt4v_internal:
23
+ data_type: frames
24
+ fps: 2.0
25
+ sample_ratio: 1
26
+ conv_type: multi
27
+ task_types: ['qa_pairs']
app/llava/configs/promptv1_2_increasing_ablation/finetune_gpt4_prompt_140k.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets:
2
+
3
+ # m3it:
4
+ # data_type: images
5
+ # sample_ratio: 4
6
+ # tasks:
7
+ # - coco
8
+ # - coco-goi
9
+ # - coco-text
10
+ # - imagenet
11
+ # - coco-itm
12
+ # - iqa
13
+ # - mocheg
14
+ # - vsr
15
+ # - refcoco
16
+ # - science-qa
17
+ # - vqa-v2
18
+ # - gqa
19
+ # - st-vqa
20
+ # - text-vqa
21
+ # - okvqa
22
+ # - a-okvqa
23
+ #
24
+ # gpt4v_internal:
25
+ # data_type: frames
26
+ # fps: 2.0
27
+ # sample_ratio: 1
28
+ # conv_type: single
29
+ # task_types: ['summary', 'detail', 'qa_pairs']
30
+
31
+ promptv1_2_internal:
32
+ data_type: frames
33
+ sample_ratio: 1
34
+ train_data_path: /mnt/bn/algo-masp-nas-2/kaili.zhao/data/masp_data/train/gpt4v_annotation/202400401week_gpt4v_all_videos_unique_ids.json
35
+ task_types: ['refine_caption']
app/llava/configs/release_version/finetune_250k_no_public.yaml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets:
2
+
3
+ # m3it:
4
+ # data_type: images
5
+ # sample_ratio: 4
6
+ # tasks:
7
+ # - coco
8
+ # - coco-goi
9
+ # - coco-text
10
+ # - imagenet
11
+ # - coco-itm
12
+ # - iqa
13
+ # - mocheg
14
+ # - vsr
15
+ # - refcoco
16
+ # - science-qa
17
+ # - vqa-v2
18
+ # - gqa
19
+ # - st-vqa
20
+ # - text-vqa
21
+ # - okvqa
22
+ # - a-okvqa
23
+ #
24
+ tt_vqa:
25
+ data_type: frames
26
+ sample_ratio: 3
27
+ fps: 2.0
28
+ conv_type: single
29
+ train_data_path: /mnt/bn/algo-masp-nas-2/xiangchen/dataset/masp/20231201_20240322_caption_250k.json
30
+
31
+
32
+ ShareGPT4V:
33
+ data_type: images
34
+ sample_ratio: 1
35
+
36
+
37
+ gpt4v_tt_vqa:
38
+ data_type: frames
39
+ fps: 0.5
40
+ sample_ratio: 6
41
+ conv_type: single
42
+ task_types: ['caption']
43
+
44
+
45
+ gpt4v_internal:
46
+ data_type: frames
47
+ fps: 2.0
48
+ sample_ratio: 1
49
+ conv_type: single
50
+ task_types: ['detail']
app/llava/configs/release_version/finetune_all_data.yaml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets:
2
+
3
+ # m3it:
4
+ # data_type: images
5
+ # sample_ratio: 4
6
+ # tasks:
7
+ # - coco
8
+ # - coco-goi
9
+ # - coco-text
10
+ # - imagenet
11
+ # - coco-itm
12
+ # - iqa
13
+ # - mocheg
14
+ # - vsr
15
+ # - refcoco
16
+ # - science-qa
17
+ # - vqa-v2
18
+ # - gqa
19
+ # - st-vqa
20
+ # - text-vqa
21
+ # - okvqa
22
+ # - a-okvqa
23
+ #
24
+ tt_vqa:
25
+ data_type: frames
26
+ sample_ratio: 3
27
+ fps: 2.0
28
+ conv_type: single
29
+ train_data_path: /mnt/bn/algo-masp-nas-2/xiangchen/dataset/masp/20231201_20240322_caption_250k.json
30
+
31
+
32
+ ShareGPT4V:
33
+ data_type: images
34
+ sample_ratio: 1
35
+
36
+
37
+ gpt4v_tt_vqa:
38
+ data_type: frames
39
+ fps: 0.5
40
+ sample_ratio: 6
41
+ conv_type: single
42
+ task_types: ['caption']
43
+
44
+ # gpt4v_public:
45
+ # data_type: frames
46
+ # fps: 1.0
47
+ # sample_ratio: 10
48
+ # conv_type: single
49
+ # task_types: ['summary', 'detail']
50
+ # train_data_path: /mnt/bn/algo-masp-nas-2/xiangchen/data/shared_gpt4v_data/data_500k_filtered.json
51
+
52
+ lk_video:
53
+ data_type: frames
54
+ conv_type: multi
55
+ fps: 1.0
56
+ sample_ratio: 6
57
+
58
+ gpt4v_internal:
59
+ data_type: frames
60
+ fps: 2.0
61
+ sample_ratio: 1
62
+ conv_type: single
63
+ task_types: ['detail']
app/llava/configs/release_version/finetune_gpt4v_caption.yaml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets:
2
+
3
+ # m3it:
4
+ # data_type: images
5
+ # sample_ratio: 4
6
+ # tasks:
7
+ # - coco
8
+ # - coco-goi
9
+ # - coco-text
10
+ # - imagenet
11
+ # - coco-itm
12
+ # - iqa
13
+ # - mocheg
14
+ # - vsr
15
+ # - refcoco
16
+ # - science-qa
17
+ # - vqa-v2
18
+ # - gqa
19
+ # - st-vqa
20
+ # - text-vqa
21
+ # - okvqa
22
+ # - a-okvqa
23
+ #
24
+ # tt_vqa:
25
+ # data_type: frames
26
+ # sample_ratio: 2
27
+ # fps: 2.0
28
+ # conv_type: single
29
+ # train_data_path: /mnt/bn/algo-masp-nas-2/baiyi.by/data/ADSO_Anno_Data/batch_20231128/meta_data_single_60k_caption_170k_QA.json
30
+
31
+ ShareGPT4V:
32
+ data_type: images
33
+ sample_ratio: 1
34
+
35
+
36
+ gpt4v_tt_vqa:
37
+ data_type: frames
38
+ fps: 0.5
39
+ sample_ratio: 6
40
+ conv_type: single
41
+ task_types: ['caption']
42
+
43
+ # gpt4v_public:
44
+ # data_type: frames
45
+ # fps: 1.0
46
+ # sample_ratio: 10
47
+ # conv_type: single
48
+ # task_types: ['summary', 'detail']
49
+ # train_data_path: /mnt/bn/algo-masp-nas-2/xiangchen/data/shared_gpt4v_data/data_500k_filtered.json
50
+
51
+ lk_video:
52
+ data_type: frames
53
+ conv_type: multi
54
+ fps: 1.0
55
+ sample_ratio: 6
56
+
57
+ gpt4v_internal:
58
+ data_type: frames
59
+ fps: 2.0
60
+ sample_ratio: 1
61
+ conv_type: single
62
+ task_types: ['detail']
app/llava/configs/release_version/finetune_gpt4v_caption_ocr.yaml ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ datasets:
2
+
3
+ # m3it:
4
+ # data_type: images
5
+ # sample_ratio: 4
6
+ # tasks:
7
+ # - coco
8
+ # - coco-goi
9
+ # - coco-text
10
+ # - imagenet
11
+ # - coco-itm
12
+ # - iqa
13
+ # - mocheg
14
+ # - vsr
15
+ # - refcoco
16
+ # - science-qa
17
+ # - vqa-v2
18
+ # - gqa
19
+ # - st-vqa
20
+ # - text-vqa
21
+ # - okvqa
22
+ # - a-okvqa
23
+ #
24
+ # tt_vqa:
25
+ # data_type: frames
26
+ # sample_ratio: 2
27
+ # fps: 2.0
28
+ # conv_type: single
29
+ # train_data_path: /mnt/bn/algo-masp-nas-2/baiyi.by/data/ADSO_Anno_Data/batch_20231128/meta_data_single_60k_caption_170k_QA.json
30
+
31
+ ShareGPT4V:
32
+ data_type: images
33
+ sample_ratio: 1
34
+
35
+
36
+ gpt4v_tt_vqa:
37
+ data_type: frames
38
+ fps: 0.5
39
+ sample_ratio: 6
40
+ conv_type: single
41
+ task_types: ['caption']
42
+
43
+ # gpt4v_public:
44
+ # data_type: frames
45
+ # fps: 1.0
46
+ # sample_ratio: 10
47
+ # conv_type: single
48
+ # task_types: ['summary', 'detail']
49
+ # train_data_path: /mnt/bn/algo-masp-nas-2/xiangchen/data/shared_gpt4v_data/data_500k_filtered.json
50
+
51
+ lk_video:
52
+ data_type: frames
53
+ conv_type: multi
54
+ fps: 1.0
55
+ sample_ratio: 6
56
+
57
+ gpt4v_internal:
58
+ data_type: frames
59
+ fps: 2.0
60
+ sample_ratio: 1
61
+ conv_type: single
62
+ task_types: ['detail']
63
+
64
+ synthetic_ocr:
65
+ data_type: video
66
+ sample_ratio: 1
67
+ fps: 0.5
app/llava/constants.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONTROLLER_HEART_BEAT_EXPIRATION = 30
2
+ WORKER_HEART_BEAT_INTERVAL = 15
3
+
4
+ LOGDIR = "."
5
+
6
+ # Model Constants
7
+ IGNORE_INDEX = -100
8
+ MM_TOKEN_INDEX = -200
9
+ DEFAULT_IMAGE_TOKEN = "<image>"
10
+ DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
11
+ DEFAULT_IM_START_TOKEN = "<im_start>"
12
+ DEFAULT_IM_END_TOKEN = "<im_end>"
13
+ IMAGE_PLACEHOLDER = "<image-placeholder>"
14
+ DEFAULT_VIDEO_TOKEN = "<video>"
15
+ DEFAULT_VIDEO_PATCH_TOKEN = "<vid_patch>"
16
+ DEFAULT_VIDEO_START_TOKEN = "<vid_start>"
17
+ DEFAULT_VIDEO_END_TOKEN = "<vid_end>"
app/llava/conversation.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+ import base64
5
+ from io import BytesIO
6
+ from PIL import Image
7
+
8
+
9
+ class SeparatorStyle(Enum):
10
+ """Different separator style."""
11
+ SINGLE = auto()
12
+ TWO = auto()
13
+ MPT = auto()
14
+ PLAIN = auto()
15
+ LLAMA_2 = auto()
16
+
17
+
18
+ @dataclasses.dataclass
19
+ class Conversation:
20
+ """A class that keeps all conversation history."""
21
+ system: str
22
+ roles: List[str]
23
+ messages: List[List[str]]
24
+ offset: int
25
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
26
+ sep: str = "###"
27
+ sep2: str = None
28
+ version: str = "Unknown"
29
+
30
+ skip_next: bool = False
31
+
32
+ def get_prompt(self, use_chat_template=False, tokenizer=None):
33
+ if use_chat_template:
34
+ assert tokenizer is not None, "must have tokenizer when using chat template"
35
+ messages = self.messages
36
+ # whether in inference mode
37
+ if messages[-1][0] == self.roles[1] and (messages[-1][1] is None or messages[-1][1] == ''):
38
+ generate_flag = True
39
+ messages = messages[:-1]
40
+ else:
41
+ generate_flag = False
42
+ chat = []
43
+ for role, message in messages:
44
+ chat.append(
45
+ {
46
+ "role": role,
47
+ "content": message,
48
+ }
49
+ )
50
+ return tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=generate_flag)
51
+ else:
52
+ messages = self.messages
53
+ if len(messages) > 0 and type(messages[0][1]) is tuple:
54
+ messages = self.messages.copy()
55
+ init_role, init_msg = messages[0].copy()
56
+ init_msg = init_msg[0].replace("<image>", "").strip()
57
+ if 'mmtag' in self.version:
58
+ messages[0] = (init_role, init_msg)
59
+ messages.insert(0, (self.roles[0], "<Image><image></Image>"))
60
+ messages.insert(1, (self.roles[1], "Received."))
61
+ else:
62
+ messages[0] = (init_role, "<image>\n" + init_msg)
63
+
64
+ if self.sep_style == SeparatorStyle.SINGLE:
65
+ ret = self.system + self.sep
66
+ for role, message in messages:
67
+ if message:
68
+ if type(message) is tuple:
69
+ message, _, _ = message
70
+ ret += role + ": " + message + self.sep
71
+ else:
72
+ ret += role + ":"
73
+ elif self.sep_style == SeparatorStyle.TWO:
74
+ seps = [self.sep, self.sep2]
75
+ ret = self.system + seps[0]
76
+ for i, (role, message) in enumerate(messages):
77
+ if message:
78
+ if type(message) is tuple:
79
+ message, _, _ = message
80
+ ret += role + ": " + message + seps[i % 2]
81
+ else:
82
+ ret += role + ":"
83
+ elif self.sep_style == SeparatorStyle.MPT:
84
+ ret = self.system + self.sep
85
+ for role, message in messages:
86
+ if message:
87
+ if type(message) is tuple:
88
+ message, _, _ = message
89
+ ret += role + message + self.sep
90
+ else:
91
+ ret += role
92
+ elif self.sep_style == SeparatorStyle.LLAMA_2:
93
+ wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n" if len(msg) > 0 else msg
94
+ wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
95
+ ret = ""
96
+
97
+ for i, (role, message) in enumerate(messages):
98
+ if i == 0:
99
+ assert message, "first message should not be none"
100
+ assert role == self.roles[0], "first message should come from user"
101
+ if message:
102
+ if type(message) is tuple:
103
+ message, _, _ = message
104
+ if i == 0: message = wrap_sys(self.system) + message
105
+ if i % 2 == 0:
106
+ message = wrap_inst(message)
107
+ ret += self.sep + message
108
+ else:
109
+ ret += " " + message + " " + self.sep2
110
+ else:
111
+ ret += ""
112
+ ret = ret.lstrip(self.sep)
113
+ elif self.sep_style == SeparatorStyle.PLAIN:
114
+ seps = [self.sep, self.sep2]
115
+ ret = self.system
116
+ for i, (role, message) in enumerate(messages):
117
+ if message:
118
+ if type(message) is tuple:
119
+ message, _, _ = message
120
+ ret += message + seps[i % 2]
121
+ else:
122
+ ret += ""
123
+ else:
124
+ raise ValueError(f"Invalid style: {self.sep_style}")
125
+
126
+ return ret
127
+
128
+
129
+
130
+ def append_message(self, role, message):
131
+ self.messages.append([role, message])
132
+
133
+ def process_image(self, image, image_process_mode, return_pil=False, image_format='PNG', max_len=1344, min_len=672):
134
+ if image_process_mode == "Pad":
135
+ def expand2square(pil_img, background_color=(122, 116, 104)):
136
+ width, height = pil_img.size
137
+ if width == height:
138
+ return pil_img
139
+ elif width > height:
140
+ result = Image.new(pil_img.mode, (width, width), background_color)
141
+ result.paste(pil_img, (0, (width - height) // 2))
142
+ return result
143
+ else:
144
+ result = Image.new(pil_img.mode, (height, height), background_color)
145
+ result.paste(pil_img, ((height - width) // 2, 0))
146
+ return result
147
+ image = expand2square(image)
148
+ elif image_process_mode in ["Default", "Crop"]:
149
+ pass
150
+ elif image_process_mode == "Resize":
151
+ image = image.resize((336, 336))
152
+ else:
153
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
154
+ if max(image.size) > max_len:
155
+ max_hw, min_hw = max(image.size), min(image.size)
156
+ aspect_ratio = max_hw / min_hw
157
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
158
+ longest_edge = int(shortest_edge * aspect_ratio)
159
+ W, H = image.size
160
+ if H > W:
161
+ H, W = longest_edge, shortest_edge
162
+ else:
163
+ H, W = shortest_edge, longest_edge
164
+ image = image.resize((W, H))
165
+ if return_pil:
166
+ return image
167
+ else:
168
+ buffered = BytesIO()
169
+ image.save(buffered, format=image_format)
170
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
171
+ return img_b64_str
172
+
173
+ def get_images(self, return_pil=False):
174
+ images = []
175
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
176
+ if i % 2 == 0:
177
+ if type(msg) is tuple:
178
+ msg, image, image_process_mode = msg
179
+ image = self.process_image(image, image_process_mode, return_pil=return_pil)
180
+ images.append(image)
181
+ return images
182
+
183
+ def to_gradio_chatbot(self):
184
+ ret = []
185
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
186
+ if i % 2 == 0:
187
+ if type(msg) is tuple:
188
+ msg, image, image_process_mode = msg
189
+ img_b64_str = self.process_image(
190
+ image, "Default", return_pil=False,
191
+ image_format='JPEG')
192
+ img_str = f'<img src="data:image/jpeg;base64,{img_b64_str}" alt="user upload image" />'
193
+ msg = img_str + msg.replace('<image>', '').strip()
194
+ ret.append([msg, None])
195
+ else:
196
+ ret.append([msg, None])
197
+ else:
198
+ ret[-1][-1] = msg
199
+ return ret
200
+
201
+ def copy(self):
202
+ return Conversation(
203
+ system=self.system,
204
+ roles=self.roles,
205
+ messages=[[x, y] for x, y in self.messages],
206
+ offset=self.offset,
207
+ sep_style=self.sep_style,
208
+ sep=self.sep,
209
+ sep2=self.sep2,
210
+ version=self.version)
211
+
212
+ def dict(self):
213
+ if len(self.get_images()) > 0:
214
+ return {
215
+ "system": self.system,
216
+ "roles": self.roles,
217
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
218
+ "offset": self.offset,
219
+ "sep": self.sep,
220
+ "sep2": self.sep2,
221
+ }
222
+ return {
223
+ "system": self.system,
224
+ "roles": self.roles,
225
+ "messages": self.messages,
226
+ "offset": self.offset,
227
+ "sep": self.sep,
228
+ "sep2": self.sep2,
229
+ }
230
+
231
+
232
+ conv_vicuna_v0 = Conversation(
233
+ system="A chat between a curious human and an artificial intelligence assistant. "
234
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
235
+ roles=("Human", "Assistant"),
236
+ messages=(
237
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
238
+ ("Assistant",
239
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
240
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
241
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
242
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
243
+ "renewable and non-renewable energy sources:\n"
244
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
245
+ "energy sources are finite and will eventually run out.\n"
246
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
247
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
248
+ "and other negative effects.\n"
249
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
250
+ "have lower operational costs than non-renewable sources.\n"
251
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
252
+ "locations than non-renewable sources.\n"
253
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
254
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
255
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
256
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
257
+ ),
258
+ offset=2,
259
+ sep_style=SeparatorStyle.SINGLE,
260
+ sep="###",
261
+ )
262
+
263
+ conv_vicuna_v1 = Conversation(
264
+ system="A chat between a curious user and an artificial intelligence assistant. "
265
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
266
+ roles=("USER", "ASSISTANT"),
267
+ version="v1",
268
+ messages=(),
269
+ offset=0,
270
+ sep_style=SeparatorStyle.TWO,
271
+ sep=" ",
272
+ sep2="</s>",
273
+ )
274
+
275
+ conv_llama_2 = Conversation(
276
+ system="""You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
277
+
278
+ If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.""",
279
+ roles=("USER", "ASSISTANT"),
280
+ version="llama_v2",
281
+ messages=(),
282
+ offset=0,
283
+ sep_style=SeparatorStyle.LLAMA_2,
284
+ sep="<s>",
285
+ sep2="</s>",
286
+ )
287
+
288
+ conv_llava_llama_2 = Conversation(
289
+ system="You are a helpful language and vision assistant. "
290
+ "You are able to understand the visual content that the user provides, "
291
+ "and assist the user with a variety of tasks using natural language.",
292
+ roles=("USER", "ASSISTANT"),
293
+ version="llama_v2",
294
+ messages=(),
295
+ offset=0,
296
+ sep_style=SeparatorStyle.LLAMA_2,
297
+ sep="<s>",
298
+ sep2="</s>",
299
+ )
300
+
301
+ conv_mpt = Conversation(
302
+ system="""<|im_start|>system
303
+ A conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.""",
304
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
305
+ version="mpt",
306
+ messages=(),
307
+ offset=0,
308
+ sep_style=SeparatorStyle.MPT,
309
+ sep="<|im_end|>",
310
+ )
311
+
312
+ conv_llava_plain = Conversation(
313
+ system="",
314
+ roles=("", ""),
315
+ messages=(
316
+ ),
317
+ offset=0,
318
+ sep_style=SeparatorStyle.PLAIN,
319
+ sep="\n",
320
+ )
321
+
322
+ conv_llava_v0 = Conversation(
323
+ system="A chat between a curious human and an artificial intelligence assistant. "
324
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
325
+ roles=("Human", "Assistant"),
326
+ messages=(
327
+ ),
328
+ offset=0,
329
+ sep_style=SeparatorStyle.SINGLE,
330
+ sep="###",
331
+ )
332
+
333
+ conv_llava_v0_mmtag = Conversation(
334
+ system="A chat between a curious user and an artificial intelligence assistant. "
335
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
336
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
337
+ roles=("Human", "Assistant"),
338
+ messages=(
339
+ ),
340
+ offset=0,
341
+ sep_style=SeparatorStyle.SINGLE,
342
+ sep="###",
343
+ version="v0_mmtag",
344
+ )
345
+
346
+ conv_llava_v1 = Conversation(
347
+ system="A chat between a curious human and an artificial intelligence assistant. "
348
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
349
+ roles=("USER", "ASSISTANT"),
350
+ version="v1",
351
+ messages=(),
352
+ offset=0,
353
+ sep_style=SeparatorStyle.TWO,
354
+ sep=" ",
355
+ sep2="</s>",
356
+ )
357
+
358
+ conv_llava_v1_mmtag = Conversation(
359
+ system="A chat between a curious user and an artificial intelligence assistant. "
360
+ "The assistant is able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
361
+ "The visual content will be provided with the following format: <Image>visual content</Image>.",
362
+ roles=("USER", "ASSISTANT"),
363
+ messages=(),
364
+ offset=0,
365
+ sep_style=SeparatorStyle.TWO,
366
+ sep=" ",
367
+ sep2="</s>",
368
+ version="v1_mmtag",
369
+ )
370
+
371
+ # conv_mistral_instruct = Conversation(
372
+ # system="",
373
+ # roles=("USER", "ASSISTANT"),
374
+ # version="llama_v2",
375
+ # messages=(),
376
+ # offset=0,
377
+ # sep_style=SeparatorStyle.LLAMA_2,
378
+ # sep="",
379
+ # sep2="</s>",
380
+ # )
381
+ conv_mistral_instruct = Conversation(
382
+ system="",
383
+ roles=("user", "assistant"),
384
+ version="mistral",
385
+ messages=(),
386
+ offset=0,
387
+ sep_style=SeparatorStyle.MPT, # not used
388
+ sep="",
389
+ sep2="</s>",
390
+ )
391
+
392
+ conv_gemma = Conversation(
393
+ system="",
394
+ roles=("user", "model"),
395
+ version="gemma",
396
+ messages=(),
397
+ offset=0,
398
+ sep_style=SeparatorStyle.MPT, # not used
399
+ sep="<start_of_turn>",
400
+ sep2="<end_of_turn>",
401
+ )
402
+
403
+ conv_thoth = Conversation(
404
+ system="A chat between a curious user and an artificial intelligence assistant. "
405
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
406
+ roles=("USER", "ASSISTANT"),
407
+ version="thoth",
408
+ messages=(),
409
+ offset=0,
410
+ sep_style=SeparatorStyle.TWO,
411
+ sep=" ",
412
+ sep2="<[SEP_never_used_51bce0c785ca2f68081bfa7d91973934]>",
413
+ )
414
+
415
+ conv_chatml_direct = Conversation(
416
+ system="""<|im_start|>system
417
+ Answer the questions.""",
418
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
419
+ version="mpt",
420
+ messages=(),
421
+ offset=0,
422
+ sep_style=SeparatorStyle.MPT,
423
+ sep="<|im_end|>",
424
+ )
425
+
426
+
427
+ default_conversation = conv_vicuna_v1
428
+ conv_templates = {
429
+ "default": conv_vicuna_v0,
430
+ "v0": conv_vicuna_v0,
431
+ "v1": conv_vicuna_v1,
432
+ "vicuna_v1": conv_vicuna_v1,
433
+ "llama_2": conv_llama_2,
434
+ "mistral_instruct": conv_mistral_instruct,
435
+ "chatml_direct": conv_chatml_direct,
436
+ "mistral_direct": conv_chatml_direct,
437
+
438
+ "plain": conv_llava_plain,
439
+ "v0_plain": conv_llava_plain,
440
+ "llava_v0": conv_llava_v0,
441
+ "v0_mmtag": conv_llava_v0_mmtag,
442
+ "llava_v1": conv_llava_v1,
443
+ "v1_mmtag": conv_llava_v1_mmtag,
444
+ "llava_llama_2": conv_llava_llama_2,
445
+
446
+ "mpt": conv_mpt,
447
+ "gemma": conv_gemma,
448
+ "thoth": conv_thoth,
449
+
450
+ }
451
+
452
+
453
+ if __name__ == "__main__":
454
+ print(default_conversation.get_prompt())
app/llava/datasets/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .data_cfgs import *
2
+ from .base_dataset import *
3
+ from .prompts import *
4
+ from .super_dataset import *
5
+ from .cc_sbu_dataset import *
6
+ from .llava_pretrain_dataset import *
7
+ # from .llava_instruct_dataset import *
8
+ # from .lrv_instruct_dataset import *
9
+ from .internvid_dataset import *
10
+ from .tt_vqa_dataset import *
11
+ from .m3it_dataset import *
12
+ from .sharegpt4v_dataset import *
13
+ from .gpt4v_tt_vqa_dataset import *
14
+ from .gpt4v_public_dataset import *
15
+ from .gpt4v_internal_dataset import *
16
+ # from .synthdog_dataset import *
17
+ # from .ocr_vqa_dataset import *
18
+ # from .sharegpt_dataset import *
19
+ from .textcaps_dataset import *
20
+ from .synthetic_ocr_dataset import *
21
+ from .lk_image_dataset import *
22
+ from .lk_video_dataset import *
23
+
24
+ from .promptv1_2_internal_dataset import *
app/llava/datasets/base_dataset.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import sys
4
+ import copy
5
+ import math
6
+ import torch
7
+ import decord
8
+ import random
9
+ import numpy as np
10
+ from PIL import Image
11
+ from decord import VideoReader
12
+ from torch.utils.data import Dataset
13
+ from llava.utils import master_print
14
+ from typing import Dict, Optional, Sequence, List
15
+ from llava.datasets.data_cfgs import data_configs
16
+ from transformers import CLIPImageProcessor, SiglipImageProcessor
17
+
18
+ from llava.mm_utils import get_frame_indices, process_anyres_image
19
+ from torch.utils.data.dataloader import default_collate
20
+
21
+ decord.bridge.set_bridge("torch")
22
+
23
+ class TaskBaseDataset(Dataset):
24
+ """ Implementation of base task dataset """
25
+ def __init__(self, anno_path=None, data_args=None, name=None, **kwargs):
26
+
27
+ self.anno_path = anno_path
28
+ self.data_args = data_args
29
+ self.image_aspect_ratio = data_args.image_aspect_ratio
30
+ self.image_grid_pinpoints = data_args.image_grid_pinpoints
31
+ self.vis_processor = data_args.image_processor
32
+ self.type = None
33
+ self.name = name
34
+
35
+ master_print(f"Loading dataset {name}...")
36
+ if (anno_path is not None):
37
+ if not hasattr(self, 'annotation'):
38
+ self.annotation = json.load(open(anno_path, 'r'))
39
+ master_print(f"Finish loading dataset {name} {len(self.annotation)} samples...")
40
+
41
+ def __len__(self):
42
+ return len(self.annotation)
43
+
44
+ def collater(self, samples):
45
+ return default_collate(samples)
46
+
47
+ def text_preprocess(self, sources) -> List[List[Dict[str, str]]]:
48
+ pass
49
+
50
+ def vis_preprocess(self, vis_path) -> Image:
51
+ pass
52
+
53
+ @property
54
+ def data_type(self):
55
+ return self.type
56
+
57
+ def __getitem__(self, i) -> Dict[str, torch.Tensor]:
58
+ item = self.annotation[i]
59
+
60
+ vis_path = item['vis_path'] if 'vis_path' in item else item['video_path']
61
+
62
+ ret = {
63
+ 'images': self.vis_preprocess(vis_path),
64
+ 'conversations': self.text_preprocess(item)
65
+ }
66
+ if 'id' in item:
67
+ ret['id'] = item['id']
68
+
69
+ return ret
70
+
71
+
72
+ class ImageTaskDataset(TaskBaseDataset):
73
+ def __init__(self, anno_path=None, data_args=None, name=None):
74
+ super().__init__(anno_path=anno_path,
75
+ data_args=data_args,
76
+ name=name)
77
+ self.type = 'images'
78
+
79
+ @staticmethod
80
+ def expand2square(pil_img, background_color):
81
+ width, height = pil_img.size
82
+ if width == height:
83
+ return pil_img
84
+ elif width > height:
85
+ result = Image.new(pil_img.mode, (width, width), background_color)
86
+ result.paste(pil_img, (0, (width - height) // 2))
87
+ return result
88
+ else:
89
+ result = Image.new(pil_img.mode, (height, height), background_color)
90
+ result.paste(pil_img, ((height - width) // 2, 0))
91
+ return result
92
+
93
+ def preprocess_image(self, image):
94
+ if self.image_aspect_ratio == 'pad':
95
+ image = self.expand2square(image, tuple(int(x *255) for x in self.vis_processor.image_mean))
96
+ if isinstance(self.vis_processor, CLIPImageProcessor) or isinstance(self.vis_processor, SiglipImageProcessor):
97
+ image = self.vis_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
98
+ else:
99
+ image = self.vis_processor.preprocess(image)
100
+ elif self.image_aspect_ratio == "anyres":
101
+ image = process_anyres_image(image, self.vis_processor, self.image_grid_pinpoints)
102
+ else:
103
+ if isinstance(self.vis_processor, CLIPImageProcessor) or isinstance(self.vis_processor, SiglipImageProcessor):
104
+ image = self.vis_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
105
+ else:
106
+ image = self.vis_processor.preprocess(image)
107
+
108
+ return image
109
+
110
+ def vis_preprocess(self, vis_path):
111
+ image = Image.open(vis_path).convert('RGB')
112
+ image = self.preprocess_image(image)
113
+ if isinstance(image, list):
114
+ images = image
115
+ else:
116
+ images = [image]
117
+
118
+ return images
119
+
120
+
121
+ class VideoTaskDataset(ImageTaskDataset):
122
+ def __init__(self, anno_path=None, data_args=None, name=None):
123
+ super().__init__(anno_path=anno_path,
124
+ data_args=data_args,
125
+ name=name)
126
+
127
+ # if not specify num_segments, use default
128
+ self.num_segments = self.data_args.num_segments
129
+ self.sample_strategy = self.data_args.sample_strategy
130
+ self.type = 'video'
131
+
132
+ def vis_preprocess(self, vis_path):
133
+ images = None
134
+ try:
135
+ video_reader = VideoReader(vis_path)
136
+ vlen = len(video_reader)
137
+ fps = video_reader.get_avg_fps()
138
+ duration = vlen / float(fps)
139
+
140
+ frame_indices = get_frame_indices(self.num_segments, vlen,
141
+ sample=self.sample_strategy, input_fps=fps, pad_last=False)
142
+ frames = video_reader.get_batch(frame_indices)
143
+ frames = frames.numpy().astype(np.uint8)
144
+ images = [Image.fromarray(frame).convert('RGB') for frame in frames]
145
+ images = [self.preprocess_image(image) for image in images]
146
+ except Exception as e:
147
+ print(e, vis_path)
148
+ sys.stdout.flush()
149
+ images = None
150
+
151
+ # print(f"images: {len(images)}, {images[0].shape}")
152
+
153
+ return images
154
+
155
+
156
+ class FramesTaskDataset(ImageTaskDataset):
157
+ def __init__(self, anno_path=None, data_args=None, fps=0.5, name=None):
158
+ super().__init__(anno_path=anno_path,
159
+ data_args=data_args,
160
+ name=name)
161
+
162
+ # if not specify num_segments, use default
163
+ self.num_segments = self.data_args.num_segments
164
+ # print("self.num_segments:", self.num_segments)
165
+ self.type = 'video'
166
+ self.default_fps = 2.0
167
+ self.fps = fps
168
+
169
+ @staticmethod
170
+ def _downsample_frames(frames, interval, keep_first_last=True):
171
+ if keep_first_last:
172
+ first, last, mid = frames[0], frames[-1], frames[1:-1]
173
+ sampled_frames = mid[interval - 1::interval]
174
+ ret = [first] + sampled_frames + [last]
175
+
176
+ else:
177
+ # may output empty list, recommend keep first and last frame
178
+ ret = frames[interval - 1::interval]
179
+
180
+ return ret
181
+
182
+ @staticmethod
183
+ def _sample_frames(frames, num_segments):
184
+ frame_indices = list(range(len(frames)))
185
+ cand_indices = copy.deepcopy(frame_indices)
186
+ intervals = np.linspace(start=0, stop=len(frame_indices), num=num_segments + 1).astype(int)
187
+ ranges = []
188
+
189
+ for idx, interv in enumerate(intervals[:-1]):
190
+ ranges.append((interv, intervals[idx + 1] - 1))
191
+
192
+ try:
193
+ frame_indices = [cand_indices[random.choice(range(x[0], x[1]))] for x in ranges]
194
+ except:
195
+ frame_indices = [cand_indices[x[0]] for x in ranges]
196
+
197
+ sampled_frames = [frames[indice] for indice in frame_indices]
198
+
199
+ return sampled_frames
200
+
201
+ def vis_preprocess(self, vis_path):
202
+ image_files = [(os.path.splitext(img)[0], img) for img in os.listdir(vis_path) if not img.startswith('cuttime')]
203
+ if image_files[0][1].endswith('jpeg'):
204
+ # gpt4v public data
205
+ image_files = [(int(x[0].split('_')[-1]), x[1]) for x in image_files]
206
+ else:
207
+ image_files = [(int(x[0]), x[1]) for x in image_files]
208
+
209
+ image_files = sorted(image_files, key=lambda img: img[0])
210
+
211
+ if self.fps < self.default_fps:
212
+ interval = math.floor(self.default_fps / self.fps)
213
+ image_files = self._downsample_frames(image_files, interval, keep_first_last=True)
214
+
215
+ if self.num_segments > 0 and len(image_files) > self.num_segments:
216
+ image_files = self._sample_frames(image_files, self.num_segments)
217
+
218
+ images = []
219
+ for image_file in image_files:
220
+ try:
221
+ images.append(Image.open(os.path.join(vis_path, image_file[1])).convert('RGB'))
222
+ except Exception as e:
223
+ continue
224
+ formatted_images = []
225
+ for image in images:
226
+ im = self.preprocess_image(image)
227
+ if isinstance(im, list):
228
+ formatted_images.extend(im)
229
+ else:
230
+ formatted_images.append(im)
231
+ return formatted_images
232
+
233
+
234
+
app/llava/datasets/builder.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .registry import Registry
2
+
3
+ __all__ = ['DATASETS']
4
+
5
+ DATASETS = Registry('datasets')
app/llava/datasets/cc_sbu_dataset.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from llava.datasets.builder import DATASETS
4
+
5
+ from typing import Dict, Optional, Sequence, List
6
+ from llava.datasets.data_cfgs import data_configs
7
+ from llava.datasets.base_dataset import ImageTaskDataset
8
+ from llava.datasets.prompts import cc_sbu_prompt
9
+ from llava.constants import DEFAULT_IMAGE_TOKEN
10
+
11
+
12
+ class CCSBUDataset(ImageTaskDataset):
13
+ def __init__(self, anno_path, data_args=None, name='cc_sbu'):
14
+ super().__init__(anno_path=anno_path,
15
+ data_args=data_args,
16
+ name=name)
17
+
18
+ def text_preprocess(self, item) -> List[Dict[str, str]]:
19
+ caption = item['caption']
20
+
21
+ conversations = [
22
+ {
23
+ 'from': 'human',
24
+ 'value': DEFAULT_IMAGE_TOKEN + random.choice(cc_sbu_prompt)
25
+ },
26
+ {
27
+ 'from': 'model',
28
+ 'value': caption
29
+ }
30
+ ]
31
+
32
+ return conversations
33
+
34
+
35
+ @DATASETS.register_obj
36
+ def cc_sbu(data_args):
37
+ return CCSBUDataset(data_configs["cc_sbu"]['train_data_path'], data_args)
38
+
39
+
40
+
app/llava/datasets/data_cfgs.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_configs = {
2
+ 'llava_pretrain': {
3
+ 'data_type': 'images',
4
+ 'train_data_path': '/mnt/bn/algo-masp-nas-2/baiyi.by/data/blip_laion_cc_sbu_558k/meta_data.json'
5
+ },
6
+ 'llava_instruct': {
7
+ 'data_type': 'images',
8
+ 'train_data_path': '/mnt/bn/algo-masp-nas-2/baiyi.by/data/llava_instruct_150k/meta_data.json'
9
+ },
10
+ 'lrv_instruct': {
11
+ 'data_type': 'images',
12
+ 'train_data_path': '/mnt/bn/algo-masp-nas-2/baiyi.by/data/lrv_instructions/meta_data.json'
13
+ },
14
+ 'coco_caption': {
15
+ 'data_type': 'images',
16
+ 'train_data_path': '/mnt/bn/data-tns-algo-masp/baiyi.by/data/coco_caption/train.json'
17
+ },
18
+ 'cc_sbu': {
19
+ 'data_type': 'images',
20
+ 'train_data_path': '/mnt/bn/baiyi-arnold-nas/data/masp/vlm_data/cc_sbu/meta_data.json'
21
+ },
22
+ 'laion': {
23
+ 'data_type': 'images',
24
+ 'train_data_path': '/mnt/bn/data-tns-algo-masp/baiyi.by/data/laion/train.json'
25
+ },
26
+ 'webvid': {
27
+ 'data_type': 'video',
28
+ 'train_data_path': '/mnt/bn/baiyi-arnold-nas/data/masp/vlm_data/webvid_10M_video/train.json',
29
+ 'val_data_path': '/mnt/bn/baiyi-arnold-nas/data/masp/vlm_data/webvid_10M_video/val.json'
30
+ },
31
+ 'internvid': {
32
+ 'data_type': 'frames',
33
+ 'fps': 0.5,
34
+ 'train_data_path': '/mnt/bn/algo-masp-nas-2/baiyi.by/data/InternVid/meta_data.json'
35
+ },
36
+ 'video_chatgpt_instruct_single': {
37
+ 'data_type': 'video',
38
+ 'train_data_path': '/mnt/bn/algo-masp-nas-2/baiyi.by/data/VideoChatGPT_Instruct_100K_single/train.json'
39
+ },
40
+ 'video_chatgpt_instruct_multi': {
41
+ 'data_type': 'video',
42
+ 'train_data_path': '/mnt/bn/algo-masp-nas-2/baiyi.by/data/VideoChatGPT_Instruct_100K_multi/train.json'
43
+ },
44
+ 'video_chatgpt': {
45
+ 'data_type': 'frames',
46
+ 'train_data_path': '/mnt/bn/algo-masp-nas-2/baiyi.by/data/video_chatgpt_instruct/meta_data.json'
47
+ },
48
+ 'm3it': {
49
+ 'data_type': 'images',
50
+ 'default_tasks': [
51
+ 'coco',
52
+ 'textcap',
53
+ 'image-paragraph-captioning',
54
+ 'coco-goi',
55
+ 'coco-itm',
56
+ 'vqa-v2',
57
+ 'shapes',
58
+ 'docvqa',
59
+ 'ocr-vqa',
60
+ 'st-vqa',
61
+ 'text-vqa',
62
+ 'gqa',
63
+ 'okvqa',
64
+ 'a-okvqa',
65
+ 'viquae',
66
+ 'clevr',
67
+ 'nlvr',
68
+ 'vcr',
69
+ 'visual-mrc',
70
+ 'visual-dialog',
71
+ 'multi30k'
72
+ ]
73
+ },
74
+ 'tt_vqa': {
75
+ 'data_type': 'frames',
76
+ 'fps': 2,
77
+ 'train_data_path': '/mnt/bn/algo-masp-nas-2/baiyi.by/data/ADSO_Anno_Data/batch_20231128/meta_data_single_60k_caption_170k_QA.json'
78
+ # 'train_data_path': '/mnt/bn/yukunfeng-nasdrive/xiangchen/dataset/masp/20240208_meta_data_single_135k_caption_160k_QA.json'
79
+ # 'train_data_path': '/mnt/bn/algo-masp-nas-2/baiyi.by/data/ADSO_Anno_Data/batch_20231128/meta_data_final_single_non_empty.json'
80
+ },
81
+ 'gpt4v_tt_vqa': {
82
+ 'data_type': 'frames',
83
+ 'fps': 0.5,
84
+ # 'train_data_path': '/mnt/bn/algo-masp-nas-2/baiyi.by/data/GPT4V_Negs/20231127_81k_single.json'
85
+ # 'train_data_path': '/mnt/bn/yukunfeng-nasdrive/xiangchen/dataset/masp/20231127_81k_25k_filtered_single_non_empty.json'
86
+ 'train_data_path': '/mnt/bn/algo-masp-nas-2/xiangchen/dataset/masp/20231222_120k_multi_filtered.json',
87
+ 'task_types': ['caption', 'qas'],
88
+ 'conv_type': 'single'
89
+ },
90
+ 'sharegpt4v': {
91
+ 'data_type': 'images',
92
+ 'coco_dir': '/mnt/bn/data-tns-algo-masp/data',
93
+ 'llava_dir': '/mnt/bn/data-tns-algo-masp/baiyi.by/data/blip_laion_cc_sbu_558k',
94
+ 'other_dir': '/mnt/bn/algo-masp-nas-2/xiangchen/dataset/sharegpt4v',
95
+ },
96
+ 'gpt4v_public': {
97
+ 'data_type': 'frames',
98
+ 'fps': 1,
99
+ 'train_data_path': '/mnt/bn/algo-masp-nas-2/xiangchen/data/shared_gpt4v_data/data_130k.json',
100
+ # 'train_data_path': '/mnt/bn/algo-masp-nas-2/xiangchen/data/shared_gpt4v_data/data_500k_filtered.json',
101
+ 'task_types': ['summary', 'detail', 'qa_pairs'],
102
+ 'conv_type': 'single',
103
+ 'sample_method': 'uniform'
104
+ },
105
+
106
+ 'gpt4v_internal': {
107
+ 'data_type': 'frames',
108
+ 'fps': 2,
109
+ 'train_data_path': '/mnt/bn/algo-masp-nas-2/xiangchen/dataset/masp/gpt4v_internal_28k.json',
110
+ 'task_types': ['summary','detail','qa_pairs'],
111
+ 'conv_type': 'single'
112
+ },
113
+
114
+ 'synthdog': { #500k
115
+ 'data_type': 'images',
116
+ },
117
+
118
+ 'ocr_vqa': { #200k
119
+ 'data_type': 'images',
120
+ 'train_data_path': '/mnt/bn/algo-masp-nas-2/xiangchen/dataset/OCR-VQA/training_meta.json'
121
+ },
122
+
123
+ 'sharegpt': { #50k
124
+ 'data_type': 'text'
125
+ },
126
+
127
+ 'text_caps':{ #100k
128
+ 'data_type': 'images',
129
+ 'train_data_path': '/mnt/bn/algo-masp-nas-2/xiangchen/dataset/TextCaps/TextCaps_0.1_train.json'
130
+ },
131
+
132
+ 'synthetic_ocr':{ # 50k
133
+ 'data_type': 'frames',
134
+ 'fps': 0.5, # total 10 frames for each video
135
+ 'train_data_path': '/mnt/bn/algo-masp-nas-2/xiangchen/dataset/masp/synthetic_ocr/train_filtered.json'
136
+ },
137
+
138
+ 'lk_image':{ # 600k
139
+ 'data_type': 'images',
140
+ 'train_data_path': '/mnt/bn/liangkeg/data/xiangchen/finetune_all_detail_vidal200k_videollava_images_im.json'
141
+ },
142
+
143
+ 'lk_video':{ # 850k
144
+ 'data_type': 'frames',
145
+ 'fps': 1,
146
+ 'train_data_path': '/mnt/bn/liangkeg/data/xiangchen/finetune_all_detail_vidal200k_videollava_images_vid.json',
147
+ 'select_datasets': ['webvid10m', 'webvid2m', 'activitynet', 'vidal', 'hdvila'],
148
+ },
149
+
150
+ 'promptv1_2_internal':{ # 210k
151
+ 'data_type': 'frames',
152
+ 'train_data_path': '/mnt/bn/algo-masp-nas-2/kaili.zhao/data/masp_data/train/gpt4v_annotation/202400401week_gpt4v_all_videos_unique_ids.json',
153
+ 'task_types': ['caption']
154
+ }
155
+ }
156
+
157
+