Tsukihjy's picture
download
raw
9.19 kB
import sys
sys.path.append("/home/i-luoxianzhen/data/TestCase-Gen/methods/utils")
from response import TurboResponser, OpenResponser
from dataset_all import get_datasets_by_name
from is_correct import test_output_comparison
from config import cfg
from prompt import validator_prompt, input_generator_prompt
from typing import List, Optional
import json
import traceback
from excute_tool_linux import run_cpp_code_linux, run_func_code, run_python_code_linux
from parallel_exe import run_func_code_parallel
def write_json_to_file(data, filepath):
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
import json
def append_dict_to_jsonl(file_path, data_dict):
with open(file_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(data_dict, ensure_ascii=False) + '\n')
import datetime
def write_log(message: str, log_file: str = "log-lcb.txt"):
"""
Append a timestamped log message to a log file.
Args:
message (str): The message to log.
log_file (str): The path to the log file (default is 'log.txt').
Returns:
None
"""
timestamp = datetime.datetime.now().strftime("[%Y-%m-%d %H:%M:%S]")
with open(log_file, "a+", encoding="utf-8") as f:
f.write(f"{timestamp} {message}\n")
import re
def extract_function_names(code_str: str):
# 正则表达式匹配函数定义,考虑到不同的编程语言格式
pattern = r'\bdef\s+(\w+)\s*\(' # 适用于 Python 函数
function_names = re.findall(pattern, code_str)
return function_names
def extract_code(ans_str):
pattern = r'```json\n(.*?)```'
matches = re.findall(pattern, ans_str, re.DOTALL)
return matches[-1]
def extract_content_code(ans_str):
pattern = r'<ASSISTANT>(.*?)</ASSISTANT>'
matches = re.findall(pattern, ans_str, re.DOTALL)
return matches[-1]
import random
import os
from collections import Counter
import time
def ht_pipeline(batch, api_key):
responser = TurboResponser(api_key=api_key, api_base=cfg.api_base, model= cfg.model_name)
log_file = cfg.log_file.format(batch)
orginal_response_file = cfg.response_file.format("ht", cfg.model_name)
al_dataset = get_datasets_by_name(cfg.dataset_name)
start_pos = (batch) * 100
end_pos = (batch + 1) * 100 if (batch + 1) * 100 <= len(al_dataset) else len(al_dataset)
al_dataset = al_dataset[start_pos: end_pos]
write_log(f"Model: {cfg.model_name} - data {cfg.dataset_name} -batch {batch} start", log_file)
print(f"Model: {cfg.model_name} - data {cfg.dataset_name} -batch {batch} start")
print(f"datasets lenght {len(al_dataset)}")
nR = cfg.nR
nH = cfg.nH
for item in al_dataset:
write_log(f"{item['tcb_id']} start", log_file)
orginal_response = {
"tcb_id": item["tcb_id"]
}
time_of_gen = 0
problem_test_count = 0
id = item["tcb_id"]
success_gen = False
if os.path.exists(cfg.tests_path.format(id)):
write_log(f"{item['tcb_id']} exist", log_file)
continue
## 有时候生成效果很差,我们重新生成一下函数,但不超过3次
while time_of_gen < 3 and not success_gen:
time_of_gen += 1
generate_testcases = []
try:
## 获取验证函数与生成函数
response_validator = responser.respond(system_info="You are a helpful code assistant.", user_prompt=validator_prompt.replace("{{ problem_specification }}", item['query_en']).replace("{{ oracle_program }}", item['solutions'][0]['code']))
orginal_response['response_validator'] = response_validator
response_validator_dict = json.loads(extract_code(response_validator))
## 获取生成函数
response_generator = responser.respond(
system_info="You are a helpful code assistant.",
user_prompt=input_generator_prompt.replace("{{ problem_specification }}", item['query_en']).replace("{{ oracle_program }}", item['solutions'][0]['code']).replace("{{ input_validator }}", response_validator_dict['input_validator'])
)
orginal_response['response_generator'] = response_generator
response_generator_dict = json.loads(extract_code(response_generator))
## 整理直接生成的输入
for test_input in response_generator_dict['directly_generated_inputs']:
generate_testcases.append({
'input':test_input,
"source":"prompt"
})
## 添加随机生成输入
## 首先需要抽取有哪些方法,然后调用 nR // func 次
function_list = extract_function_names(response_generator_dict['regular_input_generator'])
round = nR // len(function_list)
if round <= 0:
round = 1
for _ in range(round):
for funcname in function_list:
test_input = run_func_code_parallel([response_generator_dict['regular_input_generator']], funcname=funcname, time_limit=100)[0]
generate_testcases.append({
'input':test_input,
"source":"random"
})
## Hacking 输入
function_list = extract_function_names(response_generator_dict['hacking_input_generator'])
for _ in range(nH):
for funcname in function_list:
test_input = run_func_code_parallel([response_generator_dict['hacking_input_generator']], funcname=funcname, time_limit=100)[0]
generate_testcases.append({
'input':test_input,
"source":"edge"
})
## stage 2 Validate function
generate_testcases = generate_testcases
generate_testcases_passed = []
for test_dict in generate_testcases:
is_valid = run_func_code(response_validator_dict['input_validator'], funcname="validate_input", param=test_dict['input'],time_limit=100)
if is_valid:
generate_testcases_passed.append(test_dict)
#stage 3 Golden code -- output
for test in generate_testcases_passed:
try:
# 利用三种标准解,投票取最多
test_outputs = []
for solution in item["solutions"]:
try :
execute_res = run_cpp_code_linux(solution['code'], test['input'], time_limit=3)
except Exception as e:
write_log(f"{item['tcb_id']} solution execute fail {e}", log_file)
write_log(f"{test}", log_file)
if "stdout" in execute_res.keys() and execute_res["stdout"] != "":
test_outputs.append(execute_res["stdout"])
counter = Counter(test_outputs)
test_output, _ = counter.most_common(1)[0]
except Exception as e:
write_log(f"{item['tcb_id']} fail to generate {e}", log_file)
continue
test['output'] = test_output
append_dict_to_jsonl(cfg.tests_path.format(id), test)
write_log(f"{item['tcb_id']} + 1", log_file)
problem_test_count += 1
if problem_test_count < 10:
# 生成数量太少,重新来一遍
success_gen = False
continue
success_gen = True
if success_gen:
testcases_pass_rate = {
'tcb_id': id,
"gen_nums": len(generate_testcases),
"right_nums": len(generate_testcases_passed)
}
append_dict_to_jsonl(cfg.pass_rate_file.format(batch), testcases_pass_rate)
append_dict_to_jsonl(orginal_response_file, orginal_response)
except Exception as e:
write_log(f"{item['tcb_id']} error!: {e}", log_file)
continue
import argparse
if __name__ == "__main__":
keys = [
"ak-8f3d147b2c9a5e6m0n4p8x2v7y1k3l9",
"ak-63d1efgh47i8jkl26mno95pqrs34tuv7x2",
"ak-3f8a2c9e1b7d4f6h5j2k8m3n9p4r6t7",
"ak-58d7efgh23i4jkl67mno89pqrs01tuv6k5",
]
# 进入目录 cd data/TestGen
# 启动 python /home/i-luoxianzhen/data/TestCase-Gen/methods/lcb/pipeline.py 0
parser = argparse.ArgumentParser(description="接收1个命令行参数")
parser.add_argument("param1", type=str, help="第一个参数")
args = parser.parse_args()
batch = int(args.param1)
key = keys[batch // 3]
ht_pipeline(batch, key)

Xet Storage Details

Size:
9.19 kB
·
Xet hash:
a6505dbc4ff8de9c410b93da5385759962de44d2bc0b6681f54cada8792ea311

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.