| import sys | |
| sys.path.append("/home/i-luoxianzhen/data/TestCase-Gen/methods/utils") | |
| from response import TurboResponser, OpenResponser | |
| from dataset_all import get_datasets_by_name | |
| from is_correct import test_output_comparison | |
| from config import cfg | |
| from prompt import validator_prompt, input_generator_prompt | |
| from typing import List, Optional | |
| import json | |
| import traceback | |
| from excute_tool_linux import run_cpp_code_linux, run_func_code, run_python_code_linux | |
| from parallel_exe import run_func_code_parallel | |
| def write_json_to_file(data, filepath): | |
| with open(filepath, 'w', encoding='utf-8') as f: | |
| json.dump(data, f, ensure_ascii=False, indent=4) | |
| import json | |
| def append_dict_to_jsonl(file_path, data_dict): | |
| with open(file_path, 'a', encoding='utf-8') as f: | |
| f.write(json.dumps(data_dict, ensure_ascii=False) + '\n') | |
| import datetime | |
| def write_log(message: str, log_file: str = "log-lcb.txt"): | |
| """ | |
| Append a timestamped log message to a log file. | |
| Args: | |
| message (str): The message to log. | |
| log_file (str): The path to the log file (default is 'log.txt'). | |
| Returns: | |
| None | |
| """ | |
| timestamp = datetime.datetime.now().strftime("[%Y-%m-%d %H:%M:%S]") | |
| with open(log_file, "a+", encoding="utf-8") as f: | |
| f.write(f"{timestamp} {message}\n") | |
| import re | |
| def extract_function_names(code_str: str): | |
| # 正则表达式匹配函数定义,考虑到不同的编程语言格式 | |
| pattern = r'\bdef\s+(\w+)\s*\(' # 适用于 Python 函数 | |
| function_names = re.findall(pattern, code_str) | |
| return function_names | |
| def extract_code(ans_str): | |
| pattern = r'```json\n(.*?)```' | |
| matches = re.findall(pattern, ans_str, re.DOTALL) | |
| return matches[-1] | |
| def extract_content_code(ans_str): | |
| pattern = r'<ASSISTANT>(.*?)</ASSISTANT>' | |
| matches = re.findall(pattern, ans_str, re.DOTALL) | |
| return matches[-1] | |
| import random | |
| import os | |
| from collections import Counter | |
| import time | |
| def ht_pipeline(batch, api_key): | |
| responser = TurboResponser(api_key=api_key, api_base=cfg.api_base, model= cfg.model_name) | |
| log_file = cfg.log_file.format(batch) | |
| orginal_response_file = cfg.response_file.format("ht", cfg.model_name) | |
| al_dataset = get_datasets_by_name(cfg.dataset_name) | |
| start_pos = (batch) * 100 | |
| end_pos = (batch + 1) * 100 if (batch + 1) * 100 <= len(al_dataset) else len(al_dataset) | |
| al_dataset = al_dataset[start_pos: end_pos] | |
| write_log(f"Model: {cfg.model_name} - data {cfg.dataset_name} -batch {batch} start", log_file) | |
| print(f"Model: {cfg.model_name} - data {cfg.dataset_name} -batch {batch} start") | |
| print(f"datasets lenght {len(al_dataset)}") | |
| nR = cfg.nR | |
| nH = cfg.nH | |
| for item in al_dataset: | |
| write_log(f"{item['tcb_id']} start", log_file) | |
| orginal_response = { | |
| "tcb_id": item["tcb_id"] | |
| } | |
| time_of_gen = 0 | |
| problem_test_count = 0 | |
| id = item["tcb_id"] | |
| success_gen = False | |
| if os.path.exists(cfg.tests_path.format(id)): | |
| write_log(f"{item['tcb_id']} exist", log_file) | |
| continue | |
| ## 有时候生成效果很差,我们重新生成一下函数,但不超过3次 | |
| while time_of_gen < 3 and not success_gen: | |
| time_of_gen += 1 | |
| generate_testcases = [] | |
| try: | |
| ## 获取验证函数与生成函数 | |
| response_validator = responser.respond(system_info="You are a helpful code assistant.", user_prompt=validator_prompt.replace("{{ problem_specification }}", item['query_en']).replace("{{ oracle_program }}", item['solutions'][0]['code'])) | |
| orginal_response['response_validator'] = response_validator | |
| response_validator_dict = json.loads(extract_code(response_validator)) | |
| ## 获取生成函数 | |
| response_generator = responser.respond( | |
| system_info="You are a helpful code assistant.", | |
| user_prompt=input_generator_prompt.replace("{{ problem_specification }}", item['query_en']).replace("{{ oracle_program }}", item['solutions'][0]['code']).replace("{{ input_validator }}", response_validator_dict['input_validator']) | |
| ) | |
| orginal_response['response_generator'] = response_generator | |
| response_generator_dict = json.loads(extract_code(response_generator)) | |
| ## 整理直接生成的输入 | |
| for test_input in response_generator_dict['directly_generated_inputs']: | |
| generate_testcases.append({ | |
| 'input':test_input, | |
| "source":"prompt" | |
| }) | |
| ## 添加随机生成输入 | |
| ## 首先需要抽取有哪些方法,然后调用 nR // func 次 | |
| function_list = extract_function_names(response_generator_dict['regular_input_generator']) | |
| round = nR // len(function_list) | |
| if round <= 0: | |
| round = 1 | |
| for _ in range(round): | |
| for funcname in function_list: | |
| test_input = run_func_code_parallel([response_generator_dict['regular_input_generator']], funcname=funcname, time_limit=100)[0] | |
| generate_testcases.append({ | |
| 'input':test_input, | |
| "source":"random" | |
| }) | |
| ## Hacking 输入 | |
| function_list = extract_function_names(response_generator_dict['hacking_input_generator']) | |
| for _ in range(nH): | |
| for funcname in function_list: | |
| test_input = run_func_code_parallel([response_generator_dict['hacking_input_generator']], funcname=funcname, time_limit=100)[0] | |
| generate_testcases.append({ | |
| 'input':test_input, | |
| "source":"edge" | |
| }) | |
| ## stage 2 Validate function | |
| generate_testcases = generate_testcases | |
| generate_testcases_passed = [] | |
| for test_dict in generate_testcases: | |
| is_valid = run_func_code(response_validator_dict['input_validator'], funcname="validate_input", param=test_dict['input'],time_limit=100) | |
| if is_valid: | |
| generate_testcases_passed.append(test_dict) | |
| #stage 3 Golden code -- output | |
| for test in generate_testcases_passed: | |
| try: | |
| # 利用三种标准解,投票取最多 | |
| test_outputs = [] | |
| for solution in item["solutions"]: | |
| try : | |
| execute_res = run_cpp_code_linux(solution['code'], test['input'], time_limit=3) | |
| except Exception as e: | |
| write_log(f"{item['tcb_id']} solution execute fail {e}", log_file) | |
| write_log(f"{test}", log_file) | |
| if "stdout" in execute_res.keys() and execute_res["stdout"] != "": | |
| test_outputs.append(execute_res["stdout"]) | |
| counter = Counter(test_outputs) | |
| test_output, _ = counter.most_common(1)[0] | |
| except Exception as e: | |
| write_log(f"{item['tcb_id']} fail to generate {e}", log_file) | |
| continue | |
| test['output'] = test_output | |
| append_dict_to_jsonl(cfg.tests_path.format(id), test) | |
| write_log(f"{item['tcb_id']} + 1", log_file) | |
| problem_test_count += 1 | |
| if problem_test_count < 10: | |
| # 生成数量太少,重新来一遍 | |
| success_gen = False | |
| continue | |
| success_gen = True | |
| if success_gen: | |
| testcases_pass_rate = { | |
| 'tcb_id': id, | |
| "gen_nums": len(generate_testcases), | |
| "right_nums": len(generate_testcases_passed) | |
| } | |
| append_dict_to_jsonl(cfg.pass_rate_file.format(batch), testcases_pass_rate) | |
| append_dict_to_jsonl(orginal_response_file, orginal_response) | |
| except Exception as e: | |
| write_log(f"{item['tcb_id']} error!: {e}", log_file) | |
| continue | |
| import argparse | |
| if __name__ == "__main__": | |
| keys = [ | |
| "ak-8f3d147b2c9a5e6m0n4p8x2v7y1k3l9", | |
| "ak-63d1efgh47i8jkl26mno95pqrs34tuv7x2", | |
| "ak-3f8a2c9e1b7d4f6h5j2k8m3n9p4r6t7", | |
| "ak-58d7efgh23i4jkl67mno89pqrs01tuv6k5", | |
| ] | |
| # 进入目录 cd data/TestGen | |
| # 启动 python /home/i-luoxianzhen/data/TestCase-Gen/methods/lcb/pipeline.py 0 | |
| parser = argparse.ArgumentParser(description="接收1个命令行参数") | |
| parser.add_argument("param1", type=str, help="第一个参数") | |
| args = parser.parse_args() | |
| batch = int(args.param1) | |
| key = keys[batch // 3] | |
| ht_pipeline(batch, key) |
Xet Storage Details
- Size:
- 9.19 kB
- Xet hash:
- a6505dbc4ff8de9c410b93da5385759962de44d2bc0b6681f54cada8792ea311
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.