Spaces:
Runtime error
Runtime error
| import sys | |
| import os | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| import json | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| from utils.openai_access import call_chatgpt | |
| from utils.mpr import MultipleProcessRunnerSimplifier | |
| from utils.generate_protein_prompt import generate_prompt | |
| qa_data = None | |
| def _load_qa_data(prompt_path): | |
| global qa_data | |
| if qa_data is None: | |
| qa_data = {} | |
| with open(prompt_path, 'r') as f: | |
| for line in f: | |
| if line.strip(): | |
| item = json.loads(line.strip()) | |
| qa_data[item['index']] = item | |
| return qa_data | |
| def process_single_qa(process_id, idx, qa_index, writer, save_dir): | |
| """处理单个QA对并生成答案""" | |
| try: | |
| qa_item = qa_data[qa_index] | |
| protein_id = qa_item['protein_id'] | |
| prompt = qa_item['prompt'] | |
| question = qa_item['question'] | |
| ground_truth = qa_item['ground_truth'] | |
| # 调用LLM生成答案 | |
| llm_response = call_chatgpt(prompt) | |
| # 构建结果数据 | |
| result = { | |
| 'protein_id': protein_id, | |
| 'index': qa_index, | |
| 'question': question, | |
| 'ground_truth': ground_truth, | |
| 'llm_answer': llm_response | |
| } | |
| # 保存文件,文件名使用protein_id和index | |
| save_path = os.path.join(save_dir, f"{protein_id}_{qa_index}.json") | |
| with open(save_path, 'w') as f: | |
| json.dump(result, f, indent=2, ensure_ascii=False) | |
| except Exception as e: | |
| print(f"Error processing QA index {qa_index}: {str(e)}") | |
| def get_missing_qa_indices(save_dir): | |
| """检查哪些QA索引尚未成功生成数据""" | |
| # 获取所有应该生成的qa索引 | |
| all_qa_indices = list(qa_data.keys()) | |
| # 存储问题qa索引(包括空文件和未生成的文件) | |
| problem_qa_indices = set() | |
| # 检查每个应该存在的qa索引 | |
| for qa_index in tqdm(all_qa_indices, desc="检查QA数据文件"): | |
| protein_id = qa_data[qa_index]['protein_id'] | |
| json_file = Path(save_dir) / f"{protein_id}_{qa_index}.json" | |
| # 如果文件不存在,加入问题列表 | |
| if not json_file.exists(): | |
| problem_qa_indices.add(qa_index) | |
| continue | |
| # 检查文件内容 | |
| try: | |
| with open(json_file, 'r') as f: | |
| data = json.load(f) | |
| # 检查文件内容是否为空或缺少必要字段 | |
| if (data is None or len(data) == 0 or | |
| 'llm_answer' not in data or | |
| data.get('llm_answer') is None or | |
| data.get('llm_answer') == ''): | |
| problem_qa_indices.add(qa_index) | |
| json_file.unlink() # 删除空文件或不完整文件 | |
| except (json.JSONDecodeError, Exception) as e: | |
| # 如果JSON解析失败,也认为是问题文件 | |
| problem_qa_indices.add(qa_index) | |
| try: | |
| json_file.unlink() # 删除损坏的文件 | |
| except: | |
| pass | |
| return problem_qa_indices | |
| def main(): | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--prompt_path", type=str, | |
| default="data/processed_data/prompts@clean_test.jsonl", | |
| help="Path to the JSONL file containing QA prompts") | |
| parser.add_argument("--n_process", type=int, default=64, | |
| help="Number of parallel processes") | |
| parser.add_argument("--save_dir", type=str, | |
| default="data/clean_test_results_top2", | |
| help="Directory to save results") | |
| parser.add_argument("--max_iterations", type=int, default=3, | |
| help="Maximum number of iterations to try generating all QA pairs") | |
| args = parser.parse_args() | |
| # 创建保存目录 | |
| os.makedirs(args.save_dir, exist_ok=True) | |
| # 加载QA数据 | |
| _load_qa_data(args.prompt_path) | |
| print(f"已加载 {len(qa_data)} 个QA对") | |
| # 循环检查和生成,直到所有QA对都已生成或达到最大迭代次数 | |
| iteration = 0 | |
| while iteration < args.max_iterations: | |
| iteration += 1 | |
| print(f"\n开始第 {iteration} 轮检查和生成") | |
| # 获取缺失的QA索引 | |
| missing_qa_indices = get_missing_qa_indices(args.save_dir) | |
| # 如果没有缺失的QA索引,则完成 | |
| if not missing_qa_indices: | |
| print("所有QA数据已成功生成!") | |
| break | |
| print(f"发现 {len(missing_qa_indices)} 个缺失的QA数据,准备生成") | |
| # 将缺失的QA索引列表转换为列表 | |
| missing_qa_indices_list = sorted(list(missing_qa_indices)) | |
| # 保存当前缺失的QA索引列表,用于记录 | |
| missing_ids_file = Path(args.save_dir) / f"missing_qa_indices_iteration_{iteration}.txt" | |
| with open(missing_ids_file, 'w') as f: | |
| for qa_index in missing_qa_indices_list: | |
| protein_id = qa_data[qa_index]['protein_id'] | |
| f.write(f"{protein_id}_{qa_index}\n") | |
| # 使用多进程处理生成缺失的QA数据 | |
| mprs = MultipleProcessRunnerSimplifier( | |
| data=missing_qa_indices_list, | |
| do=lambda process_id, idx, qa_index, writer: process_single_qa(process_id, idx, qa_index, writer, args.save_dir), | |
| n_process=args.n_process, | |
| split_strategy="static" | |
| ) | |
| mprs.run() | |
| print(f"第 {iteration} 轮生成完成") | |
| # 最后检查一次 | |
| final_missing_indices = get_missing_qa_indices(args.save_dir) | |
| if final_missing_indices: | |
| print(f"经过 {iteration} 轮生成后,仍有 {len(final_missing_indices)} 个QA数据未成功生成") | |
| # 保存最终缺失的QA索引列表 | |
| final_missing_ids_file = Path(args.save_dir) / "final_missing_qa_indices.txt" | |
| with open(final_missing_ids_file, 'w') as f: | |
| for qa_index in sorted(final_missing_indices): | |
| protein_id = qa_data[qa_index]['protein_id'] | |
| f.write(f"{protein_id}_{qa_index}\n") | |
| print(f"最终缺失的QA索引已保存到: {final_missing_ids_file}") | |
| else: | |
| print(f"经过 {iteration} 轮生成,所有QA数据已成功生成!") | |
| if __name__ == "__main__": | |
| main() | |