| |
| |
| |
| |
|
|
| import argparse |
| import random |
|
|
| import numpy as np |
| from fairseq import options |
|
|
| from examples.noisychannel import rerank, rerank_options |
|
|
|
|
| def random_search(args): |
| param_values = [] |
| tuneable_parameters = ["lenpen", "weight1", "weight2", "weight3"] |
| initial_params = [args.lenpen, args.weight1, args.weight2, args.weight3] |
| for i, elem in enumerate(initial_params): |
| if type(elem) is not list: |
| initial_params[i] = [elem] |
| else: |
| initial_params[i] = elem |
|
|
| tune_parameters = args.tune_param.copy() |
| for i in range(len(args.tune_param)): |
| assert args.upper_bound[i] >= args.lower_bound[i] |
| index = tuneable_parameters.index(args.tune_param[i]) |
| del tuneable_parameters[index] |
| del initial_params[index] |
|
|
| tune_parameters += tuneable_parameters |
| param_values += initial_params |
| random.seed(args.seed) |
|
|
| random_params = np.array( |
| [ |
| [ |
| random.uniform(args.lower_bound[i], args.upper_bound[i]) |
| for i in range(len(args.tune_param)) |
| ] |
| for k in range(args.num_trials) |
| ] |
| ) |
| set_params = np.array( |
| [ |
| [initial_params[i][0] for i in range(len(tuneable_parameters))] |
| for k in range(args.num_trials) |
| ] |
| ) |
| random_params = np.concatenate((random_params, set_params), 1) |
|
|
| rerank_args = vars(args).copy() |
| if args.nbest_list: |
| rerank_args["gen_subset"] = "test" |
| else: |
| rerank_args["gen_subset"] = args.tune_subset |
|
|
| for k in range(len(tune_parameters)): |
| rerank_args[tune_parameters[k]] = list(random_params[:, k]) |
|
|
| if args.share_weights: |
| k = tune_parameters.index("weight2") |
| rerank_args["weight3"] = list(random_params[:, k]) |
|
|
| rerank_args = argparse.Namespace(**rerank_args) |
| best_lenpen, best_weight1, best_weight2, best_weight3, best_score = rerank.rerank( |
| rerank_args |
| ) |
| rerank_args = vars(args).copy() |
| rerank_args["lenpen"] = [best_lenpen] |
| rerank_args["weight1"] = [best_weight1] |
| rerank_args["weight2"] = [best_weight2] |
| rerank_args["weight3"] = [best_weight3] |
|
|
| |
|
|
| if args.gen_subset != "valid": |
| rerank_args["gen_subset"] = "valid" |
| rerank_args = argparse.Namespace(**rerank_args) |
| rerank.rerank(rerank_args) |
|
|
| |
| rerank_args = vars(args).copy() |
| rerank_args["gen_subset"] = args.gen_subset |
| rerank_args["lenpen"] = [best_lenpen] |
| rerank_args["weight1"] = [best_weight1] |
| rerank_args["weight2"] = [best_weight2] |
| rerank_args["weight3"] = [best_weight3] |
| rerank_args = argparse.Namespace(**rerank_args) |
| rerank.rerank(rerank_args) |
|
|
|
|
| def cli_main(): |
| parser = rerank_options.get_tuning_parser() |
| args = options.parse_args_and_arch(parser) |
|
|
| random_search(args) |
|
|
|
|
| if __name__ == "__main__": |
| cli_main() |
|
|