| """ |
| PyTorch implementation of GeneMamba model for Hugging Face Transformers. |
| Includes backbone model and task-specific heads for various downstream tasks. |
| """ |
|
|
| import math |
| import logging |
| from typing import Optional, Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.nn.init import normal_, constant_ |
|
|
| from transformers import PreTrainedModel, PretrainedConfig |
| from transformers.modeling_outputs import SequenceClassifierOutput, ModelOutput |
| try: |
| from transformers.models.auto import register_model_for_auto_class |
| except ImportError: |
| def register_model_for_auto_class(auto_class): |
| def wrapper(cls): |
| return cls |
| return wrapper |
|
|
| try: |
| from mamba_ssm import Mamba2 as MambaBlock |
| except ImportError: |
| from mamba_ssm import Mamba as MambaBlock |
| from mamba_ssm.ops.triton.layer_norm import RMSNorm |
|
|
| from .configuration_genemamba import GeneMambaConfig |
| from .modeling_outputs import GeneMambaModelOutput, GeneMambaSequenceClassifierOutput, GeneMambaMaskedLMOutput |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| |
| |
|
|
| class EncoderLayer(nn.Module): |
| """ |
| Single Mamba encoder layer with residual connection. |
| Applies a Mamba2 or Mamba layer followed by addition with input. |
| |
| Args: |
| hidden_size (int): Dimension of hidden states. |
| """ |
| |
| def __init__(self, hidden_size: int): |
| super(EncoderLayer, self).__init__() |
| self.mamba = MambaBlock(d_model=hidden_size, d_state=64, d_conv=4, expand=2) |
| |
| def forward(self, X: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size). |
| |
| Returns: |
| torch.Tensor: Output after Mamba layer and residual connection. |
| """ |
| output = self.mamba(X) + X |
| return output |
|
|
|
|
| class MambaMixer(nn.Module): |
| """ |
| Stack of Mamba encoder layers with bidirectional processing and aggregation. |
| Processes sequences in both forward and reverse directions, then aggregates. |
| |
| Args: |
| mode (str): Aggregation mode. Options: "mean", "sum", "concat", "gate". |
| hidden_size (int): Dimension of hidden states. |
| num_hidden_layers (int): Number of Mamba layers. |
| """ |
| |
| def __init__( |
| self, |
| mode: str = "gate", |
| hidden_size: int = 512, |
| num_hidden_layers: int = 24 |
| ): |
| super(MambaMixer, self).__init__() |
| self.mode = mode |
| self.hidden_size = hidden_size |
| |
| |
| self.layers = nn.ModuleList( |
| [EncoderLayer(hidden_size) for _ in range(num_hidden_layers)] |
| ) |
| |
| |
| if mode in ["concat", "gate"]: |
| self.aggr = nn.Linear(hidden_size * 2, hidden_size) |
| |
| def flip_sequence(self, X: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: |
| """ |
| Reverse a sequence based on actual length (ignoring padding). |
| |
| Args: |
| X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size). |
| mask (torch.Tensor, optional): Padding mask of shape (batch_size, seq_len). |
| |
| Returns: |
| torch.Tensor: Reversed tensor. |
| """ |
| batch_size, seq_length, embedding_dim = X.size() |
| |
| if mask is None: |
| |
| return X.flip([1]) |
| |
| |
| lengths = (~mask).sum(dim=1) |
| pos_tensor = torch.arange(seq_length, device=X.device).unsqueeze(0).expand(batch_size, -1) |
| flip_mask = pos_tensor < lengths.unsqueeze(1) |
| reversed_positions = torch.where( |
| flip_mask, |
| lengths.unsqueeze(1) - 1 - pos_tensor, |
| pos_tensor |
| ) |
| |
| X_reverse = torch.gather(X, 1, reversed_positions.unsqueeze(-1).expand(-1, -1, embedding_dim)) |
| return X_reverse |
| |
| def forward( |
| self, |
| X: torch.Tensor, |
| padding_mask: Optional[torch.Tensor] = None |
| ) -> torch.Tensor: |
| """ |
| Process sequence through bidirectional Mamba layers. |
| |
| Args: |
| X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size). |
| padding_mask (torch.Tensor, optional): Padding mask. |
| |
| Returns: |
| torch.Tensor: Output after processing all layers and aggregation. |
| """ |
| |
| for layer in self.layers: |
| |
| X_flip = self.flip_sequence(X, padding_mask) |
| |
| |
| X_f = layer(X) |
| X_b = layer(X_flip) |
| |
| |
| X_b = self.flip_sequence(X_b, padding_mask) |
| |
| |
| if self.mode == "mean": |
| X = (X_f + X_b) / 2 |
| elif self.mode == "sum": |
| X = X_f + X_b |
| elif self.mode == "concat": |
| X = torch.cat([X_f, X_b], dim=-1) |
| X = self.aggr(X) |
| elif self.mode == "gate": |
| z = torch.sigmoid(self.aggr(torch.cat([X_f, X_b], dim=-1))) |
| X = z * X_f + (1 - z) * X_b |
| else: |
| raise ValueError(f"Invalid aggregation mode: {self.mode}") |
| |
| return X |
|
|
|
|
| |
| |
| |
|
|
| class GeneMambaPreTrainedModel(PreTrainedModel): |
| """ |
| Base class for all GeneMamba models. |
| Handles weight initialization and provides standard model interfaces. |
| """ |
| |
| config_class = GeneMambaConfig |
| base_model_prefix = "genemamba" |
| supports_gradient_checkpointing = True |
| |
| def _init_weights(self, module): |
| """Initialize module weights.""" |
| if isinstance(module, nn.Linear): |
| normal_(module.weight, std=self.config.initializer_range) |
| if module.bias is not None: |
| constant_(module.bias, 0.0) |
| elif isinstance(module, nn.Embedding): |
| normal_(module.weight, std=self.config.initializer_range) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
| elif isinstance(module, nn.LayerNorm): |
| constant_(module.bias, 0.0) |
| constant_(module.weight, 1.0) |
|
|
|
|
| class GeneMambaModel(GeneMambaPreTrainedModel): |
| """ |
| GeneMamba backbone model - outputs cell embeddings and hidden states. |
| This is the core model used by task-specific heads. |
| |
| Args: |
| config (GeneMambaConfig): Model configuration class. |
| """ |
| |
| def __init__(self, config: GeneMambaConfig): |
| super().__init__(config) |
| self.config = config |
| |
| |
| self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) |
| |
| |
| self.mamba_mixer = MambaMixer( |
| mode=config.mamba_mode, |
| hidden_size=config.hidden_size, |
| num_hidden_layers=config.num_hidden_layers |
| ) |
| |
| |
| self.norm_f = RMSNorm(config.hidden_size) |
| |
| self.apply(self._init_weights) |
| |
| def get_input_embeddings(self) -> nn.Embedding: |
| """Return embedding layer.""" |
| return self.embeddings |
| |
| def set_input_embeddings(self, value: nn.Embedding): |
| """Set embedding layer.""" |
| self.embeddings = value |
| |
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| output_hidden_states: bool = False, |
| ) -> GeneMambaModelOutput: |
| """ |
| Args: |
| input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len). |
| attention_mask (torch.Tensor, optional): Attention mask of shape (batch_size, seq_len). |
| output_hidden_states (bool): Whether to output hidden states from all layers. |
| |
| Returns: |
| GeneMambaModelOutput: Contains last_hidden_state, pooled_embedding, etc. |
| """ |
| |
| hidden_states = self.embeddings(input_ids) |
| |
| |
| hidden_states = self.mamba_mixer(hidden_states, attention_mask) |
| |
| |
| hidden_states = self.norm_f(hidden_states) |
| |
| |
| if self.config.embedding_pooling == "CLS": |
| |
| pooled_embedding = hidden_states[:, 0, :] |
| elif self.config.embedding_pooling == "mean": |
| |
| if attention_mask is not None: |
| mask = attention_mask.unsqueeze(-1).expand(hidden_states.shape).float() |
| pooled_embedding = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1) |
| else: |
| pooled_embedding = hidden_states.mean(dim=1) |
| else: |
| raise ValueError(f"Unsupported embedding_pooling: {self.config.embedding_pooling}") |
| |
| return GeneMambaModelOutput( |
| last_hidden_state=hidden_states, |
| pooled_embedding=pooled_embedding, |
| hidden_states=hidden_states if output_hidden_states else None, |
| embedding_pooling=self.config.embedding_pooling, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| @register_model_for_auto_class("AutoModel") |
| class GeneMambaForMaskedLM(GeneMambaPreTrainedModel): |
| """ |
| GeneMamba model for masked language modeling (MLM). |
| Suitable for pretraining and domain adaptation. |
| |
| Args: |
| config (GeneMambaConfig): Model configuration class. |
| """ |
| |
| def __init__(self, config: GeneMambaConfig): |
| super().__init__(config) |
| self.genemamba = GeneMambaModel(config) |
| |
| |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) |
| |
| self.apply(self._init_weights) |
| |
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| output_hidden_states: bool = False, |
| ) -> GeneMambaMaskedLMOutput: |
| """ |
| Args: |
| input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len). |
| attention_mask (torch.Tensor, optional): Attention mask. |
| labels (torch.Tensor, optional): Target token ids for MLM loss. |
| output_hidden_states (bool): Whether to output hidden states. |
| |
| Returns: |
| GeneMambaMaskedLMOutput: Contains logits and optional loss. |
| """ |
| outputs = self.genemamba( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| output_hidden_states=output_hidden_states, |
| ) |
| |
| logits = self.lm_head(outputs.last_hidden_state) |
| |
| loss = None |
| if labels is not None: |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) |
| |
| return GeneMambaMaskedLMOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=outputs.hidden_states if output_hidden_states else None, |
| ) |
|
|
|
|
| @register_model_for_auto_class("AutoModelForSequenceClassification") |
| class GeneMambaForSequenceClassification(GeneMambaPreTrainedModel): |
| """ |
| GeneMamba model for sequence classification tasks. |
| Ideal for cell type annotation, tissue classification, etc. |
| |
| Args: |
| config (GeneMambaConfig): Model configuration class. |
| """ |
| |
| def __init__(self, config: GeneMambaConfig): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
| self.config = config |
| |
| self.genemamba = GeneMambaModel(config) |
| |
| |
| self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
| |
| self.apply(self._init_weights) |
| |
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| output_hidden_states: bool = False, |
| ) -> GeneMambaSequenceClassifierOutput: |
| """ |
| Args: |
| input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len). |
| attention_mask (torch.Tensor, optional): Attention mask. |
| labels (torch.Tensor, optional): Class labels for classification loss. |
| output_hidden_states (bool): Whether to output hidden states. |
| |
| Returns: |
| GeneMambaSequenceClassifierOutput: Contains logits, optional loss, and embedding. |
| """ |
| outputs = self.genemamba( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| output_hidden_states=output_hidden_states, |
| ) |
| |
| pooled_embedding = outputs.pooled_embedding |
| logits = self.classifier(self.dropout(pooled_embedding)) |
| |
| loss = None |
| if labels is not None: |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(logits, labels) |
| |
| return GeneMambaSequenceClassifierOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=outputs.hidden_states if output_hidden_states else None, |
| pooled_embedding=pooled_embedding, |
| ) |
|
|
|
|
| |
| register_model_for_auto_class("AutoModelForMaskedLM")(GeneMambaForMaskedLM) |
|
|