Spaces:
Runtime error
Runtime error
| import os | |
| import json | |
| import sys | |
| import tempfile | |
| import gradio as gr | |
| from typing import Dict, List, Optional | |
| from pathlib import Path | |
| from Bio import SeqIO | |
| from io import StringIO | |
| # 添加必要的路径 | |
| root_path = os.path.dirname(os.path.abspath(__file__)) | |
| sys.path.append(root_path) | |
| sys.path.append(os.path.join(root_path, "Models/ProTrek")) | |
| # 导入所需模块 | |
| from interproscan import InterproScan | |
| from Bio.Blast.Applications import NcbiblastpCommandline | |
| from utils.utils import extract_interproscan_metrics, get_seqnid, extract_blast_metrics, rename_interproscan_keys | |
| from go_integration_pipeline import GOIntegrationPipeline | |
| from utils.openai_access import call_chatgpt | |
| from utils.prompts import FUNCTION_PROMPT | |
| def get_prompt_template(selected_info_types=None): | |
| """ | |
| 获取prompt模板,支持可选的信息类型 | |
| Args: | |
| selected_info_types: 需要包含的信息类型列表,如['motif', 'go', 'superfamily', 'panther'] | |
| """ | |
| if selected_info_types is None: | |
| selected_info_types = ['motif', 'go'] # 默认包含motif和go信息 | |
| PROMPT_TEMPLATE = FUNCTION_PROMPT + '\n' + """ | |
| input information: | |
| {%- if 'motif' in selected_info_types and motif_pfam %} | |
| motif:{% for motif_id, motif_info in motif_pfam.items() %} | |
| {{motif_id}}: {{motif_info}} | |
| {% endfor %} | |
| {%- endif %} | |
| {%- if 'go' in selected_info_types and go_data.status == 'success' %} | |
| GO:{% for go_entry in go_data.go_annotations %} | |
| ▢ GO term{{loop.index}}: {{go_entry.go_id}} | |
| • definition: {{ go_data.all_related_definitions.get(go_entry.go_id, 'not found definition') }} | |
| {% endfor %} | |
| {%- endif %} | |
| {%- for info_type in selected_info_types %} | |
| {%- if info_type not in ['motif', 'go'] and interpro_descriptions.get(info_type) %} | |
| {{info_type}}:{% for ipr_id, ipr_info in interpro_descriptions[info_type].items() %} | |
| ▢ {{ipr_id}}: {{ipr_info.name}} | |
| • description: {{ipr_info.abstract}} | |
| {% endfor %} | |
| {%- endif %} | |
| {%- endfor %} | |
| question: \n {{question}} | |
| """ | |
| return PROMPT_TEMPLATE | |
| class ProteinAnalysisDemo: | |
| def __init__(self): | |
| """ | |
| 蛋白质分析演示类 | |
| """ | |
| self.blast_database = "uniprot_swissprot" | |
| self.expect_value = 0.01 | |
| self.interproscan_path = "interproscan/interproscan-5.75-106.0/interproscan.sh" | |
| self.interproscan_libraries = [ | |
| "PFAM", "PIRSR", "PROSITE_PROFILES", "SUPERFAMILY", "PRINTS", | |
| "PANTHER", "CDD", "GENE3D", "NCBIFAM", "SFLM", "MOBIDB_LITE", | |
| "COILS", "PROSITE_PATTERNS", "FUNFAM", "SMART" | |
| ] | |
| self.go_topk = 2 | |
| self.selected_info_types = ['motif', 'go'] | |
| # 文件路径配置 | |
| self.pfam_descriptions_path = 'data/raw_data/all_pfam_descriptions.json' | |
| self.go_info_path = 'data/raw_data/go.json' | |
| self.interpro_data_path = 'data/raw_data/interpro_data.json' | |
| # 初始化GO整合管道 | |
| self.go_pipeline = GOIntegrationPipeline(topk=self.go_topk) | |
| # 初始化InterPro管理器(如果需要) | |
| self.interpro_manager = None | |
| other_types = [t for t in self.selected_info_types if t not in ['motif', 'go']] | |
| if other_types and os.path.exists(self.interpro_data_path): | |
| try: | |
| from utils.generate_protein_prompt import get_interpro_manager | |
| self.interpro_manager = get_interpro_manager(self.interpro_data_path, None) | |
| except Exception as e: | |
| print(f"初始化InterPro管理器失败: {str(e)}") | |
| def validate_protein_sequence(self, sequence: str) -> bool: | |
| """ | |
| 验证蛋白质序列格式 | |
| """ | |
| if not sequence: | |
| return False | |
| # 移除空白字符 | |
| sequence = sequence.strip().upper() | |
| # 检查是否包含有效的氨基酸字符 | |
| valid_aa = set('ACDEFGHIKLMNPQRSTVWY') | |
| sequence_chars = set(sequence.replace('\n', '').replace(' ', '')) | |
| return sequence_chars.issubset(valid_aa) and len(sequence) > 0 | |
| def parse_fasta_content(self, fasta_content: str) -> tuple: | |
| """ | |
| 解析FASTA内容,返回第一个序列 | |
| """ | |
| try: | |
| fasta_io = StringIO(fasta_content) | |
| records = list(SeqIO.parse(fasta_io, "fasta")) | |
| if not records: | |
| return None, "FASTA文件中没有找到有效的序列" | |
| if len(records) > 1: | |
| return None, "演示版本只支持单一序列,检测到多个序列" | |
| record = records[0] | |
| return str(record.seq), f"成功解析序列 ID: {record.id}" | |
| except Exception as e: | |
| return None, f"解析FASTA文件出错: {str(e)}" | |
| def create_temp_fasta(self, sequence: str, seq_id: str = "demo_protein") -> str: | |
| """ | |
| 创建临时FASTA文件 | |
| """ | |
| temp_file = tempfile.NamedTemporaryFile(mode='w', suffix='.fasta', delete=False) | |
| temp_file.write(f">{seq_id}\n{sequence}\n") | |
| temp_file.close() | |
| return temp_file.name | |
| def run_blast_analysis(self, fasta_file: str, temp_dir: str) -> Dict: | |
| """ | |
| 运行BLAST分析 | |
| """ | |
| blast_xml = os.path.join(temp_dir, "blast_results.xml") | |
| try: | |
| blast_cmd = NcbiblastpCommandline( | |
| query=fasta_file, | |
| db=self.blast_database, | |
| out=blast_xml, | |
| outfmt=5, # XML格式 | |
| evalue=self.expect_value | |
| ) | |
| blast_cmd() | |
| # 提取BLAST结果 | |
| blast_results = extract_blast_metrics(blast_xml) | |
| # 获取序列字典 | |
| seq_dict = get_seqnid(fasta_file) | |
| blast_info = {} | |
| for uid, info in blast_results.items(): | |
| blast_info[uid] = {"sequence": seq_dict[uid], "blast_results": info} | |
| return blast_info | |
| except Exception as e: | |
| print(f"BLAST分析出错: {str(e)}") | |
| return {} | |
| finally: | |
| if os.path.exists(blast_xml): | |
| os.remove(blast_xml) | |
| def run_interproscan_analysis(self, fasta_file: str, temp_dir: str) -> Dict: | |
| """ | |
| 运行InterProScan分析 | |
| """ | |
| interproscan_json = os.path.join(temp_dir, "interproscan_results.json") | |
| try: | |
| interproscan = InterproScan(self.interproscan_path) | |
| input_args = { | |
| "fasta_file": fasta_file, | |
| "goterms": True, | |
| "pathways": True, | |
| "save_dir": interproscan_json | |
| } | |
| interproscan.run(**input_args) | |
| # 提取InterProScan结果 | |
| interproscan_results = extract_interproscan_metrics( | |
| interproscan_json, | |
| librarys=self.interproscan_libraries | |
| ) | |
| # 获取序列字典 | |
| seq_dict = get_seqnid(fasta_file) | |
| interproscan_info = {} | |
| for id, seq in seq_dict.items(): | |
| info = interproscan_results[seq] | |
| info = rename_interproscan_keys(info) | |
| interproscan_info[id] = {"sequence": seq, "interproscan_results": info} | |
| return interproscan_info | |
| except Exception as e: | |
| print(f"InterProScan分析出错: {str(e)}") | |
| return {} | |
| finally: | |
| if os.path.exists(interproscan_json): | |
| os.remove(interproscan_json) | |
| def generate_prompt(self, protein_id: str, interproscan_info: Dict, | |
| protein_go_dict: Dict, question: str) -> str: | |
| """ | |
| 从内存中的数据生成prompt,包含完整的motif和GO定义 | |
| """ | |
| try: | |
| from utils.protein_go_analysis import get_go_definition | |
| from jinja2 import Template | |
| # from utils.generate_protein_prompt import get_prompt_template | |
| # 获取GO分析结果 | |
| go_ids = protein_go_dict.get(protein_id, []) | |
| go_annotations = [] | |
| all_related_definitions = {} | |
| if go_ids: | |
| for go_id in go_ids: | |
| # 确保GO ID格式正确 | |
| clean_go_id = go_id.split(":")[-1] if ":" in go_id else go_id | |
| go_annotations.append({"go_id": clean_go_id}) | |
| # 获取GO定义 | |
| if os.path.exists(self.go_info_path): | |
| definition = get_go_definition(clean_go_id, self.go_info_path) | |
| if definition: | |
| all_related_definitions[clean_go_id] = definition | |
| # 获取motif信息 | |
| motif_pfam = {} | |
| if os.path.exists(self.pfam_descriptions_path): | |
| try: | |
| # 从interproscan结果中提取pfam信息 | |
| interproscan_results = interproscan_info[protein_id].get('interproscan_results', {}) | |
| pfam_entries = interproscan_results.get('pfam_id', []) | |
| # 加载pfam描述 | |
| with open(self.pfam_descriptions_path, 'r') as f: | |
| pfam_descriptions = json.load(f) | |
| # 构建motif_pfam字典 | |
| for entry in pfam_entries: | |
| for pfam_id, ipr_id in entry.items(): | |
| if pfam_id and pfam_id in pfam_descriptions: | |
| motif_pfam[pfam_id] = pfam_descriptions[pfam_id]['description'] | |
| except Exception as e: | |
| print(f"获取motif信息时出错: {str(e)}") | |
| # 获取InterPro描述信息 | |
| interpro_descriptions = {} | |
| other_types = [t for t in self.selected_info_types if t not in ['motif', 'go']] | |
| if other_types and self.interpro_manager: | |
| interpro_descriptions = self.interpro_manager.get_description(protein_id, other_types) | |
| # 准备模板数据 | |
| template_data = { | |
| "protein_id": protein_id, | |
| "selected_info_types": self.selected_info_types, | |
| "go_data": { | |
| "status": "success" if go_annotations else "no_data", | |
| "go_annotations": go_annotations, | |
| "all_related_definitions": all_related_definitions | |
| }, | |
| "motif_pfam": motif_pfam, | |
| "interpro_descriptions": interpro_descriptions, | |
| "question": question | |
| } | |
| # 使用模板生成prompt | |
| PROMPT_TEMPLATE = get_prompt_template(self.selected_info_types) # demo版本不使用lmdb | |
| template = Template(PROMPT_TEMPLATE) | |
| return template.render(**template_data) | |
| except Exception as e: | |
| print(f"生成prompt时出错 (protein_id: {protein_id}): {str(e)}") | |
| # 如果出错,返回简化版本的prompt | |
| return self._generate_fallback_prompt(protein_id, interproscan_info, protein_go_dict, question) | |
| def _generate_fallback_prompt(self, protein_id: str, interproscan_info: Dict, | |
| protein_go_dict: Dict, question: str) -> str: | |
| """ | |
| 生成备用prompt(当主要方法失败时使用) | |
| """ | |
| from utils.prompts import FUNCTION_PROMPT | |
| prompt_parts = [FUNCTION_PROMPT] | |
| prompt_parts.append("\ninput information:") | |
| # 添加motif信息 | |
| if 'motif' in self.selected_info_types: | |
| interproscan_results = interproscan_info[protein_id].get('interproscan_results', {}) | |
| pfam_entries = interproscan_results.get('pfam_id', []) | |
| if pfam_entries: | |
| prompt_parts.append("\nmotif:") | |
| for entry in pfam_entries: | |
| for key, value in entry.items(): | |
| if value: | |
| prompt_parts.append(f"{value}: motif information") | |
| # 添加GO信息 | |
| if 'go' in self.selected_info_types: | |
| go_ids = protein_go_dict.get(protein_id, []) | |
| if go_ids: | |
| prompt_parts.append("\nGO:") | |
| for i, go_id in enumerate(go_ids[:10], 1): | |
| prompt_parts.append(f"▢ GO term{i}: {go_id}") | |
| prompt_parts.append(f"• definition: GO term definition") | |
| # 添加用户问题 | |
| prompt_parts.append(f"\nquestion: \n{question}") | |
| return "\n".join(prompt_parts) | |
| def analyze_protein(self, sequence_input: str, fasta_file, question: str) -> str: | |
| """ | |
| 分析蛋白质序列并回答问题 | |
| """ | |
| if not question.strip(): | |
| return "请输入您的问题" | |
| # 确定使用哪个序列输入 | |
| final_sequence = None | |
| sequence_source = "" | |
| if fasta_file is not None: | |
| # 优先使用上传的文件 | |
| try: | |
| fasta_content = fasta_file.read().decode('utf-8') | |
| final_sequence, parse_msg = self.parse_fasta_content(fasta_content) | |
| if final_sequence is None: | |
| return f"文件解析错误: {parse_msg}" | |
| sequence_source = f"来自上传文件: {parse_msg}" | |
| except Exception as e: | |
| return f"读取上传文件出错: {str(e)}" | |
| elif sequence_input.strip(): | |
| # 使用文本框输入的序列 | |
| if self.validate_protein_sequence(sequence_input): | |
| final_sequence = sequence_input.strip().upper().replace('\n', '').replace(' ', '') | |
| sequence_source = "来自文本框输入" | |
| else: | |
| return "输入的序列格式不正确,请输入有效的蛋白质序列" | |
| else: | |
| return "请输入蛋白质序列或上传FASTA文件" | |
| # 创建临时目录和文件 | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| try: | |
| # 创建临时FASTA文件 | |
| temp_fasta = self.create_temp_fasta(final_sequence, "demo_protein") | |
| # 运行分析 | |
| status_msg = f"序列来源: {sequence_source}\n序列长度: {len(final_sequence)} 氨基酸\n\n正在进行分析...\n" | |
| # 步骤1: BLAST和InterProScan分析 | |
| status_msg += "步骤1: 运行BLAST分析...\n" | |
| blast_info = self.run_blast_analysis(temp_fasta, temp_dir) | |
| status_msg += "步骤2: 运行InterProScan分析...\n" | |
| interproscan_info = self.run_interproscan_analysis(temp_fasta, temp_dir) | |
| if not blast_info or not interproscan_info: | |
| return status_msg + "分析失败: 无法获取BLAST或InterProScan结果" | |
| # 步骤2: 整合GO信息 | |
| status_msg += "步骤3: 整合GO信息...\n" | |
| protein_go_dict = self.go_pipeline.first_level_filtering(interproscan_info, blast_info) | |
| # 步骤3: 生成prompt | |
| status_msg += "步骤4: 生成分析prompt...\n" | |
| protein_id = "demo_protein" | |
| prompt = self.generate_prompt(protein_id, interproscan_info, protein_go_dict, question) | |
| # 步骤4: 调用LLM生成答案 | |
| status_msg += "步骤5: 生成答案...\n" | |
| llm_response = call_chatgpt(prompt) | |
| # 组织最终结果 | |
| result = f""" | |
| {status_msg} | |
| === 分析完成 === | |
| 问题: {question} | |
| 答案: {llm_response} | |
| === 分析详情 === | |
| - BLAST匹配数: {len(blast_info.get(protein_id, {}).get('blast_results', []))} | |
| - InterProScan域数: {len(interproscan_info.get(protein_id, {}).get('interproscan_results', {}).get('pfam_id', []))} | |
| - GO术语数: {len(protein_go_dict.get(protein_id, []))} | |
| """ | |
| return result | |
| except Exception as e: | |
| return f"分析过程中出错: {str(e)}" | |
| finally: | |
| # 清理临时文件 | |
| if 'temp_fasta' in locals() and os.path.exists(temp_fasta): | |
| os.remove(temp_fasta) | |
| def create_demo(): | |
| """ | |
| 创建Gradio演示界面 | |
| """ | |
| analyzer = ProteinAnalysisDemo() | |
| with gr.Blocks(title="蛋白质功能分析演示") as demo: | |
| gr.Markdown("# 🧬 蛋白质功能分析演示") | |
| gr.Markdown("输入蛋白质序列和问题,AI将基于BLAST、InterProScan和GO信息为您提供专业分析") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 📝 序列输入") | |
| sequence_input = gr.Textbox( | |
| label="蛋白质序列", | |
| placeholder="请输入蛋白质序列(单字母氨基酸代码)...", | |
| lines=5, | |
| max_lines=10 | |
| ) | |
| gr.Markdown("**或者**") | |
| fasta_file = gr.File( | |
| label="上传FASTA文件", | |
| file_types=[".fasta", ".fa", ".fas"], | |
| file_count="single" | |
| ) | |
| gr.Markdown("### ❓ 您的问题") | |
| question_input = gr.Textbox( | |
| label="问题", | |
| placeholder="请输入关于该蛋白质的问题,例如:这个蛋白质的主要功能是什么?", | |
| lines=3 | |
| ) | |
| analyze_btn = gr.Button("🔍 开始分析", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### 📊 分析结果") | |
| output = gr.Textbox( | |
| label="分析结果", | |
| lines=20, | |
| max_lines=30, | |
| show_copy_button=True | |
| ) | |
| # 示例 | |
| gr.Markdown("### 💡 示例") | |
| gr.Examples( | |
| examples=[ | |
| ["MKALIVLGLVLLSVTVQGKVFERCELARTLKRLGMDGYRGISLANWMCLAKWESGYNTRATNYNAGDRSTDYGIFQINSRYWCNDGKTPGAVNACHLSCSALLQDNIADAVACAKRVVRDPQGIRAWVAWRNRCQNRDVRQYVQGCGV", "这个蛋白质的主要功能是什么?"], | |
| ["MGSSHHHHHHSSGLVPRGSHMRGPNPTAASLEASAGPFTVRSFTVSRPSGYGAGTVYYPTNAGGTVGAIAIVPGYTARQSSIKWWGPRLASHGFVVITIDTNSTLDQPSSRSSQQMAALRQVASLNGTSSSPIYGKVDTARMGVMGWSMGGGGSLISAANNPSLKAAAPQAPWDSSTNFSSVTVPTLIFACENDSIAPVNSSALPIYDSMSRNAKQFLEINGGSHSCANSGNSNQALIGKKGVAWMKRFPTSREJ", "这个蛋白质可能参与哪些生物学过程?"], | |
| ["ATGAGTGAACGTCTGAAATCTATCATCACCGTCGACGACGAGAACGTCAAGCTGATCGACAAGATCCTGGCCTCCATCAAGGACCTGAACGAGCTGGTGGACATGATCGACGAGATCAAGAACGTCGACGACGAGCTGATCGACAAGATCCTGGCC", "这个序列编码的蛋白质具有什么结构特征?"] | |
| ], | |
| inputs=[sequence_input, question_input] | |
| ) | |
| analyze_btn.click( | |
| fn=analyzer.analyze_protein, | |
| inputs=[sequence_input, fasta_file, question_input], | |
| outputs=[output] | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=30002, | |
| share=True, | |
| debug=False | |
| ) |