| import sys | |
| sys.path.append("/home/i-luoxianzhen/data/TestCase-Gen/methods/utils") | |
| from response import TurboResponser, OpenResponser | |
| from dataset_all import get_datasets_by_name | |
| from is_correct import test_output_comparison | |
| from config import cfg | |
| from prompt import random_input_prompt, adversarial_prompt, random_input_template, adversarial_input_template, shots | |
| from typing import List, Optional | |
| import json | |
| import traceback | |
| from excute_tool_linux import run_cpp_code_linux, run_func_code, run_python_code_linux | |
| from parallel_exe import run_func_code_parallel | |
| example = shots | |
| def write_json_to_file(data, filepath): | |
| with open(filepath, 'w', encoding='utf-8') as f: | |
| json.dump(data, f, ensure_ascii=False, indent=4) | |
| import json | |
| def append_dict_to_jsonl(file_path, data_dict): | |
| with open(file_path, 'a', encoding='utf-8') as f: | |
| f.write(json.dumps(data_dict, ensure_ascii=False) + '\n') | |
| import datetime | |
| def write_log(message: str, log_file: str = "log-lcb.txt"): | |
| """ | |
| Append a timestamped log message to a log file. | |
| Args: | |
| message (str): The message to log. | |
| log_file (str): The path to the log file (default is 'log.txt'). | |
| Returns: | |
| None | |
| """ | |
| timestamp = datetime.datetime.now().strftime("[%Y-%m-%d %H:%M:%S]") | |
| with open(log_file, "a+", encoding="utf-8") as f: | |
| f.write(f"{timestamp} {message}\n") | |
| import re | |
| def extract_code(ans_str): | |
| pattern = r'```python\n(.*?)```' | |
| matches = re.findall(pattern, ans_str, re.DOTALL) | |
| return matches[-1] | |
| def extract_content_code(ans_str): | |
| pattern = r'<ASSISTANT>(.*?)</ASSISTANT>' | |
| matches = re.findall(pattern, ans_str, re.DOTALL) | |
| return matches[-1] | |
| import random | |
| import os | |
| from collections import Counter | |
| import time | |
| def lcb_pipeline(batch, api_key): | |
| responser = TurboResponser(api_key=api_key, api_base=cfg.api_base, model= cfg.model_name) | |
| log_file = cfg.log_file.format(batch) | |
| print(cfg.model_name) | |
| write_log(f"Model {cfg.model_name} batch {batch} start", log_file) | |
| al_dataset = get_datasets_by_name(cfg.dataset_name) | |
| orginal_response_file = cfg.response_file.format("lcb", cfg.model_name) | |
| start_pos = (batch) * 100 | |
| end_pos = (batch + 1) * 100 if (batch + 1) * 100 <= len(al_dataset) else len(al_dataset) | |
| al_dataset = al_dataset[start_pos: end_pos] | |
| # al_dataset = al_dataset[81:] | |
| print(f"datasets lenght {len(al_dataset)}") | |
| for item in al_dataset: | |
| write_log(f"{item['tcb_id']} start", log_file) | |
| orginal_response = { | |
| "tcb_id": item["tcb_id"] | |
| } | |
| time_of_gen = 0 | |
| problem_test_count = 0 | |
| id = item["tcb_id"] | |
| succes_gen = False | |
| if os.path.exists(cfg.tests_path.format(id)): | |
| write_log(f"{item['tcb_id']} exist", log_file) | |
| continue | |
| ## 有时候生成效果很差,我们重新生成一下函数,但不超过3次 | |
| orginal_response['random_generator'] = [] | |
| orginal_response['adversarial_generator'] = [] | |
| while time_of_gen < 3 and not succes_gen: | |
| time_of_gen += 1 | |
| try: | |
| generators = [] | |
| # stage 1 构造随机与对抗生成函数,并且执行得到输入 | |
| for _ in range(3): | |
| try: | |
| random_generator = responser.respond( | |
| random_input_prompt.replace("{EXAMPLE_PROBLEM}", example), | |
| user_prompt= random_input_template.replace("{PROBLEM}", item["query_en"]), | |
| ) | |
| orginal_response['random_generator'].append(random_generator) | |
| if "```python" in random_generator: | |
| random_generator = extract_code(random_generator) | |
| else: | |
| random_generator = extract_content_code(random_generator) | |
| generators.append(random_generator) | |
| except Exception as e: | |
| time.sleep(1) | |
| write_log(f"{item['tcb_id']} fail to generate input func {e}", log_file) | |
| continue | |
| for _ in range(2): | |
| try: | |
| adversarial_generator = responser.respond( | |
| system_info=adversarial_prompt.replace("{EXAMPLE_PROBLEM}", example), | |
| user_prompt=adversarial_input_template.replace("{PROBLEM}", item["query_en"]), | |
| ) | |
| orginal_response['adversarial_generator'].append(adversarial_generator) | |
| if "```python" in adversarial_generator: | |
| adversarial_generator = extract_code(adversarial_generator) | |
| else: | |
| adversarial_generator = extract_content_code(adversarial_generator) | |
| generators.append(adversarial_generator) | |
| except Exception as e: | |
| time.sleep(1) | |
| write_log(f"{item['tcb_id']} fail to generate input func_ {e}", log_file) | |
| continue | |
| # 这里并行的执行生成函数的方法 | |
| try: | |
| inputs_list = run_func_code_parallel(generators, funcname="construct_inputs", time_limit=100) | |
| except Exception as e: | |
| write_log(f"{item['tcb_id']} fail to generate input {e}", log_file) | |
| continue | |
| # 原文要求筛选100个 | |
| generate_testcases = [] | |
| for random_inputs in inputs_list: | |
| if isinstance(random_inputs, List) and len(random_inputs) > 0 and isinstance(random_inputs[0], str): | |
| generate_testcases = generate_testcases + random_inputs | |
| if len(generate_testcases) > 100: | |
| generate_testcases = random.sample(generate_testcases, 100) | |
| #stage 2 Golden code -- output | |
| fail_count = 0 | |
| for test in generate_testcases: | |
| try: | |
| # 利用三种标准解,投票取最多 | |
| test_outputs = [] | |
| for solution in item["solutions"]: | |
| try : | |
| execute_res = run_cpp_code_linux(solution['code'], test, time_limit=3) | |
| except Exception as e: | |
| write_log(f"{item['tcb_id']} solution execute fail {e}", log_file) | |
| write_log(f"{test}", log_file) | |
| if "stdout" in execute_res.keys() and execute_res["stdout"] != "": | |
| test_outputs.append(execute_res["stdout"]) | |
| fail_count = 0 | |
| counter = Counter(test_outputs) | |
| test_output, _ = counter.most_common(1)[0] | |
| except Exception as e: | |
| write_log(f"{item['tcb_id']} fail to generate {e}", log_file) | |
| fail_count += 1 | |
| if fail_count > 10: | |
| break | |
| continue | |
| testcase_curr = { | |
| "input": test, | |
| "output": test_output | |
| } | |
| append_dict_to_jsonl(cfg.tests_path.format(id), testcase_curr) | |
| write_log(f"{item['tcb_id']} + 1", log_file) | |
| problem_test_count += 1 | |
| if problem_test_count < 10: | |
| # 生成数量太少,重新来一遍 | |
| succes_gen = False | |
| continue | |
| succes_gen = True | |
| append_dict_to_jsonl(orginal_response_file, orginal_response) | |
| except Exception as e: | |
| write_log(f"{item['tcb_id']} error!: {e}", log_file) | |
| continue | |
| import argparse | |
| if __name__ == "__main__": | |
| keys = [ | |
| "ak-8f3d147b2c9a5e6m0n4p8x2v7y1k3l9", | |
| "ak-63d1efgh47i8jkl26mno95pqrs34tuv7x2", | |
| "ak-3f8a2c9e1b7d4f6h5j2k8m3n9p4r6t7", | |
| "ak-58d7efgh23i4jkl67mno89pqrs01tuv6k5", | |
| ] | |
| # 进入目录 cd data/TestGen | |
| # 启动 python /home/i-luoxianzhen/data/TestCase-Gen/methods/lcb/pipeline.py 0 | |
| parser = argparse.ArgumentParser(description="接收1个命令行参数") | |
| parser.add_argument("param1", type=str, help="第一个参数") | |
| args = parser.parse_args() | |
| batch = int(args.param1) | |
| # batch = 0 | |
| key = keys[batch // 3] | |
| lcb_pipeline(batch, key) |
Xet Storage Details
- Size:
- 8.78 kB
- Xet hash:
- 6fae970a11dc26cefbe208d6bcae9caf674905f1f073622fdeb3d380a6c779fd
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.