import os import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from tqdm import tqdm import torch.nn.functional as F import torch.distributed as dist from model import load_encoder_components, ProteinMoleculeDualEncoder from dataset import ProteinMoleculeDataset, DualTowerCollator def print_model_stats(model): total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) # 格式化数字,例如转换为 Million (M) def format_params(num): if num >= 1e6: return f"{num / 1e6:.2f}M" elif num >= 1e3: return f"{num / 1e3:.2f}K" else: return str(num) print(f"|" + "-"*50 + "|") print(f"| {'Model Statistics':^48} |") print(f"|" + "-"*50 + "|") print(f"| Total Parameters : {format_params(total_params):>15} |") print(f"| Trainable Parameters : {format_params(trainable_params):>15} |") print(f"| Frozen Parameters : {format_params(total_params - trainable_params):>15} |") print(f"|" + "-"*50 + "|") # 分别查看两个塔的大小(可选,帮你确认 backbone 大小) if hasattr(model, 'protein_encoder'): p_params = sum(p.numel() for p in model.protein_encoder.parameters()) print(f"| Protein Tower : {format_params(p_params):>15} |") if hasattr(model, 'molecule_encoder'): m_params = sum(p.numel() for p in model.molecule_encoder.parameters()) print(f"| Molecule Tower : {format_params(m_params):>15} |") print(f"|" + "-"*50 + "|") class DualTowerTrainer: def __init__( self, model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, learning_rate: float = 1e-4, temperature: float = 0.07, device: str = 'cuda' if torch.cuda.is_available() else 'cpu', save_dir: str = "./checkpoints", ): print_model_stats(model) self.model = model.to(device) self.train_loader = train_loader self.val_loader = val_loader self.device = device self.temperature = temperature self.save_dir = save_dir self.optimizer = optim.AdamW(self.model.parameters(), lr=learning_rate) self.scheduler = optim.lr_scheduler.CosineAnnealingLR( self.optimizer, T_max=10, eta_min=1e-6 ) if not os.path.exists(save_dir): os.makedirs(save_dir) def compute_loss(self, p_vec, m_vec, labels): """ 计算 Masked InfoNCE Loss。 只对 batch 中 label=1 (Active) 的样本对计算正向损失。 所有样本 (包括 label=0) 都会作为分母中的负例。 """ # 1. 计算相似度矩阵 (Batch_Size, Batch_Size) # logits[i][j] 表示第 i 个蛋白和第 j 个分子的相似度 logits = torch.matmul(p_vec, m_vec.T) / self.temperature # 2. 生成目标 (对角线是正例) batch_size = p_vec.size(0) targets = torch.arange(batch_size).to(self.device) # 3. 计算 Cross Entropy # 我们只关心 label=1 的行,因为 label=0 的行本身就不应该结合, # 如果强制 label=0 的对角线相似度最大化是错误的。 # CrossEntropyLoss 默认是会对 logits 进行 Softmax # 这里的 loss 是计算每一行(每个 Protein)去匹配正确的 Molecule raw_loss = F.cross_entropy(logits, targets, reduction='none') # 4. Masking: 只取 active (label=1) 的 loss 平均 active_mask = (labels == 1).float() # 防止除以 0 num_actives = active_mask.sum() if num_actives > 0: final_loss = (raw_loss * active_mask).sum() / num_actives else: # 如果这一个 batch 全是负例,loss 为 0 (或者为了梯度流不断,设为一个很小的值) # final_loss = torch.tensor(0.0, device=self.device, requires_grad=True) final_loss = 0.0 * (p_vec.sum() + m_vec.sum()) return final_loss def train_epoch(self, epoch_idx): self.model.train() total_loss = 0 active_count = 0 loop = tqdm(self.train_loader, desc=f"Train Epoch {epoch_idx}") for batch in loop: # 1. 数据移到 GPU prot_inputs = {k: v.to(self.device) for k, v in batch['protein_inputs'].items()} mol_inputs = {k: v.to(self.device) for k, v in batch['molecule_inputs'].items()} labels = batch['labels'].to(self.device) # 0 或 1 if (labels == 1).sum() == 0: print("\nSkipping batch with no active samples") continue # 2. 前向传播 self.optimizer.zero_grad() p_vec, m_vec = self.model(prot_inputs, mol_inputs) # 3. 计算 Loss loss = self.compute_loss(p_vec, m_vec, labels) # 4. 反向传播 loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) # 梯度裁剪防止爆炸 self.optimizer.step() total_loss += loss.item() if (labels == 1).sum() > 0: active_count += 1 loop.set_postfix(loss=loss.item()) avg_loss = total_loss / len(self.train_loader) self.scheduler.step() # 更新学习率 return avg_loss @torch.no_grad() def evaluate(self): self.model.eval() total_loss = 0 # 也可以统计准确率:对于 Active 的对子,Top-1 预测是不是它自己 correct_retrieval = 0 total_actives = 0 for batch in tqdm(self.val_loader, desc="Evaluating"): prot_inputs = {k: v.to(self.device) for k, v in batch['protein_inputs'].items()} mol_inputs = {k: v.to(self.device) for k, v in batch['molecule_inputs'].items()} labels = batch['labels'].to(self.device) p_vec, m_vec = self.model(prot_inputs, mol_inputs) loss = self.compute_loss(p_vec, m_vec, labels) total_loss += loss.item() # 计算简单的 Top-1 Accuracy (仅针对 Active 样本) logits = torch.matmul(p_vec, m_vec.T) preds = torch.argmax(logits, dim=1) # 每一行预测最大的列索引 targets = torch.arange(p_vec.size(0)).to(self.device) active_mask = (labels == 1) if active_mask.sum() > 0: matches = (preds[active_mask] == targets[active_mask]) correct_retrieval += matches.sum().item() total_actives += active_mask.sum().item() avg_loss = total_loss / len(self.val_loader) acc = correct_retrieval / total_actives if total_actives > 0 else 0.0 return avg_loss, acc # def save_checkpoint(self, epoch, metric): # path = os.path.join(self.save_dir, f"model_epoch_{epoch}_acc_{metric:.4f}.pt") # torch.save(self.model.state_dict(), path) # print(f"Model saved to {path}") def save_checkpoint(self, epoch, metric): path = os.path.join(self.save_dir, f"model_epoch_{epoch}_acc_{metric:.4f}.pt") # 检查模型是否被 DDP 或 DataParallel 包裹 if isinstance(self.model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): state_dict = self.model.module.state_dict() else: state_dict = self.model.state_dict() torch.save(state_dict, path) if dist.get_rank() == 0: print(f"Model saved to {path} (Clean state_dict without 'module.' prefix)") def train_model( dataset_path: str, model_and_tokenizers = None, protein_model_path: str = None, molecule_model_path: str = None, model_save_dir: str = "./output_checkpoints", epochs: int = 10, batch_size: int = 32, lr: float = 1e-4 ): # 1. 加载 Tokenizer 和 Dataset print("Initialize components...") if model_and_tokenizers is not None: model, p_tokenizer, m_tokenizer = model_and_tokenizers else: p_encoder, p_tokenizer, m_encoder, m_tokenizer = load_encoder_components( protein_model_path, molecule_model_path ) model = ProteinMoleculeDualEncoder(p_encoder, m_encoder, projection_dim=256) full_dataset = ProteinMoleculeDataset(dataset_path) # 划分训练集和验证集 (80/20) train_size = int(0.8 * len(full_dataset)) val_size = len(full_dataset) - train_size train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size]) # 3. 准备 Collator 和 DataLoader collator = DualTowerCollator(p_tokenizer, m_tokenizer) train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collator, num_workers=4, # 根据CPU核心数调整 pin_memory=True ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collator, num_workers=4 ) # 4. 初始化模型 # model = ProteinMoleculeDualEncoder(p_encoder, m_encoder, projection_dim=256) # 5. 初始化 Trainer trainer = DualTowerTrainer( model=model, train_loader=train_loader, val_loader=val_loader, learning_rate=lr, save_dir=model_save_dir, ) # 6. 开始循环 print("Start Training...") best_acc = 0.0 for epoch in range(epochs): train_loss = trainer.train_epoch(epoch) val_loss, val_acc = trainer.evaluate() print(f"Epoch {epoch} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Retrieval Acc: {val_acc:.4f}") # 保存最好模型 if val_acc > best_acc: best_acc = val_acc trainer.save_checkpoint(epoch, val_acc) if __name__ == "__main__": protein_model_path = "./SaProt_650M_AF2" molecule_model_path = "./ChemBERTa-zinc-base-v1" dataset_path = 'drug_target_activity/processed_train.parquet' model_save_dir = './Dual_Tower_Model/output_checkpoints' train_model( protein_model_path=protein_model_path, molecule_model_path=molecule_model_path, dataset_path=dataset_path, model_save_dir=model_save_dir )