Recommend_system / model.py
tong
clear code
2180e31
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
from typing import Tuple, Dict
def load_encoder_components(
protein_model_path: str,
molecule_model_path: str
) -> Tuple[PreTrainedModel, PreTrainedTokenizer, PreTrainedModel, PreTrainedTokenizer]:
"""
加载模型和Tokenizer。
"""
print(f"Loading Protein Encoder from: {protein_model_path}")
protein_tokenizer = AutoTokenizer.from_pretrained(protein_model_path)
protein_encoder = AutoModel.from_pretrained(protein_model_path)
print(f"Loading Molecule Encoder from: {molecule_model_path}")
molecule_tokenizer = AutoTokenizer.from_pretrained(molecule_model_path)
molecule_encoder = AutoModel.from_pretrained(molecule_model_path)
return protein_encoder, protein_tokenizer, molecule_encoder, molecule_tokenizer
class ProteinMoleculeDualEncoder(nn.Module):
def __init__(
self,
protein_encoder: PreTrainedModel,
molecule_encoder: PreTrainedModel,
projection_dim: int = 256,
freeze_backbone: bool = False
):
"""
初始化双塔模型。
:param protein_encoder: 已经实例化好的 HuggingFace 模型对象
:param molecule_encoder: 已经实例化好的 HuggingFace 模型对象
"""
super().__init__()
self.protein_encoder = protein_encoder
self.molecule_encoder = molecule_encoder
self.prot_hidden_size = protein_encoder.config.hidden_size
self.mol_hidden_size = molecule_encoder.config.hidden_size
self.prot_proj = nn.Linear(self.prot_hidden_size, projection_dim)
self.mol_proj = nn.Linear(self.mol_hidden_size, projection_dim)
if freeze_backbone:
self._freeze_parameters(self.protein_encoder)
self._freeze_parameters(self.molecule_encoder)
def _freeze_parameters(self, model: nn.Module):
"""辅助函数:冻结参数"""
for param in model.parameters():
param.requires_grad = False
print(f"Freezed parameters for {model.__class__.__name__}")
def forward(self, protein_inputs: Dict, molecule_inputs: Dict):
"""
:param protein_inputs: tokenizer 输出的字典 (input_ids, attention_mask)
:param molecule_inputs: tokenizer 输出的字典
"""
# --- Tower A ---
p_out = self.protein_encoder(**protein_inputs)
p_vec = p_out.last_hidden_state[:, 0, :] # CLS token
p_vec = self.prot_proj(p_vec)
# --- Tower B ---
m_out = self.molecule_encoder(**molecule_inputs)
m_vec = m_out.last_hidden_state[:, 0, :] # CLS token
m_vec = self.mol_proj(m_vec)
# --- Normalize ---
p_vec = F.normalize(p_vec, p=2, dim=1)
m_vec = F.normalize(m_vec, p=2, dim=1)
return p_vec, m_vec