| |
| |
| |
| |
|
|
| |
| |
|
|
| import logging |
| import math |
| import uuid |
| from dataclasses import dataclass, field |
| from enum import Enum, EnumMeta |
| from typing import Callable, Dict, List, Optional, Tuple |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch import Tensor |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def rotate_half(x): |
| x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] |
| return torch.cat( |
| (-x2, x1), dim=x1.ndim - 1 |
| ) |
|
|
|
|
| def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): |
| cos, sin = ( |
| cos[offset : q.shape[0] + offset, ...], |
| sin[offset : q.shape[0] + offset, ...], |
| ) |
| return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) |
|
|
|
|
| class RotaryPositionalEmbedding(torch.nn.Module): |
| def __init__(self, dim, base=10000, precision=torch.half): |
| """Rotary positional embedding |
| Reference : https://blog.eleuther.ai/rotary-embeddings/ |
| Paper: https://arxiv.org/pdf/2104.09864.pdf |
| Args: |
| dim: Dimension of embedding |
| base: Base value for exponential |
| precision: precision to use for numerical values |
| """ |
| super().__init__() |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer("inv_freq", inv_freq) |
| self.seq_len_cached = None |
| self.cos_cached = None |
| self.sin_cached = None |
| self.precision = precision |
|
|
| def forward(self, x, seq_len=None): |
| """ |
| Args: |
| x: Input x with T X B X C |
| seq_len: Sequence length of input x |
| """ |
| if seq_len != self.seq_len_cached: |
| self.seq_len_cached = seq_len |
| t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) |
| freqs = torch.einsum("i,j->ij", t, self.inv_freq) |
| emb = torch.cat((freqs, freqs), dim=-1).to(x.device) |
| self.cos_cached = emb.cos()[:, None, None, :] |
| self.sin_cached = emb.sin()[:, None, None, :] |
| return self.cos_cached, self.sin_cached |
|
|
|
|
| class ESPNETMultiHeadedAttention(nn.Module): |
| """Multi-Head Attention layer. |
| Args: |
| n_head: The number of heads. |
| n_feat: The number of features. |
| dropout: Dropout rate. |
| """ |
|
|
| def __init__(self, n_feat, n_head, dropout): |
| """Construct an MultiHeadedAttention object.""" |
| super(ESPNETMultiHeadedAttention, self).__init__() |
| assert n_feat % n_head == 0 |
| |
| self.d_k = n_feat // n_head |
| self.h = n_head |
| self.linear_q = nn.Linear(n_feat, n_feat) |
| self.linear_k = nn.Linear(n_feat, n_feat) |
| self.linear_v = nn.Linear(n_feat, n_feat) |
| self.linear_out = nn.Linear(n_feat, n_feat) |
| self.attn = None |
| self.dropout = nn.Dropout(p=dropout) |
|
|
| def forward_qkv(self, query, key, value, **kwargs): |
| """Transform query, key and value. |
| Args: |
| query: Query tensor B X T1 X C |
| key: Key tensor B X T2 X C |
| value: Value tensor B X T2 X C |
| Returns: |
| torch.Tensor: Transformed query tensor B X n_head X T1 X d_k |
| torch.Tensor: Transformed key tensor B X n_head X T2 X d_k |
| torch.Tensor: Transformed value tensor B X n_head X T2 X d_k |
| """ |
| n_batch = query.size(0) |
| q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) |
| k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) |
| v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) |
| q = q.transpose(1, 2) |
| k = k.transpose(1, 2) |
| v = v.transpose(1, 2) |
| return q, k, v |
|
|
| def forward_attention(self, value, scores, mask): |
| """Compute attention context vector. |
| Args: |
| value: Transformed value B X n_head X T2 X d_k. |
| scores: Attention score B X n_head X T1 X T2 |
| mask: Mask T2 X B |
| Returns: |
| torch.Tensor: Transformed value B X T1 X d_model |
| weighted by the attention score B X T1 X T2 |
| """ |
| n_batch = value.size(0) |
| if mask is not None: |
| scores = scores.masked_fill( |
| mask.unsqueeze(1).unsqueeze(2).to(bool), |
| float("-inf"), |
| ) |
| self.attn = torch.softmax(scores, dim=-1) |
|
|
| else: |
| self.attn = torch.softmax(scores, dim=-1) |
| p_attn = self.dropout(self.attn) |
| x = torch.matmul(p_attn, value) |
| x = ( |
| x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) |
| ) |
|
|
| return self.linear_out(x) |
|
|
| def forward(self, query, key, value, key_padding_mask=None, **kwargs): |
| """Compute scaled dot product attention. |
| Args: |
| query (torch.Tensor): Query tensor T X B X C |
| key (torch.Tensor): Key tensor T X B X C |
| value (torch.Tensor): Value tensor T X B X C |
| mask (torch.Tensor): Mask tensor T X B |
| Returns: |
| torch.Tensor: Output tensor T X B X D. |
| """ |
| query = query.transpose(0, 1) |
| key = key.transpose(0, 1) |
| value = value.transpose(0, 1) |
|
|
| q, k, v = self.forward_qkv(query, key, value) |
| scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) |
| scores = self.forward_attention(v, scores, key_padding_mask) |
| scores = scores.transpose(0, 1) |
| return scores, None |
|
|
|
|
| class RelPositionMultiHeadedAttention(ESPNETMultiHeadedAttention): |
| """Multi-Head Attention layer with relative position encoding. |
| Paper: https://arxiv.org/abs/1901.02860 |
| Args: |
| n_head: The number of heads. |
| n_feat: The number of features. |
| dropout: Dropout rate. |
| zero_triu: Whether to zero the upper triangular part of attention matrix. |
| """ |
|
|
| def __init__(self, n_feat, n_head, dropout, zero_triu=False): |
| """Construct an RelPositionMultiHeadedAttention object.""" |
| super().__init__(n_feat, n_head, dropout) |
| self.zero_triu = zero_triu |
| |
| self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) |
| |
| |
| self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) |
| self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) |
| torch.nn.init.xavier_uniform_(self.pos_bias_u) |
| torch.nn.init.xavier_uniform_(self.pos_bias_v) |
|
|
| def rel_shift(self, x): |
| """Compute relative positional encoding. |
| Args: |
| x: Input tensor B X n_head X T X 2T-1 |
| Returns: |
| torch.Tensor: Output tensor. |
| """ |
| zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype) |
| x_padded = torch.cat([zero_pad, x], dim=-1) |
|
|
| x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) |
| x = x_padded[:, :, 1:].view_as(x)[ |
| :, :, :, : x.size(-1) // 2 + 1 |
| ] |
|
|
| if self.zero_triu: |
| ones = torch.ones((x.size(2), x.size(3)), device=x.device) |
| x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] |
|
|
| return x |
|
|
| def forward(self, query, key, value, pos_emb, key_padding_mask=None, **kwargs): |
| """Compute scaled dot product attention. |
| Args: |
| query: Query tensor T X B X C |
| key: Key tensor T X B X C |
| value: Value tensor T X B X C |
| pos_emb: Positional embedding tensor B X 2T-1 X C |
| key_padding_mask: Mask tensor T X B |
| Returns: |
| torch.Tensor: Output tensor T X B X C. |
| """ |
| query = query.transpose(0, 1) |
| key = key.transpose(0, 1) |
| value = value.transpose(0, 1) |
| pos_emb = pos_emb.transpose(0, 1) |
| q, k, v = self.forward_qkv(query, key, value) |
| q = q.transpose(1, 2) |
| n_batch_pos = pos_emb.size(0) |
| p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) |
| p = p.transpose(1, 2) |
|
|
| |
| q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) |
| |
| q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) |
|
|
| |
| |
| |
| |
| matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) |
|
|
| |
| |
| matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) |
| matrix_bd = self.rel_shift(matrix_bd) |
|
|
| scores = (matrix_ac + matrix_bd) / math.sqrt( |
| self.d_k |
| ) |
|
|
| scores = self.forward_attention(v, scores, key_padding_mask) |
| scores = scores.transpose(0, 1) |
| return scores, None |
|
|
|
|
| class RotaryPositionMultiHeadedAttention(ESPNETMultiHeadedAttention): |
| def __init__( |
| self, |
| n_feat, |
| n_head, |
| dropout, |
| precision, |
| rotary_emd_base=10000, |
| ): |
| """Construct an RotaryPositionMultiHeadedAttention object.""" |
| super().__init__(n_feat, n_head, dropout) |
| precision = torch.float |
| self.rotary_ndims = self.d_k |
| if precision == "fp16": |
| precision = torch.half |
|
|
| self.rotary_emb = RotaryPositionalEmbedding( |
| self.rotary_ndims, base=rotary_emd_base, precision=precision |
| ) |
|
|
| def forward(self, query, key, value, key_padding_mask=None, **kwargs): |
| """Compute rotary position attention. |
| Args: |
| query: Query tensor T X B X C |
| key: Key tensor T X B X C |
| value: Value tensor T X B X C |
| key_padding_mask: Mask tensor T X B |
| Returns: |
| torch.Tensor: Output tensor T X B X D. |
| Notes: |
| Assumes self attn |
| """ |
|
|
| T, B, C = value.size() |
| query = query.view(T, B, self.h, self.d_k) |
| key = key.view(T, B, self.h, self.d_k) |
| value = value.view(T, B, self.h, self.d_k) |
| cos, sin = self.rotary_emb(value, seq_len=T) |
| query, key = apply_rotary_pos_emb( |
| query, key, cos, sin, offset=0 |
| ) |
|
|
| query = query.view(T, B, self.h * self.d_k) |
| key = key.view(T, B, self.h * self.d_k) |
| value = value.view(T, B, self.h * self.d_k) |
|
|
| |
| query = query.transpose(0, 1) |
| key = key.transpose(0, 1) |
| value = value.transpose(0, 1) |
|
|
| q, k, v = self.forward_qkv(query, key, value) |
| scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) |
| scores = self.forward_attention(v, scores, key_padding_mask) |
| scores = scores.transpose(0, 1) |
| return scores, None |
|
|
|
|
| class ConvolutionModule(torch.nn.Module): |
| """Convolution block used in the conformer block""" |
|
|
| def __init__( |
| self, |
| embed_dim, |
| channels, |
| depthwise_kernel_size, |
| dropout, |
| activation_fn="swish", |
| bias=False, |
| export=False, |
| ): |
| """ |
| Args: |
| embed_dim: Embedding dimension |
| channels: Number of channels in depthwise conv layers |
| depthwise_kernel_size: Depthwise conv layer kernel size |
| dropout: dropout value |
| activation_fn: Activation function to use after depthwise convolution kernel |
| bias: If bias should be added to conv layers |
| export: If layernorm should be exported to jit |
| """ |
| super(ConvolutionModule, self).__init__() |
| assert ( |
| depthwise_kernel_size - 1 |
| ) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding" |
| self.layer_norm = LayerNorm(embed_dim, export=export) |
| self.pointwise_conv1 = torch.nn.Conv1d( |
| embed_dim, |
| 2 * channels, |
| kernel_size=1, |
| stride=1, |
| padding=0, |
| bias=bias, |
| ) |
| self.glu = torch.nn.GLU(dim=1) |
| self.depthwise_conv = torch.nn.Conv1d( |
| channels, |
| channels, |
| depthwise_kernel_size, |
| stride=1, |
| padding=(depthwise_kernel_size - 1) // 2, |
| groups=channels, |
| bias=bias, |
| ) |
| self.batch_norm = torch.nn.BatchNorm1d(channels) |
| self.activation = get_activation_fn(activation_fn)(channels) |
| self.pointwise_conv2 = torch.nn.Conv1d( |
| channels, |
| embed_dim, |
| kernel_size=1, |
| stride=1, |
| padding=0, |
| bias=bias, |
| ) |
| self.dropout = torch.nn.Dropout(dropout) |
|
|
| def forward(self, x): |
| """ |
| Args: |
| x: Input of shape B X T X C |
| Returns: |
| Tensor of shape B X T X C |
| """ |
| x = self.layer_norm(x) |
| |
| x = x.transpose(1, 2) |
|
|
| |
| x = self.pointwise_conv1(x) |
| x = self.glu(x) |
|
|
| |
| x = self.depthwise_conv(x) |
| x = self.batch_norm(x) |
| x = self.activation(x) |
|
|
| x = self.pointwise_conv2(x) |
| x = self.dropout(x) |
| return x.transpose(1, 2) |
|
|
|
|
| class FeedForwardModule(torch.nn.Module): |
| """Positionwise feed forward layer used in conformer""" |
|
|
| def __init__( |
| self, |
| input_feat, |
| hidden_units, |
| dropout1, |
| dropout2, |
| activation_fn="swish", |
| bias=True, |
| ): |
| """ |
| Args: |
| input_feat: Input feature dimension |
| hidden_units: Hidden unit dimension |
| dropout1: dropout value for layer1 |
| dropout2: dropout value for layer2 |
| activation_fn: Name of activation function |
| bias: If linear layers should have bias |
| """ |
|
|
| super(FeedForwardModule, self).__init__() |
| self.layer_norm = LayerNorm(input_feat) |
| self.w_1 = torch.nn.Linear(input_feat, hidden_units, bias=bias) |
| self.w_2 = torch.nn.Linear(hidden_units, input_feat, bias=bias) |
| self.dropout1 = torch.nn.Dropout(dropout1) |
| self.dropout2 = torch.nn.Dropout(dropout2) |
| self.activation = get_activation_fn(activation_fn)(hidden_units) |
|
|
| def forward(self, x): |
| """ |
| Args: |
| x: Input Tensor of shape T X B X C |
| Returns: |
| Tensor of shape T X B X C |
| """ |
| x = self.layer_norm(x) |
| x = self.w_1(x) |
| x = self.activation(x) |
| x = self.dropout1(x) |
| x = self.w_2(x) |
| return self.dropout2(x) |
|
|
|
|
| class ConformerEncoderLayer(torch.nn.Module): |
| """Conformer block based on https://arxiv.org/abs/2005.08100. We currently don't support relative positional encoding in MHA""" |
|
|
| def __init__( |
| self, |
| embed_dim, |
| ffn_embed_dim, |
| attention_heads, |
| dropout, |
| use_fp16, |
| depthwise_conv_kernel_size=31, |
| activation_fn="swish", |
| attn_type=None, |
| pos_enc_type="abs", |
| ): |
| """ |
| Args: |
| embed_dim: Input embedding dimension |
| ffn_embed_dim: FFN layer dimension |
| attention_heads: Number of attention heads in MHA |
| dropout: dropout value |
| depthwise_conv_kernel_size: Size of kernel in depthwise conv layer in convolution module |
| activation_fn: Activation function name to use in convulation block and feed forward block |
| attn_type: MHA implementation from ESPNET vs fairseq |
| pos_enc_type: Positional encoding type - abs, rope, rel_pos |
| """ |
| self.pos_enc_type = pos_enc_type |
| super(ConformerEncoderLayer, self).__init__() |
|
|
| self.ffn1 = FeedForwardModule( |
| embed_dim, |
| ffn_embed_dim, |
| dropout, |
| dropout, |
| ) |
|
|
| self.self_attn_layer_norm = LayerNorm(embed_dim, export=False) |
| self.self_attn_dropout = torch.nn.Dropout(dropout) |
| if attn_type == "espnet": |
| if self.pos_enc_type == "rel_pos": |
| self.self_attn = RelPositionMultiHeadedAttention( |
| embed_dim, |
| attention_heads, |
| dropout=dropout, |
| ) |
| elif self.pos_enc_type == "rope": |
| self.self_attn = RotaryPositionMultiHeadedAttention( |
| embed_dim, attention_heads, dropout=dropout, precision=use_fp16 |
| ) |
| elif self.pos_enc_type == "abs": |
| self.self_attn = ESPNETMultiHeadedAttention( |
| embed_dim, |
| attention_heads, |
| dropout=dropout, |
| ) |
| else: |
| raise Exception(f"Unsupported attention type {self.pos_enc_type}") |
| else: |
| |
| self.self_attn = MultiheadAttention( |
| embed_dim, |
| attention_heads, |
| dropout=dropout, |
| ) |
|
|
| self.conv_module = ConvolutionModule( |
| embed_dim=embed_dim, |
| channels=embed_dim, |
| depthwise_kernel_size=depthwise_conv_kernel_size, |
| dropout=dropout, |
| activation_fn=activation_fn, |
| ) |
|
|
| self.ffn2 = FeedForwardModule( |
| embed_dim, |
| ffn_embed_dim, |
| dropout, |
| dropout, |
| activation_fn=activation_fn, |
| ) |
| self.final_layer_norm = LayerNorm(embed_dim, export=False) |
|
|
| def forward( |
| self, |
| x, |
| encoder_padding_mask: Optional[torch.Tensor], |
| position_emb: Optional[torch.Tensor] = None, |
| ): |
| """ |
| Args: |
| x: Tensor of shape T X B X C |
| encoder_padding_mask: Optional mask tensor |
| positions: |
| Returns: |
| Tensor of shape T X B X C |
| """ |
| residual = x |
| x = self.ffn1(x) |
| x = x * 0.5 + residual |
| residual = x |
| x = self.self_attn_layer_norm(x) |
| if self.pos_enc_type == "rel_pos": |
| x, attn = self.self_attn( |
| query=x, |
| key=x, |
| value=x, |
| key_padding_mask=encoder_padding_mask, |
| pos_emb=position_emb, |
| need_weights=False, |
| ) |
| else: |
| x, attn = self.self_attn( |
| query=x, |
| key=x, |
| value=x, |
| key_padding_mask=encoder_padding_mask, |
| need_weights=False, |
| ) |
| x = self.self_attn_dropout(x) |
| x = x + residual |
|
|
| residual = x |
| |
| x = x.transpose(0, 1) |
| x = self.conv_module(x) |
| |
| x = x.transpose(0, 1) |
| x = residual + x |
|
|
| residual = x |
| x = self.ffn2(x) |
|
|
| layer_result = x |
|
|
| x = x * 0.5 + residual |
|
|
| x = self.final_layer_norm(x) |
| return x, (attn, layer_result) |
|
|
|
|
| class ConformerWav2Vec2EncoderLayer(ConformerEncoderLayer): |
| """Encoder layer for Wav2vec2 encoder""" |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| self_attn_mask: torch.Tensor = None, |
| self_attn_padding_mask: torch.Tensor = None, |
| need_weights: bool = False, |
| att_args=None, |
| position_emb=None, |
| ): |
| return super().forward(x, self_attn_padding_mask, position_emb) |
|
|
|
|
| class FairseqIncrementalState(object): |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.init_incremental_state() |
|
|
| def init_incremental_state(self): |
| self._incremental_state_id = str(uuid.uuid4()) |
|
|
| def _get_full_incremental_state_key(self, key: str) -> str: |
| return "{}.{}".format(self._incremental_state_id, key) |
|
|
| def get_incremental_state( |
| self, |
| incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], |
| key: str, |
| ) -> Optional[Dict[str, Optional[Tensor]]]: |
| """Helper for getting incremental state for an nn.Module.""" |
| full_key = self._get_full_incremental_state_key(key) |
| if incremental_state is None or full_key not in incremental_state: |
| return None |
| return incremental_state[full_key] |
|
|
| def set_incremental_state( |
| self, |
| incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], |
| key: str, |
| value: Dict[str, Optional[Tensor]], |
| ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: |
| """Helper for setting incremental state for an nn.Module.""" |
| if incremental_state is not None: |
| full_key = self._get_full_incremental_state_key(key) |
| incremental_state[full_key] = value |
| return incremental_state |
|
|
|
|
| def with_incremental_state(cls): |
| cls.__bases__ = (FairseqIncrementalState,) + tuple( |
| b for b in cls.__bases__ if b != FairseqIncrementalState |
| ) |
| return cls |
|
|
|
|
| class FairseqDropout(nn.Module): |
| def __init__(self, p, module_name=None): |
| super().__init__() |
| self.p = p |
| self.module_name = module_name |
| self.apply_during_inference = False |
|
|
| def forward(self, x, inplace: bool = False): |
| if self.p > 0 and (self.training or self.apply_during_inference): |
| return F.dropout(x, p=self.p, training=True, inplace=inplace) |
| else: |
| return x |
|
|
| def make_generation_fast_( |
| self, |
| name: str, |
| retain_dropout: bool = False, |
| retain_dropout_modules: Optional[List[str]] = None, |
| **kwargs, |
| ): |
| if retain_dropout: |
| if retain_dropout_modules is not None and self.module_name is None: |
| logger.warning( |
| "Cannot enable dropout during inference for module {} " |
| "because module_name was not set".format(name) |
| ) |
| elif ( |
| retain_dropout_modules is None |
| or self.module_name in retain_dropout_modules |
| ): |
| logger.info( |
| "Enabling dropout during inference for module: {}".format(name) |
| ) |
| self.apply_during_inference = True |
| else: |
| logger.info("Disabling dropout for module: {}".format(name)) |
|
|
|
|
| def quant_noise(module, p, block_size): |
| """ |
| Wraps modules and applies quantization noise to the weights for |
| subsequent quantization with Iterative Product Quantization as |
| described in "Training with Quantization Noise for Extreme Model Compression" |
| |
| Args: |
| - module: nn.Module |
| - p: amount of Quantization Noise |
| - block_size: size of the blocks for subsequent quantization with iPQ |
| |
| Remarks: |
| - Module weights must have the right sizes wrt the block size |
| - Only Linear, Embedding and Conv2d modules are supported for the moment |
| - For more detail on how to quantize by blocks with convolutional weights, |
| see "And the Bit Goes Down: Revisiting the Quantization of Neural Networks" |
| - We implement the simplest form of noise here as stated in the paper |
| which consists in randomly dropping blocks |
| """ |
|
|
| |
| if p <= 0: |
| return module |
|
|
| |
| assert isinstance(module, (nn.Linear, nn.Embedding, nn.Conv2d)) |
|
|
| |
| is_conv = module.weight.ndim == 4 |
|
|
| |
| if not is_conv: |
| assert ( |
| module.weight.size(1) % block_size == 0 |
| ), "Input features must be a multiple of block sizes" |
|
|
| |
| else: |
| |
| if module.kernel_size == (1, 1): |
| assert ( |
| module.in_channels % block_size == 0 |
| ), "Input channels must be a multiple of block sizes" |
| |
| else: |
| k = module.kernel_size[0] * module.kernel_size[1] |
| assert k % block_size == 0, "Kernel size must be a multiple of block size" |
|
|
| def _forward_pre_hook(mod, input): |
| |
| if mod.training: |
| if not is_conv: |
| |
| weight = mod.weight |
| in_features = weight.size(1) |
| out_features = weight.size(0) |
|
|
| |
| mask = torch.zeros( |
| in_features // block_size * out_features, device=weight.device |
| ) |
| mask.bernoulli_(p) |
| mask = mask.repeat_interleave(block_size, -1).view(-1, in_features) |
|
|
| else: |
| |
| weight = mod.weight |
| in_channels = mod.in_channels |
| out_channels = mod.out_channels |
|
|
| |
| if mod.kernel_size == (1, 1): |
| mask = torch.zeros( |
| int(in_channels // block_size * out_channels), |
| device=weight.device, |
| ) |
| mask.bernoulli_(p) |
| mask = mask.repeat_interleave(block_size, -1).view(-1, in_channels) |
| else: |
| mask = torch.zeros( |
| weight.size(0), weight.size(1), device=weight.device |
| ) |
| mask.bernoulli_(p) |
| mask = ( |
| mask.unsqueeze(2) |
| .unsqueeze(3) |
| .repeat(1, 1, mod.kernel_size[0], mod.kernel_size[1]) |
| ) |
|
|
| |
| mask = mask.to( |
| torch.bool |
| ) |
| s = 1 / (1 - p) |
| mod.weight.data = s * weight.masked_fill(mask, 0) |
|
|
| module.register_forward_pre_hook(_forward_pre_hook) |
| return module |
|
|
|
|
| @with_incremental_state |
| class MultiheadAttention(nn.Module): |
| """Multi-headed attention. |
| |
| See "Attention Is All You Need" for more details. |
| """ |
|
|
| def __init__( |
| self, |
| embed_dim, |
| num_heads, |
| kdim=None, |
| vdim=None, |
| dropout=0.0, |
| bias=True, |
| add_bias_kv=False, |
| add_zero_attn=False, |
| self_attention=False, |
| encoder_decoder_attention=False, |
| q_noise=0.0, |
| qn_block_size=8, |
| |
| |
| xformers_att_config: Optional[str] = None, |
| xformers_blocksparse_layout: Optional[ |
| torch.Tensor |
| ] = None, |
| xformers_blocksparse_blocksize: Optional[ |
| int |
| ] = 16, |
| ): |
| super().__init__() |
|
|
| def eval_str_dict(x, type=dict): |
| if x is None: |
| return None |
| if isinstance(x, str): |
| x = eval(x) |
| return x |
|
|
| xformers_att_config = eval_str_dict(xformers_att_config) |
| self.use_xformers = xformers_att_config is not None |
| assert not self.use_xformers, "Do not use xformers in S3PRL" |
|
|
| self.embed_dim = embed_dim |
| self.kdim = kdim if kdim is not None else embed_dim |
| self.vdim = vdim if vdim is not None else embed_dim |
| self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim |
|
|
| self.num_heads = num_heads |
| self.dropout_module = FairseqDropout( |
| dropout, module_name=self.__class__.__name__ |
| ) |
|
|
| self.head_dim = embed_dim // num_heads |
| assert ( |
| self.head_dim * num_heads == self.embed_dim |
| ), "embed_dim must be divisible by num_heads" |
| self.scaling = self.head_dim**-0.5 |
|
|
| self.self_attention = self_attention |
| self.encoder_decoder_attention = encoder_decoder_attention |
|
|
| assert not self.self_attention or self.qkv_same_dim, ( |
| "Self-attention requires query, key and " "value to be of the same size" |
| ) |
|
|
| self.k_proj = quant_noise( |
| nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size |
| ) |
| self.v_proj = quant_noise( |
| nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size |
| ) |
| self.q_proj = quant_noise( |
| nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size |
| ) |
|
|
| self.out_proj = quant_noise( |
| nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size |
| ) |
|
|
| if add_bias_kv: |
| self.bias_k = nn.Parameter(torch.Tensor(1, 1, embed_dim)) |
| self.bias_v = nn.Parameter(torch.Tensor(1, 1, embed_dim)) |
| else: |
| self.bias_k = self.bias_v = None |
|
|
| self.add_zero_attn = add_zero_attn |
| self.beam_size = 1 |
| self.reset_parameters() |
|
|
| self.onnx_trace = False |
| self.skip_embed_dim_check = False |
|
|
| def prepare_for_onnx_export_(self): |
| self.onnx_trace = True |
|
|
| def reset_parameters(self): |
| if self.qkv_same_dim: |
| |
| |
| nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2)) |
| nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2)) |
| nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2)) |
| else: |
| nn.init.xavier_uniform_(self.k_proj.weight) |
| nn.init.xavier_uniform_(self.v_proj.weight) |
| nn.init.xavier_uniform_(self.q_proj.weight) |
|
|
| nn.init.xavier_uniform_(self.out_proj.weight) |
| if self.out_proj.bias is not None: |
| nn.init.constant_(self.out_proj.bias, 0.0) |
| if self.bias_k is not None: |
| nn.init.xavier_normal_(self.bias_k) |
| if self.bias_v is not None: |
| nn.init.xavier_normal_(self.bias_v) |
|
|
| def _get_reserve_head_index(self, num_heads_to_keep: int): |
| k_proj_heads_norm = [] |
| q_proj_heads_norm = [] |
| v_proj_heads_norm = [] |
|
|
| for i in range(self.num_heads): |
| start_idx = i * self.head_dim |
| end_idx = (i + 1) * self.head_dim |
| k_proj_heads_norm.append( |
| torch.sum( |
| torch.abs( |
| self.k_proj.weight[ |
| start_idx:end_idx, |
| ] |
| ) |
| ).tolist() |
| + torch.sum(torch.abs(self.k_proj.bias[start_idx:end_idx])).tolist() |
| ) |
| q_proj_heads_norm.append( |
| torch.sum( |
| torch.abs( |
| self.q_proj.weight[ |
| start_idx:end_idx, |
| ] |
| ) |
| ).tolist() |
| + torch.sum(torch.abs(self.q_proj.bias[start_idx:end_idx])).tolist() |
| ) |
| v_proj_heads_norm.append( |
| torch.sum( |
| torch.abs( |
| self.v_proj.weight[ |
| start_idx:end_idx, |
| ] |
| ) |
| ).tolist() |
| + torch.sum(torch.abs(self.v_proj.bias[start_idx:end_idx])).tolist() |
| ) |
|
|
| heads_norm = [] |
| for i in range(self.num_heads): |
| heads_norm.append( |
| k_proj_heads_norm[i] + q_proj_heads_norm[i] + v_proj_heads_norm[i] |
| ) |
|
|
| sorted_head_index = sorted( |
| range(self.num_heads), key=lambda k: heads_norm[k], reverse=True |
| ) |
| reserve_head_index = [] |
| for i in range(num_heads_to_keep): |
| start = sorted_head_index[i] * self.head_dim |
| end = (sorted_head_index[i] + 1) * self.head_dim |
| reserve_head_index.append((start, end)) |
| return reserve_head_index |
|
|
| def _adaptive_prune_heads(self, reserve_head_index: List[Tuple[int, int]]): |
| new_q_weight = [] |
| new_q_bias = [] |
| new_k_weight = [] |
| new_k_bias = [] |
| new_v_weight = [] |
| new_v_bias = [] |
| new_out_proj_weight = [] |
|
|
| for ele in reserve_head_index: |
| start_idx, end_idx = ele |
| new_q_weight.append( |
| self.q_proj.weight[ |
| start_idx:end_idx, |
| ] |
| ) |
| new_q_bias.append(self.q_proj.bias[start_idx:end_idx]) |
|
|
| new_k_weight.append( |
| self.k_proj.weight[ |
| start_idx:end_idx, |
| ] |
| ) |
|
|
| new_k_bias.append(self.k_proj.bias[start_idx:end_idx]) |
|
|
| new_v_weight.append( |
| self.v_proj.weight[ |
| start_idx:end_idx, |
| ] |
| ) |
| new_v_bias.append(self.v_proj.bias[start_idx:end_idx]) |
|
|
| new_out_proj_weight.append(self.out_proj.weight[:, start_idx:end_idx]) |
|
|
| new_q_weight = torch.cat(new_q_weight).detach() |
| new_k_weight = torch.cat(new_k_weight).detach() |
| new_v_weight = torch.cat(new_v_weight).detach() |
| new_out_proj_weight = torch.cat(new_out_proj_weight, dim=-1).detach() |
| new_q_weight.requires_grad = True |
| new_k_weight.requires_grad = True |
| new_v_weight.requires_grad = True |
| new_out_proj_weight.requires_grad = True |
|
|
| new_q_bias = torch.cat(new_q_bias).detach() |
| new_q_bias.requires_grad = True |
|
|
| new_k_bias = torch.cat(new_k_bias).detach() |
| new_k_bias.requires_grad = True |
|
|
| new_v_bias = torch.cat(new_v_bias).detach() |
| new_v_bias.requires_grad = True |
|
|
| self.q_proj.weight = torch.nn.Parameter(new_q_weight) |
| self.q_proj.bias = torch.nn.Parameter(new_q_bias) |
|
|
| self.k_proj.weight = torch.nn.Parameter(new_k_weight) |
| self.k_proj.bias = torch.nn.Parameter(new_k_bias) |
|
|
| self.v_proj.weight = torch.nn.Parameter(new_v_weight) |
| self.v_proj.bias = torch.nn.Parameter(new_v_bias) |
|
|
| self.out_proj.weight = torch.nn.Parameter(new_out_proj_weight) |
|
|
| self.num_heads = len(reserve_head_index) |
| self.embed_dim = self.head_dim * self.num_heads |
| self.q_proj.out_features = self.embed_dim |
| self.k_proj.out_features = self.embed_dim |
| self.v_proj.out_features = self.embed_dim |
|
|
| def _set_skip_embed_dim_check(self): |
| self.skip_embed_dim_check = True |
|
|
| def _pad_masks( |
| self, |
| key_padding_mask: Optional[Tensor], |
| attn_mask: Optional[Tensor], |
| ) -> Tuple[Optional[Tensor], Optional[Tensor]]: |
| if attn_mask is not None: |
| shape = attn_mask.size()[:-1] + torch.Size([1]) |
| attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(shape)], dim=-1) |
| if key_padding_mask is not None: |
| shape = key_padding_mask.size()[:-1] + torch.Size([1]) |
| key_padding_mask = torch.cat( |
| [ |
| key_padding_mask, |
| key_padding_mask.new_zeros(shape), |
| ], |
| dim=-1, |
| ) |
| return key_padding_mask, attn_mask |
|
|
| def _add_bias( |
| self, |
| k: Tensor, |
| v: Tensor, |
| key_padding_mask: Optional[Tensor], |
| attn_mask: Optional[Tensor], |
| bsz: int, |
| ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: |
| assert self.bias_k is not None |
| assert self.bias_v is not None |
| k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) |
| v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) |
| key_padding_mask, attn_mask = self._pad_masks( |
| key_padding_mask=key_padding_mask, attn_mask=attn_mask |
| ) |
| return k, v, key_padding_mask, attn_mask |
|
|
| def _append_zero_attn( |
| self, |
| k: Tensor, |
| v: Tensor, |
| key_padding_mask: Optional[Tensor], |
| attn_mask: Optional[Tensor], |
| ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: |
| zero_attn_shape = k.size()[:-2] + torch.Size([1]) + k.size()[-1:] |
| k = torch.cat( |
| [k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=-2 |
| ) |
| v = torch.cat( |
| [v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=-2 |
| ) |
| key_padding_mask, attn_mask = self._pad_masks( |
| key_padding_mask=key_padding_mask, attn_mask=attn_mask |
| ) |
| return k, v, key_padding_mask, attn_mask |
|
|
| def forward( |
| self, |
| query, |
| key: Optional[Tensor], |
| value: Optional[Tensor], |
| key_padding_mask: Optional[Tensor] = None, |
| incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, |
| need_weights: bool = True, |
| static_kv: bool = False, |
| attn_mask: Optional[Tensor] = None, |
| before_softmax: bool = False, |
| need_head_weights: bool = False, |
| ) -> Tuple[Tensor, Optional[Tensor]]: |
| """Input shape: Time x Batch x Channel |
| |
| Args: |
| key_padding_mask (ByteTensor, optional): mask to exclude |
| keys that are pads, of shape `(batch, src_len)`, where |
| padding elements are indicated by 1s. |
| need_weights (bool, optional): return the attention weights, |
| averaged over heads (default: False). |
| attn_mask (ByteTensor, optional): typically used to |
| implement causal attention, where the mask prevents the |
| attention from looking forward in time (default: None). |
| before_softmax (bool, optional): return the raw attention |
| weights and values before the attention softmax. |
| need_head_weights (bool, optional): return the attention |
| weights for each head. Implies *need_weights*. Default: |
| return the average attention weights over all heads. |
| """ |
| if need_head_weights: |
| need_weights = True |
|
|
| is_tpu = query.device.type == "xla" |
|
|
| tgt_len, bsz, embed_dim = query.size() |
| src_len = tgt_len |
| if not self.skip_embed_dim_check: |
| assert ( |
| embed_dim == self.embed_dim |
| ), f"query dim {embed_dim} != {self.embed_dim}" |
| assert list(query.size()) == [tgt_len, bsz, embed_dim] |
| if key is not None: |
| src_len, key_bsz, _ = key.size() |
| if not torch.jit.is_scripting(): |
| assert value is not None |
| assert src_len, key_bsz == value.shape[:2] |
|
|
| if ( |
| not self.onnx_trace |
| and not is_tpu |
| and incremental_state is None |
| and not static_kv |
| |
| |
| and not torch.jit.is_scripting() |
| |
| |
| |
| |
| and not self.skip_embed_dim_check |
| ): |
| assert key is not None and value is not None |
|
|
| if self.use_xformers: |
| return self._xformers_attn_forward( |
| query, key, value, key_padding_mask, need_weights, attn_mask |
| ) |
|
|
| else: |
| return F.multi_head_attention_forward( |
| query, |
| key, |
| value, |
| self.embed_dim, |
| self.num_heads, |
| torch.empty([0]), |
| torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)), |
| self.bias_k, |
| self.bias_v, |
| self.add_zero_attn, |
| self.dropout_module.p, |
| self.out_proj.weight, |
| self.out_proj.bias, |
| self.training or self.dropout_module.apply_during_inference, |
| key_padding_mask, |
| need_weights, |
| attn_mask, |
| use_separate_proj_weight=True, |
| q_proj_weight=self.q_proj.weight, |
| k_proj_weight=self.k_proj.weight, |
| v_proj_weight=self.v_proj.weight, |
| ) |
|
|
| if incremental_state is not None: |
| saved_state = self._get_input_buffer(incremental_state) |
| if saved_state is not None and "prev_key" in saved_state: |
| |
| |
| if static_kv: |
| assert self.encoder_decoder_attention and not self.self_attention |
| key = value = None |
| else: |
| saved_state = None |
|
|
| if self.self_attention: |
| q = self.q_proj(query) |
| k = self.k_proj(query) |
| v = self.v_proj(query) |
| elif self.encoder_decoder_attention: |
| |
| q = self.q_proj(query) |
| if key is None: |
| assert value is None |
| k = v = None |
| else: |
| if self.beam_size > 1 and bsz == key.size(1): |
| |
| key = key.view(key.size(0), -1, self.beam_size, key.size(2))[ |
| :, :, 0, : |
| ] |
| if key_padding_mask is not None: |
| key_padding_mask = key_padding_mask.view( |
| -1, self.beam_size, key_padding_mask.size(1) |
| )[:, 0, :] |
| k = self.k_proj(key) |
| v = self.v_proj(key) |
|
|
| else: |
| assert key is not None and value is not None |
| q = self.q_proj(query) |
| k = self.k_proj(key) |
| v = self.v_proj(value) |
| q *= self.scaling |
|
|
| if self.bias_k is not None: |
| assert self.bias_v is not None |
| k, v, attn_mask, key_padding_mask = self._add_bias( |
| k, v, attn_mask, key_padding_mask, bsz |
| ) |
|
|
| q = ( |
| q.contiguous() |
| .view(tgt_len, bsz * self.num_heads, self.head_dim) |
| .transpose(0, 1) |
| ) |
| kv_bsz = bsz |
| if k is not None: |
| kv_bsz = k.size(1) |
| k = ( |
| k.contiguous() |
| .view(-1, kv_bsz * self.num_heads, self.head_dim) |
| .transpose(0, 1) |
| ) |
| if v is not None: |
| v = ( |
| v.contiguous() |
| .view(-1, kv_bsz * self.num_heads, self.head_dim) |
| .transpose(0, 1) |
| ) |
|
|
| if saved_state is not None: |
| |
| if "prev_key" in saved_state: |
| _prev_key = saved_state["prev_key"] |
| assert _prev_key is not None |
| kv_bsz = _prev_key.size(0) |
| prev_key = _prev_key.view(kv_bsz * self.num_heads, -1, self.head_dim) |
| if static_kv: |
| k = prev_key |
| else: |
| assert k is not None |
| k = torch.cat([prev_key, k], dim=1) |
| src_len = k.size(1) |
| if "prev_value" in saved_state: |
| _prev_value = saved_state["prev_value"] |
| assert _prev_value is not None |
| assert kv_bsz == _prev_value.size(0) |
| prev_value = _prev_value.view( |
| kv_bsz * self.num_heads, -1, self.head_dim |
| ) |
| if static_kv: |
| v = prev_value |
| else: |
| assert v is not None |
| v = torch.cat([prev_value, v], dim=1) |
| prev_key_padding_mask: Optional[Tensor] = None |
| if "prev_key_padding_mask" in saved_state: |
| prev_key_padding_mask = saved_state["prev_key_padding_mask"] |
| assert k is not None and v is not None |
| key_padding_mask = MultiheadAttention._append_prev_key_padding_mask( |
| key_padding_mask=key_padding_mask, |
| prev_key_padding_mask=prev_key_padding_mask, |
| batch_size=kv_bsz, |
| src_len=k.size(1), |
| static_kv=static_kv, |
| ) |
|
|
| saved_state["prev_key"] = k.view(kv_bsz, self.num_heads, -1, self.head_dim) |
| saved_state["prev_value"] = v.view( |
| kv_bsz, self.num_heads, -1, self.head_dim |
| ) |
| saved_state["prev_key_padding_mask"] = key_padding_mask |
| |
| assert incremental_state is not None |
| incremental_state = self._set_input_buffer(incremental_state, saved_state) |
| assert k is not None |
| assert k.size(1) == src_len |
|
|
| |
| |
| if key_padding_mask is not None and key_padding_mask.dim() == 0: |
| key_padding_mask = None |
|
|
| if key_padding_mask is not None: |
| assert key_padding_mask.size(0) == kv_bsz |
| assert key_padding_mask.size(1) == src_len |
|
|
| if self.add_zero_attn: |
| assert v is not None |
| src_len += 1 |
| k, v, key_padding_mask, attn_mask = self._append_zero_attn( |
| k=k, v=v, key_padding_mask=key_padding_mask, attn_mask=attn_mask |
| ) |
|
|
| if self.encoder_decoder_attention and bsz != kv_bsz: |
| attn_weights = torch.einsum( |
| "bxhtd,bhsd->bxhts", |
| q.view((kv_bsz, -1, self.num_heads) + q.size()[1:]), |
| k.view((kv_bsz, self.num_heads) + k.size()[1:]), |
| ) |
| attn_weights = attn_weights.reshape((-1,) + attn_weights.size()[-2:]) |
| else: |
| attn_weights = torch.bmm(q, k.transpose(1, 2)) |
| attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) |
|
|
| assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] |
|
|
| if attn_mask is not None: |
| attn_mask = attn_mask.unsqueeze(0) |
| if self.onnx_trace: |
| attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1) |
| attn_weights += attn_mask |
|
|
| if key_padding_mask is not None: |
| |
| attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) |
| if not is_tpu: |
| attn_weights = attn_weights.view( |
| kv_bsz, -1, self.num_heads, tgt_len, src_len |
| ) |
| attn_weights = attn_weights.masked_fill( |
| key_padding_mask.unsqueeze(1) |
| .unsqueeze(2) |
| .unsqueeze(3) |
| .to(torch.bool), |
| float("-inf"), |
| ) |
| else: |
| attn_weights = attn_weights.transpose(0, 2) |
| attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf")) |
| attn_weights = attn_weights.transpose(0, 2) |
| attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
|
|
| if before_softmax: |
| return attn_weights, v |
|
|
| def softmax_supporting_onnx_trace(x, dim: int, onnx_trace: bool = False): |
| if onnx_trace: |
| return F.softmax(x.float(), dim=dim) |
| else: |
| return F.softmax(x, dim=dim, dtype=torch.float32) |
|
|
| attn_weights_float = softmax_supporting_onnx_trace( |
| attn_weights, dim=-1, onnx_trace=self.onnx_trace |
| ) |
| attn_weights = attn_weights_float.type_as(attn_weights) |
| attn_probs = self.dropout_module(attn_weights) |
|
|
| assert v is not None |
| if self.encoder_decoder_attention and bsz != kv_bsz: |
| attn = torch.einsum( |
| "bxhts,bhsd->bxhtd", |
| attn_probs.view( |
| ( |
| kv_bsz, |
| -1, |
| self.num_heads, |
| ) |
| + attn_probs.size()[1:] |
| ), |
| v.view( |
| ( |
| kv_bsz, |
| self.num_heads, |
| ) |
| + v.size()[1:] |
| ), |
| ) |
| attn = attn.reshape((-1,) + attn.size()[-2:]) |
| else: |
| attn = torch.bmm(attn_probs, v) |
| assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] |
| if self.onnx_trace and attn.size(1) == 1: |
| |
| |
| attn = attn.contiguous().view(tgt_len, bsz, self.embed_dim) |
| else: |
| attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim) |
| attn = self.out_proj(attn) |
| attn_weights: Optional[Tensor] = None |
| if need_weights: |
| attn_weights = attn_weights_float.view( |
| bsz, self.num_heads, tgt_len, src_len |
| ).transpose(1, 0) |
| if not need_head_weights: |
| |
| attn_weights = attn_weights.mean(dim=0) |
|
|
| return attn, attn_weights |
|
|
| @staticmethod |
| def _append_prev_key_padding_mask( |
| key_padding_mask: Optional[Tensor], |
| prev_key_padding_mask: Optional[Tensor], |
| batch_size: int, |
| src_len: int, |
| static_kv: bool, |
| ) -> Optional[Tensor]: |
| |
| if prev_key_padding_mask is not None and static_kv: |
| new_key_padding_mask = prev_key_padding_mask |
| elif prev_key_padding_mask is not None and key_padding_mask is not None: |
| new_key_padding_mask = torch.cat( |
| [prev_key_padding_mask.float(), key_padding_mask.float()], dim=1 |
| ) |
| |
| |
| |
| elif prev_key_padding_mask is not None: |
| if src_len > prev_key_padding_mask.size(1): |
| filler = torch.zeros( |
| (batch_size, src_len - prev_key_padding_mask.size(1)), |
| device=prev_key_padding_mask.device, |
| ) |
| new_key_padding_mask = torch.cat( |
| [prev_key_padding_mask.float(), filler.float()], dim=1 |
| ) |
| else: |
| new_key_padding_mask = prev_key_padding_mask.float() |
| elif key_padding_mask is not None: |
| if src_len > key_padding_mask.size(1): |
| filler = torch.zeros( |
| (batch_size, src_len - key_padding_mask.size(1)), |
| device=key_padding_mask.device, |
| ) |
| new_key_padding_mask = torch.cat( |
| [filler.float(), key_padding_mask.float()], dim=1 |
| ) |
| else: |
| new_key_padding_mask = key_padding_mask.float() |
| else: |
| new_key_padding_mask = prev_key_padding_mask |
| return new_key_padding_mask |
|
|
| @torch.jit.export |
| def reorder_incremental_state( |
| self, |
| incremental_state: Dict[str, Dict[str, Optional[Tensor]]], |
| new_order: Tensor, |
| ): |
| """Reorder buffered internal state (for incremental generation).""" |
| input_buffer = self._get_input_buffer(incremental_state) |
| if input_buffer is not None: |
| for k in input_buffer.keys(): |
| input_buffer_k = input_buffer[k] |
| if input_buffer_k is not None: |
| if self.encoder_decoder_attention: |
| if input_buffer_k.size(0) * self.beam_size == new_order.size(0): |
| return incremental_state |
| elif self.beam_size > 1: |
| input_buffer[k] = input_buffer_k.index_select( |
| 0, |
| new_order.reshape(-1, self.beam_size)[:, 0] |
| // self.beam_size, |
| ) |
| else: |
| input_buffer[k] = input_buffer_k.index_select(0, new_order) |
| else: |
| input_buffer[k] = input_buffer_k.index_select(0, new_order) |
| incremental_state = self._set_input_buffer(incremental_state, input_buffer) |
| return incremental_state |
|
|
| def set_beam_size(self, beam_size): |
| """Used for effiecient beamable enc-dec attention""" |
| self.beam_size = beam_size |
|
|
| def _get_input_buffer( |
| self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] |
| ) -> Dict[str, Optional[Tensor]]: |
| result = self.get_incremental_state(incremental_state, "attn_state") |
| if result is not None: |
| return result |
| else: |
| empty_result: Dict[str, Optional[Tensor]] = {} |
| return empty_result |
|
|
| def _set_input_buffer( |
| self, |
| incremental_state: Dict[str, Dict[str, Optional[Tensor]]], |
| buffer: Dict[str, Optional[Tensor]], |
| ): |
| return self.set_incremental_state(incremental_state, "attn_state", buffer) |
|
|
| def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int): |
| return attn_weights |
|
|
| def upgrade_state_dict_named(self, state_dict, name): |
| prefix = name + "." if name != "" else "" |
| items_to_add = {} |
| keys_to_remove = [] |
| for k in state_dict.keys(): |
| if k.endswith(prefix + "in_proj_weight"): |
| |
| dim = int(state_dict[k].shape[0] / 3) |
| items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim] |
| items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim] |
| items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :] |
|
|
| keys_to_remove.append(k) |
|
|
| k_bias = prefix + "in_proj_bias" |
| if k_bias in state_dict.keys(): |
| dim = int(state_dict[k].shape[0] / 3) |
| items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim] |
| items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][ |
| dim : 2 * dim |
| ] |
| items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :] |
|
|
| keys_to_remove.append(prefix + "in_proj_bias") |
|
|
| for k in keys_to_remove: |
| del state_dict[k] |
|
|
| for key, value in items_to_add.items(): |
| state_dict[key] = value |
|
|
|
|
| class RelPositionalEncoding(nn.Module): |
| """Relative positional encoding module (new implementation). |
| |
| Args: |
| d_model: Embedding dimension. |
| dropout_rate: Dropout rate. |
| max_len: Maximum input length. |
| """ |
|
|
| def __init__(self, max_len, d_model): |
| """Construct an PositionalEncoding object.""" |
| super(RelPositionalEncoding, self).__init__() |
| self.d_model = d_model |
| self.pe = None |
| self.extend_pe(torch.tensor(0.0).expand(1, max_len)) |
|
|
| def extend_pe(self, x): |
| """Reset the positional encodings.""" |
| if self.pe is not None: |
| |
| |
| if self.pe.size(1) >= x.size(1) * 2 - 1: |
| if self.pe.dtype != x.dtype or self.pe.device != x.device: |
| self.pe = self.pe.to(dtype=x.dtype, device=x.device) |
| return |
| |
| |
| |
| pe_positive = torch.zeros(x.size(1), self.d_model) |
| pe_negative = torch.zeros(x.size(1), self.d_model) |
| position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) |
| div_term = torch.exp( |
| torch.arange(0, self.d_model, 2, dtype=torch.float32) |
| * -(math.log(10000.0) / self.d_model) |
| ) |
| pe_positive[:, 0::2] = torch.sin(position * div_term) |
| pe_positive[:, 1::2] = torch.cos(position * div_term) |
| pe_negative[:, 0::2] = torch.sin(-1 * position * div_term) |
| pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) |
|
|
| |
| |
| |
| pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) |
| pe_negative = pe_negative[1:].unsqueeze(0) |
| pe = torch.cat([pe_positive, pe_negative], dim=1) |
| self.pe = pe.to(device=x.device, dtype=x.dtype) |
|
|
| def forward(self, x: torch.Tensor): |
| """Add positional encoding. |
| Args: |
| x : Input tensor T X B X C. |
| Returns: |
| torch.Tensor: Encoded tensor T X B X C. |
| |
| """ |
| x = x.transpose(0, 1) |
| self.extend_pe(x) |
| pos_emb = self.pe[ |
| :, |
| self.pe.size(1) // 2 - x.size(1) + 1 : self.pe.size(1) // 2 + x.size(1), |
| ] |
| pos_emb = pos_emb.transpose(0, 1) |
| return pos_emb |
|
|
|
|
| class GumbelVectorQuantizer(nn.Module): |
| def __init__( |
| self, |
| dim, |
| num_vars, |
| temp, |
| groups, |
| combine_groups, |
| vq_dim, |
| time_first, |
| activation=nn.GELU(), |
| weight_proj_depth=1, |
| weight_proj_factor=1, |
| ): |
| """Vector quantization using gumbel softmax |
| |
| Args: |
| dim: input dimension (channels) |
| num_vars: number of quantized vectors per group |
| temp: temperature for training. this should be a tuple of 3 elements: (start, stop, decay factor) |
| groups: number of groups for vector quantization |
| combine_groups: whether to use the vectors for all groups |
| vq_dim: dimensionality of the resulting quantized vector |
| time_first: if true, expect input in BxTxC format, otherwise in BxCxT |
| activation: what activation to use (should be a module). this is only used if weight_proj_depth is > 1 |
| weight_proj_depth: number of layers (with activation in between) to project input before computing logits |
| weight_proj_factor: this is used only if weight_proj_depth is > 1. scales the inner dimensionality of |
| projections by this factor |
| """ |
| super().__init__() |
|
|
| self.groups = groups |
| self.combine_groups = combine_groups |
| self.input_dim = dim |
| self.num_vars = num_vars |
| self.time_first = time_first |
|
|
| assert ( |
| vq_dim % groups == 0 |
| ), f"dim {vq_dim} must be divisible by groups {groups} for concatenation" |
|
|
| var_dim = vq_dim // groups |
| num_groups = groups if not combine_groups else 1 |
|
|
| self.vars = nn.Parameter(torch.FloatTensor(1, num_groups * num_vars, var_dim)) |
| nn.init.uniform_(self.vars) |
|
|
| if weight_proj_depth > 1: |
|
|
| def block(input_dim, output_dim): |
| return nn.Sequential(nn.Linear(input_dim, output_dim), activation) |
|
|
| inner_dim = self.input_dim * weight_proj_factor |
| self.weight_proj = nn.Sequential( |
| *[ |
| block(self.input_dim if i == 0 else inner_dim, inner_dim) |
| for i in range(weight_proj_depth - 1) |
| ], |
| nn.Linear(inner_dim, groups * num_vars), |
| ) |
| else: |
| self.weight_proj = nn.Linear(self.input_dim, groups * num_vars) |
| nn.init.normal_(self.weight_proj.weight, mean=0, std=1) |
| nn.init.zeros_(self.weight_proj.bias) |
|
|
| if isinstance(temp, str): |
| import ast |
|
|
| temp = ast.literal_eval(temp) |
| assert len(temp) == 3, f"{temp}, {len(temp)}" |
|
|
| self.max_temp, self.min_temp, self.temp_decay = temp |
| self.curr_temp = self.max_temp |
| self.codebook_indices = None |
|
|
| def set_num_updates(self, num_updates): |
| self.curr_temp = max( |
| self.max_temp * self.temp_decay**num_updates, self.min_temp |
| ) |
|
|
| def get_codebook_indices(self): |
| if self.codebook_indices is None: |
| from itertools import product |
|
|
| p = [range(self.num_vars)] * self.groups |
| inds = list(product(*p)) |
| self.codebook_indices = torch.tensor( |
| inds, dtype=torch.long, device=self.vars.device |
| ).flatten() |
|
|
| if not self.combine_groups: |
| self.codebook_indices = self.codebook_indices.view( |
| self.num_vars**self.groups, -1 |
| ) |
| for b in range(1, self.groups): |
| self.codebook_indices[:, b] += self.num_vars * b |
| self.codebook_indices = self.codebook_indices.flatten() |
| return self.codebook_indices |
|
|
| def codebook(self): |
| indices = self.get_codebook_indices() |
| return ( |
| self.vars.squeeze(0) |
| .index_select(0, indices) |
| .view(self.num_vars**self.groups, -1) |
| ) |
|
|
| def sample_from_codebook(self, b, n): |
| indices = self.get_codebook_indices() |
| indices = indices.view(-1, self.groups) |
| cb_size = indices.size(0) |
| assert ( |
| n < cb_size |
| ), f"sample size {n} is greater than size of codebook {cb_size}" |
| sample_idx = torch.randint(low=0, high=cb_size, size=(b * n,)) |
| indices = indices[sample_idx] |
|
|
| z = self.vars.squeeze(0).index_select(0, indices.flatten()).view(b, n, -1) |
| return z |
|
|
| def to_codebook_index(self, indices): |
| res = indices.new_full(indices.shape[:-1], 0) |
| for i in range(self.groups): |
| exponent = self.groups - i - 1 |
| res += indices[..., i] * (self.num_vars**exponent) |
| return res |
|
|
| def forward_idx(self, x): |
| res = self.forward(x, produce_targets=True) |
| return res["x"], res["targets"] |
|
|
| def forward(self, x, produce_targets=False): |
| result = {"num_vars": self.num_vars * self.groups} |
|
|
| if not self.time_first: |
| x = x.transpose(1, 2) |
|
|
| bsz, tsz, fsz = x.shape |
| x = x.reshape(-1, fsz) |
| x = self.weight_proj(x) |
| x = x.view(bsz * tsz * self.groups, -1) |
|
|
| _, k = x.max(-1) |
| hard_x = ( |
| x.new_zeros(*x.shape) |
| .scatter_(-1, k.view(-1, 1), 1.0) |
| .view(bsz * tsz, self.groups, -1) |
| ) |
| hard_probs = torch.mean(hard_x.float(), dim=0) |
| result["code_perplexity"] = torch.exp( |
| -torch.sum(hard_probs * torch.log(hard_probs + 1e-7), dim=-1) |
| ).sum() |
|
|
| avg_probs = torch.softmax( |
| x.view(bsz * tsz, self.groups, -1).float(), dim=-1 |
| ).mean(dim=0) |
| result["prob_perplexity"] = torch.exp( |
| -torch.sum(avg_probs * torch.log(avg_probs + 1e-7), dim=-1) |
| ).sum() |
|
|
| result["temp"] = self.curr_temp |
|
|
| if self.training: |
| x = F.gumbel_softmax(x.float(), tau=self.curr_temp, hard=True).type_as(x) |
| else: |
| x = hard_x |
|
|
| x = x.view(bsz * tsz, -1) |
|
|
| vars = self.vars |
| if self.combine_groups: |
| vars = vars.repeat(1, self.groups, 1) |
|
|
| if produce_targets: |
| result["targets"] = ( |
| x.view(bsz * tsz * self.groups, -1) |
| .argmax(dim=-1) |
| .view(bsz, tsz, self.groups) |
| .detach() |
| ) |
|
|
| x = x.unsqueeze(-1) * vars |
| x = x.view(bsz * tsz, self.groups, self.num_vars, -1) |
| x = x.sum(-2) |
| x = x.view(bsz, tsz, -1) |
|
|
| if not self.time_first: |
| x = x.transpose(1, 2) |
|
|
| result["x"] = x |
|
|
| return result |
|
|
|
|
| class GradMultiply(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x, scale): |
| ctx.scale = scale |
| res = x.new(x) |
| return res |
|
|
| @staticmethod |
| def backward(ctx, grad): |
| return grad * ctx.scale, None |
|
|
|
|
| class SamePad(nn.Module): |
| def __init__(self, kernel_size, causal=False): |
| super().__init__() |
| if causal: |
| self.remove = kernel_size - 1 |
| else: |
| self.remove = 1 if kernel_size % 2 == 0 else 0 |
|
|
| def forward(self, x): |
| if self.remove > 0: |
| x = x[:, :, : -self.remove] |
| return x |
|
|
|
|
| class TransposeLast(nn.Module): |
| def __init__(self, deconstruct_idx=None): |
| super().__init__() |
| self.deconstruct_idx = deconstruct_idx |
|
|
| def forward(self, x): |
| if self.deconstruct_idx is not None: |
| x = x[self.deconstruct_idx] |
| return x.transpose(-2, -1) |
|
|
|
|
| def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): |
| return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) |
|
|
|
|
| class Fp32LayerNorm(nn.LayerNorm): |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
|
|
| def forward(self, input): |
| output = F.layer_norm( |
| input.float(), |
| self.normalized_shape, |
| self.weight.float() if self.weight is not None else None, |
| self.bias.float() if self.bias is not None else None, |
| self.eps, |
| ) |
| return output.type_as(input) |
|
|
|
|
| class Fp32GroupNorm(nn.GroupNorm): |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
|
|
| def forward(self, input): |
| output = F.group_norm( |
| input.float(), |
| self.num_groups, |
| self.weight.float() if self.weight is not None else None, |
| self.bias.float() if self.bias is not None else None, |
| self.eps, |
| ) |
| return output.type_as(input) |
|
|
|
|
| class StrEnumMeta(EnumMeta): |
| |
| |
| @classmethod |
| def __instancecheck__(cls, other): |
| return "enum" in str(type(other)) |
|
|
|
|
| class StrEnum(Enum, metaclass=StrEnumMeta): |
| def __str__(self): |
| return self.value |
|
|
| def __eq__(self, other: str): |
| return self.value == other |
|
|
| def __repr__(self): |
| return self.value |
|
|
| def __hash__(self): |
| return hash(str(self)) |
|
|
|
|
| def ChoiceEnum(choices: List[str]): |
| """return the Enum class used to enforce list of choices""" |
| return StrEnum("Choices", {k: k for k in choices}) |
|
|
|
|
| def relu_squared(x: torch.Tensor): |
| return F.relu(x).pow(2) |
|
|
|
|
| def get_activation_fn(activation: str) -> Callable: |
| """Returns the activation function corresponding to `activation`""" |
|
|
| def gelu_accurate(x): |
| if not hasattr(gelu_accurate, "_a"): |
| gelu_accurate._a = math.sqrt(2 / math.pi) |
| return ( |
| 0.5 |
| * x |
| * (1 + torch.tanh(gelu_accurate._a * (x + 0.044715 * torch.pow(x, 3)))) |
| ) |
|
|
| def gelu(x: torch.Tensor) -> torch.Tensor: |
| return torch.nn.functional.gelu(x.float()).type_as(x) |
|
|
| if activation == "relu": |
| return F.relu |
| elif activation == "relu_squared": |
| return relu_squared |
| elif activation == "gelu": |
| return gelu |
| elif activation == "gelu_fast": |
| return gelu_accurate |
| elif activation == "gelu_accurate": |
| return gelu_accurate |
| elif activation == "tanh": |
| return torch.tanh |
| elif activation == "linear": |
| return lambda x: x |
| elif activation == "swish": |
| return torch.nn.SiLU |
| else: |
| raise RuntimeError("--activation-fn {} not supported".format(activation)) |
|
|
|
|
| def get_available_activation_fns() -> List: |
| return [ |
| "relu", |
| "gelu", |
| "gelu_fast", |
| "gelu_accurate", |
| "tanh", |
| "linear", |
| ] |
|
|
|
|
| def compute_mask_indices( |
| shape: Tuple[int, int], |
| padding_mask: Optional[torch.Tensor], |
| mask_prob: float, |
| mask_length: int, |
| mask_type: str = "static", |
| mask_other: float = 0.0, |
| min_masks: int = 0, |
| no_overlap: bool = False, |
| min_space: int = 0, |
| require_same_masks: bool = True, |
| mask_dropout: float = 0.0, |
| ) -> np.ndarray: |
| """ |
| Computes random mask spans for a given shape |
| |
| Args: |
| shape: the the shape for which to compute masks. |
| should be of size 2 where first element is batch size and 2nd is timesteps |
| padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements |
| mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by |
| number of timesteps divided by length of mask span to mask approximately this percentage of all elements. |
| however due to overlaps, the actual number will be smaller (unless no_overlap is True) |
| mask_type: how to compute mask lengths |
| static = fixed size |
| uniform = sample from uniform distribution [mask_other, mask_length*2] |
| normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element |
| poisson = sample from possion distribution with lambda = mask length |
| min_masks: minimum number of masked spans |
| no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping |
| min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans |
| require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample |
| mask_dropout: randomly dropout this percentage of masks in each example |
| """ |
|
|
| bsz, all_sz = shape |
| mask = np.full((bsz, all_sz), False) |
|
|
| all_num_mask = int( |
| |
| mask_prob * all_sz / float(mask_length) |
| + np.random.rand() |
| ) |
|
|
| all_num_mask = max(min_masks, all_num_mask) |
|
|
| mask_idcs = [] |
| for i in range(bsz): |
| if padding_mask is not None: |
| sz = all_sz - padding_mask[i].long().sum().item() |
| num_mask = int( |
| |
| mask_prob * sz / float(mask_length) |
| + np.random.rand() |
| ) |
| num_mask = max(min_masks, num_mask) |
| else: |
| sz = all_sz |
| num_mask = all_num_mask |
|
|
| if mask_type == "static": |
| lengths = np.full(num_mask, mask_length) |
| elif mask_type == "uniform": |
| lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) |
| elif mask_type == "normal": |
| lengths = np.random.normal(mask_length, mask_other, size=num_mask) |
| lengths = [max(1, int(round(x))) for x in lengths] |
| elif mask_type == "poisson": |
| lengths = np.random.poisson(mask_length, size=num_mask) |
| lengths = [int(round(x)) for x in lengths] |
| else: |
| raise Exception("unknown mask selection " + mask_type) |
|
|
| if sum(lengths) == 0: |
| lengths[0] = min(mask_length, sz - 1) |
|
|
| if no_overlap: |
| mask_idc = [] |
|
|
| def arrange(s, e, length, keep_length): |
| span_start = np.random.randint(s, e - length) |
| mask_idc.extend(span_start + i for i in range(length)) |
|
|
| new_parts = [] |
| if span_start - s - min_space >= keep_length: |
| new_parts.append((s, span_start - min_space + 1)) |
| if e - span_start - length - min_space > keep_length: |
| new_parts.append((span_start + length + min_space, e)) |
| return new_parts |
|
|
| parts = [(0, sz)] |
| min_length = min(lengths) |
| for length in sorted(lengths, reverse=True): |
| lens = np.fromiter( |
| (e - s if e - s >= length + min_space else 0 for s, e in parts), |
| np.int, |
| ) |
| l_sum = np.sum(lens) |
| if l_sum == 0: |
| break |
| probs = lens / np.sum(lens) |
| c = np.random.choice(len(parts), p=probs) |
| s, e = parts.pop(c) |
| parts.extend(arrange(s, e, length, min_length)) |
| mask_idc = np.asarray(mask_idc) |
| else: |
| min_len = min(lengths) |
| if sz - min_len <= num_mask: |
| min_len = sz - num_mask - 1 |
|
|
| mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) |
|
|
| mask_idc = np.asarray( |
| [ |
| mask_idc[j] + offset |
| for j in range(len(mask_idc)) |
| for offset in range(lengths[j]) |
| ] |
| ) |
|
|
| mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) |
|
|
| min_len = min([len(m) for m in mask_idcs]) |
| for i, mask_idc in enumerate(mask_idcs): |
| if len(mask_idc) > min_len and require_same_masks: |
| mask_idc = np.random.choice(mask_idc, min_len, replace=False) |
| if mask_dropout > 0: |
| num_holes = np.rint(len(mask_idc) * mask_dropout).astype(int) |
| mask_idc = np.random.choice( |
| mask_idc, len(mask_idc) - num_holes, replace=False |
| ) |
|
|
| mask[i, mask_idc] = True |
|
|
| return mask |
|
|
|
|
| def index_put(tensor, indices, value): |
| tensor[indices] = value |
| return tensor |
|
|
|
|
| def buffered_arange(max): |
| if not hasattr(buffered_arange, "buf"): |
| buffered_arange.buf = torch.LongTensor() |
| if max > buffered_arange.buf.numel(): |
| buffered_arange.buf.resize_(max) |
| torch.arange(max, out=buffered_arange.buf) |
| return buffered_arange.buf[:max] |
|
|
|
|
| def pad_to_multiple(x, multiple, dim=-1, value=0): |
| |
| if x is None: |
| return None, 0 |
| tsz = x.size(dim) |
| m = tsz / multiple |
| remainder = math.ceil(m) * multiple - tsz |
| if m.is_integer(): |
| return x, 0 |
| pad_offset = (0,) * (-1 - dim) * 2 |
|
|
| return F.pad(x, (*pad_offset, 0, remainder), value=value), remainder |
|
|
|
|
| EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"]) |
| MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(["static", "uniform", "normal", "poisson"]) |
| LAYER_TYPE_CHOICES = ChoiceEnum(["transformer", "conformer"]) |
|
|
|
|
| @dataclass |
| class Wav2Vec2Config: |
| extractor_mode: EXTRACTOR_MODE_CHOICES = field( |
| default="default", |
| metadata={ |
| "help": "mode for feature extractor. default has a single group norm with d " |
| "groups in the first conv block, whereas layer_norm has layer norms in " |
| "every block (meant to use with normalize=True)" |
| }, |
| ) |
| encoder_layers: int = field( |
| default=12, metadata={"help": "num encoder layers in the transformer"} |
| ) |
| encoder_embed_dim: int = field( |
| default=768, metadata={"help": "encoder embedding dimension"} |
| ) |
| encoder_ffn_embed_dim: int = field( |
| default=3072, metadata={"help": "encoder embedding dimension for FFN"} |
| ) |
| encoder_attention_heads: int = field( |
| default=12, metadata={"help": "num encoder attention heads"} |
| ) |
| activation_fn: ChoiceEnum(get_available_activation_fns()) = field( |
| default="gelu", metadata={"help": "activation function to use"} |
| ) |
| layer_type: LAYER_TYPE_CHOICES = field( |
| default="transformer", metadata={"help": "layer type in encoder"} |
| ) |
| |
| dropout: float = field( |
| default=0.1, metadata={"help": "dropout probability for the transformer"} |
| ) |
| attention_dropout: float = field( |
| default=0.1, metadata={"help": "dropout probability for attention weights"} |
| ) |
| activation_dropout: float = field( |
| default=0.0, metadata={"help": "dropout probability after activation in FFN"} |
| ) |
| encoder_layerdrop: float = field( |
| default=0.0, metadata={"help": "probability of dropping a tarnsformer layer"} |
| ) |
| dropout_input: float = field( |
| default=0.0, |
| metadata={"help": "dropout to apply to the input (after feat extr)"}, |
| ) |
| dropout_features: float = field( |
| default=0.0, |
| metadata={"help": "dropout to apply to the features (after feat extr)"}, |
| ) |
|
|
| final_dim: int = field( |
| default=0, |
| metadata={ |
| "help": "project final representations and targets to this many dimensions." |
| "set to encoder_embed_dim is <= 0" |
| }, |
| ) |
| layer_norm_first: bool = field( |
| default=False, metadata={"help": "apply layernorm first in the transformer"} |
| ) |
| conv_feature_layers: str = field( |
| default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]", |
| metadata={ |
| "help": "string describing convolutional feature extraction layers in form of a python list that contains " |
| "[(dim, kernel_size, stride), ...]" |
| }, |
| ) |
| conv_bias: bool = field( |
| default=False, metadata={"help": "include bias in conv encoder"} |
| ) |
| logit_temp: float = field( |
| default=0.1, metadata={"help": "temperature to divide logits by"} |
| ) |
| quantize_targets: bool = field( |
| default=False, metadata={"help": "use quantized targets"} |
| ) |
| quantize_input: bool = field( |
| default=False, metadata={"help": "use quantized inputs"} |
| ) |
| same_quantizer: bool = field( |
| default=False, metadata={"help": "use same quantizer for inputs and targets"} |
| ) |
| target_glu: bool = field( |
| default=False, metadata={"help": "adds projection + glu to targets"} |
| ) |
| feature_grad_mult: float = field( |
| default=1.0, metadata={"help": "multiply feature extractor var grads by this"} |
| ) |
| quantizer_depth: int = field( |
| default=1, |
| metadata={"help": "number of quantizer layers"}, |
| ) |
| quantizer_factor: int = field( |
| default=3, |
| metadata={ |
| "help": "dimensionality increase for inner quantizer layers (if depth > 1)" |
| }, |
| ) |
| latent_vars: int = field( |
| default=320, |
| metadata={"help": "number of latent variables V in each group of the codebook"}, |
| ) |
| latent_groups: int = field( |
| default=2, |
| metadata={"help": "number of groups G of latent variables in the codebook"}, |
| ) |
| latent_dim: int = field( |
| default=0, |
| metadata={ |
| "help": "if > 0, uses this dimensionality for latent variables. " |
| "otherwise uses final_dim / latent_groups" |
| }, |
| ) |
|
|
| |
| mask_length: int = field(default=10, metadata={"help": "mask length"}) |
| mask_prob: float = field( |
| default=0.65, metadata={"help": "probability of replacing a token with mask"} |
| ) |
| mask_selection: MASKING_DISTRIBUTION_CHOICES = field( |
| default="static", metadata={"help": "how to choose mask length"} |
| ) |
| mask_other: float = field( |
| default=0, |
| metadata={ |
| "help": "secondary mask argument (used for more complex distributions), " |
| "see help in compute_mask_indices" |
| }, |
| ) |
| no_mask_overlap: bool = field( |
| default=False, metadata={"help": "whether to allow masks to overlap"} |
| ) |
| mask_min_space: int = field( |
| default=1, |
| metadata={"help": "min space between spans (if no overlap is enabled)"}, |
| ) |
| require_same_masks: bool = field( |
| default=True, |
| metadata={ |
| "help": "whether to number of masked timesteps must be the same across all " |
| "examples in a batch" |
| }, |
| ) |
| mask_dropout: float = field( |
| default=0.0, |
| metadata={"help": "percent of masks to unmask for each sample"}, |
| ) |
|
|
| |
| mask_channel_length: int = field( |
| default=10, metadata={"help": "length of the mask for features (channels)"} |
| ) |
| mask_channel_prob: float = field( |
| default=0.0, metadata={"help": "probability of replacing a feature with 0"} |
| ) |
| mask_channel_before: bool = False |
| mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( |
| default="static", |
| metadata={"help": "how to choose mask length for channel masking"}, |
| ) |
| mask_channel_other: float = field( |
| default=0, |
| metadata={ |
| "help": "secondary mask argument (used for more complex distributions), " |
| "see help in compute_mask_indicesh" |
| }, |
| ) |
| no_mask_channel_overlap: bool = field( |
| default=False, metadata={"help": "whether to allow channel masks to overlap"} |
| ) |
| mask_channel_min_space: int = field( |
| default=1, |
| metadata={"help": "min space between spans (if no overlap is enabled)"}, |
| ) |
|
|
| |
| num_negatives: int = field( |
| default=100, |
| metadata={"help": "number of negative examples from the same sample"}, |
| ) |
| negatives_from_everywhere: bool = field( |
| default=False, |
| metadata={"help": "sample negatives from everywhere, not just masked states"}, |
| ) |
| cross_sample_negatives: int = field( |
| default=0, metadata={"help": "number of negative examples from the any sample"} |
| ) |
| codebook_negatives: int = field( |
| default=0, metadata={"help": "number of negative examples codebook"} |
| ) |
|
|
| |
| conv_pos: int = field( |
| default=128, |
| metadata={"help": "number of filters for convolutional positional embeddings"}, |
| ) |
| conv_pos_groups: int = field( |
| default=16, |
| metadata={"help": "number of groups for convolutional positional embedding"}, |
| ) |
| pos_conv_depth: int = field( |
| default=1, |
| metadata={"help": "depth of positional encoder network"}, |
| ) |
|
|
| latent_temp: Tuple[float, float, float] = field( |
| default=(2, 0.5, 0.999995), |
| metadata={ |
| "help": "temperature for latent variable sampling. " |
| "can be tuple of 3 values (start, end, decay)" |
| }, |
| ) |
| max_positions: int = field(default=100000, metadata={"help": "Max positions"}) |
| checkpoint_activations: bool = field( |
| default=False, |
| metadata={"help": "recompute activations and save memory for extra compute"}, |
| ) |
|
|
| |
| required_seq_len_multiple: int = field( |
| default=2, |
| metadata={ |
| "help": "pad the input to encoder such that the sequence length is divisible by multiple" |
| }, |
| ) |
| crop_seq_to_multiple: int = field( |
| default=1, |
| metadata={ |
| "help": "crop convolutional feature extractor output such that the sequence length is divisible by multiple" |
| }, |
| ) |
|
|
| |
| depthwise_conv_kernel_size: int = field( |
| default=31, |
| metadata={ |
| "help": "depthwise-conv-kernel-size for convolution in conformer layer" |
| }, |
| ) |
| attn_type: str = field( |
| default="", |
| metadata={"help": "if espnet use ESPNET MHA"}, |
| ) |
| pos_enc_type: str = field( |
| default="abs", |
| metadata={"help": "Positional encoding type to use in conformer"}, |
| ) |
| fp16: bool = field(default=False, metadata={"help": "If fp16 is being used"}) |
|
|
|
|
| class Wav2Vec2Model(nn.Module): |
| def __init__(self, cfg: Wav2Vec2Config): |
| super().__init__() |
| self.cfg = cfg |
|
|
| feature_enc_layers = eval(cfg.conv_feature_layers) |
| self.embed = feature_enc_layers[-1][0] |
|
|
| self.feature_extractor = ConvFeatureExtractionModel( |
| conv_layers=feature_enc_layers, |
| dropout=0.0, |
| mode=cfg.extractor_mode, |
| conv_bias=cfg.conv_bias, |
| ) |
|
|
| self.post_extract_proj = ( |
| nn.Linear(self.embed, cfg.encoder_embed_dim) |
| if self.embed != cfg.encoder_embed_dim and not cfg.quantize_input |
| else None |
| ) |
|
|
| self.crop_seq_to_multiple = cfg.crop_seq_to_multiple |
|
|
| self.mask_prob = cfg.mask_prob |
| self.mask_selection = cfg.mask_selection |
| self.mask_other = cfg.mask_other |
| self.mask_length = cfg.mask_length |
| self.no_mask_overlap = cfg.no_mask_overlap |
| self.mask_min_space = cfg.mask_min_space |
|
|
| self.mask_channel_prob = cfg.mask_channel_prob |
| self.mask_channel_before = cfg.mask_channel_before |
| self.mask_channel_selection = cfg.mask_channel_selection |
| self.mask_channel_other = cfg.mask_channel_other |
| self.mask_channel_length = cfg.mask_channel_length |
| self.no_mask_channel_overlap = cfg.no_mask_channel_overlap |
| self.mask_channel_min_space = cfg.mask_channel_min_space |
|
|
| self.dropout_input = nn.Dropout(cfg.dropout_input) |
| self.dropout_features = nn.Dropout(cfg.dropout_features) |
|
|
| self.feature_grad_mult = cfg.feature_grad_mult |
|
|
| self.quantizer = None |
| self.input_quantizer = None |
|
|
| self.n_negatives = cfg.num_negatives |
| self.cross_sample_negatives = cfg.cross_sample_negatives |
| self.codebook_negatives = cfg.codebook_negatives |
| self.negatives_from_everywhere = cfg.negatives_from_everywhere |
|
|
| self.logit_temp = cfg.logit_temp |
|
|
| final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim |
|
|
| if cfg.quantize_targets: |
| vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else final_dim |
| self.quantizer = GumbelVectorQuantizer( |
| dim=self.embed, |
| num_vars=cfg.latent_vars, |
| temp=cfg.latent_temp, |
| groups=cfg.latent_groups, |
| combine_groups=False, |
| vq_dim=vq_dim, |
| time_first=True, |
| weight_proj_depth=cfg.quantizer_depth, |
| weight_proj_factor=cfg.quantizer_factor, |
| ) |
| self.project_q = nn.Linear(vq_dim, final_dim) |
| else: |
| self.project_q = nn.Linear(self.embed, final_dim) |
|
|
| if cfg.quantize_input: |
| if cfg.same_quantizer and self.quantizer is not None: |
| vq_dim = final_dim |
| self.input_quantizer = self.quantizer |
| else: |
| vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else cfg.encoder_embed_dim |
| self.input_quantizer = GumbelVectorQuantizer( |
| dim=self.embed, |
| num_vars=cfg.latent_vars, |
| temp=cfg.latent_temp, |
| groups=cfg.latent_groups, |
| combine_groups=False, |
| vq_dim=vq_dim, |
| time_first=True, |
| weight_proj_depth=cfg.quantizer_depth, |
| weight_proj_factor=cfg.quantizer_factor, |
| ) |
| self.project_inp = nn.Linear(vq_dim, cfg.encoder_embed_dim) |
|
|
| self.mask_emb = nn.Parameter( |
| torch.FloatTensor(cfg.encoder_embed_dim).uniform_() |
| ) |
| encoder_cls = TransformerEncoder |
| if cfg.layer_type == "conformer" and cfg.pos_enc_type in ["rel_pos", "rope"]: |
| encoder_cls = ConformerEncoder |
|
|
| self.encoder = encoder_cls(cfg) |
| self.layer_norm = LayerNorm(self.embed) |
|
|
| self.target_glu = None |
| if cfg.target_glu: |
| self.target_glu = nn.Sequential( |
| nn.Linear(final_dim, final_dim * 2), nn.GLU() |
| ) |
|
|
| self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim) |
|
|
| def upgrade_state_dict_named(self, state_dict, name): |
| super().upgrade_state_dict_named(state_dict, name) |
| """Upgrade a (possibly old) state dict for new versions of fairseq.""" |
| return state_dict |
|
|
| @classmethod |
| def build_model(cls, cfg: Wav2Vec2Config, task=None): |
| """Build a new model instance.""" |
| return cls(cfg) |
|
|
| def apply_mask( |
| self, |
| x, |
| padding_mask, |
| mask_indices=None, |
| mask_channel_indices=None, |
| ): |
| B, T, C = x.shape |
|
|
| if self.mask_channel_prob > 0 and self.mask_channel_before: |
| mask_channel_indices = compute_mask_indices( |
| (B, C), |
| None, |
| self.mask_channel_prob, |
| self.mask_channel_length, |
| self.mask_channel_selection, |
| self.mask_channel_other, |
| no_overlap=self.no_mask_channel_overlap, |
| min_space=self.mask_channel_min_space, |
| ) |
| mask_channel_indices = ( |
| torch.from_numpy(mask_channel_indices) |
| .to(x.device) |
| .unsqueeze(1) |
| .expand(-1, T, -1) |
| ) |
| x[mask_channel_indices] = 0 |
|
|
| if self.mask_prob > 0: |
| if mask_indices is None: |
| mask_indices = compute_mask_indices( |
| (B, T), |
| padding_mask, |
| self.mask_prob, |
| self.mask_length, |
| self.mask_selection, |
| self.mask_other, |
| min_masks=2, |
| no_overlap=self.no_mask_overlap, |
| min_space=self.mask_min_space, |
| require_same_masks=self.cfg.require_same_masks, |
| mask_dropout=self.cfg.mask_dropout, |
| ) |
| mask_indices = torch.from_numpy(mask_indices).to(x.device) |
| x = index_put(x, mask_indices, self.mask_emb) |
| else: |
| mask_indices = None |
|
|
| if self.mask_channel_prob > 0 and not self.mask_channel_before: |
| if mask_channel_indices is None: |
| mask_channel_indices = compute_mask_indices( |
| (B, C), |
| None, |
| self.mask_channel_prob, |
| self.mask_channel_length, |
| self.mask_channel_selection, |
| self.mask_channel_other, |
| no_overlap=self.no_mask_channel_overlap, |
| min_space=self.mask_channel_min_space, |
| ) |
| mask_channel_indices = ( |
| torch.from_numpy(mask_channel_indices) |
| .to(x.device) |
| .unsqueeze(1) |
| .expand(-1, T, -1) |
| ) |
| x = index_put(x, mask_channel_indices, 0) |
|
|
| return x, mask_indices |
|
|
| def sample_negatives(self, y, num, padding_count=None): |
| if self.n_negatives == 0 and self.cross_sample_negatives == 0: |
| return y.new(0) |
|
|
| bsz, tsz, fsz = y.shape |
| y = y.view(-1, fsz) |
|
|
| |
| cross_high = tsz * bsz |
| high = tsz - (padding_count or 0) |
| with torch.no_grad(): |
| assert high > 1, f"{bsz,tsz,fsz}" |
|
|
| if self.n_negatives > 0: |
| tszs = ( |
| buffered_arange(num) |
| .unsqueeze(-1) |
| .expand(-1, self.n_negatives) |
| .flatten() |
| ) |
|
|
| neg_idxs = torch.randint( |
| low=0, high=high - 1, size=(bsz, self.n_negatives * num) |
| ) |
| neg_idxs[neg_idxs >= tszs] += 1 |
|
|
| if self.cross_sample_negatives > 0: |
| tszs = ( |
| buffered_arange(num) |
| .unsqueeze(-1) |
| .expand(-1, self.cross_sample_negatives) |
| .flatten() |
| ) |
|
|
| cross_neg_idxs = torch.randint( |
| low=0, |
| high=cross_high - 1, |
| size=(bsz, self.cross_sample_negatives * num), |
| ) |
| cross_neg_idxs[cross_neg_idxs >= tszs] += 1 |
|
|
| if self.n_negatives > 0: |
| neg_idxs = neg_idxs + (torch.arange(bsz).unsqueeze(1) * high) |
| else: |
| neg_idxs = cross_neg_idxs |
|
|
| if self.cross_sample_negatives > 0 and self.n_negatives > 0: |
| neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) |
|
|
| negs = y[neg_idxs.view(-1)] |
| negs = negs.view( |
| bsz, num, self.n_negatives + self.cross_sample_negatives, fsz |
| ).permute( |
| 2, 0, 1, 3 |
| ) |
| return negs, neg_idxs |
|
|
| def compute_preds(self, x, y, negatives): |
| neg_is_pos = (y == negatives).all(-1) |
| y = y.unsqueeze(0) |
| targets = torch.cat([y, negatives], dim=0) |
|
|
| logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1) |
| logits = logits / self.logit_temp |
| logits = logits.type_as(x) |
|
|
| return logits |
|
|
| def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): |
| """ |
| Computes the output length of the convolutional layers |
| """ |
|
|
| def _conv_out_length(input_length, kernel_size, stride): |
| return torch.floor((input_length - kernel_size) / stride + 1) |
|
|
| conv_cfg_list = eval(self.cfg.conv_feature_layers) |
|
|
| for i in range(len(conv_cfg_list)): |
| input_lengths = _conv_out_length( |
| input_lengths, conv_cfg_list[i][1], conv_cfg_list[i][2] |
| ) |
|
|
| return input_lengths.to(torch.long) |
|
|
| def forward( |
| self, |
| source, |
| padding_mask=None, |
| mask=True, |
| features_only=False, |
| layer=None, |
| mask_indices=None, |
| mask_channel_indices=None, |
| padding_count=None, |
| ): |
| if self.feature_grad_mult > 0: |
| features = self.feature_extractor(source) |
| if self.feature_grad_mult != 1.0: |
| features = GradMultiply.apply(features, self.feature_grad_mult) |
| else: |
| with torch.no_grad(): |
| features = self.feature_extractor(source) |
|
|
| features_pen = features.float().pow(2).mean() |
|
|
| features = features.transpose(1, 2) |
| features = self.layer_norm(features) |
| unmasked_features = features.clone() |
|
|
| if padding_mask is not None and padding_mask.any(): |
| input_lengths = (1 - padding_mask.long()).sum(-1) |
| |
| output_lengths = self._get_feat_extract_output_lengths(input_lengths) |
|
|
| padding_mask = torch.zeros( |
| features.shape[:2], dtype=features.dtype, device=features.device |
| ) |
|
|
| |
| |
| padding_mask[ |
| ( |
| torch.arange(padding_mask.shape[0], device=padding_mask.device), |
| output_lengths - 1, |
| ) |
| ] = 1 |
| padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool() |
| else: |
| padding_mask = None |
|
|
| time_steps_to_drop = features.size(1) % self.crop_seq_to_multiple |
| if time_steps_to_drop != 0: |
| features = features[:, :-time_steps_to_drop] |
| unmasked_features = unmasked_features[:, :-time_steps_to_drop] |
| if padding_mask is not None: |
| padding_mask = padding_mask[:, :-time_steps_to_drop] |
|
|
| if self.post_extract_proj is not None: |
| features = self.post_extract_proj(features) |
|
|
| features = self.dropout_input(features) |
| unmasked_features = self.dropout_features(unmasked_features) |
|
|
| num_vars = None |
| code_ppl = None |
| prob_ppl = None |
| curr_temp = None |
|
|
| if self.input_quantizer: |
| q = self.input_quantizer(features, produce_targets=False) |
| features = q["x"] |
| num_vars = q["num_vars"] |
| code_ppl = q["code_perplexity"] |
| prob_ppl = q["prob_perplexity"] |
| curr_temp = q["temp"] |
| features = self.project_inp(features) |
|
|
| if mask: |
| x, mask_indices = self.apply_mask( |
| features, |
| padding_mask, |
| mask_indices=mask_indices, |
| mask_channel_indices=mask_channel_indices, |
| ) |
| if mask_indices is not None: |
| y = unmasked_features[mask_indices].view( |
| unmasked_features.size(0), -1, unmasked_features.size(-1) |
| ) |
| else: |
| x = features |
| y = unmasked_features |
| mask_indices = None |
|
|
| x, layer_results = self.encoder(x, padding_mask=padding_mask, layer=layer) |
|
|
| if features_only: |
| return { |
| "x": x, |
| "padding_mask": padding_mask, |
| "features": unmasked_features, |
| "layer_results": layer_results, |
| } |
|
|
| if self.quantizer: |
| if self.negatives_from_everywhere: |
| q = self.quantizer(unmasked_features, produce_targets=False) |
| y = q["x"] |
| num_vars = q["num_vars"] |
| code_ppl = q["code_perplexity"] |
| prob_ppl = q["prob_perplexity"] |
| curr_temp = q["temp"] |
| y = self.project_q(y) |
|
|
| negs, _ = self.sample_negatives( |
| y, |
| mask_indices[0].sum(), |
| padding_count=padding_count, |
| ) |
| y = y[mask_indices].view(y.size(0), -1, y.size(-1)) |
|
|
| else: |
| q = self.quantizer(y, produce_targets=False) |
| y = q["x"] |
| num_vars = q["num_vars"] |
| code_ppl = q["code_perplexity"] |
| prob_ppl = q["prob_perplexity"] |
| curr_temp = q["temp"] |
|
|
| y = self.project_q(y) |
|
|
| negs, _ = self.sample_negatives( |
| y, |
| y.size(1), |
| padding_count=padding_count, |
| ) |
|
|
| if self.codebook_negatives > 0: |
| cb_negs = self.quantizer.sample_from_codebook( |
| y.size(0) * y.size(1), self.codebook_negatives |
| ) |
| cb_negs = cb_negs.view( |
| self.codebook_negatives, y.size(0), y.size(1), -1 |
| ) |
| cb_negs = self.project_q(cb_negs) |
| negs = torch.cat([negs, cb_negs], dim=0) |
| else: |
| y = self.project_q(y) |
|
|
| if self.negatives_from_everywhere: |
| negs, _ = self.sample_negatives( |
| unmasked_features, |
| y.size(1), |
| padding_count=padding_count, |
| ) |
| negs = self.project_q(negs) |
| else: |
| negs, _ = self.sample_negatives( |
| y, |
| y.size(1), |
| padding_count=padding_count, |
| ) |
|
|
| x = x[mask_indices].view(x.size(0), -1, x.size(-1)) |
|
|
| if self.target_glu: |
| y = self.target_glu(y) |
| negs = self.target_glu(negs) |
|
|
| x = self.final_proj(x) |
| x = self.compute_preds(x, y, negs) |
|
|
| result = { |
| "x": x, |
| "padding_mask": padding_mask, |
| "features_pen": features_pen, |
| } |
|
|
| if prob_ppl is not None: |
| result["prob_perplexity"] = prob_ppl |
| result["code_perplexity"] = code_ppl |
| result["num_vars"] = num_vars |
| result["temp"] = curr_temp |
|
|
| return result |
|
|
| def quantize(self, x): |
| assert self.quantizer is not None |
| x = self.feature_extractor(x) |
| x = x.transpose(1, 2) |
| x = self.layer_norm(x) |
| return self.quantizer.forward_idx(x) |
|
|
| def extract_features(self, source, padding_mask, mask=False, layer=None): |
| res = self.forward( |
| source, padding_mask, mask=mask, features_only=True, layer=layer |
| ) |
| return res |
|
|
| def get_logits(self, net_output): |
| logits = net_output["x"] |
| logits = logits.transpose(0, 2) |
| logits = logits.reshape(-1, logits.size(-1)) |
| return logits |
|
|
| def get_targets(self, sample, net_output, expand_steps=True): |
| x = net_output["x"] |
| return x.new_zeros(x.size(1) * x.size(2), dtype=torch.long) |
|
|
| def get_extra_losses(self, net_output): |
| pen = [] |
|
|
| if "prob_perplexity" in net_output: |
| pen.append( |
| (net_output["num_vars"] - net_output["prob_perplexity"]) |
| / net_output["num_vars"] |
| ) |
|
|
| if "features_pen" in net_output: |
| pen.append(net_output["features_pen"]) |
|
|
| return pen |
|
|
| def remove_pretraining_modules(self, last_layer=None): |
| self.quantizer = None |
| self.project_q = None |
| self.target_glu = None |
| self.final_proj = None |
|
|
| if last_layer is not None: |
| self.encoder.layers = nn.ModuleList( |
| l for i, l in enumerate(self.encoder.layers) if i <= last_layer |
| ) |
|
|
|
|
| class ConvFeatureExtractionModel(nn.Module): |
| def __init__( |
| self, |
| conv_layers: List[Tuple[int, int, int]], |
| dropout: float = 0.0, |
| mode: str = "default", |
| conv_bias: bool = False, |
| ): |
| super().__init__() |
|
|
| assert mode in {"default", "layer_norm"} |
|
|
| def block( |
| n_in, |
| n_out, |
| k, |
| stride, |
| is_layer_norm=False, |
| is_group_norm=False, |
| conv_bias=False, |
| ): |
| def make_conv(): |
| conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias) |
| nn.init.kaiming_normal_(conv.weight) |
| return conv |
|
|
| assert ( |
| is_layer_norm and is_group_norm |
| ) == False, "layer norm and group norm are exclusive" |
|
|
| if is_layer_norm: |
| return nn.Sequential( |
| make_conv(), |
| nn.Dropout(p=dropout), |
| nn.Sequential( |
| TransposeLast(), |
| Fp32LayerNorm(dim, elementwise_affine=True), |
| TransposeLast(), |
| ), |
| nn.GELU(), |
| ) |
| elif is_group_norm: |
| return nn.Sequential( |
| make_conv(), |
| nn.Dropout(p=dropout), |
| Fp32GroupNorm(dim, dim, affine=True), |
| nn.GELU(), |
| ) |
| else: |
| return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU()) |
|
|
| in_d = 1 |
| self.conv_layers = nn.ModuleList() |
| for i, cl in enumerate(conv_layers): |
| assert len(cl) == 3, "invalid conv definition: " + str(cl) |
| (dim, k, stride) = cl |
|
|
| self.conv_layers.append( |
| block( |
| in_d, |
| dim, |
| k, |
| stride, |
| is_layer_norm=mode == "layer_norm", |
| is_group_norm=mode == "default" and i == 0, |
| conv_bias=conv_bias, |
| ) |
| ) |
| in_d = dim |
|
|
| def forward(self, x): |
| |
| x = x.unsqueeze(1) |
|
|
| for conv in self.conv_layers: |
| x = conv(x) |
|
|
| return x |
|
|
|
|
| def make_conv_pos(e, k, g): |
| pos_conv = nn.Conv1d( |
| e, |
| e, |
| kernel_size=k, |
| padding=k // 2, |
| groups=g, |
| ) |
| dropout = 0 |
| std = math.sqrt((4 * (1.0 - dropout)) / (k * e)) |
| nn.init.normal_(pos_conv.weight, mean=0, std=std) |
| nn.init.constant_(pos_conv.bias, 0) |
|
|
| pos_conv = nn.utils.weight_norm(pos_conv, name="weight", dim=2) |
| pos_conv = nn.Sequential(pos_conv, SamePad(k), nn.GELU()) |
|
|
| return pos_conv |
|
|
|
|
| class TransformerEncoder(nn.Module): |
| def build_encoder_layer(self, args: Wav2Vec2Config): |
| if args.layer_type == "transformer": |
| layer = TransformerSentenceEncoderLayer( |
| embedding_dim=self.embedding_dim, |
| ffn_embedding_dim=args.encoder_ffn_embed_dim, |
| num_attention_heads=args.encoder_attention_heads, |
| dropout=self.dropout, |
| attention_dropout=args.attention_dropout, |
| activation_dropout=args.activation_dropout, |
| activation_fn=args.activation_fn, |
| layer_norm_first=args.layer_norm_first, |
| ) |
| elif args.layer_type == "conformer": |
| layer = ConformerWav2Vec2EncoderLayer( |
| embed_dim=self.embedding_dim, |
| ffn_embed_dim=args.encoder_ffn_embed_dim, |
| attention_heads=args.encoder_attention_heads, |
| dropout=args.dropout, |
| depthwise_conv_kernel_size=args.depthwise_conv_kernel_size, |
| activation_fn="swish", |
| attn_type=args.attn_type, |
| use_fp16=args.fp16, |
| pos_enc_type="abs", |
| ) |
| return layer |
|
|
| def __init__( |
| self, |
| args: Wav2Vec2Config, |
| skip_pos_conv: bool = False, |
| override_encoder_layer: int = None, |
| ): |
| super().__init__() |
|
|
| self.dropout = args.dropout |
| self.embedding_dim = args.encoder_embed_dim |
| self.required_seq_len_multiple = args.required_seq_len_multiple |
|
|
| pos_conv_depth = getattr(args, "pos_conv_depth", 1) |
| if pos_conv_depth > 1: |
| num_layers = args.pos_conv_depth |
| k = max(3, args.conv_pos // num_layers) |
|
|
| def make_conv_block(e, k, g, l): |
| return nn.Sequential( |
| *[ |
| nn.Sequential( |
| nn.Conv1d( |
| e, |
| e, |
| kernel_size=k, |
| padding=k // 2, |
| groups=g, |
| ), |
| SamePad(k), |
| TransposeLast(), |
| LayerNorm(e, elementwise_affine=False), |
| TransposeLast(), |
| nn.GELU(), |
| ) |
| for _ in range(l) |
| ] |
| ) |
|
|
| self.pos_conv = make_conv_block( |
| self.embedding_dim, k, args.conv_pos_groups, num_layers |
| ) |
|
|
| elif skip_pos_conv: |
| self.pos_conv = None |
| else: |
| self.pos_conv = make_conv_pos( |
| self.embedding_dim, |
| args.conv_pos, |
| args.conv_pos_groups, |
| ) |
|
|
| if override_encoder_layer is None: |
| encoder_layers = args.encoder_layers |
| else: |
| encoder_layers = override_encoder_layer |
|
|
| self.layers = nn.ModuleList( |
| [self.build_encoder_layer(args) for _ in range(encoder_layers)] |
| ) |
| self.layer_norm_first = args.layer_norm_first |
| self.layer_norm = LayerNorm(self.embedding_dim) |
| self.layerdrop = args.encoder_layerdrop |
|
|
| def forward(self, x, padding_mask=None, layer=None): |
| x, layer_results = self.extract_features(x, padding_mask, layer) |
|
|
| if self.layer_norm_first and layer is None: |
| x = self.layer_norm(x) |
|
|
| return x, layer_results |
|
|
| def extract_features( |
| self, |
| x, |
| padding_mask=None, |
| tgt_layer=None, |
| min_layer=0, |
| ): |
| if padding_mask is not None: |
| x = index_put(x, padding_mask, 0) |
|
|
| if self.pos_conv is not None: |
| x_conv = self.pos_conv(x.transpose(1, 2)) |
| x_conv = x_conv.transpose(1, 2) |
| x = x + x_conv |
|
|
| if not self.layer_norm_first: |
| x = self.layer_norm(x) |
|
|
| |
| x, pad_length = pad_to_multiple( |
| x, self.required_seq_len_multiple, dim=-2, value=0 |
| ) |
| if pad_length > 0 and padding_mask is None: |
| padding_mask = x.new_zeros((x.size(0), x.size(1)), dtype=torch.bool) |
| padding_mask[:, -pad_length:] = True |
| else: |
| padding_mask, _ = pad_to_multiple( |
| padding_mask, self.required_seq_len_multiple, dim=-1, value=True |
| ) |
| x = F.dropout(x, p=self.dropout, training=self.training) |
|
|
| |
| x = x.transpose(0, 1) |
|
|
| layer_results = [] |
| r = None |
| for i, layer in enumerate(self.layers): |
| dropout_probability = np.random.random() if self.layerdrop > 0 else 1 |
| if not self.training or (dropout_probability > self.layerdrop): |
| x, (z, lr) = layer( |
| x, self_attn_padding_mask=padding_mask, need_weights=False |
| ) |
| if i >= min_layer: |
| layer_results.append((x, z, lr)) |
| if i == tgt_layer: |
| r = x |
| break |
|
|
| if r is not None: |
| x = r |
|
|
| |
| x = x.transpose(0, 1) |
|
|
| |
| if pad_length > 0: |
| x = x[:, :-pad_length] |
|
|
| def undo_pad(a, b, c): |
| return ( |
| a[:-pad_length], |
| b[:-pad_length] if b is not None else b, |
| c[:-pad_length], |
| ) |
|
|
| layer_results = [undo_pad(*u) for u in layer_results] |
|
|
| return x, layer_results |
|
|
| def max_positions(self): |
| """Maximum output length supported by the encoder.""" |
| return self.args.max_positions |
|
|
| def upgrade_state_dict_named(self, state_dict, name): |
| """Upgrade a (possibly old) state dict for new versions of fairseq.""" |
| return state_dict |
|
|
|
|
| class ConformerEncoder(TransformerEncoder): |
| def build_encoder_layer(self, args): |
| layer = ConformerWav2Vec2EncoderLayer( |
| embed_dim=self.embedding_dim, |
| ffn_embed_dim=args.encoder_ffn_embed_dim, |
| attention_heads=args.encoder_attention_heads, |
| dropout=args.dropout, |
| depthwise_conv_kernel_size=args.depthwise_conv_kernel_size, |
| activation_fn="swish", |
| attn_type=args.attn_type, |
| pos_enc_type=args.pos_enc_type, |
| use_fp16=args.fp16, |
| ) |
| return layer |
|
|
| def __init__(self, args): |
| super().__init__(args) |
| self.args = args |
| self.dropout = args.dropout |
| self.embedding_dim = args.encoder_embed_dim |
| self.pos_enc_type = args.pos_enc_type |
| max_source_positions = self.max_positions() |
|
|
| if self.pos_enc_type == "rel_pos": |
| self.embed_positions = RelPositionalEncoding( |
| max_source_positions, self.embedding_dim |
| ) |
| elif self.pos_enc_type == "rope": |
| self.embed_positions = None |
| else: |
| raise Exception("Unsupported positional encoding type") |
|
|
| self.layers = nn.ModuleList( |
| [self.build_encoder_layer(args) for _ in range(args.encoder_layers)] |
| ) |
| self.layer_norm_first = args.layer_norm_first |
| self.layer_norm = LayerNorm(self.embedding_dim) |
| self.layerdrop = args.encoder_layerdrop |
|
|
| def extract_features(self, x, padding_mask=None, tgt_layer=None): |
| if padding_mask is not None: |
| x = index_put(x, padding_mask, 0) |
|
|
| |
| x = x.transpose(0, 1) |
|
|
| |
| position_emb = None |
| if self.pos_enc_type == "rel_pos": |
| position_emb = self.embed_positions(x) |
|
|
| if not self.layer_norm_first: |
| x = self.layer_norm(x) |
|
|
| x = F.dropout(x, p=self.dropout, training=self.training) |
|
|
| layer_results = [] |
| r = None |
| for i, layer in enumerate(self.layers): |
| dropout_probability = np.random.random() |
| if not self.training or (dropout_probability > self.layerdrop): |
| x, z = layer( |
| x, |
| self_attn_padding_mask=padding_mask, |
| need_weights=False, |
| position_emb=position_emb, |
| ) |
| if tgt_layer is not None: |
| layer_results.append((x, z)) |
| if i == tgt_layer: |
| r = x |
| break |
|
|
| if r is not None: |
| x = r |
|
|
| |
| x = x.transpose(0, 1) |
|
|
| return x, layer_results |
|
|
|
|
| class TransformerSentenceEncoderLayer(nn.Module): |
| """ |
| Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained |
| models. |
| """ |
|
|
| def __init__( |
| self, |
| embedding_dim: float = 768, |
| ffn_embedding_dim: float = 3072, |
| num_attention_heads: int = 8, |
| dropout: float = 0.1, |
| attention_dropout: float = 0.1, |
| activation_dropout: float = 0.1, |
| activation_fn: str = "relu", |
| layer_norm_first: bool = False, |
| ) -> None: |
| super().__init__() |
| |
| self.embedding_dim = embedding_dim |
| self.dropout = dropout |
| self.activation_dropout = activation_dropout |
|
|
| |
| self.activation_fn = get_activation_fn(activation_fn) |
| self.self_attn = MultiheadAttention( |
| self.embedding_dim, |
| num_attention_heads, |
| dropout=attention_dropout, |
| self_attention=True, |
| ) |
|
|
| self.dropout1 = nn.Dropout(dropout) |
| self.dropout2 = nn.Dropout(self.activation_dropout) |
| self.dropout3 = nn.Dropout(dropout) |
|
|
| self.layer_norm_first = layer_norm_first |
|
|
| |
| self.self_attn_layer_norm = LayerNorm(self.embedding_dim) |
| self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim) |
| self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim) |
|
|
| |
| self.final_layer_norm = LayerNorm(self.embedding_dim) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| self_attn_mask: torch.Tensor = None, |
| self_attn_padding_mask: torch.Tensor = None, |
| need_weights: bool = False, |
| att_args=None, |
| ): |
| """ |
| LayerNorm is applied either before or after the self-attention/ffn |
| modules similar to the original Transformer imlementation. |
| """ |
| residual = x |
|
|
| if self.layer_norm_first: |
| x = self.self_attn_layer_norm(x) |
| x, attn = self.self_attn( |
| query=x, |
| key=x, |
| value=x, |
| key_padding_mask=self_attn_padding_mask, |
| attn_mask=self_attn_mask, |
| need_weights=False, |
| ) |
| x = self.dropout1(x) |
| x = residual + x |
|
|
| residual = x |
| x = self.final_layer_norm(x) |
| x = self.activation_fn(self.fc1(x)) |
| x = self.dropout2(x) |
| x = self.fc2(x) |
|
|
| layer_result = x |
|
|
| x = self.dropout3(x) |
| x = residual + x |
| else: |
| x, attn = self.self_attn( |
| query=x, |
| key=x, |
| value=x, |
| key_padding_mask=self_attn_padding_mask, |
| need_weights=False, |
| ) |
|
|
| x = self.dropout1(x) |
| x = residual + x |
|
|
| x = self.self_attn_layer_norm(x) |
|
|
| residual = x |
| x = self.activation_fn(self.fc1(x)) |
| x = self.dropout2(x) |
| x = self.fc2(x) |
|
|
| layer_result = x |
|
|
| x = self.dropout3(x) |
| x = residual + x |
| x = self.final_layer_norm(x) |
|
|
| return x, (attn, layer_result) |
|
|
|
|
| @dataclass |
| class AudioPretrainingConfig: |
| sample_rate: int = field( |
| default=16_000, |
| metadata={ |
| "help": "target sample rate. audio files will be up/down sampled to this rate" |
| }, |
| ) |
| normalize: bool = field( |
| default=False, |
| metadata={"help": "if set, normalizes input to have 0 mean and unit variance"}, |
| ) |
| enable_padding: bool = field( |
| default=False, metadata={"help": "pad shorter samples instead of cropping"} |
| ) |
| max_sample_size: Optional[int] = field( |
| default=None, metadata={"help": "max sample size to crop to for batching"} |
| ) |
| min_sample_size: Optional[int] = field( |
| default=None, metadata={"help": "min sample size to skip small examples"} |
| ) |