| import torch |
| import torch.nn as nn |
|
|
| from typing import Optional, List, Union, Tuple |
| from transformers import Qwen2Model, Qwen2ForCausalLM |
| from transformers.utils import logging, is_torchdynamo_compiling |
| from transformers.cache_utils import Cache, DynamicCache, StaticCache |
| from transformers.configuration_utils import PretrainedConfig |
|
|
| from transformers.modeling_outputs import ( |
| CausalLMOutputWithPast, |
| BaseModelOutputWithPast, |
| ) |
| from transformers.modeling_attn_mask_utils import ( |
| _prepare_4d_causal_attention_mask, |
| _prepare_4d_causal_attention_mask_for_sdpa, |
| ) |
| from transformers.models.qwen2.modeling_qwen2 import ( |
| Qwen2DecoderLayer, |
| Qwen2RMSNorm, |
| Qwen2RotaryEmbedding, |
| ) |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| |
| class Qwen2MMConfig(PretrainedConfig): |
| model_type = "qwen" |
| keys_to_ignore_at_inference = ["past_key_values"] |
|
|
| def __init__( |
| self, |
| vocab_size=151936, |
| hidden_size=4096, |
| intermediate_size=22016, |
| num_hidden_layers=32, |
| num_attention_heads=32, |
| num_key_value_heads=32, |
| hidden_act="silu", |
| max_position_embeddings=32768, |
| initializer_range=0.02, |
| rms_norm_eps=1e-6, |
| use_cache=True, |
| tie_word_embeddings=False, |
| rope_theta=10000.0, |
| use_sliding_window=False, |
| sliding_window=4096, |
| max_window_layers=28, |
| attention_dropout=0.0, |
| vision_patch_size=32, |
| **kwargs, |
| ): |
| self.vocab_size = vocab_size |
| self.max_position_embeddings = max_position_embeddings |
| self.hidden_size = hidden_size |
| self.intermediate_size = intermediate_size |
| self.num_hidden_layers = num_hidden_layers |
| self.num_attention_heads = num_attention_heads |
| self.use_sliding_window = use_sliding_window |
| self.sliding_window = sliding_window |
| self.max_window_layers = max_window_layers |
|
|
| |
| if num_key_value_heads is None: |
| num_key_value_heads = num_attention_heads |
|
|
| self.num_key_value_heads = num_key_value_heads |
| self.hidden_act = hidden_act |
| self.initializer_range = initializer_range |
| self.rms_norm_eps = rms_norm_eps |
| self.use_cache = use_cache |
| self.rope_theta = rope_theta |
| self.attention_dropout = attention_dropout |
| self.vision_patch_size = vision_patch_size |
|
|
| super().__init__( |
| tie_word_embeddings=tie_word_embeddings, |
| **kwargs, |
| ) |
|
|
|
|
| class MultimodalQwen2Model(Qwen2Model): |
|
|
| def __init__(self, config: Qwen2MMConfig): |
| super(Qwen2Model, self).__init__(config) |
| self.padding_idx = config.pad_token_id |
| self.vocab_size = config.vocab_size |
|
|
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
| self.layers = nn.ModuleList( |
| [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] |
| ) |
| self._attn_implementation = config._attn_implementation |
| self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
| self.gradient_checkpointing = False |
|
|
| |
| assert config.vision_patch_size == 32 |
|
|
| self.vis_embed = nn.Linear( |
| config.vision_patch_size * config.vision_patch_size * 3, |
| config.hidden_size, |
| bias=False, |
| ) |
| |
|
|
| |
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| vision_patch_indices: torch.LongTensor = None, |
| vision_patches: torch.FloatTensor = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| ) -> Union[Tuple, BaseModelOutputWithPast]: |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if (input_ids is None) ^ (inputs_embeds is not None): |
| raise ValueError( |
| "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" |
| ) |
|
|
| if self.gradient_checkpointing and self.training: |
| if use_cache: |
| logger.warning_once( |
| "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
| ) |
| use_cache = False |
|
|
| use_legacy_cache = False |
| if use_cache and not isinstance(past_key_values, Cache): |
| use_legacy_cache = True |
| past_key_values = DynamicCache.from_legacy_cache(past_key_values) |
| logger.warning_once( |
| "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " |
| "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" |
| ) |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) |
|
|
| if vision_patch_indices is not None: |
| assert ( |
| vision_patch_indices.shape == input_ids.shape |
| ), "vision_patch_indices and input_ids should have the same shape" |
|
|
| |
| if vision_patches is not None and vision_patches.size(0) > 0: |
| assert vision_patch_indices is not None, "HF QwenMM model requires vision_patch_indices for vision_patches input." |
| vision_embeds = self.vis_embed(vision_patches) |
| vision_embeds = torch.cat( |
| [ |
| vision_embeds, |
| torch.zeros(1, self.config.hidden_size).to( |
| vision_embeds.device |
| ), |
| ], |
| ) |
| |
| |
| |
| vision_embeds = vision_embeds[vision_patch_indices] |
|
|
| |
| inputs_embeds += vision_embeds |
|
|
| if cache_position is None: |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| cache_position = torch.arange( |
| past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
| ) |
| if position_ids is None: |
| position_ids = cache_position.unsqueeze(0) |
|
|
| causal_mask = self._update_causal_mask( |
| attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions |
| ) |
|
|
| hidden_states = inputs_embeds |
|
|
| |
| all_hidden_states = () if output_hidden_states else None |
| all_self_attns = () if output_attentions else None |
| next_decoder_cache = None |
|
|
| for decoder_layer in self.layers: |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| if self.gradient_checkpointing and self.training: |
| layer_outputs = self._gradient_checkpointing_func( |
| decoder_layer.__call__, |
| hidden_states, |
| causal_mask, |
| position_ids, |
| past_key_values, |
| output_attentions, |
| use_cache, |
| cache_position, |
| ) |
| else: |
| layer_outputs = decoder_layer( |
| hidden_states, |
| attention_mask=causal_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_values, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| ) |
|
|
| hidden_states = layer_outputs[0] |
|
|
| if use_cache: |
| next_decoder_cache = layer_outputs[2 if output_attentions else 1] |
|
|
| if output_attentions: |
| all_self_attns += (layer_outputs[1],) |
|
|
| hidden_states = self.norm(hidden_states) |
|
|
| |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| next_cache = None |
| if use_cache: |
| next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache |
|
|
| if not return_dict: |
| return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=next_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attns, |
| ) |
|
|
|
|
| class Qwen2MMForCausalLM(Qwen2ForCausalLM): |
|
|
| def __init__(self, config: Qwen2MMConfig): |
| super().__init__(config) |
| self.model = MultimodalQwen2Model(config) |
| self.vocab_size = config.vocab_size |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
| |
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| vision_patch_indices: torch.LongTensor = None, |
| vision_patches: torch.FloatTensor = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| num_logits_to_keep: int = 0, |
| ) -> Union[Tuple, CausalLMOutputWithPast]: |
|
|
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| outputs = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| vision_patch_indices=vision_patch_indices, |
| vision_patches=vision_patches, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| ) |
|
|
| hidden_states = outputs[0] |
| if labels is None and not is_torchdynamo_compiling(): |
| logger.warning_once( |
| "Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)" |
| ) |
| |
| |
| logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float() |
|
|
| loss = None |
| if labels is not None: |
| |
| logits = logits.float() |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| |
| loss_fct = nn.CrossEntropyLoss() |
| shift_logits = shift_logits.view(-1, self.config.vocab_size) |
| shift_labels = shift_labels.view(-1) |
| |
| shift_labels = shift_labels.to(shift_logits.device) |
| loss = loss_fct(shift_logits, shift_labels) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[1:] |
| return (loss,) + output if loss is not None else output |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| |
| def prepare_inputs_for_generation( |
| self, |
| input_ids, |
| past_key_values=None, |
| attention_mask=None, |
| inputs_embeds=None, |
| cache_position=None, |
| use_cache=True, |
| **kwargs, |
| ): |
| vision_patches = kwargs.get("vision_patches", None) |
| vision_patch_indices = kwargs.get("vision_patch_indices", None) |
|
|
| has_vision_inp = False |
| if vision_patches is not None and vision_patch_indices is not None: |
| has_vision_inp = True |
| |
| _padding = torch.full_like(input_ids, -1, dtype=vision_patch_indices.dtype) |
| _padding[:, : vision_patch_indices.shape[1]] = vision_patch_indices |
| vision_patch_indices = _padding |
|
|
| past_length = 0 |
| |
| if past_key_values is not None: |
| |
| past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() |
| max_cache_length = ( |
| torch.tensor(past_key_values.get_max_length(), device=input_ids.device) |
| if past_key_values.get_max_length() is not None |
| else None |
| ) |
| cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) |
|
|
| |
| |
| |
| |
| if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: |
| input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):] |
| if has_vision_inp: |
| vision_patch_indices = vision_patch_indices[:, -(attention_mask.shape[1] - past_length):] |
| |
| |
| elif past_length < input_ids.shape[1]: |
| input_ids = input_ids[:, past_length:] |
| if has_vision_inp: |
| vision_patch_indices = vision_patch_indices[:, past_length:] |
| |
|
|
| |
| if ( |
| max_cache_length is not None |
| and attention_mask is not None |
| and cache_length + input_ids.shape[1] > max_cache_length |
| ): |
| attention_mask = attention_mask[:, -max_cache_length:] |
|
|
| position_ids = kwargs.get("position_ids", None) |
| if attention_mask is not None and position_ids is None: |
| |
| position_ids = attention_mask.long().cumsum(-1) - 1 |
| position_ids.masked_fill_(attention_mask == 0, 1) |
| if past_key_values: |
| position_ids = position_ids[:, -input_ids.shape[1]:] |
|
|
| |
| if inputs_embeds is not None and past_length == 0: |
| model_inputs = {"inputs_embeds": inputs_embeds} |
| else: |
| model_inputs = {"input_ids": input_ids} |
|
|
| input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] |
| if cache_position is None: |
| cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) |
| elif use_cache: |
| cache_position = cache_position[-input_length:] |
|
|
| if vision_patch_indices is not None: |
| assert vision_patch_indices.shape == input_ids.shape |
|
|
| model_inputs.update( |
| { |
| "position_ids": position_ids, |
| "past_key_values": past_key_values, |
| "use_cache": use_cache, |
| "attention_mask": attention_mask, |
| "cache_position": cache_position, |
| "vision_patch_indices": vision_patch_indices, |
| "vision_patches": vision_patches, |
| } |
| ) |
| return model_inputs |
|
|
|
|
| if __name__ == "__main__": |
| mmqwen = Qwen2MMForCausalLM.from_pretrained("Qwen2-0.5B") |
| print(mmqwen) |
|
|
|
|