Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| import json | |
| import os | |
| from collections import OrderedDict | |
| from typing import Tuple, Optional | |
| from tqdm import tqdm | |
| from transformers import PreTrainedTokenizer | |
| from model import load_encoder_components, ProteinMoleculeDualEncoder | |
| from train import train_model | |
| from train_ddp import train_model_ddp | |
| # 默认路径 | |
| DEFAULT_PROTEIN_PATH = "./SaProt_650M_AF2" | |
| DEFAULT_MOLECULE_PATH = "./ChemBERTa-zinc-base-v1" | |
| def load_dual_tower_model( | |
| protein_model_path: Optional[str] = None, | |
| molecule_model_path: Optional[str] = None, | |
| pt_path: Optional[str] = None, | |
| device: str = "cuda" if torch.cuda.is_available() else "cpu" | |
| ) -> Tuple[ProteinMoleculeDualEncoder, PreTrainedTokenizer, PreTrainedTokenizer]: | |
| """ | |
| 统一加载模型和Tokenizer的函数。 | |
| 逻辑流: | |
| 1. 确定骨干网络路径(如果未提供,则使用默认值)。 | |
| 2. 初始化模型结构(load_encoder_components + ProteinMoleculeDualEncoder)。 | |
| 3. 如果提供了 pt_path,处理 DDP 前缀并加载权重覆盖初始权重。 | |
| """ | |
| # --- 1. 确定 Backbones 路径 --- | |
| # 如果参数为 None,则回退到默认路径;否则使用传入的参数 | |
| # 这样既满足了"只传pt_path时用默认路径",也允许"传pt_path同时指定特定骨干网络" | |
| p_path = protein_model_path if protein_model_path else DEFAULT_PROTEIN_PATH | |
| m_path = molecule_model_path if molecule_model_path else DEFAULT_MOLECULE_PATH | |
| print(f"Step 1: Initializing model structure...") | |
| print(f" - Protein Backbone: {p_path}") | |
| print(f" - Molecule Backbone: {m_path}") | |
| # --- 2. 初始化结构 (只写一次,代码复用) --- | |
| p_encoder, p_tokenizer, m_encoder, m_tokenizer = load_encoder_components( | |
| p_path, m_path | |
| ) | |
| model = ProteinMoleculeDualEncoder( | |
| protein_encoder=p_encoder, | |
| molecule_encoder=m_encoder, | |
| projection_dim=256 # 确保这里的 dim 和你训练时一致 | |
| ) | |
| # --- 3. 加载 Checkpoint (如果存在) --- | |
| if pt_path is not None: | |
| if not os.path.exists(pt_path): | |
| raise FileNotFoundError(f"Checkpoint not found at: {pt_path}") | |
| print(f"Step 2: Loading weights from {pt_path} ...") | |
| # map_location 防止显存不足或设备不匹配 | |
| state_dict = torch.load(pt_path, map_location=device) | |
| # --- 3.1 处理 DDP 前缀 (module.) --- | |
| new_state_dict = OrderedDict() | |
| for k, v in state_dict.items(): | |
| # 如果 key 是 module. 开头,去掉前7个字符 | |
| name = k[7:] if k.startswith('module.') else k | |
| new_state_dict[name] = v | |
| # --- 3.2 覆盖权重 --- | |
| # strict=True 保证权重和结构完全对应,如果有不匹配会报错提示 | |
| missing, unexpected = model.load_state_dict(new_state_dict, strict=True) | |
| print(" - Weights loaded successfully.") | |
| if missing: print(f" - Warning: Missing keys: {missing}") | |
| if unexpected: print(f" - Warning: Unexpected keys: {unexpected}") | |
| else: | |
| print("Using User-Given Encoders.") | |
| # 移动到指定设备 | |
| model.to(device) | |
| model.eval() # 默认设为评估模式,如果需要训练在外部改回 train() | |
| return model, p_tokenizer, m_tokenizer | |
| def _load_and_extract_data(json_path, extractor_func, desc="Data"): | |
| """ | |
| 通用数据加载辅助函数 | |
| Args: | |
| json_path: JSON文件路径 | |
| extractor_func: 一个函数,接收 (key, value),返回 (id, sequence_text) | |
| 如果 sequence_text 无效,应该返回 None | |
| """ | |
| print(f"Loading {desc} from {json_path}...") | |
| if not os.path.exists(json_path): | |
| print(f"Warning: {json_path} does not exist.") | |
| return [], [] | |
| with open(json_path, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| ids = [] | |
| seqs = [] | |
| # 遍历 JSON,利用传入的 extractor_func 提取数据 | |
| for k, v in data.items(): | |
| extracted_id, extracted_seq = extractor_func(k, v) | |
| # 简单的有效性检查 | |
| if extracted_seq and isinstance(extracted_seq, str) and len(extracted_seq.strip()) > 0: | |
| ids.append(extracted_id) | |
| seqs.append(extracted_seq) | |
| print(f"Found {len(ids)} valid items for {desc}.") | |
| return ids, seqs | |
| def _compute_tower_vectors(ids, seqs, tokenizer, encoder, projector, batch_size, device, max_len, desc): | |
| """ | |
| 通用推理辅助函数:负责 Batch处理 -> Tokenize -> Model Forward -> Normalize | |
| """ | |
| if not ids: | |
| return None, None | |
| embeddings = [] | |
| # 使用 no_grad 并在推理模式下运行 | |
| with torch.no_grad(): | |
| for i in tqdm(range(0, len(ids), batch_size), desc=f"Encoding {desc}"): | |
| batch_seqs = seqs[i : i + batch_size] | |
| # 1. Tokenize | |
| inputs = tokenizer( | |
| batch_seqs, | |
| padding=True, | |
| truncation=True, | |
| max_length=max_len, | |
| return_tensors='pt' | |
| ).to(device) | |
| # 2. Forward Chain (拆解双塔模型的单边逻辑) | |
| # Encoder (Backbone) | |
| outputs = encoder(**inputs) | |
| # Pooling (取 [CLS] token, 通常是 idx 0) | |
| vec = outputs.last_hidden_state[:, 0, :] | |
| # Projection (降维/映射层) | |
| vec = projector(vec) | |
| # Normalize (关键!检索任务必须做 L2 Normalize) | |
| vec = F.normalize(vec, p=2, dim=1) | |
| # 移回 CPU 防止显存溢出 | |
| embeddings.append(vec.cpu()) | |
| # 拼接所有 batch 的结果 | |
| if not embeddings: | |
| return None, None | |
| return ids, torch.cat(embeddings, dim=0) | |
| def update_candidate_vectors( | |
| model, | |
| protein_tokenizer, | |
| molecule_tokenizer, | |
| protein_json_path, | |
| molecule_json_path, | |
| output_dir, | |
| batch_size=64, | |
| device='cuda', | |
| max_prot_len=1024, | |
| max_mol_len=512 | |
| ): | |
| """ | |
| 主函数:编排整个流程 | |
| """ | |
| if not os.path.exists(output_dir): | |
| os.makedirs(output_dir) | |
| model.eval() | |
| model.to(device) | |
| # --- 核心配置列表 (Configuration) --- | |
| # 这里定义了如何处理 Protein 和 Molecule 的差异 | |
| tasks = [ | |
| { | |
| "desc": "Protein", | |
| "path": protein_json_path, | |
| # Protein 提取规则: | |
| # ID = Key (UniprotID) | |
| # Seq = Value["target__foldseek_seq"] | |
| "extract_fn": lambda k, v: (k, v.get('target__foldseek_seq')), | |
| "tokenizer": protein_tokenizer, | |
| "encoder": model.protein_encoder, | |
| "projector": model.prot_proj, | |
| "max_len": max_prot_len, | |
| "save_name": "protein_candidates.pt" | |
| }, | |
| { | |
| "desc": "Molecule", | |
| "path": molecule_json_path, | |
| # Molecule 提取规则: | |
| # ID = Key (SMILES) | |
| # Seq = Key (SMILES) - 因为 SMILES 既是 ID 也是序列内容 | |
| "extract_fn": lambda k, v: (k, k), | |
| "tokenizer": molecule_tokenizer, | |
| "encoder": model.molecule_encoder, | |
| "projector": model.mol_proj, | |
| "max_len": max_mol_len, | |
| "save_name": "molecule_candidates.pt" | |
| } | |
| ] | |
| # --- 统一执行流程 --- | |
| for task in tasks: | |
| # 1. 准备数据 | |
| ids, seqs = _load_and_extract_data( | |
| task['path'], | |
| task['extract_fn'], | |
| task['desc'] | |
| ) | |
| # 2. 计算向量 | |
| valid_ids, vectors = _compute_tower_vectors( | |
| ids, seqs, | |
| task['tokenizer'], | |
| task['encoder'], | |
| task['projector'], | |
| batch_size, | |
| device, | |
| task['max_len'], | |
| task['desc'] | |
| ) | |
| # 3. 保存结果 | |
| if vectors is not None: | |
| save_path = os.path.join(output_dir, task['save_name']) | |
| torch.save({ | |
| "ids": valid_ids, # 字符串列表 (UniProtID 或 SMILES) | |
| "vectors": vectors # FloatTensor [N, Dim] | |
| }, save_path) | |
| print(f"Saved {task['desc']} Vectors: {vectors.shape} to {save_path}\n") | |
| else: | |
| print(f"Skipping save for {task['desc']} (No valid data).\n") | |
| def recompute_candidate_vectors( | |
| protein_json_path: str, | |
| molecule_json_path: str, | |
| output_dir: str, | |
| pt_path: str = None, | |
| protein_model_path: str = None, | |
| molecule_model_path: str = None, | |
| batch_size: int = 64, | |
| max_prot_len: int = 1024, | |
| max_mol_len: int = 512, | |
| device: str = None | |
| ): | |
| """ | |
| 全流程函数:加载双塔模型 -> 读取元数据 -> 推理生成向量库 -> 保存 .pt 文件 | |
| Args: | |
| protein_json_path: Unique Target JSON 路径 | |
| molecule_json_path: Unique Compound JSON 路径 | |
| output_dir: 向量结果输出目录 | |
| pt_path: (可选) 训练好的 .pt 权重文件路径。如果为 None,将使用随机初始化的 Projection 层。 | |
| protein_model_path: (可选) 指定 Protein Backbone 路径。 | |
| molecule_model_path: (可选) 指定 Molecule Backbone 路径。 | |
| batch_size: 推理时的 Batch Size | |
| max_prot_len: 蛋白最大长度 | |
| max_mol_len: 分子最大长度 | |
| device: 'cuda' or 'cpu',默认自动检测 | |
| """ | |
| # 0. 自动检测设备 | |
| if device is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"={' Start Recomputing Candidate Vectors ':=^60}") | |
| print(f"Device: {device}") | |
| # 1. 加载模型和 Tokenizer | |
| # 这里的 load_dual_tower_model 是上一轮定义的函数 | |
| print(f"\n>>> [Stage 1/2] Loading Model & Tokenizers...") | |
| try: | |
| model, p_tokenizer, m_tokenizer = load_dual_tower_model( | |
| protein_model_path=protein_model_path, | |
| molecule_model_path=molecule_model_path, | |
| pt_path=pt_path, | |
| device=device | |
| ) | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| return | |
| # 2. 计算并保存向量 | |
| # 这里的 update_candidate_vectors 是上一轮定义的精简版函数 | |
| print(f"\n>>> [Stage 2/2] Computing & Saving Vectors...") | |
| try: | |
| update_candidate_vectors( | |
| model=model, | |
| protein_tokenizer=p_tokenizer, | |
| molecule_tokenizer=m_tokenizer, | |
| protein_json_path=protein_json_path, | |
| molecule_json_path=molecule_json_path, | |
| output_dir=output_dir, | |
| batch_size=batch_size, | |
| device=device, | |
| max_prot_len=max_prot_len, | |
| max_mol_len=max_mol_len | |
| ) | |
| except Exception as e: | |
| print(f"Error computing vectors: {e}") | |
| return | |
| print(f"\n={' All Done ':=^60}") | |
| print(f"Check output files in: {output_dir}") | |
| def continuous_train( | |
| dataset_path: str, | |
| model_save_dir: str = 'Dual_Tower_Model/customized_checkpoints', | |
| protein_model_path:str = None, | |
| molecule_model_path:str = None, | |
| best_model_path:str = None, | |
| device:str = "cuda" if torch.cuda.is_available() else "cpu", | |
| epochs: int = 5, | |
| lr: float = 1e-4, | |
| batch_size: int = 16, | |
| use_ddp: bool = False | |
| ): | |
| """ | |
| 执行持续训练/微调流程。 | |
| 如果 best_model_path 存在,则从该 Checkpoint 继续训练; | |
| 否则使用 protein_model_path 和 molecule_model_path (或默认 Backbone) 开始训练。 | |
| """ | |
| model, p_tokenizer, m_tokenizer = load_dual_tower_model( | |
| protein_model_path=protein_model_path, | |
| molecule_model_path=molecule_model_path, | |
| pt_path=best_model_path, | |
| device=device | |
| ) | |
| if use_ddp: | |
| train_model_ddp( | |
| model_and_tokenizers=[model, p_tokenizer,m_tokenizer], | |
| dataset_path=dataset_path, | |
| model_save_dir=model_save_dir, | |
| epochs=epochs, | |
| batch_size=batch_size, | |
| lr=lr | |
| ) | |
| else: | |
| train_model( | |
| model_and_tokenizers=[model, p_tokenizer,m_tokenizer], | |
| dataset_path=dataset_path, | |
| model_save_dir=model_save_dir, | |
| epochs=epochs, | |
| batch_size=batch_size, | |
| lr=lr | |
| ) | |
| return True | |
| if __name__ == "__main__": | |
| # 假设你已经定义好了 load_dual_tower_model 和 update_candidate_vectors | |
| # 这里有两个可能的输入以及对应的三种情况 | |
| # 1. 用户给定模型:(给定两个encoder) | |
| # 2. 用户给定数据:(给定一个有'compound__smiles', 'target__foldseek_seq', 'outcome_potency_pxc50', 'outcome_is_active'字段的数据集) | |
| # (3. 是否要使用ddp: train_model_ddp或者train_model) | |
| # **需要推理** | |
| # 1. 用户没输入数据, 输入模型 --> recompute_candidate_vectors(protein_model_path, molecule_model_path) | |
| # **需要训练 + 推理** | |
| # 2. 用户输入数据, 输入模型 | |
| # --> continuous_train(protein_model_path, molecule_model_path) --> recompute_candidate_vectors(best_model_path) | |
| # 3. 用户输入数据, 没输入模型 | |
| # --> continuous_train(best_model_path) --> recompute_candidate_vectors(best_model_path) | |
| protein_json_path = 'drug_target_activity/candidates/unique_targets.json' | |
| molecule_json_path = 'drug_target_activity/candidates/unique_compounds.json' | |
| best_model_path = 'Dual_Tower_Model/output_checkpoints_ddp/model_epoch_7_acc_0.3259.pt' | |
| recompute_candidate_vectors( | |
| protein_json_path=protein_json_path, | |
| molecule_json_path=molecule_json_path, | |
| output_dir='drug_target_activity/candidates', | |
| pt_path=best_model_path, # 加载训练好的权重 | |
| batch_size=32 | |
| ) | |
| #只给了数据没给模型,就重新训tensor |