| import os | |
| import json | |
| from datetime import datetime | |
| from multiprocessing import Pool, cpu_count | |
| from tqdm import tqdm | |
| import logging | |
| from typing import List, Optional | |
| import random | |
| import re | |
| def is_decimal(s): | |
| try: | |
| a = float(s) | |
| except: | |
| return False | |
| return bool(re.match(r"^-?\d+\.\d+$", s)) | |
| def get_testcases(testcase_path): | |
| data = [] | |
| if not os.path.exists(testcase_path): | |
| return [] | |
| with open(testcase_path, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| if line.strip(): | |
| data.append(json.loads(line)) | |
| return data | |
| def check_difference(arr): | |
| for num in arr: | |
| if not is_decimal(num): | |
| return False | |
| # 遍历每一对元素 | |
| for i in range(len(arr)): | |
| for j in range(i + 1, len(arr)): | |
| # 计算差异并判断是否满足条件 | |
| if abs(float(arr[i]) - float(arr[j])) > 1e-6: | |
| return False # 如果有任何一对差异大于 1e-6,则返回 False | |
| return True # 所有对的差异都不大于 1e-6,返回 True | |
| def append_dict_to_jsonl(file_path, data_dict): | |
| with open(file_path, 'a', encoding='utf-8') as f: | |
| f.write(json.dumps(data_dict, ensure_ascii=False) + '\n') | |
| def setup_logging(): | |
| os.makedirs("logs", exist_ok=True) | |
| log_file = f"logs/execution_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log" | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.FileHandler(log_file), | |
| logging.StreamHandler() | |
| ] | |
| ) | |
| return logging.getLogger() | |
| from execute_tool import function_execute_box_subprocess, function_execute_box_process | |
| logger = logging.getLogger(__name__) | |
| def process_function(data_item): | |
| tcb_id = data_item['tcb_id'] | |
| # 原代码里是 'eage',这里按语义改正为 'edge' | |
| random_func_list = data_item['func_list'].get('random', []) | |
| edge_func_list = data_item['func_list'].get('eage', []) | |
| random_test_inputs: List = [] | |
| edge_test_inputs: List = [] | |
| # —— 随机用例:每个函数最多执行一次,累计成功 3 个函数即停 —— | |
| # random_count = 0 | |
| # for random_func in random_func_list: | |
| # try: | |
| # res = function_execute_box_process(random_func, funcname="construct_inputs", time_limit=50) | |
| # except: | |
| # print("Generate Error") | |
| # continue | |
| # if res.get('status') == 'success' and isinstance(res.get('result'), list): | |
| # random_test_inputs += res['result'] | |
| # random_count += 1 | |
| # # if res.get('status') == 'error': | |
| # # print(res.get('status'), res.get('details')) | |
| # if random_count >= 5: | |
| # break | |
| # —— 边界用例:累计成功 2 个函数即停 —— | |
| edge_count = 0 | |
| for edge_func in edge_func_list: | |
| try: | |
| res = function_execute_box_process(edge_func, funcname="construct_inputs", time_limit=30) | |
| except: | |
| continue | |
| if res.get('status') == 'success' and isinstance(res.get('result'), list): | |
| edge_test_inputs += res['result'] | |
| edge_count += 1 | |
| if edge_count >= 3: | |
| break | |
| generate_testcases = random_test_inputs + edge_test_inputs | |
| if len(generate_testcases) > 200: | |
| generate_testcases = random.sample(generate_testcases, 200) | |
| return { | |
| "tcb_id": tcb_id, | |
| "generate_testcases": generate_testcases | |
| } | |
| def process_function_add(data_item): | |
| tcb_id = data_item['tcb_id'] | |
| # 原代码里是 'eage',这里按语义改正为 'edge' | |
| random_func_list = data_item['func_list'].get('random', []) | |
| edge_func_list = data_item['func_list'].get('eage', []) | |
| limit_nums = data_item['limit_nums'] | |
| random_test_inputs: List = [] | |
| edge_test_inputs: List = [] | |
| # —— 随机用例:每个函数最多执行一次,累计成功 3 个函数即停 —— | |
| random_count = 0 | |
| gen_times = 1 | |
| while len(random_test_inputs + edge_test_inputs) < limit_nums and gen_times > 0: | |
| gen_times -= 1 | |
| for i in range(0, max(len(random_func_list), len(edge_func_list))): | |
| if i < len(random_func_list): | |
| random_func = random_func_list[i] | |
| try: | |
| res = function_execute_box_process(random_func, funcname="construct_inputs", time_limit=10) | |
| except: | |
| continue | |
| if res.get('status') == 'success' and isinstance(res.get('result'), list): | |
| random_test_inputs += res['result'] | |
| if i < len(edge_func_list): | |
| edge_func = edge_func_list[i] | |
| try: | |
| res = function_execute_box_process(edge_func, funcname="construct_inputs", time_limit=10) | |
| except: | |
| continue | |
| if res.get('status') == 'success' and isinstance(res.get('result'), list): | |
| edge_test_inputs += res['result'] | |
| if len(random_test_inputs + edge_test_inputs) > limit_nums: | |
| break | |
| generate_testcases = random_test_inputs + edge_test_inputs | |
| if len(generate_testcases) > 200: | |
| generate_testcases = random.sample(generate_testcases, 200) | |
| return { | |
| "tcb_id": tcb_id, | |
| "generate_testcases": generate_testcases | |
| } | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| def run_all(data, max_workers: int | None = None): | |
| if max_workers is None: | |
| # 线程池调度 + 任务内部子进程执行:一般设置为 CPU 数量或稍小 | |
| cpu = max(1, (os.cpu_count() or 8) - 2) | |
| max_workers = cpu | |
| logger.info(f"使用 {max_workers} 个并发 worker 进行处理(线程池调度,任务内子进程可强杀)") | |
| results = [None] * len(data) | |
| def _task(idx_item): | |
| idx, item = idx_item | |
| try: | |
| return idx, process_function(item) | |
| except Exception as e: | |
| # 兜底,任何未捕获异常都返回空结果,避免整体崩溃 | |
| logger.exception("处理单条数据时发生异常") | |
| return idx, {"tcb_id": item.get("tcb_id"), "generate_testcases": []} | |
| with ThreadPoolExecutor(max_workers=max_workers) as ex: | |
| futures = [ex.submit(_task, (i, data[i])) for i in range(len(data))] | |
| for f in tqdm(as_completed(futures), total=len(futures), desc="执行进度"): | |
| idx, res = f.result() | |
| results[idx] = res | |
| logger.info("所有代码执行完成 - 开始保存筛选后的测试样例") | |
| return results | |
| from execute_tool import run_cpp_code_linux | |
| def process_code_with_logging(data_item): | |
| """包装函数,用于添加日志""" | |
| problem_id = data_item["problem_id"] | |
| try: | |
| result = run_cpp_code_linux(data_item) | |
| status = result.get("error", "Unknown") | |
| logger.info(f"执行完成 - 问题ID: {problem_id}, 状态: {status}") | |
| return result | |
| except Exception as e: | |
| logger.error(f"执行异常 - 问题ID: {problem_id}, 错误: {str(e)}") | |
| data_item["error"] = ["EXE"] | |
| data_item["details"] = str(e) | |
| return data_item | |
| from load_response import get_response_function, load_data, load_qwen3_result | |
| if __name__ == "__main__": | |
| datasets_name = "tcb" | |
| testcase_alg = "lcb" | |
| model_name = "" | |
| import argparse | |
| # 创建一个解析器 | |
| parser = argparse.ArgumentParser(description="Process testcase algorithm and model name.") | |
| # 添加命令行参数 | |
| parser.add_argument('--model_name', type=str, default="gpt-4o", help="Model name.") | |
| # 解析命令行参数 | |
| args = parser.parse_args() | |
| # 将命令行参数赋值给变量 | |
| model_name = args.model_name | |
| if not os.path.exists(f"/home/luoxianzhen/yang/save_tests_{model_name}-add"): | |
| os.mkdir(f"/home/luoxianzhen/yang/save_tests_{model_name}-add") | |
| if not os.path.exists(f"/home/luoxianzhen/yang/save_tests_{model_name}-add/{testcase_alg}-edge/"): | |
| os.mkdir(f"/home/luoxianzhen/yang/save_tests_{model_name}-add/{testcase_alg}-edge/") | |
| test_dir = f"/home/luoxianzhen/yang/save_tests_{model_name}-add/{testcase_alg}-edge/" + "tests-{}.jsonl" | |
| pass_rate_save_file = f"/home/luoxianzhen/yang/save_tests_{model_name}-add/{testcase_alg}-edge/test_pass_rate.jsonl" | |
| logger = setup_logging() | |
| logger.info("开始执行代码评估...") | |
| data = get_response_function(repsonse_path="/home/luoxianzhen/yang/data/response-orginal/orginal_response_{}_{}.jsonl", model_name=model_name, test_al=testcase_alg) | |
| logger.info(f"加载了 {len(data)} 个代码项目") | |
| cpu = 100 | |
| # cpu = 2 | |
| logger.info(f"使用 {cpu} 个CPU核心进行并行处理 === Stage 1 == 生成函数") | |
| import time | |
| results_input = run_all(data, cpu) | |
| logger.info(f"使用 {cpu} 个CPU核心进行并行处理 === Stage 2 == 执行得到输出") | |
| testcase_data = load_data(results_input) | |
| with Pool(cpu) as pool: | |
| results = list(tqdm( | |
| pool.imap_unordered(process_code_with_logging, testcase_data), | |
| total=len(testcase_data), | |
| desc="执行进度" | |
| )) | |
| logger.info("所有代码执行完成 - 开始保存筛选后的测试样例") | |
| result_dict = {} | |
| for item in results: | |
| if item['problem_id'] not in result_dict.keys(): | |
| result_dict[item['problem_id']] = [item, ] | |
| else: | |
| result_dict[item['problem_id']].append(item) | |
| for k, v in result_dict.items(): | |
| testcases = [test_list for test_list in results_input if test_list["tcb_id"] == k][0]['generate_testcases'] | |
| save_path = test_dir.format(k) | |
| status_array_length = len(v[0]["error"]) | |
| remove_index = [] | |
| status_wrong = ["TLE", "RE", "MLE", "WA", "EXE"] | |
| for idx, item in enumerate(v): | |
| if all(e in status_wrong for e in item["error"]): | |
| remove_index.append(idx) | |
| remove_index.sort(reverse=True) | |
| for idx in remove_index: | |
| del v[idx] | |
| remove_index = [] | |
| status_wrong = ["TLE", "RE", "MLE", "EXE"] | |
| for idx, item in enumerate(v): | |
| if all(e in status_wrong for e in item["error"]): | |
| remove_index.append(idx) | |
| remove_index.sort(reverse=True) | |
| for idx in remove_index: | |
| del v[idx] | |
| saved_nums = 0 | |
| for i in range(status_array_length): | |
| all_AC = True | |
| output_list = [] | |
| for item in v: | |
| if i >= len(item["error"]): | |
| logger.info(f"{k} error list not the same") | |
| continue | |
| if item["error"][i] == "EXE": | |
| all_AC = False | |
| break | |
| output_list.append(item['details'][i]) | |
| if not item["error"][i] == "AC": | |
| all_AC = False | |
| if all_AC and (len(list(set(output_list))) == 1 or check_difference(output_list)): | |
| if len(output_list) <= 0 or output_list[0] == "": | |
| continue | |
| curr_tests = { | |
| 'input': testcases[i], | |
| 'output': output_list[0] | |
| } | |
| append_dict_to_jsonl(save_path, curr_tests) | |
| saved_nums += 1 | |
| append_dict_to_jsonl(pass_rate_save_file, { | |
| "tcb_id": k, | |
| "gen_nums": len(testcases), | |
| "right_nums": saved_nums | |
| }) | |
Xet Storage Details
- Size:
- 11.6 kB
- Xet hash:
- 777ee606dbcdef859bc85023044ad69a20d81db816957e484f8d0ec636d96cb4
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.