Spaces:
Running
on
Zero
Running
on
Zero
| # Transformers | |
| import re | |
| import torch | |
| from torch import nn | |
| from utils.utils import * | |
| from typing import Optional, Tuple, Union | |
| from transformers import MambaForCausalLM | |
| from transformers import LlavaNextForConditionalGeneration, LlavaForConditionalGeneration | |
| class MambaCache: | |
| def __init__(self, config, batch_size, dtype=torch.float16, device=None): | |
| self.seqlen_offset = 0 | |
| self.dtype = dtype | |
| intermediate_size = config.intermediate_size | |
| ssm_state_size = config.state_size | |
| conv_kernel_size = config.conv_kernel | |
| self.conv_states = { | |
| i: torch.zeros(batch_size, intermediate_size, conv_kernel_size, device=device, dtype=dtype) | |
| for i in range(config.num_hidden_layers) | |
| } | |
| self.ssm_states = { | |
| i: torch.zeros(batch_size, intermediate_size, ssm_state_size, device=device, dtype=dtype) | |
| for i in range(config.num_hidden_layers) | |
| } | |
| # Dataclass & ModelOutput | |
| from dataclasses import dataclass | |
| from transformers.modeling_outputs import ModelOutput | |
| class MambaCausalLMOutput(ModelOutput): | |
| loss: Optional[torch.FloatTensor] = None | |
| cache_params: Optional[MambaCache] = None | |
| tor_features: Optional[torch.FloatTensor] = None | |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None | |
| class MeteorMambaForCausalLM(MambaForCausalLM): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| # initialize other projections for Vision and tor | |
| self.vision_proj = self.build_vision_projector(1024, self.config.hidden_size) | |
| self.tor_proj = self.build_vision_projector(self.config.hidden_size, 4096) | |
| # replacing embedding size of mamba with that of meteor | |
| self.backbone.embeddings = nn.Embedding(num_embeddings=92546, | |
| embedding_dim=self.config.hidden_size) | |
| # image processing variable | |
| self.mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).view(1,-1,1,1) * 255 | |
| self.std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).view(1,-1,1,1) * 255 | |
| def image_processor(self, images): | |
| norm_images = (images - self.mean.to(images.device)) / self.std.to(images.device) | |
| return norm_images | |
| def build_vision_projector(mm_hidden_size, hidden_size): | |
| projector_type = 'mlp2x_gelu' | |
| mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) | |
| if mlp_gelu_match: | |
| mlp_depth = int(mlp_gelu_match.group(1)) | |
| modules = [nn.Linear(mm_hidden_size, hidden_size)] | |
| for _ in range(1, mlp_depth): | |
| modules.append(nn.GELU()) | |
| modules.append(nn.Linear(hidden_size, hidden_size)) | |
| return nn.Sequential(*modules) | |
| raise ValueError(f'Unknown projector type: {projector_type}') | |
| def eval_process( | |
| self, | |
| inputs, | |
| tokenizer, | |
| device, | |
| img_token_number, | |
| ): | |
| batched_image=[] | |
| batched_qa_prompt=[] | |
| for _input in inputs: | |
| # Visualization | |
| # imim = _input['image'].cpu().permute(1, 2, 0) | |
| # adding <image> to question if not included despite being an image, and adding system prompt and <tor> prompt | |
| if 'image' in _input.keys() and not '<image>' in _input['question']: _input['question'] = '<image>\n' + _input['question'] | |
| # make question, rationale, and answer | |
| question = make_instruction_for_mmamba(question=_input['question']) | |
| # add bundle image tokens if it has <image> token | |
| question = add_bundle_tokens(question, '<image>', img_token_number) | |
| # making batched moai prompt | |
| if 'image' in _input.keys() and _input['image'] != None: batched_image.append(_input['image'].to(device)) | |
| batched_qa_prompt.append(question) | |
| '''For Final Outputs''' | |
| qa_prompts = tokenizer(batched_qa_prompt, padding='longest', return_tensors="pt", add_special_tokens=False) | |
| # [1] input_ids | |
| input_ids = qa_prompts.input_ids.to(device) | |
| # image or only text? | |
| if len(batched_image): | |
| # [2] pixel values | |
| try: | |
| pixel_values = self.image_processor(torch.stack(batched_image)).to(device) | |
| assert pixel_values.dim() == 4 | |
| except: | |
| new_batched_image = [] | |
| for batched_image_element in batched_image: | |
| if batched_image_element.dim() == 3: | |
| new_batched_image.append(batched_image_element.unsqueeze(0)) | |
| else: | |
| new_batched_image.append(batched_image_element) | |
| pixel_values = self.image_processor(torch.cat(new_batched_image, dim=0)).to(device) | |
| return {"input_ids": input_ids, "image": pixel_values} | |
| else: | |
| return {"input_ids": input_ids} | |
| def _merge_input_embeds_with_image_features(self, image_features, inputs_embeds, input_ids): | |
| # batch index for image feature | |
| batch_ind_image_feature = 0 | |
| # shape of image_features | |
| _, C, D = image_features.shape | |
| for ind, input_id in enumerate(input_ids): | |
| matching = torch.where(input_id==self.config.image_token_index) | |
| num_image_tokens_per_one_sample = len(matching[0]) // C | |
| inputs_embeds[ind][matching] = image_features[batch_ind_image_feature: batch_ind_image_feature+num_image_tokens_per_one_sample].view(-1, D) | |
| batch_ind_image_feature += num_image_tokens_per_one_sample | |
| def forward( | |
| self, | |
| input_ids: Optional[torch.LongTensor] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| image_features: Optional[torch.FloatTensor] = None, | |
| cache_params: Optional[MambaCache] = None, | |
| # labels: Optional[torch.LongTensor] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| use_cache: Optional[bool] = None, | |
| **kwargs, # for now we need this for generation | |
| ) -> Union[Tuple, MambaCausalLMOutput]: | |
| r""" | |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
| Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set | |
| `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` | |
| are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` | |
| """ | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if inputs_embeds is None: | |
| # 1. Extra the input embeddings | |
| inputs_embeds = self.get_input_embeddings()(input_ids) | |
| # 2. Merge text and images | |
| if image_features is not None and input_ids.shape[1] != 1: | |
| image_features = self.vision_proj(image_features) | |
| self._merge_input_embeds_with_image_features(image_features, inputs_embeds, input_ids) | |
| mamba_outputs = self.backbone( | |
| cache_params=cache_params, | |
| inputs_embeds=inputs_embeds, | |
| output_hidden_states=output_hidden_states, | |
| return_dict=return_dict, | |
| use_cache=use_cache, | |
| ) | |
| hidden_states = mamba_outputs[0] | |
| # logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float() | |
| loss = None | |
| # if labels is not None: | |
| # # move labels to correct device to enable model parallelism | |
| # labels = labels.to(logits.device) | |
| # # Shift so that tokens < n predict n | |
| # shift_logits = logits[..., :-1, :].contiguous() | |
| # shift_labels = labels[..., 1:].contiguous() | |
| # # Flatten the tokens | |
| # loss_fct = nn.CrossEntropyLoss() | |
| # loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) | |
| # if not return_dict: | |
| # output = (logits,) + mamba_outputs[1:] | |
| # return ((loss,) + output) if loss is not None else output | |
| return MambaCausalLMOutput( | |
| loss=loss, | |
| cache_params=mamba_outputs.cache_params, | |
| tor_features=self.tor_proj(hidden_states[torch.where(input_ids==self.config.tor_token_index)]), | |
| hidden_states=mamba_outputs.hidden_states, | |
| ) | |
| def prepare_inputs_for_generation( | |
| self, input_ids, cache_params: Optional[MambaCache] = None, inputs_embeds=None, image_features=None, **kwargs | |
| ): | |
| # only last token for inputs_ids if the state is passed along. | |
| if cache_params is not None: | |
| input_ids = input_ids[:, -1].unsqueeze(-1) | |
| if inputs_embeds is not None and cache_params is None: | |
| model_inputs = {"inputs_embeds": inputs_embeds, "image_features":image_features} | |
| else: | |
| model_inputs = {"input_ids": input_ids, "image_features":image_features} | |
| model_inputs["cache_params"] = cache_params | |
| return model_inputs |