import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer from typing import Tuple, Dict def load_encoder_components( protein_model_path: str, molecule_model_path: str ) -> Tuple[PreTrainedModel, PreTrainedTokenizer, PreTrainedModel, PreTrainedTokenizer]: """ 加载模型和Tokenizer。 """ print(f"Loading Protein Encoder from: {protein_model_path}") protein_tokenizer = AutoTokenizer.from_pretrained(protein_model_path) protein_encoder = AutoModel.from_pretrained(protein_model_path) print(f"Loading Molecule Encoder from: {molecule_model_path}") molecule_tokenizer = AutoTokenizer.from_pretrained(molecule_model_path) molecule_encoder = AutoModel.from_pretrained(molecule_model_path) return protein_encoder, protein_tokenizer, molecule_encoder, molecule_tokenizer class ProteinMoleculeDualEncoder(nn.Module): def __init__( self, protein_encoder: PreTrainedModel, molecule_encoder: PreTrainedModel, projection_dim: int = 256, freeze_backbone: bool = False ): """ 初始化双塔模型。 :param protein_encoder: 已经实例化好的 HuggingFace 模型对象 :param molecule_encoder: 已经实例化好的 HuggingFace 模型对象 """ super().__init__() self.protein_encoder = protein_encoder self.molecule_encoder = molecule_encoder self.prot_hidden_size = protein_encoder.config.hidden_size self.mol_hidden_size = molecule_encoder.config.hidden_size self.prot_proj = nn.Linear(self.prot_hidden_size, projection_dim) self.mol_proj = nn.Linear(self.mol_hidden_size, projection_dim) if freeze_backbone: self._freeze_parameters(self.protein_encoder) self._freeze_parameters(self.molecule_encoder) def _freeze_parameters(self, model: nn.Module): """辅助函数:冻结参数""" for param in model.parameters(): param.requires_grad = False print(f"Freezed parameters for {model.__class__.__name__}") def forward(self, protein_inputs: Dict, molecule_inputs: Dict): """ :param protein_inputs: tokenizer 输出的字典 (input_ids, attention_mask) :param molecule_inputs: tokenizer 输出的字典 """ # --- Tower A --- p_out = self.protein_encoder(**protein_inputs) p_vec = p_out.last_hidden_state[:, 0, :] # CLS token p_vec = self.prot_proj(p_vec) # --- Tower B --- m_out = self.molecule_encoder(**molecule_inputs) m_vec = m_out.last_hidden_state[:, 0, :] # CLS token m_vec = self.mol_proj(m_vec) # --- Normalize --- p_vec = F.normalize(p_vec, p=2, dim=1) m_vec = F.normalize(m_vec, p=2, dim=1) return p_vec, m_vec