|
|
import torch.nn as nn |
|
|
from transformers import PreTrainedModel |
|
|
from .configuration_exon_classifier import Evo2ExonConfig |
|
|
|
|
|
class Evo2ExonModel(PreTrainedModel): |
|
|
config_class = Evo2ExonConfig |
|
|
base_model_prefix = "evo2_exon_classifier" |
|
|
|
|
|
def __init__(self, config: Evo2ExonConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
layers = [nn.Linear(config.embedding_dim, config.hidden_dim), nn.ReLU()] |
|
|
for _ in range(config.num_hidden_layers - 1): |
|
|
layers += [nn.Linear(config.hidden_dim, config.hidden_dim), nn.ReLU()] |
|
|
layers += [nn.Linear(config.hidden_dim, 1)] |
|
|
|
|
|
self.fc_layers = nn.Sequential(*layers) |
|
|
self.sigmoid = nn.Sigmoid() |
|
|
|
|
|
def forward(self, inputs_embeds, labels=None, **kwargs): |
|
|
""" |
|
|
inputs_embeds : (batch, seq_len, embedding_dim) |
|
|
labels : (batch, seq_len) optional, 0/1 floats or ints |
|
|
""" |
|
|
bsz, seq_len, _ = inputs_embeds.shape |
|
|
|
|
|
|
|
|
logits = self.fc_layers(inputs_embeds.view(-1, inputs_embeds.size(-1))) |
|
|
logits = logits.view(bsz, seq_len) |
|
|
probs = self.sigmoid(logits) |
|
|
|
|
|
if labels is not None: |
|
|
loss = nn.BCELoss()(probs, labels.float()) |
|
|
return {"loss": loss, "logits": probs} |
|
|
|
|
|
return {"logits": probs} |