evo2-exon-classifier / wrapper_exon_classifier.py
Jonathan Schmok
Resolving import
d602c10
import torch.nn as nn
from transformers import PreTrainedModel
from .configuration_exon_classifier import Evo2ExonConfig
class Evo2ExonModel(PreTrainedModel):
config_class = Evo2ExonConfig
base_model_prefix = "evo2_exon_classifier"
def __init__(self, config: Evo2ExonConfig):
super().__init__(config)
# ▸ build (Linear + ReLU) * n + final Linear(…, 1)
layers = [nn.Linear(config.embedding_dim, config.hidden_dim), nn.ReLU()]
for _ in range(config.num_hidden_layers - 1):
layers += [nn.Linear(config.hidden_dim, config.hidden_dim), nn.ReLU()]
layers += [nn.Linear(config.hidden_dim, 1)]
self.fc_layers = nn.Sequential(*layers)
self.sigmoid = nn.Sigmoid() # convert logits → probability
def forward(self, inputs_embeds, labels=None, **kwargs):
"""
inputs_embeds : (batch, seq_len, embedding_dim)
labels : (batch, seq_len) optional, 0/1 floats or ints
"""
bsz, seq_len, _ = inputs_embeds.shape
# flatten → run FC layers → reshape back
logits = self.fc_layers(inputs_embeds.view(-1, inputs_embeds.size(-1)))
logits = logits.view(bsz, seq_len)
probs = self.sigmoid(logits)
if labels is not None:
loss = nn.BCELoss()(probs, labels.float())
return {"loss": loss, "logits": probs}
return {"logits": probs}