| from prompt import CF_PROMPT | |
| import sys | |
| sys.path.append("/home/i-luoxianzhen/data/TestCase-Gen/methods/utils") | |
| from response import TurboResponser | |
| from dataset_all import get_datasets_by_name | |
| from is_correct import test_output_comparison | |
| from config import cfg | |
| from execute_tool_linux import run_func_code, run_cpp_code_linux, run_python_code_linux | |
| import random | |
| import json | |
| def write_json_to_file(data, filepath): | |
| with open(filepath, 'w', encoding='utf-8') as f: | |
| json.dump(data, f, ensure_ascii=False, indent=4) | |
| def append_dict_to_jsonl(file_path, data_dict): | |
| with open(file_path, 'a', encoding='utf-8') as f: | |
| f.write(json.dumps(data_dict, ensure_ascii=False) + '\n') | |
| class InputGenerator(): | |
| def __init__(self, responser): | |
| self.SYSTEM_PROMPT = "You are an expert software tester tasked with thoroughly testing a given piece of code. " | |
| self.responser = responser | |
| def respond(self, prompt: str) -> str: | |
| return self.responser.respond(self.SYSTEM_PROMPT, prompt) | |
| class Verifier(): | |
| def __init__(self, responser): | |
| self.SYSTEM_PROMPT = "You are an AI programming assistant." | |
| self.responser = responser | |
| def respond(self, prompt: str) -> str: | |
| return self.responser.respond(self.SYSTEM_PROMPT, prompt) | |
| import re | |
| def extract_python_code_from_markdown(markdown_text): | |
| # 正则表达式匹配带有 python 语言标记的代码块 | |
| code_block_pattern = re.compile(r'```python\n(.*?)\n```', re.DOTALL) | |
| # 查找所有匹配的 Python 代码块 | |
| python_code_blocks = code_block_pattern.findall(markdown_text) | |
| return python_code_blocks[0] | |
| import datetime | |
| def write_log(message: str, log_file: str = "log-algo.txt"): | |
| """ | |
| Append a timestamped log message to a log file. | |
| Args: | |
| message (str): The message to log. | |
| log_file (str): The path to the log file (default is 'log.txt'). | |
| Returns: | |
| None | |
| """ | |
| timestamp = datetime.datetime.now().strftime("[%Y-%m-%d %H:%M:%S]") | |
| with open(log_file, "a", encoding="utf-8") as f: | |
| f.write(f"{timestamp} {message}\n") | |
| import json | |
| import os | |
| def extract_ids_from_jsonl(file_pattern): | |
| ids = [] | |
| for i in range(8): | |
| filepath = file_pattern.format(i) | |
| if not os.path.exists(filepath): | |
| continue | |
| with open(filepath, 'r', encoding='utf-8') as file: | |
| for line in file.readlines(): | |
| data = json.loads(line) | |
| if 'tcb_id' in data: | |
| ids.append(data['tcb_id']) | |
| return ids | |
| import argparse | |
| keys = [ | |
| "ak-8f3d147b2c9a5e6m0n4p8x2v7y1k3l9", | |
| "ak-58d7efgh23i4jkl67mno89pqrs01tuv6k5", | |
| "ak-63d1efgh47i8jkl26mno95pqrs34tuv7x2", | |
| "ak-3f8a2c9e1b7d4f6h5j2k8m3n9p4r6t7", | |
| ] | |
| parser = argparse.ArgumentParser(description="接收1个命令行参数") | |
| parser.add_argument("param1", type=str, help="第一个参数") | |
| args = parser.parse_args() | |
| batch = int(args.param1) | |
| key = keys[batch // 3] | |
| log_file = cfg.log_file.format(batch) | |
| orginal_response_file = cfg.response_file.format("algo", cfg.model_name) | |
| write_log(f"model {cfg.model_name}", log_file) | |
| print(f"model {cfg.model_name}") | |
| al_dataset = get_datasets_by_name(cfg.dataset_name) | |
| start_pos = (batch) * 100 | |
| end_pos = (batch + 1) * 100 if (batch + 1) * 100 <= len(al_dataset) else len(al_dataset) | |
| al_dataset = al_dataset[start_pos: end_pos] | |
| print(f"start {start_pos} end {end_pos}") | |
| print(f"datasets lenght {len(al_dataset)}") | |
| input_gen = InputGenerator(TurboResponser(api_key=key, api_base=cfg.api_base, model= cfg.model_name)) | |
| oracle_gen = Verifier(TurboResponser(api_key=key, api_base=cfg.api_base, model= cfg.model_name)) | |
| testcases = {} | |
| testcases_pass_rate = {} | |
| # finished_list = extract_ids_from_jsonl(cfg.pass_rate_file) | |
| count = 0 | |
| for idx, item in enumerate(al_dataset): | |
| tcb_id = item['tcb_id'] | |
| problem_test_count = 0 | |
| write_log(f"{item['tcb_id']} - {idx} - start", log_file) | |
| orginal_response = { | |
| "tcb_id": item["tcb_id"] | |
| } | |
| # if os.path.exists(cfg.tests_path.format(tcb_id)): | |
| # write_log(f"{item['tcb_id']} exist", log_file) | |
| # continue | |
| try: | |
| # 输入生成函数, 并且测试一下,如果不可以运行成功就重新生成 | |
| MAX_TRIES = 5 | |
| gen_flag = False | |
| while not gen_flag and MAX_TRIES > 0: | |
| MAX_TRIES -= 1 | |
| try: | |
| gen_prompt = CF_PROMPT["data_generator_original"] | |
| gen_prompt = gen_prompt.format(problem=item["query_en"]) | |
| test_case_gen_code = input_gen.respond(gen_prompt) | |
| orginal_response['test_case_gen_code'] = test_case_gen_code | |
| test_case_gen_code = extract_python_code_from_markdown(test_case_gen_code) | |
| input_str = run_func_code(test_case_gen_code, funcname="gen_input", time_limit=10) | |
| if 'Error' not in input_str and 'error' not in input_str and isinstance(input_str, str): | |
| gen_flag = True | |
| except Exception as e: | |
| write_log(f"{item['tcb_id']} fail to construct data_generator: {e}", log_file) | |
| # 暴力解生成 | |
| oracle_prompt = CF_PROMPT["oracle_generator"] | |
| oracle_prompt = oracle_prompt.format(problem=item["query_en"]) | |
| attempt = 5 | |
| while attempt > 0: | |
| try: | |
| oracle = oracle_gen.respond(oracle_prompt) | |
| orginal_response['oracle'] = oracle | |
| oracle = extract_python_code_from_markdown(oracle) | |
| break | |
| except Exception as e: | |
| write_log(f"{item['tcb_id']} fail to oracle: {e}", log_file) | |
| attempt -= 1 | |
| if attempt <= 0: | |
| continue | |
| append_dict_to_jsonl(orginal_response_file, orginal_response) | |
| # 生成指定数量的测试样例 | |
| for _ in range(cfg.gen_test_nums): | |
| try: | |
| input_str = run_func_code(test_case_gen_code, funcname="gen_input", time_limit=10) | |
| output_str = run_python_code_linux(oracle, input_str, time_limit=10) | |
| except Exception as e: | |
| write_log(f"{item['tcb_id']} fail to generate testcase {e}", log_file) | |
| continue | |
| if isinstance(output_str, str): | |
| continue | |
| if 'error' in output_str.stdout or 'Error' in output_str.stdout or output_str.returncode != 0: | |
| continue | |
| output_str = output_str.stdout | |
| test_case = { | |
| 'input': input_str, | |
| 'output': output_str | |
| } | |
| # # 这里使用标准解也求出一版本的输出, 三个标准解法,只要有一个相同就认为正确 | |
| try: | |
| for solution in (item["solutions"][0:3]): | |
| try : | |
| execute_res = run_cpp_code_linux(solution['code'], input_str, time_limit=3) | |
| except Exception as e: | |
| write_log(f"{item['tcb_id']} solution execute fail", log_file) | |
| if "stdout" in execute_res.keys() and execute_res['returncode'] == 0 and test_output_comparison(execute_res["stdout"], test_case['output']): | |
| append_dict_to_jsonl(cfg.tests_path.format(tcb_id), test_case) | |
| problem_test_count += 1 | |
| write_log(f"{item['tcb_id']} +1", log_file) | |
| break | |
| except Exception as e: | |
| write_log(f"{item['tcb_id']} fail to generate gloden-code output {e}", log_file) | |
| continue | |
| except Exception as e: | |
| write_log(f"{item['tcb_id']} fail: {e}", log_file) | |
| finally: | |
| testcases_pass_rate = { | |
| "tcb_id": tcb_id, | |
| "right_nums": problem_test_count, | |
| "gen_nums": cfg.gen_test_nums | |
| } | |
| append_dict_to_jsonl(cfg.pass_rate_file.format(batch), testcases_pass_rate) | |
Xet Storage Details
- Size:
- 7.98 kB
- Xet hash:
- 01d06621d80c5ff181298a2e0ea37ddbb7f6ecbb9db634b573fb764473aa3651
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.