|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
from dataclasses import dataclass |
|
|
from functools import partial |
|
|
from typing import Callable, Dict, Optional |
|
|
|
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
from fairseq.modules import PositionalEmbedding, FairseqDropout, LayerNorm |
|
|
from fairseq.tasks import FairseqTask |
|
|
from .base import D2vModalityConfig, ModalitySpecificEncoder, get_alibi_bias |
|
|
from .modules import BlockEncoder, Decoder1d |
|
|
from examples.data2vec.data.modality import Modality |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class D2vTextConfig(D2vModalityConfig): |
|
|
type: Modality = Modality.TEXT |
|
|
max_source_positions: int = 512 |
|
|
learned_pos: bool = True |
|
|
dropout: float = 0.1 |
|
|
|
|
|
no_scale_embedding: bool = True |
|
|
layernorm_embedding: bool = True |
|
|
no_token_positional_embeddings: bool = False |
|
|
|
|
|
|
|
|
class TextEncoder(ModalitySpecificEncoder): |
|
|
|
|
|
modality_cfg: D2vTextConfig |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
modality_cfg: D2vTextConfig, |
|
|
embed_dim: int, |
|
|
make_block: Callable[[float], nn.ModuleList], |
|
|
norm_layer: Callable[[int], nn.LayerNorm], |
|
|
layer_norm_first: bool, |
|
|
alibi_biases: Dict, |
|
|
task: Optional[FairseqTask], |
|
|
): |
|
|
self.pad_idx = task.source_dictionary.pad() |
|
|
self.vocab_size = len(task.source_dictionary) |
|
|
|
|
|
local_encoder = TextLocalEncoder( |
|
|
vocab_size=self.vocab_size, |
|
|
embed_dim=embed_dim, |
|
|
max_source_positions=modality_cfg.max_source_positions, |
|
|
pad_idx=self.pad_idx, |
|
|
no_scale_embedding=modality_cfg.no_scale_embedding, |
|
|
layernorm_embedding=modality_cfg.layernorm_embedding, |
|
|
dropout=modality_cfg.dropout, |
|
|
no_token_positional_embeddings=modality_cfg.no_token_positional_embeddings, |
|
|
learned_pos=modality_cfg.learned_pos, |
|
|
) |
|
|
dpr = np.linspace( |
|
|
modality_cfg.start_drop_path_rate, |
|
|
modality_cfg.end_drop_path_rate, |
|
|
modality_cfg.prenet_depth, |
|
|
) |
|
|
context_encoder = BlockEncoder( |
|
|
nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)), |
|
|
norm_layer(embed_dim) |
|
|
if not layer_norm_first and modality_cfg.prenet_depth > 0 |
|
|
else None, |
|
|
layer_norm_first, |
|
|
modality_cfg.prenet_layerdrop, |
|
|
modality_cfg.prenet_dropout if modality_cfg.prenet_depth > 0 else 0.0, |
|
|
) |
|
|
decoder = ( |
|
|
Decoder1d(modality_cfg.decoder, embed_dim) |
|
|
if modality_cfg.decoder is not None |
|
|
else None |
|
|
) |
|
|
|
|
|
alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases) |
|
|
|
|
|
super().__init__( |
|
|
modality_cfg=modality_cfg, |
|
|
embed_dim=embed_dim, |
|
|
local_encoder=local_encoder, |
|
|
project_features=nn.Identity(), |
|
|
fixed_positional_encoder=None, |
|
|
relative_positional_encoder=None, |
|
|
context_encoder=context_encoder, |
|
|
decoder=decoder, |
|
|
get_alibi_bias=alibi_bias_fn, |
|
|
) |
|
|
|
|
|
def reset_parameters(self): |
|
|
super().reset_parameters() |
|
|
|
|
|
def convert_padding_mask(self, x, padding_mask): |
|
|
if padding_mask is None or padding_mask.size(1) == x.size(1): |
|
|
return padding_mask |
|
|
|
|
|
diff = self.downsample - padding_mask.size(1) % self.downsample |
|
|
if 0 < diff < self.downsample: |
|
|
padding_mask = F.pad(padding_mask, (0, diff), value=True) |
|
|
|
|
|
padding_mask = padding_mask.view(padding_mask.size(0), -1, self.downsample) |
|
|
padding_mask = padding_mask.all(-1) |
|
|
if padding_mask.size(1) > x.size(1): |
|
|
padding_mask = padding_mask[:, : x.size(1)] |
|
|
|
|
|
assert x.size(1) == padding_mask.size( |
|
|
1 |
|
|
), f"{x.size(1), padding_mask.size(1), diff, self.downsample}" |
|
|
|
|
|
return padding_mask |
|
|
|
|
|
|
|
|
class TextLocalEncoder(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
vocab_size, |
|
|
embed_dim, |
|
|
max_source_positions, |
|
|
pad_idx, |
|
|
no_scale_embedding, |
|
|
layernorm_embedding, |
|
|
dropout, |
|
|
no_token_positional_embeddings, |
|
|
learned_pos, |
|
|
): |
|
|
super().__init__() |
|
|
self.pad_idx = pad_idx |
|
|
self.dropout_module = FairseqDropout(dropout) |
|
|
|
|
|
self.embed_tokens = nn.Embedding(vocab_size, embed_dim, pad_idx) |
|
|
self.embed_scale = 1.0 if no_scale_embedding else math.sqrt(embed_dim) |
|
|
self.embed_positions = ( |
|
|
PositionalEmbedding( |
|
|
max_source_positions, |
|
|
embed_dim, |
|
|
pad_idx, |
|
|
learned=learned_pos, |
|
|
) |
|
|
if not no_token_positional_embeddings |
|
|
else None |
|
|
) |
|
|
self.embed_scale = 1.0 if no_scale_embedding else math.sqrt(embed_dim) |
|
|
|
|
|
self.layernorm_embedding = None |
|
|
if layernorm_embedding: |
|
|
self.layernorm_embedding = LayerNorm(embed_dim) |
|
|
|
|
|
def forward(self, src_tokens): |
|
|
x = self.embed_scale * self.embed_tokens(src_tokens) |
|
|
if self.embed_positions is not None: |
|
|
x = x + self.embed_positions(src_tokens) |
|
|
|
|
|
if self.layernorm_embedding is not None: |
|
|
x = self.layernorm_embedding(x) |
|
|
x = self.dropout_module(x) |
|
|
return x |
|
|
|