| | import json |
| | import os |
| | import pickle |
| |
|
| | from rich.progress import track |
| |
|
| | from tools.tsr.utils import get_task_ids, to_path |
| |
|
| |
|
| | def collect_sample_info(sample_dir: str, sample_eval_dir: str, dataset: str): |
| | if os.path.exists(sample_dir) and len(os.listdir(sample_dir)) > 0: |
| | |
| | return |
| | task_ids = get_task_ids(dataset) |
| | assert os.path.exists(sample_eval_dir), "sample evaluation files missing" |
| | os.makedirs(sample_dir, exist_ok=True) |
| | kill_info = {task_id: {} for task_id in task_ids} |
| | model_paths = os.listdir(sample_eval_dir) |
| | for model_path in track(model_paths, description="Collecting sets..."): |
| | if not model_path[-1].isdigit(): |
| | continue |
| | eval_json_path = os.path.join(sample_eval_dir, model_path, "eval_results.json") |
| | if not os.path.exists(eval_json_path): |
| | continue |
| | with open(eval_json_path, "r") as f: |
| | res = json.load(f)["eval"] |
| | for task_id, v in res.items(): |
| | if task_id not in task_ids: |
| | continue |
| | for i_code, (status, res_list) in enumerate(v["plus"]): |
| | if status == "success": |
| | continue |
| | for i_test, res in enumerate(res_list): |
| | test_id = f"plus_{i_test}" |
| | if res == False: |
| | if "_" in task_id: |
| | task_id = task_id.replace("_", "/") |
| | kill_info[task_id].setdefault(test_id, []).append( |
| | (model_path, i_code) |
| | ) |
| | for task_id in task_ids: |
| | path_task_id = to_path(task_id) |
| | with open(os.path.join(sample_dir, f"{path_task_id}.pkl"), "wb") as f: |
| | pickle.dump(kill_info[task_id], f) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import argparse |
| |
|
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--report_dir", required=True, type=str) |
| | parser.add_argument("--dataset", type=str, choices=["humaneval", "mbpp"]) |
| | parser.add_argument("--sample_eval_dir", required=True, type=str) |
| | args = parser.parse_args() |
| |
|
| | sample_dir = os.path.join(args.report_dir, "sample_cache") |
| | collect_sample_info(sample_dir, args.sample_eval_dir, args.dataset) |
| |
|