ProtEnrich-ESM2-T36 / modeling_protenrich.py
gabrielbianchin's picture
update
cf9f7d5
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from transformers.utils import ModelOutput
from transformers import PreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPooling, SequenceClassifierOutput
from transformers import AutoModel
from .configuration_protenrich import ProtEnrichConfig
@dataclass
class ProtEnrichModelOutput(ModelOutput):
h_enrich: torch.FloatTensor = None
h_anchor: Optional[torch.FloatTensor] = None
h_algn: Optional[torch.FloatTensor] = None
struct: Optional[torch.FloatTensor] = None
dyn: Optional[torch.FloatTensor] = None
class MLPEncoder(nn.Module):
def __init__(self, in_dim, out_dim, hidden_dim=1024, n_layers=2, dropout=0.1):
super().__init__()
layers = []
d = in_dim
for _ in range(n_layers - 1):
layers += [
nn.Linear(d, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
]
d = hidden_dim
layers.append(nn.Linear(d, out_dim))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
class ProtEnrichModel(PreTrainedModel):
config_class = ProtEnrichConfig
base_model_prefix = "protenrich"
def __init__(self, config: ProtEnrichConfig):
super().__init__(config)
self.seq_anchor = MLPEncoder(config.seq_dim, config.embed_dim)
self.seq_algn = MLPEncoder(config.seq_dim, config.embed_dim)
self.struct_encoder = MLPEncoder(config.struct_dim, config.embed_dim)
self.dyn_encoder = MLPEncoder(config.dyn_dim, config.embed_dim)
for p in self.struct_encoder.parameters():
p.requires_grad = False
for p in self.dyn_encoder.parameters():
p.requires_grad = False
self.seq_projector = nn.Linear(config.embed_dim, config.project_dim)
self.struct_projector = nn.Linear(config.embed_dim, config.project_dim)
self.dyn_projector = nn.Linear(config.embed_dim, config.project_dim)
self.seq_decoder = MLPEncoder(config.embed_dim, config.seq_dim)
self.struct_decoder = MLPEncoder(config.embed_dim, config.struct_dim)
self.dyn_decoder = MLPEncoder(config.embed_dim, config.dyn_dim)
self.alpha_logit = nn.Parameter(torch.tensor(-2.0))
self.alpha_max = config.alpha_max
self.norm_anchor = nn.LayerNorm(config.embed_dim)
self.norm_algn = nn.LayerNorm(config.embed_dim)
self.post_init()
def forward(self, seq: torch.Tensor, return_dict: Optional[bool] = None):
h_anchor = self.norm_anchor(self.seq_anchor(seq))
h_algn = self.norm_algn(self.seq_algn(seq))
struct = self.struct_decoder(h_algn)
dyn = self.dyn_decoder(h_algn)
alpha = torch.sigmoid(self.alpha_logit) * self.alpha_max
h_enrich = h_anchor + alpha * h_algn
return ProtEnrichModelOutput(
h_enrich=h_enrich,
h_anchor=h_anchor,
h_algn=h_algn,
struct=struct,
dyn=dyn,
)
class ProtEnrichForSequenceClassification(PreTrainedModel):
config_class = ProtEnrichConfig
def __init__(self, config: ProtEnrichConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.protenrich = ProtEnrichModel(config)
self.classifier = nn.Linear(config.embed_dim, config.num_labels)
self.post_init()
def forward(self, seq: torch.Tensor, labels: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None):
outputs = self.protenrich(seq=seq, return_dict=return_dict)
pooled = outputs.h_enrich
logits = self.classifier(pooled)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=pooled,
)