Recommend_system / train.py
tong
clear code
2180e31
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.nn.functional as F
import torch.distributed as dist
from model import load_encoder_components, ProteinMoleculeDualEncoder
from dataset import ProteinMoleculeDataset, DualTowerCollator
def print_model_stats(model):
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# 格式化数字,例如转换为 Million (M)
def format_params(num):
if num >= 1e6:
return f"{num / 1e6:.2f}M"
elif num >= 1e3:
return f"{num / 1e3:.2f}K"
else:
return str(num)
print(f"|" + "-"*50 + "|")
print(f"| {'Model Statistics':^48} |")
print(f"|" + "-"*50 + "|")
print(f"| Total Parameters : {format_params(total_params):>15} |")
print(f"| Trainable Parameters : {format_params(trainable_params):>15} |")
print(f"| Frozen Parameters : {format_params(total_params - trainable_params):>15} |")
print(f"|" + "-"*50 + "|")
# 分别查看两个塔的大小(可选,帮你确认 backbone 大小)
if hasattr(model, 'protein_encoder'):
p_params = sum(p.numel() for p in model.protein_encoder.parameters())
print(f"| Protein Tower : {format_params(p_params):>15} |")
if hasattr(model, 'molecule_encoder'):
m_params = sum(p.numel() for p in model.molecule_encoder.parameters())
print(f"| Molecule Tower : {format_params(m_params):>15} |")
print(f"|" + "-"*50 + "|")
class DualTowerTrainer:
def __init__(
self,
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
learning_rate: float = 1e-4,
temperature: float = 0.07,
device: str = 'cuda' if torch.cuda.is_available() else 'cpu',
save_dir: str = "./checkpoints",
):
print_model_stats(model)
self.model = model.to(device)
self.train_loader = train_loader
self.val_loader = val_loader
self.device = device
self.temperature = temperature
self.save_dir = save_dir
self.optimizer = optim.AdamW(self.model.parameters(), lr=learning_rate)
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(
self.optimizer, T_max=10, eta_min=1e-6
)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
def compute_loss(self, p_vec, m_vec, labels):
"""
计算 Masked InfoNCE Loss。
只对 batch 中 label=1 (Active) 的样本对计算正向损失。
所有样本 (包括 label=0) 都会作为分母中的负例。
"""
# 1. 计算相似度矩阵 (Batch_Size, Batch_Size)
# logits[i][j] 表示第 i 个蛋白和第 j 个分子的相似度
logits = torch.matmul(p_vec, m_vec.T) / self.temperature
# 2. 生成目标 (对角线是正例)
batch_size = p_vec.size(0)
targets = torch.arange(batch_size).to(self.device)
# 3. 计算 Cross Entropy
# 我们只关心 label=1 的行,因为 label=0 的行本身就不应该结合,
# 如果强制 label=0 的对角线相似度最大化是错误的。
# CrossEntropyLoss 默认是会对 logits 进行 Softmax
# 这里的 loss 是计算每一行(每个 Protein)去匹配正确的 Molecule
raw_loss = F.cross_entropy(logits, targets, reduction='none')
# 4. Masking: 只取 active (label=1) 的 loss 平均
active_mask = (labels == 1).float()
# 防止除以 0
num_actives = active_mask.sum()
if num_actives > 0:
final_loss = (raw_loss * active_mask).sum() / num_actives
else:
# 如果这一个 batch 全是负例,loss 为 0 (或者为了梯度流不断,设为一个很小的值)
# final_loss = torch.tensor(0.0, device=self.device, requires_grad=True)
final_loss = 0.0 * (p_vec.sum() + m_vec.sum())
return final_loss
def train_epoch(self, epoch_idx):
self.model.train()
total_loss = 0
active_count = 0
loop = tqdm(self.train_loader, desc=f"Train Epoch {epoch_idx}")
for batch in loop:
# 1. 数据移到 GPU
prot_inputs = {k: v.to(self.device) for k, v in batch['protein_inputs'].items()}
mol_inputs = {k: v.to(self.device) for k, v in batch['molecule_inputs'].items()}
labels = batch['labels'].to(self.device) # 0 或 1
if (labels == 1).sum() == 0:
print("\nSkipping batch with no active samples")
continue
# 2. 前向传播
self.optimizer.zero_grad()
p_vec, m_vec = self.model(prot_inputs, mol_inputs)
# 3. 计算 Loss
loss = self.compute_loss(p_vec, m_vec, labels)
# 4. 反向传播
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) # 梯度裁剪防止爆炸
self.optimizer.step()
total_loss += loss.item()
if (labels == 1).sum() > 0:
active_count += 1
loop.set_postfix(loss=loss.item())
avg_loss = total_loss / len(self.train_loader)
self.scheduler.step() # 更新学习率
return avg_loss
@torch.no_grad()
def evaluate(self):
self.model.eval()
total_loss = 0
# 也可以统计准确率:对于 Active 的对子,Top-1 预测是不是它自己
correct_retrieval = 0
total_actives = 0
for batch in tqdm(self.val_loader, desc="Evaluating"):
prot_inputs = {k: v.to(self.device) for k, v in batch['protein_inputs'].items()}
mol_inputs = {k: v.to(self.device) for k, v in batch['molecule_inputs'].items()}
labels = batch['labels'].to(self.device)
p_vec, m_vec = self.model(prot_inputs, mol_inputs)
loss = self.compute_loss(p_vec, m_vec, labels)
total_loss += loss.item()
# 计算简单的 Top-1 Accuracy (仅针对 Active 样本)
logits = torch.matmul(p_vec, m_vec.T)
preds = torch.argmax(logits, dim=1) # 每一行预测最大的列索引
targets = torch.arange(p_vec.size(0)).to(self.device)
active_mask = (labels == 1)
if active_mask.sum() > 0:
matches = (preds[active_mask] == targets[active_mask])
correct_retrieval += matches.sum().item()
total_actives += active_mask.sum().item()
avg_loss = total_loss / len(self.val_loader)
acc = correct_retrieval / total_actives if total_actives > 0 else 0.0
return avg_loss, acc
# def save_checkpoint(self, epoch, metric):
# path = os.path.join(self.save_dir, f"model_epoch_{epoch}_acc_{metric:.4f}.pt")
# torch.save(self.model.state_dict(), path)
# print(f"Model saved to {path}")
def save_checkpoint(self, epoch, metric):
path = os.path.join(self.save_dir, f"model_epoch_{epoch}_acc_{metric:.4f}.pt")
# 检查模型是否被 DDP 或 DataParallel 包裹
if isinstance(self.model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)):
state_dict = self.model.module.state_dict()
else:
state_dict = self.model.state_dict()
torch.save(state_dict, path)
if dist.get_rank() == 0:
print(f"Model saved to {path} (Clean state_dict without 'module.' prefix)")
def train_model(
dataset_path: str,
model_and_tokenizers = None,
protein_model_path: str = None,
molecule_model_path: str = None,
model_save_dir: str = "./output_checkpoints",
epochs: int = 10,
batch_size: int = 32,
lr: float = 1e-4
):
# 1. 加载 Tokenizer 和 Dataset
print("Initialize components...")
if model_and_tokenizers is not None:
model, p_tokenizer, m_tokenizer = model_and_tokenizers
else:
p_encoder, p_tokenizer, m_encoder, m_tokenizer = load_encoder_components(
protein_model_path, molecule_model_path
)
model = ProteinMoleculeDualEncoder(p_encoder, m_encoder, projection_dim=256)
full_dataset = ProteinMoleculeDataset(dataset_path)
# 划分训练集和验证集 (80/20)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])
# 3. 准备 Collator 和 DataLoader
collator = DualTowerCollator(p_tokenizer, m_tokenizer)
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=collator,
num_workers=4, # 根据CPU核心数调整
pin_memory=True
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
collate_fn=collator,
num_workers=4
)
# 4. 初始化模型
# model = ProteinMoleculeDualEncoder(p_encoder, m_encoder, projection_dim=256)
# 5. 初始化 Trainer
trainer = DualTowerTrainer(
model=model,
train_loader=train_loader,
val_loader=val_loader,
learning_rate=lr,
save_dir=model_save_dir,
)
# 6. 开始循环
print("Start Training...")
best_acc = 0.0
for epoch in range(epochs):
train_loss = trainer.train_epoch(epoch)
val_loss, val_acc = trainer.evaluate()
print(f"Epoch {epoch} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Retrieval Acc: {val_acc:.4f}")
# 保存最好模型
if val_acc > best_acc:
best_acc = val_acc
trainer.save_checkpoint(epoch, val_acc)
if __name__ == "__main__":
protein_model_path = "./SaProt_650M_AF2"
molecule_model_path = "./ChemBERTa-zinc-base-v1"
dataset_path = 'drug_target_activity/processed_train.parquet'
model_save_dir = './Dual_Tower_Model/output_checkpoints'
train_model(
protein_model_path=protein_model_path,
molecule_model_path=molecule_model_path,
dataset_path=dataset_path,
model_save_dir=model_save_dir
)