Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import sys | |
| import argparse | |
| from typing import Dict, List, Optional | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| # 添加必要的路径 | |
| root_path = os.path.dirname(os.path.abspath(__file__)) | |
| print(root_path) | |
| sys.path.append(root_path) | |
| sys.path.append(os.path.join(root_path, "Models/ProTrek")) | |
| # 导入所需模块 | |
| from interproscan import InterproScan | |
| from Bio.Blast.Applications import NcbiblastpCommandline | |
| from utils.utils import extract_interproscan_metrics, get_seqnid, extract_blast_metrics, rename_interproscan_keys | |
| from go_integration_pipeline import GOIntegrationPipeline | |
| from utils.generate_protein_prompt import generate_prompt, get_interpro_manager, get_lmdb_connection | |
| from utils.openai_access import call_chatgpt | |
| class IntegratedProteinPipeline: | |
| def __init__(self, | |
| blast_database: str = "uniprot_swissprot", | |
| expect_value: float = 0.01, | |
| interproscan_path: str = "interproscan/interproscan-5.75-106.0/interproscan.sh", | |
| interproscan_libraries: List[str] = None, | |
| go_topk: int = 2, | |
| selected_info_types: List[str] = None, | |
| pfam_descriptions_path: str = None, | |
| go_info_path: str = None, | |
| interpro_data_path: str = None, | |
| lmdb_path: str = None, | |
| args: argparse.Namespace = None): | |
| """ | |
| 整合蛋白质分析管道 | |
| Args: | |
| blast_database: BLAST数据库名称 | |
| expect_value: BLAST E-value阈值 | |
| interproscan_path: InterProScan程序路径 | |
| interproscan_libraries: InterProScan库列表 | |
| go_topk: GO整合的topk参数 | |
| selected_info_types: prompt生成时选择的信息类型 | |
| pfam_descriptions_path: Pfam描述文件路径 | |
| go_info_path: GO信息文件路径 | |
| interpro_data_path: InterPro数据文件路径 | |
| lmdb_path: LMDB数据库路径 | |
| """ | |
| self.blast_database = blast_database | |
| self.expect_value = expect_value | |
| self.interproscan_path = interproscan_path | |
| self.interproscan_libraries = interproscan_libraries or [ | |
| "PFAM", "PIRSR", "PROSITE_PROFILES", "SUPERFAMILY", "PRINTS", | |
| "PANTHER", "CDD", "GENE3D", "NCBIFAM", "SFLM", "MOBIDB_LITE", | |
| "COILS", "PROSITE_PATTERNS", "FUNFAM", "SMART" | |
| ] | |
| self.go_topk = go_topk | |
| self.selected_info_types = selected_info_types or ['motif', 'go'] | |
| # 文件路径配置 | |
| self.pfam_descriptions_path = pfam_descriptions_path | |
| self.go_info_path = go_info_path | |
| self.interpro_data_path = interpro_data_path | |
| self.lmdb_path = lmdb_path | |
| self.interproscan_info_path = args.interproscan_info_path | |
| self.blast_info_path = args.blast_info_path | |
| # 初始化GO整合管道 | |
| self.go_pipeline = GOIntegrationPipeline(topk=self.go_topk) | |
| # 初始化InterPro管理器(如果需要) | |
| self.interpro_manager = None | |
| other_types = [t for t in self.selected_info_types if t not in ['motif', 'go']] | |
| if other_types and self.interpro_data_path: | |
| self.interpro_manager = get_interpro_manager(self.interpro_data_path, None) | |
| def step1_run_blast_and_interproscan(self, input_fasta: str, temp_dir: str = "temp") -> tuple: | |
| """ | |
| 步骤1: 运行BLAST和InterProScan分析 | |
| Args: | |
| input_fasta: 输入FASTA文件路径 | |
| temp_dir: 临时文件目录 | |
| Returns: | |
| tuple: (interproscan_info, blast_info) | |
| """ | |
| print("步骤1: 运行BLAST和InterProScan分析...") | |
| # 创建临时目录 | |
| os.makedirs(temp_dir, exist_ok=True) | |
| # 获取序列字典 | |
| seq_dict = get_seqnid(input_fasta) | |
| print(f"读取到 {len(seq_dict)} 个序列") | |
| # 运行BLAST | |
| print("运行BLAST分析...") | |
| blast_xml = os.path.join(temp_dir, "blast_results.xml") | |
| blast_cmd = NcbiblastpCommandline( | |
| query=input_fasta, | |
| db=self.blast_database, | |
| out=blast_xml, | |
| outfmt=5, # XML格式 | |
| evalue=self.expect_value | |
| ) | |
| blast_cmd() | |
| # 提取BLAST结果 | |
| blast_results = extract_blast_metrics(blast_xml) | |
| blast_info = {} | |
| for uid, info in blast_results.items(): | |
| blast_info[uid] = {"sequence": seq_dict[uid], "blast_results": info} | |
| # 运行InterProScan | |
| print("运行InterProScan分析...") | |
| interproscan_json = os.path.join(temp_dir, "interproscan_results.json") | |
| interproscan = InterproScan(self.interproscan_path) | |
| input_args = { | |
| "fasta_file": input_fasta, | |
| "goterms": True, | |
| "pathways": True, | |
| "save_dir": interproscan_json | |
| } | |
| interproscan.run(**input_args) | |
| # 提取InterProScan结果 | |
| interproscan_results = extract_interproscan_metrics( | |
| interproscan_json, | |
| librarys=self.interproscan_libraries | |
| ) | |
| interproscan_info = {} | |
| for id, seq in seq_dict.items(): | |
| info = interproscan_results[seq] | |
| info = rename_interproscan_keys(info) | |
| interproscan_info[id] = {"sequence": seq, "interproscan_results": info} | |
| # 清理临时文件 | |
| if os.path.exists(blast_xml): | |
| os.remove(blast_xml) | |
| if os.path.exists(interproscan_json): | |
| os.remove(interproscan_json) | |
| print(f"步骤1完成: 处理了 {len(interproscan_info)} 个蛋白质") | |
| return interproscan_info, blast_info | |
| def step2_integrate_go_information(self, interproscan_info: Dict, blast_info: Dict) -> Dict: | |
| """ | |
| 步骤2: 整合GO信息 | |
| Args: | |
| interproscan_info: InterProScan结果 | |
| blast_info: BLAST结果 | |
| Returns: | |
| Dict: 整合后的GO信息 | |
| """ | |
| print("步骤2: 整合GO信息...") | |
| # 使用GO整合管道进行第一层筛选 | |
| protein_go_dict = self.go_pipeline.first_level_filtering(interproscan_info, blast_info) | |
| print(f"步骤2完成: 为 {len(protein_go_dict)} 个蛋白质整合了GO信息") | |
| return protein_go_dict | |
| def step3_generate_prompts(self, interproscan_info: Dict, blast_info: Dict, | |
| protein_go_dict: Dict) -> Dict: | |
| """ | |
| 步骤3: 生成蛋白质prompt | |
| Args: | |
| interproscan_info: InterProScan结果 | |
| blast_info: BLAST结果 | |
| protein_go_dict: 整合的GO信息 | |
| Returns: | |
| Dict: 蛋白质ID到prompt的映射(如果有lmdb则包含QA对) | |
| """ | |
| print("步骤3: 生成蛋白质prompt...") | |
| # 创建临时的GO整合文件格式(用于generate_prompt函数) | |
| temp_go_data = {} | |
| for protein_id, go_ids in protein_go_dict.items(): | |
| temp_go_data[protein_id] = go_ids | |
| prompts_data = {} | |
| if self.lmdb_path: | |
| # 如果有lmdb路径,处理QA数据 | |
| from utils.generate_protein_prompt import get_qa_data | |
| global_index = 0 | |
| for protein_id in tqdm(interproscan_info.keys(), desc="生成prompts"): | |
| # 获取QA对 | |
| qa_pairs = get_qa_data(protein_id, self.lmdb_path) | |
| for qa_pair in qa_pairs: | |
| question = qa_pair['question'] | |
| ground_truth = qa_pair['ground_truth'] | |
| # 生成prompt(需要修改generate_prompt函数以支持内存数据) | |
| prompt = self._generate_prompt_from_memory( | |
| protein_id, interproscan_info, temp_go_data, question | |
| ) | |
| if prompt: | |
| prompts_data[global_index] = { | |
| "index": global_index, | |
| "protein_id": protein_id, | |
| "prompt": prompt, | |
| "question": question, | |
| "ground_truth": ground_truth | |
| } | |
| global_index += 1 | |
| else: | |
| # 如果没有lmdb路径,按原来的方式处理 | |
| for protein_id in tqdm(interproscan_info.keys(), desc="生成prompts"): | |
| prompt = self._generate_prompt_from_memory( | |
| protein_id, interproscan_info, temp_go_data | |
| ) | |
| if prompt: | |
| prompts_data[protein_id] = prompt | |
| print(f"步骤3完成: 生成了 {len(prompts_data)} 个prompt") | |
| return prompts_data | |
| def _generate_prompt_from_memory(self, protein_id: str, interproscan_info: Dict, | |
| protein_go_dict: Dict, question: str = None) -> str: | |
| """ | |
| 从内存中的数据生成prompt,包含完整的motif和GO定义 | |
| """ | |
| try: | |
| from utils.protein_go_analysis import get_go_definition | |
| from jinja2 import Template | |
| from utils.generate_protein_prompt import get_prompt_template | |
| # 获取GO分析结果 | |
| go_ids = protein_go_dict.get(protein_id, []) | |
| go_annotations = [] | |
| all_related_definitions = {} | |
| if go_ids: | |
| for go_id in go_ids: | |
| # 确保GO ID格式正确 | |
| clean_go_id = go_id.split(":")[-1] if ":" in go_id else go_id | |
| go_annotations.append({"go_id": clean_go_id}) | |
| # 获取GO定义 | |
| definition = get_go_definition(clean_go_id,self.go_info_path) | |
| if definition: | |
| all_related_definitions[clean_go_id] = definition | |
| # 获取motif信息 | |
| motif_pfam = {} | |
| if self.pfam_descriptions_path: | |
| try: | |
| # 从interproscan结果中提取pfam信息 | |
| interproscan_results = interproscan_info[protein_id].get('interproscan_results', {}) | |
| pfam_entries = interproscan_results.get('pfam_id', []) | |
| # 加载pfam描述 | |
| with open(self.pfam_descriptions_path, 'r') as f: | |
| pfam_descriptions = json.load(f) | |
| # 构建motif_pfam字典 | |
| for entry in pfam_entries: | |
| for pfam_id, ipr_id in entry.items(): | |
| if pfam_id and pfam_id in pfam_descriptions: | |
| motif_pfam[pfam_id] = pfam_descriptions[pfam_id]['description'] | |
| except Exception as e: | |
| print(f"获取motif信息时出错: {str(e)}") | |
| # 获取InterPro描述信息 | |
| interpro_descriptions = {} | |
| other_types = [t for t in self.selected_info_types if t not in ['motif', 'go']] | |
| if other_types and self.interpro_manager: | |
| interpro_descriptions = self.interpro_manager.get_description(protein_id, other_types) | |
| # 准备模板数据 | |
| template_data = { | |
| "protein_id": protein_id, | |
| "selected_info_types": self.selected_info_types, | |
| "go_data": { | |
| "status": "success" if go_annotations else "no_data", | |
| "go_annotations": go_annotations, | |
| "all_related_definitions": all_related_definitions | |
| }, | |
| "motif_pfam": motif_pfam, | |
| "interpro_descriptions": interpro_descriptions, | |
| "question": question | |
| } | |
| # 使用模板生成prompt | |
| PROMPT_TEMPLATE = get_prompt_template(self.selected_info_types, self.lmdb_path) | |
| template = Template(PROMPT_TEMPLATE) | |
| return template.render(**template_data) | |
| except Exception as e: | |
| print(f"生成prompt时出错 (protein_id: {protein_id}): {str(e)}") | |
| # 如果出错,返回简化版本的prompt | |
| return self._generate_simple_prompt(protein_id, interproscan_info, protein_go_dict, question) | |
| def _generate_simple_prompt(self, protein_id: str, interproscan_info: Dict, | |
| protein_go_dict: Dict, question: str = None) -> str: | |
| """ | |
| 生成简化版本的prompt(作为备用) | |
| """ | |
| # 获取蛋白质序列 | |
| sequence = interproscan_info[protein_id].get('sequence', '') | |
| # 获取GO信息 | |
| go_ids = protein_go_dict.get(protein_id, []) | |
| # 获取motif信息 | |
| interproscan_results = interproscan_info[protein_id].get('interproscan_results', {}) | |
| pfam_entries = interproscan_results.get('pfam_id', []) | |
| # 简化的prompt生成逻辑 | |
| prompt_parts = [] | |
| if self.lmdb_path: | |
| from utils.prompts import FUNCTION_PROMPT | |
| prompt_parts.append(FUNCTION_PROMPT) | |
| else: | |
| from utils.prompts import ENZYME_PROMPT | |
| prompt_parts.append(ENZYME_PROMPT) | |
| prompt_parts.append("\ninput information:") | |
| # 添加motif信息 | |
| if 'motif' in self.selected_info_types and pfam_entries: | |
| prompt_parts.append("\nmotif:") | |
| for entry in pfam_entries: | |
| for key, value in entry.items(): | |
| if value: | |
| prompt_parts.append(f"{value}: 无详细描述") | |
| # 添加GO信息 | |
| if 'go' in self.selected_info_types and go_ids: | |
| prompt_parts.append("\nGO:") | |
| for i, go_id in enumerate(go_ids[:10], 1): | |
| prompt_parts.append(f"▢ GO term{i}: {go_id}") | |
| prompt_parts.append(f"• definition: 无详细定义") | |
| if question: | |
| prompt_parts.append(f"\nquestion: \n{question}") | |
| return "\n".join(prompt_parts) | |
| def step4_generate_llm_answers(self, prompts_data: Dict, save_dir: str) -> None: | |
| """ | |
| 步骤4: 生成LLM答案 | |
| Args: | |
| prompts_data: prompt数据 | |
| save_dir: 保存目录 | |
| """ | |
| print("步骤4: 生成LLM答案...") | |
| # 创建保存目录 | |
| os.makedirs(save_dir, exist_ok=True) | |
| if self.lmdb_path: | |
| # 如果有lmdb路径,处理QA数据 | |
| for index, qa_item in tqdm(prompts_data.items(), desc="生成LLM答案"): | |
| try: | |
| protein_id = qa_item['protein_id'] | |
| prompt = qa_item['prompt'] | |
| question = qa_item['question'] | |
| ground_truth = qa_item['ground_truth'] | |
| # 调用LLM生成答案 | |
| llm_response = call_chatgpt(prompt) | |
| # 构建结果数据 | |
| result = { | |
| 'protein_id': protein_id, | |
| 'index': index, | |
| 'question': question, | |
| 'ground_truth': ground_truth, | |
| 'llm_answer': llm_response | |
| } | |
| # 保存文件 | |
| save_path = os.path.join(save_dir, f"{protein_id}_{index}.json") | |
| with open(save_path, 'w') as f: | |
| json.dump(result, f, indent=2, ensure_ascii=False) | |
| except Exception as e: | |
| print(f"处理索引 {index} 时出错: {str(e)}") | |
| else: | |
| # 如果没有lmdb路径,按原来的方式处理 | |
| for protein_id, prompt in tqdm(prompts_data.items(), desc="生成LLM答案"): | |
| try: | |
| # 调用LLM生成答案 | |
| llm_response = call_chatgpt(prompt) | |
| # 构建结果数据 | |
| result = { | |
| 'protein_id': protein_id, | |
| 'prompt': prompt, | |
| 'llm_answer': llm_response | |
| } | |
| # 保存文件 | |
| save_path = os.path.join(save_dir, f"{protein_id}.json") | |
| with open(save_path, 'w') as f: | |
| json.dump(result, f, indent=2, ensure_ascii=False) | |
| except Exception as e: | |
| print(f"处理蛋白质 {protein_id} 时出错: {str(e)}") | |
| print(f"步骤4完成: 结果已保存到 {save_dir}") | |
| def run(self, input_fasta: str, output_dir: str, temp_dir: str = "temp"): | |
| """ | |
| 运行完整的工作流 | |
| Args: | |
| input_fasta: 输入FASTA文件路径 | |
| output_dir: 输出目录 | |
| temp_dir: 临时文件目录 | |
| """ | |
| print(f"开始运行整合蛋白质分析管道...") | |
| print(f"输入文件: {input_fasta}") | |
| print(f"输出目录: {output_dir}") | |
| # 创建输出目录 | |
| os.makedirs(output_dir, exist_ok=True) | |
| try: | |
| # 步骤1: 运行BLAST和InterProScan | |
| if self.interproscan_info_path is None or self.blast_info_path is None: | |
| interproscan_info, blast_info = self.step1_run_blast_and_interproscan( | |
| input_fasta, temp_dir | |
| ) | |
| else: | |
| interproscan_info = json.load(open(self.interproscan_info_path)) | |
| blast_info = json.load(open(self.blast_info_path)) | |
| # 步骤2: 整合GO信息 | |
| protein_go_dict = self.step2_integrate_go_information( | |
| interproscan_info, blast_info | |
| ) | |
| # 步骤3: 生成prompt | |
| prompts_data = self.step3_generate_prompts( | |
| interproscan_info, blast_info, protein_go_dict | |
| ) | |
| print(prompts_data) | |
| # 步骤4: 生成LLM答案 | |
| self.step4_generate_llm_answers(prompts_data, output_dir) | |
| print("整合管道运行完成!") | |
| except Exception as e: | |
| print(f"管道运行出错: {str(e)}") | |
| raise | |
| finally: | |
| # 清理临时目录 | |
| print(f"清理临时目录: {temp_dir}") | |
| if os.path.exists(temp_dir): | |
| import shutil | |
| shutil.rmtree(temp_dir) | |
| def main(): | |
| parser = argparse.ArgumentParser(description="整合蛋白质分析管道") | |
| parser.add_argument("--input_fasta", type=str, required=True, help="输入FASTA文件路径") | |
| parser.add_argument("--output_dir", type=str, required=True, help="输出目录") | |
| parser.add_argument("--temp_dir", type=str, default="temp", help="临时文件目录") | |
| parser.add_argument('--interproscan_info_path', type=str, default=None, help="InterProScan结果文件路径") | |
| parser.add_argument('--blast_info_path', type=str, default=None, help="BLAST结果文件路径") | |
| # BLAST参数 | |
| parser.add_argument("--blast_database", type=str, default="uniprot_swissprot", help="BLAST数据库") | |
| parser.add_argument("--expect_value", type=float, default=0.01, help="BLAST E-value阈值") | |
| # InterProScan参数 | |
| parser.add_argument("--interproscan_path", type=str, | |
| default="interproscan/interproscan-5.75-106.0/interproscan.sh", | |
| help="InterProScan程序路径") | |
| # GO整合参数 | |
| parser.add_argument("--go_topk", type=int, default=2, help="GO整合topk参数") | |
| # Prompt生成参数 | |
| parser.add_argument("--selected_info_types", type=str, nargs='+', | |
| default=['motif', 'go'], help="选择的信息类型") | |
| parser.add_argument("--pfam_descriptions_path", type=str, default='data/raw_data/all_pfam_descriptions.json', help="Pfam描述文件路径") | |
| parser.add_argument("--go_info_path", type=str, default='data/raw_data/go.json', help="GO信息文件路径") | |
| parser.add_argument("--interpro_data_path", type=str, default='data/raw_data/interpro_data.json', help="InterPro数据文件路径") | |
| parser.add_argument("--lmdb_path", type=str, help="LMDB数据库路径") | |
| args = parser.parse_args() | |
| # 创建管道实例 | |
| pipeline = IntegratedProteinPipeline( | |
| blast_database=args.blast_database, | |
| expect_value=args.expect_value, | |
| interproscan_path=args.interproscan_path, | |
| go_topk=args.go_topk, | |
| selected_info_types=args.selected_info_types, | |
| pfam_descriptions_path=args.pfam_descriptions_path, | |
| go_info_path=args.go_info_path, | |
| interpro_data_path=args.interpro_data_path, | |
| lmdb_path=args.lmdb_path, | |
| args=args | |
| ) | |
| # 运行管道 | |
| pipeline.run(args.input_fasta, args.output_dir, args.temp_dir) | |
| if __name__ == "__main__": | |
| main() |