| import os | |
| import json | |
| from datetime import datetime | |
| from load_data import get_data, save_back_results | |
| from excute_tool_linux_parallel import run_cpp_code_linux | |
| from multiprocessing import Pool, cpu_count | |
| from tqdm import tqdm | |
| import logging | |
| import re | |
| def is_decimal(s): | |
| try: | |
| a = float(s) | |
| except: | |
| return False | |
| return bool(re.match(r"^-?\d+\.\d+$", s)) | |
| def get_testcases(testcase_path): | |
| data = [] | |
| if not os.path.exists(testcase_path): | |
| return [] | |
| with open(testcase_path, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| if line.strip(): | |
| data.append(json.loads(line)) | |
| return data | |
| def check_difference(arr): | |
| if not is_decimal(arr[0]): | |
| return False | |
| # 遍历每一对元素 | |
| for i in range(len(arr)): | |
| for j in range(i + 1, len(arr)): | |
| # 计算差异并判断是否满足条件 | |
| if abs(float(arr[i]) - float(arr[j])) > 1e-6: | |
| return False # 如果有任何一对差异大于 1e-6,则返回 False | |
| return True # 所有对的差异都不大于 1e-6,返回 True | |
| def append_dict_to_jsonl(file_path, data_dict): | |
| with open(file_path, 'a', encoding='utf-8') as f: | |
| f.write(json.dumps(data_dict, ensure_ascii=False) + '\n') | |
| def setup_logging(): | |
| os.makedirs("logs", exist_ok=True) | |
| log_file = f"logs/execution_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.FileHandler(log_file), | |
| logging.StreamHandler() | |
| ] | |
| ) | |
| return logging.getLogger() | |
| def process_code_with_logging(data_item): | |
| """包装函数,用于添加日志""" | |
| problem_id = data_item["problem_id"] | |
| try: | |
| result = run_cpp_code_linux(data_item) | |
| status = result.get("error", "Unknown") | |
| logger.info(f"执行完成 - 问题ID: {problem_id}, 状态: {status}") | |
| return result | |
| except Exception as e: | |
| logger.error(f"执行异常 - 问题ID: {problem_id}, 错误: {str(e)}") | |
| data_item["error"] = ["EXE"] | |
| data_item["details"] = str(e) | |
| return data_item | |
| def save_results(results, correct_code_output_file, output_file): | |
| """保存结果到文件""" | |
| # 初始化结果字典 | |
| problem_results = {} | |
| status_counts = {"AC": 0, "CE": 0, "TLE": 0, "MLE": 0, "RE": 0, "WA": 0, "EXE": 0} | |
| # 分类归整结果 | |
| for result in results: | |
| problem_id = result["problem_id"] | |
| code_id = result["code_id"] | |
| status = result.get("error", "Unknown") | |
| # 更新状态计数 | |
| # if status in status_counts: | |
| # status_counts[status] += 1 | |
| # if len(status) == 0: | |
| # status_counts["AC"] += 1 | |
| for status_name in status_counts.keys(): | |
| if status_name == "AC" and not all(sta == "AC" for sta in status): | |
| continue | |
| if status_name in status: | |
| status_counts[status_name] += 1 | |
| # 加入问题结果集 | |
| if problem_id not in problem_results: | |
| problem_results[problem_id] = { | |
| "problem_id": problem_id, | |
| "codes": [], | |
| "time_limit": result["time_limit"], | |
| "memory_limit": result["memory_limit"], | |
| "test_cases": result["test_cases"] | |
| } | |
| # 添加代码执行结果 | |
| problem_results[problem_id]["codes"].append({ | |
| "code_id": code_id, | |
| "code": result["code"], | |
| "status": status, | |
| "details": result.get("details", "") | |
| }) | |
| # 保存完整结果 | |
| with open(output_file, "w", encoding="utf-8") as f: | |
| json.dump(problem_results, f, indent=3) | |
| # 保存正确的代码(AC状态) | |
| correct_codes = {} | |
| for problem_id, problem_data in problem_results.items(): | |
| correct_problem_codes = [] | |
| for code_info in problem_data["codes"]: | |
| if all(status == "AC" for status in code_info["status"]): | |
| correct_problem_codes.append({ | |
| "code_id": code_info["code_id"], | |
| "code": code_info["code"] | |
| }) | |
| if correct_problem_codes: | |
| correct_codes[problem_id] = { | |
| "problem_id": problem_id, | |
| "codes": correct_problem_codes, | |
| "time_limit": problem_data["time_limit"], | |
| "memory_limit": problem_data["memory_limit"], | |
| "test_cases": problem_data["test_cases"] | |
| } | |
| # 保存正确代码结果 | |
| with open(correct_code_output_file, "w", encoding="utf-8") as f: | |
| json.dump(correct_codes, f, indent=3) | |
| # 返回状态统计 | |
| return status_counts, problem_results | |
| if __name__ == "__main__": | |
| datasets_name = "tcb" | |
| testcase_alg = "lcb" | |
| save_dir = f"/home/i-luoxianzhen/data/TestCase-Gen/saved_test_filterd/{testcase_alg}/" + "tests-{}.jsonl" | |
| test_dir = f"/home/i-luoxianzhen/data/TestCase-Gen/save_tests_scaling/{testcase_alg}/" + "tests-{}.jsonl" | |
| pass_rate_save_file = f"/home/i-luoxianzhen/data/TestCase-Gen/saved_test_filterd/{testcase_alg}/test_pass_rate.jsonl" | |
| logger = setup_logging() | |
| logger.info("开始执行代码评估...") | |
| data = get_data(name=datasets_name, prefix_dir=test_dir, save_dir=save_dir) | |
| logger.info(f"加载了 {len(data)} 个代码项目") | |
| cpu = 50 | |
| logger.info(f"使用 {cpu} 个CPU核心进行并行处理") | |
| with Pool(cpu) as pool: | |
| results = list(tqdm( | |
| pool.imap_unordered(process_code_with_logging, data), | |
| total=len(data), | |
| desc="执行进度" | |
| )) | |
| logger.info("所有代码执行完成 - 开始保存筛选后的测试样例") | |
| result_dict = {} | |
| for item in results: | |
| if item['problem_id'] not in result_dict.keys(): | |
| result_dict[item['problem_id']] = [item, ] | |
| else: | |
| result_dict[item['problem_id']].append(item) | |
| for k, v in result_dict.items(): | |
| testcases = get_testcases(test_dir.format(k)) | |
| save_path = save_dir.format(k) | |
| status_array_length = len(v[0]["error"]) | |
| saved_nums = 0 | |
| for i in range(status_array_length): | |
| all_AC = True | |
| ## 部分输出为空的,如果新的输出全部一致也可以保留 | |
| all_WA = True | |
| output_list = [] | |
| for item in v: | |
| if item["error"][i] == "EXE": | |
| all_AC = False | |
| all_WA = False | |
| break | |
| output_list.append(item['details'][i]) | |
| if not item["error"][i] == "AC": | |
| all_AC = False | |
| if not item["error"][i] == "WA": | |
| all_WA = False | |
| if all_AC: | |
| append_dict_to_jsonl(save_path, testcases[i]) | |
| saved_nums += 1 | |
| if all_WA and (len(list(set(output_list))) == 1 or check_difference(output_list)): | |
| testcases[i]['output'] = output_list[0] | |
| append_dict_to_jsonl(save_path, testcases[i]) | |
| saved_nums += 1 | |
| append_dict_to_jsonl(pass_rate_save_file, { | |
| "tcb_id": k, | |
| "gen_nums": len(testcases), | |
| "right_nums": saved_nums | |
| }) | |
Xet Storage Details
- Size:
- 7.49 kB
- Xet hash:
- ae48aa84edeffec21dfcc84b505026fd8beac787afe7ef71ad576c7af8884849
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.