Spaces:
Runtime error
Runtime error
| import json | |
| import sys | |
| import os | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from jinja2 import Template | |
| try: | |
| from utils.protein_go_analysis import analyze_protein_go | |
| from utils.prompts import ENZYME_PROMPT, RELATION_SEMANTIC_PROMPT, FUNCTION_PROMPT | |
| from utils.get_motif import get_motif_pfam | |
| except ImportError: | |
| from protein_go_analysis import analyze_protein_go | |
| from prompts import ENZYME_PROMPT, RELATION_SEMANTIC_PROMPT, FUNCTION_PROMPT | |
| from get_motif import get_motif_pfam | |
| from tqdm import tqdm | |
| class InterProDescriptionManager: | |
| """管理InterPro描述信息的类,避免重复读取文件""" | |
| def __init__(self, interpro_data_path, interproscan_info_path): | |
| """ | |
| 初始化时读取所有需要的数据 | |
| Args: | |
| interpro_data_path: interpro_data.json文件路径 | |
| interproscan_info_path: interproscan_info.json文件路径 | |
| """ | |
| self.interpro_data_path = interpro_data_path | |
| self.interproscan_info_path = interproscan_info_path | |
| self.interpro_data = None | |
| self.interproscan_info = None | |
| self._load_data() | |
| def _load_data(self): | |
| """加载数据文件,只执行一次""" | |
| if self.interpro_data_path and os.path.exists(self.interpro_data_path): | |
| with open(self.interpro_data_path, 'r') as f: | |
| self.interpro_data = json.load(f) | |
| if self.interproscan_info_path and os.path.exists(self.interproscan_info_path): | |
| with open(self.interproscan_info_path, 'r') as f: | |
| self.interproscan_info = json.load(f) | |
| def get_description(self, protein_id, selected_types=None): | |
| """ | |
| 获取蛋白质的InterPro描述信息 | |
| Args: | |
| protein_id: 蛋白质ID | |
| selected_types: 需要获取的信息类型列表,如['superfamily', 'panther', 'gene3d'] | |
| Returns: | |
| dict: 包含各类型描述信息的字典 | |
| """ | |
| if selected_types is None: | |
| selected_types = [] | |
| if not self.interpro_data or not self.interproscan_info: | |
| return {} | |
| result = {} | |
| # 检查蛋白质是否存在 | |
| if protein_id not in self.interproscan_info: | |
| return result | |
| protein_info = self.interproscan_info[protein_id] | |
| interproscan_results = protein_info.get('interproscan_results', {}) | |
| # 遍历选定的类型 | |
| for info_type in selected_types: | |
| if info_type in interproscan_results: | |
| type_descriptions = {} | |
| # 获取该类型的所有IPR ID | |
| for entry in interproscan_results[info_type]: | |
| for key, ipr_id in entry.items(): | |
| if ipr_id and ipr_id in self.interpro_data: | |
| type_descriptions[ipr_id] = { | |
| 'name': self.interpro_data[ipr_id].get('name', ''), | |
| 'abstract': self.interpro_data[ipr_id].get('abstract', '') | |
| } | |
| if type_descriptions: | |
| result[info_type] = type_descriptions | |
| return result | |
| # 全局变量来缓存InterProDescriptionManager实例和lmdb连接 | |
| _interpro_manager = None | |
| _lmdb_db = None | |
| _lmdb_path = None | |
| def get_interpro_manager(interpro_data_path, interproscan_info_path): | |
| """获取或创建InterProDescriptionManager实例""" | |
| global _interpro_manager | |
| if _interpro_manager is None: | |
| _interpro_manager = InterProDescriptionManager(interpro_data_path, interproscan_info_path) | |
| return _interpro_manager | |
| def get_lmdb_connection(lmdb_path): | |
| """获取或创建lmdb连接""" | |
| global _lmdb_db, _lmdb_path | |
| if _lmdb_db is None or _lmdb_path != lmdb_path: | |
| if _lmdb_db is not None: | |
| _lmdb_db.close() | |
| if lmdb_path and os.path.exists(lmdb_path): | |
| import lmdb | |
| _lmdb_db = lmdb.open(lmdb_path, readonly=True) | |
| _lmdb_path = lmdb_path | |
| else: | |
| _lmdb_db = None | |
| _lmdb_path = None | |
| return _lmdb_db | |
| def get_prompt_template(selected_info_types=None,lmdb_path=None): | |
| """ | |
| 获取prompt模板,支持可选的信息类型 | |
| Args: | |
| selected_info_types: 需要包含的信息类型列表,如['motif', 'go', 'superfamily', 'panther'] | |
| """ | |
| if selected_info_types is None: | |
| selected_info_types = ['motif', 'go'] # 默认包含motif和go信息 | |
| if lmdb_path is None: | |
| PROMPT_TEMPLATE = ENZYME_PROMPT + "\n" | |
| else: | |
| PROMPT_TEMPLATE = FUNCTION_PROMPT + "\n" | |
| PROMPT_TEMPLATE += """ | |
| input information: | |
| {%- if 'motif' in selected_info_types and motif_pfam %} | |
| motif:{% for motif_id, motif_info in motif_pfam.items() %} | |
| {{motif_id}}: {{motif_info}} | |
| {% endfor %} | |
| {%- endif %} | |
| {%- if 'go' in selected_info_types and go_data.status == 'success' %} | |
| GO:{% for go_entry in go_data.go_annotations %} | |
| ▢ GO term{{loop.index}}: {{go_entry.go_id}} | |
| • definition: {{ go_data.all_related_definitions.get(go_entry.go_id, 'not found definition') }} | |
| {% endfor %} | |
| {%- endif %} | |
| {%- for info_type in selected_info_types %} | |
| {%- if info_type not in ['motif', 'go'] and interpro_descriptions.get(info_type) %} | |
| {{info_type}}:{% for ipr_id, ipr_info in interpro_descriptions[info_type].items() %} | |
| ▢ {{ipr_id}}: {{ipr_info.name}} | |
| • description: {{ipr_info.abstract}} | |
| {% endfor %} | |
| {%- endif %} | |
| {%- endfor %} | |
| """ | |
| if lmdb_path is not None: | |
| PROMPT_TEMPLATE += "\n" + "question: \n {{question}}" | |
| return PROMPT_TEMPLATE | |
| def get_qa_data(protein_id, lmdb_path): | |
| """ | |
| 从lmdb中获取指定蛋白质的所有QA对 | |
| Args: | |
| protein_id: 蛋白质ID | |
| lmdb_path: lmdb数据库路径 | |
| Returns: | |
| list: QA对列表,每个元素包含question和ground_truth | |
| """ | |
| if not lmdb_path or not os.path.exists(lmdb_path): | |
| return [] | |
| import json | |
| qa_pairs = [] | |
| try: | |
| db = get_lmdb_connection(lmdb_path) | |
| if db is None: | |
| return [] | |
| with db.begin() as txn: | |
| # 遍历数字索引的数据,查找匹配的protein_id | |
| cursor = txn.cursor() | |
| for key, value in cursor: | |
| try: | |
| # 尝试将key解码为数字(数字索引的数据) | |
| key_str = key.decode('utf-8') | |
| if key_str.isdigit(): | |
| # 这是数字索引的数据,包含protein_id, question, ground_truth | |
| data = json.loads(value.decode('utf-8')) | |
| if isinstance(data, list) and len(data) >= 3: | |
| stored_protein_id, question, ground_truth = data[0], data[1], data[2] | |
| if stored_protein_id == protein_id: | |
| qa_pairs.append({ | |
| 'question': question, | |
| 'ground_truth': ground_truth | |
| }) | |
| except Exception as e: | |
| # 如果解析失败,跳过这个条目 | |
| continue | |
| except Exception as e: | |
| print(f"Error reading lmdb for protein {protein_id}: {e}") | |
| return qa_pairs | |
| def generate_prompt(protein_id, protein2gopath, protein2pfam_path, pfam_descriptions_path, go_info_path, | |
| interpro_data_path=None, interproscan_info_path=None, selected_info_types=None, lmdb_path=None, interpro_manager=None, question=None): | |
| """ | |
| 生成蛋白质prompt | |
| Args: | |
| selected_info_types: 需要包含的信息类型列表,如['motif', 'go', 'superfamily', 'panther'] | |
| interpro_data_path: interpro_data.json文件路径 | |
| interproscan_info_path: interproscan_info.json文件路径 | |
| interpro_manager: InterProDescriptionManager实例,如果提供则优先使用 | |
| question: 问题文本,用于QA任务 | |
| """ | |
| if selected_info_types is None: | |
| selected_info_types = ['motif', 'go'] | |
| # 获取分析结果 | |
| analysis = analyze_protein_go(protein_id, protein2gopath, go_info_path) | |
| motif_pfam = get_motif_pfam(protein_id, protein2pfam_path, pfam_descriptions_path) | |
| # 获取InterPro描述信息(如果需要的话) | |
| interpro_descriptions = {} | |
| other_types = [t for t in selected_info_types if t not in ['motif', 'go']] | |
| if other_types: | |
| if interpro_manager: | |
| # 使用提供的manager实例 | |
| interpro_descriptions = interpro_manager.get_description(protein_id, other_types) | |
| elif interpro_data_path and interproscan_info_path: | |
| # 使用全局缓存的manager | |
| manager = get_interpro_manager(interpro_data_path, interproscan_info_path) | |
| interpro_descriptions = manager.get_description(protein_id, other_types) | |
| # 准备模板数据 | |
| template_data = { | |
| "protein_id": protein_id, | |
| "selected_info_types": selected_info_types, | |
| "go_data": { | |
| "status": analysis["status"], | |
| "go_annotations": analysis["go_annotations"] if analysis["status"] == "success" else [], | |
| "all_related_definitions": analysis["all_related_definitions"] if analysis["status"] == "success" else {} | |
| }, | |
| "motif_pfam": motif_pfam, | |
| "interpro_descriptions": interpro_descriptions, | |
| "question": question | |
| } | |
| PROMPT_TEMPLATE = get_prompt_template(selected_info_types,lmdb_path) | |
| template = Template(PROMPT_TEMPLATE) | |
| return template.render(**template_data) | |
| def save_prompts_parallel(protein_ids, output_path, protein2gopath, protein2pfam_path, pfam_descriptions_path, go_info_path, | |
| interpro_data_path=None, interproscan_info_path=None, selected_info_types=None, lmdb_path=None, n_process=8): | |
| """并行生成和保存protein prompts""" | |
| import json | |
| try: | |
| from utils.mpr import MultipleProcessRunnerSimplifier | |
| except ImportError: | |
| from mpr import MultipleProcessRunnerSimplifier | |
| if selected_info_types is None: | |
| selected_info_types = ['motif', 'go'] | |
| # 在并行处理开始前创建InterProDescriptionManager实例 | |
| interpro_manager = None | |
| other_types = [t for t in selected_info_types if t not in ['motif', 'go']] | |
| if other_types and interpro_data_path and interproscan_info_path: | |
| interpro_manager = InterProDescriptionManager(interpro_data_path, interproscan_info_path) | |
| # 用于跟踪全局index的共享变量 | |
| if lmdb_path: | |
| import multiprocessing | |
| global_index = multiprocessing.Value('i', 0) # 共享整数,初始值为0 | |
| index_lock = multiprocessing.Lock() # 用于同步访问 | |
| else: | |
| global_index = None | |
| index_lock = None | |
| results = {} | |
| def process_protein(process_id, idx, protein_id, writer): | |
| protein_id = protein_id.strip() | |
| # 为每个进程初始化lmdb连接 | |
| if lmdb_path: | |
| get_lmdb_connection(lmdb_path) | |
| if lmdb_path: | |
| # 如果有lmdb_path,处理QA数据 | |
| qa_pairs = get_qa_data(protein_id, lmdb_path) | |
| for qa_pair in qa_pairs: | |
| question = qa_pair['question'] | |
| ground_truth = qa_pair['ground_truth'] | |
| prompt = generate_prompt(protein_id, protein2gopath, protein2pfam_path, pfam_descriptions_path, go_info_path, | |
| interpro_data_path, interproscan_info_path, selected_info_types, lmdb_path, interpro_manager, question) | |
| if prompt == "": | |
| continue | |
| if writer: | |
| # 获取并递增全局index | |
| with index_lock: | |
| current_index = global_index.value | |
| global_index.value += 1 | |
| result = { | |
| "index": current_index, | |
| "protein_id": protein_id, | |
| "prompt": prompt, | |
| "question": question, | |
| "ground_truth": ground_truth | |
| } | |
| writer.write(json.dumps(result) + '\n') | |
| else: | |
| # 如果没有lmdb_path,按原来的方式处理 | |
| prompt = generate_prompt(protein_id, protein2gopath, protein2pfam_path, pfam_descriptions_path, go_info_path, | |
| interpro_data_path, interproscan_info_path, selected_info_types, lmdb_path, interpro_manager) | |
| if prompt == "": | |
| return | |
| if writer: | |
| result = {protein_id: prompt} | |
| writer.write(json.dumps(result) + '\n') | |
| # 使用MultipleProcessRunnerSimplifier进行并行处理 | |
| runner = MultipleProcessRunnerSimplifier( | |
| data=protein_ids, | |
| do=process_protein, | |
| save_path=output_path + '.tmp', | |
| n_process=n_process, | |
| split_strategy="static" | |
| ) | |
| runner.run() | |
| # 清理全局lmdb连接 | |
| global _lmdb_db | |
| if _lmdb_db is not None: | |
| _lmdb_db.close() | |
| _lmdb_db = None | |
| if not lmdb_path: | |
| # 如果没有lmdb_path,合并所有结果到一个字典(兼容旧格式) | |
| final_results = {} | |
| with open(output_path + '.tmp', 'r') as f: | |
| for line in f: | |
| if line.strip(): # 忽略空行 | |
| final_results.update(json.loads(line)) | |
| # 保存最终结果为正确的JSON格式 | |
| with open(output_path, 'w') as f: | |
| json.dump(final_results, f, indent=2) | |
| else: | |
| # 如果有lmdb_path,直接保存为jsonl格式 | |
| import shutil | |
| shutil.move(output_path + '.tmp', output_path) | |
| # 删除临时文件(如果还存在的话) | |
| if os.path.exists(output_path + '.tmp'): | |
| os.remove(output_path + '.tmp') | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description='Generate protein prompt') | |
| parser.add_argument('--protein_path', type=str, default='data/raw_data/protein_ids_clean.txt') | |
| parser.add_argument('--protein2pfam_path', type=str, default='data/processed_data/interproscan_info.json') | |
| parser.add_argument('--pfam_descriptions_path', type=str, default='data/raw_data/all_pfam_descriptions.json') | |
| parser.add_argument('--protein2gopath', type=str, default='data/processed_data/go_integration_final_topk2.json') | |
| parser.add_argument('--go_info_path', type=str, default='data/raw_data/go.json') | |
| parser.add_argument('--interpro_data_path', type=str, default='data/raw_data/interpro_data.json') | |
| parser.add_argument('--interproscan_info_path', type=str, default='data/processed_data/interproscan_info.json') | |
| parser.add_argument('--lmdb_path', type=str, default=None) | |
| parser.add_argument('--output_path', type=str, default='data/processed_data/prompts@clean_test.json') | |
| parser.add_argument('--selected_info_types', type=str, nargs='+', default=['motif', 'go'], | |
| help='选择要包含的信息类型,如: motif go superfamily panther gene3d') | |
| parser.add_argument('--n_process', type=int, default=32) | |
| args = parser.parse_args() | |
| #更新output_path,需要包含selected_info_types | |
| args.output_path = args.output_path.replace('.json', '_' + '_'.join(args.selected_info_types) + '.json') | |
| print(args) | |
| with open(args.protein_path, 'r') as file: | |
| protein_ids = file.readlines() | |
| save_prompts_parallel( | |
| protein_ids=protein_ids, | |
| output_path=args.output_path, | |
| n_process=args.n_process, | |
| protein2gopath=args.protein2gopath, | |
| protein2pfam_path=args.protein2pfam_path, | |
| pfam_descriptions_path=args.pfam_descriptions_path, | |
| go_info_path=args.go_info_path, | |
| interpro_data_path=args.interpro_data_path, | |
| interproscan_info_path=args.interproscan_info_path, | |
| selected_info_types=args.selected_info_types, | |
| lmdb_path=args.lmdb_path | |
| ) | |
| # 测试示例 | |
| # protein_id = 'A8CF74' | |
| # prompt = generate_prompt(protein_id, 'data/processed_data/go_integration_final_topk2.json', | |
| # 'data/processed_data/interproscan_info.json', 'data/raw_data/all_pfam_descriptions.json', | |
| # 'data/raw_data/go.json', 'data/raw_data/interpro_data.json', | |
| # 'data/processed_data/interproscan_info.json', | |
| # ['motif', 'go', 'superfamily', 'panther']) | |
| # print(prompt) | |