| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ Flax Hubert model.""" |
|
|
| from functools import partial |
| from typing import Optional, Tuple, Union |
|
|
| import flax |
| import flax.linen as nn |
| import jax |
| import jax.numpy as jnp |
| from flax.core.frozen_dict import FrozenDict, freeze, unfreeze |
| from flax.linen.attention import dot_product_attention_weights |
| from flax.traverse_util import flatten_dict, unflatten_dict |
| from jax import lax |
| from transformers import HubertConfig |
| from transformers.modeling_flax_outputs import FlaxBaseModelOutput |
| from transformers.modeling_flax_utils import ( |
| ACT2FN, |
| FlaxPreTrainedModel, |
| ) |
| from transformers.utils import ModelOutput, logging |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| @flax.struct.dataclass |
| class FlaxHubertOutput(ModelOutput): |
| last_hidden_state: jnp.ndarray = None |
| hidden_states: Optional[Tuple[jnp.ndarray]] = None |
| attentions: Optional[Tuple[jnp.ndarray]] = None |
| extract_features: jnp.ndarray = None |
|
|
|
|
| class FlaxConvWithWeightNorm(nn.Module): |
| config: HubertConfig |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.conv = nn.Conv( |
| features=self.config.hidden_size, |
| kernel_size=(self.config.num_conv_pos_embeddings,), |
| kernel_init=jax.nn.initializers.he_normal(), |
| padding="VALID", |
| feature_group_count=self.config.num_conv_pos_embedding_groups, |
| dtype=self.dtype, |
| ) |
| weight_shape = ( |
| self.conv.features, |
| self.conv.features // self.conv.feature_group_count, |
| self.conv.kernel_size[0], |
| ) |
| self.weight_v = self.param( |
| "weight_v", jax.nn.initializers.he_normal(), weight_shape |
| ) |
| self.weight_g = self.param( |
| "weight_g", |
| lambda _: jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :], |
| ) |
| self.bias = self.param("bias", jax.nn.initializers.zeros, (self.conv.features,)) |
| self.prev_padding = self.conv.kernel_size[0] // 2 |
|
|
| def _get_normed_weights(self): |
| weight_v_norm = jnp.linalg.norm(self.weight_v, axis=(0, 1))[None, None, :] |
| normed_weight_v = jnp.divide(self.weight_v, weight_v_norm) |
| normed_kernel = jnp.multiply(normed_weight_v, self.weight_g) |
| return normed_kernel |
|
|
| def __call__(self, hidden_states): |
| kernel = self._get_normed_weights() |
| hidden_states = jnp.pad( |
| hidden_states, ((0, 0), (self.prev_padding, self.prev_padding), (0, 0)) |
| ) |
| hidden_states = self.conv.apply( |
| {"params": {"kernel": kernel.T, "bias": self.bias}}, hidden_states |
| ) |
| return hidden_states |
|
|
|
|
| class FlaxHubertNoLayerNormConvLayer(nn.Module): |
| config: HubertConfig |
| layer_id: int = 0 |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.in_conv_dim = ( |
| self.config.conv_dim[self.layer_id - 1] if self.layer_id > 0 else 1 |
| ) |
| self.out_conv_dim = self.config.conv_dim[self.layer_id] |
|
|
| self.conv = nn.Conv( |
| features=self.config.conv_dim[self.layer_id], |
| kernel_size=(self.config.conv_kernel[self.layer_id],), |
| strides=(self.config.conv_stride[self.layer_id],), |
| use_bias=self.config.conv_bias, |
| kernel_init=jax.nn.initializers.he_normal(), |
| padding="VALID", |
| dtype=self.dtype, |
| ) |
| self.activation = ACT2FN[self.config.feat_extract_activation] |
|
|
| def __call__(self, hidden_states): |
| hidden_states = self.conv(hidden_states) |
| hidden_states = self.activation(hidden_states) |
| return hidden_states |
|
|
|
|
| class FlaxHubertLayerNormConvLayer(nn.Module): |
| config: HubertConfig |
| layer_id: int = 0 |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.in_conv_dim = ( |
| self.config.conv_dim[self.layer_id - 1] if self.layer_id > 0 else 1 |
| ) |
| self.out_conv_dim = self.config.conv_dim[self.layer_id] |
|
|
| self.conv = nn.Conv( |
| features=self.config.conv_dim[self.layer_id], |
| kernel_size=(self.config.conv_kernel[self.layer_id],), |
| strides=(self.config.conv_stride[self.layer_id],), |
| use_bias=self.config.conv_bias, |
| kernel_init=jax.nn.initializers.he_normal(), |
| padding="VALID", |
| dtype=self.dtype, |
| ) |
| self.layer_norm = nn.LayerNorm( |
| epsilon=self.config.layer_norm_eps, dtype=self.dtype |
| ) |
| self.activation = ACT2FN[self.config.feat_extract_activation] |
|
|
| def __call__(self, hidden_states): |
| hidden_states = self.conv(hidden_states) |
| hidden_states = self.layer_norm(hidden_states) |
| hidden_states = self.activation(hidden_states) |
| return hidden_states |
|
|
|
|
| class FlaxHubertGroupNormConvLayer(nn.Module): |
| config: HubertConfig |
| layer_id: int = 0 |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.in_conv_dim = ( |
| self.config.conv_dim[self.layer_id - 1] if self.layer_id > 0 else 1 |
| ) |
| self.out_conv_dim = self.config.conv_dim[self.layer_id] |
|
|
| self.conv = nn.Conv( |
| features=self.config.conv_dim[self.layer_id], |
| kernel_size=(self.config.conv_kernel[self.layer_id],), |
| strides=(self.config.conv_stride[self.layer_id],), |
| use_bias=self.config.conv_bias, |
| kernel_init=jax.nn.initializers.he_normal(), |
| padding="VALID", |
| dtype=self.dtype, |
| ) |
| self.activation = ACT2FN[self.config.feat_extract_activation] |
|
|
| self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, dtype=self.dtype) |
|
|
| def __call__(self, hidden_states): |
| hidden_states = self.conv(hidden_states) |
| hidden_states = self.layer_norm(hidden_states) |
| hidden_states = self.activation(hidden_states) |
| return hidden_states |
|
|
|
|
| class FlaxHubertPositionalConvEmbedding(nn.Module): |
| config: HubertConfig |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.conv = FlaxConvWithWeightNorm(self.config, dtype=self.dtype) |
| self.activation = ACT2FN[self.config.feat_extract_activation] |
| self.num_pad_remove = 1 if self.config.num_conv_pos_embeddings % 2 == 0 else 0 |
|
|
| def __call__(self, hidden_states): |
| hidden_states = hidden_states.transpose((0, 1, 2)) |
|
|
| hidden_states = self.conv(hidden_states) |
|
|
| if self.num_pad_remove > 0: |
| hidden_states = hidden_states[:, : -self.num_pad_remove, :] |
| hidden_states = self.activation(hidden_states) |
|
|
| hidden_states = hidden_states.transpose((0, 1, 2)) |
| return hidden_states |
|
|
|
|
| class FlaxConvLayersCollection(nn.Module): |
| config: HubertConfig |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| if self.config.feat_extract_norm == "layer": |
| self.layers = [ |
| FlaxHubertLayerNormConvLayer( |
| self.config, layer_id=i, name=str(i), dtype=self.dtype |
| ) |
| for i in range(self.config.num_feat_extract_layers) |
| ] |
| elif self.config.feat_extract_norm == "group": |
| self.layers = [ |
| FlaxHubertGroupNormConvLayer( |
| self.config, layer_id=0, name=str(0), dtype=self.dtype |
| ) |
| ] + [ |
| FlaxHubertNoLayerNormConvLayer( |
| self.config, layer_id=i, name=str(i), dtype=self.dtype |
| ) |
| for i in range(1, self.config.num_feat_extract_layers) |
| ] |
| else: |
| raise ValueError( |
| f"`config.feat_extract_norm` is {self.config.feat_extract_norm}, but has to be one of ['group'," |
| " 'layer']" |
| ) |
|
|
| def __call__(self, hidden_states): |
| for i, conv_layer in enumerate(self.layers): |
| hidden_states = conv_layer(hidden_states) |
| return hidden_states |
|
|
|
|
| class FlaxHubertFeatureEncoder(nn.Module): |
| config: HubertConfig |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.conv_layers = FlaxConvLayersCollection(self.config, dtype=self.dtype) |
|
|
| def __call__(self, input_values, freeze_feature_encoder=False): |
| hidden_states = input_values[:, :, None] |
| hidden_states = self.conv_layers(hidden_states) |
| if freeze_feature_encoder: |
| hidden_states = jax.lax.stop_gradient(hidden_states) |
| return hidden_states |
|
|
|
|
| class FlaxHubertFeatureProjection(nn.Module): |
| config: HubertConfig |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.feat_proj_layer_norm = self.config.feat_proj_layer_norm |
| if self.feat_proj_layer_norm: |
| self.layer_norm = nn.LayerNorm( |
| epsilon=self.config.layer_norm_eps, dtype=self.dtype |
| ) |
| self.projection = nn.Dense( |
| self.config.hidden_size, |
| kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
| dtype=self.dtype, |
| ) |
| self.dropout = nn.Dropout(rate=self.config.feat_proj_dropout) |
|
|
| def __call__(self, hidden_states, deterministic=True): |
| if self.feat_proj_layer_norm: |
| hidden_states = self.layer_norm(hidden_states) |
| hidden_states = self.projection(hidden_states) |
| hidden_states = self.dropout(hidden_states, deterministic=deterministic) |
| return hidden_states |
|
|
|
|
| class FlaxHubertAttention(nn.Module): |
| config: HubertConfig |
| embed_dim: int |
| num_heads: int |
| dropout: float = 0.0 |
| bias: bool = True |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self) -> None: |
| self.head_dim = self.embed_dim // self.num_heads |
| if self.head_dim * self.num_heads != self.embed_dim: |
| raise ValueError( |
| f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" |
| f" and `num_heads`: {self.num_heads})." |
| ) |
| self.scaling = self.head_dim**-0.5 |
|
|
| dense = partial( |
| nn.Dense, |
| self.embed_dim, |
| use_bias=self.bias, |
| dtype=self.dtype, |
| kernel_init=jax.nn.initializers.normal(self.config.initializer_range), |
| ) |
|
|
| self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() |
| self.out_proj = dense() |
|
|
| self.dropout_layer = nn.Dropout(rate=self.dropout) |
|
|
| def _split_heads(self, hidden_states): |
| return hidden_states.reshape( |
| hidden_states.shape[:2] + (self.num_heads, self.head_dim) |
| ) |
|
|
| def _merge_heads(self, hidden_states): |
| return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,)) |
|
|
| def __call__( |
| self, |
| hidden_states: jnp.ndarray, |
| attention_mask: Optional[jnp.ndarray] = None, |
| output_attentions: bool = False, |
| deterministic: bool = True, |
| ) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]: |
| """Input shape: Batch x Time x Channel""" |
|
|
| |
| query_states = self.q_proj(hidden_states) |
| key_states = self.k_proj(hidden_states) |
| value_states = self.v_proj(hidden_states) |
|
|
| query_states = self._split_heads(query_states) |
| key_states = self._split_heads(key_states) |
| value_states = self._split_heads(value_states) |
|
|
| if attention_mask is not None: |
| attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) |
| attention_bias = lax.select( |
| attention_mask > 0, |
| jnp.full(attention_mask.shape, 0.0).astype(self.dtype), |
| jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype( |
| self.dtype |
| ), |
| ) |
| else: |
| attention_bias = None |
|
|
| dropout_rng = None |
| if not deterministic and self.dropout > 0.0: |
| dropout_rng = self.make_rng("dropout") |
|
|
| attn_weights = dot_product_attention_weights( |
| query_states, |
| key_states, |
| bias=attention_bias, |
| dropout_rng=dropout_rng, |
| dropout_rate=self.dropout, |
| broadcast_dropout=True, |
| deterministic=deterministic, |
| dtype=self.dtype, |
| precision=None, |
| ) |
|
|
| attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) |
| attn_output = self._merge_heads(attn_output) |
| attn_output = self.out_proj(attn_output) |
|
|
| return attn_output, attn_weights |
|
|
|
|
| class FlaxHubertFeedForward(nn.Module): |
| config: HubertConfig |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.intermediate_dropout = nn.Dropout(self.config.activation_dropout) |
|
|
| self.intermediate_dense = nn.Dense( |
| self.config.intermediate_size, dtype=self.dtype |
| ) |
| if isinstance(self.config.hidden_act, str): |
| self.intermediate_activation = ACT2FN[self.config.hidden_act] |
| else: |
| self.intermediate_activation = self.config.hidden_act |
|
|
| self.output_dense = nn.Dense(self.config.hidden_size, dtype=self.dtype) |
| self.output_dropout = nn.Dropout(self.config.activation_dropout) |
|
|
| def __call__(self, hidden_states, deterministic=True): |
| hidden_states = self.intermediate_dense(hidden_states) |
| hidden_states = self.intermediate_activation(hidden_states) |
| hidden_states = self.intermediate_dropout( |
| hidden_states, deterministic=deterministic |
| ) |
|
|
| hidden_states = self.output_dense(hidden_states) |
| hidden_states = self.output_dropout(hidden_states, deterministic=deterministic) |
|
|
| return hidden_states |
|
|
|
|
| class FlaxHubertEncoderLayer(nn.Module): |
| config: HubertConfig |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.attention = FlaxHubertAttention( |
| config=self.config, |
| embed_dim=self.config.hidden_size, |
| num_heads=self.config.num_attention_heads, |
| dropout=self.config.attention_dropout, |
| dtype=self.dtype, |
| ) |
| self.dropout = nn.Dropout(self.config.hidden_dropout) |
| self.layer_norm = nn.LayerNorm( |
| epsilon=self.config.layer_norm_eps, dtype=self.dtype |
| ) |
| self.feed_forward = FlaxHubertFeedForward(self.config, dtype=self.dtype) |
| self.final_layer_norm = nn.LayerNorm( |
| epsilon=self.config.layer_norm_eps, dtype=self.dtype |
| ) |
|
|
| def __call__( |
| self, |
| hidden_states, |
| attention_mask: Optional[jnp.ndarray] = None, |
| output_attentions: bool = False, |
| deterministic=True, |
| ): |
| attn_residual = hidden_states |
| hidden_states, attn_weights = self.attention( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| output_attentions=output_attentions, |
| deterministic=deterministic, |
| ) |
|
|
| hidden_states = self.dropout(hidden_states, deterministic=deterministic) |
| hidden_states = attn_residual + hidden_states |
|
|
| hidden_states = self.layer_norm(hidden_states) |
| hidden_states = hidden_states + self.feed_forward( |
| hidden_states, deterministic=deterministic |
| ) |
| hidden_states = self.final_layer_norm(hidden_states) |
|
|
| outputs = (hidden_states,) |
|
|
| if output_attentions: |
| outputs += (attn_weights,) |
|
|
| return outputs |
|
|
|
|
| class FlaxHubertEncoderLayerStableLayerNorm(nn.Module): |
| config: HubertConfig |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.attention = FlaxHubertAttention( |
| config=self.config, |
| embed_dim=self.config.hidden_size, |
| num_heads=self.config.num_attention_heads, |
| dropout=self.config.attention_dropout, |
| dtype=self.dtype, |
| ) |
| self.dropout = nn.Dropout(self.config.hidden_dropout) |
| self.layer_norm = nn.LayerNorm( |
| epsilon=self.config.layer_norm_eps, dtype=self.dtype |
| ) |
| self.feed_forward = FlaxHubertFeedForward(self.config, dtype=self.dtype) |
| self.final_layer_norm = nn.LayerNorm( |
| epsilon=self.config.layer_norm_eps, dtype=self.dtype |
| ) |
|
|
| def __call__( |
| self, |
| hidden_states, |
| attention_mask: Optional[jnp.ndarray] = None, |
| output_attentions: bool = False, |
| deterministic=True, |
| ): |
| attn_residual = hidden_states |
| hidden_states = self.layer_norm(hidden_states) |
| hidden_states, attn_weights = self.attention( |
| hidden_states=hidden_states, |
| attention_mask=attention_mask, |
| output_attentions=output_attentions, |
| deterministic=deterministic, |
| ) |
|
|
| hidden_states = self.dropout(hidden_states, deterministic=deterministic) |
| hidden_states = attn_residual + hidden_states |
|
|
| hidden_states = hidden_states + self.feed_forward( |
| self.final_layer_norm(hidden_states), deterministic=deterministic |
| ) |
|
|
| outputs = (hidden_states,) |
|
|
| if output_attentions: |
| outputs += (attn_weights,) |
|
|
| return outputs |
|
|
|
|
| class FlaxHubertLayerCollection(nn.Module): |
| config: HubertConfig |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.layers = [ |
| FlaxHubertEncoderLayer(self.config, name=str(i), dtype=self.dtype) |
| for i in range(self.config.num_hidden_layers) |
| ] |
|
|
| def __call__( |
| self, |
| hidden_states, |
| attention_mask=None, |
| deterministic: bool = True, |
| output_attentions: bool = False, |
| output_hidden_states: bool = False, |
| return_dict: bool = True, |
| ): |
| all_attentions = () if output_attentions else None |
| all_hidden_states = () if output_hidden_states else None |
|
|
| for i, layer in enumerate(self.layers): |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| layer_outputs = layer( |
| hidden_states, |
| attention_mask, |
| deterministic=deterministic, |
| output_attentions=output_attentions, |
| ) |
|
|
| hidden_states = layer_outputs[0] |
|
|
| if output_attentions: |
| all_attentions += (layer_outputs[1],) |
|
|
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| outputs = (hidden_states, all_hidden_states, all_attentions) |
|
|
| if not return_dict: |
| return tuple(v for v in outputs if v is not None) |
|
|
| return FlaxBaseModelOutput( |
| last_hidden_state=hidden_states, |
| hidden_states=all_hidden_states, |
| attentions=all_attentions, |
| ) |
|
|
|
|
| class FlaxHubertEncoder(nn.Module): |
| config: HubertConfig |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.pos_conv_embed = FlaxHubertPositionalConvEmbedding( |
| self.config, dtype=self.dtype |
| ) |
| self.layer_norm = nn.LayerNorm( |
| epsilon=self.config.layer_norm_eps, dtype=self.dtype |
| ) |
| self.dropout = nn.Dropout(rate=self.config.hidden_dropout) |
| self.layers = FlaxHubertLayerCollection(self.config, dtype=self.dtype) |
|
|
| def __call__( |
| self, |
| hidden_states, |
| attention_mask: Optional[jnp.ndarray] = None, |
| output_attentions: bool = False, |
| output_hidden_states: bool = False, |
| return_dict: bool = True, |
| deterministic: bool = True, |
| ): |
| if attention_mask is not None: |
| |
| hidden_states = jnp.where( |
| jnp.broadcast_to(attention_mask[:, :, None], hidden_states.shape), |
| hidden_states, |
| 0, |
| ) |
|
|
| position_embeddings = self.pos_conv_embed(hidden_states) |
|
|
| hidden_states = hidden_states + position_embeddings |
| hidden_states = self.dropout(hidden_states, deterministic=deterministic) |
|
|
| outputs = self.layers( |
| hidden_states, |
| attention_mask, |
| deterministic=deterministic, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| last_hidden_state = self.layer_norm(outputs[0]) |
|
|
| hidden_states = None |
| if output_hidden_states: |
| hidden_states = outputs[1] |
| hidden_states = hidden_states[:-1] + (last_hidden_state,) |
|
|
| if not return_dict: |
| outputs = (last_hidden_state, hidden_states) + ( |
| outputs[2:] if output_hidden_states else outputs[1:] |
| ) |
| return tuple(v for v in outputs if v is not None) |
|
|
| return FlaxBaseModelOutput( |
| last_hidden_state=last_hidden_state, |
| hidden_states=hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| class FlaxHubertLayerStableLayerNormCollection(nn.Module): |
| config: HubertConfig |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.layers = [ |
| FlaxHubertEncoderLayerStableLayerNorm( |
| self.config, name=str(i), dtype=self.dtype |
| ) |
| for i in range(self.config.num_hidden_layers) |
| ] |
|
|
| def __call__( |
| self, |
| hidden_states, |
| attention_mask=None, |
| deterministic: bool = True, |
| output_attentions: bool = False, |
| output_hidden_states: bool = False, |
| return_dict: bool = True, |
| ): |
| all_attentions = () if output_attentions else None |
| all_hidden_states = () if output_hidden_states else None |
|
|
| for i, layer in enumerate(self.layers): |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| layer_outputs = layer( |
| hidden_states, |
| attention_mask, |
| deterministic=deterministic, |
| output_attentions=output_attentions, |
| ) |
|
|
| hidden_states = layer_outputs[0] |
|
|
| if output_attentions: |
| all_attentions += (layer_outputs[1],) |
|
|
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| outputs = (hidden_states, all_hidden_states, all_attentions) |
|
|
| if not return_dict: |
| return tuple(v for v in outputs if v is not None) |
|
|
| return FlaxBaseModelOutput( |
| last_hidden_state=hidden_states, |
| hidden_states=all_hidden_states, |
| attentions=all_attentions, |
| ) |
|
|
|
|
| class FlaxHubertEncoderStableLayerNorm(nn.Module): |
| config: HubertConfig |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.pos_conv_embed = FlaxHubertPositionalConvEmbedding( |
| self.config, dtype=self.dtype |
| ) |
| self.layer_norm = nn.LayerNorm( |
| epsilon=self.config.layer_norm_eps, dtype=self.dtype |
| ) |
| self.dropout = nn.Dropout(rate=self.config.hidden_dropout) |
| self.layers = FlaxHubertLayerStableLayerNormCollection( |
| self.config, dtype=self.dtype |
| ) |
|
|
| def __call__( |
| self, |
| hidden_states, |
| attention_mask: Optional[jnp.ndarray] = None, |
| output_attentions: bool = False, |
| output_hidden_states: bool = False, |
| return_dict: bool = True, |
| deterministic: bool = True, |
| ): |
| if attention_mask is not None: |
| hidden_states = jnp.where( |
| jnp.broadcast_to(attention_mask[:, :, None], hidden_states.shape), |
| hidden_states, |
| 0, |
| ) |
|
|
| position_embeddings = self.pos_conv_embed(hidden_states) |
|
|
| hidden_states = hidden_states + position_embeddings |
| hidden_states = self.dropout(hidden_states, deterministic=deterministic) |
|
|
| outputs = self.layers( |
| hidden_states, |
| attention_mask, |
| deterministic=deterministic, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| last_hidden_state = self.layer_norm(outputs[0]) |
|
|
| hidden_states = None |
| if output_hidden_states: |
| hidden_states = outputs[1] |
| hidden_states = hidden_states[:-1] + (last_hidden_state,) |
|
|
| if not return_dict: |
| outputs = (last_hidden_state, hidden_states) + ( |
| outputs[2:] if output_hidden_states else outputs[1:] |
| ) |
| return tuple(v for v in outputs if v is not None) |
|
|
| return FlaxBaseModelOutput( |
| last_hidden_state=last_hidden_state, |
| hidden_states=hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| class FlaxHubertPreTrainedModel(FlaxPreTrainedModel): |
| config_class = HubertConfig |
| base_model_prefix = "hubert" |
| main_input_name = "input_values" |
| module_class: nn.Module = None |
| _keys_to_ignore_on_load_missing = [r"position_ids"] |
|
|
| def __init__( |
| self, |
| config: HubertConfig, |
| input_shape: Tuple = (1, 1024), |
| seed: int = 0, |
| dtype: jnp.dtype = jnp.float32, |
| _do_init: bool = True, |
| **kwargs, |
| ): |
| module = self.module_class(config=config, dtype=dtype, **kwargs) |
| super().__init__( |
| config, |
| module, |
| input_shape=input_shape, |
| seed=seed, |
| dtype=dtype, |
| _do_init=_do_init, |
| ) |
|
|
| def init_weights( |
| self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None |
| ) -> FrozenDict: |
| input_values = jnp.zeros(input_shape, dtype="i4") |
| attention_mask = jnp.ones_like(input_values) |
| params_rng, dropout_rng = jax.random.split(rng, 2) |
| rngs = {"params": params_rng, "dropout": dropout_rng} |
|
|
| random_params = self.module.init( |
| rngs, input_values, attention_mask, return_dict=False |
| )["params"] |
|
|
| if params is not None: |
| random_params = flatten_dict(unfreeze(random_params)) |
| params = flatten_dict(unfreeze(params)) |
| for missing_key in self._missing_keys: |
| params[missing_key] = random_params[missing_key] |
| self._missing_keys = set() |
| return freeze(unflatten_dict(params)) |
| else: |
| return random_params |
|
|
| def __call__( |
| self, |
| input_values, |
| attention_mask=None, |
| mask_time_indices=None, |
| params: dict = None, |
| dropout_rng: jax.random.PRNGKey = None, |
| train: bool = False, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| freeze_feature_encoder: bool = False, |
| return_dict: Optional[bool] = None, |
| ): |
| output_attentions = ( |
| output_attentions |
| if output_attentions is not None |
| else self.config.output_attentions |
| ) |
| output_hidden_states = ( |
| output_hidden_states |
| if output_hidden_states is not None |
| else self.config.output_hidden_states |
| ) |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.return_dict |
| ) |
|
|
| batch_size, sequence_length = input_values.shape |
|
|
| if attention_mask is None: |
| attention_mask = jnp.ones((batch_size, sequence_length)) |
|
|
| rngs = {} |
| if dropout_rng is not None: |
| rngs["dropout"] = dropout_rng |
|
|
| inputs = {"params": params or self.params} |
|
|
| return self.module.apply( |
| inputs, |
| jnp.array(input_values, dtype="f4"), |
| jnp.array(attention_mask, dtype="i4"), |
| mask_time_indices, |
| not train, |
| output_attentions, |
| output_hidden_states, |
| freeze_feature_encoder, |
| return_dict, |
| rngs=rngs, |
| ) |
|
|
|
|
| class FlaxHubertModule(nn.Module): |
| config: HubertConfig |
| dtype: jnp.dtype = jnp.float32 |
|
|
| def setup(self): |
| self.feature_extractor = FlaxHubertFeatureEncoder(self.config, dtype=self.dtype) |
| self.feature_projection = FlaxHubertFeatureProjection( |
| self.config, dtype=self.dtype |
| ) |
|
|
| if self.config.mask_time_prob > 0.0 or self.config.mask_feature_prob > 0.0: |
| self.masked_spec_embed = self.param( |
| "masked_spec_embed", |
| nn.initializers.uniform(dtype=self.dtype), |
| (self.config.hidden_size,), |
| ) |
|
|
| if self.config.do_stable_layer_norm: |
| self.encoder = FlaxHubertEncoderStableLayerNorm(self.config) |
| else: |
| self.encoder = FlaxHubertEncoder(self.config) |
|
|
| def __call__( |
| self, |
| input_values: Optional[jnp.ndarray], |
| attention_mask: Optional[jnp.ndarray] = None, |
| mask_time_indices: Optional[jnp.ndarray] = None, |
| deterministic: bool = True, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| freeze_feature_encoder: bool = False, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, FlaxHubertOutput]: |
| output_attentions = ( |
| output_attentions |
| if output_attentions is not None |
| else self.config.output_attentions |
| ) |
| output_hidden_states = ( |
| output_hidden_states |
| if output_hidden_states is not None |
| else self.config.output_hidden_states |
| ) |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
|
|
| extract_features = self.feature_extractor(input_values, freeze_feature_encoder) |
|
|
| if attention_mask is not None: |
| attention_mask = self._get_feature_vector_attention_mask( |
| extract_features.shape[1], attention_mask |
| ) |
|
|
| hidden_states = self.feature_projection( |
| extract_features, deterministic=deterministic |
| ) |
| if mask_time_indices is not None: |
| hidden_states = jnp.where( |
| jnp.broadcast_to(mask_time_indices[:, :, None], hidden_states.shape), |
| jnp.broadcast_to( |
| self.masked_spec_embed[None, None, :], hidden_states.shape |
| ), |
| hidden_states, |
| ) |
|
|
| encoder_outputs = self.encoder( |
| hidden_states, |
| attention_mask=attention_mask, |
| deterministic=deterministic, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| hidden_states = encoder_outputs[0] |
|
|
| if not return_dict: |
| return (hidden_states,) + encoder_outputs[1:] |
|
|
| return FlaxHubertOutput( |
| last_hidden_state=hidden_states, |
| hidden_states=encoder_outputs.hidden_states, |
| attentions=encoder_outputs.attentions, |
| extract_features=extract_features, |
| ) |
|
|
| def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]): |
| def _conv_out_length(input_length, kernel_size, stride): |
| return (input_length - kernel_size) // stride + 1 |
|
|
| for kernel_size, stride in zip( |
| self.config.conv_kernel, self.config.conv_stride |
| ): |
| input_lengths = _conv_out_length(input_lengths, kernel_size, stride) |
|
|
| return input_lengths |
|
|
| def _get_feature_vector_attention_mask( |
| self, feature_vector_length: int, attention_mask: jnp.ndarray |
| ): |
| non_padded_lengths = attention_mask.cumsum(axis=-1)[:, -1] |
|
|
| output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths) |
|
|
| batch_size = attention_mask.shape[0] |
|
|
| attention_mask = jnp.zeros( |
| (batch_size, feature_vector_length), dtype=attention_mask.dtype |
| ) |
| attention_mask = attention_mask.at[ |
| jnp.arange(attention_mask.shape[0]), output_lengths - 1 |
| ].set(1) |
| attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype( |
| "bool" |
| ) |
| return attention_mask |
|
|
|
|
| class FlaxHubertModel(FlaxHubertPreTrainedModel): |
| module_class = FlaxHubertModule |
|
|