import torch import torch.nn.functional as F import json import os from collections import OrderedDict from typing import Tuple, Optional from tqdm import tqdm from transformers import PreTrainedTokenizer from model import load_encoder_components, ProteinMoleculeDualEncoder from train import train_model from train_ddp import train_model_ddp # 默认路径 DEFAULT_PROTEIN_PATH = "./SaProt_650M_AF2" DEFAULT_MOLECULE_PATH = "./ChemBERTa-zinc-base-v1" def load_dual_tower_model( protein_model_path: Optional[str] = None, molecule_model_path: Optional[str] = None, pt_path: Optional[str] = None, device: str = "cuda" if torch.cuda.is_available() else "cpu" ) -> Tuple[ProteinMoleculeDualEncoder, PreTrainedTokenizer, PreTrainedTokenizer]: """ 统一加载模型和Tokenizer的函数。 逻辑流: 1. 确定骨干网络路径(如果未提供,则使用默认值)。 2. 初始化模型结构(load_encoder_components + ProteinMoleculeDualEncoder)。 3. 如果提供了 pt_path,处理 DDP 前缀并加载权重覆盖初始权重。 """ # --- 1. 确定 Backbones 路径 --- # 如果参数为 None,则回退到默认路径;否则使用传入的参数 # 这样既满足了"只传pt_path时用默认路径",也允许"传pt_path同时指定特定骨干网络" p_path = protein_model_path if protein_model_path else DEFAULT_PROTEIN_PATH m_path = molecule_model_path if molecule_model_path else DEFAULT_MOLECULE_PATH print(f"Step 1: Initializing model structure...") print(f" - Protein Backbone: {p_path}") print(f" - Molecule Backbone: {m_path}") # --- 2. 初始化结构 (只写一次,代码复用) --- p_encoder, p_tokenizer, m_encoder, m_tokenizer = load_encoder_components( p_path, m_path ) model = ProteinMoleculeDualEncoder( protein_encoder=p_encoder, molecule_encoder=m_encoder, projection_dim=256 # 确保这里的 dim 和你训练时一致 ) # --- 3. 加载 Checkpoint (如果存在) --- if pt_path is not None: if not os.path.exists(pt_path): raise FileNotFoundError(f"Checkpoint not found at: {pt_path}") print(f"Step 2: Loading weights from {pt_path} ...") # map_location 防止显存不足或设备不匹配 state_dict = torch.load(pt_path, map_location=device) # --- 3.1 处理 DDP 前缀 (module.) --- new_state_dict = OrderedDict() for k, v in state_dict.items(): # 如果 key 是 module. 开头,去掉前7个字符 name = k[7:] if k.startswith('module.') else k new_state_dict[name] = v # --- 3.2 覆盖权重 --- # strict=True 保证权重和结构完全对应,如果有不匹配会报错提示 missing, unexpected = model.load_state_dict(new_state_dict, strict=True) print(" - Weights loaded successfully.") if missing: print(f" - Warning: Missing keys: {missing}") if unexpected: print(f" - Warning: Unexpected keys: {unexpected}") else: print("Using User-Given Encoders.") # 移动到指定设备 model.to(device) model.eval() # 默认设为评估模式,如果需要训练在外部改回 train() return model, p_tokenizer, m_tokenizer def _load_and_extract_data(json_path, extractor_func, desc="Data"): """ 通用数据加载辅助函数 Args: json_path: JSON文件路径 extractor_func: 一个函数,接收 (key, value),返回 (id, sequence_text) 如果 sequence_text 无效,应该返回 None """ print(f"Loading {desc} from {json_path}...") if not os.path.exists(json_path): print(f"Warning: {json_path} does not exist.") return [], [] with open(json_path, 'r', encoding='utf-8') as f: data = json.load(f) ids = [] seqs = [] # 遍历 JSON,利用传入的 extractor_func 提取数据 for k, v in data.items(): extracted_id, extracted_seq = extractor_func(k, v) # 简单的有效性检查 if extracted_seq and isinstance(extracted_seq, str) and len(extracted_seq.strip()) > 0: ids.append(extracted_id) seqs.append(extracted_seq) print(f"Found {len(ids)} valid items for {desc}.") return ids, seqs def _compute_tower_vectors(ids, seqs, tokenizer, encoder, projector, batch_size, device, max_len, desc): """ 通用推理辅助函数:负责 Batch处理 -> Tokenize -> Model Forward -> Normalize """ if not ids: return None, None embeddings = [] # 使用 no_grad 并在推理模式下运行 with torch.no_grad(): for i in tqdm(range(0, len(ids), batch_size), desc=f"Encoding {desc}"): batch_seqs = seqs[i : i + batch_size] # 1. Tokenize inputs = tokenizer( batch_seqs, padding=True, truncation=True, max_length=max_len, return_tensors='pt' ).to(device) # 2. Forward Chain (拆解双塔模型的单边逻辑) # Encoder (Backbone) outputs = encoder(**inputs) # Pooling (取 [CLS] token, 通常是 idx 0) vec = outputs.last_hidden_state[:, 0, :] # Projection (降维/映射层) vec = projector(vec) # Normalize (关键!检索任务必须做 L2 Normalize) vec = F.normalize(vec, p=2, dim=1) # 移回 CPU 防止显存溢出 embeddings.append(vec.cpu()) # 拼接所有 batch 的结果 if not embeddings: return None, None return ids, torch.cat(embeddings, dim=0) def update_candidate_vectors( model, protein_tokenizer, molecule_tokenizer, protein_json_path, molecule_json_path, output_dir, batch_size=64, device='cuda', max_prot_len=1024, max_mol_len=512 ): """ 主函数:编排整个流程 """ if not os.path.exists(output_dir): os.makedirs(output_dir) model.eval() model.to(device) # --- 核心配置列表 (Configuration) --- # 这里定义了如何处理 Protein 和 Molecule 的差异 tasks = [ { "desc": "Protein", "path": protein_json_path, # Protein 提取规则: # ID = Key (UniprotID) # Seq = Value["target__foldseek_seq"] "extract_fn": lambda k, v: (k, v.get('target__foldseek_seq')), "tokenizer": protein_tokenizer, "encoder": model.protein_encoder, "projector": model.prot_proj, "max_len": max_prot_len, "save_name": "protein_candidates.pt" }, { "desc": "Molecule", "path": molecule_json_path, # Molecule 提取规则: # ID = Key (SMILES) # Seq = Key (SMILES) - 因为 SMILES 既是 ID 也是序列内容 "extract_fn": lambda k, v: (k, k), "tokenizer": molecule_tokenizer, "encoder": model.molecule_encoder, "projector": model.mol_proj, "max_len": max_mol_len, "save_name": "molecule_candidates.pt" } ] # --- 统一执行流程 --- for task in tasks: # 1. 准备数据 ids, seqs = _load_and_extract_data( task['path'], task['extract_fn'], task['desc'] ) # 2. 计算向量 valid_ids, vectors = _compute_tower_vectors( ids, seqs, task['tokenizer'], task['encoder'], task['projector'], batch_size, device, task['max_len'], task['desc'] ) # 3. 保存结果 if vectors is not None: save_path = os.path.join(output_dir, task['save_name']) torch.save({ "ids": valid_ids, # 字符串列表 (UniProtID 或 SMILES) "vectors": vectors # FloatTensor [N, Dim] }, save_path) print(f"Saved {task['desc']} Vectors: {vectors.shape} to {save_path}\n") else: print(f"Skipping save for {task['desc']} (No valid data).\n") def recompute_candidate_vectors( protein_json_path: str, molecule_json_path: str, output_dir: str, pt_path: str = None, protein_model_path: str = None, molecule_model_path: str = None, batch_size: int = 64, max_prot_len: int = 1024, max_mol_len: int = 512, device: str = None ): """ 全流程函数:加载双塔模型 -> 读取元数据 -> 推理生成向量库 -> 保存 .pt 文件 Args: protein_json_path: Unique Target JSON 路径 molecule_json_path: Unique Compound JSON 路径 output_dir: 向量结果输出目录 pt_path: (可选) 训练好的 .pt 权重文件路径。如果为 None,将使用随机初始化的 Projection 层。 protein_model_path: (可选) 指定 Protein Backbone 路径。 molecule_model_path: (可选) 指定 Molecule Backbone 路径。 batch_size: 推理时的 Batch Size max_prot_len: 蛋白最大长度 max_mol_len: 分子最大长度 device: 'cuda' or 'cpu',默认自动检测 """ # 0. 自动检测设备 if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" print(f"={' Start Recomputing Candidate Vectors ':=^60}") print(f"Device: {device}") # 1. 加载模型和 Tokenizer # 这里的 load_dual_tower_model 是上一轮定义的函数 print(f"\n>>> [Stage 1/2] Loading Model & Tokenizers...") try: model, p_tokenizer, m_tokenizer = load_dual_tower_model( protein_model_path=protein_model_path, molecule_model_path=molecule_model_path, pt_path=pt_path, device=device ) except Exception as e: print(f"Error loading model: {e}") return # 2. 计算并保存向量 # 这里的 update_candidate_vectors 是上一轮定义的精简版函数 print(f"\n>>> [Stage 2/2] Computing & Saving Vectors...") try: update_candidate_vectors( model=model, protein_tokenizer=p_tokenizer, molecule_tokenizer=m_tokenizer, protein_json_path=protein_json_path, molecule_json_path=molecule_json_path, output_dir=output_dir, batch_size=batch_size, device=device, max_prot_len=max_prot_len, max_mol_len=max_mol_len ) except Exception as e: print(f"Error computing vectors: {e}") return print(f"\n={' All Done ':=^60}") print(f"Check output files in: {output_dir}") def continuous_train( dataset_path: str, model_save_dir: str = 'Dual_Tower_Model/customized_checkpoints', protein_model_path:str = None, molecule_model_path:str = None, best_model_path:str = None, device:str = "cuda" if torch.cuda.is_available() else "cpu", epochs: int = 5, lr: float = 1e-4, batch_size: int = 16, use_ddp: bool = False ): """ 执行持续训练/微调流程。 如果 best_model_path 存在,则从该 Checkpoint 继续训练; 否则使用 protein_model_path 和 molecule_model_path (或默认 Backbone) 开始训练。 """ model, p_tokenizer, m_tokenizer = load_dual_tower_model( protein_model_path=protein_model_path, molecule_model_path=molecule_model_path, pt_path=best_model_path, device=device ) if use_ddp: train_model_ddp( model_and_tokenizers=[model, p_tokenizer,m_tokenizer], dataset_path=dataset_path, model_save_dir=model_save_dir, epochs=epochs, batch_size=batch_size, lr=lr ) else: train_model( model_and_tokenizers=[model, p_tokenizer,m_tokenizer], dataset_path=dataset_path, model_save_dir=model_save_dir, epochs=epochs, batch_size=batch_size, lr=lr ) return True if __name__ == "__main__": # 假设你已经定义好了 load_dual_tower_model 和 update_candidate_vectors # 这里有两个可能的输入以及对应的三种情况 # 1. 用户给定模型:(给定两个encoder) # 2. 用户给定数据:(给定一个有'compound__smiles', 'target__foldseek_seq', 'outcome_potency_pxc50', 'outcome_is_active'字段的数据集) # (3. 是否要使用ddp: train_model_ddp或者train_model) # **需要推理** # 1. 用户没输入数据, 输入模型 --> recompute_candidate_vectors(protein_model_path, molecule_model_path) # **需要训练 + 推理** # 2. 用户输入数据, 输入模型 # --> continuous_train(protein_model_path, molecule_model_path) --> recompute_candidate_vectors(best_model_path) # 3. 用户输入数据, 没输入模型 # --> continuous_train(best_model_path) --> recompute_candidate_vectors(best_model_path) protein_json_path = 'drug_target_activity/candidates/unique_targets.json' molecule_json_path = 'drug_target_activity/candidates/unique_compounds.json' best_model_path = 'Dual_Tower_Model/output_checkpoints_ddp/model_epoch_7_acc_0.3259.pt' recompute_candidate_vectors( protein_json_path=protein_json_path, molecule_json_path=molecule_json_path, output_dir='drug_target_activity/candidates', pt_path=best_model_path, # 加载训练好的权重 batch_size=32 ) #只给了数据没给模型,就重新训tensor