Spaces:
Sleeping
Sleeping
| import torch | |
| import json | |
| import os | |
| import torch.nn.functional as F | |
| # 定义默认路径常量 | |
| BASE_DIR = 'drug_target_activity/candidates' | |
| PATH_CONFIG = { | |
| "protein": { | |
| "json": os.path.join(BASE_DIR, "unique_targets.json"), | |
| "pt": os.path.join(BASE_DIR, "protein_candidates.pt"), | |
| "target_type": "molecule" # 如果搜蛋白,目标是分子 | |
| }, | |
| "molecule": { | |
| "json": os.path.join(BASE_DIR, "unique_compounds.json"), | |
| "pt": os.path.join(BASE_DIR, "molecule_candidates.pt"), | |
| "target_type": "protein" # 如果搜分子,目标是蛋白 | |
| } | |
| } | |
| def _find_id_by_query(metadata, query): | |
| """ | |
| 辅助函数:在元数据字典中查找 query 对应的 Unique ID (Key)。 | |
| 搜索优先级:Key (ID) -> Values (Sequence/Name/Gene...) | |
| """ | |
| # 1. 尝试直接匹配 Key (UniProtID 或 SMILES) | |
| if query in metadata: | |
| return query, "Direct ID Match" | |
| # 2. 遍历 Values 进行匹配 | |
| # 注意:这在大数据量下可能稍慢,但在 inference 场景可接受 | |
| for unique_id, info in metadata.items(): | |
| # 检查 info 字典里的所有值 (target__gene, compound__name, sequence 等) | |
| # 将所有值转为字符串进行比对 | |
| for key, val in info.items(): | |
| if val and str(val).strip() == str(query).strip(): | |
| return unique_id, f"Matched field: {key}" | |
| return None, None | |
| def retrieve_topk(input_type, input_query, topk=5, device='cpu'): | |
| """ | |
| 双塔检索主函数。 | |
| Args: | |
| input_type (str): 'protein' 或 'molecule' | |
| input_query (str): 查询内容 (ID, Name, Gene, Sequence, SMILES...) | |
| topk (int): 返回前 K 个结果 | |
| device (str): 计算设备 | |
| Returns: | |
| list: 包含 TopK 结果的字典列表 | |
| """ | |
| input_type = input_type.lower() | |
| if input_type not in PATH_CONFIG: | |
| raise ValueError("input_type must be 'protein' or 'molecule'") | |
| # 获取当前侧(Query Side)和对侧(Target Side)的配置 | |
| q_config = PATH_CONFIG[input_type] | |
| t_type = q_config['target_type'] | |
| t_config = PATH_CONFIG[t_type] | |
| print(f"[*] Searching for {input_type}: '{input_query}'...") | |
| # --------------------------------------------------------- | |
| # 1. 解析 Query,找到唯一的 ID | |
| # --------------------------------------------------------- | |
| if not os.path.exists(q_config['json']): | |
| raise FileNotFoundError(f"Metadata file not found: {q_config['json']}") | |
| with open(q_config['json'], 'r', encoding='utf-8') as f: | |
| q_meta = json.load(f) | |
| query_id, match_reason = _find_id_by_query(q_meta, input_query) | |
| if not query_id: | |
| print(f"[!] Query not found in {input_type} metadata.") | |
| return [] | |
| print(f" -> Resolved to ID: {query_id}... ({match_reason})") #query_id[:20] | |
| print(f" -> Info: {q_meta[query_id]}") | |
| # --------------------------------------------------------- | |
| # 2. 获取 Query Vector | |
| # --------------------------------------------------------- | |
| if not os.path.exists(q_config['pt']): | |
| raise FileNotFoundError(f"Vector file not found: {q_config['pt']}") | |
| q_pt_data = torch.load(q_config['pt'], map_location=device) | |
| try: | |
| # 在 .pt 文件的 ids 列表中找到 query_id 的索引 | |
| q_index = q_pt_data['ids'].index(query_id) | |
| q_vector = q_pt_data['vectors'][q_index] # [Dim] | |
| except ValueError: | |
| print(f"[!] ID {query_id} found in JSON but NOT in .pt vector file. (Did you update vectors?)") | |
| return [] | |
| # --------------------------------------------------------- | |
| # 3. 加载 Database Vectors (对侧) | |
| # --------------------------------------------------------- | |
| if not os.path.exists(t_config['pt']): | |
| raise FileNotFoundError(f"Target vector file not found: {t_config['pt']}") | |
| print(f"[*] Loading {t_type} database...") | |
| t_pt_data = torch.load(t_config['pt'], map_location=device) | |
| t_vectors = t_pt_data['vectors'] # [N, Dim] | |
| t_ids = t_pt_data['ids'] # List [N] | |
| # --------------------------------------------------------- | |
| # 4. 计算相似度 & TopK | |
| # --------------------------------------------------------- | |
| # 确保 Query 向量维度匹配 [1, Dim] | |
| q_vector = q_vector.unsqueeze(0) | |
| # 计算余弦相似度 (假设存储的向量已经是 Normalized 的,直接点积即可) | |
| # score shape: [1, N] | |
| scores = torch.matmul(q_vector, t_vectors.T).squeeze(0) | |
| # 获取 TopK | |
| # k 不能超过数据库总量 | |
| actual_k = min(topk, len(t_ids)) | |
| top_scores, top_indices = torch.topk(scores, k=actual_k) | |
| # --------------------------------------------------------- | |
| # 5. 组装结果 | |
| # --------------------------------------------------------- | |
| # 加载对侧的元数据用于展示详细信息 | |
| with open(t_config['json'], 'r', encoding='utf-8') as f: | |
| t_meta = json.load(f) | |
| results = [] | |
| print(f"\n{' Top ' + str(actual_k) + ' Matches ':=^50}") | |
| for rank, (score, idx) in enumerate(zip(top_scores, top_indices)): | |
| idx = idx.item() | |
| score = score.item() | |
| target_id = t_ids[idx] | |
| target_info = t_meta.get(target_id, {"error": "Metadata missing"}) | |
| # 结果对象 | |
| res_item = { | |
| "rank": rank + 1, | |
| "score": round(score, 4), | |
| "id": target_id, | |
| "info": target_info | |
| } | |
| results.append(res_item) | |
| # 打印简略信息 | |
| # 如果是分子,打印名字;如果是蛋白,打印基因名 | |
| display_name = target_info.get('compound__name') or target_info.get('target__gene') or "Unknown" | |
| print(f"Rank {rank+1} | Score: {score:.4f} | ID: {target_id[:15]}... | {display_name}") | |
| return results | |
| # ========================================== | |
| # 使用示例 | |
| # ========================================== | |
| if __name__ == "__main__": | |
| # 示例 1: 输入 Protein Gene Name -> 找 Molecule | |
| # input_query 可以是 "ROS1", "Kinase", "P08922", 或者是具体的 Foldseek 序列 | |
| # "P08922": { | |
| # "target__foldseek_seq": "#####################################################################################################WAFWDFPDAAQFKTKTFTDDI####KKKWKWKDFP######DIDDIDPDRIDMDFTDHGQTWMKMKMWID######HIHDIDPIDHHHADDAQLFFWAW##WDA###FKIKTFTHDGPRGSADWPWKKKWKDFPVDIDTDIGR###DMGG#HHGQTKIKMKMWIADLQGIFFIDIDMDTY#########KWKWWWFWLKIWIDTPVC##DATQIA####SPTGWQEWEDDPVQQWIWIDGAQWIWIDRNPHNDDCPRIDTQDGDPFGWREWEAQPLQQKIWTATPQWIWMDHNVDPPDIDTQRAPPGTRFRYWYDPQAQQWIWTQDLQGIWIFGD#########IADQDGDNLFQEWDDQN##QWIWTQPLFQQFIKIAHV######TD#####D#AWQYWEAD##WIWTDNQQFIWTD####TDTRFHHNDVV#####GIHDIYMD###########EW##WDWQAAFFKIKIFTHG############################WWKKKWKWF#####DIDIDTRHRDRMDMGGGDHGQTKIWIWMWTDD###I#DIH#TDIDHHAHD#PDFKWKWFAFQQGIWIDTS######HGDDRPH#AWLDWDDDDQKIWTAGPFFWIWIWGR########HTD#VTTQFNEWEFQPQQQKIWGFG##QWIKIAHPVPRDIDTQDGH#AGWQYKYDDNQLQKIWTDHQFWIWIDGLNNPPTDTQGGA######GWAEKEDDPPVQKIWTWID####IWIKMWGHP########IDT###########HAWYDHRQKIWGQH##QFIWMDGR######G#####HTNYMDMDDDVVNDAH##DPDNFQQEFADFDPVQWDWDD###WTKIFGDGTD#RRDFQKWKFKWKD############HGIDIDG#######IWIWIFIHGSNDIHPIPIDRD#HH#DAFEKWFDKDKAKD###########KMKIFTGDTP###AD#FWKKKWKDK#############DIDIDG##DGMDIDG#####IKMKMWMATAHPHTGYDTDDMDMDGH#####FFWKFFFDFQWGFTARQVVRDGPDIQGHPGGWQEKDDDPVQCWIWTAGSQWIWIANNPVRDIDTLDGH##RPQQNYKEDQPAQQKIKTKH######IWMWIWFSFDP#TDTDTL######WYWQDKYADFQQCKIWTWTQDP#AIFIWIARPVVSDIDGQA##############DDPDD##QFQGEWDWAQLDVVFIWIWGDGFQFIKIDGPNNRDIDTQAGH########WRYWEDDDFKIWTWID####IKIWMATRVH#DT##IR#####RYMYGDDPSRDDDD#L#QAWW######KAWDDAAQFKTKIFGGDT########HDYYQKKKKKWKFWD##########IDIDIDRDRMDMDGGHHFQTKMKMWMTMDGDSD#########DIDMDGHHADAFDEFDDKDWDAQAQWKIKIFGDD#DGGSHDLQQKFKWKDKP#DDIPP#DTHDCCPDDP#GRMDMGGRHHGQDKMKMWMWIAG####RIDIYDIDMDHGHHAFAWWGSFDAEQFKTKTWGADG#####DDKFKWKFW####DTDTADWPFDD##TTTTMGGGHHAQTKMWMWMKTATPSGDIHIYDIDMDTHHFDWWAAKAQWDQD####FKTKIFDTDRR#FDWFWKWKWKFA############IDTDDIGD####DMDG#####HWMKMKMWIGGPRGIHDIYDIDHGGG######################################################################################PFAAADVVQKAFDAFP######TKTWIFGQQLVHDPPGRDIWIKGWGD#######VR#VVSQSSVQVPDDDQQAWHFSGWHPPDPRTITITHDQPQAFQLVNLQCQQP#####GLADPLNLLQQLLQVLVVLVRCVVQQKDQQQQARNQKGWNDNDSPDDIGIHRHGR####################PLLLFALCCLPGVIDHLLRVLSSSLVRSVCSLQSNDDRPPVADSVRSNVCLNVPHGGDDRDPDFFLSVVLSVLSSDNDSVSRDGSVVSSVSSVVVSVVVVVV#V########################################################################################################################", | |
| # "target__class": "Kinase", | |
| # "target__gene": "ROS1" | |
| # }, | |
| retrieve_topk( | |
| input_type="protein", | |
| input_query="ROS1", | |
| topk=5 | |
| ) | |
| retrieve_topk( | |
| input_type="protein", | |
| input_query="#####################################################################################################WAFWDFPDAAQFKTKTFTDDI####KKKWKWKDFP######DIDDIDPDRIDMDFTDHGQTWMKMKMWID######HIHDIDPIDHHHADDAQLFFWAW##WDA###FKIKTFTHDGPRGSADWPWKKKWKDFPVDIDTDIGR###DMGG#HHGQTKIKMKMWIADLQGIFFIDIDMDTY#########KWKWWWFWLKIWIDTPVC##DATQIA####SPTGWQEWEDDPVQQWIWIDGAQWIWIDRNPHNDDCPRIDTQDGDPFGWREWEAQPLQQKIWTATPQWIWMDHNVDPPDIDTQRAPPGTRFRYWYDPQAQQWIWTQDLQGIWIFGD#########IADQDGDNLFQEWDDQN##QWIWTQPLFQQFIKIAHV######TD#####D#AWQYWEAD##WIWTDNQQFIWTD####TDTRFHHNDVV#####GIHDIYMD###########EW##WDWQAAFFKIKIFTHG############################WWKKKWKWF#####DIDIDTRHRDRMDMGGGDHGQTKIWIWMWTDD###I#DIH#TDIDHHAHD#PDFKWKWFAFQQGIWIDTS######HGDDRPH#AWLDWDDDDQKIWTAGPFFWIWIWGR########HTD#VTTQFNEWEFQPQQQKIWGFG##QWIKIAHPVPRDIDTQDGH#AGWQYKYDDNQLQKIWTDHQFWIWIDGLNNPPTDTQGGA######GWAEKEDDPPVQKIWTWID####IWIKMWGHP########IDT###########HAWYDHRQKIWGQH##QFIWMDGR######G#####HTNYMDMDDDVVNDAH##DPDNFQQEFADFDPVQWDWDD###WTKIFGDGTD#RRDFQKWKFKWKD############HGIDIDG#######IWIWIFIHGSNDIHPIPIDRD#HH#DAFEKWFDKDKAKD###########KMKIFTGDTP###AD#FWKKKWKDK#############DIDIDG##DGMDIDG#####IKMKMWMATAHPHTGYDTDDMDMDGH#####FFWKFFFDFQWGFTARQVVRDGPDIQGHPGGWQEKDDDPVQCWIWTAGSQWIWIANNPVRDIDTLDGH##RPQQNYKEDQPAQQKIKTKH######IWMWIWFSFDP#TDTDTL######WYWQDKYADFQQCKIWTWTQDP#AIFIWIARPVVSDIDGQA##############DDPDD##QFQGEWDWAQLDVVFIWIWGDGFQFIKIDGPNNRDIDTQAGH########WRYWEDDDFKIWTWID####IKIWMATRVH#DT##IR#####RYMYGDDPSRDDDD#L#QAWW######KAWDDAAQFKTKIFGGDT########HDYYQKKKKKWKFWD##########IDIDIDRDRMDMDGGHHFQTKMKMWMTMDGDSD#########DIDMDGHHADAFDEFDDKDWDAQAQWKIKIFGDD#DGGSHDLQQKFKWKDKP#DDIPP#DTHDCCPDDP#GRMDMGGRHHGQDKMKMWMWIAG####RIDIYDIDMDHGHHAFAWWGSFDAEQFKTKTWGADG#####DDKFKWKFW####DTDTADWPFDD##TTTTMGGGHHAQTKMWMWMKTATPSGDIHIYDIDMDTHHFDWWAAKAQWDQD####FKTKIFDTDRR#FDWFWKWKWKFA############IDTDDIGD####DMDG#####HWMKMKMWIGGPRGIHDIYDIDHGGG######################################################################################PFAAADVVQKAFDAFP######TKTWIFGQQLVHDPPGRDIWIKGWGD#######VR#VVSQSSVQVPDDDQQAWHFSGWHPPDPRTITITHDQPQAFQLVNLQCQQP#####GLADPLNLLQQLLQVLVVLVRCVVQQKDQQQQARNQKGWNDNDSPDDIGIHRHGR####################PLLLFALCCLPGVIDHLLRVLSSSLVRSVCSLQSNDDRPPVADSVRSNVCLNVPHGGDDRDPDFFLSVVLSVLSSDNDSVSRDGSVVSSVSSVVVSVVVVVV#V########################################################################################################################", | |
| topk=5 | |
| ) | |
| print("\n" + "-"*50 + "\n") | |
| # 示例 2: 输入 Molecule Name -> 找 Protein | |
| # input_query 可以是 "Lorlatinib", "DB12130" 或 SMILES 串 | |
| # "C[C@H]1OC2=C(N)N=CC(=C2)C2=C(C#N)N(C)N=C2CN(C)C(=O)C2=C1C=C(F)C=C2": { | |
| # "compound__name": "Lorlatinib", | |
| # "compound__drugbank_id": "DB12130", | |
| # "compound__cas": "1454846-35-5", | |
| # "compound__unii": "OSP71S83EU", | |
| # "compound__inchikey": "IIXWYSCJSQVBQM-LLVKDONJSA-N" | |
| # }, | |
| retrieve_topk( | |
| input_type="molecule", | |
| input_query="Lorlatinib", | |
| topk=3 | |
| ) | |
| retrieve_topk( | |
| input_type="molecule", | |
| input_query="C[C@H]1OC2=C(N)N=CC(=C2)C2=C(C#N)N(C)N=C2CN(C)C(=O)C2=C1C=C(F)C=C2", | |
| topk=3 | |
| ) |