""" PyTorch implementation of GeneMamba model for Hugging Face Transformers. Includes backbone model and task-specific heads for various downstream tasks. """ import math import logging from typing import Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.init import normal_, constant_ from transformers import PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import SequenceClassifierOutput, ModelOutput try: from transformers.models.auto import register_model_for_auto_class except ImportError: def register_model_for_auto_class(auto_class): def wrapper(cls): return cls return wrapper try: from mamba_ssm import Mamba2 as MambaBlock except ImportError: from mamba_ssm import Mamba as MambaBlock from mamba_ssm.ops.triton.layer_norm import RMSNorm from .configuration_genemamba import GeneMambaConfig from .modeling_outputs import GeneMambaModelOutput, GeneMambaSequenceClassifierOutput, GeneMambaMaskedLMOutput logger = logging.getLogger(__name__) # =========================== # Core Architecture Components # =========================== class EncoderLayer(nn.Module): """ Single Mamba encoder layer with residual connection. Applies a Mamba2 or Mamba layer followed by addition with input. Args: hidden_size (int): Dimension of hidden states. """ def __init__(self, hidden_size: int): super(EncoderLayer, self).__init__() self.mamba = MambaBlock(d_model=hidden_size, d_state=64, d_conv=4, expand=2) def forward(self, X: torch.Tensor) -> torch.Tensor: """ Args: X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size). Returns: torch.Tensor: Output after Mamba layer and residual connection. """ output = self.mamba(X) + X return output class MambaMixer(nn.Module): """ Stack of Mamba encoder layers with bidirectional processing and aggregation. Processes sequences in both forward and reverse directions, then aggregates. Args: mode (str): Aggregation mode. Options: "mean", "sum", "concat", "gate". hidden_size (int): Dimension of hidden states. num_hidden_layers (int): Number of Mamba layers. """ def __init__( self, mode: str = "gate", hidden_size: int = 512, num_hidden_layers: int = 24 ): super(MambaMixer, self).__init__() self.mode = mode self.hidden_size = hidden_size # Create Mamba layers self.layers = nn.ModuleList( [EncoderLayer(hidden_size) for _ in range(num_hidden_layers)] ) # Aggregation modules for certain modes if mode in ["concat", "gate"]: self.aggr = nn.Linear(hidden_size * 2, hidden_size) def flip_sequence(self, X: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: """ Reverse a sequence based on actual length (ignoring padding). Args: X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size). mask (torch.Tensor, optional): Padding mask of shape (batch_size, seq_len). Returns: torch.Tensor: Reversed tensor. """ batch_size, seq_length, embedding_dim = X.size() if mask is None: # Simple flip return X.flip([1]) # Flip based on actual sequence length (marked by mask) lengths = (~mask).sum(dim=1) pos_tensor = torch.arange(seq_length, device=X.device).unsqueeze(0).expand(batch_size, -1) flip_mask = pos_tensor < lengths.unsqueeze(1) reversed_positions = torch.where( flip_mask, lengths.unsqueeze(1) - 1 - pos_tensor, pos_tensor ) X_reverse = torch.gather(X, 1, reversed_positions.unsqueeze(-1).expand(-1, -1, embedding_dim)) return X_reverse def forward( self, X: torch.Tensor, padding_mask: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Process sequence through bidirectional Mamba layers. Args: X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size). padding_mask (torch.Tensor, optional): Padding mask. Returns: torch.Tensor: Output after processing all layers and aggregation. """ for layer in self.layers: # Flip sequence for reverse processing X_flip = self.flip_sequence(X, padding_mask) # Forward and reverse passes X_f = layer(X) X_b = layer(X_flip) # Flip back the reverse output X_b = self.flip_sequence(X_b, padding_mask) # Aggregate forward and reverse if self.mode == "mean": X = (X_f + X_b) / 2 elif self.mode == "sum": X = X_f + X_b elif self.mode == "concat": X = torch.cat([X_f, X_b], dim=-1) X = self.aggr(X) elif self.mode == "gate": z = torch.sigmoid(self.aggr(torch.cat([X_f, X_b], dim=-1))) X = z * X_f + (1 - z) * X_b else: raise ValueError(f"Invalid aggregation mode: {self.mode}") return X # =========================== # Base Model Classes # =========================== class GeneMambaPreTrainedModel(PreTrainedModel): """ Base class for all GeneMamba models. Handles weight initialization and provides standard model interfaces. """ config_class = GeneMambaConfig base_model_prefix = "genemamba" supports_gradient_checkpointing = True def _init_weights(self, module): """Initialize module weights.""" if isinstance(module, nn.Linear): normal_(module.weight, std=self.config.initializer_range) if module.bias is not None: constant_(module.bias, 0.0) elif isinstance(module, nn.Embedding): normal_(module.weight, std=self.config.initializer_range) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): constant_(module.bias, 0.0) constant_(module.weight, 1.0) class GeneMambaModel(GeneMambaPreTrainedModel): """ GeneMamba backbone model - outputs cell embeddings and hidden states. This is the core model used by task-specific heads. Args: config (GeneMambaConfig): Model configuration class. """ def __init__(self, config: GeneMambaConfig): super().__init__(config) self.config = config # Embedding layer self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) # Mamba layers with bidirectional aggregation self.mamba_mixer = MambaMixer( mode=config.mamba_mode, hidden_size=config.hidden_size, num_hidden_layers=config.num_hidden_layers ) # Final layer normalization (kept as norm_f to match checkpoint key names) self.norm_f = RMSNorm(config.hidden_size) self.apply(self._init_weights) def get_input_embeddings(self) -> nn.Embedding: """Return embedding layer.""" return self.embeddings def set_input_embeddings(self, value: nn.Embedding): """Set embedding layer.""" self.embeddings = value def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_hidden_states: bool = False, ) -> GeneMambaModelOutput: """ Args: input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len). attention_mask (torch.Tensor, optional): Attention mask of shape (batch_size, seq_len). output_hidden_states (bool): Whether to output hidden states from all layers. Returns: GeneMambaModelOutput: Contains last_hidden_state, pooled_embedding, etc. """ # Get embeddings hidden_states = self.embeddings(input_ids) # Pass through Mamba layers hidden_states = self.mamba_mixer(hidden_states, attention_mask) # Apply final normalization hidden_states = self.norm_f(hidden_states) # Compute pooled embedding (cell representation) if self.config.embedding_pooling == "CLS": # Use first token (CLS) pooled_embedding = hidden_states[:, 0, :] elif self.config.embedding_pooling == "mean": # Mean pooling over sequence if attention_mask is not None: mask = attention_mask.unsqueeze(-1).expand(hidden_states.shape).float() pooled_embedding = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1) else: pooled_embedding = hidden_states.mean(dim=1) else: raise ValueError(f"Unsupported embedding_pooling: {self.config.embedding_pooling}") return GeneMambaModelOutput( last_hidden_state=hidden_states, pooled_embedding=pooled_embedding, hidden_states=hidden_states if output_hidden_states else None, embedding_pooling=self.config.embedding_pooling, ) # =========================== # Task-Specific Models # =========================== @register_model_for_auto_class("AutoModel") class GeneMambaForMaskedLM(GeneMambaPreTrainedModel): """ GeneMamba model for masked language modeling (MLM). Suitable for pretraining and domain adaptation. Args: config (GeneMambaConfig): Model configuration class. """ def __init__(self, config: GeneMambaConfig): super().__init__(config) self.genemamba = GeneMambaModel(config) # Language modeling head self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) self.apply(self._init_weights) def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_hidden_states: bool = False, ) -> GeneMambaMaskedLMOutput: """ Args: input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len). attention_mask (torch.Tensor, optional): Attention mask. labels (torch.Tensor, optional): Target token ids for MLM loss. output_hidden_states (bool): Whether to output hidden states. Returns: GeneMambaMaskedLMOutput: Contains logits and optional loss. """ outputs = self.genemamba( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=output_hidden_states, ) logits = self.lm_head(outputs.last_hidden_state) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) return GeneMambaMaskedLMOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states if output_hidden_states else None, ) @register_model_for_auto_class("AutoModelForSequenceClassification") class GeneMambaForSequenceClassification(GeneMambaPreTrainedModel): """ GeneMamba model for sequence classification tasks. Ideal for cell type annotation, tissue classification, etc. Args: config (GeneMambaConfig): Model configuration class. """ def __init__(self, config: GeneMambaConfig): super().__init__(config) self.num_labels = config.num_labels self.config = config self.genemamba = GeneMambaModel(config) # Classification head self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.apply(self._init_weights) def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, output_hidden_states: bool = False, ) -> GeneMambaSequenceClassifierOutput: """ Args: input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len). attention_mask (torch.Tensor, optional): Attention mask. labels (torch.Tensor, optional): Class labels for classification loss. output_hidden_states (bool): Whether to output hidden states. Returns: GeneMambaSequenceClassifierOutput: Contains logits, optional loss, and embedding. """ outputs = self.genemamba( input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=output_hidden_states, ) pooled_embedding = outputs.pooled_embedding logits = self.classifier(self.dropout(pooled_embedding)) loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits, labels) return GeneMambaSequenceClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states if output_hidden_states else None, pooled_embedding=pooled_embedding, ) # Register tokenizer class register_model_for_auto_class("AutoModelForMaskedLM")(GeneMambaForMaskedLM)