| """PyTorch MERaLiON2 model.""" |
|
|
| from dataclasses import dataclass |
| from typing import List, Optional, Tuple, Union |
|
|
| import torch |
| import torch.utils.checkpoint |
| from torch import nn |
|
|
| from transformers import Gemma2ForCausalLM |
| from transformers.models.whisper.modeling_whisper import WhisperEncoder |
| from transformers.cache_utils import HybridCache |
| from transformers.generation import GenerationMixin |
| from transformers.modeling_outputs import ModelOutput |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.utils import ( |
| add_start_docstrings, |
| add_start_docstrings_to_model_forward, |
| logging, |
| replace_return_docstrings, |
| ) |
|
|
| from .configuration_meralion2 import MERaLiON2Config |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
| _CONFIG_FOR_DOC = "MERaLiON2Config" |
|
|
|
|
| |
| def _prepare_4d_causal_attention_mask_with_cache_position( |
| attention_mask: torch.Tensor, |
| sequence_length: int, |
| target_length: int, |
| dtype: torch.dtype, |
| device: torch.device, |
| min_dtype: float, |
| cache_position: torch.Tensor, |
| batch_size: int, |
| ): |
| """ |
| Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape |
| `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. |
| |
| Args: |
| attention_mask (`torch.Tensor`): |
| A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. |
| sequence_length (`int`): |
| The sequence length being processed. |
| target_length (`int`): |
| The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. |
| dtype (`torch.dtype`): |
| The dtype to use for the 4D attention mask. |
| device (`torch.device`): |
| The device to plcae the 4D attention mask on. |
| min_dtype (`float`): |
| The minimum value representable with the dtype `dtype`. |
| cache_position (`torch.Tensor`): |
| Indices depicting the position of the input sequence tokens in the sequence. |
| batch_size (`torch.Tensor`): |
| Batch size. |
| """ |
| if attention_mask is not None and attention_mask.dim() == 4: |
| |
| causal_mask = attention_mask |
| else: |
| causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) |
| if sequence_length != 1: |
| causal_mask = torch.triu(causal_mask, diagonal=1) |
| causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) |
| causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
| if attention_mask is not None: |
| causal_mask = causal_mask.clone() |
| mask_length = attention_mask.shape[-1] |
| padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] |
| padding_mask = padding_mask == 0 |
| causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
| padding_mask, min_dtype |
| ) |
| return causal_mask |
|
|
|
|
| |
| @dataclass |
| class MERaLiON2OutputWithPast(ModelOutput): |
| """ |
| Base class for MERaLiON2 causal language model (or autoregressive) outputs. |
| |
| Args: |
| loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): |
| Language modeling loss (for next-token prediction). |
| logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): |
| Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
| past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
| Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape |
| `(batch_size, num_heads, sequence_length, embed_size_per_head)`) |
| |
| Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see |
| `past_key_values` input) to speed up sequential decoding. |
| hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
| Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + |
| one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. |
| |
| Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. |
| attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
| Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
| sequence_length)`. |
| |
| Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
| heads. |
| attention_mask (`torch.FloatTensor`, *optional*): |
| Attentions mask, used to update attention mask and position_ids. |
| """ |
|
|
| loss: Optional[torch.FloatTensor] = None |
| logits: torch.FloatTensor = None |
| past_key_values: Optional[List[torch.FloatTensor]] = None |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| attentions: Optional[Tuple[torch.FloatTensor]] = None |
| attention_mask: Optional[torch.FloatTensor] = None |
|
|
|
|
| MERALION_START_DOCSTRING = r""" |
| This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the |
| library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
| etc.) |
| |
| This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. |
| Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage |
| and behavior. |
| |
| Parameters: |
| config ([`MERaLiON2Config`]): |
| Model configuration class with all the parameters of the model. Initializing with a config file does not |
| load the weights associated with the model, only the configuration. Check out the |
| [`~PreTrainedModel.from_pretrained`] method to load the model weights. |
| """ |
|
|
|
|
| @add_start_docstrings( |
| "The bare MERaLiON2 Model outputting raw hidden-states without any specific head on top.", |
| MERALION_START_DOCSTRING, |
| ) |
| class MERaLiON2PreTrainedModel(PreTrainedModel): |
| config_class = MERaLiON2Config |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer", "Gemma2DecoderLayer"] |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_cache_class = True |
| _supports_static_cache = True |
|
|
| def _init_weights(self, module): |
| |
| |
| std = self.config.init_std if hasattr(self.config, "init_std") else self.config.speech_config.init_std |
|
|
| if isinstance(module, (nn.Linear, nn.Conv1d)): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
|
|
| @property |
| def _supports_sdpa(self): |
| """ |
| Retrieve language_model's attribute to check whether the model supports |
| SDPA or not. |
| """ |
| return self.text_decoder._supports_sdpa |
|
|
| class MERaLiON2SpeechAudioAdaper(nn.Module): |
| def __init__( |
| self, |
| config, |
| **kwargs |
| ): |
| super(MERaLiON2SpeechAudioAdaper, self).__init__() |
| speech_audio_encoder_output_dim = config.speech_config.d_model |
| llm_input_hidden_size = config.text_config.hidden_size |
| speech_mlp_scale_factor = config.speech_mlp_scale_factor |
|
|
| self.speech_mlp_scale_factor = speech_mlp_scale_factor |
| self.mlp_adapter = nn.Sequential( |
| nn.Linear( |
| in_features=speech_audio_encoder_output_dim * speech_mlp_scale_factor, |
| out_features=speech_audio_encoder_output_dim |
| ), |
| nn.SiLU(), |
| nn.Dropout(0.1), |
| ) |
|
|
| self.speech_llm_proj = nn.Sequential( |
| nn.Linear( |
| speech_audio_encoder_output_dim, |
| speech_audio_encoder_output_dim * 4 |
| ), |
| nn.SiLU(), |
| nn.Dropout(0.1), |
|
|
| nn.Linear( |
| speech_audio_encoder_output_dim * 4, |
| llm_input_hidden_size |
| ), |
| ) |
|
|
| def forward(self, speech_embeds, **kwargs): |
| B, T, C = speech_embeds.shape |
| speech_embeds = self.mlp_adapter( |
| speech_embeds.reshape( |
| B, |
| T // self.speech_mlp_scale_factor, |
| C * self.speech_mlp_scale_factor, |
| ) |
| ) |
| return self.speech_llm_proj(speech_embeds) |
| |
|
|
| class MERaLiON2SpeechAudioAdaperLarge(nn.Module): |
| def __init__( |
| self, |
| config, |
| **kwargs |
| ): |
| super(MERaLiON2SpeechAudioAdaperLarge, self).__init__() |
| speech_audio_encoder_output_dim = config.speech_config.d_model |
| llm_input_hidden_size = config.text_config.hidden_size |
| speech_mlp_scale_factor = config.speech_mlp_scale_factor |
|
|
| self.speech_mlp_scale_factor = speech_mlp_scale_factor |
| self.mlp_adapter = nn.Sequential( |
| nn.Linear( |
| in_features=speech_audio_encoder_output_dim * speech_mlp_scale_factor, |
| out_features=speech_audio_encoder_output_dim * 5, |
| ), |
| nn.SiLU(), |
| nn.Dropout(0.01), |
| ) |
|
|
| self.gate_proj = nn.Linear( |
| in_features=speech_audio_encoder_output_dim * 5, |
| out_features=speech_audio_encoder_output_dim * 5, |
| ) |
| |
| self.pool_proj = nn.Linear( |
| in_features=speech_audio_encoder_output_dim * 5, |
| out_features=speech_audio_encoder_output_dim * 5, |
| ) |
| self.act_fn = nn.SiLU() |
| self.out_proj = nn.Linear( |
| speech_audio_encoder_output_dim * 5, |
| llm_input_hidden_size, |
| ) |
|
|
|
|
| def forward(self, speech_embeds, **kwargs): |
| B, T, C = speech_embeds.shape |
| speech_embeds = self.mlp_adapter( |
| speech_embeds.reshape( |
| B, |
| T // self.speech_mlp_scale_factor, |
| C * self.speech_mlp_scale_factor, |
| ) |
| ) |
| speech_embeds = self.act_fn(self.gate_proj(speech_embeds)) * self.pool_proj(speech_embeds) |
| speech_embeds = self.out_proj(speech_embeds) |
| return speech_embeds |
|
|
|
|
| MERALION_INPUTS_DOCSTRING = r""" |
| Args: |
| input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): |
| Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide |
| it. |
| |
| Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| [`PreTrainedTokenizer.__call__`] for details. |
| |
| [What are input IDs?](../glossary#input-ids) |
| input_features (`torch.FloatTensor` of shape `(batch_size, feature_size, feature_sequence_length)`, *optional*): |
| Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by |
| loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via |
| the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the |
| [`AutoFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a |
| tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] |
| attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
| |
| - 1 for tokens that are **not masked**, |
| - 0 for tokens that are **masked**. |
| |
| [What are attention masks?](../glossary#attention-mask) |
| |
| Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| [`PreTrainedTokenizer.__call__`] for details. |
| |
| If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see |
| `past_key_values`). |
| |
| If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] |
| and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more |
| information on the default strategy. |
| |
| - 1 indicates the head is **not masked**, |
| - 0 indicates the head is **masked**. |
| feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`, *optional*): |
| Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: |
| |
| - 1 for tokens that are **not masked**, |
| - 0 for tokens that are **masked**. |
| position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, |
| config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) |
| past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): |
| Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape |
| `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape |
| `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. |
| |
| Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention |
| blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. |
| |
| If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that |
| don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all |
| `decoder_input_ids` of shape `(batch_size, sequence_length)`. |
| inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): |
| Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This |
| is useful if you want more control over how to convert `input_ids` indices into associated vectors than the |
| model's internal embedding lookup matrix. |
| use_cache (`bool`, *optional*): |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see |
| `past_key_values`). |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
| tensors for more detail. |
| output_hidden_states (`bool`, *optional*): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
| more detail. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| """ |
|
|
| @add_start_docstrings( |
| """The MERALION model which consists of a audio backbone and a language model.""", |
| MERALION_START_DOCSTRING, |
| ) |
| class MERaLiON2ForConditionalGeneration(MERaLiON2PreTrainedModel, GenerationMixin): |
| def __init__(self, config: MERaLiON2Config): |
| config.text_config._attn_implementation = config._attn_implementation |
| config.speech_config._attn_implementation = config._attn_implementation |
|
|
| super().__init__(config) |
|
|
| self.speech_encoder = WhisperEncoder(config.speech_config) |
| |
|
|
| self.ln_speech = nn.LayerNorm(config.speech_config.d_model) |
| self.speech_audio_adapter = MERaLiON2SpeechAudioAdaperLarge(config) |
| self.vocab_size = config.text_config.vocab_size |
| self.text_decoder = Gemma2ForCausalLM(config.text_config) |
| self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1 |
| self._padding_side = "left" |
| self.post_init() |
|
|
| @property |
| def padding_side(self): |
| return self._padding_side |
|
|
| @padding_side.setter |
| def padding_side(self, padding_side: str): |
| if padding_side not in ["left", "right"]: |
| raise ValueError(f"{padding_side} is not `left` or `right`.") |
| self._padding_side = padding_side |
|
|
| |
| def get_input_embeddings(self): |
| return self.text_decoder.get_input_embeddings() |
|
|
| |
| def set_input_embeddings(self, value): |
| self.text_decoder.set_input_embeddings(value) |
|
|
| |
| def get_output_embeddings(self): |
| return self.text_decoder.get_output_embeddings() |
|
|
| |
| def set_output_embeddings(self, new_embeddings): |
| self.text_decoder.set_output_embeddings(new_embeddings) |
|
|
| |
| def set_decoder(self, decoder): |
| self.text_decoder.set_decoder(decoder) |
|
|
| |
| def get_decoder(self): |
| return self.text_decoder.get_decoder() |
|
|
| |
| def tie_weights(self): |
| return self.text_decoder.tie_weights() |
|
|
| |
| def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding: |
| model_embeds = self.text_decoder.resize_token_embeddings(new_num_tokens, pad_to_multiple_of) |
| |
| self.config.text_config.vocab_size = model_embeds.num_embeddings |
| self.vocab_size = model_embeds.num_embeddings |
| return model_embeds |
|
|
| @add_start_docstrings_to_model_forward(MERALION_INPUTS_DOCSTRING) |
| @replace_return_docstrings(output_type=MERaLiON2OutputWithPast, config_class=_CONFIG_FOR_DOC) |
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| input_features: torch.FloatTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| feature_attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, MERaLiON2OutputWithPast]: |
| r""" |
| Args: |
| labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
| Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
| config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
| (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
| |
| Returns: |
| """ |
|
|
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| speech_encoder_device = self.speech_encoder.device |
|
|
| if input_features is not None: |
| input_features = input_features.to(speech_encoder_device) |
| feature_attention_mask = feature_attention_mask.to(speech_encoder_device) |
|
|
| if inputs_embeds is None: |
| speech_contexts_embeds = self.speech_encoder(input_features, attention_mask=feature_attention_mask).last_hidden_state |
| speech_contexts_embeds = self.ln_speech(speech_contexts_embeds) |
| speech_audio_contexts_embeds = self.speech_audio_adapter(speech_contexts_embeds) |
|
|
| inputs_embeds = self.text_decoder.base_model.embed_tokens(input_ids) |
|
|
| speech_mask = (input_ids == self.config.speech_token_index).unsqueeze(-1) |
| speech_mask = speech_mask.expand_as(inputs_embeds).to(inputs_embeds.device) |
|
|
| inputs_embeds = inputs_embeds.masked_scatter(speech_mask, speech_audio_contexts_embeds) |
|
|
| input_ids = None |
|
|
| outputs = self.text_decoder( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| cache_position=cache_position, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| labels=labels |
| ) |
|
|
| return outputs |
|
|
| |
| def prepare_inputs_for_generation( |
| self, |
| input_ids, |
| attention_mask=None, |
| input_features=None, |
| feature_attention_mask=None, |
| past_key_values=None, |
| inputs_embeds=None, |
| cache_position=None, |
| position_ids=None, |
| use_cache=None, |
| **kwargs, |
| ): |
| |
| |
| |
| is_first_step = cache_position[0].item() == 0 |
| if past_key_values is not None: |
| if inputs_embeds is not None: |
| input_ids = input_ids[:, -cache_position.shape[0] :] |
| elif input_ids.shape[1] != cache_position.shape[0]: |
| input_ids = input_ids[:, cache_position] |
|
|
| if attention_mask is not None and position_ids is None: |
| |
| position_ids = attention_mask.long().cumsum(-1) - 1 |
| position_ids.masked_fill_(attention_mask == 0, 1) |
| if past_key_values: |
| position_ids = position_ids[:, -input_ids.shape[1] :] |
| |
| |
| |
| |
| |
| position_ids = position_ids.clone(memory_format=torch.contiguous_format) |
|
|
| |
| if inputs_embeds is not None and is_first_step: |
| model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} |
| else: |
| |
| model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} |
|
|
| if ( |
| isinstance(past_key_values, HybridCache) |
| and attention_mask.ndim == 2 |
| and not self.config._attn_implementation == "flash_attention_2" |
| ): |
| if model_inputs["inputs_embeds"] is not None: |
| batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape |
| device = model_inputs["inputs_embeds"].device |
| else: |
| batch_size, sequence_length = model_inputs["input_ids"].shape |
| device = model_inputs["input_ids"].device |
| dtype = self.text_decoder.lm_head.weight.dtype |
| min_dtype = torch.finfo(dtype).min |
| attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( |
| attention_mask, |
| sequence_length=sequence_length, |
| target_length=past_key_values.get_max_cache_shape(), |
| dtype=dtype, |
| device=device, |
| min_dtype=min_dtype, |
| cache_position=cache_position, |
| batch_size=batch_size, |
| ) |
|
|
| model_inputs.update( |
| { |
| "attention_mask": attention_mask, |
| "position_ids": position_ids, |
| "cache_position": cache_position, |
| "past_key_values": past_key_values, |
| "use_cache": use_cache |
| } |
| ) |
|
|
| |
| if is_first_step: |
| model_inputs["input_features"] = input_features |
| model_inputs["feature_attention_mask"] = feature_attention_mask |
|
|
| return model_inputs |
|
|
| def _reorder_cache(self, *args, **kwargs): |
| return self.text_decoder._reorder_cache(*args, **kwargs) |