Respair's picture
Upload folder using huggingface_hub
59b7eeb verified
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass, field
from typing import List, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from omegaconf import MISSING
from fairseq import utils, checkpoint_utils
from fairseq.data.data_utils import compute_mask_indices
from fairseq.dataclass import ChoiceEnum, FairseqDataclass
from fairseq.models import BaseFairseqModel, register_model
from fairseq.modules import (
Fp32GroupNorm,
Fp32LayerNorm,
GradMultiply,
GumbelVectorQuantizer,
LayerNorm,
MultiheadAttention,
SamePad,
TransposeLast,
)
from fairseq.utils import buffered_arange, index_put, is_xla_tensor
EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"])
MASKING_DISTRIBUTION_CHOICES = ChoiceEnum(["static", "uniform", "normal", "poisson"])
def init_bert_params(module):
"""
Initialize the weights specific to the BERT Model.
This overrides the default initializations depending on the specified arguments.
1. If normal_init_linear_weights is set then weights of linear
layer will be initialized using the normal distribution and
bais will be set to the specified value.
2. If normal_init_embed_weights is set then weights of embedding
layer will be initialized using the normal distribution.
3. If normal_init_proj_weights is set then weights of
in_project_weight for MultiHeadAttention initialized using
the normal distribution (to be validated).
"""
def normal_(data):
# with FSDP, module params will be on CUDA, so we cast them back to CPU
# so that the RNG is consistent with and without FSDP
data.copy_(
data.cpu().normal_(mean=0.0, std=0.02).to(data.device)
)
if isinstance(module, nn.Linear):
normal_(module.weight.data)
if module.bias is not None:
module.bias.data.zero_()
if isinstance(module, nn.Embedding):
normal_(module.weight.data)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
if isinstance(module, MultiheadAttention):
normal_(module.q_proj.weight.data)
normal_(module.k_proj.weight.data)
normal_(module.v_proj.weight.data)
@dataclass
class Wav2Vec2Config(FairseqDataclass):
pretrained_path: str = field(
default="",
metadata={"help": "pretrained unispeech path"}
)
extractor_mode: EXTRACTOR_MODE_CHOICES = field(
default="default",
metadata={
"help": "mode for feature extractor. default has a single group norm with d "
"groups in the first conv block, whereas layer_norm has layer norms in "
"every block (meant to use with normalize=True)"
},
)
encoder_layers: int = field(
default=12, metadata={"help": "num encoder layers in the transformer"}
)
encoder_embed_dim: int = field(
default=768, metadata={"help": "encoder embedding dimension"}
)
encoder_ffn_embed_dim: int = field(
default=3072, metadata={"help": "encoder embedding dimension for FFN"}
)
encoder_attention_heads: int = field(
default=12, metadata={"help": "num encoder attention heads"}
)
activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field(
default="gelu", metadata={"help": "activation function to use"}
)
# dropouts
dropout: float = field(
default=0.1, metadata={"help": "dropout probability for the transformer"}
)
attention_dropout: float = field(
default=0.1, metadata={"help": "dropout probability for attention weights"}
)
activation_dropout: float = field(
default=0.0, metadata={"help": "dropout probability after activation in FFN"}
)
encoder_layerdrop: float = field(
default=0.0, metadata={"help": "probability of dropping a tarnsformer layer"}
)
dropout_input: float = field(
default=0.0,
metadata={"help": "dropout to apply to the input (after feat extr)"},
)
dropout_features: float = field(
default=0.0,
metadata={"help": "dropout to apply to the features (after feat extr)"},
)
final_dim: int = field(
default=0,
metadata={
"help": "project final representations and targets to this many dimensions."
"set to encoder_embed_dim is <= 0"
},
)
layer_norm_first: bool = field(
default=False, metadata={"help": "apply layernorm first in the transformer"}
)
conv_feature_layers: str = field(
default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]",
metadata={
"help": "string describing convolutional feature extraction layers in form of a python list that contains "
"[(dim, kernel_size, stride), ...]"
},
)
conv_bias: bool = field(
default=False, metadata={"help": "include bias in conv encoder"}
)
logit_temp: float = field(
default=0.1, metadata={"help": "temperature to divide logits by"}
)
quantize_targets: bool = field(
default=False, metadata={"help": "use quantized targets"}
)
quantize_input: bool = field(
default=False, metadata={"help": "use quantized inputs"}
)
same_quantizer: bool = field(
default=False, metadata={"help": "use same quantizer for inputs and targets"}
)
target_glu: bool = field(
default=False, metadata={"help": "adds projection + glu to targets"}
)
feature_grad_mult: float = field(
default=1.0, metadata={"help": "multiply feature extractor var grads by this"}
)
quantizer_depth: int = field(
default=1,
metadata={"help": "number of quantizer layers"},
)
quantizer_factor: int = field(
default=3,
metadata={
"help": "dimensionality increase for inner quantizer layers (if depth > 1)"
},
)
latent_vars: int = field(
default=320,
metadata={"help": "number of latent variables V in each group of the codebook"},
)
latent_groups: int = field(
default=2,
metadata={"help": "number of groups G of latent variables in the codebook"},
)
latent_dim: int = field(
default=0,
metadata={
"help": "if > 0, uses this dimensionality for latent variables. "
"otherwise uses final_dim / latent_groups"
},
)
transpose: bool = field(
default=False, metadata={"help": "set to True for Unispeech"}
)
# masking
mask_length: int = field(default=10, metadata={"help": "mask length"})
mask_prob: float = field(
default=0.65, metadata={"help": "probability of replacing a token with mask"}
)
mask_selection: MASKING_DISTRIBUTION_CHOICES = field(
default="static", metadata={"help": "how to choose mask length"}
)
mask_other: float = field(
default=0,
metadata={
"help": "secondary mask argument (used for more complex distributions), "
"see help in compute_mask_indices"
},
)
no_mask_overlap: bool = field(
default=False, metadata={"help": "whether to allow masks to overlap"}
)
mask_min_space: int = field(
default=1,
metadata={"help": "min space between spans (if no overlap is enabled)"},
)
# channel masking
mask_channel_length: int = field(
default=10, metadata={"help": "length of the mask for features (channels)"}
)
mask_channel_prob: float = field(
default=0.0, metadata={"help": "probability of replacing a feature with 0"}
)
mask_channel_before: bool = False
mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field(
default="static",
metadata={"help": "how to choose mask length for channel masking"},
)
mask_channel_other: float = field(
default=0,
metadata={
"help": "secondary mask argument (used for more complex distributions), "
"see help in compute_mask_indicesh"
},
)
no_mask_channel_overlap: bool = field(
default=False, metadata={"help": "whether to allow channel masks to overlap"}
)
mask_channel_min_space: int = field(
default=1,
metadata={"help": "min space between spans (if no overlap is enabled)"},
)
# negative selection
num_negatives: int = field(
default=100,
metadata={"help": "number of negative examples from the same sample"},
)
negatives_from_everywhere: bool = field(
default=False,
metadata={"help": "sample negatives from everywhere, not just masked states"},
)
cross_sample_negatives: int = field(
default=0, metadata={"help": "number of negative examples from the any sample"}
)
codebook_negatives: int = field(
default=0, metadata={"help": "number of negative examples codebook"}
)
# positional embeddings
conv_pos: int = field(
default=128,
metadata={"help": "number of filters for convolutional positional embeddings"},
)
conv_pos_groups: int = field(
default=16,
metadata={"help": "number of groups for convolutional positional embedding"},
)
latent_temp: Tuple[float, float, float] = field(
default=(2, 0.5, 0.999995),
metadata={
"help": "temperature for latent variable sampling. "
"can be tuple of 3 values (start, end, decay)"
},
)
@register_model("wav2vec2", dataclass=Wav2Vec2Config)
class Wav2Vec2Model(BaseFairseqModel):
def __init__(self, cfg: Wav2Vec2Config):
super().__init__()
self.cfg = cfg
feature_enc_layers = eval(cfg.conv_feature_layers)
self.embed = feature_enc_layers[-1][0]
self.feature_extractor = ConvFeatureExtractionModel(
conv_layers=feature_enc_layers,
dropout=0.0,
mode=cfg.extractor_mode,
conv_bias=cfg.conv_bias,
)
self.post_extract_proj = (
nn.Linear(self.embed, cfg.encoder_embed_dim)
if self.embed != cfg.encoder_embed_dim and not cfg.quantize_input
else None
)
self.mask_prob = cfg.mask_prob
self.mask_selection = cfg.mask_selection
self.mask_other = cfg.mask_other
self.mask_length = cfg.mask_length
self.no_mask_overlap = cfg.no_mask_overlap
self.mask_min_space = cfg.mask_min_space
self.mask_channel_prob = cfg.mask_channel_prob
self.mask_channel_before = cfg.mask_channel_before
self.mask_channel_selection = cfg.mask_channel_selection
self.mask_channel_other = cfg.mask_channel_other
self.mask_channel_length = cfg.mask_channel_length
self.no_mask_channel_overlap = cfg.no_mask_channel_overlap
self.mask_channel_min_space = cfg.mask_channel_min_space
self.dropout_input = nn.Dropout(cfg.dropout_input)
self.dropout_features = nn.Dropout(cfg.dropout_features)
self.feature_grad_mult = cfg.feature_grad_mult
self.quantizer = None
self.input_quantizer = None
self.n_negatives = cfg.num_negatives
self.cross_sample_negatives = cfg.cross_sample_negatives
self.codebook_negatives = cfg.codebook_negatives
self.negatives_from_everywhere = cfg.negatives_from_everywhere
self.logit_temp = cfg.logit_temp
self.transpose = cfg.transpose
final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim
if cfg.quantize_targets:
vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else final_dim
self.quantizer = GumbelVectorQuantizer(
dim=self.embed,
num_vars=cfg.latent_vars,
temp=cfg.latent_temp,
groups=cfg.latent_groups,
combine_groups=False,
vq_dim=vq_dim,
time_first=True,
weight_proj_depth=cfg.quantizer_depth,
weight_proj_factor=cfg.quantizer_factor,
)
self.project_q = nn.Linear(vq_dim, final_dim)
else:
self.project_q = nn.Linear(self.embed, final_dim)
if cfg.quantize_input:
if cfg.same_quantizer and self.quantizer is not None:
vq_dim = final_dim
self.input_quantizer = self.quantizer
else:
vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else cfg.encoder_embed_dim
self.input_quantizer = GumbelVectorQuantizer(
dim=self.embed,
num_vars=cfg.latent_vars,
temp=cfg.latent_temp,
groups=cfg.latent_groups,
combine_groups=False,
vq_dim=vq_dim,
time_first=True,
weight_proj_depth=cfg.quantizer_depth,
weight_proj_factor=cfg.quantizer_factor,
)
self.project_inp = nn.Linear(vq_dim, cfg.encoder_embed_dim)
self.mask_emb = nn.Parameter(
torch.FloatTensor(cfg.encoder_embed_dim).uniform_()
)
self.encoder = TransformerEncoder(cfg)
self.layer_norm = LayerNorm(self.embed)
self.target_glu = None
if cfg.target_glu:
self.target_glu = nn.Sequential(
nn.Linear(final_dim, final_dim * 2), nn.GLU()
)
if cfg.transpose:
self.final_proj = nn.Linear(final_dim, cfg.encoder_embed_dim)
else:
self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim)
if cfg.pretrained_path is not None and cfg.pretrained_path != "":
state = checkpoint_utils.load_checkpoint_to_cpu(
cfg.pretrained_path)['model']
names = []
for name in state.keys():
if 'w2v_encoder.w2v_model.' in name:
names.append(name)
for name in names:
newname = name.replace('w2v_encoder.w2v_model.', '')
state[newname] = state.pop(name)
self.load_state_dict(state, strict=False)
def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
return state_dict
@classmethod
def build_model(cls, cfg: Wav2Vec2Config, task=None):
"""Build a new model instance."""
return cls(cfg)
def apply_mask(
self,
x,
padding_mask,
mask_indices=None,
mask_channel_indices=None,
):
B, T, C = x.shape
if self.mask_channel_prob > 0 and self.mask_channel_before:
mask_channel_indices = compute_mask_indices(
(B, C),
None,
self.mask_channel_prob,
self.mask_channel_length,
self.mask_channel_selection,
self.mask_channel_other,
no_overlap=self.no_mask_channel_overlap,
min_space=self.mask_channel_min_space,
)
mask_channel_indices = (
torch.from_numpy(mask_channel_indices)
.to(x.device)
.unsqueeze(1)
.expand(-1, T, -1)
)
x[mask_channel_indices] = 0
if self.mask_prob > 0:
if mask_indices is None:
mask_indices = compute_mask_indices(
(B, T),
padding_mask,
self.mask_prob,
self.mask_length,
self.mask_selection,
self.mask_other,
min_masks=2,
no_overlap=self.no_mask_overlap,
min_space=self.mask_min_space,
)
mask_indices = torch.from_numpy(mask_indices).to(x.device)
x = index_put(x, mask_indices, self.mask_emb)
else:
mask_indices = None
if self.mask_channel_prob > 0 and not self.mask_channel_before:
if mask_channel_indices is None:
mask_channel_indices = compute_mask_indices(
(B, C),
None,
self.mask_channel_prob,
self.mask_channel_length,
self.mask_channel_selection,
self.mask_channel_other,
no_overlap=self.no_mask_channel_overlap,
min_space=self.mask_channel_min_space,
)
mask_channel_indices = (
torch.from_numpy(mask_channel_indices)
.to(x.device)
.unsqueeze(1)
.expand(-1, T, -1)
)
x = index_put(x, mask_channel_indices, 0)
return x, mask_indices
def sample_negatives(self, y, num, padding_count=None):
if self.n_negatives == 0 and self.cross_sample_negatives == 0:
return y.new(0)
bsz, tsz, fsz = y.shape
y = y.view(-1, fsz) # BTC => (BxT)C
# FIXME: what happens if padding_count is specified?
cross_high = tsz * bsz
high = tsz - (padding_count or 0)
with torch.no_grad():
assert high > 1, f"{bsz,tsz,fsz}"
if self.n_negatives > 0:
tszs = (
buffered_arange(num)
.unsqueeze(-1)
.expand(-1, self.n_negatives)
.flatten()
)
neg_idxs = torch.randint(
low=0, high=high - 1, size=(bsz, self.n_negatives * num)
)
neg_idxs[neg_idxs >= tszs] += 1
if self.cross_sample_negatives > 0:
tszs = (
buffered_arange(num)
.unsqueeze(-1)
.expand(-1, self.cross_sample_negatives)
.flatten()
)
cross_neg_idxs = torch.randint(
low=0,
high=cross_high - 1,
size=(bsz, self.cross_sample_negatives * num),
)
cross_neg_idxs[cross_neg_idxs >= tszs] += 1
if self.n_negatives > 0:
for i in range(1, bsz):
neg_idxs[i] += i * high
else:
neg_idxs = cross_neg_idxs
if self.cross_sample_negatives > 0 and self.n_negatives > 0:
neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1)
negs = y[neg_idxs.view(-1)]
negs = negs.view(
bsz, num, self.n_negatives + self.cross_sample_negatives, fsz
).permute(
2, 0, 1, 3
) # to NxBxTxC
return negs, neg_idxs
def compute_preds(self, x, y, negatives):
neg_is_pos = (y == negatives).all(-1)
y = y.unsqueeze(0)
targets = torch.cat([y, negatives], dim=0)
logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x)
logits = logits / self.logit_temp
if is_xla_tensor(logits) or neg_is_pos.any():
fillval = -float(2 ** 30)
if not hasattr(self, "_inftensor"):
self._inftensor = (
torch.tensor(fillval).to(x.device)
if is_xla_tensor(logits)
else float("-inf")
)
logits[1:] = index_put(logits[1:], neg_is_pos, self._inftensor)
return logits
def forward(
self,
source,
padding_mask=None,
mask=True,
features_only=False,
layer=None,
mask_indices=None,
mask_channel_indices=None,
padding_count=None,
):
if self.feature_grad_mult > 0:
features = self.feature_extractor(source)
if self.feature_grad_mult != 1.0:
features = GradMultiply.apply(features, self.feature_grad_mult)
else:
with torch.no_grad():
features = self.feature_extractor(source)
features_pen = features.float().pow(2).mean()
features = features.transpose(1, 2)
features = self.layer_norm(features)
unmasked_features = features.clone()
if padding_mask is not None:
extra = padding_mask.size(1) % features.size(1)
if extra > 0:
padding_mask = padding_mask[:, :-extra]
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
padding_mask = padding_mask.all(-1)
if self.post_extract_proj is not None:
features = self.post_extract_proj(features)
features = self.dropout_input(features)
unmasked_features = self.dropout_features(unmasked_features)
num_vars = None
code_ppl = None
prob_ppl = None
curr_temp = None
if self.input_quantizer:
q = self.input_quantizer(features, produce_targets=False)
features = q["x"]
num_vars = q["num_vars"]
code_ppl = q["code_perplexity"]
prob_ppl = q["prob_perplexity"]
curr_temp = q["temp"]
features = self.project_inp(features)
if mask:
x, mask_indices = self.apply_mask(
features,
padding_mask,
mask_indices=mask_indices,
mask_channel_indices=mask_channel_indices,
)
if not is_xla_tensor(x) and mask_indices is not None:
# tpu-comment: reducing the size in a dynamic way causes
# too many recompilations on xla.
y = unmasked_features[mask_indices].view(
unmasked_features.size(0), -1, unmasked_features.size(-1)
)
else:
y = unmasked_features
else:
x = features
y = unmasked_features
mask_indices = None
x, layer_results = self.encoder(x, padding_mask=padding_mask, layer=layer)
if features_only:
return {
"x": x,
"padding_mask": padding_mask,
"features": unmasked_features,
"layer_results": layer_results,
}
results = {"features": x, "feature_padding_mask": padding_mask}
if self.quantizer:
q = self.quantizer(y, produce_targets=False)
y = q["x"]
num_vars = q["num_vars"]
code_ppl = q["code_perplexity"]
prob_ppl = q["prob_perplexity"]
curr_temp = q["temp"]
y = self.project_q(y)
q = self.quantizer(unmasked_features, produce_targets=False)
results['q'] = q['x']
if self.negatives_from_everywhere:
negs = self.project_q(q['x'])
negs, _ = self.sample_negatives(
negs,
y.size(1),
padding_count=padding_count,
)
else:
negs, _ = self.sample_negatives(
y,
y.size(1),
padding_count=padding_count,
)
if self.codebook_negatives > 0:
cb_negs = self.quantizer.sample_from_codebook(
y.size(0) * y.size(1), self.codebook_negatives
)
cb_negs = cb_negs.view(
self.codebook_negatives, y.size(0), y.size(1), -1
) # order doesnt matter
cb_negs = self.project_q(cb_negs)
negs = torch.cat([negs, cb_negs], dim=0)
else:
y = self.project_q(y)
if self.negatives_from_everywhere:
negs, _ = self.sample_negatives(
unmasked_features,
y.size(1),
padding_count=padding_count,
)
negs = self.project_q(negs)
else:
negs, _ = self.sample_negatives(
y,
y.size(1),
padding_count=padding_count,
)
if not is_xla_tensor(x):
# tpu-comment: reducing the size in a dynamic way causes
# too many recompilations on xla.
x = x[mask_indices].view(x.size(0), -1, x.size(-1))
if self.target_glu:
y = self.target_glu(y)
negs = self.target_glu(negs)
if self.transpose:
y = self.final_proj(y)
negs = self.final_proj(negs)
results["q"] = self.final_proj(self.project_q(results["q"]))
else:
x = self.final_proj(x)
x = self.compute_preds(x, y, negs)
results['x'] = x
results['padding_mask'] = padding_mask
results['features_pen'] = features_pen
if prob_ppl is not None:
results["prob_perplexity"] = prob_ppl
results["code_perplexity"] = code_ppl
results["num_vars"] = num_vars
results["temp"] = curr_temp
return results
def quantize(self, x):
assert self.quantizer is not None
x = self.feature_extractor(x)
x = x.transpose(1, 2)
x = self.layer_norm(x)
return self.quantizer.forward_idx(x)
def extract_features(self, source, padding_mask, mask=False, layer=None):
res = self.forward(
source, padding_mask, mask=mask, features_only=True, layer=layer
)
return res
def get_logits(self, net_output):
logits = net_output["x"]
logits = logits.transpose(0, 2)
logits = logits.reshape(-1, logits.size(-1))
return logits
def get_targets(self, sample, net_output, expand_steps=True):
x = net_output["x"]
return x.new_zeros(x.size(1) * x.size(2), dtype=torch.long)
def get_extra_losses(self, net_output):
pen = []
if "prob_perplexity" in net_output:
pen.append(
(net_output["num_vars"] - net_output["prob_perplexity"])
/ net_output["num_vars"]
)
if "features_pen" in net_output:
pen.append(net_output["features_pen"])
return pen
def remove_pretraining_modules(self):
self.quantizer = None
self.project_q = None
self.target_glu = None
self.final_proj = None
class ConvFeatureExtractionModel(nn.Module):
def __init__(
self,
conv_layers: List[Tuple[int, int, int]],
dropout: float = 0.0,
mode: str = "default",
conv_bias: bool = False,
):
super().__init__()
assert mode in {"default", "layer_norm"}
def block(
n_in,
n_out,
k,
stride,
is_layer_norm=False,
is_group_norm=False,
conv_bias=False,
):
def make_conv():
conv = nn.Conv1d(n_in, n_out, k, stride=stride, bias=conv_bias)
nn.init.kaiming_normal_(conv.weight)
return conv
assert (
is_layer_norm and is_group_norm
) == False, "layer norm and group norm are exclusive"
if is_layer_norm:
return nn.Sequential(
make_conv(),
nn.Dropout(p=dropout),
nn.Sequential(
TransposeLast(),
Fp32LayerNorm(dim, elementwise_affine=True),
TransposeLast(),
),
nn.GELU(),
)
elif is_group_norm:
return nn.Sequential(
make_conv(),
nn.Dropout(p=dropout),
Fp32GroupNorm(dim, dim, affine=True),
nn.GELU(),
)
else:
return nn.Sequential(make_conv(), nn.Dropout(p=dropout), nn.GELU())
in_d = 1
self.conv_layers = nn.ModuleList()
for i, cl in enumerate(conv_layers):
assert len(cl) == 3, "invalid conv definition: " + str(cl)
(dim, k, stride) = cl
self.conv_layers.append(
block(
in_d,
dim,
k,
stride,
is_layer_norm=mode == "layer_norm",
is_group_norm=mode == "default" and i == 0,
conv_bias=conv_bias,
)
)
in_d = dim
def forward(self, x):
# BxT -> BxCxT
x = x.unsqueeze(1)
for conv in self.conv_layers:
x = conv(x)
return x
class TransformerEncoder(nn.Module):
def __init__(self, args):
super().__init__()
self.dropout = args.dropout
self.embedding_dim = args.encoder_embed_dim
self.pos_conv = nn.Conv1d(
self.embedding_dim,
self.embedding_dim,
kernel_size=args.conv_pos,
padding=args.conv_pos // 2,
groups=args.conv_pos_groups,
)
dropout = 0
std = math.sqrt((4 * (1.0 - dropout)) / (args.conv_pos * self.embedding_dim))
nn.init.normal_(self.pos_conv.weight, mean=0, std=std)
nn.init.constant_(self.pos_conv.bias, 0)
self.pos_conv = nn.utils.weight_norm(self.pos_conv, name="weight", dim=2)
self.pos_conv = nn.Sequential(self.pos_conv, SamePad(args.conv_pos), nn.GELU())
self.layers = nn.ModuleList(
[
TransformerSentenceEncoderLayer(
embedding_dim=self.embedding_dim,
ffn_embedding_dim=args.encoder_ffn_embed_dim,
num_attention_heads=args.encoder_attention_heads,
dropout=self.dropout,
attention_dropout=args.attention_dropout,
activation_dropout=args.activation_dropout,
activation_fn=args.activation_fn,
layer_norm_first=args.layer_norm_first,
)
for _ in range(args.encoder_layers)
]
)
self.layer_norm_first = args.layer_norm_first
self.layer_norm = LayerNorm(self.embedding_dim)
self.layerdrop = args.encoder_layerdrop
self.apply(init_bert_params)
def forward(self, x, padding_mask=None, streaming_mask=None, layer=None):
x, layer_results = self.extract_features(x, padding_mask, streaming_mask, layer)
if self.layer_norm_first and layer is None:
x = self.layer_norm(x)
return x, layer_results
def extract_features(self, x, padding_mask=None, streaming_mask=None, tgt_layer=None):
if padding_mask is not None:
x = index_put(x, padding_mask, 0)
x_conv = self.pos_conv(x.transpose(1, 2))
x_conv = x_conv.transpose(1, 2)
x = x + x_conv
if not self.layer_norm_first:
x = self.layer_norm(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
layer_results = []
r = None
pos_bias = None
for i, layer in enumerate(self.layers):
dropout_probability = np.random.random()
if not self.training or (dropout_probability > self.layerdrop):
x, z, pos_bias = layer(x, self_attn_padding_mask=padding_mask, need_weights=False, self_attn_mask=streaming_mask, pos_bias=pos_bias)
if isinstance(tgt_layer, list) and i+1 in tgt_layer:
layer_results.append((x, z))
elif isinstance(tgt_layer, int) and i == tgt_layer:
r = x
break
else:
continue
if r is not None:
x = r
# T x B x C -> B x T x C
x = x.transpose(0, 1)
return x, layer_results
def max_positions(self):
"""Maximum output length supported by the encoder."""
return self.args.max_positions
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
return state_dict
class TransformerSentenceEncoderLayer(nn.Module):
"""
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
models.
"""
def __init__(
self,
embedding_dim: float = 768,
ffn_embedding_dim: float = 3072,
num_attention_heads: float = 8,
dropout: float = 0.1,
attention_dropout: float = 0.1,
activation_dropout: float = 0.1,
activation_fn: str = "relu",
layer_norm_first: bool = False,
) -> None:
super().__init__()
# Initialize parameters
self.embedding_dim = embedding_dim
self.dropout = dropout
self.activation_dropout = activation_dropout
# Initialize blocks
self.activation_fn = utils.get_activation_fn(activation_fn)
self.self_attn = MultiheadAttention(
self.embedding_dim,
num_attention_heads,
dropout=attention_dropout,
self_attention=True,
)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(self.activation_dropout)
self.dropout3 = nn.Dropout(dropout)
self.layer_norm_first = layer_norm_first
# layer norm associated with the self attention layer
self.self_attn_layer_norm = LayerNorm(self.embedding_dim)
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
# layer norm associated with the position wise feed-forward NN
self.final_layer_norm = LayerNorm(self.embedding_dim)
def forward(
self,
x: torch.Tensor,
self_attn_mask: torch.Tensor = None,
self_attn_padding_mask: torch.Tensor = None,
need_weights: bool = False,
att_args=None,
pos_bias=None
):
"""
LayerNorm is applied either before or after the self-attention/ffn
modules similar to the original Transformer imlementation.
"""
residual = x
if self.layer_norm_first:
x = self.self_attn_layer_norm(x)
x, attn, pos_bias = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
attn_mask=self_attn_mask,
position_bias=pos_bias
)
x = self.dropout1(x)
x = residual + x
residual = x
x = self.final_layer_norm(x)
x = self.activation_fn(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
x = self.dropout3(x)
x = residual + x
else:
x, attn, pos_bias = self.self_attn(
query=x,
key=x,
value=x,
key_padding_mask=self_attn_padding_mask,
attn_mask=self_attn_mask,
position_bias=pos_bias
)
x = self.dropout1(x)
x = residual + x
x = self.self_attn_layer_norm(x)
residual = x
x = self.activation_fn(self.fc1(x))
x = self.dropout2(x)
x = self.fc2(x)
x = self.dropout3(x)
x = residual + x
x = self.final_layer_norm(x)
return x, attn, pos_bias