| """ |
| Take in a YAML, and output all other splits with this YAML |
| """ |
|
|
| import argparse |
| import os |
|
|
| import yaml |
| from tqdm import tqdm |
|
|
| from lm_eval.logger import eval_logger |
|
|
|
|
| SUBSETS = ["WR", "GR", "RCS", "RCSS", "RCH", "LI"] |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--base_yaml_path", required=True) |
| parser.add_argument("--save_prefix_path", default="csatqa") |
| parser.add_argument("--task_prefix", default="") |
| return parser.parse_args() |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
|
|
| |
| base_yaml_name = os.path.split(args.base_yaml_path)[-1] |
| with open(args.base_yaml_path, encoding="utf-8") as f: |
| base_yaml = yaml.full_load(f) |
|
|
| for name in tqdm(SUBSETS): |
| yaml_dict = { |
| "include": base_yaml_name, |
| "task": f"csatqa_{args.task_prefix}_{name}" |
| if args.task_prefix != "" |
| else f"csatqa_{name.lower()}", |
| "dataset_name": name, |
| } |
|
|
| file_save_path = args.save_prefix_path + f"_{name.lower()}.yaml" |
| eval_logger.info(f"Saving yaml for subset {name} to {file_save_path}") |
| with open(file_save_path, "w", encoding="utf-8") as yaml_file: |
| yaml.dump( |
| yaml_dict, |
| yaml_file, |
| width=float("inf"), |
| allow_unicode=True, |
| default_style='"', |
| ) |
|
|