| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from dataclasses import dataclass |
| | from typing import Optional, Tuple |
| | import torch |
| | from transformers.utils import ModelOutput |
| | from transformers import PreTrainedModel |
| | from transformers.modeling_outputs import BaseModelOutputWithPooling, SequenceClassifierOutput |
| | from transformers import AutoModel |
| |
|
| | from .configuration_protenrich import ProtEnrichConfig |
| |
|
| | @dataclass |
| | class ProtEnrichModelOutput(ModelOutput): |
| | h_enrich: torch.FloatTensor = None |
| | h_anchor: Optional[torch.FloatTensor] = None |
| | h_algn: Optional[torch.FloatTensor] = None |
| | struct: Optional[torch.FloatTensor] = None |
| | dyn: Optional[torch.FloatTensor] = None |
| |
|
| | class MLPEncoder(nn.Module): |
| | def __init__(self, in_dim, out_dim, hidden_dim=1024, n_layers=2, dropout=0.1): |
| | super().__init__() |
| | layers = [] |
| | d = in_dim |
| | for _ in range(n_layers - 1): |
| | layers += [ |
| | nn.Linear(d, hidden_dim), |
| | nn.LayerNorm(hidden_dim), |
| | nn.GELU(), |
| | nn.Dropout(dropout), |
| | ] |
| | d = hidden_dim |
| | layers.append(nn.Linear(d, out_dim)) |
| | self.net = nn.Sequential(*layers) |
| |
|
| | def forward(self, x): |
| | return self.net(x) |
| |
|
| | class ProtEnrichModel(PreTrainedModel): |
| | config_class = ProtEnrichConfig |
| | base_model_prefix = "protenrich" |
| |
|
| | def __init__(self, config: ProtEnrichConfig): |
| | super().__init__(config) |
| |
|
| | self.seq_anchor = MLPEncoder(config.seq_dim, config.embed_dim) |
| | self.seq_algn = MLPEncoder(config.seq_dim, config.embed_dim) |
| | self.struct_encoder = MLPEncoder(config.struct_dim, config.embed_dim) |
| | self.dyn_encoder = MLPEncoder(config.dyn_dim, config.embed_dim) |
| |
|
| | for p in self.struct_encoder.parameters(): |
| | p.requires_grad = False |
| | for p in self.dyn_encoder.parameters(): |
| | p.requires_grad = False |
| |
|
| | self.seq_projector = nn.Linear(config.embed_dim, config.project_dim) |
| | self.struct_projector = nn.Linear(config.embed_dim, config.project_dim) |
| | self.dyn_projector = nn.Linear(config.embed_dim, config.project_dim) |
| |
|
| | self.seq_decoder = MLPEncoder(config.embed_dim, config.seq_dim) |
| | self.struct_decoder = MLPEncoder(config.embed_dim, config.struct_dim) |
| | self.dyn_decoder = MLPEncoder(config.embed_dim, config.dyn_dim) |
| |
|
| | self.alpha_logit = nn.Parameter(torch.tensor(-2.0)) |
| | self.alpha_max = config.alpha_max |
| |
|
| | self.norm_anchor = nn.LayerNorm(config.embed_dim) |
| | self.norm_algn = nn.LayerNorm(config.embed_dim) |
| |
|
| | self.post_init() |
| |
|
| | def forward(self, seq: torch.Tensor, return_dict: Optional[bool] = None): |
| |
|
| | h_anchor = self.norm_anchor(self.seq_anchor(seq)) |
| | h_algn = self.norm_algn(self.seq_algn(seq)) |
| |
|
| | struct = self.struct_decoder(h_algn) |
| | dyn = self.dyn_decoder(h_algn) |
| |
|
| | alpha = torch.sigmoid(self.alpha_logit) * self.alpha_max |
| | h_enrich = h_anchor + alpha * h_algn |
| |
|
| | return ProtEnrichModelOutput( |
| | h_enrich=h_enrich, |
| | h_anchor=h_anchor, |
| | h_algn=h_algn, |
| | struct=struct, |
| | dyn=dyn, |
| | ) |
| |
|
| | class ProtEnrichForSequenceClassification(PreTrainedModel): |
| | config_class = ProtEnrichConfig |
| |
|
| | def __init__(self, config: ProtEnrichConfig): |
| | super().__init__(config) |
| |
|
| | self.num_labels = config.num_labels |
| |
|
| | self.protenrich = ProtEnrichModel(config) |
| | self.classifier = nn.Linear(config.embed_dim, config.num_labels) |
| |
|
| | self.post_init() |
| |
|
| | def forward(self, seq: torch.Tensor, labels: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None): |
| |
|
| | outputs = self.protenrich(seq=seq, return_dict=return_dict) |
| | pooled = outputs.h_enrich |
| |
|
| | logits = self.classifier(pooled) |
| |
|
| | loss = None |
| | if labels is not None: |
| | loss_fct = nn.CrossEntropyLoss() |
| | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
| |
|
| | return SequenceClassifierOutput( |
| | loss=loss, |
| | logits=logits, |
| | hidden_states=pooled, |
| | ) |