import torch import torch.nn as nn import torch.nn.functional as F from dataclasses import dataclass from typing import Optional, Tuple import torch from transformers.utils import ModelOutput from transformers import PreTrainedModel from transformers.modeling_outputs import BaseModelOutputWithPooling, SequenceClassifierOutput from transformers import AutoModel from .configuration_protenrich import ProtEnrichConfig @dataclass class ProtEnrichModelOutput(ModelOutput): h_enrich: torch.FloatTensor = None h_anchor: Optional[torch.FloatTensor] = None h_algn: Optional[torch.FloatTensor] = None struct: Optional[torch.FloatTensor] = None dyn: Optional[torch.FloatTensor] = None class MLPEncoder(nn.Module): def __init__(self, in_dim, out_dim, hidden_dim=1024, n_layers=2, dropout=0.1): super().__init__() layers = [] d = in_dim for _ in range(n_layers - 1): layers += [ nn.Linear(d, hidden_dim), nn.LayerNorm(hidden_dim), nn.GELU(), nn.Dropout(dropout), ] d = hidden_dim layers.append(nn.Linear(d, out_dim)) self.net = nn.Sequential(*layers) def forward(self, x): return self.net(x) class ProtEnrichModel(PreTrainedModel): config_class = ProtEnrichConfig base_model_prefix = "protenrich" def __init__(self, config: ProtEnrichConfig): super().__init__(config) self.seq_anchor = MLPEncoder(config.seq_dim, config.embed_dim) self.seq_algn = MLPEncoder(config.seq_dim, config.embed_dim) self.struct_encoder = MLPEncoder(config.struct_dim, config.embed_dim) self.dyn_encoder = MLPEncoder(config.dyn_dim, config.embed_dim) for p in self.struct_encoder.parameters(): p.requires_grad = False for p in self.dyn_encoder.parameters(): p.requires_grad = False self.seq_projector = nn.Linear(config.embed_dim, config.project_dim) self.struct_projector = nn.Linear(config.embed_dim, config.project_dim) self.dyn_projector = nn.Linear(config.embed_dim, config.project_dim) self.seq_decoder = MLPEncoder(config.embed_dim, config.seq_dim) self.struct_decoder = MLPEncoder(config.embed_dim, config.struct_dim) self.dyn_decoder = MLPEncoder(config.embed_dim, config.dyn_dim) self.alpha_logit = nn.Parameter(torch.tensor(-2.0)) self.alpha_max = config.alpha_max self.norm_anchor = nn.LayerNorm(config.embed_dim) self.norm_algn = nn.LayerNorm(config.embed_dim) self.post_init() def forward(self, seq: torch.Tensor, return_dict: Optional[bool] = None): h_anchor = self.norm_anchor(self.seq_anchor(seq)) h_algn = self.norm_algn(self.seq_algn(seq)) struct = self.struct_decoder(h_algn) dyn = self.dyn_decoder(h_algn) alpha = torch.sigmoid(self.alpha_logit) * self.alpha_max h_enrich = h_anchor + alpha * h_algn return ProtEnrichModelOutput( h_enrich=h_enrich, h_anchor=h_anchor, h_algn=h_algn, struct=struct, dyn=dyn, ) class ProtEnrichForSequenceClassification(PreTrainedModel): config_class = ProtEnrichConfig def __init__(self, config: ProtEnrichConfig): super().__init__(config) self.num_labels = config.num_labels self.protenrich = ProtEnrichModel(config) self.classifier = nn.Linear(config.embed_dim, config.num_labels) self.post_init() def forward(self, seq: torch.Tensor, labels: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None): outputs = self.protenrich(seq=seq, return_dict=return_dict) pooled = outputs.h_enrich logits = self.classifier(pooled) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=pooled, )