Tsukihjy/testcase / methods /ALGO /parallel_exe.py
Tsukihjy's picture
download
raw
7.49 kB
import os
import json
from datetime import datetime
from load_data import get_data, save_back_results
from excute_tool_linux_parallel import run_cpp_code_linux
from multiprocessing import Pool, cpu_count
from tqdm import tqdm
import logging
import re
def is_decimal(s):
try:
a = float(s)
except:
return False
return bool(re.match(r"^-?\d+\.\d+$", s))
def get_testcases(testcase_path):
data = []
if not os.path.exists(testcase_path):
return []
with open(testcase_path, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
data.append(json.loads(line))
return data
def check_difference(arr):
if not is_decimal(arr[0]):
return False
# 遍历每一对元素
for i in range(len(arr)):
for j in range(i + 1, len(arr)):
# 计算差异并判断是否满足条件
if abs(float(arr[i]) - float(arr[j])) > 1e-6:
return False # 如果有任何一对差异大于 1e-6,则返回 False
return True # 所有对的差异都不大于 1e-6,返回 True
def append_dict_to_jsonl(file_path, data_dict):
with open(file_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(data_dict, ensure_ascii=False) + '\n')
def setup_logging():
os.makedirs("logs", exist_ok=True)
log_file = f"logs/execution_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log"
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
return logging.getLogger()
def process_code_with_logging(data_item):
"""包装函数,用于添加日志"""
problem_id = data_item["problem_id"]
try:
result = run_cpp_code_linux(data_item)
status = result.get("error", "Unknown")
logger.info(f"执行完成 - 问题ID: {problem_id}, 状态: {status}")
return result
except Exception as e:
logger.error(f"执行异常 - 问题ID: {problem_id}, 错误: {str(e)}")
data_item["error"] = ["EXE"]
data_item["details"] = str(e)
return data_item
def save_results(results, correct_code_output_file, output_file):
"""保存结果到文件"""
# 初始化结果字典
problem_results = {}
status_counts = {"AC": 0, "CE": 0, "TLE": 0, "MLE": 0, "RE": 0, "WA": 0, "EXE": 0}
# 分类归整结果
for result in results:
problem_id = result["problem_id"]
code_id = result["code_id"]
status = result.get("error", "Unknown")
# 更新状态计数
# if status in status_counts:
# status_counts[status] += 1
# if len(status) == 0:
# status_counts["AC"] += 1
for status_name in status_counts.keys():
if status_name == "AC" and not all(sta == "AC" for sta in status):
continue
if status_name in status:
status_counts[status_name] += 1
# 加入问题结果集
if problem_id not in problem_results:
problem_results[problem_id] = {
"problem_id": problem_id,
"codes": [],
"time_limit": result["time_limit"],
"memory_limit": result["memory_limit"],
"test_cases": result["test_cases"]
}
# 添加代码执行结果
problem_results[problem_id]["codes"].append({
"code_id": code_id,
"code": result["code"],
"status": status,
"details": result.get("details", "")
})
# 保存完整结果
with open(output_file, "w", encoding="utf-8") as f:
json.dump(problem_results, f, indent=3)
# 保存正确的代码(AC状态)
correct_codes = {}
for problem_id, problem_data in problem_results.items():
correct_problem_codes = []
for code_info in problem_data["codes"]:
if all(status == "AC" for status in code_info["status"]):
correct_problem_codes.append({
"code_id": code_info["code_id"],
"code": code_info["code"]
})
if correct_problem_codes:
correct_codes[problem_id] = {
"problem_id": problem_id,
"codes": correct_problem_codes,
"time_limit": problem_data["time_limit"],
"memory_limit": problem_data["memory_limit"],
"test_cases": problem_data["test_cases"]
}
# 保存正确代码结果
with open(correct_code_output_file, "w", encoding="utf-8") as f:
json.dump(correct_codes, f, indent=3)
# 返回状态统计
return status_counts, problem_results
if __name__ == "__main__":
datasets_name = "tcb"
testcase_alg = "lcb"
save_dir = f"/home/i-luoxianzhen/data/TestCase-Gen/saved_test_filterd/{testcase_alg}/" + "tests-{}.jsonl"
test_dir = f"/home/i-luoxianzhen/data/TestCase-Gen/save_tests_scaling/{testcase_alg}/" + "tests-{}.jsonl"
pass_rate_save_file = f"/home/i-luoxianzhen/data/TestCase-Gen/saved_test_filterd/{testcase_alg}/test_pass_rate.jsonl"
logger = setup_logging()
logger.info("开始执行代码评估...")
data = get_data(name=datasets_name, prefix_dir=test_dir, save_dir=save_dir)
logger.info(f"加载了 {len(data)} 个代码项目")
cpu = 50
logger.info(f"使用 {cpu} 个CPU核心进行并行处理")
with Pool(cpu) as pool:
results = list(tqdm(
pool.imap_unordered(process_code_with_logging, data),
total=len(data),
desc="执行进度"
))
logger.info("所有代码执行完成 - 开始保存筛选后的测试样例")
result_dict = {}
for item in results:
if item['problem_id'] not in result_dict.keys():
result_dict[item['problem_id']] = [item, ]
else:
result_dict[item['problem_id']].append(item)
for k, v in result_dict.items():
testcases = get_testcases(test_dir.format(k))
save_path = save_dir.format(k)
status_array_length = len(v[0]["error"])
saved_nums = 0
for i in range(status_array_length):
all_AC = True
## 部分输出为空的,如果新的输出全部一致也可以保留
all_WA = True
output_list = []
for item in v:
if item["error"][i] == "EXE":
all_AC = False
all_WA = False
break
output_list.append(item['details'][i])
if not item["error"][i] == "AC":
all_AC = False
if not item["error"][i] == "WA":
all_WA = False
if all_AC:
append_dict_to_jsonl(save_path, testcases[i])
saved_nums += 1
if all_WA and (len(list(set(output_list))) == 1 or check_difference(output_list)):
testcases[i]['output'] = output_list[0]
append_dict_to_jsonl(save_path, testcases[i])
saved_nums += 1
append_dict_to_jsonl(pass_rate_save_file, {
"tcb_id": k,
"gen_nums": len(testcases),
"right_nums": saved_nums
})

Xet Storage Details

Size:
7.49 kB
·
Xet hash:
ae48aa84edeffec21dfcc84b505026fd8beac787afe7ef71ad576c7af8884849

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.