| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import math |
| from typing import List, Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
| from torch.nn import CrossEntropyLoss |
|
|
| from transformers import AutoConfig, AutoModelForCausalLM, \ |
| LlamaConfig, LlamaModel |
|
|
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from transformers.generation.utils import GenerateOutput |
|
|
| from ..llava_arch import LlavaMetaModel, LlavaMetaForCausalLM |
| |
|
|
| from transformers.models.llama import LlamaPreTrainedModel |
| from transformers.cache_utils import Cache, DynamicCache |
|
|
| from transformers.modeling_attn_mask_utils import ( |
| AttentionMaskConverter, |
| _prepare_4d_attention_mask, |
| _prepare_4d_causal_attention_mask, |
| _prepare_4d_causal_attention_mask_for_sdpa, |
| ) |
| from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast |
|
|
| from transformers.models.llama.modeling_llama import ( |
| LlamaAttention, |
| LlamaFlashAttention2, |
| LlamaSdpaAttention, |
| LlamaMLP, |
| LlamaRMSNorm, |
| apply_rotary_pos_emb, |
| ) |
|
|
| class LlavaConfig(LlamaConfig): |
| model_type = "llava_llama" |
|
|
| LLAMA_ATTENTION_CLASSES = { |
| "eager": LlamaAttention, |
| "flash_attention_2": LlamaFlashAttention2, |
| "sdpa": LlamaSdpaAttention, |
| } |
|
|
|
|
| def reverse_cumsum(x: torch.Tensor) -> torch.Tensor: |
| return x + torch.sum(x, dim=-1, keepdims=True) - torch.cumsum(x, dim=-1) |
|
|
| def make_mask_post_last_voco( |
| inputs: torch.Tensor, |
| voco_token: int, |
| pad_token: Optional[int] = None, |
| dtype=torch.int64, |
| ) -> torch.Tensor: |
| mask = reverse_cumsum(inputs == voco_token) >= 1 |
| if pad_token is not None: |
| mask = mask & (inputs != pad_token) |
| return mask.type(dtype) |
|
|
| def make_mask_pre_first_voco( |
| inputs: torch.Tensor, |
| voco_token: int, |
| pad_token: Optional[int] = None, |
| dtype=torch.int64, |
| ) -> torch.Tensor: |
| mask = (inputs == voco_token).cumsum(-1) >= 1 |
| if pad_token is not None: |
| mask = mask & (inputs != pad_token) |
| return mask.type(dtype) |
|
|
| def make_voco_mask_llava( |
| inputs: torch.Tensor, |
| voco_token: int, |
| dtype=torch.int64, |
| ) -> torch.Tensor: |
|
|
| pre_voco_mask = make_mask_post_last_voco(inputs, voco_token, dtype=torch.bool)[ |
| :, None, None |
| ] |
| |
| post_voco_mask = make_mask_pre_first_voco(inputs, voco_token, dtype=torch.bool)[ |
| :, None, None |
| ] |
| pre_voco_time_mask = pre_voco_mask.permute((0, 1, 3, 2)) |
| mask = torch.where(pre_voco_time_mask, pre_voco_mask, post_voco_mask) |
| has_voco = (inputs == voco_token).any(-1)[:, None, None, None] |
| mask = torch.where(has_voco, mask, True) |
| return mask.type(dtype) |
|
|
| class LlamaDecoderLayer(nn.Module): |
| def __init__(self, config: LlamaConfig, layer_idx: int): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
|
|
| self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) |
| self.mlp = LlamaMLP(config) |
| self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| output_attentions: Optional[bool] = False, |
| use_cache: Optional[bool] = False, |
| **kwargs, |
| ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
| """ |
| Args: |
| hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
| attention_mask (`torch.FloatTensor`, *optional*): |
| attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, |
| query_sequence_length, key_sequence_length)` if default attention is used. |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
| returned tensors for more detail. |
| use_cache (`bool`, *optional*): |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding |
| (see `past_key_values`). |
| past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states |
| """ |
| if "padding_mask" in kwargs: |
| warnings.warn( |
| "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" |
| ) |
|
|
| residual = hidden_states |
|
|
| hidden_states = self.input_layernorm(hidden_states) |
|
|
| |
| hidden_states, self_attn_weights, present_key_value = self.self_attn( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_value, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| **kwargs, |
| ) |
| hidden_states = residual + hidden_states |
|
|
| |
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| outputs = (hidden_states,) |
|
|
| if output_attentions: |
| outputs += (self_attn_weights,) |
|
|
| if use_cache: |
| outputs += (present_key_value,) |
|
|
| return outputs |
|
|
|
|
| class LlamaModel(LlamaPreTrainedModel): |
| """ |
| Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] |
| |
| Args: |
| config: LlamaConfig |
| """ |
|
|
| def __init__(self, config: LlamaConfig): |
| super().__init__(config) |
| self.padding_idx = config.pad_token_id |
| self.vocab_size = config.vocab_size |
|
|
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
| self.layers = nn.ModuleList( |
| [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
| ) |
| self._use_sdpa = config._attn_implementation == "sdpa" |
| self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" |
| self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
| self.gradient_checkpointing = False |
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.embed_tokens = value |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| voco_loc_back=None |
| ) -> Union[Tuple, BaseModelOutputWithPast]: |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| if input_ids is not None and inputs_embeds is not None: |
| raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
| elif input_ids is not None: |
| batch_size, seq_length = input_ids.shape[:2] |
| elif inputs_embeds is not None: |
| batch_size, seq_length = inputs_embeds.shape[:2] |
| else: |
| raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
| if self.gradient_checkpointing and self.training: |
| if use_cache: |
| logger.warning_once( |
| "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
| ) |
| use_cache = False |
|
|
| past_key_values_length = 0 |
| if use_cache: |
| use_legacy_cache = not isinstance(past_key_values, Cache) |
| if use_legacy_cache: |
| past_key_values = DynamicCache.from_legacy_cache(past_key_values) |
| past_key_values_length = past_key_values.get_usable_length(seq_length) |
|
|
| if position_ids is None: |
| device = input_ids.device if input_ids is not None else inputs_embeds.device |
| position_ids = torch.arange( |
| past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device |
| ) |
| position_ids = position_ids.unsqueeze(0) |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) |
|
|
| if self._use_flash_attention_2: |
| |
| attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None |
| elif self._use_sdpa and not output_attentions: |
| |
| |
| _2d_attention_mask_b = attention_mask |
|
|
| |
| |
| |
| |
| |
| |
|
|
| attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( |
| attention_mask, |
| (batch_size, seq_length + past_key_values_length), |
| inputs_embeds, |
| 0, |
| ) |
|
|
| mask_type = attention_mask.dtype |
| mask_min = torch.finfo(mask_type).min |
|
|
| first_false_indices = (_2d_attention_mask_b == False).int().argmin(dim=1) |
|
|
| _2d_attention_mask = _2d_attention_mask_b.to(inputs_embeds.dtype) |
| for idx, locs in enumerate(voco_loc_back): |
| for loc in locs: |
| _2d_attention_mask[idx][seq_length - 1 - loc] = 32000 |
| attention_mask_voco = make_voco_mask_llava( |
| _2d_attention_mask, |
| 32000, |
| inputs_embeds.dtype |
| ) |
| attention_mask_voco = torch.where(attention_mask_voco == 1, torch.tensor(0), mask_min) |
| attention_mask = attention_mask + attention_mask_voco |
| attention_mask = torch.where(attention_mask < 0, mask_min, torch.tensor(0)).to(inputs_embeds.dtype) |
|
|
| for b in range(attention_mask.size(0)): |
| attention_mask[b, 0, :first_false_indices[b], :] = 0 |
|
|
| else: |
| |
| attention_mask = _prepare_4d_causal_attention_mask( |
| attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length |
| ) |
|
|
| attention_mask = attention_mask[:,:,-seq_length:,:] |
| |
| hidden_states = inputs_embeds |
|
|
| |
| all_hidden_states = () if output_hidden_states else None |
| all_self_attns = () if output_attentions else None |
| next_decoder_cache = None |
|
|
| for decoder_layer in self.layers: |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| if self.gradient_checkpointing and self.training: |
| layer_outputs = self._gradient_checkpointing_func( |
| decoder_layer.__call__, |
| hidden_states, |
| attention_mask, |
| position_ids, |
| past_key_values, |
| output_attentions, |
| use_cache, |
| ) |
| else: |
| layer_outputs = decoder_layer( |
| hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_values, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| ) |
|
|
| hidden_states = layer_outputs[0] |
|
|
| if use_cache: |
| next_decoder_cache = layer_outputs[2 if output_attentions else 1] |
|
|
| if output_attentions: |
| all_self_attns += (layer_outputs[1],) |
|
|
| hidden_states = self.norm(hidden_states) |
|
|
| |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| next_cache = None |
| if use_cache: |
| next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache |
| if not return_dict: |
| return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=next_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attns, |
| ) |
|
|
|
|
| class LlavaLlamaModel(LlavaMetaModel, LlamaModel): |
| config_class = LlavaConfig |
|
|
| def __init__(self, config: LlamaConfig): |
| super(LlavaLlamaModel, self).__init__(config) |
|
|
|
|
| |
| class LlavaLlamaForCausalLM(LlamaPreTrainedModel, LlavaMetaForCausalLM): |
| _tied_weights_keys = ["lm_head.weight"] |
| config_class = LlavaConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = LlavaLlamaModel(config) |
| self.pretraining_tp = config.pretraining_tp |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| |
| self.post_init() |
|
|
| def get_model(self): |
| return self.model |
| |
| def get_input_embeddings(self): |
| return self.model.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.model.embed_tokens = value |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def set_decoder(self, decoder): |
| self.model = decoder |
|
|
| def get_decoder(self): |
| return self.model |
|
|
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| images: Optional[torch.FloatTensor] = None, |
| image_sizes: Optional[List[List[int]]] = None, |
| return_dict: Optional[bool] = None, |
| voco_loc_back=None, |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
|
|
| if inputs_embeds is None: |
| ( |
| input_ids, |
| position_ids, |
| attention_mask, |
| past_key_values, |
| inputs_embeds, |
| labels, |
| voco_loc_back |
| ) = self.prepare_inputs_labels_for_multimodal( |
| input_ids, |
| position_ids, |
| attention_mask, |
| past_key_values, |
| labels, |
| images, |
| image_sizes, |
| voco_loc_back |
| ) |
|
|
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| outputs = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| voco_loc_back=voco_loc_back |
| ) |
|
|
| hidden_states = outputs[0] |
| if self.config.pretraining_tp > 1: |
| lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) |
| logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] |
| logits = torch.cat(logits, dim=-1) |
| else: |
| logits = self.lm_head(hidden_states) |
| logits = logits.float() |
|
|
| loss = None |
| if labels is not None: |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| |
| loss_fct = CrossEntropyLoss() |
| shift_logits = shift_logits.view(-1, self.config.vocab_size) |
| shift_labels = shift_labels.view(-1) |
| |
| shift_labels = shift_labels.to(shift_logits.device) |
| loss = loss_fct(shift_logits, shift_labels) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| inputs: Optional[torch.Tensor] = None, |
| images: Optional[torch.Tensor] = None, |
| image_sizes: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> Union[GenerateOutput, torch.LongTensor]: |
| position_ids = kwargs.pop("position_ids", None) |
| attention_mask = kwargs.pop("attention_mask", None) |
| if "inputs_embeds" in kwargs: |
| raise NotImplementedError("`inputs_embeds` is not supported") |
|
|
| if images is not None: |
| ( |
| inputs, |
| position_ids, |
| attention_mask, |
| _, |
| inputs_embeds, |
| _, |
| voco_loc_back |
| ) = self.prepare_inputs_labels_for_multimodal( |
| inputs, |
| position_ids, |
| attention_mask, |
| None, |
| None, |
| images, |
| image_sizes=image_sizes |
| ) |
| else: |
| inputs_embeds = self.get_model().embed_tokens(inputs) |
|
|
| return super().generate( |
| position_ids=position_ids, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, |
| voco_loc_back=voco_loc_back, |
| **kwargs |
| ) |
|
|
| def prepare_inputs_for_generation(self, input_ids, past_key_values=None, |
| inputs_embeds=None, **kwargs): |
| images = kwargs.pop("images", None) |
| image_sizes = kwargs.pop("image_sizes", None) |
| voco_loc_back = kwargs.pop("voco_loc_back", None) |
| |
| inputs = self.prepare_inputs_for_generation_llama( |
| input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, **kwargs |
| ) |
|
|
| if voco_loc_back is not None: |
| inputs['voco_loc_back'] = voco_loc_back |
| if images is not None: |
| inputs['images'] = images |
| if image_sizes is not None: |
| inputs['image_sizes'] = image_sizes |
| return inputs |
|
|
| def prepare_inputs_for_generation_llama( |
| self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs |
| ): |
| if past_key_values is not None: |
| if isinstance(past_key_values, Cache): |
| cache_length = past_key_values.get_seq_length() |
| past_length = past_key_values.seen_tokens |
| max_cache_length = past_key_values.get_max_length() |
| else: |
| cache_length = past_length = past_key_values[0][0].shape[2] |
| max_cache_length = None |
|
|
| |
| |
| |
| |
| if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: |
| input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] |
| |
| |
| elif past_length < input_ids.shape[1]: |
| input_ids = input_ids[:, past_length:] |
| |
|
|
| |
| if ( |
| max_cache_length is not None |
| and attention_mask is not None |
| and cache_length + input_ids.shape[1] > max_cache_length |
| ): |
| attention_mask = attention_mask[:, -max_cache_length:] |
|
|
| position_ids = kwargs.get("position_ids", None) |
| if attention_mask is not None and position_ids is None: |
| |
| position_ids = attention_mask.long().cumsum(-1) - 1 |
| position_ids.masked_fill_(attention_mask == 0, 1) |
| if past_key_values: |
| position_ids = position_ids[:, -input_ids.shape[1] :] |
|
|
| |
| if inputs_embeds is not None and past_key_values is None: |
| model_inputs = {"inputs_embeds": inputs_embeds} |
| else: |
| model_inputs = {"input_ids": input_ids} |
|
|
| model_inputs.update( |
| { |
| "position_ids": position_ids, |
| "past_key_values": past_key_values, |
| "use_cache": kwargs.get("use_cache"), |
| "attention_mask": attention_mask, |
| } |
| ) |
| return model_inputs |
|
|
| @staticmethod |
| def _reorder_cache(past_key_values, beam_idx): |
| reordered_past = () |
| for layer_past in past_key_values: |
| reordered_past += ( |
| tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), |
| ) |
| return reordered_past |
|
|
| AutoConfig.register("llava_llama", LlavaConfig) |
| AutoModelForCausalLM.register(LlavaConfig, LlavaLlamaForCausalLM) |
|
|