| |
| |
| |
| |
|
|
| import os |
| from contextlib import redirect_stdout |
|
|
| from fairseq import options |
| from fairseq_cli import generate |
|
|
| from examples.noisychannel import rerank_options, rerank_utils |
|
|
|
|
| def score_bw(args): |
| if args.backwards1: |
| scorer1_src = args.target_lang |
| scorer1_tgt = args.source_lang |
| else: |
| scorer1_src = args.source_lang |
| scorer1_tgt = args.target_lang |
|
|
| if args.score_model2 is not None: |
| if args.backwards2: |
| scorer2_src = args.target_lang |
| scorer2_tgt = args.source_lang |
| else: |
| scorer2_src = args.source_lang |
| scorer2_tgt = args.target_lang |
|
|
| rerank1_is_gen = ( |
| args.gen_model == args.score_model1 and args.source_prefix_frac is None |
| ) |
| rerank2_is_gen = ( |
| args.gen_model == args.score_model2 and args.source_prefix_frac is None |
| ) |
|
|
| ( |
| pre_gen, |
| left_to_right_preprocessed_dir, |
| right_to_left_preprocessed_dir, |
| backwards_preprocessed_dir, |
| lm_preprocessed_dir, |
| ) = rerank_utils.get_directories( |
| args.data_dir_name, |
| args.num_rescore, |
| args.gen_subset, |
| args.gen_model_name, |
| args.shard_id, |
| args.num_shards, |
| args.sampling, |
| args.prefix_len, |
| args.target_prefix_frac, |
| args.source_prefix_frac, |
| ) |
|
|
| score1_file = rerank_utils.rescore_file_name( |
| pre_gen, |
| args.prefix_len, |
| args.model1_name, |
| target_prefix_frac=args.target_prefix_frac, |
| source_prefix_frac=args.source_prefix_frac, |
| backwards=args.backwards1, |
| ) |
|
|
| if args.score_model2 is not None: |
| score2_file = rerank_utils.rescore_file_name( |
| pre_gen, |
| args.prefix_len, |
| args.model2_name, |
| target_prefix_frac=args.target_prefix_frac, |
| source_prefix_frac=args.source_prefix_frac, |
| backwards=args.backwards2, |
| ) |
|
|
| if args.right_to_left1: |
| rerank_data1 = right_to_left_preprocessed_dir |
| elif args.backwards1: |
| rerank_data1 = backwards_preprocessed_dir |
| else: |
| rerank_data1 = left_to_right_preprocessed_dir |
|
|
| gen_param = ["--batch-size", str(128), "--score-reference", "--gen-subset", "train"] |
| if not rerank1_is_gen and not os.path.isfile(score1_file): |
| print("STEP 4: score the translations for model 1") |
|
|
| model_param1 = [ |
| "--path", |
| args.score_model1, |
| "--source-lang", |
| scorer1_src, |
| "--target-lang", |
| scorer1_tgt, |
| ] |
| gen_model1_param = [rerank_data1] + gen_param + model_param1 |
|
|
| gen_parser = options.get_generation_parser() |
| input_args = options.parse_args_and_arch(gen_parser, gen_model1_param) |
|
|
| with open(score1_file, "w") as f: |
| with redirect_stdout(f): |
| generate.main(input_args) |
|
|
| if ( |
| args.score_model2 is not None |
| and not os.path.isfile(score2_file) |
| and not rerank2_is_gen |
| ): |
| print("STEP 4: score the translations for model 2") |
|
|
| if args.right_to_left2: |
| rerank_data2 = right_to_left_preprocessed_dir |
| elif args.backwards2: |
| rerank_data2 = backwards_preprocessed_dir |
| else: |
| rerank_data2 = left_to_right_preprocessed_dir |
|
|
| model_param2 = [ |
| "--path", |
| args.score_model2, |
| "--source-lang", |
| scorer2_src, |
| "--target-lang", |
| scorer2_tgt, |
| ] |
| gen_model2_param = [rerank_data2] + gen_param + model_param2 |
|
|
| gen_parser = options.get_generation_parser() |
| input_args = options.parse_args_and_arch(gen_parser, gen_model2_param) |
|
|
| with open(score2_file, "w") as f: |
| with redirect_stdout(f): |
| generate.main(input_args) |
|
|
|
|
| def cli_main(): |
| parser = rerank_options.get_reranking_parser() |
| args = options.parse_args_and_arch(parser) |
| score_bw(args) |
|
|
|
|
| if __name__ == "__main__": |
| cli_main() |
|
|