| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ Classes to support Flax Speech-Encoder-Decoder architectures""" |
| |
|
| | import os |
| | from functools import partial |
| | from typing import Optional, Tuple, Union, Dict |
| |
|
| | import flax |
| | import flax.linen as nn |
| | import jax |
| | import jax.numpy as jnp |
| | from flax.core.frozen_dict import FrozenDict, unfreeze |
| | from jax import lax |
| | from jax.random import PRNGKey |
| | import numpy as np |
| |
|
| | from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutputWithCrossAttentions, FlaxSeq2SeqLMOutput |
| | from transformers.modeling_flax_utils import FlaxPreTrainedModel |
| | from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings, ModelOutput |
| | from transformers.generation_flax_utils import FlaxLogitsProcessorList |
| | from models import ( |
| | FlaxWav2Vec2Model, |
| | FlaxWav2Vec2Module, |
| | FlaxBartForCausalLM, |
| | FlaxBartForCausalLMModule, |
| | BartConfig, |
| | Wav2Vec2Config, |
| | SpeechEncoderDecoderConfig, |
| | ) |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| | _CONFIG_FOR_DOC = "SpeechEncoderDecoderConfig" |
| |
|
| | SPEECH_ENCODER_DECODER_START_DOCSTRING = r""" |
| | This class can be used to initialize a speech-sequence-to-text-sequence model with any pretrained speech |
| | autoencoding model as the encoder and any pretrained text autoregressive model as the decoder. The encoder is |
| | loaded via [`~AutoModel.from_pretrained`] function and the decoder is loaded via |
| | [`~AutoModelForCausalLM.from_pretrained`] function. Cross-attention layers are automatically added to the decoder |
| | and should be fine-tuned on a downstream generative task, like summarization. |
| | |
| | The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation |
| | tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation |
| | Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi |
| | Zhou, Wei Li, Peter J. Liu. |
| | |
| | Additionally, in [Large-Scale Self- and Semi-Supervised Learning for Speech |
| | Translation](https://arxiv.org/abs/2104.06678) it is shown how leveraging large pretrained speech models for speech |
| | translation yields a significant performance improvement. |
| | |
| | After such an Speech-Encoder Decoder model has been trained/fine-tuned, it can be saved/loaded just like any other |
| | models (see the examples for more information). |
| | |
| | This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the |
| | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
| | etc.) |
| | |
| | This model is also a Flax Linen |
| | [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a |
| | regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. |
| | |
| | Parameters: |
| | config ([`SpeechEncoderDecoderConfig`]): Model configuration class with all the parameters of the model. |
| | Initializing with a config file does not load the weights associated with the model, only the |
| | configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. |
| | dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): |
| | The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and |
| | `jax.numpy.bfloat16` (on TPUs). |
| | |
| | This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If |
| | specified all the computation will be performed with the given `dtype`. |
| | |
| | **Note that this only specifies the dtype of the computation and does not influence the dtype of model |
| | parameters.** |
| | |
| | If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and |
| | [`~FlaxPreTrainedModel.to_bf16`]. |
| | """ |
| |
|
| | SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r""" |
| | Args: |
| | inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*): |
| | Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac* |
| | or *.wav* audio file into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile |
| | library (*pip install soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or |
| | [`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type |
| | *torch.FloatTensor*. |
| | attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): |
| | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
| | |
| | - 1 for tokens that are **not masked**, |
| | - 0 for tokens that are **masked**. |
| | |
| | [What are attention masks?](../glossary#attention-mask) |
| | decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): |
| | Indices of decoder input sequence tokens in the vocabulary. |
| | |
| | Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| | [`PreTrainedTokenizer.__call__`] for details. |
| | |
| | [What are input IDs?](../glossary#input-ids) |
| | |
| | If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see |
| | `past_key_values`). |
| | |
| | For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be |
| | created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id` |
| | and prepending them with the `decoder_start_token_id`. |
| | decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): |
| | Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also |
| | be used by default. |
| | decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): |
| | Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the |
| | range `[0, config.decoder.max_position_embeddings - 1]`. |
| | output_hidden_states (`bool`, *optional*): |
| | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
| | more detail. |
| | return_dict (`bool`, *optional*): |
| | If set to `True`, the model will return a [`~utils.FlaxSeq2SeqLMOutput`] instead of a plain tuple. |
| | """ |
| |
|
| | SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING = r""" |
| | Args: |
| | inputs (`jnp.ndarray` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*): |
| | Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac* |
| | or *.wav* audio file into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile |
| | library (*pip install soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or |
| | [`Speech2TextProcessor`] should be used for padding and conversion into a tensor of type |
| | *torch.FloatTensor*. |
| | attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): |
| | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
| | |
| | - 1 for tokens that are **not masked**, |
| | - 0 for tokens that are **masked**. |
| | |
| | [What are attention masks?](../glossary#attention-mask) |
| | output_attentions (`bool`, *optional*): |
| | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
| | tensors for more detail. |
| | output_hidden_states (`bool`, *optional*): |
| | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
| | more detail. |
| | return_dict (`bool`, *optional*): |
| | If set to `True`, the model will return a [`~utils.FlaxBaseModelOutput`] instead of a plain tuple. |
| | """ |
| |
|
| | SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING = r""" |
| | Args: |
| | decoder_input_ids (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): |
| | Indices of decoder input sequence tokens in the vocabulary. |
| | |
| | Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
| | [`PreTrainedTokenizer.__call__`] for details. |
| | |
| | [What are decoder input IDs?](../glossary#decoder-input-ids) |
| | |
| | If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see |
| | `past_key_values`). |
| | |
| | For sequence to sequence training, `decoder_input_ids` should be provided. `decoder_input_ids` should be |
| | created outside of the model by shifting the `labels` to the right, replacing -100 by the `pad_token_id` |
| | and prepending them with the `decoder_start_token_id`. |
| | encoder_outputs (`tuple(tuple(jnp.ndarray)`): |
| | Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) |
| | `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of |
| | hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. |
| | encoder_attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): |
| | Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
| | |
| | - 1 for tokens that are **not masked**, |
| | - 0 for tokens that are **masked**. |
| | |
| | [What are attention masks?](../glossary#attention-mask) |
| | decoder_attention_mask (`jnp.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*): |
| | Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also |
| | be used by default. |
| | decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*): |
| | Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the |
| | range `[0, config.decoder.max_position_embeddings - 1]`. |
| | past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`): |
| | Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast |
| | auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*. |
| | output_attentions (`bool`, *optional*): |
| | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
| | tensors for more detail. |
| | output_hidden_states (`bool`, *optional*): |
| | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
| | more detail. |
| | return_dict (`bool`, *optional*): |
| | If set to `True`, the model will return a [`~utils.FlaxCausalLMOutputWithCrossAttentions`] instead of a |
| | plain tuple. |
| | """ |
| |
|
| | @flax.struct.dataclass |
| | class FlaxBeamSearchOutput(ModelOutput): |
| | """ |
| | Flax Base class for outputs of decoder-only generation models using greedy search. |
| | |
| | |
| | Args: |
| | sequences (`jnp.ndarray` of shape `(batch_size, max_length)`): |
| | The generated sequences. |
| | scores (`jnp.ndarray` of shape `(batch_size,)`): |
| | The scores (log probabilites) of the generated sequences. |
| | """ |
| |
|
| | sequences: jnp.ndarray = None |
| | scores: jnp.ndarray = None |
| |
|
| |
|
| | @flax.struct.dataclass |
| | class BeamSearchState: |
| | cur_len: jnp.ndarray |
| | running_sequences: jnp.ndarray |
| | running_scores: jnp.ndarray |
| | sequences: jnp.ndarray |
| | scores: jnp.ndarray |
| | is_sent_finished: jnp.ndarray |
| | model_kwargs: Dict[str, jnp.ndarray] |
| |
|
| |
|
| |
|
| |
|
| | class FlaxSpeechEncoderDecoderModule(nn.Module): |
| | config: SpeechEncoderDecoderConfig |
| | dtype: jnp.dtype = jnp.float32 |
| |
|
| | def setup(self): |
| | encoder_config = self.config.encoder |
| | decoder_config = self.config.decoder |
| |
|
| | |
| | encoder_module = FlaxWav2Vec2Module |
| | decoder_module = FlaxBartForCausalLMModule |
| |
|
| | self.encoder = encoder_module(encoder_config, dtype=self.dtype) |
| | self.decoder = decoder_module(decoder_config, dtype=self.dtype) |
| |
|
| | |
| | if ( |
| | self.encoder.config.hidden_size != self.decoder.config.hidden_size |
| | and self.decoder.config.cross_attention_hidden_size is None |
| | ): |
| | self.enc_to_dec_proj = nn.Dense( |
| | self.decoder.config.hidden_size, |
| | kernel_init=jax.nn.initializers.normal(self.decoder.config.initializer_range), |
| | dtype=self.dtype, |
| | ) |
| | else: |
| | self.enc_to_dec_proj = None |
| |
|
| | def _get_feat_extract_output_lengths( |
| | self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None |
| | ): |
| | """ |
| | Computes the output length of the convolutional layers |
| | """ |
| |
|
| | add_adapter = self.config.encoder.add_adapter if add_adapter is None else add_adapter |
| |
|
| | def _conv_out_length(input_length, kernel_size, stride): |
| | |
| | |
| | return (input_length - kernel_size) // stride + 1 |
| |
|
| | for kernel_size, stride in zip(self.config.encoder.conv_kernel, self.config.encoder.conv_stride): |
| | input_lengths = _conv_out_length(input_lengths, kernel_size, stride) |
| |
|
| | if add_adapter: |
| | for _ in range(self.config.encoder.num_adapter_layers): |
| | input_lengths = _conv_out_length(input_lengths, 1, self.config.encoder.adapter_stride) |
| |
|
| | return input_lengths |
| |
|
| | def _get_encoder_module(self): |
| | return self.encoder |
| |
|
| | def _get_projection_module(self): |
| | return self.enc_to_dec_proj |
| |
|
| | def _get_decoder_module(self): |
| | return self.decoder |
| |
|
| | def __call__( |
| | self, |
| | inputs, |
| | attention_mask, |
| | decoder_input_ids, |
| | decoder_attention_mask, |
| | decoder_position_ids, |
| | encoder_outputs=None, |
| | extract_features=None, |
| | output_attentions: bool = False, |
| | output_hidden_states: bool = False, |
| | output_features: bool = False, |
| | return_dict: bool = True, |
| | deterministic: bool = True, |
| | freeze_feature_encoder: bool = False, |
| | ): |
| | if encoder_outputs is None: |
| | encoder_outputs = self.encoder( |
| | inputs, |
| | attention_mask=attention_mask, |
| | extract_features=extract_features, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | output_features=output_features, |
| | return_dict=return_dict, |
| | deterministic=deterministic, |
| | freeze_feature_encoder=freeze_feature_encoder, |
| | ) |
| |
|
| | if output_features: |
| | return encoder_outputs |
| |
|
| | encoder_hidden_states = encoder_outputs[0] |
| |
|
| | |
| | if self.enc_to_dec_proj is not None: |
| | encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) |
| |
|
| | |
| | if attention_mask is not None: |
| | encoder_attention_mask = self.encoder._get_feature_vector_attention_mask( |
| | encoder_hidden_states.shape[1], attention_mask |
| | ) |
| | else: |
| | encoder_attention_mask = None |
| |
|
| | |
| | decoder_outputs = self.decoder( |
| | input_ids=decoder_input_ids, |
| | attention_mask=decoder_attention_mask, |
| | position_ids=decoder_position_ids, |
| | encoder_hidden_states=encoder_hidden_states, |
| | encoder_attention_mask=encoder_attention_mask, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | deterministic=deterministic, |
| | ) |
| |
|
| | if not return_dict: |
| | return decoder_outputs + encoder_outputs |
| |
|
| | return FlaxSeq2SeqLMOutput( |
| | logits=decoder_outputs.logits, |
| | decoder_hidden_states=decoder_outputs.hidden_states, |
| | decoder_attentions=decoder_outputs.attentions, |
| | cross_attentions=decoder_outputs.cross_attentions, |
| | encoder_last_hidden_state=encoder_hidden_states, |
| | encoder_hidden_states=encoder_outputs.hidden_states, |
| | encoder_attentions=encoder_outputs.attentions, |
| | ) |
| |
|
| |
|
| | @add_start_docstrings(SPEECH_ENCODER_DECODER_START_DOCSTRING) |
| | class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel): |
| | r""" |
| | [`FlaxSpeechEncoderDecoderModel`] is a generic model class that will be instantiated as a transformer architecture |
| | with the module (flax.nn.Module) of one of the base model classes of the library as encoder module and another one |
| | as decoder module when created with the :meth*~transformers.FlaxAutoModel.from_pretrained* class method for the |
| | encoder and :meth*~transformers.FlaxAutoModelForCausalLM.from_pretrained* class method for the decoder. |
| | """ |
| |
|
| | config_class = SpeechEncoderDecoderConfig |
| | base_model_prefix: str = "speech_encoder_decoder" |
| | module_class = FlaxSpeechEncoderDecoderModule |
| |
|
| | def __init__( |
| | self, |
| | config: SpeechEncoderDecoderConfig, |
| | input_shape: Optional[Tuple] = None, |
| | seed: int = 0, |
| | dtype: jnp.dtype = jnp.float32, |
| | _do_init: bool = True, |
| | **kwargs |
| | ): |
| |
|
| | if not _do_init: |
| | raise ValueError( |
| | "`FlaxSpeechEncoderDecoderModel` cannot be created without initializing, `_do_init` must be `True`." |
| | ) |
| |
|
| | if config.decoder.cross_attention_hidden_size is not None: |
| | |
| | if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: |
| | raise ValueError( |
| | "If `cross_attention_hidden_size` is specified in the decoder's configuration, " |
| | "it has to be equal to the encoder's `hidden_size`. " |
| | f"Got {config.decoder.cross_attention_hidden_size} for `config.decoder.cross_attention_hidden_size` " |
| | f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`." |
| | ) |
| |
|
| | |
| | config.tie_word_embeddings = False |
| | module = self.module_class(config=config, dtype=dtype, **kwargs) |
| |
|
| | if input_shape is None: |
| | |
| | encoder_input_length = 1024 |
| | decoder_input_length = module._get_feat_extract_output_lengths(encoder_input_length) |
| | input_shape = ((1, encoder_input_length), (1, decoder_input_length)) |
| |
|
| | super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) |
| |
|
| | def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: |
| | encoder_input_shape, decoder_input_shape = input_shape |
| |
|
| | |
| | inputs = jnp.zeros(encoder_input_shape, dtype="f4") |
| | attention_mask = jnp.ones_like(inputs, dtype="i4") |
| | decoder_input_ids = jnp.zeros(decoder_input_shape, dtype="i4") |
| | decoder_attention_mask = jnp.ones_like(decoder_input_ids) |
| |
|
| | batch_size, sequence_length = inputs.shape |
| |
|
| | decoder_batch_size, decoder_sequence_length = decoder_input_ids.shape |
| | if not decoder_batch_size == batch_size: |
| | raise ValueError( |
| | f"The inputs of encoder and decoder should have the same batch size, but got {batch_size} for encoder and {decoder_batch_size} for decoder." |
| | ) |
| | decoder_position_ids = jnp.broadcast_to( |
| | jnp.arange(decoder_sequence_length)[None, :], (decoder_batch_size, decoder_sequence_length) |
| | ) |
| |
|
| | params_rng, dropout_rng = jax.random.split(rng) |
| | rngs = {"params": params_rng, "dropout": dropout_rng} |
| |
|
| | return self.module.init( |
| | rngs, |
| | inputs, |
| | attention_mask, |
| | decoder_input_ids, |
| | decoder_attention_mask, |
| | decoder_position_ids, |
| | )["params"] |
| |
|
| | def init_cache(self, batch_size, max_length, encoder_outputs): |
| | r""" |
| | Args: |
| | batch_size (`int`): |
| | batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. |
| | max_length (`int`): |
| | maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized |
| | cache. |
| | encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`): |
| | `encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: |
| | `attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) |
| | is a sequence of hidden-states at the output of the last layer of the encoder. Used in the |
| | cross-attention of the decoder. |
| | """ |
| | |
| | decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4") |
| | decoder_attention_mask = jnp.ones_like(decoder_input_ids) |
| | decoder_position_ids = jnp.broadcast_to( |
| | jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]), decoder_input_ids.shape |
| | ) |
| |
|
| | def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, **kwargs): |
| | decoder_module = module._get_decoder_module() |
| | return decoder_module( |
| | input_ids=decoder_input_ids, |
| | attention_mask=decoder_attention_mask, |
| | position_ids=decoder_position_ids, |
| | **kwargs, |
| | ) |
| |
|
| | init_variables = self.module.init( |
| | jax.random.PRNGKey(0), |
| | decoder_input_ids=decoder_input_ids, |
| | decoder_attention_mask=decoder_attention_mask, |
| | decoder_position_ids=decoder_position_ids, |
| | encoder_hidden_states=encoder_outputs[0], |
| | init_cache=True, |
| | method=_decoder_forward, |
| | ) |
| | return unfreeze(init_variables["cache"]) |
| |
|
| | def _get_feat_extract_output_lengths( |
| | self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None |
| | ): |
| | return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter) |
| |
|
| | @add_start_docstrings(SPEECH_ENCODER_DECODER_ENCODE_INPUTS_DOCSTRING) |
| | @replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=_CONFIG_FOR_DOC) |
| | def encode( |
| | self, |
| | inputs: jnp.ndarray, |
| | attention_mask: Optional[jnp.ndarray] = None, |
| | extract_features: Optional[jnp.ndarray] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | output_features: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | train: bool = False, |
| | freeze_feature_encoder: bool = False, |
| | params: dict = None, |
| | dropout_rng: PRNGKey = None, |
| | ): |
| | r""" |
| | Returns: |
| | |
| | Example: |
| | |
| | ```python |
| | >>> from transformers import FlaxSpeechEncoderDecoderModel |
| | |
| | >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized |
| | >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained( |
| | ... "facebook/wav2vec2-large-lv60", "facebook/bart-large" |
| | ... ) |
| | |
| | >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32) |
| | >>> encoder_outputs = model.encode(inputs) |
| | ```""" |
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| | ) |
| | return_dict = return_dict if return_dict is not None else self.config.return_dict |
| |
|
| | if attention_mask is None: |
| | attention_mask = jnp.ones_like(inputs, dtype="i4") |
| |
|
| | if extract_features is not None: |
| | extract_features = jnp.array(extract_features, dtype="f4") |
| |
|
| | |
| | rngs = {} |
| | if dropout_rng is not None: |
| | rngs["dropout"] = dropout_rng |
| |
|
| | def _encoder_forward(module, inputs, attention_mask, **kwargs): |
| | encode_module = module._get_encoder_module() |
| | return encode_module(inputs, attention_mask, **kwargs) |
| |
|
| | outputs = self.module.apply( |
| | {"params": params or self.params}, |
| | inputs=jnp.array(inputs, dtype="f4"), |
| | attention_mask=jnp.array(attention_mask, dtype="i4"), |
| | extract_features=extract_features, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | output_features=output_features, |
| | return_dict=return_dict, |
| | deterministic=not train, |
| | freeze_feature_encoder=freeze_feature_encoder, |
| | rngs=rngs, |
| | method=_encoder_forward, |
| | ) |
| |
|
| | if return_dict and not output_features: |
| | outputs = FlaxBaseModelOutput( |
| | last_hidden_state=outputs.last_hidden_state, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | ) |
| |
|
| | return outputs |
| |
|
| | @add_start_docstrings(SPEECH_ENCODER_DECODER_DECODE_INPUTS_DOCSTRING) |
| | @replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) |
| | def decode( |
| | self, |
| | decoder_input_ids, |
| | encoder_outputs, |
| | encoder_attention_mask: Optional[jnp.ndarray] = None, |
| | decoder_attention_mask: Optional[jnp.ndarray] = None, |
| | decoder_position_ids: Optional[jnp.ndarray] = None, |
| | past_key_values: dict = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | train: bool = False, |
| | params: dict = None, |
| | dropout_rng: PRNGKey = None, |
| | ): |
| | r""" |
| | Returns: |
| | |
| | Example: |
| | |
| | ```python |
| | >>> from transformers import FlaxSpeechEncoderDecoderModel |
| | >>> import jax.numpy as jnp |
| | |
| | >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized |
| | >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained( |
| | ... "facebook/wav2vec2-large-lv60", "facebook/bart-large" |
| | ... ) |
| | |
| | >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32) |
| | >>> encoder_outputs = model.encode(inputs) |
| | |
| | >>> decoder_start_token_id = model.config.decoder.bos_token_id |
| | >>> decoder_input_ids = jnp.ones((inputs.shape[0], 1), dtype="i4") * decoder_start_token_id |
| | |
| | >>> outputs = model.decode(decoder_input_ids, encoder_outputs) |
| | >>> logits = outputs.logits |
| | ```""" |
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| | ) |
| | return_dict = return_dict if return_dict is not None else self.config.return_dict |
| |
|
| | encoder_hidden_states = encoder_outputs[0] |
| | if encoder_attention_mask is None: |
| | batch_size, sequence_length = encoder_hidden_states.shape[:2] |
| | encoder_attention_mask = jnp.ones((batch_size, sequence_length)) |
| |
|
| | batch_size, sequence_length = decoder_input_ids.shape |
| | if decoder_attention_mask is None: |
| | decoder_attention_mask = jnp.ones((batch_size, sequence_length)) |
| |
|
| | if decoder_position_ids is None: |
| | if past_key_values is not None: |
| | raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.") |
| |
|
| | decoder_position_ids = jnp.broadcast_to( |
| | jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) |
| | ) |
| |
|
| | |
| | rngs = {} |
| | if dropout_rng is not None: |
| | rngs["dropout"] = dropout_rng |
| |
|
| | params = {"params": params or self.params} |
| |
|
| | |
| | |
| | |
| | if past_key_values: |
| | params["cache"] = past_key_values |
| | mutable = ["cache"] |
| | else: |
| | mutable = False |
| |
|
| | def _decoder_forward( |
| | module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs |
| | ): |
| |
|
| | projection_module = module._get_projection_module() |
| | decoder_module = module._get_decoder_module() |
| |
|
| | |
| | if projection_module is not None: |
| | encoder_hidden_states = projection_module(encoder_hidden_states) |
| |
|
| | return decoder_module( |
| | decoder_input_ids, |
| | decoder_attention_mask, |
| | decoder_position_ids, |
| | encoder_hidden_states, |
| | **kwargs, |
| | ) |
| |
|
| | outputs = self.module.apply( |
| | params, |
| | decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), |
| | decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), |
| | decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), |
| | encoder_hidden_states=encoder_hidden_states, |
| | encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"), |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | deterministic=not train, |
| | rngs=rngs, |
| | mutable=mutable, |
| | method=_decoder_forward, |
| | ) |
| |
|
| | |
| | if past_key_values is not None and return_dict: |
| | outputs, past = outputs |
| | outputs["past_key_values"] = unfreeze(past["cache"]) |
| | return outputs |
| | elif past_key_values is not None and not return_dict: |
| | outputs, past = outputs |
| | outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:] |
| |
|
| | return outputs |
| |
|
| | @add_start_docstrings_to_model_forward(SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING) |
| | @replace_return_docstrings(output_type=FlaxSeq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) |
| | def __call__( |
| | self, |
| | inputs: jnp.ndarray, |
| | attention_mask: Optional[jnp.ndarray] = None, |
| | extract_features: Optional[jnp.ndarray] = None, |
| | decoder_input_ids: Optional[jnp.ndarray] = None, |
| | decoder_attention_mask: Optional[jnp.ndarray] = None, |
| | decoder_position_ids: Optional[jnp.ndarray] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | output_features: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | train: bool = False, |
| | freeze_feature_encoder: bool = False, |
| | params: dict = None, |
| | dropout_rng: PRNGKey = None, |
| | ): |
| | r""" |
| | Returns: |
| | |
| | Examples: |
| | |
| | ```python |
| | >>> from transformers import FlaxSpeechEncoderDecoderModel, BartTokenizer |
| | |
| | >>> # load a fine-tuned wav2vec2-2-bart model |
| | >>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("patrickvonplaten/wav2vec2-2-bart-large") |
| | >>> # load output tokenizer |
| | >>> tokenizer_output = BartTokenizer.from_pretrained("facebook/bart-large") |
| | |
| | >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32) |
| | |
| | >>> # use bart's special bos, pad and eos tokens |
| | >>> model.config.decoder_start_token_id = model.decoder.config.bos_token_id |
| | >>> model.config.pad_token_id = model.decoder.config.pad_token_id |
| | >>> model.config.eos_token_id = model.decoder.config.eos_token_id |
| | |
| | >>> outputs = model.generate(inputs) |
| | # Assert something? More interesting input? dtype correct? |
| | ``` |
| | """ |
| |
|
| | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| | output_hidden_states = ( |
| | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| | ) |
| | return_dict = return_dict if return_dict is not None else self.config.return_dict |
| |
|
| | |
| | if attention_mask is None: |
| | attention_mask = jnp.ones_like(inputs, dtype="i4") |
| |
|
| | if extract_features is not None: |
| | inputs = None |
| | extract_features = jnp.array(extract_features, dtype="f4") |
| | else: |
| | inputs = jnp.array(inputs, dtype="f4") |
| |
|
| | |
| | if decoder_input_ids is None: |
| | raise ValueError( |
| | "`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must be specified as an input argument." |
| | ) |
| | if decoder_attention_mask is None: |
| | decoder_attention_mask = jnp.ones_like(decoder_input_ids) |
| | if decoder_position_ids is None: |
| | batch_size, sequence_length = decoder_input_ids.shape |
| | decoder_position_ids = jnp.broadcast_to( |
| | jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) |
| | ) |
| |
|
| | |
| | rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} |
| |
|
| | return self.module.apply( |
| | {"params": params or self.params}, |
| | inputs=inputs, |
| | attention_mask=jnp.array(attention_mask, dtype="i4"), |
| | extract_features=extract_features, |
| | decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), |
| | decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), |
| | decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | output_features=output_features, |
| | return_dict=return_dict, |
| | deterministic=not train, |
| | freeze_feature_encoder=freeze_feature_encoder, |
| | rngs=rngs, |
| | ) |
| |
|
| | def prepare_inputs_for_generation( |
| | self, |
| | decoder_input_ids, |
| | max_length, |
| | attention_mask: Optional[jnp.DeviceArray] = None, |
| | decoder_attention_mask: Optional[jnp.DeviceArray] = None, |
| | encoder_outputs=None, |
| | **kwargs |
| | ): |
| | |
| | batch_size, seq_length = decoder_input_ids.shape |
| |
|
| | past_key_values = self.init_cache(batch_size, max_length, encoder_outputs) |
| | |
| | |
| | |
| | extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") |
| | if decoder_attention_mask is not None: |
| | decoder_position_ids = decoder_attention_mask.cumsum(axis=-1) - 1 |
| | extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0)) |
| | else: |
| | decoder_position_ids = jnp.broadcast_to( |
| | jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length) |
| | ) |
| |
|
| | return { |
| | "past_key_values": past_key_values, |
| | "encoder_outputs": encoder_outputs, |
| | "encoder_attention_mask": attention_mask, |
| | "decoder_attention_mask": extended_attention_mask, |
| | "decoder_position_ids": decoder_position_ids, |
| | } |
| |
|
| | def update_inputs_for_generation(self, model_outputs, model_kwargs): |
| | model_kwargs["past_key_values"] = model_outputs.past_key_values |
| | model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1 |
| | return model_kwargs |
| |
|
| | @classmethod |
| | def from_encoder_decoder_pretrained( |
| | cls, |
| | encoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, |
| | decoder_pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, |
| | *model_args, |
| | **kwargs |
| | ) -> FlaxPreTrainedModel: |
| | r""" |
| | Instantiate an encoder and a decoder from one or two base classes of the library from pretrained model |
| | checkpoints. |
| | |
| | Params: |
| | encoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*): |
| | Information necessary to initiate the encoder. Can be either: |
| | |
| | - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. |
| | Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a |
| | user or organization name, like `dbmdz/bert-base-german-cased`. |
| | - A path to a *directory* containing model weights saved using |
| | [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. |
| | |
| | decoder_pretrained_model_name_or_path (`Union[str, os.PathLike]`, *optional*, defaults to `None`): |
| | Information necessary to initiate the decoder. Can be either: |
| | |
| | - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. |
| | Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a |
| | user or organization name, like `dbmdz/bert-base-german-cased`. |
| | - A path to a *directory* containing model weights saved using |
| | [`~FlaxPreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. |
| | |
| | model_args (remaining positional arguments, *optional*): |
| | All remaning positional arguments will be passed to the underlying model's `__init__` method. |
| | |
| | kwargs (remaining dictionary of keyword arguments, *optional*): |
| | Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., |
| | `output_attentions=True`). |
| | |
| | - To update the encoder configuration, use the prefix *encoder_* for each configuration parameter. |
| | - To update the decoder configuration, use the prefix *decoder_* for each configuration parameter. |
| | - To update the parent model configuration, do not use a prefix for each configuration parameter. |
| | |
| | Behaves differently depending on whether a `config` is provided or automatically loaded. |
| | |
| | Example: |
| | |
| | ```python |
| | >>> from transformers import FlaxSpeechEncoderDecoderModel |
| | |
| | >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized |
| | >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained( |
| | ... "facebook/wav2vec2-large-lv60", "facebook/bart-large" |
| | ... ) |
| | >>> # saving model after fine-tuning |
| | >>> model.save_pretrained("./wav2vec2-2-bart-large") |
| | >>> # load fine-tuned model |
| | >>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("./wav2vec2-2-bart-large") |
| | ```""" |
| |
|
| | kwargs_encoder = { |
| | argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_") |
| | } |
| |
|
| | kwargs_decoder = { |
| | argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") |
| | } |
| |
|
| | |
| | for key in kwargs_encoder.keys(): |
| | del kwargs["encoder_" + key] |
| | for key in kwargs_decoder.keys(): |
| | del kwargs["decoder_" + key] |
| |
|
| | |
| | |
| | |
| | encoder = kwargs_encoder.pop("model", None) |
| | if encoder is None: |
| | if encoder_pretrained_model_name_or_path is None: |
| | raise ValueError( |
| | "If `encoder_model` is not defined as an argument, a `encoder_pretrained_model_name_or_path` has " |
| | "to be defined." |
| | ) |
| |
|
| | if "config" not in kwargs_encoder: |
| | |
| | encoder_config, kwargs_encoder = Wav2Vec2Config.from_pretrained( |
| | encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True |
| | ) |
| | if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: |
| | logger.info( |
| | f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " |
| | "from a decoder model. Cross-attention and casual mask are disabled." |
| | ) |
| | encoder_config.is_decoder = False |
| | encoder_config.add_cross_attention = False |
| |
|
| | kwargs_encoder["config"] = encoder_config |
| |
|
| | |
| | encoder = FlaxWav2Vec2Model.from_pretrained( |
| | encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder |
| | ) |
| |
|
| | decoder = kwargs_decoder.pop("model", None) |
| | if decoder is None: |
| | if decoder_pretrained_model_name_or_path is None: |
| | raise ValueError( |
| | "If `decoder_model` is not defined as an argument, a `decoder_pretrained_model_name_or_path` has " |
| | "to be defined." |
| | ) |
| |
|
| | if "config" not in kwargs_decoder: |
| | |
| | decoder_config, kwargs_decoder = BartConfig.from_pretrained( |
| | decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True |
| | ) |
| | if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: |
| | logger.info( |
| | f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. " |
| | f"Cross attention layers are added to {decoder_pretrained_model_name_or_path} " |
| | f"and randomly initialized if {decoder_pretrained_model_name_or_path}'s architecture allows for " |
| | "cross attention layers." |
| | ) |
| | decoder_config.is_decoder = True |
| | decoder_config.add_cross_attention = True |
| |
|
| | kwargs_decoder["config"] = decoder_config |
| |
|
| | if kwargs_decoder["config"].is_decoder is False or kwargs_decoder["config"].add_cross_attention is False: |
| | logger.warning( |
| | f"Decoder model {decoder_pretrained_model_name_or_path} is not initialized as a decoder. " |
| | f"In order to initialize {decoder_pretrained_model_name_or_path} as a decoder, " |
| | "make sure that the attributes `is_decoder` and `add_cross_attention` of `decoder_config` " |
| | "passed to `.from_encoder_decoder_pretrained(...)` are set to `True` or do not pass a " |
| | "`decoder_config` to `.from_encoder_decoder_pretrained(...)`" |
| | ) |
| |
|
| | |
| | decoder = FlaxBartForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) |
| |
|
| | |
| | dtype = kwargs.pop("dtype", jnp.float32) |
| | config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) |
| |
|
| | |
| | config.tie_word_embeddings = False |
| |
|
| | |
| | model = cls(config, dtype=dtype) |
| | model.params["encoder"] = encoder.params |
| | model.params["decoder"] = decoder.params |
| |
|
| | return model |
| |
|
| | def _beam_search( |
| | self, |
| | input_ids: None, |
| | max_length: Optional[int] = None, |
| | pad_token_id: Optional[int] = None, |
| | eos_token_id: Optional[int] = None, |
| | length_penalty: Optional[float] = None, |
| | early_stopping: Optional[bool] = None, |
| | logits_processor: Optional[FlaxLogitsProcessorList] = None, |
| | trace: bool = True, |
| | params: Optional[Dict[str, jnp.ndarray]] = None, |
| | model_kwargs: Optional[Dict[str, jnp.ndarray]] = None, |
| | ): |
| | """ |
| | This beam search function is heavily inspired by Flax's official example: |
| | https://github.com/google/flax/blob/master/examples/wmt/train.py#L254 |
| | """ |
| |
|
| | def flatten_beam_dim(tensor): |
| | """Flattens the first two dimensions of a non-scalar array.""" |
| | |
| | if tensor.ndim == 0 or tensor.ndim == 1: |
| | return tensor |
| | elif tensor.ndim == 6: |
| | return tensor.reshape(tensor.shape[:1] + (tensor.shape[1] * tensor.shape[2],) + tensor.shape[3:]) |
| | return tensor.reshape((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:]) |
| |
|
| | def unflatten_beam_dim(tensor, batch_size, num_beams): |
| | """Unflattens the first, flat batch*beam dimension of a non-scalar array.""" |
| | |
| | if tensor.ndim == 0 or tensor.ndim == 1: |
| | return tensor |
| | if tensor.ndim == 5: |
| | return tensor.reshape(tensor.shape[:1] + (batch_size, num_beams) + tensor.shape[2:]) |
| | return tensor.reshape((batch_size, num_beams) + tensor.shape[1:]) |
| |
|
| | def gather_beams(nested, beam_indices, batch_size, new_num_beams): |
| | """ |
| | Gathers the beam slices indexed by beam_indices into new beam array. |
| | """ |
| | batch_indices = jnp.reshape( |
| | jnp.arange(batch_size * new_num_beams) // new_num_beams, (batch_size, new_num_beams) |
| | ) |
| |
|
| | def gather_fn(tensor): |
| | |
| | if tensor.ndim == 0 or tensor.ndim == 1: |
| | return tensor |
| | if tensor.ndim == 6: |
| | return tensor[:, batch_indices, beam_indices] |
| | return tensor[batch_indices, beam_indices] |
| |
|
| | return jax.tree_map(gather_fn, nested) |
| |
|
| | |
| | max_length = max_length if max_length is not None else self.config.max_length |
| | pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id |
| | eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id |
| | length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty |
| | early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping |
| |
|
| | batch_size, num_beams, cur_len = input_ids.shape |
| |
|
| | eos_token_id = jnp.array(eos_token_id) |
| | pad_token_id = jnp.array(pad_token_id) |
| | cur_len = jnp.array(cur_len) |
| |
|
| | |
| | sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32) |
| | running_sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32) |
| | running_sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0, 0)) |
| |
|
| | |
| | is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_) |
| |
|
| | |
| | running_scores = jnp.tile(jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)), [batch_size, 1]) |
| | scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7) |
| |
|
| | |
| | |
| | model = self.decode if self.config.is_encoder_decoder else self |
| |
|
| | |
| | if "encoder_outputs" in model_kwargs: |
| | model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim( |
| | model_kwargs["encoder_outputs"]["last_hidden_state"] |
| | ) |
| | if "attention_mask" in model_kwargs: |
| | model_kwargs["attention_mask"] = flatten_beam_dim(model_kwargs["attention_mask"]) |
| |
|
| | |
| | model_kwargs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), max_length, **model_kwargs) |
| |
|
| | |
| | state = BeamSearchState( |
| | cur_len=cur_len, |
| | running_sequences=running_sequences, |
| | running_scores=running_scores, |
| | sequences=sequences, |
| | scores=scores, |
| | is_sent_finished=is_sent_finished, |
| | model_kwargs=model_kwargs, |
| | ) |
| |
|
| | def beam_search_cond_fn(state): |
| | """beam search state termination condition fn.""" |
| |
|
| | |
| | not_max_length_yet = state.cur_len < max_length |
| |
|
| | |
| | best_running_score = state.running_scores[:, -1:] / (max_length**length_penalty) |
| | worst_finished_score = jnp.where( |
| | state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7) |
| | ) |
| | improvement_still_possible = jnp.all(worst_finished_score < best_running_score) |
| |
|
| | |
| | still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping) |
| |
|
| | return not_max_length_yet & still_open_beam & improvement_still_possible |
| |
|
| | def beam_search_body_fn(state, input_ids_length=1): |
| | """beam search state update fn.""" |
| | |
| | |
| | |
| | |
| | |
| | |
| | input_token = flatten_beam_dim( |
| | lax.dynamic_slice( |
| | state.running_sequences, |
| | (0, 0, state.cur_len - input_ids_length), |
| | (batch_size, num_beams, input_ids_length), |
| | ) |
| | ) |
| | model_outputs = model(input_token, params=params, **state.model_kwargs) |
| |
|
| | logits = unflatten_beam_dim(model_outputs.logits[:, -1], batch_size, num_beams) |
| | cache = jax.tree_map( |
| | lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values |
| | ) |
| |
|
| | |
| | logits = self._adapt_logits_for_beam_search(logits) |
| |
|
| | |
| | |
| | |
| | |
| | log_probs = jax.nn.log_softmax(logits) |
| | log_probs = logits_processor( |
| | flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), state.cur_len |
| | ) |
| | log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams) |
| | log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2) |
| | vocab_size = log_probs.shape[2] |
| | log_probs = log_probs.reshape((batch_size, num_beams * vocab_size)) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | beams_to_keep = 2 * num_beams |
| | topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep) |
| | topk_beam_indices = topk_indices // vocab_size |
| | topk_running_sequences = gather_beams( |
| | state.running_sequences, topk_beam_indices, batch_size, beams_to_keep |
| | ) |
| | topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2) |
| | topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len)) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id |
| | running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7) |
| | |
| | |
| | |
| | next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs, k=num_beams)[1], axis=1) |
| | next_running_sequences, next_running_scores = gather_beams( |
| | [topk_sequences, running_topk_log_probs], next_topk_indices, batch_size, num_beams |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | topk_log_probs = topk_log_probs / (state.cur_len**length_penalty) |
| | beams_in_batch_are_full = ( |
| | jnp.broadcast_to(state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape) |
| | & early_stopping |
| | ) |
| | add_penalty = ~did_topk_just_finished | beams_in_batch_are_full |
| | topk_log_probs += add_penalty * np.array(-1.0e7) |
| |
|
| | |
| | |
| | |
| | |
| | merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1) |
| | merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1) |
| | merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1) |
| | topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1) |
| | next_sequences, next_scores, next_is_sent_finished = gather_beams( |
| | [merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams |
| | ) |
| |
|
| | |
| | |
| | |
| | next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams) |
| | next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams) |
| | model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache) |
| | next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs) |
| |
|
| | return BeamSearchState( |
| | cur_len=state.cur_len + 1, |
| | running_scores=next_running_scores, |
| | running_sequences=next_running_sequences, |
| | scores=next_scores, |
| | sequences=next_sequences, |
| | is_sent_finished=next_is_sent_finished, |
| | model_kwargs=next_model_kwargs, |
| | ) |
| |
|
| | |
| | if input_ids.shape[-1] > 1: |
| | state = partial(beam_search_body_fn, input_ids_length=input_ids.shape[-1])(state) |
| |
|
| | if not trace: |
| | state = self._run_loop_in_debug(beam_search_cond_fn, beam_search_body_fn, state) |
| | else: |
| | state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn, state) |
| |
|
| | |
| | |
| | none_finished = jnp.any(state.is_sent_finished, axis=1) |
| | sequences = jnp.where(none_finished[:, None, None], state.sequences, state.running_sequences) |
| | scores = jnp.where(none_finished[:, None], state.scores, state.running_scores) |
| |
|
| | |
| | sequences = sequences[:, :] |
| | scores = scores[:, -1] |
| |
|
| | return FlaxBeamSearchOutput(sequences=sequences, scores=scores) |
| |
|