Recommend_system / server_train.py
tong
clear code
2180e31
import torch
import torch.nn.functional as F
import json
import os
from collections import OrderedDict
from typing import Tuple, Optional
from tqdm import tqdm
from transformers import PreTrainedTokenizer
from model import load_encoder_components, ProteinMoleculeDualEncoder
from train import train_model
from train_ddp import train_model_ddp
# 默认路径
DEFAULT_PROTEIN_PATH = "./SaProt_650M_AF2"
DEFAULT_MOLECULE_PATH = "./ChemBERTa-zinc-base-v1"
def load_dual_tower_model(
protein_model_path: Optional[str] = None,
molecule_model_path: Optional[str] = None,
pt_path: Optional[str] = None,
device: str = "cuda" if torch.cuda.is_available() else "cpu"
) -> Tuple[ProteinMoleculeDualEncoder, PreTrainedTokenizer, PreTrainedTokenizer]:
"""
统一加载模型和Tokenizer的函数。
逻辑流:
1. 确定骨干网络路径(如果未提供,则使用默认值)。
2. 初始化模型结构(load_encoder_components + ProteinMoleculeDualEncoder)。
3. 如果提供了 pt_path,处理 DDP 前缀并加载权重覆盖初始权重。
"""
# --- 1. 确定 Backbones 路径 ---
# 如果参数为 None,则回退到默认路径;否则使用传入的参数
# 这样既满足了"只传pt_path时用默认路径",也允许"传pt_path同时指定特定骨干网络"
p_path = protein_model_path if protein_model_path else DEFAULT_PROTEIN_PATH
m_path = molecule_model_path if molecule_model_path else DEFAULT_MOLECULE_PATH
print(f"Step 1: Initializing model structure...")
print(f" - Protein Backbone: {p_path}")
print(f" - Molecule Backbone: {m_path}")
# --- 2. 初始化结构 (只写一次,代码复用) ---
p_encoder, p_tokenizer, m_encoder, m_tokenizer = load_encoder_components(
p_path, m_path
)
model = ProteinMoleculeDualEncoder(
protein_encoder=p_encoder,
molecule_encoder=m_encoder,
projection_dim=256 # 确保这里的 dim 和你训练时一致
)
# --- 3. 加载 Checkpoint (如果存在) ---
if pt_path is not None:
if not os.path.exists(pt_path):
raise FileNotFoundError(f"Checkpoint not found at: {pt_path}")
print(f"Step 2: Loading weights from {pt_path} ...")
# map_location 防止显存不足或设备不匹配
state_dict = torch.load(pt_path, map_location=device)
# --- 3.1 处理 DDP 前缀 (module.) ---
new_state_dict = OrderedDict()
for k, v in state_dict.items():
# 如果 key 是 module. 开头,去掉前7个字符
name = k[7:] if k.startswith('module.') else k
new_state_dict[name] = v
# --- 3.2 覆盖权重 ---
# strict=True 保证权重和结构完全对应,如果有不匹配会报错提示
missing, unexpected = model.load_state_dict(new_state_dict, strict=True)
print(" - Weights loaded successfully.")
if missing: print(f" - Warning: Missing keys: {missing}")
if unexpected: print(f" - Warning: Unexpected keys: {unexpected}")
else:
print("Using User-Given Encoders.")
# 移动到指定设备
model.to(device)
model.eval() # 默认设为评估模式,如果需要训练在外部改回 train()
return model, p_tokenizer, m_tokenizer
def _load_and_extract_data(json_path, extractor_func, desc="Data"):
"""
通用数据加载辅助函数
Args:
json_path: JSON文件路径
extractor_func: 一个函数,接收 (key, value),返回 (id, sequence_text)
如果 sequence_text 无效,应该返回 None
"""
print(f"Loading {desc} from {json_path}...")
if not os.path.exists(json_path):
print(f"Warning: {json_path} does not exist.")
return [], []
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
ids = []
seqs = []
# 遍历 JSON,利用传入的 extractor_func 提取数据
for k, v in data.items():
extracted_id, extracted_seq = extractor_func(k, v)
# 简单的有效性检查
if extracted_seq and isinstance(extracted_seq, str) and len(extracted_seq.strip()) > 0:
ids.append(extracted_id)
seqs.append(extracted_seq)
print(f"Found {len(ids)} valid items for {desc}.")
return ids, seqs
def _compute_tower_vectors(ids, seqs, tokenizer, encoder, projector, batch_size, device, max_len, desc):
"""
通用推理辅助函数:负责 Batch处理 -> Tokenize -> Model Forward -> Normalize
"""
if not ids:
return None, None
embeddings = []
# 使用 no_grad 并在推理模式下运行
with torch.no_grad():
for i in tqdm(range(0, len(ids), batch_size), desc=f"Encoding {desc}"):
batch_seqs = seqs[i : i + batch_size]
# 1. Tokenize
inputs = tokenizer(
batch_seqs,
padding=True,
truncation=True,
max_length=max_len,
return_tensors='pt'
).to(device)
# 2. Forward Chain (拆解双塔模型的单边逻辑)
# Encoder (Backbone)
outputs = encoder(**inputs)
# Pooling (取 [CLS] token, 通常是 idx 0)
vec = outputs.last_hidden_state[:, 0, :]
# Projection (降维/映射层)
vec = projector(vec)
# Normalize (关键!检索任务必须做 L2 Normalize)
vec = F.normalize(vec, p=2, dim=1)
# 移回 CPU 防止显存溢出
embeddings.append(vec.cpu())
# 拼接所有 batch 的结果
if not embeddings:
return None, None
return ids, torch.cat(embeddings, dim=0)
def update_candidate_vectors(
model,
protein_tokenizer,
molecule_tokenizer,
protein_json_path,
molecule_json_path,
output_dir,
batch_size=64,
device='cuda',
max_prot_len=1024,
max_mol_len=512
):
"""
主函数:编排整个流程
"""
if not os.path.exists(output_dir):
os.makedirs(output_dir)
model.eval()
model.to(device)
# --- 核心配置列表 (Configuration) ---
# 这里定义了如何处理 Protein 和 Molecule 的差异
tasks = [
{
"desc": "Protein",
"path": protein_json_path,
# Protein 提取规则:
# ID = Key (UniprotID)
# Seq = Value["target__foldseek_seq"]
"extract_fn": lambda k, v: (k, v.get('target__foldseek_seq')),
"tokenizer": protein_tokenizer,
"encoder": model.protein_encoder,
"projector": model.prot_proj,
"max_len": max_prot_len,
"save_name": "protein_candidates.pt"
},
{
"desc": "Molecule",
"path": molecule_json_path,
# Molecule 提取规则:
# ID = Key (SMILES)
# Seq = Key (SMILES) - 因为 SMILES 既是 ID 也是序列内容
"extract_fn": lambda k, v: (k, k),
"tokenizer": molecule_tokenizer,
"encoder": model.molecule_encoder,
"projector": model.mol_proj,
"max_len": max_mol_len,
"save_name": "molecule_candidates.pt"
}
]
# --- 统一执行流程 ---
for task in tasks:
# 1. 准备数据
ids, seqs = _load_and_extract_data(
task['path'],
task['extract_fn'],
task['desc']
)
# 2. 计算向量
valid_ids, vectors = _compute_tower_vectors(
ids, seqs,
task['tokenizer'],
task['encoder'],
task['projector'],
batch_size,
device,
task['max_len'],
task['desc']
)
# 3. 保存结果
if vectors is not None:
save_path = os.path.join(output_dir, task['save_name'])
torch.save({
"ids": valid_ids, # 字符串列表 (UniProtID 或 SMILES)
"vectors": vectors # FloatTensor [N, Dim]
}, save_path)
print(f"Saved {task['desc']} Vectors: {vectors.shape} to {save_path}\n")
else:
print(f"Skipping save for {task['desc']} (No valid data).\n")
def recompute_candidate_vectors(
protein_json_path: str,
molecule_json_path: str,
output_dir: str,
pt_path: str = None,
protein_model_path: str = None,
molecule_model_path: str = None,
batch_size: int = 64,
max_prot_len: int = 1024,
max_mol_len: int = 512,
device: str = None
):
"""
全流程函数:加载双塔模型 -> 读取元数据 -> 推理生成向量库 -> 保存 .pt 文件
Args:
protein_json_path: Unique Target JSON 路径
molecule_json_path: Unique Compound JSON 路径
output_dir: 向量结果输出目录
pt_path: (可选) 训练好的 .pt 权重文件路径。如果为 None,将使用随机初始化的 Projection 层。
protein_model_path: (可选) 指定 Protein Backbone 路径。
molecule_model_path: (可选) 指定 Molecule Backbone 路径。
batch_size: 推理时的 Batch Size
max_prot_len: 蛋白最大长度
max_mol_len: 分子最大长度
device: 'cuda' or 'cpu',默认自动检测
"""
# 0. 自动检测设备
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"={' Start Recomputing Candidate Vectors ':=^60}")
print(f"Device: {device}")
# 1. 加载模型和 Tokenizer
# 这里的 load_dual_tower_model 是上一轮定义的函数
print(f"\n>>> [Stage 1/2] Loading Model & Tokenizers...")
try:
model, p_tokenizer, m_tokenizer = load_dual_tower_model(
protein_model_path=protein_model_path,
molecule_model_path=molecule_model_path,
pt_path=pt_path,
device=device
)
except Exception as e:
print(f"Error loading model: {e}")
return
# 2. 计算并保存向量
# 这里的 update_candidate_vectors 是上一轮定义的精简版函数
print(f"\n>>> [Stage 2/2] Computing & Saving Vectors...")
try:
update_candidate_vectors(
model=model,
protein_tokenizer=p_tokenizer,
molecule_tokenizer=m_tokenizer,
protein_json_path=protein_json_path,
molecule_json_path=molecule_json_path,
output_dir=output_dir,
batch_size=batch_size,
device=device,
max_prot_len=max_prot_len,
max_mol_len=max_mol_len
)
except Exception as e:
print(f"Error computing vectors: {e}")
return
print(f"\n={' All Done ':=^60}")
print(f"Check output files in: {output_dir}")
def continuous_train(
dataset_path: str,
model_save_dir: str = 'Dual_Tower_Model/customized_checkpoints',
protein_model_path:str = None,
molecule_model_path:str = None,
best_model_path:str = None,
device:str = "cuda" if torch.cuda.is_available() else "cpu",
epochs: int = 5,
lr: float = 1e-4,
batch_size: int = 16,
use_ddp: bool = False
):
"""
执行持续训练/微调流程。
如果 best_model_path 存在,则从该 Checkpoint 继续训练;
否则使用 protein_model_path 和 molecule_model_path (或默认 Backbone) 开始训练。
"""
model, p_tokenizer, m_tokenizer = load_dual_tower_model(
protein_model_path=protein_model_path,
molecule_model_path=molecule_model_path,
pt_path=best_model_path,
device=device
)
if use_ddp:
train_model_ddp(
model_and_tokenizers=[model, p_tokenizer,m_tokenizer],
dataset_path=dataset_path,
model_save_dir=model_save_dir,
epochs=epochs,
batch_size=batch_size,
lr=lr
)
else:
train_model(
model_and_tokenizers=[model, p_tokenizer,m_tokenizer],
dataset_path=dataset_path,
model_save_dir=model_save_dir,
epochs=epochs,
batch_size=batch_size,
lr=lr
)
return True
if __name__ == "__main__":
# 假设你已经定义好了 load_dual_tower_model 和 update_candidate_vectors
# 这里有两个可能的输入以及对应的三种情况
# 1. 用户给定模型:(给定两个encoder)
# 2. 用户给定数据:(给定一个有'compound__smiles', 'target__foldseek_seq', 'outcome_potency_pxc50', 'outcome_is_active'字段的数据集)
# (3. 是否要使用ddp: train_model_ddp或者train_model)
# **需要推理**
# 1. 用户没输入数据, 输入模型 --> recompute_candidate_vectors(protein_model_path, molecule_model_path)
# **需要训练 + 推理**
# 2. 用户输入数据, 输入模型
# --> continuous_train(protein_model_path, molecule_model_path) --> recompute_candidate_vectors(best_model_path)
# 3. 用户输入数据, 没输入模型
# --> continuous_train(best_model_path) --> recompute_candidate_vectors(best_model_path)
protein_json_path = 'drug_target_activity/candidates/unique_targets.json'
molecule_json_path = 'drug_target_activity/candidates/unique_compounds.json'
best_model_path = 'Dual_Tower_Model/output_checkpoints_ddp/model_epoch_7_acc_0.3259.pt'
recompute_candidate_vectors(
protein_json_path=protein_json_path,
molecule_json_path=molecule_json_path,
output_dir='drug_target_activity/candidates',
pt_path=best_model_path, # 加载训练好的权重
batch_size=32
)
#只给了数据没给模型,就重新训tensor