Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import os | |
| from fairseq import options | |
| from examples.noisychannel import rerank_options, rerank_utils | |
| def score_lm(args): | |
| using_nbest = args.nbest_list is not None | |
| ( | |
| pre_gen, | |
| left_to_right_preprocessed_dir, | |
| right_to_left_preprocessed_dir, | |
| backwards_preprocessed_dir, | |
| lm_preprocessed_dir, | |
| ) = rerank_utils.get_directories( | |
| args.data_dir_name, | |
| args.num_rescore, | |
| args.gen_subset, | |
| args.gen_model_name, | |
| args.shard_id, | |
| args.num_shards, | |
| args.sampling, | |
| args.prefix_len, | |
| args.target_prefix_frac, | |
| args.source_prefix_frac, | |
| ) | |
| predictions_bpe_file = pre_gen + "/generate_output_bpe.txt" | |
| if using_nbest: | |
| print("Using predefined n-best list from interactive.py") | |
| predictions_bpe_file = args.nbest_list | |
| gen_output = rerank_utils.BitextOutputFromGen( | |
| predictions_bpe_file, bpe_symbol=args.post_process, nbest=using_nbest | |
| ) | |
| if args.language_model is not None: | |
| lm_score_file = rerank_utils.rescore_file_name( | |
| pre_gen, args.prefix_len, args.lm_name, lm_file=True | |
| ) | |
| if args.language_model is not None and not os.path.isfile(lm_score_file): | |
| print("STEP 4.5: language modeling for P(T)") | |
| if args.lm_bpe_code is None: | |
| bpe_status = "no bpe" | |
| elif args.lm_bpe_code == "shared": | |
| bpe_status = "shared" | |
| else: | |
| bpe_status = "different" | |
| rerank_utils.lm_scoring( | |
| lm_preprocessed_dir, | |
| bpe_status, | |
| gen_output, | |
| pre_gen, | |
| args.lm_dict, | |
| args.lm_name, | |
| args.language_model, | |
| args.lm_bpe_code, | |
| 128, | |
| lm_score_file, | |
| args.target_lang, | |
| args.source_lang, | |
| prefix_len=args.prefix_len, | |
| ) | |
| def cli_main(): | |
| parser = rerank_options.get_reranking_parser() | |
| args = options.parse_args_and_arch(parser) | |
| score_lm(args) | |
| if __name__ == "__main__": | |
| cli_main() | |