PyTorch
ssl-aasist
custom_code
ash56's picture
Add files using upload-large-folder tool
6789f6f verified
raw
history blame
5.51 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass
from functools import partial
from typing import Callable, Dict, Optional
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from fairseq.modules import PositionalEmbedding, FairseqDropout, LayerNorm
from fairseq.tasks import FairseqTask
from .base import D2vModalityConfig, ModalitySpecificEncoder, get_alibi_bias
from .modules import BlockEncoder, Decoder1d
from examples.data2vec.data.modality import Modality
@dataclass
class D2vTextConfig(D2vModalityConfig):
type: Modality = Modality.TEXT
max_source_positions: int = 512
learned_pos: bool = True
dropout: float = 0.1 # used for both local_encoder and contextualized encoder. tied with global transformer in data2vec_text
no_scale_embedding: bool = True
layernorm_embedding: bool = True
no_token_positional_embeddings: bool = False
class TextEncoder(ModalitySpecificEncoder):
modality_cfg: D2vTextConfig
def __init__(
self,
modality_cfg: D2vTextConfig,
embed_dim: int,
make_block: Callable[[float], nn.ModuleList],
norm_layer: Callable[[int], nn.LayerNorm],
layer_norm_first: bool,
alibi_biases: Dict,
task: Optional[FairseqTask],
):
self.pad_idx = task.source_dictionary.pad()
self.vocab_size = len(task.source_dictionary)
local_encoder = TextLocalEncoder(
vocab_size=self.vocab_size,
embed_dim=embed_dim,
max_source_positions=modality_cfg.max_source_positions,
pad_idx=self.pad_idx,
no_scale_embedding=modality_cfg.no_scale_embedding,
layernorm_embedding=modality_cfg.layernorm_embedding,
dropout=modality_cfg.dropout,
no_token_positional_embeddings=modality_cfg.no_token_positional_embeddings,
learned_pos=modality_cfg.learned_pos,
)
dpr = np.linspace(
modality_cfg.start_drop_path_rate,
modality_cfg.end_drop_path_rate,
modality_cfg.prenet_depth,
)
context_encoder = BlockEncoder(
nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)),
norm_layer(embed_dim)
if not layer_norm_first and modality_cfg.prenet_depth > 0
else None,
layer_norm_first,
modality_cfg.prenet_layerdrop,
modality_cfg.prenet_dropout if modality_cfg.prenet_depth > 0 else 0.0,
)
decoder = (
Decoder1d(modality_cfg.decoder, embed_dim)
if modality_cfg.decoder is not None
else None
)
alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases)
super().__init__(
modality_cfg=modality_cfg,
embed_dim=embed_dim,
local_encoder=local_encoder,
project_features=nn.Identity(),
fixed_positional_encoder=None,
relative_positional_encoder=None,
context_encoder=context_encoder,
decoder=decoder,
get_alibi_bias=alibi_bias_fn,
)
def reset_parameters(self):
super().reset_parameters()
def convert_padding_mask(self, x, padding_mask):
if padding_mask is None or padding_mask.size(1) == x.size(1):
return padding_mask
diff = self.downsample - padding_mask.size(1) % self.downsample
if 0 < diff < self.downsample:
padding_mask = F.pad(padding_mask, (0, diff), value=True)
padding_mask = padding_mask.view(padding_mask.size(0), -1, self.downsample)
padding_mask = padding_mask.all(-1)
if padding_mask.size(1) > x.size(1):
padding_mask = padding_mask[:, : x.size(1)]
assert x.size(1) == padding_mask.size(
1
), f"{x.size(1), padding_mask.size(1), diff, self.downsample}"
return padding_mask
class TextLocalEncoder(nn.Module):
def __init__(
self,
vocab_size,
embed_dim,
max_source_positions,
pad_idx,
no_scale_embedding,
layernorm_embedding,
dropout,
no_token_positional_embeddings,
learned_pos,
):
super().__init__()
self.pad_idx = pad_idx
self.dropout_module = FairseqDropout(dropout)
self.embed_tokens = nn.Embedding(vocab_size, embed_dim, pad_idx)
self.embed_scale = 1.0 if no_scale_embedding else math.sqrt(embed_dim)
self.embed_positions = (
PositionalEmbedding(
max_source_positions,
embed_dim,
pad_idx,
learned=learned_pos,
)
if not no_token_positional_embeddings
else None
)
self.embed_scale = 1.0 if no_scale_embedding else math.sqrt(embed_dim)
self.layernorm_embedding = None
if layernorm_embedding:
self.layernorm_embedding = LayerNorm(embed_dim)
def forward(self, src_tokens):
x = self.embed_scale * self.embed_tokens(src_tokens)
if self.embed_positions is not None:
x = x + self.embed_positions(src_tokens)
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
return x