| | from dataclasses import dataclass |
| | from typing import List, Optional, Tuple, Union |
| | import torch |
| | import torch.utils.checkpoint |
| | from torch import nn |
| | from transformers import PreTrainedModel |
| | from transformers.activations import ACT2FN |
| | from transformers.cache_utils import Cache |
| | from transformers.modeling_outputs import ModelOutput |
| | from transformers.utils import ( |
| | add_start_docstrings, |
| | add_start_docstrings_to_model_forward, |
| | logging, |
| | replace_return_docstrings, |
| | ) |
| | from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer, GenerationConfig, AutoConfig |
| | from .configuration_wemm import WeMMConfig |
| | from .vision_model import Idefics2VisionTransformer |
| | from .connector import Idefics2Connector |
| | from .image_processor import Idefics2ImageProcessor |
| | from .modeling_downsampler import DownsamplerModel |
| | from .modeling_projector import ProjectorModel |
| | from .modeling_internlm2 import InternLM2ForCausalLM |
| | from .tokenization_internlm2 import InternLM2Tokenizer |
| | from peft import PeftModel |
| | from peft import PeftConfig |
| | import os |
| | from PIL import Image |
| | import numpy as np |
| | IMAGE_TOKEN_INDEX = -200 |
| | DEFAULT_IMAGE_TOKEN = "<image>" |
| | IGNORE_INDEX = -100 |
| | from transformers import StoppingCriteria |
| | from transformers import PreTrainedTokenizerFast, StoppingCriteriaList |
| | import torch.nn.functional as F |
| | class StopWordStoppingCriteria(StoppingCriteria): |
| | """StopWord stopping criteria.""" |
| | def __init__(self, tokenizer, stop_word): |
| | self.tokenizer = tokenizer |
| | self.stop_word = stop_word |
| | self.length = len(self.stop_word) |
| | def __call__(self, input_ids, *args, **kwargs) -> bool: |
| | cur_text = self.tokenizer.decode(input_ids[0]) |
| | cur_text = cur_text.replace('\r', '').replace('\n', '') |
| | return cur_text[-self.length:] == self.stop_word |
| | def get_stop_criteria( |
| | tokenizer, |
| | stop_words=[], |
| | ): |
| | stop_criteria = StoppingCriteriaList() |
| | for word in stop_words: |
| | stop_criteria.append(StopWordStoppingCriteria(tokenizer, word)) |
| | return stop_criteria |
| | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): |
| | assert embed_dim % 2 == 0 |
| | |
| | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) |
| | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) |
| | emb = np.concatenate([emb_h, emb_w], axis=-1) |
| | return emb |
| | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): |
| | """ |
| | embed_dim: output dimension for each position |
| | pos: a list of positions to be encoded: size (M,) |
| | out: (M, D) |
| | """ |
| | assert embed_dim % 2 == 0 |
| | omega = np.arange(embed_dim // 2, dtype=np.float) |
| | omega /= embed_dim / 2. |
| | omega = 1. / 10000**omega |
| | pos = np.squeeze(pos) |
| | out = np.einsum('hw,d->hwd', pos, omega) |
| | emb_sin = np.sin(out) |
| | emb_cos = np.cos(out) |
| | emb = np.concatenate([emb_sin, emb_cos], axis=-1) |
| | return emb |
| | |
| | |
| | |
| | |
| | |
| | def get_2d_sincos_pos_embed(embed_dim, grid_size_h, grid_size_w, cls_token=False): |
| | """ |
| | grid_size: int of the grid height and width |
| | return: |
| | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) |
| | """ |
| | grid_h = np.arange(grid_size_h, dtype=np.float32) |
| | grid_w = np.arange(grid_size_w, dtype=np.float32) |
| | grid = np.meshgrid(grid_w, grid_h) |
| | grid = np.stack(grid, axis=0) |
| | grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) |
| | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) |
| | if cls_token: |
| | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) |
| | return pos_embed |
| | def recover_navit_subimages_with_pos_emb( |
| | sub_image_hidden_states, |
| | attention_mask, |
| | num_sub_images, |
| | visual_embedding_group, |
| | pos_hidden_size, |
| | thumbnail_only=False): |
| | _slice = int(np.sqrt(num_sub_images)) |
| | N, L, D = sub_image_hidden_states.shape |
| | _, H, W = attention_mask.shape |
| | if thumbnail_only is True: |
| | num_sub_images += 1 |
| | sub_image_hidden_states = sub_image_hidden_states.reshape(-1, num_sub_images, H, W, D) |
| | attention_mask = attention_mask.reshape(-1, num_sub_images, H, W) |
| | if thumbnail_only is True: |
| | sub_image_hidden_states = sub_image_hidden_states[:, -1:, :, :, :] |
| | attention_mask = attention_mask[:, -1:, :, :] |
| | _slice = 1 |
| | def _infer_ori_image_patch_shape(sub_image_attention_mask): |
| | ind_h, ind_w = torch.where(sub_image_attention_mask > 0) |
| | return torch.max(ind_h) + 1, torch.max(ind_w) + 1 |
| | def _pad_to_same(image_hidden): |
| | _dtype = image_hidden.dtype |
| | visual_downsample_stride = int(np.sqrt(visual_embedding_group)) |
| | full_h, full_w, _ = image_hidden.shape |
| | target_h, target_w = H * _slice, W * _slice |
| | |
| | to_pad_h = (target_h - full_h) + ( |
| | visual_downsample_stride - target_h % visual_downsample_stride) % visual_downsample_stride |
| | to_pad_w = (target_w - full_w) + ( |
| | visual_downsample_stride - target_w % visual_downsample_stride) % visual_downsample_stride |
| | |
| | image_hidden = image_hidden.permute(2, 0, 1).unsqueeze(0) |
| | pad_size = (0, to_pad_w, 0, to_pad_h) |
| | |
| | image_hidden = F.pad(image_hidden.to(torch.float32), pad_size, mode='replicate').squeeze(0).permute(1, 2, 0) |
| | return image_hidden.to(_dtype) |
| | image_hidden_states = list() |
| | valid_image_token = list() |
| | image_2d_pos = list() |
| | for batch_id in range(len(sub_image_hidden_states)): |
| | ori_h, ori_w = _infer_ori_image_patch_shape(attention_mask[batch_id][0]) |
| | full_h, full_w = ori_h * _slice, ori_w * _slice |
| | |
| | this_image_hidden = sub_image_hidden_states[batch_id][:, 0:ori_h, 0:ori_w, :] \ |
| | .view(_slice, _slice, ori_h, ori_w, D).permute(0, 2, 1, 3, 4).contiguous().view(full_h, full_w, D) |
| | pos_emb = get_2d_sincos_pos_embed(pos_hidden_size, grid_size_h=full_h, |
| | grid_size_w=full_w) |
| | pos_emb = torch.tensor(pos_emb, dtype=this_image_hidden.dtype, device=this_image_hidden.device) |
| | image_hidden_states.append(_pad_to_same(this_image_hidden)) |
| | image_2d_pos.append(_pad_to_same(pos_emb)) |
| | valid_image_token.append([full_h, full_w]) |
| | image_hidden_states = torch.stack(image_hidden_states) |
| | image_2d_pos = torch.stack(image_2d_pos) |
| | valid_image_token = torch.tensor(valid_image_token, dtype=torch.int64) |
| | return image_hidden_states, image_2d_pos, valid_image_token |
| | def visiual_token_downsample( |
| | visual_downsampler, |
| | image_hidden_states, |
| | valid_image_token, |
| | visual_embedding_group, |
| | image_2d_pos): |
| | if image_2d_pos is not None: |
| | image_hidden_states = image_hidden_states + image_2d_pos |
| | image_hidden_states = visual_downsampler(image_hidden_states) |
| | valid_image_token = torch.ceil(valid_image_token / np.sqrt(visual_embedding_group)).to(torch.int64) |
| | return image_hidden_states, valid_image_token |
| | def merge_native_qformer( |
| | clip_embeddings_native_patch, |
| | valid_image_token_shape, |
| | clip_embeddings_qformer, |
| | visual_source_spliter, |
| | num_sub_images): |
| | assert clip_embeddings_native_patch.size(0) == valid_image_token_shape.size(0) == clip_embeddings_qformer.size(0) |
| | def add_split_token_for_qformer_token(qformer_emb): |
| | |
| | len_per_token = int(qformer_emb.size(0) // (num_sub_images + 1)) |
| | qformer_emb_with_spliter = list() |
| | for i in range(num_sub_images + 1): |
| | qformer_emb_with_spliter.append( |
| | visual_source_spliter(torch.tensor([2 * i]).to(visual_source_spliter.weight.device)) |
| | ) |
| | qformer_emb_with_spliter.append(qformer_emb[i * len_per_token:(i + 1) * len_per_token]) |
| | qformer_emb_with_spliter.append( |
| | visual_source_spliter(torch.tensor([2 * i + 1]).to(visual_source_spliter.weight.device)) |
| | ) |
| | return torch.cat(qformer_emb_with_spliter, dim=0) |
| | merged_visual_embeddings = list() |
| | for batch_id in range(clip_embeddings_native_patch.size(0)): |
| | h, w = valid_image_token_shape[batch_id] |
| | native_patch_emb = clip_embeddings_native_patch[batch_id][:h, :w, :].reshape(h*w, -1) |
| | qformer_emb = clip_embeddings_qformer[batch_id] |
| | qformer_emb = add_split_token_for_qformer_token(qformer_emb) |
| | merged_visual_embeddings.append( |
| | torch.cat( |
| | [visual_source_spliter(torch.tensor([10]).to(visual_source_spliter.weight.device)), |
| | native_patch_emb, |
| | visual_source_spliter(torch.tensor([11]).to(visual_source_spliter.weight.device)), |
| | qformer_emb], |
| | dim=0)) |
| | return merged_visual_embeddings |
| | class WemmForConditionalGeneration(PreTrainedModel): |
| | config_class = WeMMConfig |
| | def __init__(self, config: WeMMConfig): |
| | super().__init__(config) |
| | self.vision_tower = Idefics2VisionTransformer(config.vision_config) |
| | self.image_processor = Idefics2ImageProcessor(config.image_processor) |
| | self.connector = Idefics2Connector(config.connector_config) |
| | self.projector = ProjectorModel(config.projector_config) |
| | self.language_model = InternLM2ForCausalLM(config.text_config) |
| | self.tokenizer = AutoTokenizer.from_pretrained("internlm/internlm2-chat-7b", trust_remote_code=True, encode_special_tokens=True) |
| | self.downsampler = DownsamplerModel(config.downsampler_config) |
| | self.visual_source_spliter_emb = torch.nn.Embedding(**config.spliter_emb_config) |
| | self.gen_config = GenerationConfig( |
| | max_new_tokens=512, |
| | do_sample=False, |
| | eos_token_id=self.tokenizer.eos_token_id, |
| | pad_token_id=self.tokenizer.pad_token_id |
| | if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id, |
| | ) |
| | self.do_image_splitting = config.do_image_splitting |
| | self.stop_criteria = get_stop_criteria( |
| | tokenizer=self.tokenizer, stop_words=['<|im_end|>']) |
| | self.config = config |
| | def mm_generate(self, image_path, prompt, gen_config=None): |
| | prompt = "<image>" + '\n' + prompt |
| | prompt = f"<|im_start|>user\n{prompt}<|im_end|><|im_start|>assistant\n" |
| | image = Image.open(image_path).convert('RGB') |
| | navit980_images = self.image_processor([[image]], return_tensors="pt", do_image_splitting=self.do_image_splitting) |
| | batch_size_navit = navit980_images['pixel_values'].shape[0] |
| | navit_pixel_values = navit980_images['navit_pixel_values'].cuda() |
| | navit_patch_attention_mask = navit980_images["pixel_attention_mask"].cuda() |
| | clip_visual_outputs = self.vision_tower(pixel_values=navit_pixel_values,patch_attention_mask=navit_patch_attention_mask,).last_hidden_state |
| | super_image_hidden_states, image_2d_pos, valid_image_token_shape = \ |
| | recover_navit_subimages_with_pos_emb( |
| | clip_visual_outputs, navit_patch_attention_mask, num_sub_images=4, |
| | visual_embedding_group=1, |
| | pos_hidden_size=4096, |
| | thumbnail_only=True |
| | ) |
| | clip_embeddings_native_patch, valid_image_token_shape = visiual_token_downsample( |
| | self.downsampler, |
| | super_image_hidden_states, valid_image_token_shape, |
| | visual_embedding_group=1, image_2d_pos=None |
| | ) |
| | clip_embeddings_qformer = self.connector(clip_visual_outputs, attention_mask=navit_patch_attention_mask.view(navit_pixel_values.size(0), -1)) |
| | hidden_size = clip_embeddings_qformer.shape[-1] |
| | clip_embeddings_qformer = clip_embeddings_qformer.view(batch_size_navit, -1, hidden_size) |
| | clip_embeddings_qformer = self.projector(clip_embeddings_qformer) |
| | merged_visual_embeddings = \ |
| | merge_native_qformer( |
| | clip_embeddings_native_patch, |
| | valid_image_token_shape, |
| | clip_embeddings_qformer, |
| | visual_source_spliter=self.visual_source_spliter_emb, |
| | num_sub_images=4 |
| | ) |
| | chunk_encode = [] |
| | for idx, chunk in enumerate(prompt.split(DEFAULT_IMAGE_TOKEN)): |
| | if idx == 0: |
| | cur_encode = self.tokenizer.encode(chunk) |
| | else: |
| | cur_encode = self.tokenizer.encode(chunk, add_special_tokens=False) |
| | chunk_encode.append(cur_encode) |
| | assert len(chunk_encode) == 2 |
| | ids = [] |
| | for idx, cur_chunk_encode in enumerate(chunk_encode): |
| | ids.extend(cur_chunk_encode) |
| | if idx != len(chunk_encode) - 1: |
| | ids.append(IMAGE_TOKEN_INDEX) |
| | ids = torch.tensor(ids).cuda().unsqueeze(0) |
| | pixel_values = None |
| | mm_inputs = self.prepare_inputs_labels_for_multimodal( |
| | llm=self.language_model, input_ids=ids, pixel_values=pixel_values, clip_embeddings=merged_visual_embeddings) |
| | generate_output = self.language_model.generate( |
| | **mm_inputs, |
| | generation_config=gen_config if gen_config is not None else self.gen_config, |
| | streamer=None, |
| | bos_token_id=self.tokenizer.bos_token_id, |
| | stopping_criteria=self.stop_criteria |
| | ) |
| | predict = self.tokenizer.decode( |
| | generate_output[0], skip_special_tokens=True).strip() |
| | return predict |
| | def get_valid_visual_embedding(self, embedding, valid_token_shape): |
| | if valid_token_shape is None: |
| | return embedding |
| | h, w = valid_token_shape |
| | return embedding[:h, :w, :].reshape(h*w, -1) |
| | |
| | def prepare_inputs_labels_for_multimodal( |
| | self, |
| | llm: PreTrainedModel, |
| | input_ids: torch.LongTensor = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | past_key_values: Optional[List[torch.FloatTensor]] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | pixel_values: Optional[torch.FloatTensor] = None, |
| | clip_embeddings: Optional[torch.FloatTensor] = None, |
| | hard_coded_max_len: Optional[int] = None, |
| | **kwargs): |
| | if pixel_values is None and clip_embeddings is None: |
| | return { |
| | 'input_ids': input_ids, |
| | 'position_ids': position_ids, |
| | 'attention_mask': attention_mask, |
| | 'past_key_values': past_key_values, |
| | 'inputs_embeds': None, |
| | 'labels': labels |
| | } |
| | valid_image_token_shape = kwargs.get('valid_image_token_shape', None) |
| | _labels = labels |
| | _position_ids = position_ids |
| | _attention_mask = attention_mask |
| | if attention_mask is None: |
| | attention_mask = torch.ones_like(input_ids, dtype=torch.bool) |
| | else: |
| | attention_mask = attention_mask.bool() |
| | if position_ids is None: |
| | position_ids = torch.arange( |
| | 0, input_ids.shape[1], dtype=torch.long, device=input_ids.device) |
| | if labels is None: |
| | labels = torch.full_like(input_ids, IGNORE_INDEX) |
| | |
| | input_ids = [ |
| | cur_input_ids[cur_attention_mask] |
| | for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask) |
| | ] |
| | labels = [ |
| | cur_labels[cur_attention_mask] |
| | for cur_labels, cur_attention_mask in zip(labels, attention_mask) |
| | ] |
| | new_inputs_embeds = [] |
| | new_labels = [] |
| | new_img_masks = [] |
| | cur_image_idx = 0 |
| | for batch_idx, cur_input_ids in enumerate(input_ids): |
| | num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum() |
| | if num_images == 0: |
| | cur_pixel_values = pixel_values[cur_image_idx] if pixel_values is not None else None |
| | cur_clip_emb = self.get_valid_visual_embedding(clip_embeddings[cur_image_idx], valid_image_token_shape[cur_image_idx]) if clip_embeddings is not None else None |
| | cur_inputs_embeds_1 = llm.get_input_embeddings()(cur_input_ids) |
| | if cur_clip_emb is not None and cur_pixel_values is not None: |
| | cur_inputs_embeds = torch.cat( |
| | [cur_inputs_embeds_1, cur_pixel_values[0:0], cur_clip_emb[0:0]], dim=0) |
| | elif cur_pixel_values is not None: |
| | cur_inputs_embeds = torch.cat( |
| | [cur_inputs_embeds_1, cur_pixel_values[0:0]], dim=0) |
| | elif cur_clip_emb is not None: |
| | cur_inputs_embeds = torch.cat( |
| | [cur_inputs_embeds_1, cur_clip_emb[0:0]], dim=0) |
| | else: |
| | raise ValueError |
| | new_inputs_embeds.append(cur_inputs_embeds) |
| | new_labels.append(labels[batch_idx]) |
| | new_img_masks.append(torch.zeros( |
| | cur_inputs_embeds.shape[0], device=cur_inputs_embeds.device).bool()) |
| | cur_image_idx += 1 |
| | continue |
| | image_token_indices = [-1] + torch.where( |
| | cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [ |
| | cur_input_ids.shape[0] |
| | ] |
| | cur_input_ids_noim = [] |
| | cur_labels = labels[batch_idx] |
| | cur_labels_noim = [] |
| | for i in range(len(image_token_indices) - 1): |
| | cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + |
| | 1:image_token_indices[i + |
| | 1]]) |
| | cur_labels_noim.append(cur_labels[image_token_indices[i] + |
| | 1:image_token_indices[i + 1]]) |
| | split_sizes = [x.shape[0] for x in cur_labels_noim] |
| | cur_inputs_embeds = llm.get_input_embeddings()( |
| | torch.cat(cur_input_ids_noim)) |
| | cur_inputs_embeds_no_im = torch.split( |
| | cur_inputs_embeds, split_sizes, dim=0) |
| | cur_new_inputs_embeds = [] |
| | cur_new_labels = [] |
| | cur_img_masks = [] |
| | for i in range(num_images + 1): |
| | cur_new_inputs_embeds.append(cur_inputs_embeds_no_im[i]) |
| | cur_new_labels.append(cur_labels_noim[i]) |
| | cur_img_masks.append(torch.zeros( |
| | cur_inputs_embeds_no_im[i].shape[0], device=cur_inputs_embeds_no_im[i].device).bool()) |
| | if i < num_images: |
| | cur_pixel_values = pixel_values[cur_image_idx] if pixel_values is not None else None |
| | if(valid_image_token_shape is not None): |
| | cur_clip_emb = \ |
| | self.get_valid_visual_embedding(clip_embeddings[cur_image_idx], valid_image_token_shape[cur_image_idx]) \ |
| | if clip_embeddings is not None else None |
| | else: |
| | cur_clip_emb = clip_embeddings[cur_image_idx] if clip_embeddings is not None else None |
| | cur_image_idx += 1 |
| | |
| | if cur_pixel_values is not None: |
| | cur_new_inputs_embeds.append(cur_pixel_values) |
| | cur_img_masks.append(torch.ones( |
| | cur_pixel_values.shape[0], device=cur_pixel_values.device).bool()) |
| | cur_new_labels.append( |
| | torch.full((cur_pixel_values.shape[0], ), |
| | IGNORE_INDEX, |
| | device=cur_labels.device, |
| | dtype=cur_labels.dtype)) |
| | |
| | if cur_clip_emb is not None: |
| | cur_new_inputs_embeds.append(cur_clip_emb) |
| | cur_img_masks.append(torch.zeros( |
| | cur_clip_emb.shape[0], device=cur_clip_emb.device).bool()) |
| | cur_new_labels.append( |
| | torch.full((cur_clip_emb.shape[0],), |
| | IGNORE_INDEX, |
| | device=cur_labels.device, |
| | dtype=cur_labels.dtype)) |
| | cur_new_inputs_embeds = torch.cat(cur_new_inputs_embeds) |
| | cur_new_labels = torch.cat(cur_new_labels) |
| | cur_img_masks = torch.cat(cur_img_masks) |
| | new_inputs_embeds.append(cur_new_inputs_embeds) |
| | new_labels.append(cur_new_labels) |
| | new_img_masks.append(cur_img_masks) |
| | |
| | max_len = max(x.shape[0] for x in new_inputs_embeds) |
| | if hard_coded_max_len is not None: |
| | max_len = min(max_len, hard_coded_max_len) |
| | batch_size = len(new_inputs_embeds) |
| | new_inputs_embeds_padded = [] |
| | new_labels_padded = torch.full((batch_size, max_len), |
| | IGNORE_INDEX, |
| | dtype=new_labels[0].dtype, |
| | device=new_labels[0].device) |
| | attention_mask = torch.zeros((batch_size, max_len), |
| | dtype=attention_mask.dtype, |
| | device=attention_mask.device) |
| | position_ids = torch.zeros((batch_size, max_len), |
| | dtype=position_ids.dtype, |
| | device=position_ids.device) |
| | new_img_masks_padded = torch.zeros((batch_size, max_len), device=new_img_masks[0].device).bool() |
| | for i, (cur_new_embed, |
| | cur_new_labels, cur_new_img_masks) in enumerate(zip(new_inputs_embeds, new_labels, new_img_masks)): |
| | cur_new_embed = cur_new_embed[:max_len] |
| | cur_new_labels = cur_new_labels[:max_len] |
| | cur_new_img_masks = cur_new_img_masks[:max_len] |
| | cur_len = cur_new_embed.shape[0] |
| | new_inputs_embeds_padded.append( |
| | torch.cat((cur_new_embed, |
| | torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), |
| | dtype=cur_new_embed.dtype, |
| | device=cur_new_embed.device)), |
| | dim=0)) |
| | if cur_len > 0: |
| | new_labels_padded[i, :cur_len] = cur_new_labels |
| | attention_mask[i, :cur_len] = True |
| | position_ids[i, :cur_len] = torch.arange( |
| | 0, |
| | cur_len, |
| | dtype=position_ids.dtype, |
| | device=position_ids.device) |
| | new_img_masks_padded[i, :cur_len] = cur_new_img_masks |
| | new_inputs_embeds = torch.stack(new_inputs_embeds_padded, dim=0) |
| | if _labels is None: |
| | new_labels = None |
| | else: |
| | new_labels = new_labels_padded |
| | if _attention_mask is None: |
| | attention_mask = None |
| | else: |
| | attention_mask = attention_mask.to(dtype=_attention_mask.dtype) |
| | if _position_ids is None: |
| | position_ids = None |
| | prepared_data = { |
| | 'input_ids': None, |
| | 'position_ids': position_ids, |
| | 'attention_mask': attention_mask, |
| | 'past_key_values': past_key_values, |
| | 'inputs_embeds': new_inputs_embeds, |
| | 'labels': new_labels, |
| | } |
| | if pixel_values is not None: |
| | prepared_data.update({'im_mask': new_img_masks_padded}) |
| | return prepared_data |
| | AutoConfig.register("wemm_hf", WeMMConfig) |
| | AutoModel.register(WeMMConfig, WemmForConditionalGeneration) |
| |
|