| | |
| | |
| | |
| | |
| | |
| | import itertools |
| | import sys |
| | import time |
| | from typing import Any, Dict, List |
| |
|
| | import torch |
| | from torch import nn |
| | from omegaconf import DictConfig |
| | from PIL import Image |
| |
|
| | from torchtune import config, utils |
| | from torchtune.utils._generation import sample |
| | from torchtune.models import convert_weights |
| | from torchtune.data import Message |
| |
|
| | from models.tokenizer import START_IMAGE, END_IMAGE, START_AUDIO, END_AUDIO, START_VIDEO, END_VIDEO |
| | from imagebind.models.imagebind_model import ModalityType |
| | from diffusers import DiffusionPipeline |
| |
|
| | from models import add_proj_convert_weights, _BASE_TRAINABLE |
| |
|
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| |
|
| |
|
| | log = utils.get_logger("DEBUG") |
| | add_proj_convert_weights() |
| |
|
| |
|
| | class InferenceRecipe: |
| | """ |
| | Recipe for generating tokens from a dense Transformer-based LLM. |
| | |
| | Currently this recipe supports single-GPU generation only. Speculative |
| | decoding is not supported. |
| | |
| | For more details on how to use this recipe for generation, please see our |
| | tutorial: https://pytorch.org/torchtune/main/tutorials/e2e_flow.html#generation |
| | |
| | For using this recipe with a quantized model, please the following section of |
| | the above tutorial: |
| | https://pytorch.org/torchtune/main/tutorials/e2e_flow.html#speeding-up-generation-using-quantization |
| | """ |
| |
|
| | def __init__(self, cfg: DictConfig) -> None: |
| | self._device = utils.get_device(device=cfg.device) |
| | self._dtype = utils.get_dtype(dtype=cfg.dtype) |
| | self._quantizer = config.instantiate(cfg.inference.quantizer) |
| | self._quantization_mode = utils.get_quantizer_mode(self._quantizer) |
| | self.prompt_template = cfg.inference.prompt_template |
| | perception_tokens = cfg.model.perception_tokens |
| | self._perception_tokens = ("0 " * perception_tokens)[:perception_tokens] |
| | utils.set_seed(seed=cfg.seed) |
| |
|
| | def setup(self, cfg: DictConfig) -> None: |
| | checkpointer = config.instantiate(cfg.checkpointer) |
| | if self._quantization_mode is None: |
| | ckpt_dict = checkpointer.load_checkpoint() |
| | else: |
| | |
| | |
| | |
| | ckpt_dict = checkpointer.load_checkpoint(weights_only=False) |
| |
|
| | self._model = self._setup_model( |
| | model_cfg=cfg.model, |
| | model_state_dict=ckpt_dict[utils.MODEL_KEY], |
| | ) |
| | with self._device: |
| | self._model.setup_caches(max_batch_size=cfg.batch_size, dtype=self._dtype) |
| |
|
| | self._tokenizer = config.instantiate(cfg.tokenizer) |
| | self._mm_ids_start = self._tokenizer.encode(START_IMAGE + START_AUDIO + START_VIDEO, add_eos=False, add_bos=False) |
| | self._mm_ids_end = self._tokenizer.encode(END_IMAGE + END_AUDIO + END_VIDEO, add_eos=False, add_bos=False) |
| | self.use_clip = cfg.model.use_clip |
| | if self.use_clip: |
| | self._clip_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-unclip-small", torch_dtype=self._dtype).to(self._device) |
| |
|
| | def _setup_model( |
| | self, |
| | model_cfg: DictConfig, |
| | model_state_dict: Dict[str, Any], |
| | ) -> nn.Module: |
| | with utils.set_default_dtype(self._dtype), self._device: |
| | model = config.instantiate(model_cfg) |
| |
|
| | if self._quantization_mode is not None: |
| | model = self._quantizer.quantize(model) |
| | model = model.to(device=self._device, dtype=self._dtype) |
| |
|
| | model.load_state_dict(model_state_dict) |
| |
|
| | |
| | utils.validate_expected_param_dtype(model.named_parameters(), dtype=self._dtype) |
| | log.debug(f"Model is initialized with precision {self._dtype}.") |
| |
|
| | return model |
| |
|
| | def mm_process_prompt(self, prompt): |
| | return ( |
| | prompt |
| | .replace("{image}", f"{START_IMAGE}{self._perception_tokens}{END_IMAGE}") |
| | .replace("{audio}", f"{START_AUDIO}{self._perception_tokens}{END_AUDIO}") |
| | .replace("{video}", f"{START_VIDEO}{self._perception_tokens}{END_VIDEO}") |
| | ) |
| |
|
| | def extract_mm_context(self, video_ib_embed, tokens): |
| | context = {} |
| | in_mm_embed = False |
| | for idx, tok in enumerate(tokens): |
| | in_mm_embed = in_mm_embed and not tok in self._mm_ids_end |
| | if in_mm_embed: |
| | |
| | context[idx] = { |
| | "ib_embed": video_ib_embed.to(dtype=self._dtype, device=self._device), |
| | } |
| | in_mm_embed = in_mm_embed or tok in self._mm_ids_start |
| | return context |
| |
|
| | @torch.no_grad() |
| | def generate(self, cfg: DictConfig, video_ib_embed: List[float]): |
| | messages = [ |
| | Message( |
| | role="user", |
| | content=self.mm_process_prompt(self.prompt_template), |
| | ), |
| | Message( |
| | role="assistant", |
| | content="", |
| | ) |
| | ] |
| | tokens, mask = self._tokenizer.tokenize_messages(messages) |
| | tokens = tokens[:-2] |
| | mm_context = [self.extract_mm_context(video_ib_embed, tokens)] |
| | prompt = torch.tensor(tokens, dtype=torch.int, device=self._device) |
| |
|
| | self._model.tok_embeddings.set_context(mm_context) |
| | self._model.output.set_context(mm_context) |
| |
|
| | bos_id = self._tokenizer.tt_model.encode("<|begin_of_text|>", allowed_special="all")[0] |
| | allowed_id = self._tokenizer.tt_model.encode(f"<|eot_id|>{START_IMAGE}{END_IMAGE}{START_AUDIO}{END_AUDIO}{START_VIDEO}{END_VIDEO}", allowed_special="all") |
| | disallowed_tokens = list(set(range(bos_id, bos_id + 256)) - set(allowed_id)) |
| | |
| |
|
| | def custom_generate_next_token(model, input_pos, x, temperature=1.0, top_k=None): |
| | model.tok_embeddings.set_context([]) |
| | model.output.set_context([]) |
| | |
| | |
| | logits = model(x, input_pos=input_pos) |
| | |
| | |
| | |
| | logits = logits[0, -1] |
| | |
| | |
| | token = sample(logits, temperature, top_k) |
| | if token in disallowed_tokens: |
| | return torch.tensor([self._tokenizer.eos_id]).to(x) |
| | return token |
| |
|
| | |
| | |
| | if self._quantization_mode is not None: |
| | log.info("Starting compilation to improve generation performance ...") |
| | custom_generate_next_token = torch.compile( |
| | custom_generate_next_token, mode="max-autotune", fullgraph=True |
| | ) |
| | t0 = time.perf_counter() |
| | _ = utils.generate( |
| | model=self._model, |
| | prompt=prompt, |
| | max_generated_tokens=2, |
| | temperature=cfg.temperature, |
| | top_k=cfg.top_k, |
| | eos_id=self._tokenizer.eos_id, |
| | custom_generate_next_token=custom_generate_next_token, |
| | ) |
| | t = time.perf_counter() - t0 |
| | log.info(f"Warmup run for quantized model takes: {t:.02f} sec") |
| |
|
| | t0 = time.perf_counter() |
| | generated_tokens = utils.generate( |
| | model=self._model, |
| | prompt=prompt, |
| | max_generated_tokens=cfg.max_new_tokens, |
| | temperature=cfg.temperature, |
| | top_k=cfg.top_k, |
| | eos_id=self._tokenizer.eos_id, |
| | custom_generate_next_token=custom_generate_next_token, |
| | ) |
| | t = time.perf_counter() - t0 |
| |
|
| | cleaned_tokens = [t for t in generated_tokens[len(prompt):] if t not in disallowed_tokens + allowed_id] |
| | caption = self._tokenizer.decode(cleaned_tokens) |
| |
|
| | |
| |
|
| | return caption |
| |
|
| |
|
| | @torch.no_grad() |
| | def generate_batch(self, cfg: DictConfig, video_ib_embed: torch.Tensor): |
| | log.info(f"inside generate_batch, video_ib_embed shape: {video_ib_embed.shape}") |
| | batch_dim = video_ib_embed.size(0) |
| | messages = [ |
| | Message( |
| | role="user", |
| | content=self.mm_process_prompt(self.prompt_template), |
| | ), |
| | Message(role="assistant", content="") |
| | ] |
| | tokens, mask = self._tokenizer.tokenize_messages(messages) |
| | tokens = tokens[:-2] |
| | mm_context = [self.extract_mm_context(e, tokens) for e in video_ib_embed] |
| | prompt = torch.tensor(tokens, dtype=torch.int, device=self._device).expand(batch_dim, -1).clone() |
| | prompt_length = prompt.size(1) |
| |
|
| | self._model.tok_embeddings.set_context(mm_context) |
| | self._model.output.set_context(mm_context) |
| |
|
| | bos_id = self._tokenizer.tt_model.encode("<|begin_of_text|>", allowed_special="all")[0] |
| | allowed_id = self._tokenizer.tt_model.encode(f"<|eot_id|>{START_IMAGE}{END_IMAGE}{START_AUDIO}{END_AUDIO}{START_VIDEO}{END_VIDEO}", allowed_special="all") |
| | disallowed_tokens = list(set(range(bos_id, bos_id + 256)) - set(allowed_id)) |
| |
|
| | def generate_next_token(model, input_pos, x, temperature=1.0, top_k=None): |
| | |
| | |
| | |
| | logits = model(x, input_pos=input_pos)[:, -1] |
| | tokens = sample(logits, temperature, top_k) |
| | return torch.tensor([ |
| | [self._tokenizer.eos_id if t in disallowed_tokens else t for t in toks] |
| | for toks in tokens |
| | ]).to(x.device) |
| |
|
| | generated_tokens = prompt.clone() |
| | |
| | stop_token_reached = torch.zeros(batch_dim, dtype=torch.bool, device=prompt.device) |
| |
|
| | |
| | tokens = generate_next_token( |
| | self._model, |
| | input_pos=torch.arange(0, prompt_length, device=prompt.device), |
| | x=prompt, |
| | temperature=cfg.temperature, |
| | top_k=cfg.top_k, |
| | ) |
| | eot_reached_b = tokens == self._tokenizer.eot_id |
| | generated_tokens = torch.cat([generated_tokens, tokens], dim=-1) |
| |
|
| | self._model.tok_embeddings.set_context([]) |
| | self._model.output.set_context([]) |
| |
|
| | input_pos = torch.tensor([prompt_length], device=prompt.device) |
| | for _ in range(cfg.max_new_tokens - 1): |
| | tokens = generate_next_token( |
| | self._model, input_pos=input_pos, x=tokens, temperature=cfg.temperature, top_k=cfg.top_k |
| | ) |
| | eot_reached_b |= tokens == self._tokenizer.eot_id |
| | tokens *= ~eot_reached_b |
| | generated_tokens = torch.cat([generated_tokens, tokens], dim=-1) |
| | if eot_reached_b.all(): |
| | print('eot_reached_b.all()') |
| | break |
| | input_pos += 1 |
| |
|
| | captions = [] |
| | for caption_tokens in generated_tokens.tolist(): |
| | _caption = "This video shows a " + self._tokenizer.decode(caption_tokens[prompt.size(1):]) |
| | captions.append(_caption) |
| | return captions |
| |
|
| |
|
| | @config.parse |
| | def main(cfg: DictConfig) -> None: |
| | import os |
| | config.log_config(recipe_name="InferenceRecipe", cfg=cfg) |
| | cfg.model = DictConfig({ |
| | "_component_": "models.mmllama3_8b", |
| | "use_clip": False, |
| | "perception_tokens": cfg.model.perception_tokens, |
| | }) |
| | cfg.batch_size = 4 |
| | cfg.checkpointer.checkpoint_dir = os.path.dirname("/home/salman/tezuesh/omegalabs-anytoany-bittensor/sandboxing/cache/xzistance_omega-a2a-hotkey/meta_model_0.pth") |
| | |
| | cfg.checkpointer.checkpoint_files = ["models/meta_model_0.pt"] |
| | cfg.inference.max_new_tokens = 300 |
| | cfg.tokenizer.path = "./models/tokenizer.model" |
| | inference_recipe = InferenceRecipe(cfg) |
| | inference_recipe.setup(cfg=cfg) |
| | captions = inference_recipe.generate_batch(cfg=cfg, video_ib_embed=torch.randn(4,1024)) |
| | print(captions) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | sys.exit(main()) |