Tsukihjy's picture
download
raw
8.82 kB
import sys
sys.path.append("/home/i-luoxianzhen/data/TestCase-Gen/methods/utils")
from response import TurboResponser, OpenResponser
from dataset_all import get_datasets_by_name
from is_correct import test_output_comparison
from config import cfg
from prompt import test_gen_prompt, predo_prompt, input_template, System_prompt
from typing import List, Optional
import json
from collections import Counter
from execute_tool_linux import run_cpp_code_linux, run_func_code, run_python_code_linux
import os
import re
import json
def append_dict_to_jsonl(file_path, data_dict):
with open(file_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(data_dict, ensure_ascii=False) + '\n')
def extract_jsonl_from_markdown(markdown_text):
# 使用正则表达式匹配 Markdown 中的 JSON 格式数据
json_pattern = re.compile(r'```json(.*?)```', re.DOTALL)
matches = json_pattern.findall(markdown_text)
# 将匹配到的 JSON 字符串解析为字典
json_data = []
for match in matches:
try:
json_data.append(json.loads(match.strip()))
except json.JSONDecodeError:
continue
return json_data[0]
import datetime
def write_log(message: str, log_file: str = "log-predo.txt"):
"""
Append a timestamped log message to a log file.
Args:
message (str): The message to log.
log_file (str): The path to the log file (default is 'log.txt').
Returns:
None
"""
timestamp = datetime.datetime.now().strftime("[%Y-%m-%d %H:%M:%S]")
with open(log_file, "a", encoding="utf-8") as f:
f.write(f"{timestamp} {message}\n")
def extract_code_between_tags(text, start_tag="<BEGIN>", end_tag="<END>"):
"""
Extracts the code between <BEGIN> and <END> tags in a given text.
Args:
text (str): The full text containing the code block.
start_tag (str): The start tag (default is "<BEGIN>").
end_tag (str): The end tag (default is "<END>").
Returns:
str: The extracted code block, or an empty string if not found.
"""
match = re.search(f"{re.escape(start_tag)}\\s*(.*?)\\s*{re.escape(end_tag)}", text, re.DOTALL)
return match.group(1).strip() if match else ""
import json
import os
def extract_ids_from_jsonl(file_pattern):
ids = []
for i in range(9):
filepath = file_pattern.format(i)
if not os.path.exists(filepath):
continue
with open(filepath, 'r', encoding='utf-8') as file:
for line in file.readlines():
data = json.loads(line)
if 'tcb_id' in data:
ids.append(data['tcb_id'])
return ids
def write_json_to_file(data, filepath):
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
def predo(batch, key):
tests_gen_responser = TurboResponser(api_key=key, api_base=cfg.api_base, model= cfg.model_name_test_gen)
predo_gen_responser = TurboResponser(api_key=key, api_base=cfg.api_base, model= cfg.model_name_answer_gen)
log_file = cfg.log_file.format(batch)
orginal_response_file = cfg.response_file.format("predo", cfg.model_name_test_gen)
# finished_list = extract_ids_from_jsonl(cfg.pass_rate_file)
al_dataset = get_datasets_by_name(cfg.dataset_name)
start_pos = (batch) * 100
end_pos = (batch + 1) * 100 if (batch + 1) * 100 <= len(al_dataset) else len(al_dataset)
al_dataset = al_dataset[start_pos: end_pos]
print(f"start {start_pos}, end {end_pos}")
print(f"datasets len {len(al_dataset)}")
for idx, item in enumerate(al_dataset):
tcb_id = item['tcb_id']
orginal_response = {
"tcb_id": item["tcb_id"]
}
gen_test_count = 0
testcase_list = []
testcase_json = ""
try:
write_log(f"{item['tcb_id']} - {idx} - start", log_file)
# if tcb_id in finished_list:
# write_log(f"{item['tcb_id']} exist", log_file)
# continue
# if os.path.exists(cfg.tests_path.format(tcb_id)):
# write_log(f"{item['tcb_id']} exist", log_file)
# continue
# Stage 1 生成测试样例
MAX_TRIES = 3
while (testcase_list is None or len(testcase_list) <= 50) and MAX_TRIES > 0:
MAX_TRIES -= 1
try:
testcase_json = tests_gen_responser.respond(system_info=test_gen_prompt, user_prompt=input_template.replace("[[Question]]", item["query_en"]))
orginal_response['testcase_json'] = testcase_json
if '```json' in testcase_json:
testcase_list += list(extract_jsonl_from_markdown(testcase_json).values())
else:
testcase_list += list(json.loads(testcase_json).values())
except Exception as e:
continue
# stage 2 生成多种解 -- 这里实现的方式是多次请求
# 不知道要不要换成 pass@k 那样的
# 另外为生成的多样性,提高了温度 : 0.7
orginal_response['code_gen'] = []
code_gen_list = []
for _ in range(cfg.predo_count):
try:
code_gen = predo_gen_responser.respond(system_info=System_prompt,
user_prompt=predo_prompt.replace("{question}", item['query_en']), temperature=0.7)
orginal_response['code_gen'].append(code_gen)
code_gen_list.append(extract_code_between_tags(code_gen))
except Exception as e:
continue
# 将测试样例输入执行多中解,投票得到输出
for test in testcase_list:
try:
temp_outputs = []
for code_str in code_gen_list:
res = run_python_code_linux(code_str, test, time_limit=5)
if 'Error' in res:
continue
actual_lines = [line.strip() for line in res.splitlines()]
actual_lines = [line for line in actual_lines if line]
result_string = "\n".join(actual_lines)
temp_outputs.append(result_string)
counter = Counter(temp_outputs)
most_common, _ = counter.most_common(1)[0]
current_test = {
'input': test,
'output': most_common
}
# 验证结果,与三个标准解法有一个相同则认为正确
for solution in (item["solutions"][0:3]):
try :
execute_res = run_cpp_code_linux(solution['code'], current_test["input"], time_limit=3)
except Exception as e:
write_log(f"{item['tcb_id']} solution execute fail", log_file)
if "stdout" in execute_res.keys() and test_output_comparison(current_test['output'], execute_res["stdout"]):
append_dict_to_jsonl(cfg.tests_path.format(tcb_id), current_test)
gen_test_count += 1
write_log(f"{item['tcb_id']} +1", log_file)
break
except Exception as e:
write_log(f"{item['tcb_id']} error {e}", log_file)
append_dict_to_jsonl(orginal_response_file, orginal_response)
except Exception as e:
write_log(f"{item['tcb_id']} error {e}", log_file)
finally:
testcases_pass_rate = {
"tcb_id": tcb_id,
"gen_nums": len(testcase_list),
"right_nums": gen_test_count
}
append_dict_to_jsonl(cfg.pass_rate_file.format(batch), testcases_pass_rate)
if len(testcase_list) == 0:
write_log(f"{item['tcb_id']} generate no tests response: \n{testcase_json}", log_file)
if __name__ == "__main__":
import argparse
keys = [
"ak-3f8a2c9e1b7d4f6h5j2k8m3n9p4r6t7",
"ak-8f3d147b2c9a5e6m0n4p8x2v7y1k3l9",
"ak-58d7efgh23i4jkl67mno89pqrs01tuv6k5",
"ak-63d1efgh47i8jkl26mno95pqrs34tuv7x2",
]
parser = argparse.ArgumentParser(description="接收1个命令行参数")
parser.add_argument("param1", type=str, help="第一个参数")
args = parser.parse_args()
batch = int(args.param1)
# batch = 8
key = keys[batch // 3]
predo(batch, key)

Xet Storage Details

Size:
8.82 kB
·
Xet hash:
476ff4fb57b4f1df6307bb17757d0f3d8abb8afc653c11ab23a781554408f5e7

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.