| |
| |
| |
| |
|
|
| from dataclasses import dataclass, field |
| import logging |
| import math |
| from typing import Optional, Tuple |
| from omegaconf import II |
| import sys |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from fairseq.dataclass import ChoiceEnum, FairseqDataclass |
| from fairseq.models import BaseFairseqModel, register_model |
| from fairseq.modules import ( |
| Fp32GroupNorm, |
| Fp32LayerNorm, |
| GumbelVectorQuantizer, |
| KmeansVectorQuantizer, |
| TransposeLast, |
| ) |
| from fairseq.tasks import FairseqTask |
| from fairseq.utils import buffered_arange |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| AGGREGATOR_CHOICES = ChoiceEnum(["cnn", "gru"]) |
| PROJECT_FEATURES_CHOICES = ChoiceEnum(["none", "same", "new"]) |
| ACTIVATION_CHOICES = ChoiceEnum(["relu", "gelu"]) |
| VQ_TYPE_CHOICES = ChoiceEnum(["none", "gumbel", "kmeans"]) |
|
|
|
|
| @dataclass |
| class Wav2VecConfig(FairseqDataclass): |
| prediction_steps: int = field( |
| default=12, metadata={"help": "number of steps ahead to predict"} |
| ) |
| sample_distance: Optional[int] = field( |
| default=None, |
| metadata={ |
| "help": "sample distance from target. does not work properly with cross-sampling" |
| }, |
| ) |
| cross_sample_negatives: int = field( |
| default=0, metadata={"help": "num of cross sampled negatives"} |
| ) |
| num_negatives: int = field( |
| default=10, metadata={"help": "num of sampled negatives"} |
| ) |
| conv_feature_layers: str = field( |
| default="[(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1), (512, 1, 1)]", |
| metadata={ |
| "help": "convolutional feature extraction layers [(dim, kernel_size, stride), ...]" |
| }, |
| ) |
| conv_aggregator_layers: str = field( |
| default="[(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)]", |
| metadata={ |
| "help": "convolutional aggregator layers [(dim, kernel_size, stride), ...]" |
| }, |
| ) |
| dropout: float = field( |
| default=0.0, metadata={"help": "dropout to apply within the model"} |
| ) |
| dropout_features: float = field( |
| default=0.0, metadata={"help": "dropout to apply to the features"} |
| ) |
| dropout_agg: float = field( |
| default=0.0, metadata={"help": "dropout to apply after aggregation step"} |
| ) |
| aggregator: AGGREGATOR_CHOICES = field( |
| default="cnn", metadata={"help": "type of aggregator to use"} |
| ) |
| gru_dim: int = field(default=512, metadata={"help": "GRU dimensionality"}) |
| no_conv_bias: bool = field( |
| default=False, metadata={"help": "if set, does not learn bias for conv layers"} |
| ) |
| agg_zero_pad: bool = field( |
| default=False, |
| metadata={"help": "if set, zero pads in aggregator instead of repl pad"}, |
| ) |
| skip_connections_feat: bool = field( |
| default=False, |
| metadata={"help": "if set, adds skip connections to the feature extractor"}, |
| ) |
| skip_connections_agg: bool = field( |
| default=True, |
| metadata={"help": "if set, adds skip connections to the aggregator"}, |
| ) |
| residual_scale: float = field( |
| default=0.5, metadata={"help": "scales residual by sqrt(value)"} |
| ) |
| log_compression: bool = field( |
| default=True, |
| metadata={"help": "if set, adds a log compression to feature extractor"}, |
| ) |
| balanced_classes: bool = field( |
| default=False, |
| metadata={"help": "if set, loss is scaled to balance for number of negatives"}, |
| ) |
| project_features: PROJECT_FEATURES_CHOICES = field( |
| default="none", |
| metadata={ |
| "help": "if not none, features are projected using the (same or new) aggregator" |
| }, |
| ) |
| non_affine_group_norm: bool = field( |
| default=False, metadata={"help": "if set, group norm is not affine"} |
| ) |
| offset: str = field( |
| default="auto", |
| metadata={ |
| "help": "if set to 'auto', it is computed automatically from the receptive field, else set to int value" |
| }, |
| ) |
| activation: ACTIVATION_CHOICES = field( |
| default="relu", |
| metadata={ |
| "help": "if set to 'auto', it is computed automatically from the receptive field, else set to int value" |
| }, |
| ) |
| vq_type: VQ_TYPE_CHOICES = field( |
| default="none", metadata={"help": "which type of quantizer to use"} |
| ) |
| vq_vars: int = field( |
| default=320, |
| metadata={"help": "project to this many vector quantized variables per group"}, |
| ) |
| vq_groups: int = field( |
| default=2, metadata={"help": "number of groups of latent variables"} |
| ) |
| vq_dim: int = field( |
| default=0, |
| metadata={ |
| "help": "uses this dimensionality for quantized vectors. 0 to use model dim // groups" |
| }, |
| ) |
| vq_depth: int = field( |
| default=1, metadata={"help": "number of layers for vq weight projection"} |
| ) |
| combine_groups: bool = field( |
| default=False, metadata={"help": "if set, variables are shared among groups"} |
| ) |
| vq_temp: Tuple[float, float, float] = field( |
| default=(2.0, 0.5, 0.999995), |
| metadata={ |
| "help": "temperature for latent variable sampling with gumbel softmax. should be a tuple of 3 values (start, end, decay)" |
| }, |
| ) |
| vq_gamma: float = field( |
| default=0.25, |
| metadata={"help": "gamma parameter for kmeans style vector quantization"}, |
| ) |
| infonce: bool = II("criterion.infonce") |
|
|
|
|
| @register_model("wav2vec", dataclass=Wav2VecConfig) |
| class Wav2VecModel(BaseFairseqModel): |
| @classmethod |
| def build_model(cls, cfg: Wav2VecConfig, task: FairseqTask): |
| """Build a new model instance.""" |
|
|
| model = Wav2VecModel(cfg) |
| logger.info(model) |
| return model |
|
|
| def __init__(self, cfg: Wav2VecConfig): |
| super().__init__() |
|
|
| self.prediction_steps = cfg.prediction_steps |
| offset = cfg.offset |
|
|
| if cfg.activation == "relu": |
| activation = nn.ReLU() |
| elif cfg.activation == "gelu": |
| activation = nn.GELU() |
| else: |
| raise Exception("unknown activation " + cfg.activation) |
|
|
| feature_enc_layers = eval(cfg.conv_feature_layers) |
| self.feature_extractor = ConvFeatureExtractionModel( |
| conv_layers=feature_enc_layers, |
| dropout=0.0, |
| log_compression=cfg.log_compression, |
| skip_connections=cfg.skip_connections_feat, |
| residual_scale=cfg.residual_scale, |
| non_affine_group_norm=cfg.non_affine_group_norm, |
| activation=activation, |
| ) |
| embed = feature_enc_layers[-1][0] |
|
|
| self.vector_quantizer = None |
| if cfg.vq_type == "gumbel": |
| self.vector_quantizer = GumbelVectorQuantizer( |
| dim=embed, |
| num_vars=cfg.vq_vars, |
| temp=cfg.vq_temp, |
| groups=cfg.vq_groups, |
| combine_groups=cfg.combine_groups, |
| vq_dim=cfg.vq_dim if cfg.vq_dim > 0 else embed, |
| time_first=False, |
| activation=activation, |
| weight_proj_depth=cfg.vq_depth, |
| weight_proj_factor=2, |
| ) |
| elif cfg.vq_type == "kmeans": |
| self.vector_quantizer = KmeansVectorQuantizer( |
| dim=embed, |
| num_vars=cfg.vq_vars, |
| groups=cfg.vq_groups, |
| combine_groups=cfg.combine_groups, |
| vq_dim=cfg.vq_dim if cfg.vq_dim > 0 else embed, |
| time_first=False, |
| gamma=cfg.vq_gamma, |
| ) |
| else: |
| assert ( |
| cfg.vq_type == "none" or cfg.vq_type is None |
| ), "Unknown quantizer type" |
|
|
| if cfg.offset == "auto": |
| jin = 0 |
| rin = 0 |
| for _, k, stride in feature_enc_layers: |
| if rin == 0: |
| rin = k |
| rin = rin + (k - 1) * jin |
| if jin == 0: |
| jin = stride |
| else: |
| jin *= stride |
| offset = math.ceil(rin / jin) |
|
|
| offset = int(offset) |
|
|
| def make_aggregator(): |
| if cfg.aggregator == "cnn": |
| agg_layers = eval(cfg.conv_aggregator_layers) |
| agg_dim = agg_layers[-1][0] |
| feature_aggregator = ConvAggegator( |
| conv_layers=agg_layers, |
| embed=embed, |
| dropout=cfg.dropout, |
| skip_connections=cfg.skip_connections_agg, |
| residual_scale=cfg.residual_scale, |
| non_affine_group_norm=cfg.non_affine_group_norm, |
| conv_bias=not cfg.no_conv_bias, |
| zero_pad=cfg.agg_zero_pad, |
| activation=activation, |
| ) |
| elif cfg.aggregator == "gru": |
| agg_dim = cfg.gru_dim |
| feature_aggregator = nn.Sequential( |
| TransposeLast(), |
| nn.GRU( |
| input_size=embed, |
| hidden_size=agg_dim, |
| num_layers=1, |
| dropout=cfg.dropout, |
| ), |
| TransposeLast(deconstruct_idx=0), |
| ) |
| else: |
| raise Exception("unknown aggregator type " + cfg.aggregator) |
|
|
| return feature_aggregator, agg_dim |
|
|
| self.feature_aggregator, agg_dim = make_aggregator() |
|
|
| self.wav2vec_predictions = Wav2VecPredictionsModel( |
| in_dim=agg_dim, |
| out_dim=embed, |
| prediction_steps=cfg.prediction_steps, |
| n_negatives=cfg.num_negatives, |
| cross_sample_negatives=cfg.cross_sample_negatives, |
| sample_distance=cfg.sample_distance, |
| dropout=cfg.dropout, |
| offset=offset, |
| balanced_classes=cfg.balanced_classes, |
| infonce=cfg.infonce, |
| ) |
|
|
| self.dropout_feats = nn.Dropout(p=cfg.dropout_features) |
| self.dropout_agg = nn.Dropout(p=cfg.dropout_agg) |
|
|
| if cfg.project_features == "none": |
| self.project_features = None |
| elif cfg.project_features == "same": |
| self.project_features = self.feature_aggregator |
| elif cfg.project_features == "new": |
| self.project_features, _ = make_aggregator() |
|
|
| def forward(self, source): |
| result = {} |
|
|
| features = self.feature_extractor(source) |
| if self.vector_quantizer: |
| q_res = self.vector_quantizer(features) |
| features = q_res["x"] |
| for k in q_res.keys(): |
| if k != "x": |
| result[k] = q_res[k] |
|
|
| x = self.dropout_feats(features) |
| x = self.feature_aggregator(x) |
| x = self.dropout_agg(x) |
|
|
| if self.project_features is not None: |
| features = self.project_features(features) |
| x, targets = self.wav2vec_predictions(x, features) |
| result["cpc_logits"] = x |
| result["cpc_targets"] = targets |
|
|
| return result |
|
|
| def upgrade_state_dict_named(self, state_dict, name): |
| super().upgrade_state_dict_named(state_dict, name) |
|
|
| def max_positions(self): |
| """Maximum length supported by the model.""" |
| return sys.maxsize |
|
|
| def get_logits(self, net_output): |
| logits = net_output["cpc_logits"] |
| return logits |
|
|
| def get_targets(self, sample, net_output): |
| t = net_output["cpc_targets"] |
| if isinstance(t, tuple): |
| t = t[0] |
| return t.contiguous() |
|
|
| def get_target_weights(self, targets, net_output): |
| targets = net_output["cpc_targets"] |
| if isinstance(targets, tuple) and targets[-1] is not None: |
| return targets[-1] |
| return None |
|
|
| def get_extra_losses(self, net_output): |
| loss = None |
| if "prob_perplexity" in net_output: |
| loss = net_output["num_vars"] - net_output["prob_perplexity"] |
| elif "kmeans_loss" in net_output: |
| loss = net_output["kmeans_loss"] |
|
|
| return loss |
|
|
|
|
| def norm_block(is_layer_norm, dim, affine=True): |
| if is_layer_norm: |
| mod = nn.Sequential( |
| TransposeLast(), |
| Fp32LayerNorm(dim, elementwise_affine=affine), |
| TransposeLast(), |
| ) |
| else: |
| mod = Fp32GroupNorm(1, dim, affine=affine) |
|
|
| return mod |
|
|
|
|
| class ConvFeatureExtractionModel(nn.Module): |
| def __init__( |
| self, |
| conv_layers, |
| dropout, |
| log_compression, |
| skip_connections, |
| residual_scale, |
| non_affine_group_norm, |
| activation, |
| ): |
| super().__init__() |
|
|
| def block(n_in, n_out, k, stride): |
| return nn.Sequential( |
| nn.Conv1d(n_in, n_out, k, stride=stride, bias=False), |
| nn.Dropout(p=dropout), |
| norm_block( |
| is_layer_norm=False, dim=n_out, affine=not non_affine_group_norm |
| ), |
| activation, |
| ) |
|
|
| in_d = 1 |
| self.conv_layers = nn.ModuleList() |
| for dim, k, stride in conv_layers: |
| self.conv_layers.append(block(in_d, dim, k, stride)) |
| in_d = dim |
|
|
| self.log_compression = log_compression |
| self.skip_connections = skip_connections |
| self.residual_scale = math.sqrt(residual_scale) |
|
|
| def forward(self, x): |
| |
| x = x.unsqueeze(1) |
|
|
| for conv in self.conv_layers: |
| residual = x |
| x = conv(x) |
| if self.skip_connections and x.size(1) == residual.size(1): |
| tsz = x.size(2) |
| r_tsz = residual.size(2) |
| residual = residual[..., :: r_tsz // tsz][..., :tsz] |
| x = (x + residual) * self.residual_scale |
|
|
| if self.log_compression: |
| x = x.abs() |
| x = x + 1 |
| x = x.log() |
|
|
| return x |
|
|
|
|
| class ZeroPad1d(nn.Module): |
| def __init__(self, pad_left, pad_right): |
| super().__init__() |
| self.pad_left = pad_left |
| self.pad_right = pad_right |
|
|
| def forward(self, x): |
| return F.pad(x, (self.pad_left, self.pad_right)) |
|
|
|
|
| class ConvAggegator(nn.Module): |
| def __init__( |
| self, |
| conv_layers, |
| embed, |
| dropout, |
| skip_connections, |
| residual_scale, |
| non_affine_group_norm, |
| conv_bias, |
| zero_pad, |
| activation, |
| ): |
| super().__init__() |
|
|
| def block(n_in, n_out, k, stride): |
| |
| ka = k // 2 |
| kb = ka - 1 if k % 2 == 0 else ka |
|
|
| pad = ( |
| ZeroPad1d(ka + kb, 0) if zero_pad else nn.ReplicationPad1d((ka + kb, 0)) |
| ) |
|
|
| return nn.Sequential( |
| pad, |
| nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias), |
| nn.Dropout(p=dropout), |
| norm_block(False, n_out, affine=not non_affine_group_norm), |
| activation, |
| ) |
|
|
| in_d = embed |
| self.conv_layers = nn.ModuleList() |
| self.residual_proj = nn.ModuleList() |
| for dim, k, stride in conv_layers: |
| if in_d != dim and skip_connections: |
| self.residual_proj.append(nn.Conv1d(in_d, dim, 1, bias=False)) |
| else: |
| self.residual_proj.append(None) |
|
|
| self.conv_layers.append(block(in_d, dim, k, stride)) |
| in_d = dim |
| self.conv_layers = nn.Sequential(*self.conv_layers) |
| self.skip_connections = skip_connections |
| self.residual_scale = math.sqrt(residual_scale) |
|
|
| def forward(self, x): |
| for rproj, conv in zip(self.residual_proj, self.conv_layers): |
| residual = x |
| x = conv(x) |
| if self.skip_connections: |
| if rproj is not None: |
| residual = rproj(residual) |
| x = (x + residual) * self.residual_scale |
| return x |
|
|
|
|
| class Wav2VecPredictionsModel(nn.Module): |
| def __init__( |
| self, |
| in_dim, |
| out_dim, |
| prediction_steps, |
| n_negatives, |
| cross_sample_negatives, |
| sample_distance, |
| dropout, |
| offset, |
| balanced_classes, |
| infonce, |
| ): |
| super().__init__() |
|
|
| self.n_negatives = n_negatives |
| self.cross_sample_negatives = cross_sample_negatives |
| self.sample_distance = sample_distance |
| self.project_to_steps = nn.ConvTranspose2d( |
| in_dim, out_dim, (1, prediction_steps) |
| ) |
| self.dropout = nn.Dropout(p=dropout) |
| self.offset = offset |
| self.balanced_classes = balanced_classes |
| self.infonce = infonce |
|
|
| def sample_negatives(self, y): |
| bsz, fsz, tsz = y.shape |
|
|
| y = y.transpose(0, 1) |
| y = y.contiguous().view(fsz, -1) |
|
|
| cross_high = tsz * bsz |
| high = tsz if self.sample_distance is None else min(tsz, self.sample_distance) |
| assert high > 1 |
|
|
| neg_idxs = torch.randint(low=0, high=high, size=(bsz, self.n_negatives * tsz)) |
|
|
| with torch.no_grad(): |
| if self.n_negatives > 0: |
| tszs = ( |
| buffered_arange(tsz) |
| .unsqueeze(-1) |
| .expand(-1, self.n_negatives) |
| .flatten() |
| ) |
|
|
| neg_idxs = torch.randint( |
| low=0, high=high - 1, size=(bsz, self.n_negatives * tsz) |
| ) |
| neg_idxs[neg_idxs >= tszs] += 1 |
|
|
| if self.cross_sample_negatives > 0: |
| tszs = ( |
| buffered_arange(tsz) |
| .unsqueeze(-1) |
| .expand(-1, self.cross_sample_negatives) |
| .flatten() |
| ) |
|
|
| cross_neg_idxs = torch.randint( |
| low=0, |
| high=cross_high - 1, |
| size=(bsz, self.cross_sample_negatives * tsz), |
| ) |
| cross_neg_idxs[cross_neg_idxs >= tszs] += 1 |
|
|
| if self.n_negatives > 0: |
| for i in range(1, bsz): |
| neg_idxs[i] += i * high |
| else: |
| neg_idxs = cross_neg_idxs |
|
|
| if self.cross_sample_negatives > 0 and self.n_negatives > 0: |
| neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) |
|
|
| negs = y[..., neg_idxs.view(-1)] |
| negs = negs.view( |
| fsz, bsz, self.n_negatives + self.cross_sample_negatives, tsz |
| ).permute( |
| 2, 1, 0, 3 |
| ) |
|
|
| return negs |
|
|
| def forward(self, x, y): |
|
|
| x = x.unsqueeze(-1) |
| x = self.project_to_steps(x) |
| x = self.dropout(x) |
|
|
| negatives = self.sample_negatives(y) |
| y = y.unsqueeze(0) |
| targets = torch.cat([y, negatives], dim=0) |
|
|
| copies = targets.size(0) |
| bsz, dim, tsz, steps = x.shape |
| steps = min(steps, tsz - self.offset) |
|
|
| predictions = x.new( |
| bsz * copies * (tsz - self.offset + 1) * steps |
| - ((steps + 1) * steps // 2) * copies * bsz |
| ) |
| if self.infonce: |
| labels = predictions.new_full( |
| (predictions.shape[0] // copies,), 0, dtype=torch.long |
| ) |
| else: |
| labels = torch.zeros_like(predictions) |
| weights = ( |
| torch.full_like(labels, 1 / self.n_negatives) |
| if self.balanced_classes and not self.infonce |
| else None |
| ) |
|
|
| start = end = 0 |
| for i in range(steps): |
| offset = i + self.offset |
| end = start + (tsz - offset) * bsz * copies |
| if self.infonce: |
| predictions[start:end] = torch.einsum( |
| "bct,nbct->tbn", x[..., :-offset, i], targets[..., offset:] |
| ).flatten() |
| else: |
| pos_num = (end - start) // copies |
| predictions[start:end] = torch.einsum( |
| "bct,nbct->nbt", x[..., :-offset, i], targets[..., offset:] |
| ).flatten() |
| labels[start : start + pos_num] = 1.0 |
| if weights is not None: |
| weights[start : start + pos_num] = 1.0 |
| start = end |
| assert end == predictions.numel(), "{} != {}".format(end, predictions.numel()) |
|
|
| if self.infonce: |
| predictions = predictions.view(-1, copies) |
| else: |
| if weights is not None: |
| labels = (labels, weights) |
|
|
| return predictions, labels |
|
|