File size: 3,828 Bytes
a854962
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from transformers.utils import ModelOutput
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPooling, SequenceClassifierOutput
from transformers import AutoModel

from .configuration_protenrich import ProtEnrichConfig

@dataclass
class ProtEnrichModelOutput(ModelOutput):
  h_enrich: torch.FloatTensor = None
  h_anchor: Optional[torch.FloatTensor] = None
  h_algn: Optional[torch.FloatTensor] = None
  struct: Optional[torch.FloatTensor] = None
  dyn: Optional[torch.FloatTensor] = None

class MLPEncoder(nn.Module):
  def __init__(self, in_dim, out_dim, hidden_dim=1024, n_layers=2, dropout=0.1):
    super().__init__()
    layers = []
    d = in_dim
    for _ in range(n_layers - 1):
      layers += [
        nn.Linear(d, hidden_dim),
        nn.LayerNorm(hidden_dim),
        nn.GELU(),
        nn.Dropout(dropout),
      ]
      d = hidden_dim
    layers.append(nn.Linear(d, out_dim))
    self.net = nn.Sequential(*layers)

  def forward(self, x):
    return self.net(x)

class ProtEnrichModel(PreTrainedModel):
  config_class = ProtEnrichConfig
  base_model_prefix = "protenrich"

  def __init__(self, config: ProtEnrichConfig):
    super().__init__(config)

    self.seq_anchor = MLPEncoder(config.seq_dim, config.embed_dim)
    self.seq_algn = MLPEncoder(config.seq_dim, config.embed_dim)
    self.struct_encoder = MLPEncoder(config.struct_dim, config.embed_dim)
    self.dyn_encoder    = MLPEncoder(config.dyn_dim, config.embed_dim)

    for p in self.struct_encoder.parameters():
      p.requires_grad = False
    for p in self.dyn_encoder.parameters():
      p.requires_grad = False

    self.seq_projector = nn.Linear(config.embed_dim, config.project_dim)
    self.struct_projector = nn.Linear(config.embed_dim, config.project_dim)
    self.dyn_projector = nn.Linear(config.embed_dim, config.project_dim)

    self.seq_decoder = MLPEncoder(config.embed_dim, config.seq_dim)
    self.struct_decoder = MLPEncoder(config.embed_dim, config.struct_dim)
    self.dyn_decoder = MLPEncoder(config.embed_dim, config.dyn_dim)

    self.alpha_logit = nn.Parameter(torch.tensor(-2.0))
    self.alpha_max = config.alpha_max

    self.norm_anchor = nn.LayerNorm(config.embed_dim)
    self.norm_algn = nn.LayerNorm(config.embed_dim)

    self.post_init()

  def forward(self, seq: torch.Tensor, return_dict: Optional[bool] = None):

    h_anchor = self.norm_anchor(self.seq_anchor(seq))
    h_algn = self.norm_algn(self.seq_algn(seq))

    struct = self.struct_decoder(h_algn)
    dyn = self.dyn_decoder(h_algn)

    alpha = torch.sigmoid(self.alpha_logit) * self.alpha_max
    h_enrich = h_anchor + alpha * h_algn

    return ProtEnrichModelOutput(
      h_enrich=h_enrich,
      h_anchor=h_anchor,
      h_algn=h_algn,
      struct=struct,
      dyn=dyn,
    )

class ProtEnrichForSequenceClassification(PreTrainedModel):
  config_class = ProtEnrichConfig

  def __init__(self, config: ProtEnrichConfig):
    super().__init__(config)

    self.num_labels = config.num_labels

    self.protenrich = ProtEnrichModel(config)
    self.classifier = nn.Linear(config.embed_dim, config.num_labels)

    self.post_init()

  def forward(self, seq: torch.Tensor, labels: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None):

    outputs = self.protenrich(seq=seq, return_dict=return_dict)
    pooled = outputs.h_enrich

    logits = self.classifier(pooled)

    loss = None
    if labels is not None:
      loss_fct = nn.CrossEntropyLoss()
      loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

    return SequenceClassifierOutput(
      loss=loss,
      logits=logits,
      hidden_states=pooled,
    )