| """ |
| Custom ModelOutput classes for GeneMamba. |
| Defines the output structure for different GeneMamba tasks. |
| """ |
|
|
| from dataclasses import dataclass |
| from typing import Optional, Tuple |
| import torch |
| from transformers.utils import ModelOutput |
|
|
|
|
| @dataclass |
| class GeneMambaModelOutput(ModelOutput): |
| """ |
| Base output class for GeneMamba models. |
| |
| Attributes: |
| last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)): |
| Sequence of hidden-states at the output of the last layer of the model. |
| |
| hidden_states (tuple(torch.FloatTensor), optional): |
| Hidden-states of the model at the output of each layer plus the initial embedding outputs. |
| |
| pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size)): |
| Cell/sequence-level embedding (pooled representation) used for downstream tasks. |
| This is the recommended embedding to use for classification, clustering, etc. |
| |
| embedding_pooling (str): |
| The pooling method used to generate pooled_embedding. |
| """ |
| |
| last_hidden_state: torch.FloatTensor = None |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| pooled_embedding: torch.FloatTensor = None |
| embedding_pooling: Optional[str] = None |
|
|
|
|
| @dataclass |
| class GeneMambaSequenceClassifierOutput(ModelOutput): |
| """ |
| Output class for GeneMamba sequence classification models. |
| |
| Attributes: |
| loss (torch.FloatTensor of shape (), optional): |
| Classification loss (if labels were provided). |
| |
| logits (torch.FloatTensor of shape (batch_size, num_labels)): |
| Classification scores (before softmax). |
| |
| hidden_states (tuple(torch.FloatTensor), optional): |
| Hidden-states of the model at the output of each layer. |
| |
| pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size), optional): |
| Cell embedding before classification head. |
| """ |
| |
| loss: Optional[torch.FloatTensor] = None |
| logits: torch.FloatTensor = None |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| pooled_embedding: Optional[torch.FloatTensor] = None |
|
|
|
|
| @dataclass |
| class GeneMambaMaskedLMOutput(ModelOutput): |
| """ |
| Output class for GeneMamba masked language modeling. |
| |
| Attributes: |
| loss (torch.FloatTensor of shape (), optional): |
| MLM loss (if labels were provided). |
| |
| logits (torch.FloatTensor of shape (batch_size, sequence_length, vocab_size)): |
| Prediction scores of the language modeling head. |
| |
| hidden_states (tuple(torch.FloatTensor), optional): |
| Hidden-states of the model at the output of each layer. |
| """ |
| |
| loss: Optional[torch.FloatTensor] = None |
| logits: torch.FloatTensor = None |
| hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
|
|