""" Custom ModelOutput classes for GeneMamba. Defines the output structure for different GeneMamba tasks. """ from dataclasses import dataclass from typing import Optional, Tuple import torch from transformers.utils import ModelOutput @dataclass class GeneMambaModelOutput(ModelOutput): """ Base output class for GeneMamba models. Attributes: last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)): Sequence of hidden-states at the output of the last layer of the model. hidden_states (tuple(torch.FloatTensor), optional): Hidden-states of the model at the output of each layer plus the initial embedding outputs. pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size)): Cell/sequence-level embedding (pooled representation) used for downstream tasks. This is the recommended embedding to use for classification, clustering, etc. embedding_pooling (str): The pooling method used to generate pooled_embedding. """ last_hidden_state: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None pooled_embedding: torch.FloatTensor = None embedding_pooling: Optional[str] = None @dataclass class GeneMambaSequenceClassifierOutput(ModelOutput): """ Output class for GeneMamba sequence classification models. Attributes: loss (torch.FloatTensor of shape (), optional): Classification loss (if labels were provided). logits (torch.FloatTensor of shape (batch_size, num_labels)): Classification scores (before softmax). hidden_states (tuple(torch.FloatTensor), optional): Hidden-states of the model at the output of each layer. pooled_embedding (torch.FloatTensor of shape (batch_size, hidden_size), optional): Cell embedding before classification head. """ loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None pooled_embedding: Optional[torch.FloatTensor] = None @dataclass class GeneMambaMaskedLMOutput(ModelOutput): """ Output class for GeneMamba masked language modeling. Attributes: loss (torch.FloatTensor of shape (), optional): MLM loss (if labels were provided). logits (torch.FloatTensor of shape (batch_size, sequence_length, vocab_size)): Prediction scores of the language modeling head. hidden_states (tuple(torch.FloatTensor), optional): Hidden-states of the model at the output of each layer. """ loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None