GeneMamba / modeling_genemamba.py
mineself2016's picture
Align to Mamba2 checkpoint keys and config
020c027 verified
"""
PyTorch implementation of GeneMamba model for Hugging Face Transformers.
Includes backbone model and task-specific heads for various downstream tasks.
"""
import math
import logging
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import normal_, constant_
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import SequenceClassifierOutput, ModelOutput
try:
from transformers.models.auto import register_model_for_auto_class
except ImportError:
def register_model_for_auto_class(auto_class):
def wrapper(cls):
return cls
return wrapper
try:
from mamba_ssm import Mamba2 as MambaBlock
except ImportError:
from mamba_ssm import Mamba as MambaBlock
from mamba_ssm.ops.triton.layer_norm import RMSNorm
from .configuration_genemamba import GeneMambaConfig
from .modeling_outputs import GeneMambaModelOutput, GeneMambaSequenceClassifierOutput, GeneMambaMaskedLMOutput
logger = logging.getLogger(__name__)
# ===========================
# Core Architecture Components
# ===========================
class EncoderLayer(nn.Module):
"""
Single Mamba encoder layer with residual connection.
Applies a Mamba2 or Mamba layer followed by addition with input.
Args:
hidden_size (int): Dimension of hidden states.
"""
def __init__(self, hidden_size: int):
super(EncoderLayer, self).__init__()
self.mamba = MambaBlock(d_model=hidden_size, d_state=64, d_conv=4, expand=2)
def forward(self, X: torch.Tensor) -> torch.Tensor:
"""
Args:
X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
Returns:
torch.Tensor: Output after Mamba layer and residual connection.
"""
output = self.mamba(X) + X
return output
class MambaMixer(nn.Module):
"""
Stack of Mamba encoder layers with bidirectional processing and aggregation.
Processes sequences in both forward and reverse directions, then aggregates.
Args:
mode (str): Aggregation mode. Options: "mean", "sum", "concat", "gate".
hidden_size (int): Dimension of hidden states.
num_hidden_layers (int): Number of Mamba layers.
"""
def __init__(
self,
mode: str = "gate",
hidden_size: int = 512,
num_hidden_layers: int = 24
):
super(MambaMixer, self).__init__()
self.mode = mode
self.hidden_size = hidden_size
# Create Mamba layers
self.layers = nn.ModuleList(
[EncoderLayer(hidden_size) for _ in range(num_hidden_layers)]
)
# Aggregation modules for certain modes
if mode in ["concat", "gate"]:
self.aggr = nn.Linear(hidden_size * 2, hidden_size)
def flip_sequence(self, X: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Reverse a sequence based on actual length (ignoring padding).
Args:
X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
mask (torch.Tensor, optional): Padding mask of shape (batch_size, seq_len).
Returns:
torch.Tensor: Reversed tensor.
"""
batch_size, seq_length, embedding_dim = X.size()
if mask is None:
# Simple flip
return X.flip([1])
# Flip based on actual sequence length (marked by mask)
lengths = (~mask).sum(dim=1)
pos_tensor = torch.arange(seq_length, device=X.device).unsqueeze(0).expand(batch_size, -1)
flip_mask = pos_tensor < lengths.unsqueeze(1)
reversed_positions = torch.where(
flip_mask,
lengths.unsqueeze(1) - 1 - pos_tensor,
pos_tensor
)
X_reverse = torch.gather(X, 1, reversed_positions.unsqueeze(-1).expand(-1, -1, embedding_dim))
return X_reverse
def forward(
self,
X: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Process sequence through bidirectional Mamba layers.
Args:
X (torch.Tensor): Input tensor of shape (batch_size, seq_len, hidden_size).
padding_mask (torch.Tensor, optional): Padding mask.
Returns:
torch.Tensor: Output after processing all layers and aggregation.
"""
for layer in self.layers:
# Flip sequence for reverse processing
X_flip = self.flip_sequence(X, padding_mask)
# Forward and reverse passes
X_f = layer(X)
X_b = layer(X_flip)
# Flip back the reverse output
X_b = self.flip_sequence(X_b, padding_mask)
# Aggregate forward and reverse
if self.mode == "mean":
X = (X_f + X_b) / 2
elif self.mode == "sum":
X = X_f + X_b
elif self.mode == "concat":
X = torch.cat([X_f, X_b], dim=-1)
X = self.aggr(X)
elif self.mode == "gate":
z = torch.sigmoid(self.aggr(torch.cat([X_f, X_b], dim=-1)))
X = z * X_f + (1 - z) * X_b
else:
raise ValueError(f"Invalid aggregation mode: {self.mode}")
return X
# ===========================
# Base Model Classes
# ===========================
class GeneMambaPreTrainedModel(PreTrainedModel):
"""
Base class for all GeneMamba models.
Handles weight initialization and provides standard model interfaces.
"""
config_class = GeneMambaConfig
base_model_prefix = "genemamba"
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize module weights."""
if isinstance(module, nn.Linear):
normal_(module.weight, std=self.config.initializer_range)
if module.bias is not None:
constant_(module.bias, 0.0)
elif isinstance(module, nn.Embedding):
normal_(module.weight, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
constant_(module.bias, 0.0)
constant_(module.weight, 1.0)
class GeneMambaModel(GeneMambaPreTrainedModel):
"""
GeneMamba backbone model - outputs cell embeddings and hidden states.
This is the core model used by task-specific heads.
Args:
config (GeneMambaConfig): Model configuration class.
"""
def __init__(self, config: GeneMambaConfig):
super().__init__(config)
self.config = config
# Embedding layer
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
# Mamba layers with bidirectional aggregation
self.mamba_mixer = MambaMixer(
mode=config.mamba_mode,
hidden_size=config.hidden_size,
num_hidden_layers=config.num_hidden_layers
)
# Final layer normalization (kept as norm_f to match checkpoint key names)
self.norm_f = RMSNorm(config.hidden_size)
self.apply(self._init_weights)
def get_input_embeddings(self) -> nn.Embedding:
"""Return embedding layer."""
return self.embeddings
def set_input_embeddings(self, value: nn.Embedding):
"""Set embedding layer."""
self.embeddings = value
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_hidden_states: bool = False,
) -> GeneMambaModelOutput:
"""
Args:
input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
attention_mask (torch.Tensor, optional): Attention mask of shape (batch_size, seq_len).
output_hidden_states (bool): Whether to output hidden states from all layers.
Returns:
GeneMambaModelOutput: Contains last_hidden_state, pooled_embedding, etc.
"""
# Get embeddings
hidden_states = self.embeddings(input_ids)
# Pass through Mamba layers
hidden_states = self.mamba_mixer(hidden_states, attention_mask)
# Apply final normalization
hidden_states = self.norm_f(hidden_states)
# Compute pooled embedding (cell representation)
if self.config.embedding_pooling == "CLS":
# Use first token (CLS)
pooled_embedding = hidden_states[:, 0, :]
elif self.config.embedding_pooling == "mean":
# Mean pooling over sequence
if attention_mask is not None:
mask = attention_mask.unsqueeze(-1).expand(hidden_states.shape).float()
pooled_embedding = (hidden_states * mask).sum(dim=1) / mask.sum(dim=1)
else:
pooled_embedding = hidden_states.mean(dim=1)
else:
raise ValueError(f"Unsupported embedding_pooling: {self.config.embedding_pooling}")
return GeneMambaModelOutput(
last_hidden_state=hidden_states,
pooled_embedding=pooled_embedding,
hidden_states=hidden_states if output_hidden_states else None,
embedding_pooling=self.config.embedding_pooling,
)
# ===========================
# Task-Specific Models
# ===========================
@register_model_for_auto_class("AutoModel")
class GeneMambaForMaskedLM(GeneMambaPreTrainedModel):
"""
GeneMamba model for masked language modeling (MLM).
Suitable for pretraining and domain adaptation.
Args:
config (GeneMambaConfig): Model configuration class.
"""
def __init__(self, config: GeneMambaConfig):
super().__init__(config)
self.genemamba = GeneMambaModel(config)
# Language modeling head
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.apply(self._init_weights)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_hidden_states: bool = False,
) -> GeneMambaMaskedLMOutput:
"""
Args:
input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
attention_mask (torch.Tensor, optional): Attention mask.
labels (torch.Tensor, optional): Target token ids for MLM loss.
output_hidden_states (bool): Whether to output hidden states.
Returns:
GeneMambaMaskedLMOutput: Contains logits and optional loss.
"""
outputs = self.genemamba(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=output_hidden_states,
)
logits = self.lm_head(outputs.last_hidden_state)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
return GeneMambaMaskedLMOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states if output_hidden_states else None,
)
@register_model_for_auto_class("AutoModelForSequenceClassification")
class GeneMambaForSequenceClassification(GeneMambaPreTrainedModel):
"""
GeneMamba model for sequence classification tasks.
Ideal for cell type annotation, tissue classification, etc.
Args:
config (GeneMambaConfig): Model configuration class.
"""
def __init__(self, config: GeneMambaConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
self.genemamba = GeneMambaModel(config)
# Classification head
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
self.apply(self._init_weights)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_hidden_states: bool = False,
) -> GeneMambaSequenceClassifierOutput:
"""
Args:
input_ids (torch.Tensor): Token indices of shape (batch_size, seq_len).
attention_mask (torch.Tensor, optional): Attention mask.
labels (torch.Tensor, optional): Class labels for classification loss.
output_hidden_states (bool): Whether to output hidden states.
Returns:
GeneMambaSequenceClassifierOutput: Contains logits, optional loss, and embedding.
"""
outputs = self.genemamba(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=output_hidden_states,
)
pooled_embedding = outputs.pooled_embedding
logits = self.classifier(self.dropout(pooled_embedding))
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits, labels)
return GeneMambaSequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states if output_hidden_states else None,
pooled_embedding=pooled_embedding,
)
# Register tokenizer class
register_model_for_auto_class("AutoModelForMaskedLM")(GeneMambaForMaskedLM)