|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
from typing import Dict, List, Optional |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch import Tensor |
|
|
|
|
|
from fairseq import utils |
|
|
from fairseq.distributed import fsdp_wrap |
|
|
from fairseq.models import FairseqEncoder |
|
|
from fairseq.models.transformer import TransformerConfig |
|
|
from fairseq.modules import ( |
|
|
FairseqDropout, |
|
|
LayerDropModuleList, |
|
|
LayerNorm, |
|
|
PositionalEmbedding, |
|
|
SinusoidalPositionalEmbedding, |
|
|
transformer_layer, |
|
|
) |
|
|
from fairseq.modules.checkpoint_activations import checkpoint_wrapper |
|
|
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_ |
|
|
|
|
|
|
|
|
|
|
|
def module_name_fordropout(module_name: str) -> str: |
|
|
if module_name == "TransformerEncoderBase": |
|
|
return "TransformerEncoder" |
|
|
else: |
|
|
return module_name |
|
|
|
|
|
|
|
|
class TransformerEncoderBase(FairseqEncoder): |
|
|
""" |
|
|
Transformer encoder consisting of *cfg.encoder.layers* layers. Each layer |
|
|
is a :class:`TransformerEncoderLayer`. |
|
|
|
|
|
Args: |
|
|
args (argparse.Namespace): parsed command-line arguments |
|
|
dictionary (~fairseq.data.Dictionary): encoding dictionary |
|
|
embed_tokens (torch.nn.Embedding): input embedding |
|
|
""" |
|
|
|
|
|
def __init__(self, cfg, dictionary, embed_tokens, return_fc=False): |
|
|
self.cfg = cfg |
|
|
super().__init__(dictionary) |
|
|
self.register_buffer("version", torch.Tensor([3])) |
|
|
|
|
|
self.dropout_module = FairseqDropout( |
|
|
cfg.dropout, module_name=module_name_fordropout(self.__class__.__name__) |
|
|
) |
|
|
self.encoder_layerdrop = cfg.encoder.layerdrop |
|
|
self.return_fc = return_fc |
|
|
|
|
|
embed_dim = embed_tokens.embedding_dim |
|
|
self.padding_idx = embed_tokens.padding_idx |
|
|
self.max_source_positions = cfg.max_source_positions |
|
|
|
|
|
self.embed_tokens = embed_tokens |
|
|
|
|
|
self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim) |
|
|
|
|
|
self.embed_positions = ( |
|
|
PositionalEmbedding( |
|
|
cfg.max_source_positions, |
|
|
embed_dim, |
|
|
self.padding_idx, |
|
|
learned=cfg.encoder.learned_pos, |
|
|
) |
|
|
if not cfg.no_token_positional_embeddings |
|
|
else None |
|
|
) |
|
|
if cfg.layernorm_embedding: |
|
|
self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export) |
|
|
else: |
|
|
self.layernorm_embedding = None |
|
|
|
|
|
if not cfg.adaptive_input and cfg.quant_noise.pq > 0: |
|
|
self.quant_noise = apply_quant_noise_( |
|
|
nn.Linear(embed_dim, embed_dim, bias=False), |
|
|
cfg.quant_noise.pq, |
|
|
cfg.quant_noise.pq_block_size, |
|
|
) |
|
|
else: |
|
|
self.quant_noise = None |
|
|
|
|
|
if self.encoder_layerdrop > 0.0: |
|
|
self.layers = LayerDropModuleList(p=self.encoder_layerdrop) |
|
|
else: |
|
|
self.layers = nn.ModuleList([]) |
|
|
self.layers.extend( |
|
|
[self.build_encoder_layer(cfg) for i in range(cfg.encoder.layers)] |
|
|
) |
|
|
self.num_layers = len(self.layers) |
|
|
|
|
|
if cfg.encoder.normalize_before: |
|
|
self.layer_norm = LayerNorm(embed_dim, export=cfg.export) |
|
|
else: |
|
|
self.layer_norm = None |
|
|
|
|
|
def build_encoder_layer(self, cfg): |
|
|
layer = transformer_layer.TransformerEncoderLayerBase( |
|
|
cfg, return_fc=self.return_fc |
|
|
) |
|
|
checkpoint = cfg.checkpoint_activations |
|
|
if checkpoint: |
|
|
offload_to_cpu = cfg.offload_activations |
|
|
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu) |
|
|
|
|
|
|
|
|
min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0 |
|
|
layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap) |
|
|
return layer |
|
|
|
|
|
def forward_embedding( |
|
|
self, src_tokens, token_embedding: Optional[torch.Tensor] = None |
|
|
): |
|
|
|
|
|
if token_embedding is None: |
|
|
token_embedding = self.embed_tokens(src_tokens) |
|
|
x = embed = self.embed_scale * token_embedding |
|
|
if self.embed_positions is not None: |
|
|
x = embed + self.embed_positions(src_tokens) |
|
|
if self.layernorm_embedding is not None: |
|
|
x = self.layernorm_embedding(x) |
|
|
x = self.dropout_module(x) |
|
|
if self.quant_noise is not None: |
|
|
x = self.quant_noise(x) |
|
|
return x, embed |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
src_tokens, |
|
|
src_lengths: Optional[torch.Tensor] = None, |
|
|
return_all_hiddens: bool = False, |
|
|
token_embeddings: Optional[torch.Tensor] = None, |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
src_tokens (LongTensor): tokens in the source language of shape |
|
|
`(batch, src_len)` |
|
|
src_lengths (torch.LongTensor): lengths of each source sentence of |
|
|
shape `(batch)` |
|
|
return_all_hiddens (bool, optional): also return all of the |
|
|
intermediate hidden states (default: False). |
|
|
token_embeddings (torch.Tensor, optional): precomputed embeddings |
|
|
default `None` will recompute embeddings |
|
|
|
|
|
Returns: |
|
|
dict: |
|
|
- **encoder_out** (Tensor): the last encoder layer's output of |
|
|
shape `(src_len, batch, embed_dim)` |
|
|
- **encoder_padding_mask** (ByteTensor): the positions of |
|
|
padding elements of shape `(batch, src_len)` |
|
|
- **encoder_embedding** (Tensor): the (scaled) embedding lookup |
|
|
of shape `(batch, src_len, embed_dim)` |
|
|
- **encoder_states** (List[Tensor]): all intermediate |
|
|
hidden states of shape `(src_len, batch, embed_dim)`. |
|
|
Only populated if *return_all_hiddens* is True. |
|
|
""" |
|
|
return self.forward_scriptable( |
|
|
src_tokens, src_lengths, return_all_hiddens, token_embeddings |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward_scriptable( |
|
|
self, |
|
|
src_tokens, |
|
|
src_lengths: Optional[torch.Tensor] = None, |
|
|
return_all_hiddens: bool = False, |
|
|
token_embeddings: Optional[torch.Tensor] = None, |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
src_tokens (LongTensor): tokens in the source language of shape |
|
|
`(batch, src_len)` |
|
|
src_lengths (torch.LongTensor): lengths of each source sentence of |
|
|
shape `(batch)` |
|
|
return_all_hiddens (bool, optional): also return all of the |
|
|
intermediate hidden states (default: False). |
|
|
token_embeddings (torch.Tensor, optional): precomputed embeddings |
|
|
default `None` will recompute embeddings |
|
|
|
|
|
Returns: |
|
|
dict: |
|
|
- **encoder_out** (Tensor): the last encoder layer's output of |
|
|
shape `(src_len, batch, embed_dim)` |
|
|
- **encoder_padding_mask** (ByteTensor): the positions of |
|
|
padding elements of shape `(batch, src_len)` |
|
|
- **encoder_embedding** (Tensor): the (scaled) embedding lookup |
|
|
of shape `(batch, src_len, embed_dim)` |
|
|
- **encoder_states** (List[Tensor]): all intermediate |
|
|
hidden states of shape `(src_len, batch, embed_dim)`. |
|
|
Only populated if *return_all_hiddens* is True. |
|
|
""" |
|
|
|
|
|
encoder_padding_mask = src_tokens.eq(self.padding_idx) |
|
|
has_pads = ( |
|
|
torch.tensor(src_tokens.device.type == "xla") or encoder_padding_mask.any() |
|
|
) |
|
|
|
|
|
if torch.jit.is_scripting(): |
|
|
has_pads = torch.tensor(1) if has_pads else torch.tensor(0) |
|
|
|
|
|
x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings) |
|
|
|
|
|
|
|
|
x = x * ( |
|
|
1 - encoder_padding_mask.unsqueeze(-1).type_as(x) * has_pads.type_as(x) |
|
|
) |
|
|
|
|
|
|
|
|
x = x.transpose(0, 1) |
|
|
|
|
|
encoder_states = [] |
|
|
fc_results = [] |
|
|
|
|
|
if return_all_hiddens: |
|
|
encoder_states.append(x) |
|
|
|
|
|
|
|
|
for layer in self.layers: |
|
|
lr = layer( |
|
|
x, encoder_padding_mask=encoder_padding_mask if has_pads else None |
|
|
) |
|
|
|
|
|
if isinstance(lr, tuple) and len(lr) == 2: |
|
|
x, fc_result = lr |
|
|
else: |
|
|
x = lr |
|
|
fc_result = None |
|
|
|
|
|
if return_all_hiddens and not torch.jit.is_scripting(): |
|
|
assert encoder_states is not None |
|
|
encoder_states.append(x) |
|
|
fc_results.append(fc_result) |
|
|
|
|
|
if self.layer_norm is not None: |
|
|
x = self.layer_norm(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src_lengths = ( |
|
|
src_tokens.ne(self.padding_idx) |
|
|
.sum(dim=1, dtype=torch.int32) |
|
|
.reshape(-1, 1) |
|
|
.contiguous() |
|
|
) |
|
|
return { |
|
|
"encoder_out": [x], |
|
|
"encoder_padding_mask": [encoder_padding_mask], |
|
|
"encoder_embedding": [encoder_embedding], |
|
|
"encoder_states": encoder_states, |
|
|
"fc_results": fc_results, |
|
|
"src_tokens": [], |
|
|
"src_lengths": [src_lengths], |
|
|
} |
|
|
|
|
|
@torch.jit.export |
|
|
def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): |
|
|
""" |
|
|
Reorder encoder output according to *new_order*. |
|
|
|
|
|
Args: |
|
|
encoder_out: output from the ``forward()`` method |
|
|
new_order (LongTensor): desired order |
|
|
|
|
|
Returns: |
|
|
*encoder_out* rearranged according to *new_order* |
|
|
""" |
|
|
if len(encoder_out["encoder_out"]) == 0: |
|
|
new_encoder_out = [] |
|
|
else: |
|
|
new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)] |
|
|
if len(encoder_out["encoder_padding_mask"]) == 0: |
|
|
new_encoder_padding_mask = [] |
|
|
else: |
|
|
new_encoder_padding_mask = [ |
|
|
encoder_out["encoder_padding_mask"][0].index_select(0, new_order) |
|
|
] |
|
|
if len(encoder_out["encoder_embedding"]) == 0: |
|
|
new_encoder_embedding = [] |
|
|
else: |
|
|
new_encoder_embedding = [ |
|
|
encoder_out["encoder_embedding"][0].index_select(0, new_order) |
|
|
] |
|
|
|
|
|
if len(encoder_out["src_tokens"]) == 0: |
|
|
src_tokens = [] |
|
|
else: |
|
|
src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)] |
|
|
|
|
|
if len(encoder_out["src_lengths"]) == 0: |
|
|
src_lengths = [] |
|
|
else: |
|
|
src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)] |
|
|
|
|
|
encoder_states = encoder_out["encoder_states"] |
|
|
if len(encoder_states) > 0: |
|
|
for idx, state in enumerate(encoder_states): |
|
|
encoder_states[idx] = state.index_select(1, new_order) |
|
|
|
|
|
return { |
|
|
"encoder_out": new_encoder_out, |
|
|
"encoder_padding_mask": new_encoder_padding_mask, |
|
|
"encoder_embedding": new_encoder_embedding, |
|
|
"encoder_states": encoder_states, |
|
|
"src_tokens": src_tokens, |
|
|
"src_lengths": src_lengths, |
|
|
} |
|
|
|
|
|
@torch.jit.export |
|
|
def _reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): |
|
|
"""Dummy re-order function for beamable enc-dec attention""" |
|
|
return encoder_out |
|
|
|
|
|
def max_positions(self): |
|
|
"""Maximum input length supported by the encoder.""" |
|
|
if self.embed_positions is None: |
|
|
return self.max_source_positions |
|
|
return min(self.max_source_positions, self.embed_positions.max_positions) |
|
|
|
|
|
def upgrade_state_dict_named(self, state_dict, name): |
|
|
"""Upgrade a (possibly old) state dict for new versions of fairseq.""" |
|
|
for i in range(self.num_layers): |
|
|
|
|
|
self.layers[i].upgrade_state_dict_named( |
|
|
state_dict, "{}.layers.{}".format(name, i) |
|
|
) |
|
|
|
|
|
version_key = "{}.version".format(name) |
|
|
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) < 2: |
|
|
|
|
|
self.layer_norm = None |
|
|
self.normalize = False |
|
|
state_dict[version_key] = torch.Tensor([1]) |
|
|
return state_dict |
|
|
|
|
|
|
|
|
class TransformerEncoder(TransformerEncoderBase): |
|
|
def __init__(self, args, dictionary, embed_tokens, return_fc=False): |
|
|
self.args = args |
|
|
super().__init__( |
|
|
TransformerConfig.from_namespace(args), |
|
|
dictionary, |
|
|
embed_tokens, |
|
|
return_fc=return_fc, |
|
|
) |
|
|
|
|
|
def build_encoder_layer(self, args): |
|
|
return super().build_encoder_layer( |
|
|
TransformerConfig.from_namespace(args), |
|
|
) |
|
|
|