JOSESMOKE commited on
Commit
a6bd5e1
·
verified ·
1 Parent(s): f1eaac0

Upload folder using huggingface_hub

Browse files
Dockerfile ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.3.2-cudnn9-devel-ubuntu22.04
2
+
3
+ # Set environment variables
4
+ ENV PYTHONUNBUFFERED=1 \
5
+ DEBIAN_FRONTEND=noninteractive \
6
+ CUDA_HOME=/usr/local/cuda \
7
+ PATH=/usr/local/cuda/bin:$PATH \
8
+ LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH \
9
+ NVIDIA_VISIBLE_DEVICES=all \
10
+ NVIDIA_DRIVER_CAPABILITIES=compute,utility
11
+
12
+ # Install system dependencies
13
+ RUN apt-get update && apt-get install -y --no-install-recommends \
14
+ python3 \
15
+ python3-pip \
16
+ python3-dev \
17
+ build-essential \
18
+ ffmpeg \
19
+ libsndfile1 \
20
+ curl \
21
+ git \
22
+ && rm -rf /var/lib/apt/lists/*
23
+
24
+ # Upgrade pip and install build tools
25
+ RUN python3 -m pip install --upgrade pip setuptools wheel
26
+
27
+ WORKDIR /app
28
+
29
+ COPY . .
30
+
31
+ # Install requirements
32
+ RUN pip3 install --no-cache-dir -r requirements.txt
33
+
34
+ EXPOSE 8000
35
+
36
+ CMD ["python3", "server.py"]
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - any-to-any
5
+ - omega
6
+ - omegalabs
7
+ - bittensor
8
+ - agi
9
+ ---
10
+
11
+ This is an Any-to-Any model checkpoint for the OMEGA Labs x Bittensor Any-to-Any subnet.
12
+
13
+ Check out the [git repo](https://github.com/omegalabsinc/omegalabs-anytoany-bittensor) and find OMEGA on X: [@omegalabsai](https://x.com/omegalabsai).
hotkey.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 5F1zctJ9w3rMuSzjRgEfG9rQ4ShAM7poyBnHiqUZK3tBpahE
inference.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import itertools
7
+ import sys
8
+ import time
9
+ from typing import Any, Dict, List
10
+
11
+ import torch
12
+ from torch import nn
13
+ from omegaconf import DictConfig
14
+ from PIL import Image
15
+
16
+ from torchtune import config, utils
17
+ from torchtune.utils._generation import sample
18
+ from torchtune.models import convert_weights
19
+ from torchtune.data import Message
20
+
21
+ from models.tokenizer import START_IMAGE, END_IMAGE, START_AUDIO, END_AUDIO, START_VIDEO, END_VIDEO
22
+ from imagebind.models.imagebind_model import ModalityType
23
+ from diffusers import DiffusionPipeline
24
+
25
+ from models import add_proj_convert_weights, _BASE_TRAINABLE
26
+ import os
27
+
28
+ log = utils.get_logger("DEBUG")
29
+ add_proj_convert_weights()
30
+
31
+
32
+ class InferenceRecipe:
33
+ """
34
+ Recipe for generating tokens from a dense Transformer-based LLM.
35
+
36
+ Currently this recipe supports single-GPU generation only. Speculative
37
+ decoding is not supported.
38
+
39
+ For more details on how to use this recipe for generation, please see our
40
+ tutorial: https://pytorch.org/torchtune/main/tutorials/e2e_flow.html#generation
41
+
42
+ For using this recipe with a quantized model, please the following section of
43
+ the above tutorial:
44
+ https://pytorch.org/torchtune/main/tutorials/e2e_flow.html#speeding-up-generation-using-quantization
45
+ """
46
+
47
+ def __init__(self, cfg: DictConfig) -> None:
48
+ self._device = utils.get_device(device=cfg.device)
49
+ self._dtype = utils.get_dtype(dtype=cfg.dtype)
50
+ self._quantizer = config.instantiate(cfg.inference.quantizer)
51
+ self._quantization_mode = utils.get_quantizer_mode(self._quantizer)
52
+ self.prompt_template = cfg.inference.prompt_template
53
+ perception_tokens = cfg.model.perception_tokens
54
+ self._perception_tokens = ("0 " * perception_tokens)[:perception_tokens]
55
+ utils.set_seed(seed=cfg.seed)
56
+
57
+ def setup(self, cfg: DictConfig) -> None:
58
+ checkpointer = config.instantiate(cfg.checkpointer)
59
+ if self._quantization_mode is None:
60
+ ckpt_dict = checkpointer.load_checkpoint()
61
+ else:
62
+ # weights_only needs to be False when loading a quantized model
63
+ # currently loading a quantized model is only supported with the
64
+ # FullModelTorchTuneCheckpointer
65
+ ckpt_dict = checkpointer.load_checkpoint(weights_only=False)
66
+
67
+ self._model = self._setup_model(
68
+ model_cfg=cfg.model,
69
+ model_state_dict=ckpt_dict[utils.MODEL_KEY],
70
+ )
71
+ with self._device:
72
+ self._model.setup_caches(max_batch_size=cfg.batch_size, dtype=self._dtype)
73
+
74
+ self._tokenizer = config.instantiate(cfg.tokenizer)
75
+ self._mm_ids_start = self._tokenizer.encode(START_IMAGE + START_AUDIO + START_VIDEO, add_eos=False, add_bos=False)
76
+ self._mm_ids_end = self._tokenizer.encode(END_IMAGE + END_AUDIO + END_VIDEO, add_eos=False, add_bos=False)
77
+ self.use_clip = cfg.model.use_clip
78
+ if self.use_clip:
79
+ self._clip_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-unclip-small", torch_dtype=self._dtype).to(self._device)
80
+
81
+ def _setup_model(
82
+ self,
83
+ model_cfg: DictConfig,
84
+ model_state_dict: Dict[str, Any],
85
+ ) -> nn.Module:
86
+ with utils.set_default_dtype(self._dtype), self._device:
87
+ model = config.instantiate(model_cfg)
88
+
89
+ if self._quantization_mode is not None:
90
+ model = self._quantizer.quantize(model)
91
+ model = model.to(device=self._device, dtype=self._dtype)
92
+
93
+ model.load_state_dict(model_state_dict)
94
+
95
+ # Validate model was loaded in with the expected dtype.
96
+ utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype)
97
+ log.debug(f"Model is initialized with precision {self._dtype}.")
98
+
99
+ return model
100
+
101
+ def mm_process_prompt(self, prompt):
102
+ return (
103
+ prompt
104
+ .replace("{image}", f"{START_IMAGE}{self._perception_tokens}{END_IMAGE}")
105
+ .replace("{audio}", f"{START_AUDIO}{self._perception_tokens}{END_AUDIO}")
106
+ .replace("{video}", f"{START_VIDEO}{self._perception_tokens}{END_VIDEO}")
107
+ )
108
+
109
+ def extract_mm_context(self, video_ib_embed, tokens):
110
+ context = {}
111
+ in_mm_embed = False
112
+ for idx, tok in enumerate(tokens):
113
+ in_mm_embed = in_mm_embed and not tok in self._mm_ids_end
114
+ if in_mm_embed:
115
+ #tokens[idx] # to support multiple embeds: get the value, match it up with the sample embed
116
+ context[idx] = {
117
+ "ib_embed": video_ib_embed.to(dtype=self._dtype, device=self._device),
118
+ }
119
+ in_mm_embed = in_mm_embed or tok in self._mm_ids_start
120
+ return context
121
+
122
+ @torch.no_grad()
123
+ def generate(self, cfg: DictConfig, video_ib_embed: List[float]):
124
+ messages = [
125
+ Message(
126
+ role="user",
127
+ content=self.mm_process_prompt(self.prompt_template),
128
+ ),
129
+ Message(
130
+ role="assistant",
131
+ content="",
132
+ )
133
+ ]
134
+ tokens, mask = self._tokenizer.tokenize_messages(messages)
135
+ tokens = tokens[:-2] # strip eot and eos
136
+ mm_context = [self.extract_mm_context(video_ib_embed, tokens)] # context should be a list, batch-id indexed
137
+ prompt = torch.tensor(tokens, dtype=torch.int, device=self._device)
138
+
139
+ self._model.tok_embeddings.set_context(mm_context)
140
+ self._model.output.set_context(mm_context)
141
+
142
+ bos_id = self._tokenizer.tt_model.encode("<|begin_of_text|>", allowed_special="all")[0]
143
+ allowed_id = self._tokenizer.tt_model.encode(f"<|eot_id|>{START_IMAGE}{END_IMAGE}{START_AUDIO}{END_AUDIO}{START_VIDEO}{END_VIDEO}", allowed_special="all")
144
+ disallowed_tokens = list(set(range(bos_id, bos_id + 256)) - set(allowed_id))
145
+ # self._model.output.weight.data[disallowed_tokens, :] = 0
146
+
147
+ def custom_generate_next_token(model, input_pos, x, temperature=1.0, top_k=None):
148
+ model.tok_embeddings.set_context([])
149
+ model.output.set_context([])
150
+ # x: [1, s]
151
+ # input_pos: [s]
152
+ logits = model(x, input_pos=input_pos)
153
+ # logits: [1, s, v] where v is vocab_size
154
+ # for sampling we extract the logits for the
155
+ # last token and convert to shape: [v]
156
+ logits = logits[0, -1]
157
+ # logits[disallowed_tokens] = float("-inf")
158
+ # sample the next token
159
+ token = sample(logits, temperature, top_k)
160
+ if token in disallowed_tokens:
161
+ return torch.tensor([self._tokenizer.eos_id]).to(x)
162
+ return token
163
+
164
+ # since quantized model uses torch.compile to get speedup, it needs a warm up / prefill run
165
+ # to get the accurate performance measurement
166
+ if self._quantization_mode is not None:
167
+ log.info("Starting compilation to improve generation performance ...")
168
+ custom_generate_next_token = torch.compile(
169
+ custom_generate_next_token, mode="max-autotune", fullgraph=True
170
+ )
171
+ t0 = time.perf_counter()
172
+ _ = utils.generate(
173
+ model=self._model,
174
+ prompt=prompt,
175
+ max_generated_tokens=2,
176
+ temperature=cfg.temperature,
177
+ top_k=cfg.top_k,
178
+ eos_id=self._tokenizer.eos_id,
179
+ custom_generate_next_token=custom_generate_next_token,
180
+ )
181
+ t = time.perf_counter() - t0
182
+ log.info(f"Warmup run for quantized model takes: {t:.02f} sec")
183
+
184
+ t0 = time.perf_counter()
185
+ generated_tokens = utils.generate(
186
+ model=self._model,
187
+ prompt=prompt,
188
+ max_generated_tokens=cfg.max_new_tokens,
189
+ temperature=cfg.temperature,
190
+ top_k=cfg.top_k,
191
+ eos_id=self._tokenizer.eos_id,
192
+ custom_generate_next_token=custom_generate_next_token,
193
+ )
194
+ t = time.perf_counter() - t0
195
+
196
+ cleaned_tokens = [t for t in generated_tokens[len(prompt):] if t not in disallowed_tokens + allowed_id]
197
+ caption = self._tokenizer.decode(cleaned_tokens)
198
+
199
+ # log.debug(f"Generated caption: {caption} in {t:.02f} sec")
200
+
201
+ return caption
202
+
203
+
204
+ @torch.no_grad()
205
+ def generate_batch(self, cfg: DictConfig, video_ib_embed: torch.Tensor):
206
+ log.info(f"inside generate_batch, video_ib_embed shape: {video_ib_embed.shape}")
207
+ batch_dim = video_ib_embed.size(0)
208
+ messages = [
209
+ Message(
210
+ role="user",
211
+ content=self.mm_process_prompt(self.prompt_template),
212
+ ),
213
+ Message(role="assistant", content="")
214
+ ]
215
+ tokens, mask = self._tokenizer.tokenize_messages(messages)
216
+ tokens = tokens[:-2] # strip eot and eos
217
+ mm_context = [self.extract_mm_context(e, tokens) for e in video_ib_embed] # context should be a list, batch-id indexed
218
+ prompt = torch.tensor(tokens, dtype=torch.int, device=self._device).expand(batch_dim, -1).clone()
219
+ prompt_length = prompt.size(1)
220
+
221
+ self._model.tok_embeddings.set_context(mm_context)
222
+ self._model.output.set_context(mm_context)
223
+
224
+ bos_id = self._tokenizer.tt_model.encode("<|begin_of_text|>", allowed_special="all")[0]
225
+ allowed_id = self._tokenizer.tt_model.encode(f"<|eot_id|>{START_IMAGE}{END_IMAGE}{START_AUDIO}{END_AUDIO}{START_VIDEO}{END_VIDEO}", allowed_special="all")
226
+ disallowed_tokens = list(set(range(bos_id, bos_id + 256)) - set(allowed_id))
227
+
228
+ def generate_next_token(model, input_pos, x, temperature=1.0, top_k=None):
229
+ # x: [B, s]
230
+ # input_pos: [s]
231
+ # logits: [B, s, v] where v is vocab_size
232
+ logits = model(x, input_pos=input_pos)[:, -1]
233
+ tokens = sample(logits, temperature, top_k)
234
+ return torch.tensor([
235
+ [self._tokenizer.eos_id if t in disallowed_tokens else t for t in toks]
236
+ for toks in tokens
237
+ ]).to(x.device)
238
+
239
+ generated_tokens = prompt.clone()
240
+ # keeps track at a high level if we've already hit a stop token in a sequence so we can early stop
241
+ stop_token_reached = torch.zeros(batch_dim, dtype=torch.bool, device=prompt.device)
242
+
243
+ # generate the first tokens conditioned on the prompt
244
+ tokens = generate_next_token(
245
+ self._model,
246
+ input_pos=torch.arange(0, prompt_length, device=prompt.device),
247
+ x=prompt,
248
+ temperature=cfg.temperature,
249
+ top_k=cfg.top_k,
250
+ )
251
+ eot_reached_b = tokens == self._tokenizer.eot_id
252
+ generated_tokens = torch.cat([generated_tokens, tokens], dim=-1)
253
+
254
+ self._model.tok_embeddings.set_context([])
255
+ self._model.output.set_context([])
256
+
257
+ input_pos = torch.tensor([prompt_length], device=prompt.device)
258
+ for _ in range(cfg.max_new_tokens - 1):
259
+ tokens = generate_next_token(
260
+ self._model, input_pos=input_pos, x=tokens, temperature=cfg.temperature, top_k=cfg.top_k
261
+ )
262
+ eot_reached_b |= tokens == self._tokenizer.eot_id
263
+ tokens *= ~eot_reached_b
264
+ generated_tokens = torch.cat([generated_tokens, tokens], dim=-1)
265
+ if eot_reached_b.all():
266
+ print('eot_reached_b.all()')
267
+ break
268
+ input_pos += 1
269
+
270
+ captions = []
271
+ for caption_tokens in generated_tokens.tolist():
272
+ captions.append(self._tokenizer.decode(caption_tokens[prompt.size(1):]))
273
+ return captions
274
+
275
+
276
+ @config.parse
277
+ def main(cfg: DictConfig) -> None:
278
+ config.log_config(recipe_name="InferenceRecipe", cfg=cfg)
279
+ cfg.model = DictConfig({
280
+ "_component_": "models.mmllama3_8b",
281
+ "use_clip": False,
282
+ "perception_tokens": cfg.model.perception_tokens,
283
+ })
284
+ cfg.batch_size = 4
285
+ cfg.checkpointer.checkpoint_dir = os.path.dirname("/home/salman/tezuesh/omegalabs-anytoany-bittensor/sandboxing/cache/xzistance_omega-a2a-hotkey/meta_model_0.pth")
286
+
287
+ cfg.checkpointer.checkpoint_files = ["models/meta_model_0.pt"]
288
+ cfg.inference.max_new_tokens = 300
289
+ cfg.tokenizer.path = "./models/tokenizer.model"
290
+ inference_recipe = InferenceRecipe(cfg)
291
+ inference_recipe.setup(cfg=cfg)
292
+ captions = inference_recipe.generate_batch(cfg=cfg, video_ib_embed=torch.randn(4,1024))
293
+ print(captions)
294
+
295
+
296
+ if __name__ == "__main__":
297
+ sys.exit(main())
298
+
299
+
300
+
301
+ # if __name__ == "__main__":
302
+ # sys.exit(main())
303
+
304
+
305
+ # if __name__ == "__main__":
306
+ # sys.exit(main())
models/__init__.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchtune.models import convert_weights
2
+
3
+ from models.tokenizer import a2a_tokenizer
4
+ from models.mmllama3 import lora_mmllama3_8b, mmllama3_8b, imagebind_huge
5
+
6
+ __all__ = [
7
+ "a2a_tokenizer",
8
+ "lora_mmllama3_8b",
9
+ "mmllama3_8b",
10
+ "imagebind_huge",
11
+
12
+ ]
13
+
14
+ _BASE_TRAINABLE = [
15
+ "tok_embeddings.proj_to_llama.0.weight",
16
+ "tok_embeddings.proj_to_llama.0.bias",
17
+ "tok_embeddings.proj_to_llama.2.weight",
18
+ "tok_embeddings.proj_to_llama.2.bias",
19
+ "tok_embeddings.proj_to_llama.3.weight",
20
+ "tok_embeddings.proj_to_llama.3.bias",
21
+ "output.proj_from_llama.0.weight",
22
+ "output.proj_from_llama.0.bias",
23
+ "output.proj_from_llama.2.weight",
24
+ "output.proj_from_llama.2.bias",
25
+ "output.proj_from_llama.3.weight",
26
+ "output.proj_from_llama.3.bias",
27
+ ]
28
+
29
+ def add_proj_convert_weights():
30
+ # extend _FROM_META torchtune -> meta mapping with new parameter names
31
+ # allow existing ckpt-save code to work without changes
32
+ convert_weights._FROM_META.update({a: a for a in _BASE_TRAINABLE})
33
+
34
+
models/bpe_simple_vocab_16e6.txt.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
+ size 1356917
models/imagebind_wrapper.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from typing import BinaryIO, List
5
+
6
+ from imagebind import imagebind_model
7
+ from imagebind.models.imagebind_model import ModalityType
8
+ from imagebind.models.multimodal_preprocessors import SimpleTokenizer, TextPreprocessor
9
+
10
+
11
+ V2_URL = "https://huggingface.co/jondurbin/videobind-v0.2/resolve/main/videobind.pth"
12
+ V2_PATH = "./.checkpoints/videobind-v0.2.pth"
13
+ BPE_PATH = "./models/bpe_simple_vocab_16e6.txt.gz"
14
+ TOKENIZER = SimpleTokenizer(bpe_path=BPE_PATH)
15
+ LENGTH_TOKENIZER = SimpleTokenizer(bpe_path=BPE_PATH, context_length=1024)
16
+ TOKEN_CHUNK_SIZE = 74
17
+
18
+ def get_imagebind_v2(path: str=V2_PATH):
19
+ if not os.path.isfile(path):
20
+ os.makedirs(os.path.dirname(path), exist_ok=True)
21
+ torch.hub.download_url_to_file(V2_URL, path, progress=True)
22
+ imagebind_model = torch.load(path)
23
+ return imagebind_model
24
+
25
+
26
+ def load_and_transform_text(text, device):
27
+ if text is None:
28
+ return None
29
+ tokens = [TOKENIZER(t).unsqueeze(0).to(device) for t in text]
30
+ tokens = torch.cat(tokens, dim=0)
31
+ return tokens
32
+
33
+ def split_text_by_token_limit(text, tokenizer, max_tokens=TOKEN_CHUNK_SIZE):
34
+ def fits_in_token_limit(text_segment):
35
+ tokens = tokenizer(text_segment)
36
+ tokens = tokens[tokens != 0][1:-1].tolist()
37
+ return len(tokens) <= max_tokens
38
+
39
+ def recursive_split(text, delimiters):
40
+ if fits_in_token_limit(text):
41
+ return [text]
42
+ if not delimiters:
43
+ return split_by_tokens(text)
44
+ delimiter = delimiters[0]
45
+ parts = text.split(delimiter)
46
+ result = []
47
+ current_segment = ""
48
+ for part in parts:
49
+ candidate_segment = current_segment + (delimiter if current_segment else '') + part
50
+ if fits_in_token_limit(candidate_segment):
51
+ current_segment = candidate_segment
52
+ else:
53
+ if current_segment:
54
+ result.append(current_segment)
55
+ current_segment = part
56
+ if current_segment:
57
+ result.append(current_segment)
58
+ final_result = []
59
+ for segment in result:
60
+ if fits_in_token_limit(segment):
61
+ final_result.append(segment)
62
+ else:
63
+ final_result.extend(recursive_split(segment, delimiters[1:]))
64
+ return final_result
65
+
66
+ def split_by_tokens(text):
67
+ tokens = tokenizer(text)
68
+ tokens = tokens[tokens != 0][1:-1].tolist()
69
+ chunks = np.array_split(tokens, int(len(tokens) / max_tokens) or 1)
70
+ return [
71
+ tokenizer.decode(segment_tokens)
72
+ for segment_tokens in chunks
73
+ ]
74
+
75
+ return recursive_split(text, ['\n', '.', '!', '?', ',', ' '])
76
+
77
+ def load_and_transform_text_chunks(text, device):
78
+ if not text:
79
+ return []
80
+ all_tokens = LENGTH_TOKENIZER(text)
81
+ all_tokens = all_tokens[all_tokens != 0][1:-1].tolist()
82
+
83
+ return [
84
+ load_and_transform_text([segment], device)
85
+ for segment in split_text_by_token_limit(text, LENGTH_TOKENIZER)
86
+ ]
87
+
88
+
89
+ class ImageBind:
90
+ def __init__(self, device="cuda:0", v2=False):
91
+ self.device = device
92
+ self.v2 = v2
93
+ if v2:
94
+ if not os.path.exists(V2_PATH):
95
+ os.makedirs(os.path.dirname(V2_PATH), exist_ok=True)
96
+ torch.hub.download_url_to_file(
97
+ V2_URL,
98
+ V2_PATH,
99
+ progress=True,
100
+ )
101
+ self.imagebind = torch.load(V2_PATH)
102
+ else:
103
+ self.imagebind = imagebind_model.imagebind_huge(pretrained=True)
104
+ self.imagebind.eval()
105
+ self.imagebind.to(self.device)
106
+
107
+ def generate_text_embeddings(self, text: str):
108
+ if not self.v2:
109
+ return self.imagebind({
110
+ ModalityType.TEXT: load_and_transform_text([text], self.device)
111
+ })[ModalityType.TEXT]
112
+ chunks = load_and_transform_text_chunks(text, self.device)
113
+ embeddings = [
114
+ self.imagebind({ModalityType.TEXT: chunk})[ModalityType.TEXT]
115
+ for chunk in chunks
116
+ ]
117
+ return torch.mean(torch.stack(embeddings), dim=0)
118
+
119
+ """ Deactivating full embeddings as they are not used in the current implementation
120
+ def get_inputs(self, video_file: BinaryIO) -> dict:
121
+ audio_file = video_utils.copy_audio(video_file.name)
122
+ try:
123
+ duration = video_utils.get_video_duration(video_file.name)
124
+ video_data = data.load_and_transform_video_data(
125
+ [video_file.name],
126
+ self.device,
127
+ )
128
+ audio_data = data.load_and_transform_audio_data(
129
+ [audio_file.name],
130
+ self.device,
131
+ )
132
+ inputs = {
133
+ ModalityType.VISION: video_data,
134
+ ModalityType.AUDIO: audio_data,
135
+ }
136
+ return inputs
137
+ finally:
138
+ audio_file.close()
139
+
140
+ @torch.no_grad()
141
+ def embed(self, descriptions: List[str], video_files: List[BinaryIO]) -> Embeddings:
142
+ return_value = None
143
+ for idx in range(len(descriptions)):
144
+ inputs = self.get_inputs(video_files[idx])
145
+ embeddings = self.imagebind(inputs)
146
+ text_embeddings = self.generate_text_embeddings(descriptions[idx])
147
+ if not return_value:
148
+ return_value = Embeddings(
149
+ video=embeddings[ModalityType.VISION],
150
+ audio=embeddings[ModalityType.AUDIO],
151
+ description=text_embeddings,
152
+ )
153
+ else:
154
+ return_value.video = torch.cat((return_value.video, embeddings[ModalityType.VISION]))
155
+ return_value.audio = torch.cat((return_value.audio, embeddings[ModalityType.AUDIO]))
156
+ return_value.description = torch.cat((return_value.description, text_embeddings))
157
+ return return_value
158
+
159
+ @torch.no_grad()
160
+ def embed_only_video(self, video_files: List[BinaryIO]) -> Embeddings:
161
+ video_filepaths = [video_file.name for video_file in video_files]
162
+ durations = [video_utils.get_video_duration(f.name) for f in video_files]
163
+ embeddings = self.imagebind({
164
+ ModalityType.VISION: [
165
+ data.load_and_transform_video_data(
166
+ [video_filepaths[idx]],
167
+ self.device,
168
+ )[0]
169
+ for idx in range(len(video_filepaths))
170
+ ]
171
+ })
172
+ return Embeddings(
173
+ video=embeddings[ModalityType.VISION],
174
+ )
175
+
176
+ @torch.no_grad()
177
+ def embed_video_and_text(self, video_files: List[BinaryIO], descriptions: List[str]) -> Embeddings:
178
+ video_filepaths = [video_file.name for video_file in video_files]
179
+ durations = [video_utils.get_video_duration(f.name) for f in video_files]
180
+ embeddings = self.imagebind({
181
+ ModalityType.VISION: [
182
+ data.load_and_transform_video_data(
183
+ [video_filepaths[idx]],
184
+ self.device,
185
+ )[0]
186
+ for idx in range(len(video_filepaths))
187
+ ],
188
+ })
189
+ description_embeddings = torch.stack([
190
+ self.generate_text_embeddings(description)
191
+ for description in descriptions
192
+ ])
193
+ return Embeddings(
194
+ video=embeddings[ModalityType.VISION],
195
+ description=description_embeddings,
196
+ )
197
+
198
+ @torch.no_grad()
199
+ def embed_text(self, texts: List[str]) -> torch.Tensor:
200
+ return_value = None
201
+ for text in texts:
202
+ emb = self.generate_text_embeddings(text)
203
+ if not return_value:
204
+ return_value = emb
205
+ else:
206
+ return_value = torch.cat((return_value, emb))
207
+ return return_value
208
+ """
209
+
210
+ @torch.no_grad()
211
+ def embed_text(self, texts: List[str]) -> torch.Tensor:
212
+ embeddings = []
213
+ for text in texts:
214
+ emb = self.generate_text_embeddings(text)
215
+ embeddings.append(emb)
216
+
217
+ if not embeddings:
218
+ return None
219
+
220
+ # Stack all embeddings along dimension 0
221
+ return torch.stack(embeddings, dim=0)
models/meta_model_5.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2322e3915d35dea89b79a2b537624745ced04a64a7882072ddc000e053bbf9a6
3
+ size 16219158403
models/mmllama3.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import warnings
3
+
4
+ import torch
5
+ from torch import nn, Tensor
6
+ from torchvision import transforms
7
+
8
+ from torchtune.models.llama3 import lora_llama3_8b, llama3_8b
9
+ from torchtune.modules.peft import LORA_ATTN_MODULES, LoRALinear
10
+ from torchtune.modules import TransformerDecoder
11
+
12
+ with warnings.catch_warnings():
13
+ warnings.simplefilter("ignore", UserWarning)
14
+ from imagebind.models import imagebind_model
15
+ from models.imagebind_wrapper import get_imagebind_v2, V2_PATH
16
+ from models.imagebind_wrapper import ImageBind
17
+
18
+ IMAGEBIND_DIM = 1024
19
+ CLIP_DIM = 768
20
+
21
+
22
+ class MMEmbedding(nn.Embedding):
23
+ def __init__(self, e, perception_tokens=1, use_clip=False):
24
+ super().__init__(
25
+ num_embeddings=e.num_embeddings,
26
+ embedding_dim=e.embedding_dim,
27
+ padding_idx=e.padding_idx,
28
+ max_norm=e.max_norm,
29
+ norm_type=e.norm_type,
30
+ scale_grad_by_freq=e.scale_grad_by_freq,
31
+ sparse=e.sparse,
32
+ )
33
+ self._perception_tokens = perception_tokens
34
+ self._context = []
35
+ self._use_clip = use_clip
36
+
37
+ dim_in = IMAGEBIND_DIM + (CLIP_DIM if use_clip else 0)
38
+ dim_out = e.embedding_dim * perception_tokens
39
+
40
+ self.proj_to_llama = nn.Sequential(
41
+ nn.Linear(dim_in, dim_out),
42
+ nn.GELU(),
43
+ nn.LayerNorm(dim_out),
44
+ nn.Linear(dim_out, dim_out),
45
+ )
46
+
47
+ def set_context(self, context):
48
+ self._context = context
49
+
50
+ def forward(self, input: Tensor) -> Tensor:
51
+ r = super().forward(input)
52
+ # self._context is first indexed by batch idx
53
+ for b, context_dict in enumerate(self._context):
54
+ # then by sequence idx
55
+ for s, embed in context_dict.items():
56
+ # and then must be transformed from imagebind dim -> llama3 dim
57
+ if self._use_clip:
58
+ llama_embed = self.proj_to_llama(torch.cat([embed["ib_embed"], embed["clip_embed"]]))
59
+ else:
60
+ llama_embed = self.proj_to_llama(torch.cat([embed["ib_embed"]]))
61
+ r[b, s:s+self._perception_tokens] = llama_embed.view(self._perception_tokens, -1)
62
+ return r
63
+
64
+
65
+ class MMLinear(nn.Linear):
66
+ def __init__(self, o):
67
+ super().__init__(
68
+ in_features=o.in_features,
69
+ out_features=o.out_features,
70
+ bias=(o.bias != None)
71
+ )
72
+ self._context = []
73
+
74
+ dim_out = CLIP_DIM
75
+ dim_in = o.in_features
76
+ self.proj_from_llama = nn.Sequential(
77
+ nn.Linear(dim_in, dim_out),
78
+ nn.GELU(),
79
+ nn.LayerNorm(dim_out),
80
+ nn.Linear(dim_out, dim_out),
81
+ )
82
+
83
+ def set_context(self, context):
84
+ self._context = context
85
+
86
+ def forward(self, input_bsd: Tensor) -> Tensor:
87
+ # self._context has the indexes of image llama tokens: process these with proj_from_llama
88
+ self._clip_projections = []
89
+ # # self._context is first indexed by batch idx
90
+ # for b, context_dict in enumerate(self._context):
91
+ # # then by sequence idx
92
+ # for s, embed in context_dict.items():
93
+ # # and then must be transformed from llama3 dim -> clip dim
94
+ # self._clip_projections.append((
95
+ # self.proj_from_llama(input_bsd[b, s]),
96
+ # (embed["clip_embed"] if "clip_embed" in embed else None) # terrible
97
+ # ))
98
+ r = super().forward(input_bsd)
99
+ return r
100
+
101
+
102
+
103
+ def lora_mmllama3_8b(
104
+ lora_attn_modules: List[LORA_ATTN_MODULES],
105
+ apply_lora_to_mlp: bool = False,
106
+ apply_lora_to_output: bool = False,
107
+ lora_rank: int = 8,
108
+ lora_alpha: float = 16,
109
+ quantize_base: bool = False,
110
+ perception_tokens: int = 2,
111
+ use_clip: bool = False
112
+ ) -> TransformerDecoder:
113
+ llama3 = lora_llama3_8b(
114
+ lora_attn_modules,
115
+ apply_lora_to_mlp,
116
+ apply_lora_to_output,
117
+ lora_rank,
118
+ lora_alpha,
119
+ quantize_base,
120
+ )
121
+ llama3.tok_embeddings = MMEmbedding(llama3.tok_embeddings, perception_tokens, use_clip)
122
+ llama3.output = MMLinear(llama3.output)
123
+ return llama3
124
+
125
+
126
+ def mmllama3_8b(
127
+ perception_tokens: int = 2,
128
+ use_clip: bool = False
129
+ ) -> TransformerDecoder:
130
+ llama3 = llama3_8b()
131
+ llama3.tok_embeddings = MMEmbedding(llama3.tok_embeddings, perception_tokens, use_clip)
132
+ llama3.output = MMLinear(llama3.output)
133
+ return llama3
134
+
135
+
136
+ def imagebind_huge(use_v2: bool=True):
137
+ if use_v2:
138
+ imagebind = ImageBind(v2=True)
139
+ else:
140
+ imagebind = imagebind_model.imagebind_huge(pretrained=True)
141
+ imagebind.transform_from_pil = transforms.Compose([
142
+ transforms.Resize(
143
+ 224, interpolation=transforms.InterpolationMode.BICUBIC
144
+ ),
145
+ transforms.CenterCrop(224),
146
+ transforms.ToTensor(),
147
+ transforms.Normalize(
148
+ mean=(0.48145466, 0.4578275, 0.40821073),
149
+ std=(0.26862954, 0.26130258, 0.27577711),
150
+ ),
151
+ ])
152
+ return imagebind
153
+
models/tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:82e9d31979e92ab929cd544440f129d9ecd797b69e327f80f17e1c50d5551b55
3
+ size 2183982
models/tokenizer.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from torchtune.modules.tokenizers import TikTokenTokenizer
4
+ from torchtune.modules.tokenizers._utils import _split_long_repetitions
5
+ from torchtune.modules.tokenizers._tiktoken import (
6
+ MAX_ENCODE_CHARS,
7
+ MAX_NO_WHITESPACE_CHARS,
8
+ ALL_SPECIAL_TOKENS,
9
+ )
10
+
11
+
12
+ # use special tokens from TikTokenTokenizer, add some for MM delimiters
13
+ START_IMAGE = "<|start_image|>"
14
+ END_IMAGE = "<|end_image|>"
15
+ START_VIDEO = "<|start_video|>"
16
+ END_VIDEO = "<|end_video|>"
17
+ START_AUDIO = "<|start_audio|>"
18
+ END_AUDIO = "<|end_audio|>"
19
+
20
+ A2A_SPECIAL_TOKENS = ALL_SPECIAL_TOKENS[:-2] + [
21
+ START_IMAGE,
22
+ END_IMAGE,
23
+ START_VIDEO,
24
+ END_VIDEO,
25
+ START_AUDIO,
26
+ END_AUDIO,
27
+ ] + ALL_SPECIAL_TOKENS[-2:]
28
+
29
+ # override to allow START_IMAGE, END_IMAGE to be encoded
30
+ class A2ATokenizer(TikTokenTokenizer):
31
+ def encode(
32
+ self,
33
+ text: str,
34
+ add_bos: bool,
35
+ add_eos: bool,
36
+ ) -> List[int]:
37
+ """
38
+ Encode a string into a list of token ids. Assumes that the string
39
+ contains no special tokens.
40
+
41
+ Args:
42
+ text (str): The string to encode.
43
+ add_bos (bool): Whether to add the beginning of sequence token.
44
+ add_eos (bool): Whether to add the end of sequence token.
45
+
46
+ Returns:
47
+ List[int]: The list of token ids.
48
+ """
49
+ substrs: List[str] = []
50
+ tokens = []
51
+ for i in range(0, len(text), MAX_ENCODE_CHARS):
52
+ substr = text[i : i + MAX_ENCODE_CHARS]
53
+ # See https://github.com/openai/tiktoken/issues/195
54
+ sliced_substr = _split_long_repetitions(substr, MAX_NO_WHITESPACE_CHARS)
55
+ substrs.extend(sliced_substr)
56
+ for substr in substrs:
57
+ # allowed_special and disallowed_special are used by tiktoken to define
58
+ # how special tokens are encoded. Our setting here is to encode any
59
+ # special token as regular text and prevent tiktoken from raising errors.
60
+ # This means we should only call encode on strings not containing special tokens.
61
+ tokens.extend(
62
+ self.tt_model.encode(
63
+ substr,
64
+ allowed_special=set([
65
+ START_IMAGE,
66
+ END_IMAGE,
67
+ START_VIDEO,
68
+ END_VIDEO,
69
+ START_AUDIO,
70
+ END_AUDIO,
71
+ ]),
72
+ disallowed_special=(),
73
+ )
74
+ )
75
+ if add_bos:
76
+ tokens.insert(0, self.bos_id)
77
+ if add_eos:
78
+ tokens.append(self.eos_id)
79
+ return tokens
80
+
81
+
82
+ def a2a_tokenizer(path: str) -> TikTokenTokenizer:
83
+ tiktoken = A2ATokenizer(path, all_special_tokens=A2A_SPECIAL_TOKENS)
84
+ tiktoken.pad_id = 0
85
+ return tiktoken
models/training_config.yml ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ _component_: models.lora_mmllama3_8b
3
+ lora_attn_modules:
4
+ - q_proj
5
+ - v_proj
6
+ apply_lora_to_mlp: false
7
+ apply_lora_to_output: false
8
+ lora_rank: 8
9
+ lora_alpha: 16
10
+ perception_tokens: 2
11
+ use_clip: false
12
+ tokenizer:
13
+ _component_: models.a2a_tokenizer
14
+ path: models/tokenizer.model
15
+ checkpointer:
16
+ _component_: torchtune.utils.FullModelMetaCheckpointer
17
+ checkpoint_dir: /workspace/omega_a2a/training
18
+ checkpoint_files:
19
+ - consolidated.00.pth
20
+ adapter_checkpoint: null
21
+ recipe_checkpoint: null
22
+ output_dir: /workspace/omega_a2a/checkpoints
23
+ model_type: LLAMA3
24
+ resume_from_checkpoint: false
25
+ interim_checkpoint_steps: 5000
26
+ interim_gen_steps: null
27
+ max_new_tokens: 170
28
+ temperature: 0.8
29
+ top_k: 200
30
+ dataset:
31
+ _component_: ds.EvenBatcher
32
+ buffer_size: 36
33
+ dataset:
34
+ _component_: ds.RoundRobinDataset
35
+ datasets:
36
+ - _component_: ds.OmegaVideoCaptionDataset
37
+ length: 500000
38
+ - _component_: ds.LlavaInstructDataset
39
+ dataset_path: ds/coco_llava_instruct/output.parquet
40
+ train_on_input: false
41
+ - _component_: ds.LlavaInstructDataset
42
+ dataset_path: ds/vision_flan/output.parquet
43
+ train_on_input: false
44
+ - _component_: ds.CaptionInstructDataset
45
+ dataset_path: ds/sam_llava/output.parquet
46
+ train_on_input: false
47
+ seed: null
48
+ shuffle: true
49
+ batch_size: 4
50
+ optimizer:
51
+ _component_: torch.optim.AdamW
52
+ weight_decay: 0.0001
53
+ lr: 3.0e-05
54
+ lr_scheduler:
55
+ _component_: torchtune.modules.get_cosine_schedule_with_warmup
56
+ num_warmup_steps: 100
57
+ loss:
58
+ _component_: torch.nn.CrossEntropyLoss
59
+ epochs: 6
60
+ max_steps_per_epoch: null
61
+ gradient_accumulation_steps: 64
62
+ compile: false
63
+ output_dir: /tmp/lora_finetune_output
64
+ metric_logger:
65
+ _component_: torchtune.utils.metric_logging.DiskLogger
66
+ log_dir: ${output_dir}
67
+ log_every_n_steps: null
68
+ device: cuda
69
+ dtype: bf16
70
+ enable_activation_checkpointing: false
71
+ profiler:
72
+ _component_: torchtune.utils.profiler
73
+ enabled: false
74
+ inference:
75
+ prompt_template: 'Video:
76
+
77
+ {video}
78
+
79
+ Caption the previous video.'
80
+ max_new_tokens: 300
81
+ temperature: 0.6
82
+ top_k: 5
83
+ quantizer: null
quantized/infer.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import os
2
+ # import torch
3
+ # from litgpt.generate.base import next_token_image_batch
4
+ # import soundfile as sf
5
+ # from utils.snac_utils import layershift, reconscruct_snac, reconstruct_tensors, get_time_str
6
+ # from utils.snac_utils import get_snac, generate_audio_data
7
+ # import clip
8
+ # import inference
9
+ # from tqdm import tqdm
10
+ # from inference import OmniInference, load_model, load_audio, download_model
11
+ # from inference import text_vocabsize, padded_text_vocabsize, get_text_stream
12
+ # from PIL import Image
13
+
14
+
15
+ # torch.set_printoptions(sci_mode=False)
16
+
17
+ # _image = inference._image
18
+ # _eoimage = inference._eoimage
19
+ # _pad_t = inference._pad_t
20
+ # _input_t = inference._input_t
21
+ # _answer_t = inference._answer_t
22
+ # _eot = inference._eot
23
+ # _eoa = inference._eoa
24
+ # _pad_a = inference._pad_a
25
+ # _input_a = inference._input_a
26
+ # _answer_a = inference._answer_a
27
+
28
+
29
+ # def get_input_ids_ImageQA_ATBatch(mel, leng, whispermodel, device):
30
+
31
+ # with torch.no_grad():
32
+ # mel = mel.unsqueeze(0).to(device)
33
+ # audio_feature = whispermodel.embed_audio(mel)[0][:leng]
34
+
35
+ # audio_len = audio_feature.size(0)
36
+
37
+ # input_ids = []
38
+ # input_ids_item = [[] for i in range(8)]
39
+ # for i in range(7):
40
+ # input_ids_item[i] = [layershift(_image,i)] + [layershift(_pad_a,i)] * 50 + [layershift(_eoimage,i)]
41
+ # input_ids_item[i] += [layershift(_input_a,i)]+[layershift(_pad_a,i)]*(audio_len)+[layershift(_eoa,i)]
42
+ # input_ids_item[i] += [layershift(_answer_a,i)]
43
+
44
+ # input_ids_item[-1] = [_pad_t]* (52 + 2 + audio_len) + [_answer_t]
45
+ # input_ids_item = [torch.tensor(item) for item in input_ids_item]
46
+
47
+ # input_ids.append(input_ids_item)
48
+
49
+ # input_ids_item = [[] for i in range(8)]
50
+ # for i in range(7):
51
+ # input_ids_item[i] = [layershift(_image,i)] + [layershift(_pad_a,i)] * 50 + [layershift(_eoimage,i)]
52
+ # input_ids_item[i] += [layershift(_input_a,i)]+[layershift(_pad_a,i)]*(audio_len)+[layershift(_eoa,i)] + [layershift(_pad_a,i)]
53
+
54
+ # input_ids_item[-1] = [_pad_t]* (52 + 2 + audio_len) + [_answer_t]
55
+
56
+ # input_ids_item = [torch.tensor(item) for item in input_ids_item]
57
+ # input_ids.append(input_ids_item)
58
+
59
+ # stacked_inputids = [[] for _ in range(8)]
60
+ # for i in range(2):
61
+ # for j in range(8):
62
+ # stacked_inputids[j].append(input_ids[i][j])
63
+ # stacked_inputids = [torch.stack(tensors) for tensors in stacked_inputids]
64
+
65
+ # return torch.stack([audio_feature,audio_feature]), stacked_inputids
66
+
67
+
68
+ # def load_clip_model(ckpt_dir, device):
69
+ # clip_model_path = ckpt_dir + "/ViT-B-32.pt"
70
+ # if not os.path.exists(clip_model_path):
71
+ # clip_model_path = "ViT-B/32"
72
+ # clipmodel, clippreprocess = clip.load(clip_model_path, device=device)
73
+ # return clipmodel, clippreprocess
74
+
75
+
76
+ # class OmniVisionInference(OmniInference):
77
+
78
+ # def __init__(self, ckpt_dir='./checkpoint', device='cuda:0'):
79
+ # self.device = device
80
+ # if not os.path.exists(ckpt_dir):
81
+ # print(f"checkpoint directory {ckpt_dir} not found, downloading from huggingface")
82
+ # download_model(ckpt_dir)
83
+ # self.fabric, self.model, self.text_tokenizer, self.snacmodel, self.whispermodel = load_model(ckpt_dir, device)
84
+ # self.clipmodel, self.clippreprocess = load_clip_model(ckpt_dir, device)
85
+
86
+ # def warm_up(self,
87
+ # audio_sample='./data/samples/vision_qa_audio.wav',
88
+ # image_sample='./data/samples/vision_qa_image.jpg'
89
+ # ):
90
+ # for _ in self.run_vision_AA_batch_stream(audio_sample, image_sample,
91
+ # save_path="./data/samples/vision_qa_output.wav",
92
+ # warm_up=True):
93
+ # pass
94
+
95
+ # @torch.inference_mode()
96
+ # def run_vision_AA_batch_stream(self, audio_path, image_path,
97
+ # stream_stride=4,
98
+ # max_returned_tokens=2048,
99
+ # temperature=0.9,
100
+ # top_k=1,
101
+ # top_p=1.0,
102
+ # eos_id_a=_eoa,
103
+ # eos_id_t=_eot,
104
+ # pad_id=_pad_t,
105
+ # save_path=None,
106
+ # warm_up=False
107
+ # ):
108
+ # with self.fabric.init_tensor():
109
+ # self.model.set_kv_cache(batch_size=2)
110
+
111
+ # model = self.model
112
+
113
+ # mel, leng = load_audio(audio_path)
114
+ # img = Image.open(image_path)
115
+
116
+ # audio_feature, input_ids = get_input_ids_ImageQA_ATBatch(mel, leng, self.whispermodel, self.device)
117
+ # ima = self.clippreprocess(img).unsqueeze(0).to(self.device)
118
+ # ima_feature = self.clipmodel.encode_image(ima).squeeze(0).to(self.device)
119
+
120
+ # ima_feature = torch.stack([ima_feature.clone(),ima_feature.clone()]).to(self.device)
121
+ # leng = [leng,leng]
122
+ # task = ['ImageQA_A','ImageQA_AT']
123
+
124
+ # T = input_ids[0].size(1)
125
+ # assert max_returned_tokens > T, f"max_returned_tokens {max_returned_tokens} should be greater than audio length {T}"
126
+
127
+ # if model.max_seq_length < max_returned_tokens - 1:
128
+ # raise NotImplementedError(
129
+ # f"max_seq_length {model.max_seq_length} needs to be >= {max_returned_tokens - 1}"
130
+ # )
131
+
132
+ # list_output = [[] for i in range(8)]
133
+
134
+ # tokens_A , token_T = next_token_image_batch(
135
+ # model,
136
+ # audio_feature.to(torch.float32).to(self.device),
137
+ # ima_feature.to(torch.float32).to(self.device) ,
138
+ # input_ids ,
139
+ # whisper_lens = leng ,
140
+ # task = task,
141
+ # input_pos = torch.arange(0, T, device=self.device),
142
+ # temperature=temperature,
143
+ # top_k=top_k,
144
+ # top_p=top_p
145
+ # )
146
+ # for i in range(7): list_output[i].append(tokens_A[i].tolist()[0])
147
+ # list_output[7].append(token_T.tolist()[0])
148
+
149
+ # text_end = False
150
+ # index = 1
151
+ # nums_generate = stream_stride
152
+ # begin_generate = False
153
+ # current_index = 0
154
+ # input_pos = torch.tensor([T], device=self.device)
155
+
156
+ # model_input_ids = [[] for i in range(8)]
157
+ # for i in range(7):
158
+ # tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize+ i * 4160
159
+ # model_input_ids[i].append(tokens_A[i].clone().to(self.device).to(torch.int32))
160
+ # model_input_ids[i].append(torch.tensor([layershift(4097,i)],device=self.device))
161
+ # model_input_ids[i] = torch.stack(model_input_ids[i])
162
+
163
+ # model_input_ids[-1].append(token_T.clone().to(torch.int32))
164
+ # model_input_ids[-1].append(token_T.clone().to(torch.int32))
165
+ # model_input_ids[-1] = torch.stack(model_input_ids[-1])
166
+
167
+ # text_index = 0
168
+ # is_text_end = False
169
+
170
+ # for _ in tqdm(range(2, max_returned_tokens - T + 1)):
171
+
172
+ # tokens_A , token_T = next_token_image_batch(model, None , None ,
173
+ # input_ids = model_input_ids,
174
+ # whisper_lens= None,
175
+ # task = None,
176
+ # input_pos = input_pos,
177
+ # temperature=temperature,
178
+ # top_k=top_k,
179
+ # top_p=top_p)
180
+
181
+ # if text_end:
182
+ # token_T = torch.tensor([_pad_t], device=self.device)
183
+
184
+ # if tokens_A[-1] == eos_id_a:
185
+ # break
186
+ # if token_T == eos_id_t:
187
+ # text_end = True
188
+
189
+ # for i in range(7): list_output[i].append(tokens_A[i].tolist()[0])
190
+ # list_output[7].append(token_T.tolist()[0])
191
+
192
+
193
+ # if index == 7:
194
+ # begin_generate = True
195
+
196
+ # if begin_generate:
197
+ # current_index += 1
198
+ # if current_index == nums_generate:
199
+ # current_index = 0
200
+ # snac = get_snac(list_output,index,nums_generate)
201
+ # audio_stream = generate_audio_data(snac, self.snacmodel, self.device)
202
+ # if is_text_end:
203
+ # text_stream = ""
204
+ # else:
205
+ # text_stream, text_index, is_text_end = get_text_stream(list_output, text_index, self.text_tokenizer)
206
+
207
+ # yield (audio_stream, text_stream)
208
+
209
+ # if warm_up:
210
+ # break
211
+
212
+ # input_pos = input_pos.add_(1)
213
+ # model_input_ids = [[] for i in range(8)]
214
+ # for i in range(7):
215
+ # tokens_A[i] = tokens_A[i].clone() + padded_text_vocabsize+ i * 4160
216
+ # model_input_ids[i].append(tokens_A[i].clone().to(self.device).to(torch.int32))
217
+ # model_input_ids[i].append(torch.tensor([layershift(4097,i)],device=self.device))
218
+ # model_input_ids[i] = torch.stack(model_input_ids[i])
219
+
220
+ # model_input_ids[-1].append(token_T.clone().to(torch.int32))
221
+ # model_input_ids[-1].append(token_T.clone().to(torch.int32))
222
+ # model_input_ids[-1] = torch.stack(model_input_ids[-1])
223
+
224
+ # index += 1
225
+
226
+ # text_tokens = list_output[-1]
227
+ # if text_vocabsize in text_tokens:
228
+ # text_tokens = text_tokens[:text_tokens.index(text_vocabsize)]
229
+ # res_text = self.text_tokenizer.decode(torch.tensor(text_tokens))
230
+ # print(f"text output: {res_text}")
231
+
232
+ # if save_path is not None:
233
+ # audiolist = reconscruct_snac(list_output)
234
+ # audio = reconstruct_tensors(audiolist)
235
+ # with torch.inference_mode():
236
+ # audio_hat = self.snacmodel.decode(audio)
237
+ # sf.write(save_path, audio_hat.squeeze().cpu().numpy(), 24000)
238
+
239
+ # model.clear_kv_cache()
240
+
241
+
242
+ # def test_vision_infer():
243
+ # client = OmniVisionInference()
244
+ # client.warm_up()
245
+ # input_audio_path = './data/samples/vision_qa_audio.wav'
246
+ # input_image_path = './data/samples/vision_qa_image.jpg'
247
+
248
+ # res_text = ""
249
+ # for audio_stream, text_stream in client.run_vision_AA_batch_stream(
250
+ # input_audio_path,
251
+ # input_image_path,
252
+ # save_path="./vision_qa_output.wav"
253
+ # ):
254
+ # res_text += text_stream
255
+ # print(f"text_output: {res_text}")
256
+
257
+
258
+ # if __name__ == "__main__":
259
+ # test_vision_infer()
260
+
261
+
262
+ # # 1234232434232
263
+ # # 1234232434232
264
+
265
+ # # 1234232434232
266
+ # # 1234232434232
267
+
268
+ # # 1234232434232
269
+ # # 1234232434232
270
+
271
+
272
+ # # 5069
273
+
274
+
275
+ # # 3670
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.6.0
2
+ sentencepiece==0.2.0
3
+ tiktoken==0.4.0
4
+ torchtune @ git+https://github.com/pytorch/torchtune.git@8f59c2fecd722691271eecca630a526719a32f76#egg=torchtune
5
+ lm_eval==0.4
6
+ torchvision==0.21.0
7
+ diffusers==0.27.2
8
+ imagebind @ git+https://github.com/omegalabsinc/ImageBind.git@c3c3b2e1ce6fd850ff42ce0375823fe22880a7cc#egg=imagebind
9
+ llama3 @ git+https://github.com/meta-llama/llama3.git@af6eedf7042fb51d00b2b26d8ef1ceaab73e1670
10
+ pytorchvideo @ git+https://github.com/facebookresearch/pytorchvideo.git@28fe037d212663c6a24f373b94cc5d478c8c1a1d
11
+ wandb==0.17.1
12
+ numpy==1.26.4
13
+ huggingface-hub==0.24.0
14
+ omegaconf==2.3.0
15
+ uvicorn==0.25.0
16
+ fastapi==0.104.1
17
+ pydantic==2.5.2
18
+ torchaudio==2.6.0
server.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ import numpy as np
3
+ import torch
4
+ from pydantic import BaseModel
5
+ from typing import List
6
+ import base64
7
+ import io
8
+ import os
9
+ import logging
10
+ from pathlib import Path
11
+ from inference import InferenceRecipe
12
+ from fastapi.middleware.cors import CORSMiddleware
13
+
14
+ from omegaconf import OmegaConf, DictConfig
15
+
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ app = FastAPI()
20
+
21
+ # Add CORS middleware
22
+ app.add_middleware(
23
+ CORSMiddleware,
24
+ allow_origins=["*"],
25
+ allow_credentials=True,
26
+ allow_methods=["*"],
27
+ allow_headers=["*"],
28
+ )
29
+
30
+ class EmbeddingRequest(BaseModel):
31
+ embedding: List[float]
32
+
33
+ class TextResponse(BaseModel):
34
+ texts: List[str] = []
35
+
36
+ # Model initialization status
37
+ INITIALIZATION_STATUS = {
38
+ "model_loaded": False,
39
+ "error": None
40
+ }
41
+
42
+ # Global model instance
43
+ inference_recipe = None
44
+ cfg = None
45
+
46
+
47
+ def initialize_model():
48
+ """Initialize the model with correct path resolution"""
49
+ global inference_recipe, INITIALIZATION_STATUS, cfg
50
+ try:
51
+ device = "cuda" if torch.cuda.is_available() else "cpu"
52
+ logger.info(f"Initializing model on device: {device}")
53
+
54
+ # Critical: Use absolute path for model loading
55
+ model_path = os.path.abspath(os.path.join('/app', 'models'))
56
+ logger.info(f"Loading models from: {model_path}")
57
+
58
+ if not os.path.exists(model_path):
59
+ raise RuntimeError(f"Model path {model_path} does not exist")
60
+
61
+ # Log available model files for debugging
62
+ model_files = os.listdir(model_path)
63
+ logger.info(f"Available model files: {model_files}")
64
+
65
+ cfg = OmegaConf.load(os.path.join('/app', 'training_config.yml'))
66
+ cfg.model = DictConfig({
67
+ "_component_": "models.mmllama3_8b",
68
+ "use_clip": False,
69
+ "perception_tokens": cfg.model.perception_tokens,
70
+ })
71
+ cfg.checkpointer.checkpoint_dir = model_path
72
+ cfg.checkpointer.checkpoint_files = ["meta_model_5.pt"]
73
+ cfg.inference.max_new_tokens = 300
74
+ cfg.tokenizer.path = os.path.join(model_path, "tokenizer.model")
75
+ inference_recipe = InferenceRecipe(cfg)
76
+ inference_recipe.setup(cfg=cfg)
77
+ INITIALIZATION_STATUS["model_loaded"] = True
78
+ logger.info("Model initialized successfully")
79
+ return True
80
+ except Exception as e:
81
+ INITIALIZATION_STATUS["error"] = str(e)
82
+ logger.error(f"Failed to initialize model: {e}")
83
+ return False
84
+
85
+ @app.on_event("startup")
86
+ async def startup_event():
87
+ """Initialize model on startup"""
88
+ initialize_model()
89
+
90
+ @app.get("/api/v1/health")
91
+ def health_check():
92
+ """Health check endpoint"""
93
+ status = {
94
+ "status": "healthy" if INITIALIZATION_STATUS["model_loaded"] else "initializing",
95
+ "initialization_status": INITIALIZATION_STATUS
96
+ }
97
+
98
+ if inference_recipe is not None:
99
+ status.update({
100
+ "device": str(inference_recipe._device),
101
+ "dtype": str(inference_recipe._dtype)
102
+ })
103
+
104
+ return status
105
+
106
+ @app.post("/api/v1/inference")
107
+ async def inference(request: EmbeddingRequest) -> TextResponse:
108
+ """Run inference with enhanced error handling and logging"""
109
+ if not INITIALIZATION_STATUS["model_loaded"]:
110
+ raise HTTPException(
111
+ status_code=503,
112
+ detail=f"Model not ready. Status: {INITIALIZATION_STATUS}"
113
+ )
114
+
115
+ try:
116
+ # Log input validation
117
+ logger.info("Received inference request")
118
+
119
+ # Convert embedding to tensor
120
+ embedding = request.embedding # generate() expects List[float]
121
+ embedding = torch.tensor(embedding)
122
+ embedding = embedding.unsqueeze(0) # Add batch dimension
123
+ embedding = embedding.reshape(-1, 1024)
124
+ logger.info(f"Converted embedding to tensor with shape: {embedding.shape}")
125
+
126
+ # Run inference
127
+ results = inference_recipe.generate_batch(cfg=cfg, video_ib_embed=embedding)
128
+ logger.info("Generation complete")
129
+
130
+ # Convert results to list if it's not already
131
+ if isinstance(results, str):
132
+ results = [results]
133
+
134
+ return TextResponse(texts=results)
135
+
136
+ except Exception as e:
137
+ logger.error(f"Inference failed: {str(e)}", exc_info=True)
138
+ raise HTTPException(
139
+ status_code=500,
140
+ detail=str(e)
141
+ )
142
+
143
+ if __name__ == "__main__":
144
+ import uvicorn
145
+ uvicorn.run(app, host="0.0.0.0", port=8000)
146
+
147
+
148
+ # if __name__ == "__main__":
149
+ # import uvicorn
150
+ # uvicorn.run(app, host="0.0.0.0", port=8000)
151
+
152
+
153
+ # if __name__ == "__main__":
154
+ # import uvicorn
155
+ # uvicorn.run(app, host="0.0.0.0", port=8000)
156
+
setup.py ADDED
File without changes
test.sh ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Configure bash error handling
4
+ set -euo pipefail
5
+
6
+ # Configuration
7
+ API_HOST="localhost"
8
+ API_PORT="8000"
9
+ API_VERSION="v1"
10
+ BASE_URL="http://${API_HOST}:${API_PORT}/api/${API_VERSION}"
11
+
12
+ # Function to generate test embedding data
13
+ generate_test_embedding() {
14
+ python3 - <<EOF
15
+ import numpy as np
16
+ import json
17
+
18
+ # Generate a 4096-dimensional embedding vector (correct dimension for model)
19
+ embedding = np.random.randn(4096).astype(np.float32)
20
+ # Normalize the embedding
21
+ embedding = embedding / np.linalg.norm(embedding)
22
+ print(json.dumps(embedding.tolist()), end="")
23
+ EOF
24
+ }
25
+
26
+ # Function to test health endpoint
27
+ test_health() {
28
+ echo "Testing health endpoint..."
29
+ curl -s "${BASE_URL}/health" || {
30
+ echo "Health check failed"
31
+ exit 1
32
+ }
33
+ }
34
+
35
+ # Function to test inference endpoint
36
+ test_inference() {
37
+ echo
38
+ start_time=$(date +%s)
39
+ echo "Testing inference endpoint..."
40
+ local embedding_data=$(generate_test_embedding)
41
+
42
+ curl -X POST "${BASE_URL}/inference" \
43
+ -H "Content-Type: application/json" \
44
+ -d "{
45
+ \"embedding\": ${embedding_data}
46
+ }" || {
47
+ echo "Inference request failed"
48
+ exit 1
49
+ }
50
+ end_time=$(date +%s)
51
+ duration=$((end_time - start_time))
52
+ echo "Inference request completed in ${duration} seconds"
53
+ }
54
+
55
+ main() {
56
+ test_health
57
+ test_inference
58
+ }
59
+
60
+ main "$@"
training_config.yml ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ identity_token: 0 1 2
2
+ model:
3
+ _component_: models.lora_mmllama3_8b
4
+ lora_attn_modules:
5
+ - q_proj
6
+ - v_proj
7
+ apply_lora_to_mlp: false
8
+ apply_lora_to_output: false
9
+ lora_rank: 8
10
+ lora_alpha: 16
11
+ perception_tokens: 2
12
+ use_clip: false
13
+ tokenizer:
14
+ _component_: models.a2a_tokenizer
15
+ path: checkpoints/Meta-Llama-3-8B-Instruct/original/tokenizer.model
16
+ checkpointer:
17
+ _component_: torchtune.utils.FullModelMetaCheckpointer
18
+ checkpoint_dir: checkpoints/Meta-Llama-3-8B-Instruct/original/
19
+ checkpoint_files:
20
+ - consolidated.00.pth
21
+ adapter_checkpoint: null
22
+ recipe_checkpoint: null
23
+ output_dir: output_checkpoints/experiment_4
24
+ model_type: LLAMA3
25
+ resume_from_checkpoint: false
26
+ interim_checkpoint_steps: 1500000
27
+ interim_gen_steps: null
28
+ max_new_tokens: 100
29
+ temperature: 0.6
30
+ top_k: 300
31
+ dataset:
32
+ _component_: ds.EvenBatcher
33
+ dataset:
34
+ _component_: ds.RoundRobinDataset
35
+ datasets:
36
+ - _component_: ds.IdentityDataset
37
+ identity: ${identity_token}
38
+ length: 250000
39
+ train_on_input: true
40
+ seed: null
41
+ shuffle: true
42
+ batch_size: 4
43
+ optimizer:
44
+ _component_: torch.optim.AdamW
45
+ weight_decay: 0.01
46
+ lr: 0.0003
47
+ lr_scheduler:
48
+ _component_: torchtune.modules.get_cosine_schedule_with_warmup
49
+ num_warmup_steps: 100
50
+ loss:
51
+ _component_: torch.nn.CrossEntropyLoss
52
+ epochs: 1
53
+ max_steps_per_epoch: null
54
+ gradient_accumulation_steps: 64
55
+ compile: false
56
+ output_dir: /tmp/lora_finetune_output
57
+ metric_logger:
58
+ _component_: torchtune.utils.metric_logging.DiskLogger
59
+ log_dir: ${output_dir}
60
+ log_every_n_steps: null
61
+ device: cuda
62
+ dtype: bf16
63
+ enable_activation_checkpointing: false
64
+ profiler:
65
+ _component_: torchtune.utils.profiler
66
+ enabled: false
67
+ inference:
68
+ prompt_template: 'Video:
69
+
70
+ {video}
71
+
72
+ Caption the previous video.'
73
+ max_new_tokens: 300
74
+ temperature: 0.6
75
+ top_k: 300
76
+ quantizer: null