| | import logging |
| | import transformers |
| | from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, AutoModel |
| | import torch |
| | from PIL import Image |
| | import requests |
| | from typing import List, Optional, Tuple, Union |
| | from transformers.cache_utils import Cache |
| | from transformers.models.llava_next.modeling_llava_next import image_size_to_num_patches |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def my_mistral_forward( |
| | self, |
| | input_ids: torch.LongTensor = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | num_logits_to_keep: int = 0, |
| | ): |
| | r""" |
| | Args: |
| | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| | Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| | config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| | |
| | num_logits_to_keep (`int`, *optional*): |
| | Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all |
| | `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that |
| | token can save memory, which becomes pretty significant for long sequences or large vocabulary size. |
| | |
| | Returns: |
| | |
| | Example: |
| | |
| | ```python |
| | >>> from transformers import AutoTokenizer, MistralForCausalLM |
| | |
| | >>> model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") |
| | >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") |
| | |
| | >>> prompt = "Hey, are you conscious? Can you talk to me?" |
| | >>> inputs = tokenizer(prompt, return_tensors="pt") |
| | |
| | >>> # Generate |
| | >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
| | >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| | "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
| | ```""" |
| |
|
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| | ) |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
| | |
| | outputs = self.model( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | cache_position=cache_position, |
| | ) |
| |
|
| | hidden_states = outputs[0] |
| | |
| | return hidden_states |
| |
|
| |
|
| | def transfer_mistral_forward(): |
| | transformers.models.mistral.MistralForCausalLM.forward = my_mistral_forward |
| |
|
| | class LLaVANextForEmbedding(LlavaNextForConditionalGeneration): |
| | def __init__(self, config): |
| | super().__init__(config) |
| | |
| | transfer_mistral_forward() |
| | def forward( |
| | self, |
| | input_ids: torch.LongTensor = None, |
| | pixel_values: torch.FloatTensor = None, |
| | image_sizes: Optional[torch.LongTensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[List[torch.FloatTensor]] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | vision_feature_layer: Optional[int] = None, |
| | vision_feature_select_strategy: Optional[str] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | num_logits_to_keep: int = 0, |
| | ): |
| |
|
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| | ) |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| | vision_feature_layer = ( |
| | vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer |
| | ) |
| | vision_feature_select_strategy = ( |
| | vision_feature_select_strategy |
| | if vision_feature_select_strategy is not None |
| | else self.config.vision_feature_select_strategy |
| | ) |
| |
|
| | if (input_ids is None) ^ (inputs_embeds is not None): |
| | raise ValueError( |
| | "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" |
| | ) |
| |
|
| | if pixel_values is not None and inputs_embeds is not None: |
| | raise ValueError( |
| | "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one" |
| | ) |
| |
|
| | legacy_processing = False |
| | if inputs_embeds is None: |
| | inputs_embeds = self.get_input_embeddings()(input_ids) |
| |
|
| | |
| | |
| | |
| | legacy_processing = ( |
| | (input_ids == self.config.image_token_index).sum(1).max() < self.config.image_seq_length |
| | ) or (input_ids.shape[-1] == 1 and pixel_values is not None) |
| |
|
| | if pixel_values is not None and pixel_values.size(0) > 0: |
| | |
| | image_num_patches = [ |
| | image_size_to_num_patches( |
| | image_size=imsize, |
| | grid_pinpoints=self.config.image_grid_pinpoints, |
| | patch_size=self.config.vision_config.image_size, |
| | ) |
| | for imsize in image_sizes |
| | ] |
| | |
| | if pixel_values.dim() == 5: |
| | |
| | _pixel_values_list = [ |
| | pix_val[:num_patch] for pix_val, num_patch in zip(pixel_values, image_num_patches) |
| | ] |
| | pixel_values = torch.cat(_pixel_values_list, dim=0) |
| | elif pixel_values.dim() != 4: |
| | |
| | raise ValueError(f"pixel_values of shape {pixel_values.shape}, expect to be of 4 or 5 dimensions") |
| |
|
| | image_features = self.vision_tower(pixel_values, output_hidden_states=True) |
| | selected_image_feature = image_features.hidden_states[vision_feature_layer] |
| | if vision_feature_select_strategy == "default": |
| | selected_image_feature = selected_image_feature[:, 1:] |
| | elif vision_feature_select_strategy == "full": |
| | selected_image_feature = selected_image_feature |
| | image_features = self.multi_modal_projector(selected_image_feature) |
| | image_features = torch.split(image_features, image_num_patches, dim=0) |
| |
|
| | |
| | image_features, feature_lens = self.pack_image_features( |
| | image_features, |
| | image_sizes, |
| | vision_feature_select_strategy=vision_feature_select_strategy, |
| | image_newline=self.image_newline, |
| | ) |
| | if legacy_processing: |
| | logger.warning_once( |
| | "Expanding inputs for image tokens in LLaVa-NeXT should be done in processing. " |
| | "Please add `patch_size` and `vision_feature_select_strategy` to the model's processing config or set directly " |
| | "with `processor.patch_size = {{patch_size}}` and processor.vision_feature_select_strategy = {{vision_feature_select_strategy}}`. " |
| | "Using processors without these attributes in the config is deprecated and will throw an error in v4.47." |
| | ) |
| | if input_ids.shape[1] != 1: |
| | inputs_embeds = inputs_embeds.to(image_features.dtype) |
| | inputs_embeds, attention_mask, position_ids, labels, _ = self._merge_input_ids_with_image_features( |
| | image_features, |
| | feature_lens, |
| | inputs_embeds, |
| | input_ids, |
| | attention_mask, |
| | position_ids, |
| | labels=labels, |
| | ) |
| | cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device) |
| | else: |
| | |
| | |
| | first_layer_past_key_value = past_key_values[0][0][:, :, :, 0] |
| |
|
| | |
| | batch_index, non_attended_tokens = torch.where(first_layer_past_key_value.float().sum(-2) == 0) |
| |
|
| | |
| | target_length = input_ids.shape[1] |
| | past_length = first_layer_past_key_value.shape[-1] |
| |
|
| | extended_attention_mask = torch.ones( |
| | (attention_mask.shape[0], past_length), |
| | dtype=attention_mask.dtype, |
| | device=attention_mask.device, |
| | ) |
| |
|
| | |
| | |
| | |
| | valid_indices = non_attended_tokens < extended_attention_mask.size(-1) |
| | new_batch_index = batch_index[valid_indices] |
| | new_non_attended_tokens = non_attended_tokens[valid_indices] |
| |
|
| | |
| | extended_attention_mask[new_batch_index, new_non_attended_tokens] = 0 |
| | attention_mask = torch.cat((extended_attention_mask, attention_mask[:, -target_length:]), dim=1) |
| | position_ids = torch.sum(attention_mask, dim=1).unsqueeze(-1) - 1 |
| | cache_position = torch.arange(attention_mask.shape[1], device=attention_mask.device)[ |
| | -target_length: |
| | ] |
| |
|
| | |
| | else: |
| | special_image_mask = ( |
| | (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds) |
| | ) |
| | image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) |
| | inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) |
| |
|
| | outputs = self.language_model( |
| | attention_mask=attention_mask, |
| | position_ids=position_ids, |
| | past_key_values=past_key_values, |
| | inputs_embeds=inputs_embeds, |
| | use_cache=use_cache, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | cache_position=cache_position, |
| | num_logits_to_keep=num_logits_to_keep, |
| | ) |
| |
|
| | return outputs |
| |
|
| | def set_processor(self, model_name): |
| | self.processor = LlavaNextProcessor.from_pretrained(model_name) |
| | def prepare_text_input(self, image=None, text=None, q_or_c=None, task_instruction=None): |
| | task_instruction_example_cir = "Retrieve the target image that best meets the combined criteria by using both the provided image and the image retrieval instructions: " |
| | |
| | assert q_or_c in ["query", "candidate", "q", "c"] |
| | |
| | if "q" in q_or_c: |
| | if task_instruction is None: |
| | text_input = "[INST] \n <instruct> <query>" |
| | print(f""" |
| | Warning: For optimal performance, MMRet-MLLM requires the task instruction to be specified in the query. |
| | For example, for the composed image retrieval task, you might use a specific instruction like: {task_instruction_example_cir}. |
| | Instructions for other tasks can be referenced in the MMEB benchmark. |
| | """) |
| | elif task_instruction is not None: |
| | text_input = f"[INST] \n <instruct> {task_instruction} <query> " |
| | |
| | if text is not None: |
| | text_input = f"{text_input} {text} \n" |
| | if image is not None: |
| | text_input = f"{text_input} <image>" |
| |
|
| | text_input = f"{text_input} [/INST]" |
| | else: |
| | text_input = "[INST] " |
| | if text is not None: |
| | text_input = f"{text_input} {text} \n" |
| | if image is not None: |
| | text_input = f"{text_input} <image>" |
| | text_input = f"{text_input} [/INST]" |
| | |
| | return text_input |
| |
|
| | def data_process(self, images=None, text=None, q_or_c=None, task_instruction=None): |
| | if images is not None: |
| | _is_list = isinstance(images, list) |
| | elif text is not None: |
| | _is_list = isinstance(text, list) |
| | else: |
| | raise ValueError("images and text cannot be both None.") |
| | |
| | assert q_or_c in ["query", "candidate", "q", "c"] |
| |
|
| | if not _is_list : |
| | text_input = self.prepare_text_input(images, text, q_or_c, task_instruction) |
| | text_input = [text_input] |
| | |
| |
|
| | if images is not None: |
| | images = Image.open(images).resize((512,512)).convert("RGB") |
| | images = [images] |
| | inputs = self.processor(images=images, text=text_input, return_tensors="pt", padding=True) |
| | else: |
| | inputs = self.processor(text=text_input, return_tensors="pt", padding=True) |
| |
|
| | else: |
| | if text is None: |
| | text = [None] * len(images) |
| | text_input = [self.prepare_text_input(_image, _text, q_or_c, task_instruction) for _image, _text in zip(images, text)] |
| | |
| | if images is not None: |
| | images = [Image.open(_image).resize((512,512)).convert("RGB") for _image in images] |
| | inputs = self.processor(images=images, text=text_input, return_tensors="pt", padding=True) |
| | else: |
| | inputs = self.processor(text=text_input, return_tensors="pt", padding=True) |
| | |
| | inputs = inputs.to(self.device) |
| |
|
| | return inputs |