Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from dataclasses import dataclass, field | |
| import torch | |
| from omegaconf import II | |
| from fairseq import metrics, utils | |
| from fairseq.dataclass import ChoiceEnum | |
| from fairseq.tasks import register_task | |
| from fairseq.tasks.translation import TranslationConfig, TranslationTask | |
| from .logsumexp_moe import LogSumExpMoE | |
| from .mean_pool_gating_network import MeanPoolGatingNetwork | |
| METHOD_CHOICES = ChoiceEnum(["sMoElp", "sMoEup", "hMoElp", "hMoEup"]) | |
| class TranslationMoEConfig(TranslationConfig): | |
| method: METHOD_CHOICES = field( | |
| default="hMoEup", | |
| metadata={"help": "MoE method"}, | |
| ) | |
| num_experts: int = field( | |
| default=3, | |
| metadata={"help": "number of experts"}, | |
| ) | |
| mean_pool_gating_network: bool = field( | |
| default=False, | |
| metadata={"help": "use a simple mean-pooling gating network"}, | |
| ) | |
| mean_pool_gating_network_dropout: float = field( | |
| default=0, | |
| metadata={"help": "dropout for mean-pooling gating network"}, | |
| ) | |
| mean_pool_gating_network_encoder_dim: int = field( | |
| default=0, | |
| metadata={"help": "encoder output dim for mean-pooling gating network"}, | |
| ) | |
| gen_expert: int = field( | |
| default=0, | |
| metadata={"help": "which expert to use for generation"}, | |
| ) | |
| sentence_avg: bool = II("optimization.sentence_avg") | |
| class TranslationMoETask(TranslationTask): | |
| """ | |
| Translation task for Mixture of Experts (MoE) models. | |
| See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade" | |
| (Shen et al., 2019) <https://arxiv.org/abs/1902.07816>`_. | |
| Args: | |
| src_dict (~fairseq.data.Dictionary): dictionary for the source language | |
| tgt_dict (~fairseq.data.Dictionary): dictionary for the target language | |
| .. note:: | |
| The translation task is compatible with :mod:`fairseq-train`, | |
| :mod:`fairseq-generate` and :mod:`fairseq-interactive`. | |
| The translation task provides the following additional command-line | |
| arguments: | |
| .. argparse:: | |
| :ref: fairseq.tasks.translation_parser | |
| :prog: | |
| """ | |
| cfg: TranslationMoEConfig | |
| def __init__(self, cfg: TranslationMoEConfig, src_dict, tgt_dict): | |
| if cfg.method == "sMoElp": | |
| # soft MoE with learned prior | |
| self.uniform_prior = False | |
| self.hard_selection = False | |
| elif cfg.method == "sMoEup": | |
| # soft MoE with uniform prior | |
| self.uniform_prior = True | |
| self.hard_selection = False | |
| elif cfg.method == "hMoElp": | |
| # hard MoE with learned prior | |
| self.uniform_prior = False | |
| self.hard_selection = True | |
| elif cfg.method == "hMoEup": | |
| # hard MoE with uniform prior | |
| self.uniform_prior = True | |
| self.hard_selection = True | |
| # add indicator tokens for each expert | |
| for i in range(cfg.num_experts): | |
| # add to both dictionaries in case we're sharing embeddings | |
| src_dict.add_symbol("<expert_{}>".format(i)) | |
| tgt_dict.add_symbol("<expert_{}>".format(i)) | |
| super().__init__(cfg, src_dict, tgt_dict) | |
| def build_model(self, cfg): | |
| from fairseq import models | |
| model = models.build_model(cfg, self) | |
| if not self.uniform_prior and not hasattr(model, "gating_network"): | |
| if self.cfg.mean_pool_gating_network: | |
| if self.cfg.mean_pool_gating_network_encoder_dim > 0: | |
| encoder_dim = self.cfg.mean_pool_gating_network_encoder_dim | |
| elif getattr(cfg, "encoder_embed_dim", None): | |
| # assume that encoder_embed_dim is the encoder's output dimension | |
| encoder_dim = cfg.encoder_embed_dim | |
| else: | |
| raise ValueError( | |
| "Must specify --mean-pool-gating-network-encoder-dim" | |
| ) | |
| if self.cfg.mean_pool_gating_network_dropout > 0: | |
| dropout = self.cfg.mean_pool_gating_network_dropout | |
| elif getattr(cfg, "dropout", None): | |
| dropout = cfg.dropout | |
| else: | |
| raise ValueError("Must specify task.mean_pool_gating_network_dropout") | |
| model.gating_network = MeanPoolGatingNetwork( | |
| encoder_dim, | |
| self.cfg.num_experts, | |
| dropout, | |
| ) | |
| else: | |
| raise ValueError( | |
| "translation_moe task with learned prior requires the model to " | |
| "have a gating network; try using --mean-pool-gating-network" | |
| ) | |
| return model | |
| def expert_index(self, i): | |
| return i + self.tgt_dict.index("<expert_0>") | |
| def _get_loss(self, sample, model, criterion): | |
| assert hasattr( | |
| criterion, "compute_loss" | |
| ), "translation_moe task requires the criterion to implement the compute_loss() method" | |
| k = self.cfg.num_experts | |
| bsz = sample["target"].size(0) | |
| def get_lprob_y(encoder_out, prev_output_tokens_k): | |
| net_output = model.decoder( | |
| prev_output_tokens=prev_output_tokens_k, | |
| encoder_out=encoder_out, | |
| ) | |
| loss, _ = criterion.compute_loss(model, net_output, sample, reduce=False) | |
| loss = loss.view(bsz, -1) | |
| return -loss.sum(dim=1, keepdim=True) # -> B x 1 | |
| def get_lprob_yz(winners=None): | |
| encoder_out = model.encoder( | |
| src_tokens=sample["net_input"]["src_tokens"], | |
| src_lengths=sample["net_input"]["src_lengths"], | |
| ) | |
| if winners is None: | |
| lprob_y = [] | |
| for i in range(k): | |
| prev_output_tokens_k = sample["net_input"][ | |
| "prev_output_tokens" | |
| ].clone() | |
| assert not prev_output_tokens_k.requires_grad | |
| prev_output_tokens_k[:, 0] = self.expert_index(i) | |
| lprob_y.append(get_lprob_y(encoder_out, prev_output_tokens_k)) | |
| lprob_y = torch.cat(lprob_y, dim=1) # -> B x K | |
| else: | |
| prev_output_tokens_k = sample["net_input"]["prev_output_tokens"].clone() | |
| prev_output_tokens_k[:, 0] = self.expert_index(winners) | |
| lprob_y = get_lprob_y(encoder_out, prev_output_tokens_k) # -> B | |
| if self.uniform_prior: | |
| lprob_yz = lprob_y | |
| else: | |
| lprob_z = model.gating_network(encoder_out) # B x K | |
| if winners is not None: | |
| lprob_z = lprob_z.gather(dim=1, index=winners.unsqueeze(-1)) | |
| lprob_yz = lprob_y + lprob_z.type_as(lprob_y) # B x K | |
| return lprob_yz | |
| # compute responsibilities without dropout | |
| with utils.model_eval(model): # disable dropout | |
| with torch.no_grad(): # disable autograd | |
| lprob_yz = get_lprob_yz() # B x K | |
| prob_z_xy = torch.nn.functional.softmax(lprob_yz, dim=1) | |
| assert not prob_z_xy.requires_grad | |
| # compute loss with dropout | |
| if self.hard_selection: | |
| winners = prob_z_xy.max(dim=1)[1] | |
| loss = -get_lprob_yz(winners) | |
| else: | |
| lprob_yz = get_lprob_yz() # B x K | |
| loss = -LogSumExpMoE.apply(lprob_yz, prob_z_xy, 1) | |
| loss = loss.sum() | |
| sample_size = ( | |
| sample["target"].size(0) if self.cfg.sentence_avg else sample["ntokens"] | |
| ) | |
| logging_output = { | |
| "loss": utils.item(loss.data), | |
| "ntokens": sample["ntokens"], | |
| "nsentences": bsz, | |
| "sample_size": sample_size, | |
| "posterior": prob_z_xy.float().sum(dim=0).cpu(), | |
| } | |
| return loss, sample_size, logging_output | |
| def train_step( | |
| self, sample, model, criterion, optimizer, update_num, ignore_grad=False | |
| ): | |
| model.train() | |
| loss, sample_size, logging_output = self._get_loss(sample, model, criterion) | |
| if ignore_grad: | |
| loss *= 0 | |
| optimizer.backward(loss) | |
| return loss, sample_size, logging_output | |
| def valid_step(self, sample, model, criterion): | |
| model.eval() | |
| with torch.no_grad(): | |
| loss, sample_size, logging_output = self._get_loss(sample, model, criterion) | |
| return loss, sample_size, logging_output | |
| def inference_step( | |
| self, | |
| generator, | |
| models, | |
| sample, | |
| prefix_tokens=None, | |
| expert=None, | |
| constraints=None, | |
| ): | |
| expert = expert or self.cfg.gen_expert | |
| with torch.no_grad(): | |
| return generator.generate( | |
| models, | |
| sample, | |
| prefix_tokens=prefix_tokens, | |
| constraints=constraints, | |
| bos_token=self.expert_index(expert), | |
| ) | |
| def reduce_metrics(self, logging_outputs, criterion): | |
| super().reduce_metrics(logging_outputs, criterion) | |
| metrics.log_scalar( | |
| "posterior", | |
| sum(log["posterior"] for log in logging_outputs if "posterior" in log), | |
| ) | |