| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| from typing import List, Optional, Sequence, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.nn import CrossEntropyLoss |
| from transformers import ( |
| AutoConfig, |
| AutoModelForCausalLM, |
| GenerationMixin, |
| LlamaModel, |
| LlamaPreTrainedModel, |
| PreTrainedModel, |
| ) |
| from transformers.cache_utils import Cache, DynamicCache |
| from transformers.masking_utils import create_causal_mask |
| from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
| from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
| from transformers.models.qwen2_vl.modeling_qwen2_vl import ( |
| PatchMerger, |
| Qwen2VisionTransformerPretrainedModel, |
| ) |
| from transformers.processing_utils import Unpack |
| from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, logging |
|
|
| from .configuration_sarashina2_vision import Sarashina2VisionConfig |
|
|
| logger = logging.get_logger(__name__) |
|
|
| _CONFIG_FOR_DOC = "Sarashina2VisionConfig" |
|
|
|
|
| class Sarashina2VisionVisionTransformerPretrainedModel(Qwen2VisionTransformerPretrainedModel): |
| def __init__(self, config: Sarashina2VisionConfig) -> None: |
| super().__init__(config) |
| self.deepstack_visual_indices: Sequence[int] = config.deepstack_visual_indices |
| self.deepstack_merger = nn.ModuleList( |
| [ |
| PatchMerger( |
| dim=config.hidden_size, |
| context_dim=config.embed_dim, |
| spatial_merge_size=config.spatial_merge_size, |
| ) |
| for _ in self.deepstack_visual_indices |
| ] |
| ) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| grid_thw: torch.Tensor, |
| **kwargs, |
| ) -> Tuple[torch.Tensor, List[torch.Tensor]]: |
| r""" |
| grid_thw (`torch.LongTensor` of shape `(num_images, 3)`): |
| The temporal, height and width dimensions of feature shape for each image. Each row contains [t, h, w] values. |
| """ |
| hidden_states = self.patch_embed(hidden_states) |
| rotary_pos_emb = self.rot_pos_emb(grid_thw) |
| emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) |
| position_embeddings = (emb.cos(), emb.sin()) |
|
|
| cu_seqlens = torch.repeat_interleave( |
| grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] |
| ).cumsum( |
| dim=0, |
| |
| |
| |
| |
| dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32, |
| ) |
| cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) |
| deepstack_features = [] |
|
|
| for layer_idx, blk in enumerate(self.blocks): |
| hidden_states = blk( |
| hidden_states, |
| cu_seqlens=cu_seqlens, |
| position_embeddings=position_embeddings, |
| **kwargs, |
| ) |
| if layer_idx in self.deepstack_visual_indices: |
| deepstack_layer_index = self.deepstack_visual_indices.index(layer_idx) |
| deepstack_features.append( |
| self.deepstack_merger[deepstack_layer_index](hidden_states) |
| ) |
|
|
| return self.merger(hidden_states), deepstack_features |
|
|
|
|
| class Sarashina2VisionTextModel(LlamaModel): |
| @auto_docstring |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| deepstack_features: Sequence[torch.Tensor] = (), |
| visual_mask: Optional[torch.Tensor] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> BaseModelOutputWithPast: |
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError("You must specify exactly one of input_ids or inputs_embeds") |
|
|
| if inputs_embeds is None: |
| inputs_embeds: torch.Tensor = self.embed_tokens(input_ids) |
|
|
| if use_cache and past_key_values is None: |
| past_key_values = DynamicCache(config=self.config) |
|
|
| if cache_position is None: |
| past_seen_tokens = ( |
| past_key_values.get_seq_length() if past_key_values is not None else 0 |
| ) |
| cache_position: torch.Tensor = torch.arange( |
| past_seen_tokens, |
| past_seen_tokens + inputs_embeds.shape[1], |
| device=inputs_embeds.device, |
| ) |
|
|
| if position_ids is None: |
| position_ids = cache_position.unsqueeze(0) |
|
|
| causal_mask = create_causal_mask( |
| config=self.config, |
| input_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| cache_position=cache_position, |
| past_key_values=past_key_values, |
| position_ids=position_ids, |
| ) |
|
|
| hidden_states = inputs_embeds |
| position_embeddings = self.rotary_emb(hidden_states, position_ids) |
|
|
| for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]): |
| hidden_states = decoder_layer( |
| hidden_states, |
| attention_mask=causal_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| cache_position=cache_position, |
| position_embeddings=position_embeddings, |
| **kwargs, |
| ) |
| if layer_idx < len(deepstack_features): |
| hidden_states = hidden_states.clone() |
| hidden_states[visual_mask, :] = ( |
| hidden_states[visual_mask, :] + deepstack_features[layer_idx] |
| ) |
|
|
| hidden_states = self.norm(hidden_states) |
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=past_key_values, |
| ) |
|
|
|
|
| class Sarashina2VisionTextForCausalLM(LlamaPreTrainedModel, GenerationMixin): |
| _tied_weights_keys = ["lm_head.weight"] |
| _tp_plan = {"lm_head": "colwise_rep"} |
| _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = Sarashina2VisionTextModel(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| |
| self.post_init() |
|
|
| @can_return_tuple |
| @auto_docstring |
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| logits_to_keep: Union[int, torch.Tensor] = 0, |
| deepstack_features: Sequence[torch.Tensor] = (), |
| visual_mask: Optional[torch.Tensor] = None, |
| **kwargs: Unpack[TransformersKwargs], |
| ) -> CausalLMOutputWithPast: |
| r""" |
| Example: |
| |
| ```python |
| >>> from transformers import AutoTokenizer, LlamaForCausalLM |
| |
| >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") |
| >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") |
| |
| >>> prompt = "Hey, are you conscious? Can you talk to me?" |
| >>> inputs = tokenizer(prompt, return_tensors="pt") |
| |
| >>> # Generate |
| >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
| >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
| ```""" |
| outputs: BaseModelOutputWithPast = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| deepstack_features=deepstack_features, |
| visual_mask=visual_mask, |
| **kwargs, |
| ) |
|
|
| hidden_states = outputs.last_hidden_state |
| |
| slice_indices = ( |
| slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep |
| ) |
| logits = self.lm_head(hidden_states[:, slice_indices, :]) |
|
|
| loss = None |
| if labels is not None: |
| loss = self.loss_function( |
| logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs |
| ) |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| class Sarashina2VisionPreTrainedModel(PreTrainedModel): |
| config_class = Sarashina2VisionConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_cache_class = True |
| _supports_static_cache = True |
|
|
| def _init_weights(self, module): |
| std = ( |
| self.config.initializer_range |
| if hasattr(self.config, "initializer_range") |
| else self.config.text_config.initializer_range |
| ) |
|
|
| if hasattr(module, "class_embedding"): |
| module.class_embedding.data.normal_(mean=0.0, std=std) |
|
|
| if isinstance(module, (nn.Linear, nn.Conv3d)): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
|
|
|
|
| class Sarashina2VisionForCausalLM(Sarashina2VisionPreTrainedModel, GenerationMixin): |
| def __init__(self, config: Sarashina2VisionConfig): |
| super().__init__(config) |
| config.text_config._attn_implementation = config._attn_implementation |
| config.vision_config._attn_implementation = config._attn_implementation |
|
|
| self.visual = Sarashina2VisionVisionTransformerPretrainedModel._from_config( |
| config.vision_config |
| ) |
| self.norm = nn.LayerNorm(config.text_config.hidden_size) |
| self.llm = Sarashina2VisionTextForCausalLM._from_config(config.text_config) |
|
|
| self.use_mrope = False |
| self.mrope_section = None |
| if hasattr(config, "rope_scaling") and config.rope_scaling is not None: |
| self.mrope_section = config.rope_scaling.get("mrope_section", []) |
| self.mrope_interleaved = config.rope_scaling.get("mrope_interleaved", False) |
| self.spatial_reset = config.rope_scaling.get("spatial_reset", False) |
| self.use_mrope = True |
|
|
| if self.use_mrope: |
| assert len(self.mrope_section) > 0, ( |
| f"mrope_section: {self.mrope_section} must len(mrope_section) > 0" |
| ) |
| self.llm.rope_deltas = None |
|
|
| logger.info( |
| "Replace RotaryEmbedding to MRopeRotaryEmbedding: model.llm.model.rotary_emb" |
| ) |
| replace_module_path = "model.rotary_emb" |
| parent_path, leaf = replace_module_path.rsplit(".", 1) |
| parent = self.llm.get_submodule(parent_path) |
|
|
| setattr( |
| parent, |
| leaf, |
| MRopeRotaryEmbedding( |
| self.config, |
| ), |
| ) |
|
|
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.llm.get_input_embeddings() |
|
|
| def get_image_embeds( |
| self, |
| hidden_states: torch.Tensor, |
| grid_thw: torch.Tensor, |
| ) -> torch.Tensor: |
| rotary_pos_emb = self.visual.rot_pos_emb(grid_thw) |
| emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) |
| position_embeddings = (emb.cos(), emb.sin()) |
| hidden_states = self.visual.patch_embed(hidden_states) |
|
|
| cu_seqlens = torch.repeat_interleave( |
| grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] |
| ).cumsum(dim=0, dtype=torch.int32) |
| cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) |
|
|
| for blk in self.visual.blocks: |
| hidden_states = blk( |
| hidden_states, |
| cu_seqlens=cu_seqlens, |
| rotary_pos_emb=rotary_pos_emb, |
| position_embeddings=position_embeddings, |
| ) |
| return self.norm(self.visual.merger(hidden_states)) |
|
|
| @auto_docstring |
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| pixel_values: torch.FloatTensor = None, |
| pixel_values_video: torch.FloatTensor = None, |
| image_grid_thw: Optional[torch.LongTensor] = None, |
| video_grid_thw: Optional[torch.LongTensor] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| logits_to_keep: Union[int, torch.Tensor] = 0, |
| **lm_kwargs, |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
| """ |
| Args: |
| input_ids (torch.LongTensor, optional): Indices of input sequence tokens in the vocabulary. Defaults to None. |
| attention_mask (Optional[torch.Tensor], optional): Mask to avoid performing attention on padding token indices. Defaults to None. |
| position_ids (Optional[torch.LongTensor], optional): Indices of positions of each input sequence tokens in the position embeddings. Defaults to None. |
| past_key_values (Optional[List[torch.FloatTensor]], optional): _description_. Defaults to None. |
| inputs_embeds (Optional[torch.FloatTensor], optional): Instead of passing `input_ids` you can choose to directly pass an embedded representation. Defaults to None. |
| labels (Optional[torch.LongTensor], optional): Labels for computing the masked language modeling loss. Defaults to None. |
| use_cache (Optional[bool], optional): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding. Defaults to None. |
| output_attentions (Optional[bool], optional): Whether or not to return the attentions tensors of all attention layers. Defaults to None. |
| output_hidden_states (Optional[bool], optional): Whether or not to return the hidden states of all layers. Defaults to None. |
| return_dict (Optional[bool], optional): Whether or not to return a `CausalLMOutputWithPast` instead of a plain tuple. Defaults to None. |
| pixel_values (torch.FloatTensor, optional): The tensors corresponding to the input images. Defaults to None. |
| pixel_values_video (torch.FloatTensor, optional): The tensors corresponding to the input videos. Defaults to None. |
| image_grid_thw (Optional[torch.LongTensor], optional): The temporal, height and width of feature shape of each image in LLM. Defaults to None. |
| video_grid_thw (Optional[torch.LongTensor], optional): The temporal, height and width of feature shape of each video in LLM. Defaults to None. |
| cache_position (Optional[torch.LongTensor], optional): Indices depicting the position of the input sequence tokens in the sequence. Defaults to None. |
| logits_to_keep (Union[int, torch.Tensor]): If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all |
| `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that |
| token can save memory, which becomes pretty significant for long sequences or large vocabulary size. |
| If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. |
| This is useful when using packed tensor format (single dimension for batch and sequence length). |
| Returns: |
| CausalLMOutputWithPast: The output of the model. |
| """ |
| output_attentions = ( |
| output_attentions if output_attentions is not None else self.config.output_attentions |
| ) |
| output_hidden_states = ( |
| output_hidden_states |
| if output_hidden_states is not None |
| else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| deepstack_visual_embeds: Sequence[torch.Tensor] = () |
| visual_pos_masks: Optional[torch.Tensor] = None |
| image_mask: Optional[torch.Tensor] = None |
| video_mask: Optional[torch.Tensor] = None |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.get_input_embeddings()(input_ids) |
| if pixel_values is not None: |
| pixel_values = pixel_values.type(self.visual.get_dtype()) |
| image_embeds, deepstack_image_features = self.visual(pixel_values, image_grid_thw) |
| image_embeds = self.norm(image_embeds) |
| n_image_tokens = (input_ids == self.config.image_token_id).sum().item() |
| n_image_features = image_embeds.shape[0] |
| if n_image_tokens != n_image_features: |
| raise ValueError( |
| f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" |
| ) |
| visual_mask = input_ids == self.config.image_token_id |
| image_mask = ( |
| visual_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) |
| ) |
| image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
| inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) |
| if pixel_values_video is not None: |
| pixel_values_video = pixel_values_video.type(self.visual.get_dtype()) |
| video_embeds, deepstack_video_features = self.visual( |
| pixel_values_video, video_grid_thw |
| ) |
| video_embeds = self.norm(video_embeds) |
| n_video_tokens = (input_ids == self.config.video_token_id).sum().item() |
| n_video_features = video_embeds.shape[0] |
| if n_video_tokens != n_video_features: |
| raise ValueError( |
| f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" |
| ) |
| visual_mask = input_ids == self.config.video_token_id |
| video_mask = ( |
| visual_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) |
| ) |
| video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
| inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) |
|
|
| if image_mask is not None and video_mask is not None: |
| image_mask = image_mask[..., 0] |
| video_mask = video_mask[..., 0] |
| visual_pos_masks = image_mask | video_mask |
| deepstack_visual_embeds = [] |
| image_mask_joint = image_mask[visual_pos_masks] |
| video_mask_joint = video_mask[visual_pos_masks] |
| for img_embed, vid_embed in zip(deepstack_image_features, deepstack_video_features): |
| embed_joint = img_embed.new_zeros(visual_pos_masks.sum(), img_embed.shape[-1]).to( |
| img_embed.device |
| ) |
| embed_joint[image_mask_joint, :] = img_embed |
| embed_joint[video_mask_joint, :] = vid_embed |
| deepstack_visual_embeds.append(embed_joint) |
| elif image_mask is not None: |
| image_mask = image_mask[..., 0] |
| visual_pos_masks = image_mask |
| deepstack_visual_embeds = deepstack_image_features |
| elif video_mask is not None: |
| video_mask = video_mask[..., 0] |
| visual_pos_masks = video_mask |
| deepstack_visual_embeds = deepstack_video_features |
|
|
| outputs = self.llm( |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| logits_to_keep=logits_to_keep, |
| deepstack_features=deepstack_visual_embeds, |
| visual_mask=visual_pos_masks, |
| **lm_kwargs, |
| ) |
|
|
| logits = outputs[0] |
|
|
| loss = None |
| if labels is not None: |
| |
| logits = logits.float() |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| |
| loss_fct = CrossEntropyLoss() |
| shift_logits = shift_logits.view(-1, self.config.vocab_size) |
| shift_labels = shift_labels.view(-1) |
| |
| shift_labels = shift_labels.to(shift_logits.device) |
| loss = loss_fct(shift_logits, shift_labels) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| def get_mrope_position_ids( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| image_grid_thw: Optional[torch.LongTensor] = None, |
| video_grid_thw: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| spatial_merge_size: Optional[int] = 2, |
| image_token_id: Optional[int] = 14, |
| video_token_id: Optional[int] = 102399, |
| vision_start_token_id: Optional[int] = 102397, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Calculate the 3D rope index based on image and video's temporal, height and width in LLM. |
| """ |
| mrope_position_deltas = [] |
| if input_ids is not None and (image_grid_thw is not None): |
| total_input_ids = input_ids |
| if attention_mask is None: |
| attention_mask = torch.ones_like(total_input_ids) |
| position_ids = torch.ones( |
| 3, |
| input_ids.shape[0], |
| input_ids.shape[1], |
| dtype=input_ids.dtype, |
| device=input_ids.device, |
| ) |
| image_index, video_index = 0, 0 |
| for i, input_ids in enumerate(total_input_ids): |
| input_ids = input_ids[attention_mask[i].to(input_ids.device) == 1] |
| image_nums = 0 |
| vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze( |
| 1 |
| ) |
| vision_tokens = input_ids[vision_start_indices + 1] |
| image_nums = (vision_tokens == image_token_id).sum() |
| video_nums = (vision_tokens == video_token_id).sum() |
| input_tokens = input_ids.tolist() |
| llm_pos_ids_list: list = [] |
| st = 0 |
| remain_images, remain_videos = image_nums, video_nums |
| for _ in range(image_nums + video_nums): |
| if image_token_id in input_tokens and remain_images > 0: |
| ed_image = input_tokens.index(image_token_id, st) |
| else: |
| ed_image = len(input_tokens) + 1 |
| if video_token_id in input_tokens and remain_videos > 0: |
| ed_video = input_tokens.index(video_token_id, st) |
| else: |
| ed_video = len(input_tokens) + 1 |
| if ed_image < ed_video: |
| t, h, w = ( |
| image_grid_thw[image_index][0], |
| image_grid_thw[image_index][1], |
| image_grid_thw[image_index][2], |
| ) |
| image_index += 1 |
| remain_images -= 1 |
| ed = ed_image |
|
|
| else: |
| t, h, w = ( |
| video_grid_thw[video_index][0], |
| video_grid_thw[video_index][1], |
| video_grid_thw[video_index][2], |
| ) |
| video_index += 1 |
| remain_videos -= 1 |
| ed = ed_video |
|
|
| llm_grid_t, llm_grid_h, llm_grid_w = ( |
| t.item(), |
| h.item() // spatial_merge_size, |
| w.item() // spatial_merge_size, |
| ) |
| text_len = ed - st |
|
|
| st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 |
| llm_pos_ids_list.append( |
| torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx |
| ) |
|
|
| t_index = ( |
| torch.arange(llm_grid_t) |
| .view(-1, 1) |
| .expand(-1, llm_grid_h * llm_grid_w) |
| .flatten() |
| ) |
| h_index = ( |
| torch.arange(llm_grid_h) |
| .view(1, -1, 1) |
| .expand(llm_grid_t, -1, llm_grid_w) |
| .flatten() |
| ) |
| w_index = ( |
| torch.arange(llm_grid_w) |
| .view(1, 1, -1) |
| .expand(llm_grid_t, llm_grid_h, -1) |
| .flatten() |
| ) |
| if self.spatial_reset: |
| mm_pos_ids = torch.stack([t_index, h_index, w_index]) |
| vision_end_token_id = torch.full( |
| (3, 1), torch.max(mm_pos_ids).item() + 1 + text_len + st_idx |
| ) |
| |
| mm_pos_ids[0] += text_len + st_idx |
| llm_pos_ids_list.append( |
| torch.cat([mm_pos_ids, vision_end_token_id], dim=1) |
| ) |
| st = ed + llm_grid_t * llm_grid_h * llm_grid_w + 1 |
| else: |
| llm_pos_ids_list.append( |
| torch.stack([t_index, h_index, w_index]) + text_len + st_idx |
| ) |
| st = ed + llm_grid_t * llm_grid_h * llm_grid_w |
|
|
| if st < len(input_tokens): |
| st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 |
| text_len = len(input_tokens) - st |
| llm_pos_ids_list.append( |
| torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx |
| ) |
|
|
| llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) |
|
|
| position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( |
| position_ids.device |
| ) |
| mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) |
|
|
| mrope_position_deltas = torch.tensor( |
| mrope_position_deltas, device=input_ids.device |
| ).unsqueeze(1) |
| return position_ids, mrope_position_deltas |
| else: |
| position_ids = ( |
| torch.arange(input_ids.shape[1], device=input_ids.device) |
| .view(1, 1, -1) |
| .expand(3, input_ids.shape[0], -1) |
| ) |
| mrope_position_deltas = torch.zeros( |
| [input_ids.shape[0], 1], |
| device=input_ids.device, |
| dtype=input_ids.dtype, |
| ) |
|
|
| return position_ids, mrope_position_deltas |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids, |
| past_key_values=None, |
| inputs_embeds=None, |
| pixel_values=None, |
| pixel_values_video=None, |
| attention_mask=None, |
| cache_position=None, |
| logits_to_keep=None, |
| image_grid_thw=None, |
| video_grid_thw=None, |
| **kwargs, |
| ): |
| model_inputs = self.llm.prepare_inputs_for_generation( |
| input_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| cache_position=cache_position, |
| logits_to_keep=logits_to_keep, |
| **kwargs, |
| ) |
| if self.use_mrope: |
| input_ids = model_inputs["input_ids"] |
| attention_mask = model_inputs["attention_mask"] |
| cache_position = model_inputs["cache_position"] |
| if cache_position[0] == 0 or self.llm.rope_deltas is None: |
| position_ids, rope_deltas = self.get_mrope_position_ids( |
| input_ids=input_ids, |
| image_grid_thw=image_grid_thw, |
| video_grid_thw=video_grid_thw, |
| attention_mask=attention_mask, |
| spatial_merge_size=self.visual.spatial_merge_size, |
| ) |
| self.llm.rope_deltas = rope_deltas |
| else: |
| batch_size, seq_length = input_ids.shape |
| position_ids = torch.arange(seq_length, device=input_ids.device) |
| position_ids = position_ids.view(1, -1).expand(batch_size, -1) |
| delta = ( |
| cache_position[0] + self.llm.rope_deltas if cache_position is not None else 0 |
| ) |
| delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) |
| delta = delta.to(position_ids.device) |
| position_ids = position_ids.add(delta) |
| position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) |
| model_inputs["position_ids"] = position_ids |
|
|
| if cache_position[0] == 0: |
| |
| |
| model_inputs["pixel_values"] = pixel_values |
| model_inputs["pixel_values_video"] = pixel_values_video |
| model_inputs["image_grid_thw"] = image_grid_thw |
| model_inputs["video_grid_thw"] = video_grid_thw |
|
|
| return model_inputs |
|
|
|
|
| class MRopeRotaryEmbedding(nn.Module): |
| def __init__( |
| self, |
| config: Sarashina2VisionConfig, |
| device=None, |
| ): |
| super().__init__() |
|
|
| self.mrope_section = config.rope_scaling.get("mrope_section") |
| self.mrope_interleaved = config.rope_scaling.get("mrope_interleaved", False) |
| self.rope_type = config.rope_scaling.get("rope_type") |
| if self.rope_type not in ROPE_INIT_FUNCTIONS: |
| self.rope_type = "default" |
| self.max_seq_len_cached = config.text_config.max_position_embeddings |
| self.original_max_seq_len = config.text_config.max_position_embeddings |
|
|
| self.config = config.text_config |
| self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
|
|
| inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self.original_inv_freq = self.inv_freq |
|
|
| @torch.no_grad() |
| @dynamic_rope_update |
| def forward(self, x, position_ids): |
| |
| |
| inv_freq_expanded = ( |
| self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) |
| ) |
| position_ids_expanded = position_ids[:, :, None, :].float() |
|
|
| device_type = ( |
| x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
| ) |
| if self.mrope_interleaved: |
| with torch.autocast(device_type=device_type, enabled=False): |
| freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) |
| freqs = self.apply_interleaved_mrope(freqs) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| cos = emb.cos() * self.attention_scaling |
| sin = emb.sin() * self.attention_scaling |
| else: |
| with torch.autocast(device_type=device_type, enabled=False): |
| freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| cos = emb.cos() * self.attention_scaling |
| sin = emb.sin() * self.attention_scaling |
|
|
| mrope_section = self.mrope_section * 2 |
| cos = torch.cat( |
| [m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1 |
| ) |
| sin = torch.cat( |
| [m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1 |
| ) |
|
|
| return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
| def apply_interleaved_mrope(self, freqs): |
| """Apply interleaved MRoPE to 3D rotary embeddings. |
| Reorganizes frequency layout from chunked [TTT...HHH...WWW] to |
| interleaved [THWTHWTHW...TT], preserving frequency continuity. |
| args: |
| freqs: (3, bs, seq_len, head_dim // 2) |
| returns: |
| freqs_t: (bs, seq_len, head_dim // 2) |
| """ |
| freqs_t = freqs[0] |
| for dim, offset in enumerate((1, 2), start=1): |
| length = self.mrope_section[dim] * 3 |
| idx = slice(offset, length, 3) |
| freqs_t[..., idx] = freqs[dim, ..., idx] |
| return freqs_t |
|
|
|
|
| AutoConfig.register("sarashina2_vision", Sarashina2VisionConfig) |
| AutoModelForCausalLM.register(Sarashina2VisionConfig, Sarashina2VisionForCausalLM) |
| Sarashina2VisionConfig.register_for_auto_class() |
| Sarashina2VisionForCausalLM.register_for_auto_class("AutoModelForCausalLM") |
|
|