| |
| |
| |
| |
| """ |
| Base classes for various fairseq models. |
| """ |
|
|
| import logging |
| from argparse import Namespace |
| from typing import Dict, List, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from fairseq import utils |
| from fairseq.data import Dictionary |
| from fairseq.dataclass.utils import ( |
| convert_namespace_to_omegaconf, |
| gen_parser_from_dataclass, |
| ) |
| from fairseq.models import FairseqDecoder, FairseqEncoder |
| from omegaconf import DictConfig |
| from torch import Tensor |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def check_type(module, expected_type): |
| if hasattr(module, "unwrapped_module"): |
| assert isinstance( |
| module.unwrapped_module, expected_type |
| ), f"{type(module.unwrapped_module)} != {expected_type}" |
| else: |
| assert isinstance(module, expected_type), f"{type(module)} != {expected_type}" |
|
|
|
|
| class BaseFairseqModel(nn.Module): |
| """Base class for fairseq models.""" |
|
|
| def __init__(self): |
| super().__init__() |
| self._is_generation_fast = False |
|
|
| @classmethod |
| def add_args(cls, parser): |
| """Add model-specific arguments to the parser.""" |
| dc = getattr(cls, "__dataclass", None) |
| if dc is not None: |
| |
| gen_parser_from_dataclass(parser, dc(), delete_default=True) |
|
|
| @classmethod |
| def build_model(cls, args, task): |
| """Build a new model instance.""" |
| raise NotImplementedError("Model must implement the build_model method") |
|
|
| def get_targets(self, sample, net_output): |
| """Get targets from either the sample or the net's output.""" |
| return sample["target"] |
|
|
| def get_normalized_probs( |
| self, |
| net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], |
| log_probs: bool, |
| sample: Optional[Dict[str, Tensor]] = None, |
| ): |
| """Get normalized probabilities (or log probs) from a net's output.""" |
| return self.get_normalized_probs_scriptable(net_output, log_probs, sample) |
|
|
| |
| |
| |
| |
| def get_normalized_probs_scriptable( |
| self, |
| net_output: Tuple[Tensor, Optional[Dict[str, List[Optional[Tensor]]]]], |
| log_probs: bool, |
| sample: Optional[Dict[str, Tensor]] = None, |
| ): |
| """Scriptable helper function for get_normalized_probs in ~BaseFairseqModel""" |
| if hasattr(self, "decoder"): |
| return self.decoder.get_normalized_probs(net_output, log_probs, sample) |
| elif torch.is_tensor(net_output): |
| |
| |
| logits = net_output.float() |
| if log_probs: |
| return F.log_softmax(logits, dim=-1) |
| else: |
| return F.softmax(logits, dim=-1) |
| raise NotImplementedError |
|
|
| def extract_features(self, *args, **kwargs): |
| """Similar to *forward* but only return features.""" |
| return self(*args, **kwargs) |
|
|
| def max_positions(self): |
| """Maximum length supported by the model.""" |
| return None |
|
|
| def load_state_dict( |
| self, |
| state_dict, |
| strict=True, |
| model_cfg: Optional[DictConfig] = None, |
| args: Optional[Namespace] = None, |
| ): |
| """Copies parameters and buffers from *state_dict* into this module and |
| its descendants. |
| |
| Overrides the method in :class:`nn.Module`. Compared with that method |
| this additionally "upgrades" *state_dicts* from old checkpoints. |
| """ |
|
|
| if model_cfg is None and args is not None: |
| logger.warn( |
| "using 'args' is deprecated, please update your code to use dataclass config" |
| ) |
| model_cfg = convert_namespace_to_omegaconf(args).model |
|
|
| self.upgrade_state_dict(state_dict) |
|
|
| from fairseq.checkpoint_utils import prune_state_dict |
|
|
| new_state_dict = prune_state_dict(state_dict, model_cfg) |
| return super().load_state_dict(new_state_dict, strict) |
|
|
| def upgrade_state_dict(self, state_dict): |
| """Upgrade old state dicts to work with newer code.""" |
| self.upgrade_state_dict_named(state_dict, "") |
|
|
| def upgrade_state_dict_named(self, state_dict, name): |
| """Upgrade old state dicts to work with newer code. |
| |
| Args: |
| state_dict (dict): state dictionary to upgrade, in place |
| name (str): the state dict key corresponding to the current module |
| """ |
| assert state_dict is not None |
|
|
| def do_upgrade(m, prefix): |
| if len(prefix) > 0: |
| prefix += "." |
|
|
| for n, c in m.named_children(): |
| name = prefix + n |
| if hasattr(c, "upgrade_state_dict_named"): |
| c.upgrade_state_dict_named(state_dict, name) |
| elif hasattr(c, "upgrade_state_dict"): |
| c.upgrade_state_dict(state_dict) |
| do_upgrade(c, name) |
|
|
| do_upgrade(self, name) |
|
|
| def set_num_updates(self, num_updates): |
| """State from trainer to pass along to model at every update.""" |
| for m in self.modules(): |
| if hasattr(m, "set_num_updates") and m != self: |
| m.set_num_updates(num_updates) |
|
|
| def prepare_for_inference_(self, cfg: DictConfig): |
| """Prepare model for inference.""" |
| kwargs = {} |
| kwargs["beamable_mm_beam_size"] = ( |
| None |
| if getattr(cfg.generation, "no_beamable_mm", False) |
| else getattr(cfg.generation, "beam", 5) |
| ) |
| kwargs["need_attn"] = getattr(cfg.generation, "print_alignment", False) |
| if getattr(cfg.generation, "retain_dropout", False): |
| kwargs["retain_dropout"] = cfg.generation.retain_dropout |
| kwargs["retain_dropout_modules"] = cfg.generation.retain_dropout_modules |
| self.make_generation_fast_(**kwargs) |
|
|
| def make_generation_fast_(self, **kwargs): |
| """ |
| Legacy entry point to optimize model for faster generation. |
| Prefer prepare_for_inference_. |
| """ |
| if self._is_generation_fast: |
| return |
| self._is_generation_fast = True |
|
|
| |
| def apply_remove_weight_norm(module): |
| try: |
| nn.utils.remove_weight_norm(module) |
| except (AttributeError, ValueError): |
| return |
|
|
| self.apply(apply_remove_weight_norm) |
|
|
| def apply_make_generation_fast_(module, prefix): |
| if len(prefix) > 0: |
| prefix += "." |
|
|
| base_func = BaseFairseqModel.make_generation_fast_ |
| for n, m in module.named_modules(): |
| if ( |
| m != self |
| and hasattr(m, "make_generation_fast_") |
| |
| |
| and m.make_generation_fast_.__func__ is not base_func |
| ): |
| name = prefix + n |
| m.make_generation_fast_(name=name, **kwargs) |
|
|
| apply_make_generation_fast_(self, "") |
|
|
| def train(mode=True): |
| if mode: |
| raise RuntimeError("cannot train after make_generation_fast") |
|
|
| |
| self.eval() |
| self.train = train |
|
|
| def prepare_for_onnx_export_(self, **kwargs): |
| """Make model exportable via ONNX trace.""" |
| seen = set() |
|
|
| def apply_prepare_for_onnx_export_(module): |
| if ( |
| module != self |
| and hasattr(module, "prepare_for_onnx_export_") |
| and module not in seen |
| ): |
| seen.add(module) |
| module.prepare_for_onnx_export_(**kwargs) |
|
|
| self.apply(apply_prepare_for_onnx_export_) |
|
|
| @classmethod |
| def from_pretrained( |
| cls, |
| model_name_or_path, |
| checkpoint_file="model.pt", |
| data_name_or_path=".", |
| **kwargs, |
| ): |
| """ |
| Load a :class:`~fairseq.models.FairseqModel` from a pre-trained model |
| file. Downloads and caches the pre-trained model file if needed. |
| |
| The base implementation returns a |
| :class:`~fairseq.hub_utils.GeneratorHubInterface`, which can be used to |
| generate translations or sample from language models. The underlying |
| :class:`~fairseq.models.FairseqModel` can be accessed via the |
| *generator.models* attribute. |
| |
| Other models may override this to implement custom hub interfaces. |
| |
| Args: |
| model_name_or_path (str): either the name of a pre-trained model to |
| load or a path/URL to a pre-trained model state dict |
| checkpoint_file (str, optional): colon-separated list of checkpoint |
| files in the model archive to ensemble (default: 'model.pt') |
| data_name_or_path (str, optional): point args.data to the archive |
| at the given path/URL. Can start with '.' or './' to reuse the |
| model archive path. |
| """ |
| from fairseq import hub_utils |
|
|
| x = hub_utils.from_pretrained( |
| model_name_or_path, |
| checkpoint_file, |
| data_name_or_path, |
| archive_map=cls.hub_models(), |
| **kwargs, |
| ) |
| logger.info(x["args"]) |
| return hub_utils.GeneratorHubInterface(x["args"], x["task"], x["models"]) |
|
|
| @classmethod |
| def hub_models(cls): |
| return {} |
|
|
|
|
| class FairseqEncoderDecoderModel(BaseFairseqModel): |
| """Base class for encoder-decoder models. |
| |
| Args: |
| encoder (FairseqEncoder): the encoder |
| decoder (FairseqDecoder): the decoder |
| """ |
|
|
| def __init__(self, encoder, decoder): |
| super().__init__() |
|
|
| self.encoder = encoder |
| self.decoder = decoder |
|
|
| check_type(self.encoder, FairseqEncoder) |
| check_type(self.decoder, FairseqDecoder) |
|
|
| def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): |
| """ |
| Run the forward pass for an encoder-decoder model. |
| |
| First feed a batch of source tokens through the encoder. Then, feed the |
| encoder output and previous decoder outputs (i.e., teacher forcing) to |
| the decoder to produce the next outputs:: |
| |
| encoder_out = self.encoder(src_tokens, src_lengths) |
| return self.decoder(prev_output_tokens, encoder_out) |
| |
| Args: |
| src_tokens (LongTensor): tokens in the source language of shape |
| `(batch, src_len)` |
| src_lengths (LongTensor): source sentence lengths of shape `(batch)` |
| prev_output_tokens (LongTensor): previous decoder outputs of shape |
| `(batch, tgt_len)`, for teacher forcing |
| |
| Returns: |
| tuple: |
| - the decoder's output of shape `(batch, tgt_len, vocab)` |
| - a dictionary with any model-specific outputs |
| """ |
| encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) |
| decoder_out = self.decoder( |
| prev_output_tokens, encoder_out=encoder_out, **kwargs |
| ) |
| return decoder_out |
|
|
| def forward_decoder(self, prev_output_tokens, **kwargs): |
| return self.decoder(prev_output_tokens, **kwargs) |
|
|
| def extract_features(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): |
| """ |
| Similar to *forward* but only return features. |
| |
| Returns: |
| tuple: |
| - the decoder's features of shape `(batch, tgt_len, embed_dim)` |
| - a dictionary with any model-specific outputs |
| """ |
| encoder_out = self.encoder(src_tokens, src_lengths=src_lengths, **kwargs) |
| features = self.decoder.extract_features( |
| prev_output_tokens, encoder_out=encoder_out, **kwargs |
| ) |
| return features |
|
|
| def output_layer(self, features, **kwargs): |
| """Project features to the default output size (typically vocabulary size).""" |
| return self.decoder.output_layer(features, **kwargs) |
|
|
| def max_positions(self): |
| """Maximum length supported by the model.""" |
| return (self.encoder.max_positions(), self.decoder.max_positions()) |
|
|
| def max_decoder_positions(self): |
| """Maximum length supported by the decoder.""" |
| return self.decoder.max_positions() |
|
|
|
|
| class FairseqModel(FairseqEncoderDecoderModel): |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| utils.deprecation_warning( |
| "FairseqModel is deprecated, please use FairseqEncoderDecoderModel " |
| "or BaseFairseqModel instead", |
| stacklevel=4, |
| ) |
|
|
|
|
| class FairseqMultiModel(BaseFairseqModel): |
| """Base class for combining multiple encoder-decoder models.""" |
|
|
| def __init__(self, encoders, decoders): |
| super().__init__() |
| assert encoders.keys() == decoders.keys() |
| self.keys = list(encoders.keys()) |
| for key in self.keys: |
| check_type(encoders[key], FairseqEncoder) |
| check_type(decoders[key], FairseqDecoder) |
|
|
| self.models = nn.ModuleDict( |
| { |
| key: FairseqEncoderDecoderModel(encoders[key], decoders[key]) |
| for key in self.keys |
| } |
| ) |
|
|
| @staticmethod |
| def build_shared_embeddings( |
| dicts: Dict[str, Dictionary], |
| langs: List[str], |
| embed_dim: int, |
| build_embedding: callable, |
| pretrained_embed_path: Optional[str] = None, |
| ): |
| """ |
| Helper function to build shared embeddings for a set of languages after |
| checking that all dicts corresponding to those languages are equivalent. |
| |
| Args: |
| dicts: Dict of lang_id to its corresponding Dictionary |
| langs: languages that we want to share embeddings for |
| embed_dim: embedding dimension |
| build_embedding: callable function to actually build the embedding |
| pretrained_embed_path: Optional path to load pretrained embeddings |
| """ |
| shared_dict = dicts[langs[0]] |
| if any(dicts[lang] != shared_dict for lang in langs): |
| raise ValueError( |
| "--share-*-embeddings requires a joined dictionary: " |
| "--share-encoder-embeddings requires a joined source " |
| "dictionary, --share-decoder-embeddings requires a joined " |
| "target dictionary, and --share-all-embeddings requires a " |
| "joint source + target dictionary." |
| ) |
| return build_embedding(shared_dict, embed_dim, pretrained_embed_path) |
|
|
| def forward(self, src_tokens, src_lengths, prev_output_tokens, **kwargs): |
| raise NotImplementedError |
|
|
| def max_positions(self): |
| """Maximum length supported by the model.""" |
| return { |
| key: ( |
| self.models[key].encoder.max_positions(), |
| self.models[key].decoder.max_positions(), |
| ) |
| for key in self.keys |
| } |
|
|
| def max_decoder_positions(self): |
| """Maximum length supported by the decoder.""" |
| return min(model.decoder.max_positions() for model in self.models.values()) |
|
|
| @property |
| def encoder(self): |
| return self.models[self.keys[0]].encoder |
|
|
| @property |
| def decoder(self): |
| return self.models[self.keys[0]].decoder |
|
|
| def forward_decoder(self, prev_output_tokens, **kwargs): |
| return self.decoder(prev_output_tokens, **kwargs) |
|
|
| def load_state_dict( |
| self, |
| state_dict, |
| strict=True, |
| model_cfg=None, |
| args: Optional[Namespace] = None, |
| ): |
| """Copies parameters and buffers from *state_dict* into this module and |
| its descendants. |
| |
| Overrides the method in :class:`nn.Module`. Compared with that method |
| this additionally "upgrades" *state_dicts* from old checkpoints. |
| """ |
|
|
| if model_cfg is None and args is not None: |
| logger.warn( |
| "using 'args' is deprecated, please update your code to use dataclass config" |
| ) |
| model_cfg = convert_namespace_to_omegaconf(args).model |
|
|
| self.upgrade_state_dict(state_dict) |
|
|
| from fairseq.checkpoint_utils import prune_state_dict |
|
|
| new_state_dict = prune_state_dict(state_dict, model_cfg) |
| return super().load_state_dict(new_state_dict, strict) |
|
|
|
|
| class FairseqLanguageModel(BaseFairseqModel): |
| """Base class for decoder-only models. |
| |
| Args: |
| decoder (FairseqDecoder): the decoder |
| """ |
|
|
| def __init__(self, decoder): |
| super().__init__() |
| self.decoder = decoder |
| check_type(self.decoder, FairseqDecoder) |
|
|
| def forward(self, src_tokens, **kwargs): |
| """ |
| Run the forward pass for a decoder-only model. |
| |
| Feeds a batch of tokens through the decoder to predict the next tokens. |
| |
| Args: |
| src_tokens (LongTensor): tokens on which to condition the decoder, |
| of shape `(batch, tgt_len)` |
| src_lengths (LongTensor): source sentence lengths of shape `(batch)` |
| |
| Returns: |
| tuple: |
| - the decoder's output of shape `(batch, seq_len, vocab)` |
| - a dictionary with any model-specific outputs |
| """ |
| return self.decoder(src_tokens, **kwargs) |
|
|
| def forward_decoder(self, prev_output_tokens, **kwargs): |
| return self.decoder(prev_output_tokens, **kwargs) |
|
|
| def extract_features(self, src_tokens, **kwargs): |
| """ |
| Similar to *forward* but only return features. |
| |
| Returns: |
| tuple: |
| - the decoder's features of shape `(batch, seq_len, embed_dim)` |
| - a dictionary with any model-specific outputs |
| """ |
| return self.decoder.extract_features(src_tokens, **kwargs) |
|
|
| def output_layer(self, features, **kwargs): |
| """Project features to the default output size (typically vocabulary size).""" |
| return self.decoder.output_layer(features, **kwargs) |
|
|
| def max_positions(self): |
| """Maximum length supported by the model.""" |
| return self.decoder.max_positions() |
|
|
| def max_decoder_positions(self): |
| """Maximum length supported by the decoder.""" |
| return self.decoder.max_positions() |
|
|
| @property |
| def supported_targets(self): |
| return {"future"} |
|
|
|
|
| class FairseqEncoderModel(BaseFairseqModel): |
| """Base class for encoder-only models. |
| |
| Args: |
| encoder (FairseqEncoder): the encoder |
| """ |
|
|
| def __init__(self, encoder): |
| super().__init__() |
| self.encoder = encoder |
| check_type(self.encoder, FairseqEncoder) |
|
|
| def forward(self, src_tokens, src_lengths, **kwargs): |
| """ |
| Run the forward pass for a encoder-only model. |
| |
| Feeds a batch of tokens through the encoder to generate features. |
| |
| Args: |
| src_tokens (LongTensor): input tokens of shape `(batch, src_len)` |
| src_lengths (LongTensor): source sentence lengths of shape `(batch)` |
| |
| Returns: |
| the encoder's output, typically of shape `(batch, src_len, features)` |
| """ |
| return self.encoder(src_tokens, src_lengths, **kwargs) |
|
|
| def get_normalized_probs(self, net_output, log_probs, sample=None): |
| """Get normalized probabilities (or log probs) from a net's output.""" |
| encoder_out = net_output["encoder_out"] |
| if torch.is_tensor(encoder_out): |
| logits = encoder_out.float() |
| if log_probs: |
| return F.log_softmax(logits, dim=-1) |
| else: |
| return F.softmax(logits, dim=-1) |
| raise NotImplementedError |
|
|
| def max_positions(self): |
| """Maximum length supported by the model.""" |
| return self.encoder.max_positions() |
|
|