Tsukihjy/testcase / methods /lcb /parallel_executor_debug.py
Tsukihjy's picture
download
raw
11.6 kB
import os
import json
from datetime import datetime
from multiprocessing import Pool, cpu_count
from tqdm import tqdm
import logging
from typing import List, Optional
import random
import re
def is_decimal(s):
try:
a = float(s)
except:
return False
return bool(re.match(r"^-?\d+\.\d+$", s))
def get_testcases(testcase_path):
data = []
if not os.path.exists(testcase_path):
return []
with open(testcase_path, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
data.append(json.loads(line))
return data
def check_difference(arr):
for num in arr:
if not is_decimal(num):
return False
# 遍历每一对元素
for i in range(len(arr)):
for j in range(i + 1, len(arr)):
# 计算差异并判断是否满足条件
if abs(float(arr[i]) - float(arr[j])) > 1e-6:
return False # 如果有任何一对差异大于 1e-6,则返回 False
return True # 所有对的差异都不大于 1e-6,返回 True
def append_dict_to_jsonl(file_path, data_dict):
with open(file_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(data_dict, ensure_ascii=False) + '\n')
def setup_logging():
os.makedirs("logs", exist_ok=True)
log_file = f"logs/execution_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
return logging.getLogger()
from execute_tool import function_execute_box_subprocess, function_execute_box_process
logger = logging.getLogger(__name__)
def process_function(data_item):
tcb_id = data_item['tcb_id']
# 原代码里是 'eage',这里按语义改正为 'edge'
random_func_list = data_item['func_list'].get('random', [])
edge_func_list = data_item['func_list'].get('eage', [])
random_test_inputs: List = []
edge_test_inputs: List = []
# —— 随机用例:每个函数最多执行一次,累计成功 3 个函数即停 ——
# random_count = 0
# for random_func in random_func_list:
# try:
# res = function_execute_box_process(random_func, funcname="construct_inputs", time_limit=50)
# except:
# print("Generate Error")
# continue
# if res.get('status') == 'success' and isinstance(res.get('result'), list):
# random_test_inputs += res['result']
# random_count += 1
# # if res.get('status') == 'error':
# # print(res.get('status'), res.get('details'))
# if random_count >= 5:
# break
# —— 边界用例:累计成功 2 个函数即停 ——
edge_count = 0
for edge_func in edge_func_list:
try:
res = function_execute_box_process(edge_func, funcname="construct_inputs", time_limit=30)
except:
continue
if res.get('status') == 'success' and isinstance(res.get('result'), list):
edge_test_inputs += res['result']
edge_count += 1
if edge_count >= 3:
break
generate_testcases = random_test_inputs + edge_test_inputs
if len(generate_testcases) > 200:
generate_testcases = random.sample(generate_testcases, 200)
return {
"tcb_id": tcb_id,
"generate_testcases": generate_testcases
}
def process_function_add(data_item):
tcb_id = data_item['tcb_id']
# 原代码里是 'eage',这里按语义改正为 'edge'
random_func_list = data_item['func_list'].get('random', [])
edge_func_list = data_item['func_list'].get('eage', [])
limit_nums = data_item['limit_nums']
random_test_inputs: List = []
edge_test_inputs: List = []
# —— 随机用例:每个函数最多执行一次,累计成功 3 个函数即停 ——
random_count = 0
gen_times = 1
while len(random_test_inputs + edge_test_inputs) < limit_nums and gen_times > 0:
gen_times -= 1
for i in range(0, max(len(random_func_list), len(edge_func_list))):
if i < len(random_func_list):
random_func = random_func_list[i]
try:
res = function_execute_box_process(random_func, funcname="construct_inputs", time_limit=10)
except:
continue
if res.get('status') == 'success' and isinstance(res.get('result'), list):
random_test_inputs += res['result']
if i < len(edge_func_list):
edge_func = edge_func_list[i]
try:
res = function_execute_box_process(edge_func, funcname="construct_inputs", time_limit=10)
except:
continue
if res.get('status') == 'success' and isinstance(res.get('result'), list):
edge_test_inputs += res['result']
if len(random_test_inputs + edge_test_inputs) > limit_nums:
break
generate_testcases = random_test_inputs + edge_test_inputs
if len(generate_testcases) > 200:
generate_testcases = random.sample(generate_testcases, 200)
return {
"tcb_id": tcb_id,
"generate_testcases": generate_testcases
}
from concurrent.futures import ThreadPoolExecutor, as_completed
def run_all(data, max_workers: int | None = None):
if max_workers is None:
# 线程池调度 + 任务内部子进程执行:一般设置为 CPU 数量或稍小
cpu = max(1, (os.cpu_count() or 8) - 2)
max_workers = cpu
logger.info(f"使用 {max_workers} 个并发 worker 进行处理(线程池调度,任务内子进程可强杀)")
results = [None] * len(data)
def _task(idx_item):
idx, item = idx_item
try:
return idx, process_function(item)
except Exception as e:
# 兜底,任何未捕获异常都返回空结果,避免整体崩溃
logger.exception("处理单条数据时发生异常")
return idx, {"tcb_id": item.get("tcb_id"), "generate_testcases": []}
with ThreadPoolExecutor(max_workers=max_workers) as ex:
futures = [ex.submit(_task, (i, data[i])) for i in range(len(data))]
for f in tqdm(as_completed(futures), total=len(futures), desc="执行进度"):
idx, res = f.result()
results[idx] = res
logger.info("所有代码执行完成 - 开始保存筛选后的测试样例")
return results
from execute_tool import run_cpp_code_linux
def process_code_with_logging(data_item):
"""包装函数,用于添加日志"""
problem_id = data_item["problem_id"]
try:
result = run_cpp_code_linux(data_item)
status = result.get("error", "Unknown")
logger.info(f"执行完成 - 问题ID: {problem_id}, 状态: {status}")
return result
except Exception as e:
logger.error(f"执行异常 - 问题ID: {problem_id}, 错误: {str(e)}")
data_item["error"] = ["EXE"]
data_item["details"] = str(e)
return data_item
from load_response import get_response_function, load_data, load_qwen3_result
if __name__ == "__main__":
datasets_name = "tcb"
testcase_alg = "lcb"
model_name = ""
import argparse
# 创建一个解析器
parser = argparse.ArgumentParser(description="Process testcase algorithm and model name.")
# 添加命令行参数
parser.add_argument('--model_name', type=str, default="gpt-4o", help="Model name.")
# 解析命令行参数
args = parser.parse_args()
# 将命令行参数赋值给变量
model_name = args.model_name
if not os.path.exists(f"/home/luoxianzhen/yang/save_tests_{model_name}-add"):
os.mkdir(f"/home/luoxianzhen/yang/save_tests_{model_name}-add")
if not os.path.exists(f"/home/luoxianzhen/yang/save_tests_{model_name}-add/{testcase_alg}-edge/"):
os.mkdir(f"/home/luoxianzhen/yang/save_tests_{model_name}-add/{testcase_alg}-edge/")
test_dir = f"/home/luoxianzhen/yang/save_tests_{model_name}-add/{testcase_alg}-edge/" + "tests-{}.jsonl"
pass_rate_save_file = f"/home/luoxianzhen/yang/save_tests_{model_name}-add/{testcase_alg}-edge/test_pass_rate.jsonl"
logger = setup_logging()
logger.info("开始执行代码评估...")
data = get_response_function(repsonse_path="/home/luoxianzhen/yang/data/response-orginal/orginal_response_{}_{}.jsonl", model_name=model_name, test_al=testcase_alg)
logger.info(f"加载了 {len(data)} 个代码项目")
cpu = 100
# cpu = 2
logger.info(f"使用 {cpu} 个CPU核心进行并行处理 === Stage 1 == 生成函数")
import time
results_input = run_all(data, cpu)
logger.info(f"使用 {cpu} 个CPU核心进行并行处理 === Stage 2 == 执行得到输出")
testcase_data = load_data(results_input)
with Pool(cpu) as pool:
results = list(tqdm(
pool.imap_unordered(process_code_with_logging, testcase_data),
total=len(testcase_data),
desc="执行进度"
))
logger.info("所有代码执行完成 - 开始保存筛选后的测试样例")
result_dict = {}
for item in results:
if item['problem_id'] not in result_dict.keys():
result_dict[item['problem_id']] = [item, ]
else:
result_dict[item['problem_id']].append(item)
for k, v in result_dict.items():
testcases = [test_list for test_list in results_input if test_list["tcb_id"] == k][0]['generate_testcases']
save_path = test_dir.format(k)
status_array_length = len(v[0]["error"])
remove_index = []
status_wrong = ["TLE", "RE", "MLE", "WA", "EXE"]
for idx, item in enumerate(v):
if all(e in status_wrong for e in item["error"]):
remove_index.append(idx)
remove_index.sort(reverse=True)
for idx in remove_index:
del v[idx]
remove_index = []
status_wrong = ["TLE", "RE", "MLE", "EXE"]
for idx, item in enumerate(v):
if all(e in status_wrong for e in item["error"]):
remove_index.append(idx)
remove_index.sort(reverse=True)
for idx in remove_index:
del v[idx]
saved_nums = 0
for i in range(status_array_length):
all_AC = True
output_list = []
for item in v:
if i >= len(item["error"]):
logger.info(f"{k} error list not the same")
continue
if item["error"][i] == "EXE":
all_AC = False
break
output_list.append(item['details'][i])
if not item["error"][i] == "AC":
all_AC = False
if all_AC and (len(list(set(output_list))) == 1 or check_difference(output_list)):
if len(output_list) <= 0 or output_list[0] == "":
continue
curr_tests = {
'input': testcases[i],
'output': output_list[0]
}
append_dict_to_jsonl(save_path, curr_tests)
saved_nums += 1
append_dict_to_jsonl(pass_rate_save_file, {
"tcb_id": k,
"gen_nums": len(testcases),
"right_nums": saved_nums
})

Xet Storage Details

Size:
11.6 kB
·
Xet hash:
777ee606dbcdef859bc85023044ad69a20d81db816957e484f8d0ec636d96cb4

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.