| import math |
| import warnings |
| from functools import partial |
| from typing import List, Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn.functional as F |
| import torch.utils.checkpoint |
| from torch import nn |
|
|
|
|
| import copy |
| import os |
| import sys |
|
|
| dir_path = os.path.dirname(os.path.realpath(__file__)) |
| sys.path.insert(0, dir_path) |
|
|
| import transformers |
| from transformers.models.llama.modeling_llama import * |
| from transformers.configuration_utils import PretrainedConfig |
| from transformers.utils import logging |
|
|
| from .modeling_attn_mask_utils import _prepare_4d_causal_attention_mask |
| from .configuration_mplug_owl2 import LlamaConfig |
|
|
| class MultiwayNetwork(nn.Module): |
|
|
| def __init__(self, module_provider, num_multiway=2): |
| super(MultiwayNetwork, self).__init__() |
|
|
| self.multiway = torch.nn.ModuleList([module_provider() for _ in range(num_multiway)]) |
| |
| def forward(self, hidden_states, multiway_indices): |
|
|
| if len(self.multiway) == 1: |
| return self.multiway[0](hidden_states) |
|
|
| output_hidden_states = torch.empty_like(hidden_states) |
| |
| for idx, subway in enumerate(self.multiway): |
| local_indices = multiway_indices.eq(idx).nonzero(as_tuple=True) |
| hidden = hidden_states[local_indices].unsqueeze(1).contiguous() |
| if hidden.numel(): |
| output = subway(hidden) |
| if isinstance(output, tuple): |
| output = output[0] |
| output = output.squeeze(1) |
| output_hidden_states[local_indices] = output |
| |
| return output_hidden_states.contiguous() |
| |
|
|
| class LlamaAttention(nn.Module): |
| """Multi-headed attention from 'Attention Is All You Need' paper""" |
|
|
| def __init__(self, config: LlamaConfig): |
| super().__init__() |
| self.config = config |
| self.hidden_size = config.hidden_size |
| self.num_heads = config.num_attention_heads |
| self.head_dim = self.hidden_size // self.num_heads |
| self.num_key_value_heads = config.num_key_value_heads |
| self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
| self.max_position_embeddings = config.max_position_embeddings |
| self.rope_theta = config.rope_theta |
|
|
| if (self.head_dim * self.num_heads) != self.hidden_size: |
| raise ValueError( |
| f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" |
| f" and `num_heads`: {self.num_heads})." |
| ) |
| self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) |
| self.k_proj = MultiwayNetwork(module_provider=partial( |
| nn.Linear, in_features=self.hidden_size, out_features=self.num_key_value_heads * self.head_dim, bias=config.attention_bias) |
| ) |
| self.v_proj = MultiwayNetwork(module_provider=partial( |
| nn.Linear, in_features=self.hidden_size, out_features=self.num_key_value_heads * self.head_dim, bias=config.attention_bias) |
| ) |
| self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias) |
| self._init_rope() |
|
|
| def _init_rope(self): |
| if self.config.rope_scaling is None: |
| self.rotary_emb = LlamaRotaryEmbedding( |
| self.head_dim, |
| max_position_embeddings=self.max_position_embeddings, |
| base=self.rope_theta, |
| ) |
| else: |
| scaling_type = self.config.rope_scaling["type"] |
| scaling_factor = self.config.rope_scaling["factor"] |
| if scaling_type == "linear": |
| self.rotary_emb = LlamaLinearScalingRotaryEmbedding( |
| self.head_dim, |
| max_position_embeddings=self.max_position_embeddings, |
| scaling_factor=scaling_factor, |
| base=self.rope_theta, |
| ) |
| elif scaling_type == "dynamic": |
| self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( |
| self.head_dim, |
| max_position_embeddings=self.max_position_embeddings, |
| scaling_factor=scaling_factor, |
| base=self.rope_theta, |
| ) |
| else: |
| raise ValueError(f"Unknown RoPE scaling type {scaling_type}") |
|
|
| def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): |
| return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| modality_indicators: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| output_attentions: bool = False, |
| use_cache: bool = False, |
| padding_mask: Optional[torch.LongTensor] = None, |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| bsz, q_len, _ = hidden_states.size() |
|
|
| query_states = self.q_proj(hidden_states, ) |
| key_states = self.k_proj(hidden_states, modality_indicators) |
| value_states = self.v_proj(hidden_states, modality_indicators) |
|
|
| query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
| kv_seq_len = key_states.shape[-2] |
| if past_key_value is not None: |
| kv_seq_len += past_key_value[0].shape[-2] |
| cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) |
|
|
| if past_key_value is not None: |
| |
| key_states = torch.cat([past_key_value[0], key_states], dim=2) |
| value_states = torch.cat([past_key_value[1], value_states], dim=2) |
|
|
| past_key_value = (key_states, value_states) if use_cache else None |
|
|
| key_states = repeat_kv(key_states, self.num_key_value_groups) |
| value_states = repeat_kv(value_states, self.num_key_value_groups) |
|
|
| attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) |
|
|
| if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): |
| raise ValueError( |
| f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" |
| f" {attn_weights.size()}" |
| ) |
|
|
| if attention_mask is not None: |
| if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): |
| raise ValueError( |
| f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" |
| ) |
| attn_weights = attn_weights + attention_mask |
|
|
| |
| attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
| attn_output = torch.matmul(attn_weights, value_states) |
|
|
| if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): |
| raise ValueError( |
| f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" |
| f" {attn_output.size()}" |
| ) |
|
|
| attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
| attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) |
|
|
| attn_output = self.o_proj(attn_output) |
|
|
| if not output_attentions: |
| attn_weights = None |
|
|
| return attn_output, attn_weights, past_key_value |
|
|
|
|
|
|
| class LlamaDecoderLayer(nn.Module): |
| def __init__(self, config: LlamaConfig): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
| self.self_attn = LlamaAttention(config=config) |
| self.mlp = LlamaMLP(config) |
| self.input_layernorm = MultiwayNetwork(module_provider=partial( |
| LlamaRMSNorm, hidden_size=config.hidden_size, eps=config.rms_norm_eps |
| )) |
| self.post_attention_layernorm = MultiwayNetwork(module_provider=partial( |
| LlamaRMSNorm, hidden_size=config.hidden_size, eps=config.rms_norm_eps |
| )) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| modality_indicators: torch.Tensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| output_attentions: Optional[bool] = False, |
| use_cache: Optional[bool] = False, |
| ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: |
| """ |
| Args: |
| hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
| attention_mask (`torch.FloatTensor`, *optional*): attention mask of size |
| `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
| returned tensors for more detail. |
| use_cache (`bool`, *optional*): |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding |
| (see `past_key_values`). |
| past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states |
| """ |
|
|
| residual = hidden_states |
|
|
| hidden_states = self.input_layernorm(hidden_states, modality_indicators) |
|
|
| |
| hidden_states, self_attn_weights, present_key_value = self.self_attn( |
| hidden_states=hidden_states, |
| modality_indicators=modality_indicators, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_value, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| ) |
| hidden_states = residual + hidden_states |
|
|
| |
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states, modality_indicators) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| outputs = (hidden_states,) |
|
|
| if output_attentions: |
| outputs += (self_attn_weights,) |
|
|
| if use_cache: |
| outputs += (present_key_value,) |
|
|
| return outputs |
|
|
|
|
| def model_forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| modality_indicators: torch.Tensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, BaseModelOutputWithPast]: |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| if input_ids is not None and inputs_embeds is not None: |
| raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") |
| elif input_ids is not None: |
| batch_size, seq_length = input_ids.shape |
| elif inputs_embeds is not None: |
| batch_size, seq_length, _ = inputs_embeds.shape |
| else: |
| raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") |
|
|
| seq_length_with_past = seq_length |
| past_key_values_length = 0 |
|
|
| if past_key_values is not None: |
| past_key_values_length = past_key_values[0][0].shape[2] |
| seq_length_with_past = seq_length_with_past + past_key_values_length |
|
|
| if position_ids is None: |
| device = input_ids.device if input_ids is not None else inputs_embeds.device |
| position_ids = torch.arange( |
| past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device |
| ) |
| position_ids = position_ids.unsqueeze(0).view(-1, seq_length) |
| else: |
| position_ids = position_ids.view(-1, seq_length).long() |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) |
| |
| if attention_mask is None: |
| attention_mask = torch.ones( |
| (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device |
| ) |
| attention_mask = self._prepare_decoder_attention_mask( |
| attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length |
| ) |
|
|
| hidden_states = inputs_embeds |
|
|
| if self.gradient_checkpointing and self.training: |
| if use_cache: |
| logger.warning_once( |
| "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
| ) |
| use_cache = False |
|
|
| |
| all_hidden_states = () if output_hidden_states else None |
| all_self_attns = () if output_attentions else None |
| next_decoder_cache = () if use_cache else None |
|
|
| for idx, decoder_layer in enumerate(self.layers): |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| past_key_value = past_key_values[idx] if past_key_values is not None else None |
|
|
| if self.gradient_checkpointing and self.training: |
|
|
| def create_custom_forward(module): |
| def custom_forward(*inputs): |
| |
| return module(*inputs, past_key_value, output_attentions) |
|
|
| return custom_forward |
|
|
| layer_outputs = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(decoder_layer), |
| hidden_states, |
| modality_indicators, |
| attention_mask, |
| position_ids, |
| ) |
| else: |
| layer_outputs = decoder_layer( |
| hidden_states, |
| modality_indicators=modality_indicators, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_value, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| ) |
|
|
| hidden_states = layer_outputs[0] |
|
|
| if use_cache: |
| next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) |
|
|
| if output_attentions: |
| all_self_attns += (layer_outputs[1],) |
|
|
| hidden_states = self.norm(hidden_states) |
|
|
| |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| next_cache = next_decoder_cache if use_cache else None |
| if not return_dict: |
| return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=next_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attns, |
| ) |
|
|
|
|
| def causal_model_forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| modality_indicators: torch.Tensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
| r""" |
| Args: |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| |
| Returns: |
| |
| Example: |
| |
| ```python |
| >>> from transformers import AutoTokenizer, LlamaForCausalLM |
| |
| >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) |
| >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) |
| |
| >>> prompt = "Hey, are you conscious? Can you talk to me?" |
| >>> inputs = tokenizer(prompt, return_tensors="pt") |
| |
| >>> # Generate |
| >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
| >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
| ```""" |
|
|
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| outputs = self.model( |
| input_ids=input_ids, |
| modality_indicators=modality_indicators, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| hidden_states = outputs[0] |
| if self.config.pretraining_tp > 1: |
| lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) |
| logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] |
| logits = torch.cat(logits, dim=-1) |
| else: |
| logits = self.lm_head(hidden_states) |
| logits = logits.float() |
|
|
| loss = None |
| if labels is not None: |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| |
| loss_fct = CrossEntropyLoss() |
| shift_logits = shift_logits.view(-1, self.config.vocab_size) |
| shift_labels = shift_labels.view(-1) |
| |
| shift_labels = shift_labels.to(shift_logits.device) |
| loss = loss_fct(shift_logits, shift_labels) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| def replace_llama_modality_adaptive(): |
| transformers.models.llama.configuration_llama.LlamaConfig = LlamaConfig |
| transformers.models.llama.modeling_llama.LlamaAttention = LlamaAttention |
| transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer |
| transformers.models.llama.modeling_llama.LlamaModel.forward = model_forward |
| transformers.models.llama.modeling_llama.LlamaForCausalLM.forward = causal_model_forward |
|
|
| |
| if __name__ == "__main__": |
| replace_llama_modality_adaptive() |
| config = transformers.LlamaConfig.from_pretrained('/cpfs01/shared/public/test/vicuna-7b-v1.5/') |
| model = transformers.LlamaForCausalLM(config) |
| print(model) |