File size: 5,514 Bytes
6789f6f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import math
from dataclasses import dataclass
from functools import partial
from typing import Callable, Dict, Optional
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from fairseq.modules import PositionalEmbedding, FairseqDropout, LayerNorm
from fairseq.tasks import FairseqTask
from .base import D2vModalityConfig, ModalitySpecificEncoder, get_alibi_bias
from .modules import BlockEncoder, Decoder1d
from examples.data2vec.data.modality import Modality
@dataclass
class D2vTextConfig(D2vModalityConfig):
type: Modality = Modality.TEXT
max_source_positions: int = 512
learned_pos: bool = True
dropout: float = 0.1 # used for both local_encoder and contextualized encoder. tied with global transformer in data2vec_text
no_scale_embedding: bool = True
layernorm_embedding: bool = True
no_token_positional_embeddings: bool = False
class TextEncoder(ModalitySpecificEncoder):
modality_cfg: D2vTextConfig
def __init__(
self,
modality_cfg: D2vTextConfig,
embed_dim: int,
make_block: Callable[[float], nn.ModuleList],
norm_layer: Callable[[int], nn.LayerNorm],
layer_norm_first: bool,
alibi_biases: Dict,
task: Optional[FairseqTask],
):
self.pad_idx = task.source_dictionary.pad()
self.vocab_size = len(task.source_dictionary)
local_encoder = TextLocalEncoder(
vocab_size=self.vocab_size,
embed_dim=embed_dim,
max_source_positions=modality_cfg.max_source_positions,
pad_idx=self.pad_idx,
no_scale_embedding=modality_cfg.no_scale_embedding,
layernorm_embedding=modality_cfg.layernorm_embedding,
dropout=modality_cfg.dropout,
no_token_positional_embeddings=modality_cfg.no_token_positional_embeddings,
learned_pos=modality_cfg.learned_pos,
)
dpr = np.linspace(
modality_cfg.start_drop_path_rate,
modality_cfg.end_drop_path_rate,
modality_cfg.prenet_depth,
)
context_encoder = BlockEncoder(
nn.ModuleList(make_block(dpr[i]) for i in range(modality_cfg.prenet_depth)),
norm_layer(embed_dim)
if not layer_norm_first and modality_cfg.prenet_depth > 0
else None,
layer_norm_first,
modality_cfg.prenet_layerdrop,
modality_cfg.prenet_dropout if modality_cfg.prenet_depth > 0 else 0.0,
)
decoder = (
Decoder1d(modality_cfg.decoder, embed_dim)
if modality_cfg.decoder is not None
else None
)
alibi_bias_fn = partial(get_alibi_bias, alibi_biases=alibi_biases)
super().__init__(
modality_cfg=modality_cfg,
embed_dim=embed_dim,
local_encoder=local_encoder,
project_features=nn.Identity(),
fixed_positional_encoder=None,
relative_positional_encoder=None,
context_encoder=context_encoder,
decoder=decoder,
get_alibi_bias=alibi_bias_fn,
)
def reset_parameters(self):
super().reset_parameters()
def convert_padding_mask(self, x, padding_mask):
if padding_mask is None or padding_mask.size(1) == x.size(1):
return padding_mask
diff = self.downsample - padding_mask.size(1) % self.downsample
if 0 < diff < self.downsample:
padding_mask = F.pad(padding_mask, (0, diff), value=True)
padding_mask = padding_mask.view(padding_mask.size(0), -1, self.downsample)
padding_mask = padding_mask.all(-1)
if padding_mask.size(1) > x.size(1):
padding_mask = padding_mask[:, : x.size(1)]
assert x.size(1) == padding_mask.size(
1
), f"{x.size(1), padding_mask.size(1), diff, self.downsample}"
return padding_mask
class TextLocalEncoder(nn.Module):
def __init__(
self,
vocab_size,
embed_dim,
max_source_positions,
pad_idx,
no_scale_embedding,
layernorm_embedding,
dropout,
no_token_positional_embeddings,
learned_pos,
):
super().__init__()
self.pad_idx = pad_idx
self.dropout_module = FairseqDropout(dropout)
self.embed_tokens = nn.Embedding(vocab_size, embed_dim, pad_idx)
self.embed_scale = 1.0 if no_scale_embedding else math.sqrt(embed_dim)
self.embed_positions = (
PositionalEmbedding(
max_source_positions,
embed_dim,
pad_idx,
learned=learned_pos,
)
if not no_token_positional_embeddings
else None
)
self.embed_scale = 1.0 if no_scale_embedding else math.sqrt(embed_dim)
self.layernorm_embedding = None
if layernorm_embedding:
self.layernorm_embedding = LayerNorm(embed_dim)
def forward(self, src_tokens):
x = self.embed_scale * self.embed_tokens(src_tokens)
if self.embed_positions is not None:
x = x + self.embed_positions(src_tokens)
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
return x
|