import torch import json import os import torch.nn.functional as F # 定义默认路径常量 BASE_DIR = 'drug_target_activity/candidates' PATH_CONFIG = { "protein": { "json": os.path.join(BASE_DIR, "unique_targets.json"), "pt": os.path.join(BASE_DIR, "protein_candidates.pt"), "target_type": "molecule" # 如果搜蛋白,目标是分子 }, "molecule": { "json": os.path.join(BASE_DIR, "unique_compounds.json"), "pt": os.path.join(BASE_DIR, "molecule_candidates.pt"), "target_type": "protein" # 如果搜分子,目标是蛋白 } } def _find_id_by_query(metadata, query): """ 辅助函数:在元数据字典中查找 query 对应的 Unique ID (Key)。 搜索优先级:Key (ID) -> Values (Sequence/Name/Gene...) """ # 1. 尝试直接匹配 Key (UniProtID 或 SMILES) if query in metadata: return query, "Direct ID Match" # 2. 遍历 Values 进行匹配 # 注意:这在大数据量下可能稍慢,但在 inference 场景可接受 for unique_id, info in metadata.items(): # 检查 info 字典里的所有值 (target__gene, compound__name, sequence 等) # 将所有值转为字符串进行比对 for key, val in info.items(): if val and str(val).strip() == str(query).strip(): return unique_id, f"Matched field: {key}" return None, None def retrieve_topk(input_type, input_query, topk=5, device='cpu'): """ 双塔检索主函数。 Args: input_type (str): 'protein' 或 'molecule' input_query (str): 查询内容 (ID, Name, Gene, Sequence, SMILES...) topk (int): 返回前 K 个结果 device (str): 计算设备 Returns: list: 包含 TopK 结果的字典列表 """ input_type = input_type.lower() if input_type not in PATH_CONFIG: raise ValueError("input_type must be 'protein' or 'molecule'") # 获取当前侧(Query Side)和对侧(Target Side)的配置 q_config = PATH_CONFIG[input_type] t_type = q_config['target_type'] t_config = PATH_CONFIG[t_type] print(f"[*] Searching for {input_type}: '{input_query}'...") # --------------------------------------------------------- # 1. 解析 Query,找到唯一的 ID # --------------------------------------------------------- if not os.path.exists(q_config['json']): raise FileNotFoundError(f"Metadata file not found: {q_config['json']}") with open(q_config['json'], 'r', encoding='utf-8') as f: q_meta = json.load(f) query_id, match_reason = _find_id_by_query(q_meta, input_query) if not query_id: print(f"[!] Query not found in {input_type} metadata.") return [] print(f" -> Resolved to ID: {query_id}... ({match_reason})") #query_id[:20] print(f" -> Info: {q_meta[query_id]}") # --------------------------------------------------------- # 2. 获取 Query Vector # --------------------------------------------------------- if not os.path.exists(q_config['pt']): raise FileNotFoundError(f"Vector file not found: {q_config['pt']}") q_pt_data = torch.load(q_config['pt'], map_location=device) try: # 在 .pt 文件的 ids 列表中找到 query_id 的索引 q_index = q_pt_data['ids'].index(query_id) q_vector = q_pt_data['vectors'][q_index] # [Dim] except ValueError: print(f"[!] ID {query_id} found in JSON but NOT in .pt vector file. (Did you update vectors?)") return [] # --------------------------------------------------------- # 3. 加载 Database Vectors (对侧) # --------------------------------------------------------- if not os.path.exists(t_config['pt']): raise FileNotFoundError(f"Target vector file not found: {t_config['pt']}") print(f"[*] Loading {t_type} database...") t_pt_data = torch.load(t_config['pt'], map_location=device) t_vectors = t_pt_data['vectors'] # [N, Dim] t_ids = t_pt_data['ids'] # List [N] # --------------------------------------------------------- # 4. 计算相似度 & TopK # --------------------------------------------------------- # 确保 Query 向量维度匹配 [1, Dim] q_vector = q_vector.unsqueeze(0) # 计算余弦相似度 (假设存储的向量已经是 Normalized 的,直接点积即可) # score shape: [1, N] scores = torch.matmul(q_vector, t_vectors.T).squeeze(0) # 获取 TopK # k 不能超过数据库总量 actual_k = min(topk, len(t_ids)) top_scores, top_indices = torch.topk(scores, k=actual_k) # --------------------------------------------------------- # 5. 组装结果 # --------------------------------------------------------- # 加载对侧的元数据用于展示详细信息 with open(t_config['json'], 'r', encoding='utf-8') as f: t_meta = json.load(f) results = [] print(f"\n{' Top ' + str(actual_k) + ' Matches ':=^50}") for rank, (score, idx) in enumerate(zip(top_scores, top_indices)): idx = idx.item() score = score.item() target_id = t_ids[idx] target_info = t_meta.get(target_id, {"error": "Metadata missing"}) # 结果对象 res_item = { "rank": rank + 1, "score": round(score, 4), "id": target_id, "info": target_info } results.append(res_item) # 打印简略信息 # 如果是分子,打印名字;如果是蛋白,打印基因名 display_name = target_info.get('compound__name') or target_info.get('target__gene') or "Unknown" print(f"Rank {rank+1} | Score: {score:.4f} | ID: {target_id[:15]}... | {display_name}") return results # ========================================== # 使用示例 # ========================================== if __name__ == "__main__": # 示例 1: 输入 Protein Gene Name -> 找 Molecule # input_query 可以是 "ROS1", "Kinase", "P08922", 或者是具体的 Foldseek 序列 # "P08922": { # "target__foldseek_seq": "#####################################################################################################WAFWDFPDAAQFKTKTFTDDI####KKKWKWKDFP######DIDDIDPDRIDMDFTDHGQTWMKMKMWID######HIHDIDPIDHHHADDAQLFFWAW##WDA###FKIKTFTHDGPRGSADWPWKKKWKDFPVDIDTDIGR###DMGG#HHGQTKIKMKMWIADLQGIFFIDIDMDTY#########KWKWWWFWLKIWIDTPVC##DATQIA####SPTGWQEWEDDPVQQWIWIDGAQWIWIDRNPHNDDCPRIDTQDGDPFGWREWEAQPLQQKIWTATPQWIWMDHNVDPPDIDTQRAPPGTRFRYWYDPQAQQWIWTQDLQGIWIFGD#########IADQDGDNLFQEWDDQN##QWIWTQPLFQQFIKIAHV######TD#####D#AWQYWEAD##WIWTDNQQFIWTD####TDTRFHHNDVV#####GIHDIYMD###########EW##WDWQAAFFKIKIFTHG############################WWKKKWKWF#####DIDIDTRHRDRMDMGGGDHGQTKIWIWMWTDD###I#DIH#TDIDHHAHD#PDFKWKWFAFQQGIWIDTS######HGDDRPH#AWLDWDDDDQKIWTAGPFFWIWIWGR########HTD#VTTQFNEWEFQPQQQKIWGFG##QWIKIAHPVPRDIDTQDGH#AGWQYKYDDNQLQKIWTDHQFWIWIDGLNNPPTDTQGGA######GWAEKEDDPPVQKIWTWID####IWIKMWGHP########IDT###########HAWYDHRQKIWGQH##QFIWMDGR######G#####HTNYMDMDDDVVNDAH##DPDNFQQEFADFDPVQWDWDD###WTKIFGDGTD#RRDFQKWKFKWKD############HGIDIDG#######IWIWIFIHGSNDIHPIPIDRD#HH#DAFEKWFDKDKAKD###########KMKIFTGDTP###AD#FWKKKWKDK#############DIDIDG##DGMDIDG#####IKMKMWMATAHPHTGYDTDDMDMDGH#####FFWKFFFDFQWGFTARQVVRDGPDIQGHPGGWQEKDDDPVQCWIWTAGSQWIWIANNPVRDIDTLDGH##RPQQNYKEDQPAQQKIKTKH######IWMWIWFSFDP#TDTDTL######WYWQDKYADFQQCKIWTWTQDP#AIFIWIARPVVSDIDGQA##############DDPDD##QFQGEWDWAQLDVVFIWIWGDGFQFIKIDGPNNRDIDTQAGH########WRYWEDDDFKIWTWID####IKIWMATRVH#DT##IR#####RYMYGDDPSRDDDD#L#QAWW######KAWDDAAQFKTKIFGGDT########HDYYQKKKKKWKFWD##########IDIDIDRDRMDMDGGHHFQTKMKMWMTMDGDSD#########DIDMDGHHADAFDEFDDKDWDAQAQWKIKIFGDD#DGGSHDLQQKFKWKDKP#DDIPP#DTHDCCPDDP#GRMDMGGRHHGQDKMKMWMWIAG####RIDIYDIDMDHGHHAFAWWGSFDAEQFKTKTWGADG#####DDKFKWKFW####DTDTADWPFDD##TTTTMGGGHHAQTKMWMWMKTATPSGDIHIYDIDMDTHHFDWWAAKAQWDQD####FKTKIFDTDRR#FDWFWKWKWKFA############IDTDDIGD####DMDG#####HWMKMKMWIGGPRGIHDIYDIDHGGG######################################################################################PFAAADVVQKAFDAFP######TKTWIFGQQLVHDPPGRDIWIKGWGD#######VR#VVSQSSVQVPDDDQQAWHFSGWHPPDPRTITITHDQPQAFQLVNLQCQQP#####GLADPLNLLQQLLQVLVVLVRCVVQQKDQQQQARNQKGWNDNDSPDDIGIHRHGR####################PLLLFALCCLPGVIDHLLRVLSSSLVRSVCSLQSNDDRPPVADSVRSNVCLNVPHGGDDRDPDFFLSVVLSVLSSDNDSVSRDGSVVSSVSSVVVSVVVVVV#V########################################################################################################################", # "target__class": "Kinase", # "target__gene": "ROS1" # }, retrieve_topk( input_type="protein", input_query="ROS1", topk=5 ) retrieve_topk( input_type="protein", input_query="#####################################################################################################WAFWDFPDAAQFKTKTFTDDI####KKKWKWKDFP######DIDDIDPDRIDMDFTDHGQTWMKMKMWID######HIHDIDPIDHHHADDAQLFFWAW##WDA###FKIKTFTHDGPRGSADWPWKKKWKDFPVDIDTDIGR###DMGG#HHGQTKIKMKMWIADLQGIFFIDIDMDTY#########KWKWWWFWLKIWIDTPVC##DATQIA####SPTGWQEWEDDPVQQWIWIDGAQWIWIDRNPHNDDCPRIDTQDGDPFGWREWEAQPLQQKIWTATPQWIWMDHNVDPPDIDTQRAPPGTRFRYWYDPQAQQWIWTQDLQGIWIFGD#########IADQDGDNLFQEWDDQN##QWIWTQPLFQQFIKIAHV######TD#####D#AWQYWEAD##WIWTDNQQFIWTD####TDTRFHHNDVV#####GIHDIYMD###########EW##WDWQAAFFKIKIFTHG############################WWKKKWKWF#####DIDIDTRHRDRMDMGGGDHGQTKIWIWMWTDD###I#DIH#TDIDHHAHD#PDFKWKWFAFQQGIWIDTS######HGDDRPH#AWLDWDDDDQKIWTAGPFFWIWIWGR########HTD#VTTQFNEWEFQPQQQKIWGFG##QWIKIAHPVPRDIDTQDGH#AGWQYKYDDNQLQKIWTDHQFWIWIDGLNNPPTDTQGGA######GWAEKEDDPPVQKIWTWID####IWIKMWGHP########IDT###########HAWYDHRQKIWGQH##QFIWMDGR######G#####HTNYMDMDDDVVNDAH##DPDNFQQEFADFDPVQWDWDD###WTKIFGDGTD#RRDFQKWKFKWKD############HGIDIDG#######IWIWIFIHGSNDIHPIPIDRD#HH#DAFEKWFDKDKAKD###########KMKIFTGDTP###AD#FWKKKWKDK#############DIDIDG##DGMDIDG#####IKMKMWMATAHPHTGYDTDDMDMDGH#####FFWKFFFDFQWGFTARQVVRDGPDIQGHPGGWQEKDDDPVQCWIWTAGSQWIWIANNPVRDIDTLDGH##RPQQNYKEDQPAQQKIKTKH######IWMWIWFSFDP#TDTDTL######WYWQDKYADFQQCKIWTWTQDP#AIFIWIARPVVSDIDGQA##############DDPDD##QFQGEWDWAQLDVVFIWIWGDGFQFIKIDGPNNRDIDTQAGH########WRYWEDDDFKIWTWID####IKIWMATRVH#DT##IR#####RYMYGDDPSRDDDD#L#QAWW######KAWDDAAQFKTKIFGGDT########HDYYQKKKKKWKFWD##########IDIDIDRDRMDMDGGHHFQTKMKMWMTMDGDSD#########DIDMDGHHADAFDEFDDKDWDAQAQWKIKIFGDD#DGGSHDLQQKFKWKDKP#DDIPP#DTHDCCPDDP#GRMDMGGRHHGQDKMKMWMWIAG####RIDIYDIDMDHGHHAFAWWGSFDAEQFKTKTWGADG#####DDKFKWKFW####DTDTADWPFDD##TTTTMGGGHHAQTKMWMWMKTATPSGDIHIYDIDMDTHHFDWWAAKAQWDQD####FKTKIFDTDRR#FDWFWKWKWKFA############IDTDDIGD####DMDG#####HWMKMKMWIGGPRGIHDIYDIDHGGG######################################################################################PFAAADVVQKAFDAFP######TKTWIFGQQLVHDPPGRDIWIKGWGD#######VR#VVSQSSVQVPDDDQQAWHFSGWHPPDPRTITITHDQPQAFQLVNLQCQQP#####GLADPLNLLQQLLQVLVVLVRCVVQQKDQQQQARNQKGWNDNDSPDDIGIHRHGR####################PLLLFALCCLPGVIDHLLRVLSSSLVRSVCSLQSNDDRPPVADSVRSNVCLNVPHGGDDRDPDFFLSVVLSVLSSDNDSVSRDGSVVSSVSSVVVSVVVVVV#V########################################################################################################################", topk=5 ) print("\n" + "-"*50 + "\n") # 示例 2: 输入 Molecule Name -> 找 Protein # input_query 可以是 "Lorlatinib", "DB12130" 或 SMILES 串 # "C[C@H]1OC2=C(N)N=CC(=C2)C2=C(C#N)N(C)N=C2CN(C)C(=O)C2=C1C=C(F)C=C2": { # "compound__name": "Lorlatinib", # "compound__drugbank_id": "DB12130", # "compound__cas": "1454846-35-5", # "compound__unii": "OSP71S83EU", # "compound__inchikey": "IIXWYSCJSQVBQM-LLVKDONJSA-N" # }, retrieve_topk( input_type="molecule", input_query="Lorlatinib", topk=3 ) retrieve_topk( input_type="molecule", input_query="C[C@H]1OC2=C(N)N=CC(=C2)C2=C(C#N)N(C)N=C2CN(C)C(=O)C2=C1C=C(F)C=C2", topk=3 )