| |
| import torch |
| import torch.nn as nn |
|
|
| from typing import Optional, List, Union, Tuple |
| from transformers import MistralModel, MistralForCausalLM |
| from transformers.utils import logging |
| from transformers.cache_utils import Cache, DynamicCache |
| from transformers.configuration_utils import PretrainedConfig |
| from transformers.modeling_outputs import ( |
| CausalLMOutputWithPast, |
| BaseModelOutputWithPast, |
| ) |
| from transformers.modeling_attn_mask_utils import ( |
| _prepare_4d_causal_attention_mask, |
| _prepare_4d_causal_attention_mask_for_sdpa, |
| ) |
| from transformers.models.mistral.modeling_mistral import ( |
| MistralDecoderLayer, |
| MistralRMSNorm, |
| ) |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class SoloConfig(PretrainedConfig): |
| r""" |
| This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an |
| Mistral model according to the specified arguments, defining the model architecture. Instantiating a configuration |
| with the defaults will yield a similar configuration to that of the Mistral-7B-v0.1 or Mistral-7B-Instruct-v0.1. |
| |
| [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) |
| [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) |
| |
| Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the |
| documentation from [`PretrainedConfig`] for more information. |
| |
| |
| Args: |
| vocab_size (`int`, *optional*, defaults to 32000): |
| Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the |
| `inputs_ids` passed when calling [`MistralModel`] |
| hidden_size (`int`, *optional*, defaults to 4096): |
| Dimension of the hidden representations. |
| intermediate_size (`int`, *optional*, defaults to 14336): |
| Dimension of the MLP representations. |
| num_hidden_layers (`int`, *optional*, defaults to 32): |
| Number of hidden layers in the Transformer encoder. |
| num_attention_heads (`int`, *optional*, defaults to 32): |
| Number of attention heads for each attention layer in the Transformer encoder. |
| num_key_value_heads (`int`, *optional*, defaults to 8): |
| This is the number of key_value heads that should be used to implement Grouped Query Attention. If |
| `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if |
| `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When |
| converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed |
| by meanpooling all the original heads within that group. For more details checkout [this |
| paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. |
| hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): |
| The non-linear activation function (function or string) in the decoder. |
| max_position_embeddings (`int`, *optional*, defaults to `4096*32`): |
| The maximum sequence length that this model might ever be used with. Mistral's sliding window attention |
| allows sequence of up to 4096*32 tokens. |
| initializer_range (`float`, *optional*, defaults to 0.02): |
| The standard deviation of the truncated_normal_initializer for initializing all weight matrices. |
| rms_norm_eps (`float`, *optional*, defaults to 1e-06): |
| The epsilon used by the rms normalization layers. |
| use_cache (`bool`, *optional*, defaults to `True`): |
| Whether or not the model should return the last key/values attentions (not used by all models). Only |
| relevant if `config.is_decoder=True`. |
| pad_token_id (`int`, *optional*): |
| The id of the padding token. |
| bos_token_id (`int`, *optional*, defaults to 1): |
| The id of the "beginning-of-sequence" token. |
| eos_token_id (`int`, *optional*, defaults to 2): |
| The id of the "end-of-sequence" token. |
| tie_word_embeddings (`bool`, *optional*, defaults to `False`): |
| Whether the model's input and output word embeddings should be tied. |
| rope_theta (`float`, *optional*, defaults to 10000.0): |
| The base period of the RoPE embeddings. |
| sliding_window (`int`, *optional*, defaults to 4096): |
| Sliding window attention window size. If not specified, will default to `4096`. |
| attention_dropout (`float`, *optional*, defaults to 0.0): |
| The dropout ratio for the attention probabilities. |
| vision_patch_size (`int`, *optional*, defaults to 32): |
| The size of the vision patch. |
| |
| ```python |
| >>> from transformers import MultimodalMistralModel, SoloConfig |
| |
| >>> # Initializing a Mistral 7B style configuration |
| >>> configuration = SoloConfig() |
| |
| >>> # Initializing a model from the Mistral 7B style configuration |
| >>> model = MultimodalMistralModel(configuration) |
| |
| >>> # Accessing the model configuration |
| >>> configuration = model.config |
| ```""" |
|
|
| model_type = "mistral" |
| keys_to_ignore_at_inference = ["past_key_values"] |
|
|
| def __init__( |
| self, |
| vocab_size=32000, |
| hidden_size=4096, |
| intermediate_size=14336, |
| num_hidden_layers=32, |
| num_attention_heads=32, |
| num_key_value_heads=8, |
| hidden_act="silu", |
| max_position_embeddings=4096 * 32, |
| initializer_range=0.02, |
| rms_norm_eps=1e-6, |
| use_cache=True, |
| pad_token_id=None, |
| bos_token_id=1, |
| eos_token_id=2, |
| tie_word_embeddings=False, |
| rope_theta=10000.0, |
| sliding_window=4096, |
| attention_dropout=0.0, |
| vision_patch_size=32, |
| **kwargs, |
| ): |
| self.vocab_size = vocab_size |
| self.max_position_embeddings = max_position_embeddings |
| self.hidden_size = hidden_size |
| self.intermediate_size = intermediate_size |
| self.num_hidden_layers = num_hidden_layers |
| self.num_attention_heads = num_attention_heads |
| self.sliding_window = sliding_window |
|
|
| |
| if num_key_value_heads is None: |
| num_key_value_heads = num_attention_heads |
|
|
| self.num_key_value_heads = num_key_value_heads |
| self.hidden_act = hidden_act |
| self.initializer_range = initializer_range |
| self.rms_norm_eps = rms_norm_eps |
| self.use_cache = use_cache |
| self.rope_theta = rope_theta |
| self.attention_dropout = attention_dropout |
| self.vision_patch_size = vision_patch_size |
|
|
| super().__init__( |
| pad_token_id=pad_token_id, |
| bos_token_id=bos_token_id, |
| eos_token_id=eos_token_id, |
| tie_word_embeddings=tie_word_embeddings, |
| **kwargs, |
| ) |
|
|
|
|
| class MultimodalMistralModel(MistralModel): |
| def __init__(self, config: SoloConfig): |
| |
| super(MistralModel, self).__init__(config) |
|
|
| |
| self.padding_idx = config.pad_token_id |
| self.vocab_size = config.vocab_size |
|
|
| self.embed_tokens = nn.Embedding( |
| config.vocab_size, config.hidden_size, self.padding_idx |
| ) |
| self.layers = nn.ModuleList( |
| [ |
| MistralDecoderLayer(config, layer_idx) |
| for layer_idx in range(config.num_hidden_layers) |
| ] |
| ) |
| self._attn_implementation = config._attn_implementation |
| self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| |
|
|
| |
| assert config.vision_patch_size == 32 |
| assert config.hidden_size == 4096 |
| self.vis_embed = nn.Linear( |
| config.vision_patch_size * config.vision_patch_size * 3, |
| config.hidden_size, |
| bias=False, |
| ) |
| |
|
|
| self.gradient_checkpointing = False |
| |
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| vision_patch_indices: torch.LongTensor = None, |
| vision_patches: torch.FloatTensor = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, BaseModelOutputWithPast]: |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| if input_ids is not None and inputs_embeds is not None: |
| raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") |
| elif input_ids is not None: |
| batch_size, seq_length = input_ids.shape |
| elif inputs_embeds is not None: |
| batch_size, seq_length, _ = inputs_embeds.shape |
| else: |
| raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") |
|
|
| if self.gradient_checkpointing and self.training: |
| if use_cache: |
| logger.warning_once( |
| "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
| ) |
| use_cache = False |
|
|
| past_key_values_length = 0 |
|
|
| if use_cache: |
| use_legacy_cache = not isinstance(past_key_values, Cache) |
| if use_legacy_cache: |
| past_key_values = DynamicCache.from_legacy_cache(past_key_values) |
| past_key_values_length = past_key_values.get_usable_length(seq_length) |
|
|
| if position_ids is None: |
| device = input_ids.device if input_ids is not None else inputs_embeds.device |
| position_ids = torch.arange( |
| past_key_values_length, |
| seq_length + past_key_values_length, |
| dtype=torch.long, |
| device=device, |
| ) |
| position_ids = position_ids.unsqueeze(0).view(-1, seq_length) |
| else: |
| position_ids = position_ids.view(-1, seq_length).long() |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) |
|
|
| |
| if vision_patches is not None and vision_patches.size(0) > 0: |
|
|
| assert ( |
| vision_patch_indices.shape == input_ids.shape |
| ), "vision_patch_indices and input_ids should have the same shape" |
|
|
| vision_embeds = self.vis_embed(vision_patches) |
| vision_embeds = torch.cat( |
| [ |
| vision_embeds, |
| torch.zeros(1, self.config.hidden_size).to( |
| vision_embeds.device |
| ), |
| ], |
| ) |
| |
| |
| |
| vision_embeds = vision_embeds[vision_patch_indices] |
|
|
| |
| inputs_embeds += vision_embeds |
|
|
| if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: |
| is_padding_right = attention_mask[:, -1].sum().item() != batch_size |
| if is_padding_right: |
| raise ValueError( |
| "You are attempting to perform batched generation with padding_side='right'" |
| " this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to " |
| " call `tokenizer.padding_side = 'left'` before tokenizing the input. " |
| ) |
|
|
| if self._attn_implementation == "flash_attention_2": |
| |
| attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None |
| elif self._attn_implementation == "sdpa" and not output_attentions: |
| |
| |
| attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( |
| attention_mask, |
| (batch_size, seq_length), |
| inputs_embeds, |
| past_key_values_length, |
| ) |
| else: |
| |
| attention_mask = _prepare_4d_causal_attention_mask( |
| attention_mask, |
| (batch_size, seq_length), |
| inputs_embeds, |
| past_key_values_length, |
| sliding_window=self.config.sliding_window, |
| ) |
|
|
| hidden_states = inputs_embeds |
|
|
| |
| all_hidden_states = () if output_hidden_states else None |
| all_self_attns = () if output_attentions else None |
| next_decoder_cache = None |
|
|
| for decoder_layer in self.layers: |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| if self.gradient_checkpointing and self.training: |
| layer_outputs = self._gradient_checkpointing_func( |
| decoder_layer.__call__, |
| hidden_states, |
| attention_mask, |
| position_ids, |
| past_key_values, |
| output_attentions, |
| use_cache, |
| ) |
| else: |
| layer_outputs = decoder_layer( |
| hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_values, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| ) |
|
|
| hidden_states = layer_outputs[0] |
|
|
| if use_cache: |
| next_decoder_cache = layer_outputs[2 if output_attentions else 1] |
|
|
| if output_attentions: |
| all_self_attns += (layer_outputs[1],) |
|
|
| hidden_states = self.norm(hidden_states) |
|
|
| |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| next_cache = None |
| if use_cache: |
| next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache |
|
|
| if not return_dict: |
| return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=next_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attns, |
| ) |
|
|
|
|
| class SoloForCausalLM(MistralForCausalLM): |
| def __init__(self, config: SoloConfig): |
| super().__init__(config) |
|
|
| self.model = MultimodalMistralModel(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| |
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| vision_patch_indices: torch.LongTensor = None, |
| vision_patches: torch.FloatTensor = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position=None |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
| r""" |
| Args: |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| |
| Returns: |
| |
| Example: |
| |
| ```python |
| >>> from transformers import AutoTokenizer, MistralForCausalLM |
| |
| >>> model = MistralForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1") |
| >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-v0.1") |
| |
| >>> prompt = "Hey, are you conscious? Can you talk to me?" |
| >>> inputs = tokenizer(prompt, return_tensors="pt") |
| |
| >>> # Generate |
| >>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
| >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
| "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." |
| ```""" |
|
|
| output_attentions = ( |
| output_attentions |
| if output_attentions is not None |
| else self.config.output_attentions |
| ) |
| output_hidden_states = ( |
| output_hidden_states |
| if output_hidden_states is not None |
| else self.config.output_hidden_states |
| ) |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
|
|
| |
| outputs = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| vision_patch_indices=vision_patch_indices, |
| vision_patches=vision_patches, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| hidden_states = outputs[0] |
| logits = self.lm_head(hidden_states) |
| logits = logits.float() |
|
|
| loss = None |
| if labels is not None: |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| |
| shift_logits = shift_logits.view(-1, self.config.vocab_size) |
| shift_labels = shift_labels.view(-1) |
| |
| shift_labels = shift_labels.to(shift_logits.device) |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(shift_logits, shift_labels) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| def prepare_inputs_for_generation( |
| self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs |
| ): |
| vision_patches = kwargs.get("vision_patches") |
| vision_patch_indices = kwargs.get("vision_patch_indices") |
| |
| _padding = torch.full_like(input_ids, -1, dtype=vision_patch_indices.dtype) |
| _padding[:, : vision_patch_indices.shape[1]] = vision_patch_indices |
| vision_patch_indices = _padding |
|
|
| |
| if past_key_values is not None: |
| if isinstance(past_key_values, Cache): |
| cache_length = past_key_values.get_seq_length() |
| past_length = past_key_values.seen_tokens |
| max_cache_length = past_key_values.get_max_length() |
| else: |
| cache_length = past_length = past_key_values[0][0].shape[2] |
| max_cache_length = None |
|
|
| |
| |
| |
| |
| if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: |
| input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):] |
| vision_patch_indices = vision_patch_indices[:, -(attention_mask.shape[1] - past_length):] |
| |
| |
| elif past_length < input_ids.shape[1]: |
| input_ids = input_ids[:, past_length:] |
| vision_patch_indices = vision_patch_indices[:, past_length:] |
| |
|
|
| |
| if ( |
| max_cache_length is not None |
| and attention_mask is not None |
| and cache_length + input_ids.shape[1] > max_cache_length |
| ): |
| attention_mask = attention_mask[:, -max_cache_length:] |
|
|
| position_ids = kwargs.get("position_ids", None) |
| if attention_mask is not None and position_ids is None: |
| |
| position_ids = attention_mask.long().cumsum(-1) - 1 |
| position_ids.masked_fill_(attention_mask == 0, 1) |
| if past_key_values: |
| position_ids = position_ids[:, -input_ids.shape[1]:] |
|
|
| |
| if inputs_embeds is not None and past_key_values is None: |
| model_inputs = {"inputs_embeds": inputs_embeds} |
| else: |
| model_inputs = {"input_ids": input_ids} |
|
|
| assert vision_patch_indices.shape == input_ids.shape |
|
|
| model_inputs.update( |
| { |
| "position_ids": position_ids, |
| "past_key_values": past_key_values, |
| "use_cache": kwargs.get("use_cache"), |
| "attention_mask": attention_mask, |
| "vision_patch_indices": vision_patch_indices, |
| "vision_patches": vision_patches, |
| } |
| ) |
| return model_inputs |
|
|
|
|
| if __name__=="__main__": |
| model = SoloForCausalLM.from_pretrained("/mnt/bn/zilongdata-us/weixian/code/Megatron-MLLM/outputs-100M-8nodes-pt1/checkpoint/dsw-pretrains1-mmistral-7B-lr-5e-5-bs-8192-pr-bf16-tp-2-pp-1-ac-full-do-true-sp-false-tt-128000000-warmup-400/iter3124_hf") |
| print(model) |