Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer | |
| from typing import Tuple, Dict | |
| def load_encoder_components( | |
| protein_model_path: str, | |
| molecule_model_path: str | |
| ) -> Tuple[PreTrainedModel, PreTrainedTokenizer, PreTrainedModel, PreTrainedTokenizer]: | |
| """ | |
| 加载模型和Tokenizer。 | |
| """ | |
| print(f"Loading Protein Encoder from: {protein_model_path}") | |
| protein_tokenizer = AutoTokenizer.from_pretrained(protein_model_path) | |
| protein_encoder = AutoModel.from_pretrained(protein_model_path) | |
| print(f"Loading Molecule Encoder from: {molecule_model_path}") | |
| molecule_tokenizer = AutoTokenizer.from_pretrained(molecule_model_path) | |
| molecule_encoder = AutoModel.from_pretrained(molecule_model_path) | |
| return protein_encoder, protein_tokenizer, molecule_encoder, molecule_tokenizer | |
| class ProteinMoleculeDualEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| protein_encoder: PreTrainedModel, | |
| molecule_encoder: PreTrainedModel, | |
| projection_dim: int = 256, | |
| freeze_backbone: bool = False | |
| ): | |
| """ | |
| 初始化双塔模型。 | |
| :param protein_encoder: 已经实例化好的 HuggingFace 模型对象 | |
| :param molecule_encoder: 已经实例化好的 HuggingFace 模型对象 | |
| """ | |
| super().__init__() | |
| self.protein_encoder = protein_encoder | |
| self.molecule_encoder = molecule_encoder | |
| self.prot_hidden_size = protein_encoder.config.hidden_size | |
| self.mol_hidden_size = molecule_encoder.config.hidden_size | |
| self.prot_proj = nn.Linear(self.prot_hidden_size, projection_dim) | |
| self.mol_proj = nn.Linear(self.mol_hidden_size, projection_dim) | |
| if freeze_backbone: | |
| self._freeze_parameters(self.protein_encoder) | |
| self._freeze_parameters(self.molecule_encoder) | |
| def _freeze_parameters(self, model: nn.Module): | |
| """辅助函数:冻结参数""" | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| print(f"Freezed parameters for {model.__class__.__name__}") | |
| def forward(self, protein_inputs: Dict, molecule_inputs: Dict): | |
| """ | |
| :param protein_inputs: tokenizer 输出的字典 (input_ids, attention_mask) | |
| :param molecule_inputs: tokenizer 输出的字典 | |
| """ | |
| # --- Tower A --- | |
| p_out = self.protein_encoder(**protein_inputs) | |
| p_vec = p_out.last_hidden_state[:, 0, :] # CLS token | |
| p_vec = self.prot_proj(p_vec) | |
| # --- Tower B --- | |
| m_out = self.molecule_encoder(**molecule_inputs) | |
| m_vec = m_out.last_hidden_state[:, 0, :] # CLS token | |
| m_vec = self.mol_proj(m_vec) | |
| # --- Normalize --- | |
| p_vec = F.normalize(p_vec, p=2, dim=1) | |
| m_vec = F.normalize(m_vec, p=2, dim=1) | |
| return p_vec, m_vec |