File size: 1,437 Bytes
4b70876
 
d602c10
4b70876
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch.nn as nn
from transformers import PreTrainedModel
from .configuration_exon_classifier import Evo2ExonConfig
    
class Evo2ExonModel(PreTrainedModel):
    config_class      = Evo2ExonConfig
    base_model_prefix = "evo2_exon_classifier"

    def __init__(self, config: Evo2ExonConfig):
        super().__init__(config)

        # ▸ build (Linear + ReLU) * n  + final Linear(…, 1)
        layers = [nn.Linear(config.embedding_dim, config.hidden_dim), nn.ReLU()]
        for _ in range(config.num_hidden_layers - 1):
            layers += [nn.Linear(config.hidden_dim, config.hidden_dim), nn.ReLU()]
        layers += [nn.Linear(config.hidden_dim, 1)]

        self.fc_layers = nn.Sequential(*layers)
        self.sigmoid    = nn.Sigmoid()      # convert logits → probability

    def forward(self, inputs_embeds, labels=None, **kwargs):
        """
        inputs_embeds : (batch, seq_len, embedding_dim)
        labels        : (batch, seq_len) optional, 0/1 floats or ints
        """
        bsz, seq_len, _ = inputs_embeds.shape

        # flatten → run FC layers → reshape back
        logits = self.fc_layers(inputs_embeds.view(-1, inputs_embeds.size(-1)))
        logits = logits.view(bsz, seq_len)
        probs  = self.sigmoid(logits)

        if labels is not None:
            loss = nn.BCELoss()(probs, labels.float())
            return {"loss": loss, "logits": probs}

        return {"logits": probs}