| |
| |
| |
| |
|
|
| from argparse import Namespace |
| import contextlib |
| import copy |
| import math |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from dataclasses import dataclass, field |
| from omegaconf import MISSING, II, open_dict |
| from typing import Any, Optional |
|
|
| from fairseq import checkpoint_utils, tasks, utils |
| from fairseq.dataclass import FairseqDataclass |
| from fairseq.dataclass.utils import convert_namespace_to_omegaconf |
| from fairseq.tasks import FairseqTask |
| from fairseq.models import ( |
| BaseFairseqModel, |
| FairseqEncoder, |
| FairseqEncoderDecoderModel, |
| FairseqIncrementalDecoder, |
| register_model, |
| ) |
| from fairseq.models.wav2vec.wav2vec2 import MASKING_DISTRIBUTION_CHOICES |
| from fairseq.modules import ( |
| LayerNorm, |
| PositionalEmbedding, |
| TransformerDecoderLayer, |
| ) |
|
|
|
|
| @dataclass |
| class Wav2Vec2AsrConfig(FairseqDataclass): |
| w2v_path: str = field( |
| default=MISSING, metadata={"help": "path to wav2vec 2.0 model"} |
| ) |
| no_pretrained_weights: bool = field( |
| default=False, metadata={"help": "if true, does not load pretrained weights"} |
| ) |
| dropout_input: float = field( |
| default=0.0, |
| metadata={"help": "dropout to apply to the input (after feat extr)"}, |
| ) |
| final_dropout: float = field( |
| default=0.0, |
| metadata={"help": "dropout after transformer and before final projection"}, |
| ) |
| dropout: float = field( |
| default=0.0, metadata={"help": "dropout probability inside wav2vec 2.0 model"} |
| ) |
| attention_dropout: float = field( |
| default=0.0, |
| metadata={ |
| "help": "dropout probability for attention weights inside wav2vec 2.0 model" |
| }, |
| ) |
| activation_dropout: float = field( |
| default=0.0, |
| metadata={ |
| "help": "dropout probability after activation in FFN inside wav2vec 2.0 model" |
| }, |
| ) |
|
|
| |
| apply_mask: bool = field( |
| default=False, metadata={"help": "apply masking during fine-tuning"} |
| ) |
| mask_length: int = field( |
| default=10, metadata={"help": "repeat the mask indices multiple times"} |
| ) |
| mask_prob: float = field( |
| default=0.5, |
| metadata={ |
| "help": "probability of replacing a token with mask (normalized by length)" |
| }, |
| ) |
| mask_selection: MASKING_DISTRIBUTION_CHOICES = field( |
| default="static", metadata={"help": "how to choose masks"} |
| ) |
| mask_other: float = field( |
| default=0, |
| metadata={ |
| "help": "secondary mask argument (used for more complex distributions), " |
| "see help in compute_mask_indices" |
| }, |
| ) |
| no_mask_overlap: bool = field( |
| default=False, metadata={"help": "whether to allow masks to overlap"} |
| ) |
|
|
| |
| mask_channel_length: int = field( |
| default=10, metadata={"help": "length of the mask for features (channels)"} |
| ) |
| mask_channel_prob: float = field( |
| default=0.0, metadata={"help": "probability of replacing a feature with 0"} |
| ) |
| mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( |
| default="static", |
| metadata={"help": "how to choose mask length for channel masking"}, |
| ) |
| mask_channel_other: float = field( |
| default=0, |
| metadata={ |
| "help": "secondary mask argument (used for more complex distributions), " |
| "see help in compute_mask_indicesh" |
| }, |
| ) |
| no_mask_channel_overlap: bool = field( |
| default=False, metadata={"help": "whether to allow channel masks to overlap"} |
| ) |
| freeze_finetune_updates: int = field( |
| default=0, metadata={"help": "dont finetune wav2vec for this many updates"} |
| ) |
| feature_grad_mult: float = field( |
| default=0.0, metadata={"help": "reset feature grad mult in wav2vec 2.0 to this"} |
| ) |
| layerdrop: float = field( |
| default=0.0, metadata={"help": "probability of dropping a layer in wav2vec 2.0"} |
| ) |
| mask_channel_before: bool = False |
| normalize: bool = II("task.normalize") |
| data: str = II("task.data") |
| |
| w2v_args: Any = None |
|
|
|
|
| @dataclass |
| class Wav2Vec2CtcConfig(Wav2Vec2AsrConfig): |
| blank_weight: float = 0 |
| blank_mode: str = "add" |
| mask_min_space: Optional[int] = field( |
| default=1, |
| metadata={"help": "min space between spans (if no overlap is enabled)"}, |
| ) |
| mask_channel_min_space: Optional[int] = field( |
| default=1, |
| metadata={"help": "min space between spans (if no overlap is enabled)"}, |
| ) |
| conv_feature_layers: Optional[str] = field( |
| default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]", |
| metadata={ |
| "help": ( |
| "string describing convolutional feature extraction " |
| "layers in form of a python list that contains " |
| "[(dim, kernel_size, stride), ...]" |
| ), |
| }, |
| ) |
| encoder_embed_dim: Optional[int] = field( |
| default=768, metadata={"help": "encoder embedding dimension"} |
| ) |
|
|
|
|
| @register_model("wav2vec_ctc", dataclass=Wav2Vec2CtcConfig) |
| class Wav2VecCtc(BaseFairseqModel): |
| def __init__(self, cfg: Wav2Vec2CtcConfig, w2v_encoder: BaseFairseqModel): |
| super().__init__() |
| self.cfg = cfg |
| self.w2v_encoder = w2v_encoder |
| self.blank_weight = cfg.blank_weight |
| self.blank_mode = cfg.blank_mode |
|
|
| def upgrade_state_dict_named(self, state_dict, name): |
| super().upgrade_state_dict_named(state_dict, name) |
| return state_dict |
|
|
| @classmethod |
| def build_model(cls, cfg: Wav2Vec2CtcConfig, task: FairseqTask): |
| """Build a new model instance.""" |
| w2v_encoder = Wav2VecEncoder(cfg, len(task.target_dictionary)) |
| return cls(cfg, w2v_encoder) |
|
|
| def get_logits(self, net_output, normalize=False): |
| logits = net_output["encoder_out"] |
| if self.blank_weight != 0: |
| if self.blank_mode == "add": |
| logits[..., 0] += self.blank_weight |
| elif self.blank_mode == "set": |
| logits[..., 0] = self.blank_weight |
| else: |
| raise Exception(f"invalid blank mode {self.blank_mode}") |
|
|
| if net_output["padding_mask"] is not None and net_output["padding_mask"].any(): |
| logits[net_output["padding_mask"].T][..., 0] = float("inf") |
| logits[net_output["padding_mask"].T][..., 1:] = float("-inf") |
|
|
| if normalize: |
| logits = utils.log_softmax(logits.float(), dim=-1) |
|
|
| return logits |
|
|
| def get_normalized_probs(self, net_output, log_probs): |
| """Get normalized probabilities (or log probs) from a net's output.""" |
|
|
| logits = self.get_logits(net_output) |
|
|
| if log_probs: |
| return utils.log_softmax(logits.float(), dim=-1) |
| else: |
| return utils.softmax(logits.float(), dim=-1) |
|
|
| def forward(self, **kwargs): |
| x = self.w2v_encoder(**kwargs) |
| return x |
|
|
|
|
| @dataclass |
| class Wav2Vec2Seq2SeqConfig(Wav2Vec2AsrConfig): |
| decoder_embed_dim: int = field( |
| default=768, metadata={"help": "decoder embedding dimension"} |
| ) |
| decoder_ffn_embed_dim: int = field( |
| default=3072, metadata={"help": "decoder embedding dimension for FFN"} |
| ) |
| decoder_layers: int = field(default=6, metadata={"help": "num of decoder layers"}) |
| decoder_layerdrop: float = field( |
| default=0.0, metadata={"help": "decoder layerdrop chance"} |
| ) |
| decoder_attention_heads: int = field( |
| default=4, metadata={"help": "num decoder attention heads"} |
| ) |
| decoder_learned_pos: bool = field( |
| default=False, |
| metadata={"help": "use learned positional embeddings in the decoder"}, |
| ) |
| decoder_normalize_before: bool = field( |
| default=False, metadata={"help": "apply layernorm before each decoder block"} |
| ) |
| no_token_positional_embeddings: bool = field( |
| default=False, |
| metadata={ |
| "help": "if set, disables positional embeddings (outside self attention)" |
| }, |
| ) |
| decoder_dropout: float = field( |
| default=0.0, metadata={"help": "dropout probability in the decoder"} |
| ) |
| decoder_attention_dropout: float = field( |
| default=0.0, |
| metadata={ |
| "help": "dropout probability for attention weights inside the decoder" |
| }, |
| ) |
| decoder_activation_dropout: float = field( |
| default=0.0, |
| metadata={ |
| "help": "dropout probability after activation in FFN inside the decoder" |
| }, |
| ) |
| max_target_positions: int = field( |
| default=2048, metadata={"help": "max target positions"} |
| ) |
| share_decoder_input_output_embed: bool = field( |
| default=False, metadata={"help": "share decoder input and output embeddings"} |
| ) |
| autoregressive: bool = II("task.autoregressive") |
|
|
|
|
| @register_model("wav2vec_seq2seq", dataclass=Wav2Vec2Seq2SeqConfig) |
| class Wav2Vec2Seq2SeqModel(FairseqEncoderDecoderModel): |
| def __init__(self, encoder, decoder): |
| super().__init__(encoder, decoder) |
|
|
| @classmethod |
| def build_model(cls, cfg: Wav2Vec2Seq2SeqConfig, task: FairseqTask): |
| """Build a new model instance.""" |
|
|
| assert ( |
| cfg.autoregressive |
| ), "Please set task.autoregressive=true for seq2seq asr models" |
|
|
| src_dict, tgt_dict = task.source_dictionary, task.target_dictionary |
|
|
| def build_embedding(dictionary, embed_dim): |
| num_embeddings = len(dictionary) |
| padding_idx = dictionary.pad() |
| emb = Embedding(num_embeddings, embed_dim, padding_idx) |
| return emb |
|
|
| decoder_embed_tokens = build_embedding(tgt_dict, cfg.decoder_embed_dim) |
|
|
| encoder = cls.build_encoder(cfg) |
| decoder = cls.build_decoder(cfg, tgt_dict, decoder_embed_tokens) |
|
|
| return Wav2Vec2Seq2SeqModel(encoder, decoder) |
|
|
| @classmethod |
| def build_encoder(cls, cfg: Wav2Vec2AsrConfig): |
| return Wav2VecEncoder(cfg) |
|
|
| @classmethod |
| def build_decoder(cls, cfg: Wav2Vec2Seq2SeqConfig, tgt_dict, embed_tokens): |
| return TransformerDecoder(cfg, tgt_dict, embed_tokens) |
|
|
| def forward(self, **kwargs): |
| encoder_out = self.encoder(tbc=False, **kwargs) |
| decoder_out = self.decoder(encoder_out=encoder_out, **kwargs) |
| return decoder_out |
|
|
| def upgrade_state_dict_named(self, state_dict, name): |
| super().upgrade_state_dict_named(state_dict, name) |
| return state_dict |
|
|
|
|
| class Wav2VecEncoder(FairseqEncoder): |
| def __init__(self, cfg: Wav2Vec2AsrConfig, output_size=None): |
| self.apply_mask = cfg.apply_mask |
|
|
| arg_overrides = { |
| "dropout": cfg.dropout, |
| "activation_dropout": cfg.activation_dropout, |
| "dropout_input": cfg.dropout_input, |
| "attention_dropout": cfg.attention_dropout, |
| "mask_length": cfg.mask_length, |
| "mask_prob": cfg.mask_prob, |
| "mask_selection": cfg.mask_selection, |
| "mask_other": cfg.mask_other, |
| "no_mask_overlap": cfg.no_mask_overlap, |
| "mask_channel_length": cfg.mask_channel_length, |
| "mask_channel_prob": cfg.mask_channel_prob, |
| "mask_channel_before": cfg.mask_channel_before, |
| "mask_channel_selection": cfg.mask_channel_selection, |
| "mask_channel_other": cfg.mask_channel_other, |
| "no_mask_channel_overlap": cfg.no_mask_channel_overlap, |
| "encoder_layerdrop": cfg.layerdrop, |
| "feature_grad_mult": cfg.feature_grad_mult, |
| } |
|
|
| if cfg.w2v_args is None: |
| state = checkpoint_utils.load_checkpoint_to_cpu(cfg.w2v_path, arg_overrides) |
| w2v_args = state.get("cfg", None) |
| if w2v_args is None: |
| w2v_args = convert_namespace_to_omegaconf(state["args"]) |
| cfg.w2v_args = w2v_args |
| else: |
| state = None |
| w2v_args = cfg.w2v_args |
| if isinstance(w2v_args, Namespace): |
| cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args) |
|
|
| assert cfg.normalize == w2v_args.task.normalize, ( |
| "Fine-tuning works best when data normalization is the same. " |
| "Please check that --normalize is set or unset for both pre-training and here" |
| ) |
|
|
| w2v_args.task.data = cfg.data |
| task = tasks.setup_task(w2v_args.task) |
| model = task.build_model(w2v_args.model) |
|
|
| if state is not None and not cfg.no_pretrained_weights: |
| model.load_state_dict(state["model"], strict=True) |
|
|
| model.remove_pretraining_modules() |
|
|
| super().__init__(task.source_dictionary) |
|
|
| d = w2v_args.model.encoder_embed_dim |
|
|
| self.w2v_model = model |
|
|
| self.final_dropout = nn.Dropout(cfg.final_dropout) |
| self.freeze_finetune_updates = cfg.freeze_finetune_updates |
| self.num_updates = 0 |
|
|
| targ_d = None |
| self.proj = None |
|
|
| if output_size is not None: |
| targ_d = output_size |
| elif getattr(cfg, "decoder_embed_dim", d) != d: |
| targ_d = cfg.decoder_embed_dim |
|
|
| if targ_d is not None: |
| self.proj = Linear(d, targ_d) |
|
|
| def set_num_updates(self, num_updates): |
| """Set the number of parameters updates.""" |
| super().set_num_updates(num_updates) |
| self.num_updates = num_updates |
|
|
| def forward(self, source, padding_mask, tbc=True, **kwargs): |
| w2v_args = { |
| "source": source, |
| "padding_mask": padding_mask, |
| "mask": self.apply_mask and self.training, |
| } |
|
|
| ft = self.freeze_finetune_updates <= self.num_updates |
|
|
| with torch.no_grad() if not ft else contextlib.ExitStack(): |
| res = self.w2v_model.extract_features(**w2v_args) |
|
|
| x = res["x"] |
| padding_mask = res["padding_mask"] |
|
|
| if tbc: |
| |
| x = x.transpose(0, 1) |
|
|
| x = self.final_dropout(x) |
|
|
| if self.proj: |
| x = self.proj(x) |
|
|
| return { |
| "encoder_out": x, |
| "encoder_padding_mask": padding_mask.transpose(0, 1) |
| if padding_mask is not None |
| else None, |
| "padding_mask": padding_mask, |
| "layer_results": res["layer_results"], |
| } |
|
|
| def reorder_encoder_out(self, encoder_out, new_order): |
| if encoder_out["encoder_out"] is not None: |
| encoder_out["encoder_out"] = encoder_out["encoder_out"].index_select( |
| 1, new_order |
| ) |
| if encoder_out["encoder_padding_mask"] is not None: |
| encoder_out["encoder_padding_mask"] = encoder_out[ |
| "encoder_padding_mask" |
| ].index_select(0, new_order) |
| return encoder_out |
|
|
| def max_positions(self): |
| """Maximum input length supported by the encoder.""" |
| return None |
|
|
| def upgrade_state_dict_named(self, state_dict, name): |
| return state_dict |
|
|
|
|
| class TransformerDecoder(FairseqIncrementalDecoder): |
| """ |
| Transformer decoder consisting of *args.decoder_layers* layers. Each layer |
| is a :class:`TransformerDecoderLayer`. |
| |
| Args: |
| args (argparse.Namespace): parsed command-line arguments |
| dictionary (~fairseq.data.Dictionary): decoding dictionary |
| embed_tokens (torch.nn.Embedding): output embedding |
| no_encoder_attn (bool, optional): whether to attend to encoder outputs |
| (default: False). |
| """ |
|
|
| def __init__( |
| self, |
| cfg: Wav2Vec2Seq2SeqConfig, |
| dictionary, |
| embed_tokens, |
| no_encoder_attn=False, |
| ): |
| super().__init__(dictionary) |
|
|
| self.dropout = cfg.decoder_dropout |
| self.share_input_output_embed = cfg.share_decoder_input_output_embed |
|
|
| input_embed_dim = embed_tokens.embedding_dim |
| embed_dim = cfg.decoder_embed_dim |
| self.output_embed_dim = cfg.decoder_embed_dim |
|
|
| self.layerdrop = cfg.decoder_layerdrop |
|
|
| padding_idx = embed_tokens.padding_idx |
| self.max_target_positions = cfg.max_target_positions |
|
|
| self.embed_tokens = embed_tokens |
| self.embed_scale = math.sqrt(embed_dim) |
|
|
| self.project_in_dim = ( |
| Linear(input_embed_dim, embed_dim, bias=False) |
| if embed_dim != input_embed_dim |
| else None |
| ) |
|
|
| self.embed_positions = ( |
| PositionalEmbedding( |
| cfg.max_target_positions, |
| embed_dim, |
| padding_idx, |
| learned=cfg.decoder_learned_pos, |
| ) |
| if not cfg.no_token_positional_embeddings |
| else None |
| ) |
|
|
| |
| transformer_cfg = copy.deepcopy(cfg) |
| with open_dict(transformer_cfg): |
| transformer_cfg.dropout = transformer_cfg.decoder_dropout |
| transformer_cfg.attention_dropout = ( |
| transformer_cfg.decoder_attention_dropout |
| ) |
| transformer_cfg.activation_dropout = ( |
| transformer_cfg.decoder_activation_dropout |
| ) |
|
|
| self.layers = nn.ModuleList([]) |
| self.layers.extend( |
| [ |
| TransformerDecoderLayer(transformer_cfg, no_encoder_attn) |
| for _ in range(transformer_cfg.decoder_layers) |
| ] |
| ) |
|
|
| if not self.share_input_output_embed: |
| self.embed_out = nn.Parameter( |
| torch.Tensor(len(dictionary), self.output_embed_dim) |
| ) |
| nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim ** -0.5) |
|
|
| if transformer_cfg.decoder_normalize_before: |
| self.layer_norm = LayerNorm(embed_dim) |
| else: |
| self.layer_norm = None |
|
|
| def forward( |
| self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused |
| ): |
| """ |
| Args: |
| prev_output_tokens (LongTensor): previous decoder outputs of shape |
| `(batch, tgt_len)`, for teacher forcing |
| encoder_out (Tensor, optional): output from the encoder, used for |
| encoder-side attention |
| incremental_state (dict): dictionary used for storing state during |
| :ref:`Incremental decoding` |
| |
| Returns: |
| tuple: |
| - the decoder's output of shape `(batch, tgt_len, vocab)` |
| - a dictionary with any model-specific outputs |
| """ |
| prev_output_tokens = prev_output_tokens.long() |
| x, extra = self.extract_features( |
| prev_output_tokens, encoder_out, incremental_state |
| ) |
| x = self.output_layer(x) |
| return x, extra |
|
|
| def extract_features( |
| self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused |
| ): |
| """ |
| Similar to *forward* but only return features. |
| |
| Returns: |
| tuple: |
| - the decoder's features of shape `(batch, tgt_len, embed_dim)` |
| - a dictionary with any model-specific outputs |
| """ |
|
|
| |
| positions = ( |
| self.embed_positions( |
| prev_output_tokens, incremental_state=incremental_state |
| ) |
| if self.embed_positions is not None |
| else None |
| ) |
|
|
| if incremental_state is not None: |
| prev_output_tokens = prev_output_tokens[:, -1:] |
| if positions is not None: |
| positions = positions[:, -1:] |
|
|
| |
| x = self.embed_scale * self.embed_tokens(prev_output_tokens) |
|
|
| if self.project_in_dim is not None: |
| x = self.project_in_dim(x) |
|
|
| if positions is not None: |
| x += positions |
| x = F.dropout(x, p=self.dropout, training=self.training) |
|
|
| |
| x = x.transpose(0, 1) |
| attn = None |
|
|
| inner_states = [x] |
|
|
| |
| for layer in self.layers: |
| dropout_probability = np.random.random() |
| if not self.training or (dropout_probability > self.layerdrop): |
| x, attn, _ = layer( |
| x, |
| encoder_out["encoder_out"] if encoder_out is not None else None, |
| encoder_out["padding_mask"] if encoder_out is not None else None, |
| incremental_state, |
| self_attn_mask=self.buffered_future_mask(x) |
| if incremental_state is None |
| else None, |
| ) |
| inner_states.append(x) |
|
|
| if self.layer_norm: |
| x = self.layer_norm(x) |
|
|
| |
| x = x.transpose(0, 1) |
|
|
| return x, {"attn": attn, "inner_states": inner_states} |
|
|
| def output_layer(self, features, **kwargs): |
| """Project features to the vocabulary size.""" |
| |
| if self.share_input_output_embed: |
| return F.linear(features, self.embed_tokens.weight) |
| else: |
| return F.linear(features, self.embed_out) |
|
|
| def max_positions(self): |
| """Maximum output length supported by the decoder.""" |
| if self.embed_positions is None: |
| return self.max_target_positions |
| return min(self.max_target_positions, self.embed_positions.max_positions) |
|
|
| def buffered_future_mask(self, tensor): |
| dim = tensor.size(0) |
| if ( |
| not hasattr(self, "_future_mask") |
| or self._future_mask is None |
| or self._future_mask.device != tensor.device |
| or self._future_mask.size(0) < dim |
| ): |
| self._future_mask = torch.triu( |
| utils.fill_with_neg_inf(tensor.new(dim, dim)), 1 |
| ) |
| return self._future_mask[:dim, :dim] |
|
|
| def upgrade_state_dict_named(self, state_dict, name): |
| return state_dict |
|
|
|
|
| def Embedding(num_embeddings, embedding_dim, padding_idx): |
| m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) |
| nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) |
| nn.init.constant_(m.weight[padding_idx], 0) |
| return m |
|
|
|
|
| def Linear(in_features, out_features, bias=True): |
| m = nn.Linear(in_features, out_features, bias) |
| nn.init.xavier_uniform_(m.weight) |
| if bias: |
| nn.init.constant_(m.bias, 0.0) |
| return m |
|
|