Tsukihjy/testcase / methods /lcb /pipeline.py
Tsukihjy's picture
download
raw
8.78 kB
import sys
sys.path.append("/home/i-luoxianzhen/data/TestCase-Gen/methods/utils")
from response import TurboResponser, OpenResponser
from dataset_all import get_datasets_by_name
from is_correct import test_output_comparison
from config import cfg
from prompt import random_input_prompt, adversarial_prompt, random_input_template, adversarial_input_template, shots
from typing import List, Optional
import json
import traceback
from excute_tool_linux import run_cpp_code_linux, run_func_code, run_python_code_linux
from parallel_exe import run_func_code_parallel
example = shots
def write_json_to_file(data, filepath):
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=4)
import json
def append_dict_to_jsonl(file_path, data_dict):
with open(file_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(data_dict, ensure_ascii=False) + '\n')
import datetime
def write_log(message: str, log_file: str = "log-lcb.txt"):
"""
Append a timestamped log message to a log file.
Args:
message (str): The message to log.
log_file (str): The path to the log file (default is 'log.txt').
Returns:
None
"""
timestamp = datetime.datetime.now().strftime("[%Y-%m-%d %H:%M:%S]")
with open(log_file, "a+", encoding="utf-8") as f:
f.write(f"{timestamp} {message}\n")
import re
def extract_code(ans_str):
pattern = r'```python\n(.*?)```'
matches = re.findall(pattern, ans_str, re.DOTALL)
return matches[-1]
def extract_content_code(ans_str):
pattern = r'<ASSISTANT>(.*?)</ASSISTANT>'
matches = re.findall(pattern, ans_str, re.DOTALL)
return matches[-1]
import random
import os
from collections import Counter
import time
def lcb_pipeline(batch, api_key):
responser = TurboResponser(api_key=api_key, api_base=cfg.api_base, model= cfg.model_name)
log_file = cfg.log_file.format(batch)
print(cfg.model_name)
write_log(f"Model {cfg.model_name} batch {batch} start", log_file)
al_dataset = get_datasets_by_name(cfg.dataset_name)
orginal_response_file = cfg.response_file.format("lcb", cfg.model_name)
start_pos = (batch) * 100
end_pos = (batch + 1) * 100 if (batch + 1) * 100 <= len(al_dataset) else len(al_dataset)
al_dataset = al_dataset[start_pos: end_pos]
# al_dataset = al_dataset[81:]
print(f"datasets lenght {len(al_dataset)}")
for item in al_dataset:
write_log(f"{item['tcb_id']} start", log_file)
orginal_response = {
"tcb_id": item["tcb_id"]
}
time_of_gen = 0
problem_test_count = 0
id = item["tcb_id"]
succes_gen = False
if os.path.exists(cfg.tests_path.format(id)):
write_log(f"{item['tcb_id']} exist", log_file)
continue
## 有时候生成效果很差,我们重新生成一下函数,但不超过3次
orginal_response['random_generator'] = []
orginal_response['adversarial_generator'] = []
while time_of_gen < 3 and not succes_gen:
time_of_gen += 1
try:
generators = []
# stage 1 构造随机与对抗生成函数,并且执行得到输入
for _ in range(3):
try:
random_generator = responser.respond(
random_input_prompt.replace("{EXAMPLE_PROBLEM}", example),
user_prompt= random_input_template.replace("{PROBLEM}", item["query_en"]),
)
orginal_response['random_generator'].append(random_generator)
if "```python" in random_generator:
random_generator = extract_code(random_generator)
else:
random_generator = extract_content_code(random_generator)
generators.append(random_generator)
except Exception as e:
time.sleep(1)
write_log(f"{item['tcb_id']} fail to generate input func {e}", log_file)
continue
for _ in range(2):
try:
adversarial_generator = responser.respond(
system_info=adversarial_prompt.replace("{EXAMPLE_PROBLEM}", example),
user_prompt=adversarial_input_template.replace("{PROBLEM}", item["query_en"]),
)
orginal_response['adversarial_generator'].append(adversarial_generator)
if "```python" in adversarial_generator:
adversarial_generator = extract_code(adversarial_generator)
else:
adversarial_generator = extract_content_code(adversarial_generator)
generators.append(adversarial_generator)
except Exception as e:
time.sleep(1)
write_log(f"{item['tcb_id']} fail to generate input func_ {e}", log_file)
continue
# 这里并行的执行生成函数的方法
try:
inputs_list = run_func_code_parallel(generators, funcname="construct_inputs", time_limit=100)
except Exception as e:
write_log(f"{item['tcb_id']} fail to generate input {e}", log_file)
continue
# 原文要求筛选100个
generate_testcases = []
for random_inputs in inputs_list:
if isinstance(random_inputs, List) and len(random_inputs) > 0 and isinstance(random_inputs[0], str):
generate_testcases = generate_testcases + random_inputs
if len(generate_testcases) > 100:
generate_testcases = random.sample(generate_testcases, 100)
#stage 2 Golden code -- output
fail_count = 0
for test in generate_testcases:
try:
# 利用三种标准解,投票取最多
test_outputs = []
for solution in item["solutions"]:
try :
execute_res = run_cpp_code_linux(solution['code'], test, time_limit=3)
except Exception as e:
write_log(f"{item['tcb_id']} solution execute fail {e}", log_file)
write_log(f"{test}", log_file)
if "stdout" in execute_res.keys() and execute_res["stdout"] != "":
test_outputs.append(execute_res["stdout"])
fail_count = 0
counter = Counter(test_outputs)
test_output, _ = counter.most_common(1)[0]
except Exception as e:
write_log(f"{item['tcb_id']} fail to generate {e}", log_file)
fail_count += 1
if fail_count > 10:
break
continue
testcase_curr = {
"input": test,
"output": test_output
}
append_dict_to_jsonl(cfg.tests_path.format(id), testcase_curr)
write_log(f"{item['tcb_id']} + 1", log_file)
problem_test_count += 1
if problem_test_count < 10:
# 生成数量太少,重新来一遍
succes_gen = False
continue
succes_gen = True
append_dict_to_jsonl(orginal_response_file, orginal_response)
except Exception as e:
write_log(f"{item['tcb_id']} error!: {e}", log_file)
continue
import argparse
if __name__ == "__main__":
keys = [
"ak-8f3d147b2c9a5e6m0n4p8x2v7y1k3l9",
"ak-63d1efgh47i8jkl26mno95pqrs34tuv7x2",
"ak-3f8a2c9e1b7d4f6h5j2k8m3n9p4r6t7",
"ak-58d7efgh23i4jkl67mno89pqrs01tuv6k5",
]
# 进入目录 cd data/TestGen
# 启动 python /home/i-luoxianzhen/data/TestCase-Gen/methods/lcb/pipeline.py 0
parser = argparse.ArgumentParser(description="接收1个命令行参数")
parser.add_argument("param1", type=str, help="第一个参数")
args = parser.parse_args()
batch = int(args.param1)
# batch = 0
key = keys[batch // 3]
lcb_pipeline(batch, key)

Xet Storage Details

Size:
8.78 kB
·
Xet hash:
6fae970a11dc26cefbe208d6bcae9caf674905f1f073622fdeb3d380a6c779fd

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.