| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ Flax Wav2Vec2 model.""" |
|
|
| from functools import partial |
| from typing import Optional, Tuple, Union |
|
|
| import flax |
| import flax.linen as nn |
| import jax |
| import jax.numpy as jnp |
| from flax.core.frozen_dict import FrozenDict |
| from flax.linen import partitioning as nn_partitioning |
| from flax.linen.attention import dot_product_attention_weights |
| from jax import lax |
|
|
| from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput |
| from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel |
| from transformers.utils import ModelOutput |
|
|
| from models import Wav2Vec2Config |
|
|
| scan_with_axes = nn_partitioning.scan_with_axes |
| remat = nn_partitioning.remat |
|
|
|
|
| @flax.struct.dataclass |
| class FlaxWav2Vec2BaseModelOutput(ModelOutput): |
| """ |
| Output type of [`FlaxWav2Vec2BaseModelOutput`], with potential hidden states and attentions. |
| |
| Args: |
| last_hidden_state (`jnp.ndarray` of shape `(batch_size, sequence_length, hidden_size)`): |
| Sequence of hidden-states at the output of the last layer of the model. |
| extract_features (`jnp.ndarray` of shape `(batch_size, sequence_length, last_conv_dim)`): |
| Sequence of extracted feature vectors of the last convolutional layer of the model with `last_conv_dim` |
| being the dimension of the last convolutional layer. |
| hidden_states (`tuple(jnp.ndarray)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): |
| Tuple of `jnp.ndarray` (one for the output of the embeddings + one for the output of each layer) of shape |
| `(batch_size, sequence_length, hidden_size)`. |
| |
| Hidden-states of the model at the output of each layer plus the initial embedding outputs. |
| attentions (`tuple(jnp.ndarray)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): |
| Tuple of `jnp.ndarray` (one for each layer) of shape `(batch_size, num_heads, sequence_length, |
| sequence_length)`. |
| |
| Attentions weights after the attention softmax, used to compute the weighted average in the self-attention |
| heads. |
| """ |
|
|
| last_hidden_state: jnp.ndarray = None |
| extract_features: jnp.ndarray = None |
| hidden_states: Optional[Tuple[jnp.ndarray]] = None |
| attentions: Optional[Tuple[jnp.ndarray]] = None |
|
|
|
|
| WAV_2_VEC_2_START_DOCSTRING = r""" |
| Wav2Vec2 was proposed in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech |
| Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael |
| Auli. |
| |
| This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the |
| library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads |
| etc.) |
| |
| This model is also a Flax Linen |
| [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a |
| regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior. |
| |
| Finally, this model supports inherent JAX features such as: |
| |
| - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit) |
| - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) |
| - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) |
| - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap) |
| |
| Parameters: |
| config ([`Wav2Vec2Config`]): Model configuration class with all the parameters of the model. |
| Initializing with a config file does not load the weights associated with the model, only the |
| configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. |
| dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`): |
| The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and |
| `jax.numpy.bfloat16` (on TPUs). |
| |
| This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If |
| specified all the computation will be performed with the given `dtype`. |
| |
| **Note that this only specifies the dtype of the computation and does not influence the dtype of model |
| parameters.** |
| |
| If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and |
| [`~FlaxPreTrainedModel.to_bf16`]. |
| """ |
|
|
|
|
| WAV_2_VEC_2_INPUTS_DOCSTRING = r""" |
| Args: |
| input_values (`jnp.ndarray` of shape `(batch_size, sequence_length)`): |
| Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file |
| into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install |
| soundfile*). To prepare the array into *input_values*, the [`Wav2Vec2Processor`] should be used for padding |
| and conversion into a tensor of type *jnp.ndarray*. See [`Wav2Vec2Processor.__call__`] for details. |
| attention_mask (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): |
| Mask to avoid performing convolution and attention on padding token indices. Mask values selected in `[0, |
| 1]`: |
| |
| - 1 for tokens that are **not masked**, |
| - 0 for tokens that are **masked**. |
| |
| [What are attention masks?](../glossary#attention-mask) .. warning:: `attention_mask` should only be passed |
| if the corresponding processor has `config.return_attention_mask == True`. For all models whose processor |
| has `config.return_attention_mask == False`, such as |
| [wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base-960h), `attention_mask` should **not** be |
| passed to avoid degraded performance when doing batched inference. For such models `input_values` should |
| simply be padded with 0 and passed without `attention_mask`. Be aware that these models also yield slightly |
| different results depending on whether `input_values` is padded or not. |
| mask_time_indices (`jnp.ndarray` of shape `(batch_size, sequence_length)`, *optional*): |
| Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict |
| masked extracted features in *config.proj_codevector_dim* space. |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned |
| tensors for more detail. |
| output_hidden_states (`bool`, *optional*): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for |
| more detail. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| """ |
|
|
|
|
| class FlaxWav2Vec2LayerNormConvLayer(nn.Module): |
| config: Wav2Vec2Config |
| layer_id: int = 0 |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.in_conv_dim = self.config.conv_dim[self.layer_id] if self.layer_id > 0 else 1 |
| self.out_conv_dim = self.config.conv_dim[self.layer_id] |
|
|
| self.conv = nn.Conv( |
| features=self.config.conv_dim[self.layer_id], |
| kernel_size=(self.config.conv_kernel[self.layer_id],), |
| strides=(self.config.conv_stride[self.layer_id],), |
| use_bias=self.config.conv_bias, |
| kernel_init=jax.nn.initializers.he_normal(), |
| padding="VALID", |
| dtype=self.dtype, |
| ) |
| self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) |
| self.activation = ACT2FN[self.config.feat_extract_activation] |
|
|
| def __call__(self, hidden_states): |
| hidden_states = self.conv(hidden_states) |
| hidden_states = self.layer_norm(hidden_states) |
| hidden_states = self.activation(hidden_states) |
| return hidden_states |
|
|
|
|
| class FlaxConvWithWeightNorm(nn.Module): |
| config: Wav2Vec2Config |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.conv = nn.Conv( |
| features=self.config.hidden_size, |
| kernel_size=(self.config.num_conv_pos_embeddings,), |
| kernel_init=jax.nn.initializers.he_normal(), |
| padding="VALID", |
| feature_group_count=self.config.num_conv_pos_embedding_groups, |
| dtype=self.dtype, |
| ) |
| weight_shape = ( |
| self.conv.features, |
| self.conv.features // self.conv.feature_group_count, |
| self.conv.kernel_size[0], |
| ) |
| self.weight_v = self.param("weight_v", jax.nn.initializers.he_normal(), weight_shape) |
| self.weight_g = self.param("weight_g", lambda _: jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :]) |
| self.bias = self.param("bias", jax.nn.initializers.zeros, (self.conv.features,)) |
| self.prev_padding = self.conv.kernel_size[0] // 2 |
|
|
| def _get_normed_weights(self): |
| weight_v_norm = jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :] |
| normed_weight_v = jnp.divide(self.weight_v, weight_v_norm) |
| normed_kernel = jnp.multiply(normed_weight_v, self.weight_g) |
| return normed_kernel |
|
|
| def __call__(self, hidden_states): |
| kernel = self._get_normed_weights() |
| hidden_states = jnp.pad(hidden_states, ((0, 0), (self.prev_padding, self.prev_padding), (0, 0))) |
| hidden_states = self.conv.apply({"params": {"kernel": kernel.T, "bias": self.bias}}, hidden_states) |
| return hidden_states |
|
|
|
|
| class FlaxWav2Vec2PositionalConvEmbedding(nn.Module): |
| config: Wav2Vec2Config |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.conv = FlaxConvWithWeightNorm(self.config, dtype=self.dtype) |
| self.activation = ACT2FN[self.config.feat_extract_activation] |
| self.num_pad_remove = 1 if self.config.num_conv_pos_embeddings % 2 == 0 else 0 |
|
|
| def __call__(self, hidden_states): |
| hidden_states = hidden_states.transpose((0, 1, 2)) |
|
|
| hidden_states = self.conv(hidden_states) |
|
|
| if self.num_pad_remove > 0: |
| hidden_states = hidden_states[:, : -self.num_pad_remove, :] |
| hidden_states = self.activation(hidden_states) |
|
|
| hidden_states = hidden_states.transpose((0, 1, 2)) |
| return hidden_states |
|
|
|
|
| class FlaxConvLayersCollection(nn.Module): |
| config: Wav2Vec2Config |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| if self.config.feat_extract_norm == "layer": |
| |
| BlockLayer = remat(FlaxWav2Vec2LayerNormConvLayer) if self.config.gradient_checkpointing else FlaxWav2Vec2LayerNormConvLayer |
| self.layers = [ |
| BlockLayer(self.config, layer_id=i, name=str(i), dtype=self.dtype) |
| for i in range(self.config.num_feat_extract_layers) |
| ] |
| elif self.config.feat_extract_norm == "group": |
| raise NotImplementedError("At the moment only ``config.feat_extact_norm == 'layer'`` is supported") |
| else: |
| raise ValueError( |
| f"`config.feat_extract_norm` is {self.config.feat_extract_norm}, but has to be one of ['group', 'layer']" |
| ) |
|
|
| def __call__(self, hidden_states): |
| for i, conv_layer in enumerate(self.layers): |
| hidden_states = conv_layer(hidden_states) |
| return hidden_states |
|
|
|
|
| class FlaxWav2Vec2FeatureEncoder(nn.Module): |
| """Construct the features from raw audio waveform""" |
|
|
| config: Wav2Vec2Config |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.conv_layers = FlaxConvLayersCollection(self.config, dtype=self.dtype) |
|
|
| def __call__(self, input_values, freeze_feature_encoder=False): |
| hidden_states = input_values[:, :, None] |
| hidden_states = self.conv_layers(hidden_states) |
| if freeze_feature_encoder: |
| hidden_states = jax.lax.stop_gradient(hidden_states) |
| return hidden_states |
|
|
|
|
| class FlaxWav2Vec2FeatureProjection(nn.Module): |
| config: Wav2Vec2Config |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) |
| self.projection = nn.Dense( |
| self.config.hidden_size, |
| kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
| dtype=self.dtype, |
| ) |
| self.dropout = nn.Dropout(rate=self.config.feat_proj_dropout) |
|
|
| def __call__(self, hidden_states, deterministic=True): |
| norm_hidden_states = self.layer_norm(hidden_states) |
| hidden_states = self.projection(norm_hidden_states) |
| hidden_states = self.dropout(hidden_states, deterministic=deterministic) |
| return hidden_states, norm_hidden_states |
|
|
|
|
| class FlaxWav2Vec2Attention(nn.Module): |
| config: Wav2Vec2Config |
| embed_dim: int |
| num_heads: int |
| dropout: float = 0.0 |
| bias: bool = True |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self) -> None: |
| self.head_dim = self.embed_dim // self.num_heads |
| if self.head_dim * self.num_heads != self.embed_dim: |
| raise ValueError( |
| f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." |
| ) |
|
|
| dense = partial( |
| nn.Dense, |
| self.embed_dim, |
| use_bias=self.bias, |
| dtype=self.dtype, |
| kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
| ) |
|
|
| self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() |
|
|
| self.fused_proj = nn.Dense( |
| self.embed_dim * 3, |
| use_bias=self.bias, |
| dtype=self.dtype, |
| kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
| ) |
|
|
| self.out_proj = dense() |
|
|
| self.dropout_layer = nn.Dropout(rate=self.dropout) |
|
|
| def _split_heads(self, hidden_states): |
| return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim)) |
|
|
| def _merge_heads(self, hidden_states): |
| return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) |
|
|
| def __call__( |
| self, |
| hidden_states: jnp.ndarray, |
| key_value_states: Optional[jnp.ndarray] = None, |
| attention_mask: Optional[jnp.ndarray] = None, |
| deterministic: bool = True, |
| ) -> Tuple[jnp.ndarray]: |
| """Input shape: Batch x Time x Channel""" |
|
|
| if self.config.fuse_matmuls: |
| attention_states = self.fused_proj(hidden_states) |
| query_states, key_states, value_states = jnp.split(attention_states, 3, axis=-1) |
|
|
| else: |
| |
| query_states = self.q_proj(hidden_states) |
|
|
| key_states = self.k_proj(hidden_states) |
| value_states = self.v_proj(hidden_states) |
|
|
| query_states = self._split_heads(query_states) |
| key_states = self._split_heads(key_states) |
| value_states = self._split_heads(value_states) |
|
|
| if attention_mask is not None: |
| attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) |
|
|
| |
| if attention_mask is not None: |
| |
| attention_bias = lax.select( |
| attention_mask > 0, |
| jnp.full(attention_mask.shape, 0.0).astype(self.dtype), |
| jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype), |
| ) |
| else: |
| attention_bias = None |
|
|
| dropout_rng = None |
| if not deterministic and self.dropout > 0.0: |
| dropout_rng = self.make_rng("dropout") |
|
|
| attn_weights = dot_product_attention_weights( |
| query_states, |
| key_states, |
| bias=attention_bias, |
| dropout_rng=dropout_rng, |
| dropout_rate=self.dropout, |
| broadcast_dropout=True, |
| deterministic=deterministic, |
| dtype=self.dtype, |
| precision=None, |
| ) |
|
|
| attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) |
| attn_output = self._merge_heads(attn_output) |
| attn_output = self.out_proj(attn_output) |
|
|
| return attn_output, attn_weights |
|
|
|
|
| class FlaxWav2Vec2FeedForward(nn.Module): |
| config: Wav2Vec2Config |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.intermediate_dropout = nn.Dropout(rate=self.config.activation_dropout) |
|
|
| self.intermediate_dense = nn.Dense( |
| self.config.intermediate_size, |
| kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
| dtype=self.dtype, |
| ) |
| if isinstance(self.config.hidden_act, str): |
| self.intermediate_act_fn = ACT2FN[self.config.hidden_act] |
| else: |
| self.intermediate_act_fn = self.config.hidden_act |
|
|
| self.output_dense = nn.Dense( |
| self.config.hidden_size, |
| kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
| dtype=self.dtype, |
| ) |
| self.output_dropout = nn.Dropout(rate=self.config.hidden_dropout) |
|
|
| def __call__(self, hidden_states, deterministic=True): |
| hidden_states = self.intermediate_dense(hidden_states) |
| hidden_states = self.intermediate_act_fn(hidden_states) |
| hidden_states = self.intermediate_dropout(hidden_states, deterministic=deterministic) |
|
|
| hidden_states = self.output_dense(hidden_states) |
| hidden_states = self.output_dropout(hidden_states, deterministic=deterministic) |
| return hidden_states |
|
|
|
|
| class FlaxWav2Vec2EncoderLayerStableLayerNorm(nn.Module): |
| config: Wav2Vec2Config |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.attention = FlaxWav2Vec2Attention( |
| config=self.config, |
| embed_dim=self.config.hidden_size, |
| num_heads=self.config.num_attention_heads, |
| dropout=self.config.attention_dropout, |
| dtype=self.dtype, |
| ) |
| self.dropout = nn.Dropout(rate=self.config.hidden_dropout) |
| self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) |
| self.feed_forward = FlaxWav2Vec2FeedForward(self.config, dtype=self.dtype) |
| self.final_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) |
|
|
| def __call__(self, hidden_states, attention_mask=None, deterministic=True, output_attentions=False): |
| if self.config.use_scan: |
| hidden_states = hidden_states[0] |
| attn_residual = hidden_states |
| hidden_states = self.layer_norm(hidden_states) |
| hidden_states, attn_weights = self.attention( |
| hidden_states, attention_mask=attention_mask, deterministic=deterministic |
| ) |
| hidden_states = self.dropout(hidden_states, deterministic=deterministic) |
| hidden_states = attn_residual + hidden_states |
| hidden_states = hidden_states + self.feed_forward( |
| self.final_layer_norm(hidden_states), deterministic=deterministic |
| ) |
|
|
| outputs = (hidden_states,) |
|
|
| if output_attentions: |
| outputs += (attn_weights,) |
|
|
| if self.config.use_scan: |
| outputs = (outputs, None) |
|
|
| return outputs |
|
|
|
|
| class FlaxWav2Vec2EncoderLayerStableLayerNormCollection(nn.Module): |
| config: Wav2Vec2Config |
| dtype: jnp.dtype = jnp.float32 |
|
|
| @nn.compact |
| def __call__( |
| self, |
| hidden_states, |
| attention_mask=None, |
| deterministic: bool = True, |
| output_attentions: bool = False, |
| output_hidden_states: bool = False, |
| return_dict: bool = True, |
| ): |
| all_attentions = () if output_attentions else None |
| all_hidden_states = () if output_hidden_states else None |
|
|
| num_layers = self.config.num_hidden_layers |
| BlockEncoderLayer = ( |
| remat( |
| FlaxWav2Vec2EncoderLayerStableLayerNorm, |
| static_argnums=(2, 3), |
| prevent_cse=not self.config.use_scan, |
| ) |
| if self.config.gradient_checkpointing |
| else FlaxWav2Vec2EncoderLayerStableLayerNorm |
| ) |
|
|
| if self.config.use_scan: |
| |
| assert not output_attentions, "cannot use `scan` with `output_attentions` set to `True`" |
| assert not output_hidden_states, "cannot use `scan` with `output_hidden_states` set to `True`" |
| hidden_states = (hidden_states,) |
|
|
| hidden_states, _ = scan_with_axes( |
| BlockEncoderLayer, |
| variable_axes={"params": 0, "cache": 0}, |
| split_rngs={"params": True, "dropout": True}, |
| in_axes=(nn.broadcast, nn.broadcast, nn.broadcast), |
| length=num_layers, |
| )(self.config, dtype=self.dtype, name="FlaxWav2Vec2EncoderLayers",)( |
| hidden_states, attention_mask, deterministic, output_attentions |
| ) |
| hidden_states = hidden_states[0] |
|
|
| else: |
| for layer in range(num_layers): |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| layer_outputs = BlockEncoderLayer( |
| self.config, |
| dtype=self.dtype, |
| name=str(layer), |
| )(hidden_states, attention_mask, deterministic, output_attentions) |
|
|
| hidden_states = layer_outputs[0] |
|
|
| if output_attentions: |
| all_attentions += (layer_outputs[1],) |
|
|
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| outputs = (hidden_states, all_hidden_states, all_attentions) |
|
|
| if not return_dict: |
| return tuple(v for v in outputs if v is not None) |
|
|
| return FlaxBaseModelOutput( |
| last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions |
| ) |
|
|
|
|
| class FlaxWav2Vec2StableLayerNormEncoder(nn.Module): |
| config: Wav2Vec2Config |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.pos_conv_embed = FlaxWav2Vec2PositionalConvEmbedding(self.config, dtype=self.dtype) |
| self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) |
| self.dropout = nn.Dropout(rate=self.config.hidden_dropout) |
| self.layers = FlaxWav2Vec2EncoderLayerStableLayerNormCollection(self.config, dtype=self.dtype) |
|
|
| def __call__( |
| self, |
| hidden_states, |
| attention_mask=None, |
| deterministic=True, |
| output_attentions=False, |
| output_hidden_states=False, |
| return_dict=True, |
| ): |
|
|
| if attention_mask is not None: |
| |
| hidden_states = jnp.where( |
| jnp.broadcast_to(attention_mask[:, :, None], hidden_states.shape), hidden_states, 0 |
| ) |
|
|
| position_embeddings = self.pos_conv_embed(hidden_states) |
|
|
| hidden_states = hidden_states + position_embeddings |
| hidden_states = self.dropout(hidden_states, deterministic=deterministic) |
|
|
| outputs = self.layers( |
| hidden_states, |
| attention_mask, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| last_hidden_state = self.layer_norm(outputs[0]) |
|
|
| |
| hidden_states = None |
| if output_hidden_states: |
| hidden_states = outputs[1] |
| hidden_states = hidden_states[:-1] + (last_hidden_state,) |
|
|
| if not return_dict: |
| outputs = (last_hidden_state, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:]) |
| return tuple(v for v in outputs if v is not None) |
|
|
| return FlaxBaseModelOutput( |
| last_hidden_state=last_hidden_state, hidden_states=hidden_states, attentions=outputs.attentions |
| ) |
|
|
|
|
| class FlaxWav2Vec2Adapter(nn.Module): |
| config: Wav2Vec2Config |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| |
| if self.config.output_hidden_size != self.config.hidden_size: |
| self.proj = nn.Dense( |
| self.config.output_hidden_size, |
| kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
| dtype=self.dtype, |
| ) |
| self.proj_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) |
| else: |
| self.proj = self.proj_layer_norm = None |
|
|
| self.layers = FlaxWav2Vec2AdapterLayersCollection(self.config, dtype=self.dtype) |
|
|
| def __call__(self, hidden_states, deterministic=True): |
| |
| if self.proj is not None and self.proj_layer_norm is not None: |
| hidden_states = self.proj(hidden_states) |
| hidden_states = self.proj_layer_norm(hidden_states) |
|
|
| hidden_states = self.layers(hidden_states) |
|
|
| return hidden_states |
|
|
|
|
| class FlaxWav2Vec2AdapterLayer(nn.Module): |
| config: Wav2Vec2Config |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.conv = nn.Conv( |
| features=2 * self.config.output_hidden_size, |
| kernel_size=(self.config.adapter_kernel_size,), |
| strides=(self.config.adapter_stride,), |
| padding=((1, 1),), |
| kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
| dtype=self.dtype, |
| ) |
|
|
| def __call__(self, hidden_states): |
| hidden_states = self.conv(hidden_states) |
| hidden_states = nn.glu(hidden_states, axis=2) |
|
|
| return hidden_states |
|
|
|
|
| class FlaxWav2Vec2AdapterLayersCollection(nn.Module): |
| config: Wav2Vec2Config |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| BlockAdapterLayer = remat(FlaxWav2Vec2AdapterLayer) if self.config.gradient_checkpointing else FlaxWav2Vec2AdapterLayer |
| self.layers = [ |
| BlockAdapterLayer(self.config, name=str(i), dtype=self.dtype) |
| for i in range(self.config.num_adapter_layers) |
| ] |
|
|
| def __call__(self, hidden_states): |
| for conv_layer in self.layers: |
| hidden_states = conv_layer(hidden_states) |
|
|
| return hidden_states |
|
|
|
|
| class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel): |
| """ |
| An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
| models. |
| """ |
|
|
| config_class = Wav2Vec2Config |
| base_model_prefix: str = "wav2vec2" |
| main_input_name = "input_values" |
| module_class: nn.Module = None |
|
|
| def __init__( |
| self, |
| config: Wav2Vec2Config, |
| input_shape: Tuple = (1, 1024), |
| seed: int = 0, |
| dtype: jnp.dtype = jnp.float32, |
| _do_init: bool = True, |
| **kwargs, |
| ): |
| module = self.module_class(config=config, dtype=dtype, **kwargs) |
| super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) |
|
|
| def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: |
| |
| input_values = jnp.zeros(input_shape, dtype="i4") |
| attention_mask = jnp.ones_like(input_values) |
| params_rng, dropout_rng = jax.random.split(rng, 2) |
| rngs = {"params": params_rng, "dropout": dropout_rng} |
|
|
| return self.module.init(rngs, input_values, attention_mask, return_dict=False)["params"] |
|
|
| def __call__( |
| self, |
| input_values, |
| attention_mask=None, |
| mask_time_indices=None, |
| extract_features=None, |
| params: dict = None, |
| dropout_rng: jax.random.PRNGKey = None, |
| train: bool = False, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| output_features: Optional[bool] = None, |
| freeze_feature_encoder: bool = False, |
| return_dict: Optional[bool] = None, |
| ): |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.return_dict |
|
|
| if attention_mask is None: |
| batch_size, sequence_length = input_values.shape |
| attention_mask = jnp.ones((batch_size, sequence_length)) |
|
|
| if extract_features is not None: |
| extract_features = jnp.array(extract_features, dtype="f4") |
|
|
| |
| rngs = {} |
| if dropout_rng is not None: |
| rngs["dropout"] = dropout_rng |
|
|
| inputs = {"params": params or self.params} |
|
|
| return self.module.apply( |
| inputs, |
| jnp.array(input_values, dtype="f4"), |
| jnp.array(attention_mask, dtype="i4"), |
| mask_time_indices, |
| extract_features, |
| not train, |
| output_attentions, |
| output_hidden_states, |
| output_features, |
| freeze_feature_encoder, |
| return_dict, |
| rngs=rngs, |
| ) |
|
|
| def _get_feat_extract_output_lengths( |
| self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None |
| ): |
| return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter) |
|
|
| def _get_feature_vector_attention_mask( |
| self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None |
| ): |
| return self.module._get_feature_vector_attention_mask(feature_vector_length, attention_mask, add_adapter=add_adapter) |
|
|
|
|
| class FlaxWav2Vec2Module(nn.Module): |
| config: Wav2Vec2Config |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.feature_extractor = FlaxWav2Vec2FeatureEncoder(self.config, dtype=self.dtype) |
| self.feature_projection = FlaxWav2Vec2FeatureProjection(self.config, dtype=self.dtype) |
| self.masked_spec_embed = self.param( |
| "masked_spec_embed", jax.nn.initializers.uniform(), (self.config.hidden_size,) |
| ) |
|
|
| if self.config.do_stable_layer_norm: |
| self.encoder = FlaxWav2Vec2StableLayerNormEncoder(self.config, dtype=self.dtype) |
| else: |
| raise NotImplementedError("``config.do_stable_layer_norm is False`` is currently not supported.") |
|
|
| self.adapter = FlaxWav2Vec2Adapter(self.config, dtype=self.dtype) if self.config.add_adapter else None |
|
|
| def __call__( |
| self, |
| input_values, |
| attention_mask=None, |
| mask_time_indices=None, |
| extract_features=None, |
| deterministic=True, |
| output_attentions=None, |
| output_hidden_states=None, |
| output_features=False, |
| freeze_feature_encoder=False, |
| return_dict=None, |
| ): |
|
|
| |
| if extract_features is None: |
| extract_features = self.feature_extractor(input_values, freeze_feature_encoder=freeze_feature_encoder) |
|
|
| if output_features: |
| return extract_features |
|
|
| |
| if attention_mask is not None: |
| |
| attention_mask = self._get_feature_vector_attention_mask( |
| extract_features.shape[1], attention_mask, add_adapter=False |
| ) |
|
|
| hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic) |
| if mask_time_indices is not None: |
| hidden_states = jnp.where( |
| jnp.broadcast_to(mask_time_indices[:, :, None], hidden_states.shape), |
| jnp.broadcast_to(self.masked_spec_embed[None, None, :], hidden_states.shape), |
| hidden_states, |
| ) |
|
|
| encoder_outputs = self.encoder( |
| hidden_states, |
| attention_mask=attention_mask, |
| deterministic=deterministic, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| hidden_states = encoder_outputs[0] |
|
|
| if self.adapter is not None: |
| hidden_states = self.adapter(hidden_states) |
|
|
| if not return_dict: |
| return (hidden_states, extract_features) + encoder_outputs[1:] |
|
|
| return FlaxWav2Vec2BaseModelOutput( |
| last_hidden_state=hidden_states, |
| extract_features=extract_features, |
| hidden_states=encoder_outputs.hidden_states, |
| attentions=encoder_outputs.attentions, |
| ) |
|
|
| def _get_feat_extract_output_lengths( |
| self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None |
| ): |
| """ |
| Computes the output length of the convolutional layers |
| """ |
|
|
| add_adapter = self.config.add_adapter if add_adapter is None else add_adapter |
|
|
| def _conv_out_length(input_length, kernel_size, stride): |
| |
| |
| return (input_length - kernel_size) // stride + 1 |
|
|
| for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): |
| input_lengths = _conv_out_length(input_lengths, kernel_size, stride) |
|
|
| if add_adapter: |
| for _ in range(self.config.num_adapter_layers): |
| input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride) |
|
|
| return input_lengths |
|
|
| def _get_feature_vector_attention_mask( |
| self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None |
| ): |
| |
| |
| non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1] |
|
|
| output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter) |
|
|
| batch_size = attention_mask.shape[0] |
|
|
| attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype) |
| |
| |
| attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1) |
| attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool") |
| return attention_mask |
|
|
|
|
| class FlaxWav2Vec2Model(FlaxWav2Vec2PreTrainedModel): |
| module_class = FlaxWav2Vec2Module |
|
|
|
|
| class FlaxWav2Vec2ForCTCModule(nn.Module): |
| config: Wav2Vec2Config |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.wav2vec2 = FlaxWav2Vec2Module(self.config, dtype=self.dtype) |
| self.dropout = nn.Dropout(rate=self.config.final_dropout) |
| self.lm_head = nn.Dense( |
| self.config.vocab_size, |
| kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
| dtype=self.dtype, |
| ) |
|
|
| def __call__( |
| self, |
| input_values, |
| attention_mask=None, |
| mask_time_indices=None, |
| extract_features=None, |
| deterministic=True, |
| output_attentions=None, |
| output_hidden_states=None, |
| output_features=False, |
| freeze_feature_encoder=False, |
| return_dict=None, |
| ): |
| outputs = self.wav2vec2( |
| input_values, |
| attention_mask=attention_mask, |
| mask_time_indices=mask_time_indices, |
| deterministic=deterministic, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| freeze_feature_encoder=freeze_feature_encoder, |
| return_dict=return_dict, |
| ) |
|
|
| hidden_states = outputs[0] |
| hidden_states = self.dropout(hidden_states, deterministic=deterministic) |
|
|
| logits = self.lm_head(hidden_states) |
|
|
| if not return_dict: |
| return (logits,) + outputs[2:] |
|
|
| return FlaxCausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions) |
|
|
| def _get_feat_extract_output_lengths( |
| self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None |
| ): |
| """ |
| Computes the output length of the convolutional layers |
| """ |
|
|
| add_adapter = self.config.add_adapter if add_adapter is None else add_adapter |
|
|
| def _conv_out_length(input_length, kernel_size, stride): |
| |
| |
| return (input_length - kernel_size) // stride + 1 |
|
|
| for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): |
| input_lengths = _conv_out_length(input_lengths, kernel_size, stride) |
|
|
| if add_adapter: |
| for _ in range(self.config.num_adapter_layers): |
| input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride) |
|
|
| return input_lengths |
|
|
| def _get_feature_vector_attention_mask( |
| self, feature_vector_length: int, attention_mask: jnp.ndarray, add_adapter=None |
| ): |
| |
| |
| non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1] |
|
|
| output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter) |
|
|
| batch_size = attention_mask.shape[0] |
|
|
| attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype) |
| |
| |
| attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1) |
| attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool") |
| return attention_mask |
|
|
|
|
| class FlaxWav2Vec2ForCTC(FlaxWav2Vec2PreTrainedModel): |
| module_class = FlaxWav2Vec2ForCTCModule |
|
|