File size: 2,932 Bytes
2180e31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer
from typing import Tuple, Dict

def load_encoder_components(
    protein_model_path: str, 
    molecule_model_path: str
) -> Tuple[PreTrainedModel, PreTrainedTokenizer, PreTrainedModel, PreTrainedTokenizer]:
    """
    加载模型和Tokenizer。
    """
    print(f"Loading Protein Encoder from: {protein_model_path}")
    protein_tokenizer = AutoTokenizer.from_pretrained(protein_model_path)
    protein_encoder = AutoModel.from_pretrained(protein_model_path)

    print(f"Loading Molecule Encoder from: {molecule_model_path}")
    molecule_tokenizer = AutoTokenizer.from_pretrained(molecule_model_path)
    molecule_encoder = AutoModel.from_pretrained(molecule_model_path)

    return protein_encoder, protein_tokenizer, molecule_encoder, molecule_tokenizer

class ProteinMoleculeDualEncoder(nn.Module):
    def __init__(
        self, 
        protein_encoder: PreTrainedModel, 
        molecule_encoder: PreTrainedModel, 
        projection_dim: int = 256,
        freeze_backbone: bool = False
    ):
        """
        初始化双塔模型。
        :param protein_encoder: 已经实例化好的 HuggingFace 模型对象
        :param molecule_encoder: 已经实例化好的 HuggingFace 模型对象
        """
        super().__init__()
        
        self.protein_encoder = protein_encoder
        self.molecule_encoder = molecule_encoder

        self.prot_hidden_size = protein_encoder.config.hidden_size
        self.mol_hidden_size = molecule_encoder.config.hidden_size

        self.prot_proj = nn.Linear(self.prot_hidden_size, projection_dim)
        self.mol_proj = nn.Linear(self.mol_hidden_size, projection_dim)

        if freeze_backbone:
            self._freeze_parameters(self.protein_encoder)
            self._freeze_parameters(self.molecule_encoder)

    def _freeze_parameters(self, model: nn.Module):
        """辅助函数:冻结参数"""
        for param in model.parameters():
            param.requires_grad = False
        print(f"Freezed parameters for {model.__class__.__name__}")

    def forward(self, protein_inputs: Dict, molecule_inputs: Dict):
        """
        :param protein_inputs: tokenizer 输出的字典 (input_ids, attention_mask)
        :param molecule_inputs: tokenizer 输出的字典
        """
        # --- Tower A ---
        p_out = self.protein_encoder(**protein_inputs)
        p_vec = p_out.last_hidden_state[:, 0, :] # CLS token
        p_vec = self.prot_proj(p_vec)

        # --- Tower B ---
        m_out = self.molecule_encoder(**molecule_inputs)
        m_vec = m_out.last_hidden_state[:, 0, :] # CLS token
        m_vec = self.mol_proj(m_vec)

        # --- Normalize ---
        p_vec = F.normalize(p_vec, p=2, dim=1)
        m_vec = F.normalize(m_vec, p=2, dim=1)

        return p_vec, m_vec