|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Dict, List, Optional, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch import Tensor |
|
|
|
|
|
import logging |
|
|
|
|
|
from fairseq import utils |
|
|
from fairseq.dataclass.utils import gen_parser_from_dataclass |
|
|
from fairseq.distributed import fsdp_wrap |
|
|
from fairseq.models import FairseqEncoderDecoderModel |
|
|
from fairseq.models.transformer import ( |
|
|
TransformerConfig, |
|
|
TransformerDecoderBase, |
|
|
TransformerEncoderBase, |
|
|
) |
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class TransformerModelBase(FairseqEncoderDecoderModel): |
|
|
""" |
|
|
Transformer model from `"Attention Is All You Need" (Vaswani, et al, 2017) |
|
|
<https://arxiv.org/abs/1706.03762>`_. |
|
|
|
|
|
Args: |
|
|
encoder (TransformerEncoder): the encoder |
|
|
decoder (TransformerDecoder): the decoder |
|
|
|
|
|
The Transformer model provides the following named architectures and |
|
|
command-line arguments: |
|
|
|
|
|
.. argparse:: |
|
|
:ref: fairseq.models.transformer_parser |
|
|
:prog: |
|
|
""" |
|
|
|
|
|
def __init__(self, cfg, encoder, decoder): |
|
|
super().__init__(encoder, decoder) |
|
|
self.cfg = cfg |
|
|
self.supports_align_args = True |
|
|
|
|
|
@classmethod |
|
|
def add_args(cls, parser): |
|
|
"""Add model-specific arguments to the parser.""" |
|
|
|
|
|
gen_parser_from_dataclass( |
|
|
parser, TransformerConfig(), delete_default=False, with_prefix="" |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def build_model(cls, cfg, task): |
|
|
"""Build a new model instance.""" |
|
|
|
|
|
|
|
|
|
|
|
cfg.decoder.input_dim = int(cfg.decoder.input_dim) |
|
|
cfg.decoder.output_dim = int(cfg.decoder.output_dim) |
|
|
|
|
|
|
|
|
if cfg.encoder.layers_to_keep: |
|
|
cfg.encoder.layers = len(cfg.encoder.layers_to_keep.split(",")) |
|
|
if cfg.decoder.layers_to_keep: |
|
|
cfg.decoder.layers = len(cfg.decoder.layers_to_keep.split(",")) |
|
|
|
|
|
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary |
|
|
|
|
|
if cfg.share_all_embeddings: |
|
|
if src_dict != tgt_dict: |
|
|
raise ValueError("--share-all-embeddings requires a joined dictionary") |
|
|
if cfg.encoder.embed_dim != cfg.decoder.embed_dim: |
|
|
raise ValueError( |
|
|
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim" |
|
|
) |
|
|
if cfg.decoder.embed_path and ( |
|
|
cfg.decoder.embed_path != cfg.encoder.embed_path |
|
|
): |
|
|
raise ValueError( |
|
|
"--share-all-embeddings not compatible with --decoder-embed-path" |
|
|
) |
|
|
encoder_embed_tokens = cls.build_embedding( |
|
|
cfg, src_dict, cfg.encoder.embed_dim, cfg.encoder.embed_path |
|
|
) |
|
|
decoder_embed_tokens = encoder_embed_tokens |
|
|
cfg.share_decoder_input_output_embed = True |
|
|
elif cfg.merge_src_tgt_embed: |
|
|
logger.info(f"source dict size: {len(src_dict)}") |
|
|
logger.info(f"target dict size: {len(tgt_dict)}") |
|
|
src_dict.update(tgt_dict) |
|
|
task.src_dict = src_dict |
|
|
task.tgt_dict = src_dict |
|
|
logger.info(f"merged dict size: {len(src_dict)}") |
|
|
encoder_embed_tokens = cls.build_embedding( |
|
|
cfg, src_dict, cfg.encoder.embed_dim |
|
|
) |
|
|
decoder_embed_tokens = encoder_embed_tokens |
|
|
cfg.share_decoder_input_output_embed = True |
|
|
else: |
|
|
encoder_embed_tokens = cls.build_embedding( |
|
|
cfg, src_dict, cfg.encoder.embed_dim, cfg.encoder.embed_path |
|
|
) |
|
|
decoder_embed_tokens = cls.build_embedding( |
|
|
cfg, tgt_dict, cfg.decoder.embed_dim, cfg.decoder.embed_path |
|
|
) |
|
|
if cfg.offload_activations: |
|
|
cfg.checkpoint_activations = True |
|
|
encoder = cls.build_encoder(cfg, src_dict, encoder_embed_tokens) |
|
|
decoder = cls.build_decoder(cfg, tgt_dict, decoder_embed_tokens) |
|
|
return cls(cfg, encoder, decoder) |
|
|
|
|
|
@classmethod |
|
|
def build_embedding(cls, cfg, dictionary, embed_dim, path=None): |
|
|
num_embeddings = len(dictionary) |
|
|
padding_idx = dictionary.pad() |
|
|
|
|
|
emb = Embedding(num_embeddings, embed_dim, padding_idx) |
|
|
|
|
|
if path: |
|
|
embed_dict = utils.parse_embedding(path) |
|
|
utils.load_embedding(embed_dict, dictionary, emb) |
|
|
return emb |
|
|
|
|
|
@classmethod |
|
|
def build_encoder(cls, cfg, src_dict, embed_tokens): |
|
|
return TransformerEncoderBase(cfg, src_dict, embed_tokens) |
|
|
|
|
|
@classmethod |
|
|
def build_decoder(cls, cfg, tgt_dict, embed_tokens): |
|
|
return TransformerDecoderBase( |
|
|
cfg, |
|
|
tgt_dict, |
|
|
embed_tokens, |
|
|
no_encoder_attn=cfg.no_cross_attention, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def forward( |
|
|
self, |
|
|
src_tokens, |
|
|
src_lengths, |
|
|
prev_output_tokens, |
|
|
return_all_hiddens: bool = True, |
|
|
features_only: bool = False, |
|
|
alignment_layer: Optional[int] = None, |
|
|
alignment_heads: Optional[int] = None, |
|
|
): |
|
|
""" |
|
|
Run the forward pass for an encoder-decoder model. |
|
|
|
|
|
Copied from the base class, but without ``**kwargs``, |
|
|
which are not supported by TorchScript. |
|
|
""" |
|
|
encoder_out = self.encoder( |
|
|
src_tokens, src_lengths=src_lengths, return_all_hiddens=return_all_hiddens |
|
|
) |
|
|
decoder_out = self.decoder( |
|
|
prev_output_tokens, |
|
|
encoder_out=encoder_out, |
|
|
features_only=features_only, |
|
|
alignment_layer=alignment_layer, |
|
|
alignment_heads=alignment_heads, |
|
|
src_lengths=src_lengths, |
|
|
return_all_hiddens=return_all_hiddens, |
|
|
) |
|
|
return decoder_out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.jit.export |
|
|
def get_normalized_probs( |
|
|
self, |
|
|
net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], |
|
|
log_probs: bool, |
|
|
sample: Optional[Dict[str, Tensor]] = None, |
|
|
): |
|
|
"""Get normalized probabilities (or log probs) from a net's output.""" |
|
|
return self.get_normalized_probs_scriptable(net_output, log_probs, sample) |
|
|
|
|
|
|
|
|
def Embedding(num_embeddings, embedding_dim, padding_idx): |
|
|
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) |
|
|
nn.init.normal_(m.weight, mean=0, std=embedding_dim**-0.5) |
|
|
nn.init.constant_(m.weight[padding_idx], 0) |
|
|
return m |
|
|
|