| import sys | |
| sys.path.append("/home/luoxianzhen/yang/methods/utils") | |
| from dataset_all import get_datasets_by_name | |
| from config import cfg | |
| from typing import List, Optional | |
| import json | |
| from collections import Counter | |
| from excute_tool_linux import run_cpp_code_linux | |
| import os | |
| import re | |
| import json | |
| def append_dict_to_jsonl(file_path, data_dict): | |
| with open(file_path, 'a', encoding='utf-8') as f: | |
| f.write(json.dumps(data_dict, ensure_ascii=False) + '\n') | |
| def extract_jsonl_from_markdown(markdown_text): | |
| # 使用正则表达式匹配 Markdown 中的 JSON 格式数据 | |
| json_pattern = re.compile(r'```json(.*?)```', re.DOTALL) | |
| matches = json_pattern.findall(markdown_text) | |
| # 将匹配到的 JSON 字符串解析为字典 | |
| json_data = [] | |
| for match in matches: | |
| try: | |
| json_data.append(json.loads(match.strip())) | |
| except json.JSONDecodeError: | |
| continue | |
| return json_data[0] | |
| import datetime | |
| def write_log(message: str, log_file: str = "log-predo.txt"): | |
| """ | |
| Append a timestamped log message to a log file. | |
| Args: | |
| message (str): The message to log. | |
| log_file (str): The path to the log file (default is 'log.txt'). | |
| Returns: | |
| None | |
| """ | |
| timestamp = datetime.datetime.now().strftime("[%Y-%m-%d %H:%M:%S]") | |
| with open(log_file, "a", encoding="utf-8") as f: | |
| f.write(f"{timestamp} {message}\n") | |
| def extract_code_between_tags(text, start_tag="<BEGIN>", end_tag="<END>"): | |
| """ | |
| Extracts the code between <BEGIN> and <END> tags in a given text. | |
| Args: | |
| text (str): The full text containing the code block. | |
| start_tag (str): The start tag (default is "<BEGIN>"). | |
| end_tag (str): The end tag (default is "<END>"). | |
| Returns: | |
| str: The extracted code block, or an empty string if not found. | |
| """ | |
| match = re.search(f"{re.escape(start_tag)}\\s*(.*?)\\s*{re.escape(end_tag)}", text, re.DOTALL) | |
| return match.group(1).strip() if match else "" | |
| import json | |
| def extract_ids_from_jsonl(filepath): | |
| ids = [] | |
| if not os.path.exists(filepath): | |
| return [] | |
| with open(filepath, 'r', encoding='utf-8') as file: | |
| for line in file: | |
| data = json.loads(line) | |
| if 'tcb_id' in data: | |
| ids.append(data['tcb_id']) | |
| return ids | |
| import json | |
| def read_jsonl_to_array(file_path): | |
| result = [] | |
| try: | |
| with open(file_path, 'r', encoding='utf-8') as file: | |
| for line in file: | |
| result.append(json.loads(line.strip())) # 解析每一行并添加到数组 | |
| return result | |
| except Exception as e: | |
| print(f"Error reading the file: {e}") | |
| return [] | |
| def write_json_to_file(data, filepath): | |
| with open(filepath, 'w', encoding='utf-8') as f: | |
| json.dump(data, f, ensure_ascii=False, indent=4) | |
| def process_code(code_string, input_string, std, time_limit=2, memory_limit=256): | |
| """包装函数,用于添加日志""" | |
| res = {} | |
| try: | |
| result = run_cpp_code_linux(res, code_string, input_string, std, time_limit, memory_limit) | |
| status = result.get("error", "Unknown") | |
| return result | |
| except Exception as e: | |
| res["status"] = "EXE" | |
| res["details"] = str(e) | |
| return res | |
| import os | |
| def main(batch): | |
| log_file = cfg.log_file.format(batch) | |
| al_dataset = get_datasets_by_name(cfg.dataset_name) | |
| start_pos = (batch) * 100 | |
| end_pos = (batch + 1) * 100 if (batch + 1) * 100 <= len(al_dataset) else len(al_dataset) | |
| al_dataset = al_dataset[start_pos: end_pos] | |
| # al_dataset = al_dataset[50:70] | |
| print(f"datasets len {len(al_dataset)}") | |
| for item in al_dataset: | |
| try: | |
| tcb_id = item['tcb_id'] | |
| gen_test_count = 0 | |
| write_log(f"{item['tcb_id']} Start", log_file) | |
| # 加载生成的 testcase | |
| test_file_path = cfg.tests_path.format(tcb_id) | |
| testcase_list = read_jsonl_to_array(test_file_path) | |
| if len(testcase_list) == 0: | |
| continue | |
| # 将测试样例输入执行 正确解法 | |
| ## 这里要求所有的解一致,并且时间限制达标 | |
| code_test_pass_count = [0] * len(item["solutions"]) | |
| test_save_path = cfg.tests_path_flited.format(tcb_id) | |
| ## 如果已存在,删除testcase | |
| if os.path.exists(test_save_path): | |
| os.remove(test_save_path) | |
| for test in testcase_list: | |
| save_flag = True | |
| output_lists = [] | |
| for idx, solution in enumerate(item["solutions"]): | |
| try : | |
| ## TODO 这个正确代码没有带编译 暂时使用 c++11 | |
| execute_res = process_code(solution, test["input"], "c++11", time_limit=item["runtime_limit"], memory_limit=item['memory_limit']) | |
| if execute_res["status"] == "CE": | |
| ## CE问题需要修复代码,这个我们先不管,记录一下换一个正确代码 | |
| write_log(f"{item['tcb_id']} Correct Code {idx} CE", log_file) | |
| item["solutions"].remove(solution) | |
| continue | |
| if execute_res['status'] != "AC": | |
| save_flag = False | |
| write_log(f"{item['tcb_id']} - 1", log_file) | |
| if execute_res['status'] == "AC": | |
| output_lists.append(execute_res['stdout']) | |
| code_test_pass_count[idx] += 1 | |
| except Exception as e: | |
| write_log(f"{item['tcb_id']} solution execute fail", log_file) | |
| if save_flag and len(set(output_lists)) == 1 and output_lists[0] != "": | |
| test['output'] = output_lists[0] | |
| gen_test_count += 1 | |
| append_dict_to_jsonl(test_save_path, test) | |
| write_log(f"{item['tcb_id']} + 1", log_file) | |
| for idx in range(len(item["solutions"])): | |
| if code_test_pass_count[idx] == 0: | |
| write_log(f"{item['tcb_id']} {idx} psssed no test", log_file) | |
| testcases_pass_rate = { | |
| "tcb_id": tcb_id, | |
| "gen_nums": len(testcase_list), | |
| "right_nums": gen_test_count | |
| } | |
| append_dict_to_jsonl(cfg.pass_rate_file.format(batch), testcases_pass_rate) | |
| except Exception as e: | |
| write_log(f"{item['tcb_id']} error {e}", log_file) | |
| continue | |
| if __name__ == "__main__": | |
| import argparse | |
| keys = [ | |
| "ak-3f8a2c9e1b7d4f6h5j2k8m3n9p4r6t7", | |
| "ak-8f3d147b2c9a5e6m0n4p8x2v7y1k3l9", | |
| "ak-58d7efgh23i4jkl67mno89pqrs01tuv6k5", | |
| "ak-63d1efgh47i8jkl26mno95pqrs34tuv7x2", | |
| ] | |
| parser = argparse.ArgumentParser(description="接收1个命令行参数") | |
| parser.add_argument("param1", type=str, help="第一个参数") | |
| args = parser.parse_args() | |
| batch = int(args.param1) | |
| # batch = 0 | |
| key = keys[batch // 3] | |
| main(batch) |
Xet Storage Details
- Size:
- 7.24 kB
- Xet hash:
- 3e4e8de820cc4792214a6ea1b1718134f3acd77c202fd8557db59ddc0331b9e1
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.