Tsukihjy/testcase / methods /ALGO /pipeline.py
Tsukihjy's picture
download
raw
7.98 kB
from prompt import CF_PROMPT
import sys
sys.path.append("/home/i-luoxianzhen/data/TestCase-Gen/methods/utils")
from response import TurboResponser
from dataset_all import get_datasets_by_name
from is_correct import test_output_comparison
from config import cfg
from execute_tool_linux import run_func_code, run_cpp_code_linux, run_python_code_linux
import random
import json
def write_json_to_file(data, filepath):
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
def append_dict_to_jsonl(file_path, data_dict):
with open(file_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(data_dict, ensure_ascii=False) + '\n')
class InputGenerator():
def __init__(self, responser):
self.SYSTEM_PROMPT = "You are an expert software tester tasked with thoroughly testing a given piece of code. "
self.responser = responser
def respond(self, prompt: str) -> str:
return self.responser.respond(self.SYSTEM_PROMPT, prompt)
class Verifier():
def __init__(self, responser):
self.SYSTEM_PROMPT = "You are an AI programming assistant."
self.responser = responser
def respond(self, prompt: str) -> str:
return self.responser.respond(self.SYSTEM_PROMPT, prompt)
import re
def extract_python_code_from_markdown(markdown_text):
# 正则表达式匹配带有 python 语言标记的代码块
code_block_pattern = re.compile(r'```python\n(.*?)\n```', re.DOTALL)
# 查找所有匹配的 Python 代码块
python_code_blocks = code_block_pattern.findall(markdown_text)
return python_code_blocks[0]
import datetime
def write_log(message: str, log_file: str = "log-algo.txt"):
"""
Append a timestamped log message to a log file.
Args:
message (str): The message to log.
log_file (str): The path to the log file (default is 'log.txt').
Returns:
None
"""
timestamp = datetime.datetime.now().strftime("[%Y-%m-%d %H:%M:%S]")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"{timestamp} {message}\n")
import json
import os
def extract_ids_from_jsonl(file_pattern):
ids = []
for i in range(8):
filepath = file_pattern.format(i)
if not os.path.exists(filepath):
continue
with open(filepath, 'r', encoding='utf-8') as file:
for line in file.readlines():
data = json.loads(line)
if 'tcb_id' in data:
ids.append(data['tcb_id'])
return ids
import argparse
keys = [
"ak-8f3d147b2c9a5e6m0n4p8x2v7y1k3l9",
"ak-58d7efgh23i4jkl67mno89pqrs01tuv6k5",
"ak-63d1efgh47i8jkl26mno95pqrs34tuv7x2",
"ak-3f8a2c9e1b7d4f6h5j2k8m3n9p4r6t7",
]
parser = argparse.ArgumentParser(description="接收1个命令行参数")
parser.add_argument("param1", type=str, help="第一个参数")
args = parser.parse_args()
batch = int(args.param1)
key = keys[batch // 3]
log_file = cfg.log_file.format(batch)
orginal_response_file = cfg.response_file.format("algo", cfg.model_name)
write_log(f"model {cfg.model_name}", log_file)
print(f"model {cfg.model_name}")
al_dataset = get_datasets_by_name(cfg.dataset_name)
start_pos = (batch) * 100
end_pos = (batch + 1) * 100 if (batch + 1) * 100 <= len(al_dataset) else len(al_dataset)
al_dataset = al_dataset[start_pos: end_pos]
print(f"start {start_pos} end {end_pos}")
print(f"datasets lenght {len(al_dataset)}")
input_gen = InputGenerator(TurboResponser(api_key=key, api_base=cfg.api_base, model= cfg.model_name))
oracle_gen = Verifier(TurboResponser(api_key=key, api_base=cfg.api_base, model= cfg.model_name))
testcases = {}
testcases_pass_rate = {}
# finished_list = extract_ids_from_jsonl(cfg.pass_rate_file)
count = 0
for idx, item in enumerate(al_dataset):
tcb_id = item['tcb_id']
problem_test_count = 0
write_log(f"{item['tcb_id']} - {idx} - start", log_file)
orginal_response = {
"tcb_id": item["tcb_id"]
}
# if os.path.exists(cfg.tests_path.format(tcb_id)):
# write_log(f"{item['tcb_id']} exist", log_file)
# continue
try:
# 输入生成函数, 并且测试一下,如果不可以运行成功就重新生成
MAX_TRIES = 5
gen_flag = False
while not gen_flag and MAX_TRIES > 0:
MAX_TRIES -= 1
try:
gen_prompt = CF_PROMPT["data_generator_original"]
gen_prompt = gen_prompt.format(problem=item["query_en"])
test_case_gen_code = input_gen.respond(gen_prompt)
orginal_response['test_case_gen_code'] = test_case_gen_code
test_case_gen_code = extract_python_code_from_markdown(test_case_gen_code)
input_str = run_func_code(test_case_gen_code, funcname="gen_input", time_limit=10)
if 'Error' not in input_str and 'error' not in input_str and isinstance(input_str, str):
gen_flag = True
except Exception as e:
write_log(f"{item['tcb_id']} fail to construct data_generator: {e}", log_file)
# 暴力解生成
oracle_prompt = CF_PROMPT["oracle_generator"]
oracle_prompt = oracle_prompt.format(problem=item["query_en"])
attempt = 5
while attempt > 0:
try:
oracle = oracle_gen.respond(oracle_prompt)
orginal_response['oracle'] = oracle
oracle = extract_python_code_from_markdown(oracle)
break
except Exception as e:
write_log(f"{item['tcb_id']} fail to oracle: {e}", log_file)
attempt -= 1
if attempt <= 0:
continue
append_dict_to_jsonl(orginal_response_file, orginal_response)
# 生成指定数量的测试样例
for _ in range(cfg.gen_test_nums):
try:
input_str = run_func_code(test_case_gen_code, funcname="gen_input", time_limit=10)
output_str = run_python_code_linux(oracle, input_str, time_limit=10)
except Exception as e:
write_log(f"{item['tcb_id']} fail to generate testcase {e}", log_file)
continue
if isinstance(output_str, str):
continue
if 'error' in output_str.stdout or 'Error' in output_str.stdout or output_str.returncode != 0:
continue
output_str = output_str.stdout
test_case = {
'input': input_str,
'output': output_str
}
# # 这里使用标准解也求出一版本的输出, 三个标准解法,只要有一个相同就认为正确
try:
for solution in (item["solutions"][0:3]):
try :
execute_res = run_cpp_code_linux(solution['code'], input_str, time_limit=3)
except Exception as e:
write_log(f"{item['tcb_id']} solution execute fail", log_file)
if "stdout" in execute_res.keys() and execute_res['returncode'] == 0 and test_output_comparison(execute_res["stdout"], test_case['output']):
append_dict_to_jsonl(cfg.tests_path.format(tcb_id), test_case)
problem_test_count += 1
write_log(f"{item['tcb_id']} +1", log_file)
break
except Exception as e:
write_log(f"{item['tcb_id']} fail to generate gloden-code output {e}", log_file)
continue
except Exception as e:
write_log(f"{item['tcb_id']} fail: {e}", log_file)
finally:
testcases_pass_rate = {
"tcb_id": tcb_id,
"right_nums": problem_test_count,
"gen_nums": cfg.gen_test_nums
}
append_dict_to_jsonl(cfg.pass_rate_file.format(batch), testcases_pass_rate)

Xet Storage Details

Size:
7.98 kB
·
Xet hash:
01d06621d80c5ff181298a2e0ea37ddbb7f6ecbb9db634b573fb764473aa3651

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.