GeneMamba / README.md
mineself2016's picture
Sync latest GeneMamba docs and next-token training updates
3d0c815 verified
---
library_name: transformers
pipeline_tag: feature-extraction
tags:
- genemamba
- mamba
- genomics
- single-cell
- custom_code
---
# GeneMamba: Foundation Model for Single-Cell Analysis
A Hugging Face compatible implementation of GeneMamba, a foundational state-space model (Mamba) designed for advanced single-cell RNA-seq analysis.
## πŸ“‹ Table of Contents
- [Overview](#overview)
- [Installation](#installation)
- [Quick Start](#quick-start)
- [Phase 1: Extract Cell Embeddings](#phase-1-extract-cell-embeddings)
- [Phase 2: Downstream Tasks](#phase-2-downstream-tasks)
- [Phase 3: Train from Scratch](#phase-3-train-from-scratch)
- [Model Variants](#model-variants)
- [Architecture](#architecture)
- [Datasets](#datasets)
- [Usage Guide](#usage-guide)
- [Citation](#citation)
---
## Overview
GeneMamba is a **state-space model (SSM)** based on **Mamba architecture** optimized for single-cell gene expression analysis. The model:
- **Takes ranked gene sequences** as input (genes sorted by expression level)
- **Outputs cell embeddings** suitable for clustering, classification, and batch integration
- **Supports multiple downstream tasks** including cell type annotation and next-token pretraining
- **Is compatible with Hugging Face Transformers** for easy integration into existing pipelines
### Key Features
βœ… **Efficient Sequence Processing**: SSM-based architecture with linear complexity
βœ… **Cell Representation Learning**: Direct cell embedding without intermediate steps
βœ… **Multi-task Support**: Classification, next-token pretraining, and embeddings in one model
βœ… **Hugging Face Integration**: Standard `from_pretrained()` and `save_pretrained()` interface
βœ… **Production Ready**: Pretrained checkpoints available on Hugging Face Hub
---
## Datasets
The pretraining dataset and downstream datasets can be found in the official GeneMamba GitHub repository:
https://github.com/MineSelf2016/GeneMamba
---
## Installation
### Option 1: Install from Source
```bash
cd GeneMamba_HuggingFace
pip install -e .
```
### Option 2: Install from PyPI (coming soon)
```bash
pip install genemamba-hf
```
### Dependencies
- Python >= 3.9
- PyTorch >= 2.0
- Transformers >= 4.40.0
- mamba-ssm >= 2.2.0
Install all dependencies:
```bash
pip install -r requirements.txt
```
---
## Quick Start
### Phase 1: Extract Cell Embeddings
This is the **most common use case**. Extract single-cell embeddings for downstream analysis:
```python
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModel
# Load pretrained model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(
"mineself2016/GeneMamba",
trust_remote_code=True
)
model = AutoModel.from_pretrained(
"mineself2016/GeneMamba",
trust_remote_code=True
)
# Prepare input: ranked gene sequences
# Shape: (batch_size, seq_len) with gene Ensembl IDs as token IDs
batch_size, seq_len = 8, 2048
input_ids = torch.randint(2, 25426, (batch_size, seq_len))
# Extract cell embedding
outputs = model(input_ids)
cell_embeddings = outputs.pooled_embedding # shape: (8, 512)
print(f"Cell embeddings shape: {cell_embeddings.shape}")
# Output: Cell embeddings shape: torch.Size([8, 512])
```
#### Key Points
- **Input format**: Ranked sequences of gene token IDs (genes sorted by expression descending)
- **Recommended embedding**: Always use `outputs.pooled_embedding` for downstream tasks
- **Pooling method**: Default is mean pooling over sequence (see `config.embedding_pooling`)
- **Sequence length**: Maximum 2048; shorter sequences are auto-padded
- **Token vocabulary**: Based on Ensembl Gene IDs (e.g., `ENSG00000000003`)
#### Use Cases for Cell Embeddings
- **Clustering**: KMeans, Leiden, etc.
- **Visualization**: UMAP, t-SNE
- **Classification**: Logistic regression with frozen embeddings
- **Batch integration**: Evaluate with batch correction metrics
- **Retrieval**: Find similar cells or genes
---
### Phase 2: Downstream Tasks
Use GeneMamba for **cell type annotation** and other sequence classification tasks:
```python
import torch
from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments
from torch.utils.data import Dataset
# Load model with classification head
model = AutoModelForSequenceClassification.from_pretrained(
"mineself2016/GeneMamba",
num_labels=10, # number of cell types
trust_remote_code=True
)
# Prepare dataset
class GeneExpressionDataset(Dataset):
def __init__(self, input_ids, labels):
self.input_ids = input_ids
self.labels = labels
def __len__(self):
return len(self.input_ids)
def __getitem__(self, idx):
return {
"input_ids": self.input_ids[idx],
"labels": self.labels[idx]
}
# Example data
X_train = torch.randint(2, 25426, (1000, 2048))
y_train = torch.randint(0, 10, (1000,))
train_dataset = GeneExpressionDataset(X_train, y_train)
# Fine-tune with Trainer
trainer = Trainer(
model=model,
args=TrainingArguments(
output_dir="./results",
num_train_epochs=5,
per_device_train_batch_size=32,
learning_rate=2e-5,
save_strategy="epoch",
),
train_dataset=train_dataset,
)
trainer.train()
```
#### Classification Variants
The model also supports:
- **Binary classification**: `num_labels=2`
- **Multi-class**: `num_labels=N`
- **Multi-label**: Use `BCEWithLogitsLoss` in custom training loop
- **Regression**: Modify head (custom implementation needed)
---
### Phase 3: Train from Scratch
Train a new GeneMamba model with **next-token prediction**.
If a checkpoint exists, resume automatically; otherwise start from scratch.
```python
import torch
from pathlib import Path
from transformers import AutoTokenizer, AutoConfig, AutoModelForMaskedLM, Trainer, TrainingArguments
from transformers.trainer_utils import get_last_checkpoint
tokenizer = AutoTokenizer.from_pretrained(
"mineself2016/GeneMamba",
trust_remote_code=True,
)
print("vocab_size:", tokenizer.vocab_size) # 25426
print("unk/pad:", tokenizer.unk_token_id, tokenizer.pad_token_id) # 0, 1
print("cls/mask:", tokenizer.cls_token_id, tokenizer.mask_token_id) # None, None
# Build model config (no local modeling file import required)
config = AutoConfig.from_pretrained("mineself2016/GeneMamba", trust_remote_code=True)
config.vocab_size = 25426
config.hidden_size = 512
config.num_hidden_layers = 24
config.max_position_embeddings = 2048
config.mamba_mode = "mean"
# Resume if checkpoint exists
output_dir = "./from_scratch_pretrain"
checkpoint_dir = Path(output_dir) / "checkpoint-last"
if checkpoint_dir.exists():
resume_from_checkpoint = str(checkpoint_dir)
else:
resume_from_checkpoint = get_last_checkpoint(output_dir)
if resume_from_checkpoint is not None:
model = AutoModelForMaskedLM.from_pretrained(
resume_from_checkpoint,
trust_remote_code=True,
local_files_only=True,
)
else:
model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True)
class NextTokenTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
input_ids = inputs["input_ids"]
logits = model(input_ids=input_ids).logits
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = input_ids[:, 1:].contiguous().to(shift_logits.device)
loss = torch.nn.functional.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
)
return loss
trainer = NextTokenTrainer(
model=model,
args=TrainingArguments(
output_dir=output_dir,
num_train_epochs=3,
per_device_train_batch_size=32,
learning_rate=2e-5,
),
train_dataset=train_dataset,
)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
```
---
## Model Variants
We provide several pretrained checkpoint sizes:
| Model Name | Layers | Hidden Size | Parameters | Download |
|-----------|--------|------------|-----------|----------|
| `GeneMamba-24l-512d` | 24 | 512 | ~170M | πŸ€— Hub |
| `GeneMamba-24l-768d` | 24 | 768 | ~380M | πŸ€— Hub |
| `GeneMamba-48l-512d` | 48 | 512 | ~340M | πŸ€— Hub |
| `GeneMamba-48l-768d` | 48 | 768 | ~750M | πŸ€— Hub |
All models share the same tokenizer (25,426 Ensembl Gene IDs + special tokens).
---
## Architecture
### Model Components
```
GeneMambaModel (Backbone)
β”œβ”€β”€ Embedding Layer (vocab_size Γ— hidden_size)
β”œβ”€β”€ MambaMixer (Bidirectional SSM processing)
β”‚ β”œβ”€β”€ EncoderLayer 0
β”‚ β”œβ”€β”€ EncoderLayer 1
β”‚ β”œβ”€β”€ ...
β”‚ └── EncoderLayer N-1
β”œβ”€β”€ RMSNorm (Layer Normalization)
└── Output: Pooled Embedding (batch_size Γ— hidden_size)
Task-Specific Heads:
β”œβ”€β”€ GeneMambaForSequenceClassification
β”‚ └── Linear(hidden_size β†’ num_labels)
β”œβ”€β”€ GeneMambaForMaskedLM
β”‚ └── Linear(hidden_size β†’ vocab_size)
```
### Key Design Choices
- **Bidirectional Mamba Block**: Bidirectional Mamba enables significant improvement in gene rank reconstruction task
- **Pooling Strategy**: Bidirectional Mamba with multiple aggregation modes (mean/sum/concat/gate)
- **Regularization**: Dropout on classification head
- **Activation**: No explicit activation (Mamba uses internal gating)
---
<!-- ## Usage Guide
### Loading Models
```python
# Standard loading (backbone only)
from transformers import AutoModel
model = AutoModel.from_pretrained("mineself2016/GeneMamba", trust_remote_code=True)
# Classification
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained(
"mineself2016/GeneMamba", num_labels=10, trust_remote_code=True
)
# Language modeling head (used with next-token objective)
from transformers import AutoModelForMaskedLM
model = AutoModelForMaskedLM.from_pretrained("mineself2016/GeneMamba", trust_remote_code=True)
```
Load other model sizes from subfolders:
```python
model_24l_768d = AutoModel.from_pretrained(
"mineself2016/GeneMamba",
subfolder="24l-768d",
trust_remote_code=True,
)
```
### Saving Models
```python
# Save locally
model.save_pretrained("./my_model")
tokenizer.save_pretrained("./my_model")
# Push to Hugging Face Hub
model.push_to_hub("username/my-genemamba")
tokenizer.push_to_hub("username/my-genemamba")
```
### Configuration
All hyperparameters are stored in `config.json`:
```json
{
"model_type": "genemamba",
"hidden_size": 512,
"num_hidden_layers": 24,
"vocab_size": 25426,
"mamba_mode": "mean",
"embedding_pooling": "mean"
}
```
Modify at runtime:
```python
config = model.config
config.hidden_dropout_prob = 0.2
```
--- -->
## Important Notes ⚠️
### Input Format
**GeneMamba expects a very specific input format:**
1. Each cell is represented as a **ranked sequence** of genes
2. Genes should be **sorted by expression value in descending order**
3. Use **Ensembl Gene IDs** as tokens (e.g., `ENSG00000000003`)
4. Sequences are **padded/truncated to max_position_embeddings** (default 2048)
**Example preparation:**
```python
import numpy as np
import scanpy as sc
# Load scRNA-seq data
adata = sc.read_h5ad("data.h5ad")
# For each cell, rank genes by expression
gene_ids = []
for cell_idx in range(adata.n_obs):
expression = adata.X[cell_idx].toarray().flatten()
ranked_indices = np.argsort(-expression) # Descending order
ranked_gene_ids = [gene_id_mapping[idx] for idx in ranked_indices[:2048]]
gene_ids.append(ranked_gene_ids)
# Convert to token IDs
input_ids = tokenizer(gene_ids, return_tensors="pt", padding=True)["input_ids"]
```
### Limitations
- **Gene vocabulary**: Only genes in Ensembl (25,426 total) can be directly tokenized
- **Sequence order**: Expects ranked order; random order will degrade performance
- **Batch size**: Larger batches (32-64) recommended for better convergence
- **GPU memory**: Base model needs ~10GB for batch_size=32; larger variants need more
---
## Examples
See the `examples/` directory for complete scripts:
- `1_extract_embeddings.py` - Extract cell embeddings
- `2_finetune_classification.py` - Cell type annotation
- `3_pretrain_from_scratch.py` - Train from scratch (next-token + optional resume)
---
## Citation
If you find GeneMamba is useful in your research, please cite:
```bibtex
@article{qi2025genemamba,
title={GeneMamba: An Efficient and Effective Foundation Model on Single Cell Data},
author={Qi, Cong and Fang, Hanzhang and Jiang, Siqi and Song, Xun and Hu, Tianxing and Zhi, Wei},
journal={arXiv preprint arXiv:2504.16956},
year={2026}
}
```
---
## Troubleshooting
### `trust_remote_code=True` Error
This is expected for custom models. Either:
1. Set `trust_remote_code=True` (safe if loading from official repo)
2. Or use `sys.path.insert(0, '.')` if loading local code
### Old Cached Code / Shape Mismatch
If you still see old loading errors after an update, force refresh files from Hub:
```python
from transformers import AutoModel
model = AutoModel.from_pretrained(
"mineself2016/GeneMamba",
trust_remote_code=True,
force_download=True,
)
```
You can also clear local cache if needed:
```bash
rm -rf ~/.cache/huggingface/hub/models--mineself2016--GeneMamba
```
### Out of Memory (OOM)
Reduce batch size:
```python
args = TrainingArguments(
per_device_train_batch_size=8, # Reduce from 32
...
)
```
### Tokenizer Not Found
Make sure tokenizer files are in the same directory:
```
GeneMamba_repo/
β”œβ”€β”€ config.json
β”œβ”€β”€ model.safetensors
β”œβ”€β”€ tokenizer.json ← Required
β”œβ”€β”€ tokenizer_config.json ← Required
└── ...
```
---
**Last Updated**: March 2026