Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| import torch.nn.functional as F | |
| import torch.distributed as dist | |
| from model import load_encoder_components, ProteinMoleculeDualEncoder | |
| from dataset import ProteinMoleculeDataset, DualTowerCollator | |
| def print_model_stats(model): | |
| total_params = sum(p.numel() for p in model.parameters()) | |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| # 格式化数字,例如转换为 Million (M) | |
| def format_params(num): | |
| if num >= 1e6: | |
| return f"{num / 1e6:.2f}M" | |
| elif num >= 1e3: | |
| return f"{num / 1e3:.2f}K" | |
| else: | |
| return str(num) | |
| print(f"|" + "-"*50 + "|") | |
| print(f"| {'Model Statistics':^48} |") | |
| print(f"|" + "-"*50 + "|") | |
| print(f"| Total Parameters : {format_params(total_params):>15} |") | |
| print(f"| Trainable Parameters : {format_params(trainable_params):>15} |") | |
| print(f"| Frozen Parameters : {format_params(total_params - trainable_params):>15} |") | |
| print(f"|" + "-"*50 + "|") | |
| # 分别查看两个塔的大小(可选,帮你确认 backbone 大小) | |
| if hasattr(model, 'protein_encoder'): | |
| p_params = sum(p.numel() for p in model.protein_encoder.parameters()) | |
| print(f"| Protein Tower : {format_params(p_params):>15} |") | |
| if hasattr(model, 'molecule_encoder'): | |
| m_params = sum(p.numel() for p in model.molecule_encoder.parameters()) | |
| print(f"| Molecule Tower : {format_params(m_params):>15} |") | |
| print(f"|" + "-"*50 + "|") | |
| class DualTowerTrainer: | |
| def __init__( | |
| self, | |
| model: nn.Module, | |
| train_loader: DataLoader, | |
| val_loader: DataLoader, | |
| learning_rate: float = 1e-4, | |
| temperature: float = 0.07, | |
| device: str = 'cuda' if torch.cuda.is_available() else 'cpu', | |
| save_dir: str = "./checkpoints", | |
| ): | |
| print_model_stats(model) | |
| self.model = model.to(device) | |
| self.train_loader = train_loader | |
| self.val_loader = val_loader | |
| self.device = device | |
| self.temperature = temperature | |
| self.save_dir = save_dir | |
| self.optimizer = optim.AdamW(self.model.parameters(), lr=learning_rate) | |
| self.scheduler = optim.lr_scheduler.CosineAnnealingLR( | |
| self.optimizer, T_max=10, eta_min=1e-6 | |
| ) | |
| if not os.path.exists(save_dir): | |
| os.makedirs(save_dir) | |
| def compute_loss(self, p_vec, m_vec, labels): | |
| """ | |
| 计算 Masked InfoNCE Loss。 | |
| 只对 batch 中 label=1 (Active) 的样本对计算正向损失。 | |
| 所有样本 (包括 label=0) 都会作为分母中的负例。 | |
| """ | |
| # 1. 计算相似度矩阵 (Batch_Size, Batch_Size) | |
| # logits[i][j] 表示第 i 个蛋白和第 j 个分子的相似度 | |
| logits = torch.matmul(p_vec, m_vec.T) / self.temperature | |
| # 2. 生成目标 (对角线是正例) | |
| batch_size = p_vec.size(0) | |
| targets = torch.arange(batch_size).to(self.device) | |
| # 3. 计算 Cross Entropy | |
| # 我们只关心 label=1 的行,因为 label=0 的行本身就不应该结合, | |
| # 如果强制 label=0 的对角线相似度最大化是错误的。 | |
| # CrossEntropyLoss 默认是会对 logits 进行 Softmax | |
| # 这里的 loss 是计算每一行(每个 Protein)去匹配正确的 Molecule | |
| raw_loss = F.cross_entropy(logits, targets, reduction='none') | |
| # 4. Masking: 只取 active (label=1) 的 loss 平均 | |
| active_mask = (labels == 1).float() | |
| # 防止除以 0 | |
| num_actives = active_mask.sum() | |
| if num_actives > 0: | |
| final_loss = (raw_loss * active_mask).sum() / num_actives | |
| else: | |
| # 如果这一个 batch 全是负例,loss 为 0 (或者为了梯度流不断,设为一个很小的值) | |
| # final_loss = torch.tensor(0.0, device=self.device, requires_grad=True) | |
| final_loss = 0.0 * (p_vec.sum() + m_vec.sum()) | |
| return final_loss | |
| def train_epoch(self, epoch_idx): | |
| self.model.train() | |
| total_loss = 0 | |
| active_count = 0 | |
| loop = tqdm(self.train_loader, desc=f"Train Epoch {epoch_idx}") | |
| for batch in loop: | |
| # 1. 数据移到 GPU | |
| prot_inputs = {k: v.to(self.device) for k, v in batch['protein_inputs'].items()} | |
| mol_inputs = {k: v.to(self.device) for k, v in batch['molecule_inputs'].items()} | |
| labels = batch['labels'].to(self.device) # 0 或 1 | |
| if (labels == 1).sum() == 0: | |
| print("\nSkipping batch with no active samples") | |
| continue | |
| # 2. 前向传播 | |
| self.optimizer.zero_grad() | |
| p_vec, m_vec = self.model(prot_inputs, mol_inputs) | |
| # 3. 计算 Loss | |
| loss = self.compute_loss(p_vec, m_vec, labels) | |
| # 4. 反向传播 | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) # 梯度裁剪防止爆炸 | |
| self.optimizer.step() | |
| total_loss += loss.item() | |
| if (labels == 1).sum() > 0: | |
| active_count += 1 | |
| loop.set_postfix(loss=loss.item()) | |
| avg_loss = total_loss / len(self.train_loader) | |
| self.scheduler.step() # 更新学习率 | |
| return avg_loss | |
| def evaluate(self): | |
| self.model.eval() | |
| total_loss = 0 | |
| # 也可以统计准确率:对于 Active 的对子,Top-1 预测是不是它自己 | |
| correct_retrieval = 0 | |
| total_actives = 0 | |
| for batch in tqdm(self.val_loader, desc="Evaluating"): | |
| prot_inputs = {k: v.to(self.device) for k, v in batch['protein_inputs'].items()} | |
| mol_inputs = {k: v.to(self.device) for k, v in batch['molecule_inputs'].items()} | |
| labels = batch['labels'].to(self.device) | |
| p_vec, m_vec = self.model(prot_inputs, mol_inputs) | |
| loss = self.compute_loss(p_vec, m_vec, labels) | |
| total_loss += loss.item() | |
| # 计算简单的 Top-1 Accuracy (仅针对 Active 样本) | |
| logits = torch.matmul(p_vec, m_vec.T) | |
| preds = torch.argmax(logits, dim=1) # 每一行预测最大的列索引 | |
| targets = torch.arange(p_vec.size(0)).to(self.device) | |
| active_mask = (labels == 1) | |
| if active_mask.sum() > 0: | |
| matches = (preds[active_mask] == targets[active_mask]) | |
| correct_retrieval += matches.sum().item() | |
| total_actives += active_mask.sum().item() | |
| avg_loss = total_loss / len(self.val_loader) | |
| acc = correct_retrieval / total_actives if total_actives > 0 else 0.0 | |
| return avg_loss, acc | |
| # def save_checkpoint(self, epoch, metric): | |
| # path = os.path.join(self.save_dir, f"model_epoch_{epoch}_acc_{metric:.4f}.pt") | |
| # torch.save(self.model.state_dict(), path) | |
| # print(f"Model saved to {path}") | |
| def save_checkpoint(self, epoch, metric): | |
| path = os.path.join(self.save_dir, f"model_epoch_{epoch}_acc_{metric:.4f}.pt") | |
| # 检查模型是否被 DDP 或 DataParallel 包裹 | |
| if isinstance(self.model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)): | |
| state_dict = self.model.module.state_dict() | |
| else: | |
| state_dict = self.model.state_dict() | |
| torch.save(state_dict, path) | |
| if dist.get_rank() == 0: | |
| print(f"Model saved to {path} (Clean state_dict without 'module.' prefix)") | |
| def train_model( | |
| dataset_path: str, | |
| model_and_tokenizers = None, | |
| protein_model_path: str = None, | |
| molecule_model_path: str = None, | |
| model_save_dir: str = "./output_checkpoints", | |
| epochs: int = 10, | |
| batch_size: int = 32, | |
| lr: float = 1e-4 | |
| ): | |
| # 1. 加载 Tokenizer 和 Dataset | |
| print("Initialize components...") | |
| if model_and_tokenizers is not None: | |
| model, p_tokenizer, m_tokenizer = model_and_tokenizers | |
| else: | |
| p_encoder, p_tokenizer, m_encoder, m_tokenizer = load_encoder_components( | |
| protein_model_path, molecule_model_path | |
| ) | |
| model = ProteinMoleculeDualEncoder(p_encoder, m_encoder, projection_dim=256) | |
| full_dataset = ProteinMoleculeDataset(dataset_path) | |
| # 划分训练集和验证集 (80/20) | |
| train_size = int(0.8 * len(full_dataset)) | |
| val_size = len(full_dataset) - train_size | |
| train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size]) | |
| # 3. 准备 Collator 和 DataLoader | |
| collator = DualTowerCollator(p_tokenizer, m_tokenizer) | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| collate_fn=collator, | |
| num_workers=4, # 根据CPU核心数调整 | |
| pin_memory=True | |
| ) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| collate_fn=collator, | |
| num_workers=4 | |
| ) | |
| # 4. 初始化模型 | |
| # model = ProteinMoleculeDualEncoder(p_encoder, m_encoder, projection_dim=256) | |
| # 5. 初始化 Trainer | |
| trainer = DualTowerTrainer( | |
| model=model, | |
| train_loader=train_loader, | |
| val_loader=val_loader, | |
| learning_rate=lr, | |
| save_dir=model_save_dir, | |
| ) | |
| # 6. 开始循环 | |
| print("Start Training...") | |
| best_acc = 0.0 | |
| for epoch in range(epochs): | |
| train_loss = trainer.train_epoch(epoch) | |
| val_loss, val_acc = trainer.evaluate() | |
| print(f"Epoch {epoch} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Retrieval Acc: {val_acc:.4f}") | |
| # 保存最好模型 | |
| if val_acc > best_acc: | |
| best_acc = val_acc | |
| trainer.save_checkpoint(epoch, val_acc) | |
| if __name__ == "__main__": | |
| protein_model_path = "./SaProt_650M_AF2" | |
| molecule_model_path = "./ChemBERTa-zinc-base-v1" | |
| dataset_path = 'drug_target_activity/processed_train.parquet' | |
| model_save_dir = './Dual_Tower_Model/output_checkpoints' | |
| train_model( | |
| protein_model_path=protein_model_path, | |
| molecule_model_path=molecule_model_path, | |
| dataset_path=dataset_path, | |
| model_save_dir=model_save_dir | |
| ) |