File size: 2,919 Bytes
c174f3b 0231a38 c174f3b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 | """
Custom ModelOutput classes for GeneMamba.
Defines the output structure for different GeneMamba tasks.
"""
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from transformers.utils import ModelOutput
@dataclass
class GeneMambaModelOutput(ModelOutput):
"""
Base output class for GeneMamba models.
Attributes:
last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (tuple(torch.FloatTensor), optional):
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size)):
Cell/sequence-level embedding (pooled representation) used for downstream tasks.
This is the recommended embedding to use for classification, clustering, etc.
embedding_pooling (str):
The pooling method used to generate pooled_embedding.
"""
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
pooled_embedding: torch.FloatTensor = None
embedding_pooling: Optional[str] = None
@dataclass
class GeneMambaSequenceClassifierOutput(ModelOutput):
"""
Output class for GeneMamba sequence classification models.
Attributes:
loss (torch.FloatTensor of shape (), optional):
Classification loss (if labels were provided).
logits (torch.FloatTensor of shape (batch_size, num_labels)):
Classification scores (before softmax).
hidden_states (tuple(torch.FloatTensor), optional):
Hidden-states of the model at the output of each layer.
pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size), optional):
Cell embedding before classification head.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
pooled_embedding: Optional[torch.FloatTensor] = None
@dataclass
class GeneMambaMaskedLMOutput(ModelOutput):
"""
Output class for GeneMamba masked language modeling.
Attributes:
loss (torch.FloatTensor of shape (), optional):
MLM loss (if labels were provided).
logits (torch.FloatTensor of shape (batch_size, sequence_length, vocab_size)):
Prediction scores of the language modeling head.
hidden_states (tuple(torch.FloatTensor), optional):
Hidden-states of the model at the output of each layer.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|