| --- |
| library_name: transformers |
| pipeline_tag: feature-extraction |
| tags: |
| - genemamba |
| - mamba |
| - genomics |
| - single-cell |
| - custom_code |
| --- |
| |
| # GeneMamba: Foundation Model for Single-Cell Analysis |
|
|
| A Hugging Face compatible implementation of GeneMamba, a foundational state-space model (Mamba) designed for advanced single-cell RNA-seq analysis. |
|
|
| ## π Table of Contents |
|
|
| - [Overview](#overview) |
| - [Installation](#installation) |
| - [Quick Start](#quick-start) |
| - [Phase 1: Extract Cell Embeddings](#phase-1-extract-cell-embeddings) |
| - [Phase 2: Downstream Tasks](#phase-2-downstream-tasks) |
| - [Phase 3: Train from Scratch](#phase-3-train-from-scratch) |
| - [Model Variants](#model-variants) |
| - [Architecture](#architecture) |
| - [Datasets](#datasets) |
| - [Usage Guide](#usage-guide) |
| - [Citation](#citation) |
|
|
| --- |
|
|
| ## Overview |
|
|
| GeneMamba is a **state-space model (SSM)** based on **Mamba architecture** optimized for single-cell gene expression analysis. The model: |
|
|
| - **Takes ranked gene sequences** as input (genes sorted by expression level) |
| - **Outputs cell embeddings** suitable for clustering, classification, and batch integration |
| - **Supports multiple downstream tasks** including cell type annotation and next-token pretraining |
| - **Is compatible with Hugging Face Transformers** for easy integration into existing pipelines |
|
|
| ### Key Features |
|
|
| β
**Efficient Sequence Processing**: SSM-based architecture with linear complexity |
| β
**Cell Representation Learning**: Direct cell embedding without intermediate steps |
| β
**Multi-task Support**: Classification, next-token pretraining, and embeddings in one model |
| β
**Hugging Face Integration**: Standard `from_pretrained()` and `save_pretrained()` interface |
| β
**Production Ready**: Pretrained checkpoints available on Hugging Face Hub |
|
|
| --- |
|
|
| ## Datasets |
|
|
| The pretraining dataset and downstream datasets can be found in the official GeneMamba GitHub repository: |
|
|
| https://github.com/MineSelf2016/GeneMamba |
|
|
| --- |
|
|
| ## Installation |
|
|
| ### Option 1: Install from Source |
|
|
| ```bash |
| cd GeneMamba_HuggingFace |
| pip install -e . |
| ``` |
|
|
| ### Option 2: Install from PyPI (coming soon) |
|
|
| ```bash |
| pip install genemamba-hf |
| ``` |
|
|
| ### Dependencies |
|
|
| - Python >= 3.9 |
| - PyTorch >= 2.0 |
| - Transformers >= 4.40.0 |
| - mamba-ssm >= 2.2.0 |
|
|
| Install all dependencies: |
|
|
| ```bash |
| pip install -r requirements.txt |
| ``` |
|
|
| --- |
|
|
| ## Quick Start |
|
|
| ### Phase 1: Extract Cell Embeddings |
|
|
| This is the **most common use case**. Extract single-cell embeddings for downstream analysis: |
|
|
| ```python |
| import torch |
| import numpy as np |
| from transformers import AutoTokenizer, AutoModel |
| |
| # Load pretrained model and tokenizer |
| tokenizer = AutoTokenizer.from_pretrained( |
| "mineself2016/GeneMamba", |
| trust_remote_code=True |
| ) |
| model = AutoModel.from_pretrained( |
| "mineself2016/GeneMamba", |
| trust_remote_code=True |
| ) |
| |
| # Prepare input: ranked gene sequences |
| # Shape: (batch_size, seq_len) with gene Ensembl IDs as token IDs |
| batch_size, seq_len = 8, 2048 |
| input_ids = torch.randint(2, 25426, (batch_size, seq_len)) |
| |
| # Extract cell embedding |
| outputs = model(input_ids) |
| cell_embeddings = outputs.pooled_embedding # shape: (8, 512) |
| |
| print(f"Cell embeddings shape: {cell_embeddings.shape}") |
| # Output: Cell embeddings shape: torch.Size([8, 512]) |
| ``` |
|
|
| #### Key Points |
|
|
| - **Input format**: Ranked sequences of gene token IDs (genes sorted by expression descending) |
| - **Recommended embedding**: Always use `outputs.pooled_embedding` for downstream tasks |
| - **Pooling method**: Default is mean pooling over sequence (see `config.embedding_pooling`) |
| - **Sequence length**: Maximum 2048; shorter sequences are auto-padded |
| - **Token vocabulary**: Based on Ensembl Gene IDs (e.g., `ENSG00000000003`) |
|
|
| #### Use Cases for Cell Embeddings |
|
|
| - **Clustering**: KMeans, Leiden, etc. |
| - **Visualization**: UMAP, t-SNE |
| - **Classification**: Logistic regression with frozen embeddings |
| - **Batch integration**: Evaluate with batch correction metrics |
| - **Retrieval**: Find similar cells or genes |
|
|
| --- |
|
|
| ### Phase 2: Downstream Tasks |
|
|
| Use GeneMamba for **cell type annotation** and other sequence classification tasks: |
|
|
| ```python |
| import torch |
| from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments |
| from torch.utils.data import Dataset |
| |
| # Load model with classification head |
| model = AutoModelForSequenceClassification.from_pretrained( |
| "mineself2016/GeneMamba", |
| num_labels=10, # number of cell types |
| trust_remote_code=True |
| ) |
| |
| # Prepare dataset |
| class GeneExpressionDataset(Dataset): |
| def __init__(self, input_ids, labels): |
| self.input_ids = input_ids |
| self.labels = labels |
| |
| def __len__(self): |
| return len(self.input_ids) |
| |
| def __getitem__(self, idx): |
| return { |
| "input_ids": self.input_ids[idx], |
| "labels": self.labels[idx] |
| } |
| |
| # Example data |
| X_train = torch.randint(2, 25426, (1000, 2048)) |
| y_train = torch.randint(0, 10, (1000,)) |
| |
| train_dataset = GeneExpressionDataset(X_train, y_train) |
| |
| # Fine-tune with Trainer |
| trainer = Trainer( |
| model=model, |
| args=TrainingArguments( |
| output_dir="./results", |
| num_train_epochs=5, |
| per_device_train_batch_size=32, |
| learning_rate=2e-5, |
| save_strategy="epoch", |
| ), |
| train_dataset=train_dataset, |
| ) |
| |
| trainer.train() |
| ``` |
|
|
| #### Classification Variants |
|
|
| The model also supports: |
|
|
| - **Binary classification**: `num_labels=2` |
| - **Multi-class**: `num_labels=N` |
| - **Multi-label**: Use `BCEWithLogitsLoss` in custom training loop |
| - **Regression**: Modify head (custom implementation needed) |
|
|
| --- |
|
|
| ### Phase 3: Train from Scratch |
|
|
| Train a new GeneMamba model with **next-token prediction**. |
| If a checkpoint exists, resume automatically; otherwise start from scratch. |
|
|
| ```python |
| import torch |
| from pathlib import Path |
| from transformers import AutoTokenizer, AutoConfig, AutoModelForMaskedLM, Trainer, TrainingArguments |
| from transformers.trainer_utils import get_last_checkpoint |
| |
| tokenizer = AutoTokenizer.from_pretrained( |
| "mineself2016/GeneMamba", |
| trust_remote_code=True, |
| ) |
| |
| print("vocab_size:", tokenizer.vocab_size) # 25426 |
| print("unk/pad:", tokenizer.unk_token_id, tokenizer.pad_token_id) # 0, 1 |
| print("cls/mask:", tokenizer.cls_token_id, tokenizer.mask_token_id) # None, None |
| |
| # Build model config (no local modeling file import required) |
| config = AutoConfig.from_pretrained("mineself2016/GeneMamba", trust_remote_code=True) |
| config.vocab_size = 25426 |
| config.hidden_size = 512 |
| config.num_hidden_layers = 24 |
| config.max_position_embeddings = 2048 |
| config.mamba_mode = "mean" |
| |
| # Resume if checkpoint exists |
| output_dir = "./from_scratch_pretrain" |
| checkpoint_dir = Path(output_dir) / "checkpoint-last" |
| |
| if checkpoint_dir.exists(): |
| resume_from_checkpoint = str(checkpoint_dir) |
| else: |
| resume_from_checkpoint = get_last_checkpoint(output_dir) |
| |
| if resume_from_checkpoint is not None: |
| model = AutoModelForMaskedLM.from_pretrained( |
| resume_from_checkpoint, |
| trust_remote_code=True, |
| local_files_only=True, |
| ) |
| else: |
| model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True) |
| |
| class NextTokenTrainer(Trainer): |
| def compute_loss(self, model, inputs, return_outputs=False): |
| input_ids = inputs["input_ids"] |
| logits = model(input_ids=input_ids).logits |
| shift_logits = logits[:, :-1, :].contiguous() |
| shift_labels = input_ids[:, 1:].contiguous().to(shift_logits.device) |
| loss = torch.nn.functional.cross_entropy( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1), |
| ) |
| return loss |
| |
| trainer = NextTokenTrainer( |
| model=model, |
| args=TrainingArguments( |
| output_dir=output_dir, |
| num_train_epochs=3, |
| per_device_train_batch_size=32, |
| learning_rate=2e-5, |
| ), |
| train_dataset=train_dataset, |
| ) |
| |
| trainer.train(resume_from_checkpoint=resume_from_checkpoint) |
| ``` |
|
|
| --- |
|
|
| ## Model Variants |
|
|
| We provide several pretrained checkpoint sizes: |
|
|
| | Model Name | Layers | Hidden Size | Parameters | Download | |
| |-----------|--------|------------|-----------|----------| |
| | `GeneMamba-24l-512d` | 24 | 512 | ~170M | π€ Hub | |
| | `GeneMamba-24l-768d` | 24 | 768 | ~380M | π€ Hub | |
| | `GeneMamba-48l-512d` | 48 | 512 | ~340M | π€ Hub | |
| | `GeneMamba-48l-768d` | 48 | 768 | ~750M | π€ Hub | |
|
|
| All models share the same tokenizer (25,426 Ensembl Gene IDs + special tokens). |
|
|
| --- |
|
|
| ## Architecture |
|
|
| ### Model Components |
|
|
| ``` |
| GeneMambaModel (Backbone) |
| βββ Embedding Layer (vocab_size Γ hidden_size) |
| βββ MambaMixer (Bidirectional SSM processing) |
| β βββ EncoderLayer 0 |
| β βββ EncoderLayer 1 |
| β βββ ... |
| β βββ EncoderLayer N-1 |
| βββ RMSNorm (Layer Normalization) |
| βββ Output: Pooled Embedding (batch_size Γ hidden_size) |
| |
| Task-Specific Heads: |
| βββ GeneMambaForSequenceClassification |
| β βββ Linear(hidden_size β num_labels) |
| βββ GeneMambaForMaskedLM |
| β βββ Linear(hidden_size β vocab_size) |
| ``` |
|
|
| ### Key Design Choices |
|
|
| - **Bidirectional Mamba Block**: Bidirectional Mamba enables significant improvement in gene rank reconstruction task |
| - **Pooling Strategy**: Bidirectional Mamba with multiple aggregation modes (mean/sum/concat/gate) |
| - **Regularization**: Dropout on classification head |
| - **Activation**: No explicit activation (Mamba uses internal gating) |
|
|
| --- |
|
|
| <!-- ## Usage Guide |
|
|
| ### Loading Models |
|
|
| ```python |
| # Standard loading (backbone only) |
| from transformers import AutoModel |
| model = AutoModel.from_pretrained("mineself2016/GeneMamba", trust_remote_code=True) |
| |
| # Classification |
| from transformers import AutoModelForSequenceClassification |
| model = AutoModelForSequenceClassification.from_pretrained( |
| "mineself2016/GeneMamba", num_labels=10, trust_remote_code=True |
| ) |
| |
| # Language modeling head (used with next-token objective) |
| from transformers import AutoModelForMaskedLM |
| model = AutoModelForMaskedLM.from_pretrained("mineself2016/GeneMamba", trust_remote_code=True) |
| ``` |
|
|
| Load other model sizes from subfolders: |
|
|
| ```python |
| model_24l_768d = AutoModel.from_pretrained( |
| "mineself2016/GeneMamba", |
| subfolder="24l-768d", |
| trust_remote_code=True, |
| ) |
| ``` |
|
|
| ### Saving Models |
|
|
| ```python |
| # Save locally |
| model.save_pretrained("./my_model") |
| tokenizer.save_pretrained("./my_model") |
| |
| # Push to Hugging Face Hub |
| model.push_to_hub("username/my-genemamba") |
| tokenizer.push_to_hub("username/my-genemamba") |
| ``` |
|
|
| ### Configuration |
|
|
| All hyperparameters are stored in `config.json`: |
|
|
| ```json |
| { |
| "model_type": "genemamba", |
| "hidden_size": 512, |
| "num_hidden_layers": 24, |
| "vocab_size": 25426, |
| "mamba_mode": "mean", |
| "embedding_pooling": "mean" |
| } |
| ``` |
|
|
| Modify at runtime: |
|
|
| ```python |
| config = model.config |
| config.hidden_dropout_prob = 0.2 |
| ``` |
|
|
| --- --> |
|
|
| ## Important Notes β οΈ |
|
|
| ### Input Format |
|
|
| **GeneMamba expects a very specific input format:** |
|
|
| 1. Each cell is represented as a **ranked sequence** of genes |
| 2. Genes should be **sorted by expression value in descending order** |
| 3. Use **Ensembl Gene IDs** as tokens (e.g., `ENSG00000000003`) |
| 4. Sequences are **padded/truncated to max_position_embeddings** (default 2048) |
|
|
| **Example preparation:** |
|
|
| ```python |
| import numpy as np |
| import scanpy as sc |
| |
| # Load scRNA-seq data |
| adata = sc.read_h5ad("data.h5ad") |
| |
| # For each cell, rank genes by expression |
| gene_ids = [] |
| for cell_idx in range(adata.n_obs): |
| expression = adata.X[cell_idx].toarray().flatten() |
| ranked_indices = np.argsort(-expression) # Descending order |
| ranked_gene_ids = [gene_id_mapping[idx] for idx in ranked_indices[:2048]] |
| gene_ids.append(ranked_gene_ids) |
| |
| # Convert to token IDs |
| input_ids = tokenizer(gene_ids, return_tensors="pt", padding=True)["input_ids"] |
| ``` |
|
|
| ### Limitations |
|
|
| - **Gene vocabulary**: Only genes in Ensembl (25,426 total) can be directly tokenized |
| - **Sequence order**: Expects ranked order; random order will degrade performance |
| - **Batch size**: Larger batches (32-64) recommended for better convergence |
| - **GPU memory**: Base model needs ~10GB for batch_size=32; larger variants need more |
| |
| --- |
| |
| ## Examples |
| |
| See the `examples/` directory for complete scripts: |
| |
| - `1_extract_embeddings.py` - Extract cell embeddings |
| - `2_finetune_classification.py` - Cell type annotation |
| - `3_pretrain_from_scratch.py` - Train from scratch (next-token + optional resume) |
|
|
| --- |
|
|
| ## Citation |
|
|
| If you find GeneMamba is useful in your research, please cite: |
|
|
| ```bibtex |
| @article{qi2025genemamba, |
| title={GeneMamba: An Efficient and Effective Foundation Model on Single Cell Data}, |
| author={Qi, Cong and Fang, Hanzhang and Jiang, Siqi and Song, Xun and Hu, Tianxing and Zhi, Wei}, |
| journal={arXiv preprint arXiv:2504.16956}, |
| year={2026} |
| } |
| ``` |
|
|
| --- |
|
|
| ## Troubleshooting |
|
|
| ### `trust_remote_code=True` Error |
|
|
| This is expected for custom models. Either: |
|
|
| 1. Set `trust_remote_code=True` (safe if loading from official repo) |
| 2. Or use `sys.path.insert(0, '.')` if loading local code |
|
|
| ### Old Cached Code / Shape Mismatch |
|
|
| If you still see old loading errors after an update, force refresh files from Hub: |
|
|
| ```python |
| from transformers import AutoModel |
| model = AutoModel.from_pretrained( |
| "mineself2016/GeneMamba", |
| trust_remote_code=True, |
| force_download=True, |
| ) |
| ``` |
|
|
| You can also clear local cache if needed: |
|
|
| ```bash |
| rm -rf ~/.cache/huggingface/hub/models--mineself2016--GeneMamba |
| ``` |
|
|
| ### Out of Memory (OOM) |
|
|
| Reduce batch size: |
|
|
| ```python |
| args = TrainingArguments( |
| per_device_train_batch_size=8, # Reduce from 32 |
| ... |
| ) |
| ``` |
|
|
| ### Tokenizer Not Found |
|
|
| Make sure tokenizer files are in the same directory: |
|
|
| ``` |
| GeneMamba_repo/ |
| βββ config.json |
| βββ model.safetensors |
| βββ tokenizer.json β Required |
| βββ tokenizer_config.json β Required |
| βββ ... |
| ``` |
|
|
| --- |
|
|
| **Last Updated**: March 2026 |
|
|