GeneMamba / modeling_outputs.py
mineself2016's picture
Fix ModelOutput dataclass defaults
0231a38 verified
"""
Custom ModelOutput classes for GeneMamba.
Defines the output structure for different GeneMamba tasks.
"""
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from transformers.utils import ModelOutput
@dataclass
class GeneMambaModelOutput(ModelOutput):
"""
Base output class for GeneMamba models.
Attributes:
last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (tuple(torch.FloatTensor), optional):
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size)):
Cell/sequence-level embedding (pooled representation) used for downstream tasks.
This is the recommended embedding to use for classification, clustering, etc.
embedding_pooling (str):
The pooling method used to generate pooled_embedding.
"""
last_hidden_state: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
pooled_embedding: torch.FloatTensor = None
embedding_pooling: Optional[str] = None
@dataclass
class GeneMambaSequenceClassifierOutput(ModelOutput):
"""
Output class for GeneMamba sequence classification models.
Attributes:
loss (torch.FloatTensor of shape (), optional):
Classification loss (if labels were provided).
logits (torch.FloatTensor of shape (batch_size, num_labels)):
Classification scores (before softmax).
hidden_states (tuple(torch.FloatTensor), optional):
Hidden-states of the model at the output of each layer.
pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size), optional):
Cell embedding before classification head.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
pooled_embedding: Optional[torch.FloatTensor] = None
@dataclass
class GeneMambaMaskedLMOutput(ModelOutput):
"""
Output class for GeneMamba masked language modeling.
Attributes:
loss (torch.FloatTensor of shape (), optional):
MLM loss (if labels were provided).
logits (torch.FloatTensor of shape (batch_size, sequence_length, vocab_size)):
Prediction scores of the language modeling head.
hidden_states (tuple(torch.FloatTensor), optional):
Hidden-states of the model at the output of each layer.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None