| | from typing import List |
| | import warnings |
| |
|
| | import torch |
| | from torch import nn, Tensor |
| | from torchvision import transforms |
| |
|
| | from torchtune.models.llama3 import lora_llama3_8b, llama3_8b |
| | from torchtune.modules.peft import LORA_ATTN_MODULES, LoRALinear |
| | from torchtune.modules import TransformerDecoder |
| |
|
| | with warnings.catch_warnings(): |
| | warnings.simplefilter("ignore", UserWarning) |
| | from imagebind.models import imagebind_model |
| | from models.imagebind_wrapper import get_imagebind_v2, V2_PATH |
| | from models.imagebind_wrapper import ImageBind |
| |
|
| | IMAGEBIND_DIM = 1024 |
| | CLIP_DIM = 768 |
| |
|
| |
|
| | class MMEmbedding(nn.Embedding): |
| | def __init__(self, e, perception_tokens=1, use_clip=False): |
| | super().__init__( |
| | num_embeddings=e.num_embeddings, |
| | embedding_dim=e.embedding_dim, |
| | padding_idx=e.padding_idx, |
| | max_norm=e.max_norm, |
| | norm_type=e.norm_type, |
| | scale_grad_by_freq=e.scale_grad_by_freq, |
| | sparse=e.sparse, |
| | ) |
| | self._perception_tokens = perception_tokens |
| | self._context = [] |
| | self._use_clip = use_clip |
| |
|
| | dim_in = IMAGEBIND_DIM + (CLIP_DIM if use_clip else 0) |
| | dim_out = e.embedding_dim * perception_tokens |
| |
|
| | self.proj_to_llama = nn.Sequential( |
| | nn.Linear(dim_in, dim_out), |
| | nn.GELU(), |
| | nn.LayerNorm(dim_out), |
| | nn.Linear(dim_out, dim_out), |
| | ) |
| |
|
| | def set_context(self, context): |
| | self._context = context |
| |
|
| | def forward(self, input: Tensor) -> Tensor: |
| | r = super().forward(input) |
| | |
| | for b, context_dict in enumerate(self._context): |
| | |
| | for s, embed in context_dict.items(): |
| | |
| | if self._use_clip: |
| | llama_embed = self.proj_to_llama(torch.cat([embed["ib_embed"], embed["clip_embed"]])) |
| | else: |
| | llama_embed = self.proj_to_llama(torch.cat([embed["ib_embed"]])) |
| | r[b, s:s+self._perception_tokens] = llama_embed.view(self._perception_tokens, -1) |
| | return r |
| |
|
| |
|
| | class MMLinear(nn.Linear): |
| | def __init__(self, o): |
| | super().__init__( |
| | in_features=o.in_features, |
| | out_features=o.out_features, |
| | bias=(o.bias != None) |
| | ) |
| | self._context = [] |
| |
|
| | dim_out = CLIP_DIM |
| | dim_in = o.in_features |
| | self.proj_from_llama = nn.Sequential( |
| | nn.Linear(dim_in, dim_out), |
| | nn.GELU(), |
| | nn.LayerNorm(dim_out), |
| | nn.Linear(dim_out, dim_out), |
| | ) |
| |
|
| | def set_context(self, context): |
| | self._context = context |
| |
|
| | def forward(self, input_bsd: Tensor) -> Tensor: |
| | |
| | self._clip_projections = [] |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | r = super().forward(input_bsd) |
| | return r |
| |
|
| |
|
| |
|
| | def lora_mmllama3_8b( |
| | lora_attn_modules: List[LORA_ATTN_MODULES], |
| | apply_lora_to_mlp: bool = False, |
| | apply_lora_to_output: bool = False, |
| | lora_rank: int = 8, |
| | lora_alpha: float = 16, |
| | quantize_base: bool = False, |
| | perception_tokens: int = 2, |
| | use_clip: bool = False |
| | ) -> TransformerDecoder: |
| | llama3 = lora_llama3_8b( |
| | lora_attn_modules, |
| | apply_lora_to_mlp, |
| | apply_lora_to_output, |
| | lora_rank, |
| | lora_alpha, |
| | quantize_base, |
| | ) |
| | llama3.tok_embeddings = MMEmbedding(llama3.tok_embeddings, perception_tokens, use_clip) |
| | llama3.output = MMLinear(llama3.output) |
| | return llama3 |
| |
|
| |
|
| | def mmllama3_8b( |
| | perception_tokens: int = 2, |
| | use_clip: bool = False |
| | ) -> TransformerDecoder: |
| | llama3 = llama3_8b() |
| | llama3.tok_embeddings = MMEmbedding(llama3.tok_embeddings, perception_tokens, use_clip) |
| | llama3.output = MMLinear(llama3.output) |
| | return llama3 |
| |
|
| |
|
| | def imagebind_huge(use_v2: bool=True): |
| | if use_v2: |
| | imagebind = ImageBind(v2=True) |
| | else: |
| | imagebind = imagebind_model.imagebind_huge(pretrained=True) |
| | imagebind.transform_from_pil = transforms.Compose([ |
| | transforms.Resize( |
| | 224, interpolation=transforms.InterpolationMode.BICUBIC |
| | ), |
| | transforms.CenterCrop(224), |
| | transforms.ToTensor(), |
| | transforms.Normalize( |
| | mean=(0.48145466, 0.4578275, 0.40821073), |
| | std=(0.26862954, 0.26130258, 0.27577711), |
| | ), |
| | ]) |
| | return imagebind |
| |
|
| |
|