File size: 2,919 Bytes
c174f3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0231a38
c174f3b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""
Custom ModelOutput classes for GeneMamba.
Defines the output structure for different GeneMamba tasks.
"""

from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from transformers.utils import ModelOutput


@dataclass
class GeneMambaModelOutput(ModelOutput):
    """
    Base output class for GeneMamba models.
    
    Attributes:
        last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)):
            Sequence of hidden-states at the output of the last layer of the model.
        
        hidden_states (tuple(torch.FloatTensor), optional):
            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        
        pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size)):
            Cell/sequence-level embedding (pooled representation) used for downstream tasks.
            This is the recommended embedding to use for classification, clustering, etc.
        
        embedding_pooling (str):
            The pooling method used to generate pooled_embedding.
    """
    
    last_hidden_state: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    pooled_embedding: torch.FloatTensor = None
    embedding_pooling: Optional[str] = None


@dataclass
class GeneMambaSequenceClassifierOutput(ModelOutput):
    """
    Output class for GeneMamba sequence classification models.
    
    Attributes:
        loss (torch.FloatTensor of shape (), optional):
            Classification loss (if labels were provided).
        
        logits (torch.FloatTensor of shape (batch_size, num_labels)):
            Classification scores (before softmax).
        
        hidden_states (tuple(torch.FloatTensor), optional):
            Hidden-states of the model at the output of each layer.
        
        pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size), optional):
            Cell embedding before classification head.
    """
    
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    pooled_embedding: Optional[torch.FloatTensor] = None


@dataclass
class GeneMambaMaskedLMOutput(ModelOutput):
    """
    Output class for GeneMamba masked language modeling.
    
    Attributes:
        loss (torch.FloatTensor of shape (), optional):
            MLM loss (if labels were provided).
        
        logits (torch.FloatTensor of shape (batch_size, sequence_length, vocab_size)):
            Prediction scores of the language modeling head.
        
        hidden_states (tuple(torch.FloatTensor), optional):
            Hidden-states of the model at the output of each layer.
    """
    
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None