| | |
| | |
| | |
| | |
| | |
| | """ |
| | Translate pre-processed data with a trained model. |
| | """ |
| |
|
| | import numpy as np |
| | import torch |
| | from fairseq import checkpoint_utils, options, progress_bar, tasks, utils |
| | from fairseq.sequence_generator import EnsembleModel |
| | from fairseq.utils import safe_hasattr |
| |
|
| |
|
| | def get_avg_pool( |
| | models, sample, prefix_tokens, src_dict, remove_bpe, has_langtok=False |
| | ): |
| | model = EnsembleModel(models) |
| |
|
| | |
| | |
| | encoder_input = { |
| | k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens" |
| | } |
| |
|
| | |
| | encoder_outs = model.forward_encoder(encoder_input) |
| | np_encoder_outs = encoder_outs[0].encoder_out.cpu().numpy().astype(np.float32) |
| | encoder_mask = 1 - encoder_outs[0].encoder_padding_mask.cpu().numpy().astype( |
| | np.float32 |
| | ) |
| | encoder_mask = np.expand_dims(encoder_mask.T, axis=2) |
| | if has_langtok: |
| | encoder_mask = encoder_mask[1:, :, :] |
| | np_encoder_outs = np_encoder_outs[1, :, :] |
| | masked_encoder_outs = encoder_mask * np_encoder_outs |
| | avg_pool = (masked_encoder_outs / encoder_mask.sum(axis=0)).sum(axis=0) |
| | return avg_pool |
| |
|
| |
|
| | def main(args): |
| | assert args.path is not None, "--path required for generation!" |
| | assert ( |
| | not args.sampling or args.nbest == args.beam |
| | ), "--sampling requires --nbest to be equal to --beam" |
| | assert ( |
| | args.replace_unk is None or args.raw_text |
| | ), "--replace-unk requires a raw text dataset (--raw-text)" |
| |
|
| | args.beam = 1 |
| | utils.import_user_module(args) |
| |
|
| | if args.max_tokens is None: |
| | args.max_tokens = 12000 |
| | print(args) |
| | use_cuda = torch.cuda.is_available() and not args.cpu |
| |
|
| | |
| | task = tasks.setup_task(args) |
| | task.load_dataset(args.gen_subset) |
| |
|
| | |
| | try: |
| | src_dict = getattr(task, "source_dictionary", None) |
| | except NotImplementedError: |
| | src_dict = None |
| | tgt_dict = task.target_dictionary |
| |
|
| | |
| | print("| loading model(s) from {}".format(args.path)) |
| | models, _model_args = checkpoint_utils.load_model_ensemble( |
| | args.path.split(":"), |
| | arg_overrides=eval(args.model_overrides), |
| | task=task, |
| | ) |
| |
|
| | |
| | for model in models: |
| | model.make_generation_fast_( |
| | beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, |
| | need_attn=args.print_alignment, |
| | ) |
| | if args.fp16: |
| | model.half() |
| | if use_cuda: |
| | model.cuda() |
| |
|
| | |
| | |
| | align_dict = utils.load_align_dict(args.replace_unk) |
| |
|
| | |
| | itr = task.get_batch_iterator( |
| | dataset=task.dataset(args.gen_subset), |
| | max_tokens=args.max_tokens, |
| | max_positions=utils.resolve_max_positions( |
| | task.max_positions(), |
| | ), |
| | ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test, |
| | required_batch_size_multiple=args.required_batch_size_multiple, |
| | num_shards=args.num_shards, |
| | shard_id=args.shard_id, |
| | num_workers=args.num_workers, |
| | ).next_epoch_itr(shuffle=False) |
| |
|
| | num_sentences = 0 |
| | source_sentences = [] |
| | shard_id = 0 |
| | all_avg_pool = None |
| | encoder_has_langtok = ( |
| | safe_hasattr(task.args, "encoder_langtok") |
| | and task.args.encoder_langtok is not None |
| | and safe_hasattr(task.args, "lang_tok_replacing_bos_eos") |
| | and not task.args.lang_tok_replacing_bos_eos |
| | ) |
| | with progress_bar.build_progress_bar(args, itr) as t: |
| | for sample in t: |
| | if sample is None: |
| | print("Skipping None") |
| | continue |
| | sample = utils.move_to_cuda(sample) if use_cuda else sample |
| | if "net_input" not in sample: |
| | continue |
| |
|
| | prefix_tokens = None |
| | if args.prefix_size > 0: |
| | prefix_tokens = sample["target"][:, : args.prefix_size] |
| |
|
| | with torch.no_grad(): |
| | avg_pool = get_avg_pool( |
| | models, |
| | sample, |
| | prefix_tokens, |
| | src_dict, |
| | args.post_process, |
| | has_langtok=encoder_has_langtok, |
| | ) |
| | if all_avg_pool is not None: |
| | all_avg_pool = np.concatenate((all_avg_pool, avg_pool)) |
| | else: |
| | all_avg_pool = avg_pool |
| |
|
| | if not isinstance(sample["id"], list): |
| | sample_ids = sample["id"].tolist() |
| | else: |
| | sample_ids = sample["id"] |
| | for i, sample_id in enumerate(sample_ids): |
| | |
| | src_tokens = utils.strip_pad( |
| | sample["net_input"]["src_tokens"][i, :], tgt_dict.pad() |
| | ) |
| |
|
| | |
| | if align_dict is not None: |
| | src_str = task.dataset(args.gen_subset).src.get_original_text( |
| | sample_id |
| | ) |
| | else: |
| | if src_dict is not None: |
| | src_str = src_dict.string(src_tokens, args.post_process) |
| | else: |
| | src_str = "" |
| |
|
| | if not args.quiet: |
| | if src_dict is not None: |
| | print("S-{}\t{}".format(sample_id, src_str)) |
| |
|
| | source_sentences.append(f"{sample_id}\t{src_str}") |
| |
|
| | num_sentences += sample["nsentences"] |
| | if all_avg_pool.shape[0] >= 1000000: |
| | with open( |
| | f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}", |
| | "w", |
| | ) as avg_pool_file: |
| | all_avg_pool.tofile(avg_pool_file) |
| | with open( |
| | f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}", |
| | "w", |
| | ) as sentence_file: |
| | sentence_file.writelines(f"{line}\n" for line in source_sentences) |
| | all_avg_pool = None |
| | source_sentences = [] |
| | shard_id += 1 |
| |
|
| | if all_avg_pool is not None: |
| | with open( |
| | f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}", "w" |
| | ) as avg_pool_file: |
| | all_avg_pool.tofile(avg_pool_file) |
| | with open( |
| | f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}", "w" |
| | ) as sentence_file: |
| | sentence_file.writelines(f"{line}\n" for line in source_sentences) |
| | return None |
| |
|
| |
|
| | def cli_main(): |
| | parser = options.get_generation_parser() |
| | parser.add_argument( |
| | "--encoder-save-dir", |
| | default="", |
| | type=str, |
| | metavar="N", |
| | help="directory to save encoder outputs", |
| | ) |
| | args = options.parse_args_and_arch(parser) |
| | main(args) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | cli_main() |
| |
|