| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | """
|
| | Translate raw text with a trained model. Batches data on-the-fly.
|
| | """
|
| |
|
| | import os
|
| | import ast
|
| | from collections import namedtuple
|
| |
|
| | import torch
|
| | from fairseq import checkpoint_utils, options, tasks, utils
|
| | from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
| | from fairseq.token_generation_constraints import pack_constraints, unpack_constraints
|
| | from fairseq_cli.generate import get_symbols_to_strip_from_output
|
| |
|
| | import codecs
|
| |
|
| | PWD = os.path.dirname(__file__)
|
| | Batch = namedtuple("Batch", "ids src_tokens src_lengths constraints")
|
| | Translation = namedtuple("Translation", "src_str hypos pos_scores alignments")
|
| |
|
| |
|
| | def make_batches(
|
| | lines, cfg, task, max_positions, encode_fn, constrainted_decoding=False
|
| | ):
|
| | def encode_fn_target(x):
|
| | return encode_fn(x)
|
| |
|
| | if constrainted_decoding:
|
| |
|
| |
|
| | batch_constraints = [list() for _ in lines]
|
| | for i, line in enumerate(lines):
|
| | if "\t" in line:
|
| | lines[i], *batch_constraints[i] = line.split("\t")
|
| |
|
| |
|
| | for i, constraint_list in enumerate(batch_constraints):
|
| | batch_constraints[i] = [
|
| | task.target_dictionary.encode_line(
|
| | encode_fn_target(constraint),
|
| | append_eos=False,
|
| | add_if_not_exist=False,
|
| | )
|
| | for constraint in constraint_list
|
| | ]
|
| |
|
| | if constrainted_decoding:
|
| | constraints_tensor = pack_constraints(batch_constraints)
|
| | else:
|
| | constraints_tensor = None
|
| |
|
| | tokens, lengths = task.get_interactive_tokens_and_lengths(lines, encode_fn)
|
| |
|
| | itr = task.get_batch_iterator(
|
| | dataset=task.build_dataset_for_inference(
|
| | tokens, lengths, constraints=constraints_tensor
|
| | ),
|
| | max_tokens=cfg.dataset.max_tokens,
|
| | max_sentences=cfg.dataset.batch_size,
|
| | max_positions=max_positions,
|
| | ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
|
| | ).next_epoch_itr(shuffle=False)
|
| | for batch in itr:
|
| | ids = batch["id"]
|
| | src_tokens = batch["net_input"]["src_tokens"]
|
| | src_lengths = batch["net_input"]["src_lengths"]
|
| | constraints = batch.get("constraints", None)
|
| |
|
| | yield Batch(
|
| | ids=ids,
|
| | src_tokens=src_tokens,
|
| | src_lengths=src_lengths,
|
| | constraints=constraints,
|
| | )
|
| |
|
| |
|
| | class Translator:
|
| | """
|
| | Wrapper class to handle the interaction with fairseq model class for translation
|
| | """
|
| |
|
| | def __init__(
|
| | self, data_dir, checkpoint_path, batch_size=25, constrained_decoding=False
|
| | ):
|
| |
|
| | self.constrained_decoding = constrained_decoding
|
| | self.parser = options.get_generation_parser(interactive=True)
|
| |
|
| |
|
| | if self.constrained_decoding:
|
| | self.parser.set_defaults(
|
| | path=checkpoint_path,
|
| | num_workers=-1,
|
| | constraints="ordered",
|
| | batch_size=batch_size,
|
| | buffer_size=batch_size + 1,
|
| | )
|
| | else:
|
| | self.parser.set_defaults(
|
| | path=checkpoint_path,
|
| | remove_bpe="subword_nmt",
|
| | num_workers=-1,
|
| | batch_size=batch_size,
|
| | buffer_size=batch_size + 1,
|
| | )
|
| | args = options.parse_args_and_arch(self.parser, input_args=[data_dir])
|
| |
|
| |
|
| |
|
| |
|
| | args.source_lang = "SRC"
|
| | args.target_lang = "TGT"
|
| |
|
| | args.skip_invalid_size_inputs_valid_test = False
|
| |
|
| |
|
| |
|
| | args.user_dir = os.path.join(PWD, "model_configs")
|
| | self.cfg = convert_namespace_to_omegaconf(args)
|
| |
|
| | utils.import_user_module(self.cfg.common)
|
| |
|
| | if self.cfg.interactive.buffer_size < 1:
|
| | self.cfg.interactive.buffer_size = 1
|
| | if self.cfg.dataset.max_tokens is None and self.cfg.dataset.batch_size is None:
|
| | self.cfg.dataset.batch_size = 1
|
| |
|
| | assert (
|
| | not self.cfg.generation.sampling
|
| | or self.cfg.generation.nbest == self.cfg.generation.beam
|
| | ), "--sampling requires --nbest to be equal to --beam"
|
| | assert (
|
| | not self.cfg.dataset.batch_size
|
| | or self.cfg.dataset.batch_size <= self.cfg.interactive.buffer_size
|
| | ), "--batch-size cannot be larger than --buffer-size"
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | self.use_cuda = torch.cuda.is_available() and not self.cfg.common.cpu
|
| |
|
| |
|
| | self.task = tasks.setup_task(self.cfg.task)
|
| |
|
| |
|
| | overrides = ast.literal_eval(self.cfg.common_eval.model_overrides)
|
| | self.models, self._model_args = checkpoint_utils.load_model_ensemble(
|
| | utils.split_paths(self.cfg.common_eval.path),
|
| | arg_overrides=overrides,
|
| | task=self.task,
|
| | suffix=self.cfg.checkpoint.checkpoint_suffix,
|
| | strict=(self.cfg.checkpoint.checkpoint_shard_count == 1),
|
| | num_shards=self.cfg.checkpoint.checkpoint_shard_count,
|
| | )
|
| |
|
| |
|
| | self.src_dict = self.task.source_dictionary
|
| | self.tgt_dict = self.task.target_dictionary
|
| |
|
| |
|
| | for model in self.models:
|
| | if model is None:
|
| | continue
|
| | if self.cfg.common.fp16:
|
| | model.half()
|
| | if (
|
| | self.use_cuda
|
| | and not self.cfg.distributed_training.pipeline_model_parallel
|
| | ):
|
| | model.cuda()
|
| | model.prepare_for_inference_(self.cfg)
|
| |
|
| |
|
| | self.generator = self.task.build_generator(self.models, self.cfg.generation)
|
| |
|
| | self.tokenizer = None
|
| | self.bpe = None
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | self.align_dict = utils.load_align_dict(self.cfg.generation.replace_unk)
|
| |
|
| | self.max_positions = utils.resolve_max_positions(
|
| | self.task.max_positions(), *[model.max_positions() for model in self.models]
|
| | )
|
| |
|
| | def encode_fn(self, x):
|
| | if self.tokenizer is not None:
|
| | x = self.tokenizer.encode(x)
|
| | if self.bpe is not None:
|
| | x = self.bpe.encode(x)
|
| | return x
|
| |
|
| | def decode_fn(self, x):
|
| | if self.bpe is not None:
|
| | x = self.bpe.decode(x)
|
| | if self.tokenizer is not None:
|
| | x = self.tokenizer.decode(x)
|
| | return x
|
| |
|
| | def translate(self, inputs, constraints=None):
|
| | if self.constrained_decoding and constraints is None:
|
| | raise ValueError("Constraints cant be None in constrained decoding mode")
|
| | if not self.constrained_decoding and constraints is not None:
|
| | raise ValueError("Cannot pass constraints during normal translation")
|
| | if constraints:
|
| | constrained_decoding = True
|
| | modified_inputs = []
|
| | for _input, constraint in zip(inputs, constraints):
|
| | modified_inputs.append(_input + f"\t{constraint}")
|
| | inputs = modified_inputs
|
| | else:
|
| | constrained_decoding = False
|
| |
|
| | start_id = 0
|
| | results = []
|
| | final_translations = []
|
| | for batch in make_batches(
|
| | inputs,
|
| | self.cfg,
|
| | self.task,
|
| | self.max_positions,
|
| | self.encode_fn,
|
| | constrained_decoding,
|
| | ):
|
| | bsz = batch.src_tokens.size(0)
|
| | src_tokens = batch.src_tokens
|
| | src_lengths = batch.src_lengths
|
| | constraints = batch.constraints
|
| | if self.use_cuda:
|
| | src_tokens = src_tokens.cuda()
|
| | src_lengths = src_lengths.cuda()
|
| | if constraints is not None:
|
| | constraints = constraints.cuda()
|
| |
|
| | sample = {
|
| | "net_input": {
|
| | "src_tokens": src_tokens,
|
| | "src_lengths": src_lengths,
|
| | },
|
| | }
|
| |
|
| | translations = self.task.inference_step(
|
| | self.generator, self.models, sample, constraints=constraints
|
| | )
|
| |
|
| | list_constraints = [[] for _ in range(bsz)]
|
| | if constrained_decoding:
|
| | list_constraints = [unpack_constraints(c) for c in constraints]
|
| | for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)):
|
| | src_tokens_i = utils.strip_pad(src_tokens[i], self.tgt_dict.pad())
|
| | constraints = list_constraints[i]
|
| | results.append(
|
| | (
|
| | start_id + id,
|
| | src_tokens_i,
|
| | hypos,
|
| | {
|
| | "constraints": constraints,
|
| | },
|
| | )
|
| | )
|
| |
|
| |
|
| | for id_, src_tokens, hypos, _ in sorted(results, key=lambda x: x[0]):
|
| | src_str = ""
|
| | if self.src_dict is not None:
|
| | src_str = self.src_dict.string(
|
| | src_tokens, self.cfg.common_eval.post_process
|
| | )
|
| |
|
| |
|
| | for hypo in hypos[: min(len(hypos), self.cfg.generation.nbest)]:
|
| | hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
|
| | hypo_tokens=hypo["tokens"].int().cpu(),
|
| | src_str=src_str,
|
| | alignment=hypo["alignment"],
|
| | align_dict=self.align_dict,
|
| | tgt_dict=self.tgt_dict,
|
| |
|
| | extra_symbols_to_ignore=get_symbols_to_strip_from_output(
|
| | self.generator
|
| | ),
|
| | )
|
| | detok_hypo_str = self.decode_fn(hypo_str)
|
| | final_translations.append(detok_hypo_str)
|
| | return final_translations
|
| |
|