Recommend_system / server_infer.py
tong
clear code
2180e31
import torch
import json
import os
import torch.nn.functional as F
# 定义默认路径常量
BASE_DIR = 'drug_target_activity/candidates'
PATH_CONFIG = {
"protein": {
"json": os.path.join(BASE_DIR, "unique_targets.json"),
"pt": os.path.join(BASE_DIR, "protein_candidates.pt"),
"target_type": "molecule" # 如果搜蛋白,目标是分子
},
"molecule": {
"json": os.path.join(BASE_DIR, "unique_compounds.json"),
"pt": os.path.join(BASE_DIR, "molecule_candidates.pt"),
"target_type": "protein" # 如果搜分子,目标是蛋白
}
}
def _find_id_by_query(metadata, query):
"""
辅助函数:在元数据字典中查找 query 对应的 Unique ID (Key)。
搜索优先级:Key (ID) -> Values (Sequence/Name/Gene...)
"""
# 1. 尝试直接匹配 Key (UniProtID 或 SMILES)
if query in metadata:
return query, "Direct ID Match"
# 2. 遍历 Values 进行匹配
# 注意:这在大数据量下可能稍慢,但在 inference 场景可接受
for unique_id, info in metadata.items():
# 检查 info 字典里的所有值 (target__gene, compound__name, sequence 等)
# 将所有值转为字符串进行比对
for key, val in info.items():
if val and str(val).strip() == str(query).strip():
return unique_id, f"Matched field: {key}"
return None, None
def retrieve_topk(input_type, input_query, topk=5, device='cpu'):
"""
双塔检索主函数。
Args:
input_type (str): 'protein' 或 'molecule'
input_query (str): 查询内容 (ID, Name, Gene, Sequence, SMILES...)
topk (int): 返回前 K 个结果
device (str): 计算设备
Returns:
list: 包含 TopK 结果的字典列表
"""
input_type = input_type.lower()
if input_type not in PATH_CONFIG:
raise ValueError("input_type must be 'protein' or 'molecule'")
# 获取当前侧(Query Side)和对侧(Target Side)的配置
q_config = PATH_CONFIG[input_type]
t_type = q_config['target_type']
t_config = PATH_CONFIG[t_type]
print(f"[*] Searching for {input_type}: '{input_query}'...")
# ---------------------------------------------------------
# 1. 解析 Query,找到唯一的 ID
# ---------------------------------------------------------
if not os.path.exists(q_config['json']):
raise FileNotFoundError(f"Metadata file not found: {q_config['json']}")
with open(q_config['json'], 'r', encoding='utf-8') as f:
q_meta = json.load(f)
query_id, match_reason = _find_id_by_query(q_meta, input_query)
if not query_id:
print(f"[!] Query not found in {input_type} metadata.")
return []
print(f" -> Resolved to ID: {query_id}... ({match_reason})") #query_id[:20]
print(f" -> Info: {q_meta[query_id]}")
# ---------------------------------------------------------
# 2. 获取 Query Vector
# ---------------------------------------------------------
if not os.path.exists(q_config['pt']):
raise FileNotFoundError(f"Vector file not found: {q_config['pt']}")
q_pt_data = torch.load(q_config['pt'], map_location=device)
try:
# 在 .pt 文件的 ids 列表中找到 query_id 的索引
q_index = q_pt_data['ids'].index(query_id)
q_vector = q_pt_data['vectors'][q_index] # [Dim]
except ValueError:
print(f"[!] ID {query_id} found in JSON but NOT in .pt vector file. (Did you update vectors?)")
return []
# ---------------------------------------------------------
# 3. 加载 Database Vectors (对侧)
# ---------------------------------------------------------
if not os.path.exists(t_config['pt']):
raise FileNotFoundError(f"Target vector file not found: {t_config['pt']}")
print(f"[*] Loading {t_type} database...")
t_pt_data = torch.load(t_config['pt'], map_location=device)
t_vectors = t_pt_data['vectors'] # [N, Dim]
t_ids = t_pt_data['ids'] # List [N]
# ---------------------------------------------------------
# 4. 计算相似度 & TopK
# ---------------------------------------------------------
# 确保 Query 向量维度匹配 [1, Dim]
q_vector = q_vector.unsqueeze(0)
# 计算余弦相似度 (假设存储的向量已经是 Normalized 的,直接点积即可)
# score shape: [1, N]
scores = torch.matmul(q_vector, t_vectors.T).squeeze(0)
# 获取 TopK
# k 不能超过数据库总量
actual_k = min(topk, len(t_ids))
top_scores, top_indices = torch.topk(scores, k=actual_k)
# ---------------------------------------------------------
# 5. 组装结果
# ---------------------------------------------------------
# 加载对侧的元数据用于展示详细信息
with open(t_config['json'], 'r', encoding='utf-8') as f:
t_meta = json.load(f)
results = []
print(f"\n{' Top ' + str(actual_k) + ' Matches ':=^50}")
for rank, (score, idx) in enumerate(zip(top_scores, top_indices)):
idx = idx.item()
score = score.item()
target_id = t_ids[idx]
target_info = t_meta.get(target_id, {"error": "Metadata missing"})
# 结果对象
res_item = {
"rank": rank + 1,
"score": round(score, 4),
"id": target_id,
"info": target_info
}
results.append(res_item)
# 打印简略信息
# 如果是分子,打印名字;如果是蛋白,打印基因名
display_name = target_info.get('compound__name') or target_info.get('target__gene') or "Unknown"
print(f"Rank {rank+1} | Score: {score:.4f} | ID: {target_id[:15]}... | {display_name}")
return results
# ==========================================
# 使用示例
# ==========================================
if __name__ == "__main__":
# 示例 1: 输入 Protein Gene Name -> 找 Molecule
# input_query 可以是 "ROS1", "Kinase", "P08922", 或者是具体的 Foldseek 序列
# "P08922": {
# "target__foldseek_seq": "#####################################################################################################WAFWDFPDAAQFKTKTFTDDI####KKKWKWKDFP######DIDDIDPDRIDMDFTDHGQTWMKMKMWID######HIHDIDPIDHHHADDAQLFFWAW##WDA###FKIKTFTHDGPRGSADWPWKKKWKDFPVDIDTDIGR###DMGG#HHGQTKIKMKMWIADLQGIFFIDIDMDTY#########KWKWWWFWLKIWIDTPVC##DATQIA####SPTGWQEWEDDPVQQWIWIDGAQWIWIDRNPHNDDCPRIDTQDGDPFGWREWEAQPLQQKIWTATPQWIWMDHNVDPPDIDTQRAPPGTRFRYWYDPQAQQWIWTQDLQGIWIFGD#########IADQDGDNLFQEWDDQN##QWIWTQPLFQQFIKIAHV######TD#####D#AWQYWEAD##WIWTDNQQFIWTD####TDTRFHHNDVV#####GIHDIYMD###########EW##WDWQAAFFKIKIFTHG############################WWKKKWKWF#####DIDIDTRHRDRMDMGGGDHGQTKIWIWMWTDD###I#DIH#TDIDHHAHD#PDFKWKWFAFQQGIWIDTS######HGDDRPH#AWLDWDDDDQKIWTAGPFFWIWIWGR########HTD#VTTQFNEWEFQPQQQKIWGFG##QWIKIAHPVPRDIDTQDGH#AGWQYKYDDNQLQKIWTDHQFWIWIDGLNNPPTDTQGGA######GWAEKEDDPPVQKIWTWID####IWIKMWGHP########IDT###########HAWYDHRQKIWGQH##QFIWMDGR######G#####HTNYMDMDDDVVNDAH##DPDNFQQEFADFDPVQWDWDD###WTKIFGDGTD#RRDFQKWKFKWKD############HGIDIDG#######IWIWIFIHGSNDIHPIPIDRD#HH#DAFEKWFDKDKAKD###########KMKIFTGDTP###AD#FWKKKWKDK#############DIDIDG##DGMDIDG#####IKMKMWMATAHPHTGYDTDDMDMDGH#####FFWKFFFDFQWGFTARQVVRDGPDIQGHPGGWQEKDDDPVQCWIWTAGSQWIWIANNPVRDIDTLDGH##RPQQNYKEDQPAQQKIKTKH######IWMWIWFSFDP#TDTDTL######WYWQDKYADFQQCKIWTWTQDP#AIFIWIARPVVSDIDGQA##############DDPDD##QFQGEWDWAQLDVVFIWIWGDGFQFIKIDGPNNRDIDTQAGH########WRYWEDDDFKIWTWID####IKIWMATRVH#DT##IR#####RYMYGDDPSRDDDD#L#QAWW######KAWDDAAQFKTKIFGGDT########HDYYQKKKKKWKFWD##########IDIDIDRDRMDMDGGHHFQTKMKMWMTMDGDSD#########DIDMDGHHADAFDEFDDKDWDAQAQWKIKIFGDD#DGGSHDLQQKFKWKDKP#DDIPP#DTHDCCPDDP#GRMDMGGRHHGQDKMKMWMWIAG####RIDIYDIDMDHGHHAFAWWGSFDAEQFKTKTWGADG#####DDKFKWKFW####DTDTADWPFDD##TTTTMGGGHHAQTKMWMWMKTATPSGDIHIYDIDMDTHHFDWWAAKAQWDQD####FKTKIFDTDRR#FDWFWKWKWKFA############IDTDDIGD####DMDG#####HWMKMKMWIGGPRGIHDIYDIDHGGG######################################################################################PFAAADVVQKAFDAFP######TKTWIFGQQLVHDPPGRDIWIKGWGD#######VR#VVSQSSVQVPDDDQQAWHFSGWHPPDPRTITITHDQPQAFQLVNLQCQQP#####GLADPLNLLQQLLQVLVVLVRCVVQQKDQQQQARNQKGWNDNDSPDDIGIHRHGR####################PLLLFALCCLPGVIDHLLRVLSSSLVRSVCSLQSNDDRPPVADSVRSNVCLNVPHGGDDRDPDFFLSVVLSVLSSDNDSVSRDGSVVSSVSSVVVSVVVVVV#V########################################################################################################################",
# "target__class": "Kinase",
# "target__gene": "ROS1"
# },
retrieve_topk(
input_type="protein",
input_query="ROS1",
topk=5
)
retrieve_topk(
input_type="protein",
input_query="#####################################################################################################WAFWDFPDAAQFKTKTFTDDI####KKKWKWKDFP######DIDDIDPDRIDMDFTDHGQTWMKMKMWID######HIHDIDPIDHHHADDAQLFFWAW##WDA###FKIKTFTHDGPRGSADWPWKKKWKDFPVDIDTDIGR###DMGG#HHGQTKIKMKMWIADLQGIFFIDIDMDTY#########KWKWWWFWLKIWIDTPVC##DATQIA####SPTGWQEWEDDPVQQWIWIDGAQWIWIDRNPHNDDCPRIDTQDGDPFGWREWEAQPLQQKIWTATPQWIWMDHNVDPPDIDTQRAPPGTRFRYWYDPQAQQWIWTQDLQGIWIFGD#########IADQDGDNLFQEWDDQN##QWIWTQPLFQQFIKIAHV######TD#####D#AWQYWEAD##WIWTDNQQFIWTD####TDTRFHHNDVV#####GIHDIYMD###########EW##WDWQAAFFKIKIFTHG############################WWKKKWKWF#####DIDIDTRHRDRMDMGGGDHGQTKIWIWMWTDD###I#DIH#TDIDHHAHD#PDFKWKWFAFQQGIWIDTS######HGDDRPH#AWLDWDDDDQKIWTAGPFFWIWIWGR########HTD#VTTQFNEWEFQPQQQKIWGFG##QWIKIAHPVPRDIDTQDGH#AGWQYKYDDNQLQKIWTDHQFWIWIDGLNNPPTDTQGGA######GWAEKEDDPPVQKIWTWID####IWIKMWGHP########IDT###########HAWYDHRQKIWGQH##QFIWMDGR######G#####HTNYMDMDDDVVNDAH##DPDNFQQEFADFDPVQWDWDD###WTKIFGDGTD#RRDFQKWKFKWKD############HGIDIDG#######IWIWIFIHGSNDIHPIPIDRD#HH#DAFEKWFDKDKAKD###########KMKIFTGDTP###AD#FWKKKWKDK#############DIDIDG##DGMDIDG#####IKMKMWMATAHPHTGYDTDDMDMDGH#####FFWKFFFDFQWGFTARQVVRDGPDIQGHPGGWQEKDDDPVQCWIWTAGSQWIWIANNPVRDIDTLDGH##RPQQNYKEDQPAQQKIKTKH######IWMWIWFSFDP#TDTDTL######WYWQDKYADFQQCKIWTWTQDP#AIFIWIARPVVSDIDGQA##############DDPDD##QFQGEWDWAQLDVVFIWIWGDGFQFIKIDGPNNRDIDTQAGH########WRYWEDDDFKIWTWID####IKIWMATRVH#DT##IR#####RYMYGDDPSRDDDD#L#QAWW######KAWDDAAQFKTKIFGGDT########HDYYQKKKKKWKFWD##########IDIDIDRDRMDMDGGHHFQTKMKMWMTMDGDSD#########DIDMDGHHADAFDEFDDKDWDAQAQWKIKIFGDD#DGGSHDLQQKFKWKDKP#DDIPP#DTHDCCPDDP#GRMDMGGRHHGQDKMKMWMWIAG####RIDIYDIDMDHGHHAFAWWGSFDAEQFKTKTWGADG#####DDKFKWKFW####DTDTADWPFDD##TTTTMGGGHHAQTKMWMWMKTATPSGDIHIYDIDMDTHHFDWWAAKAQWDQD####FKTKIFDTDRR#FDWFWKWKWKFA############IDTDDIGD####DMDG#####HWMKMKMWIGGPRGIHDIYDIDHGGG######################################################################################PFAAADVVQKAFDAFP######TKTWIFGQQLVHDPPGRDIWIKGWGD#######VR#VVSQSSVQVPDDDQQAWHFSGWHPPDPRTITITHDQPQAFQLVNLQCQQP#####GLADPLNLLQQLLQVLVVLVRCVVQQKDQQQQARNQKGWNDNDSPDDIGIHRHGR####################PLLLFALCCLPGVIDHLLRVLSSSLVRSVCSLQSNDDRPPVADSVRSNVCLNVPHGGDDRDPDFFLSVVLSVLSSDNDSVSRDGSVVSSVSSVVVSVVVVVV#V########################################################################################################################",
topk=5
)
print("\n" + "-"*50 + "\n")
# 示例 2: 输入 Molecule Name -> 找 Protein
# input_query 可以是 "Lorlatinib", "DB12130" 或 SMILES 串
# "C[C@H]1OC2=C(N)N=CC(=C2)C2=C(C#N)N(C)N=C2CN(C)C(=O)C2=C1C=C(F)C=C2": {
# "compound__name": "Lorlatinib",
# "compound__drugbank_id": "DB12130",
# "compound__cas": "1454846-35-5",
# "compound__unii": "OSP71S83EU",
# "compound__inchikey": "IIXWYSCJSQVBQM-LLVKDONJSA-N"
# },
retrieve_topk(
input_type="molecule",
input_query="Lorlatinib",
topk=3
)
retrieve_topk(
input_type="molecule",
input_query="C[C@H]1OC2=C(N)N=CC(=C2)C2=C(C#N)N(C)N=C2CN(C)C(=O)C2=C1C=C(F)C=C2",
topk=3
)