--- library_name: transformers pipeline_tag: feature-extraction tags: - genemamba - mamba - genomics - single-cell - custom_code --- # GeneMamba: Foundation Model for Single-Cell Analysis A Hugging Face compatible implementation of GeneMamba, a foundational state-space model (Mamba) designed for advanced single-cell RNA-seq analysis. ## 📋 Table of Contents - [Overview](#overview) - [Installation](#installation) - [Quick Start](#quick-start) - [Phase 1: Extract Cell Embeddings](#phase-1-extract-cell-embeddings) - [Phase 2: Downstream Tasks](#phase-2-downstream-tasks) - [Phase 3: Train from Scratch](#phase-3-train-from-scratch) - [Model Variants](#model-variants) - [Architecture](#architecture) - [Datasets](#datasets) - [Usage Guide](#usage-guide) - [Citation](#citation) --- ## Overview GeneMamba is a **state-space model (SSM)** based on **Mamba architecture** optimized for single-cell gene expression analysis. The model: - **Takes ranked gene sequences** as input (genes sorted by expression level) - **Outputs cell embeddings** suitable for clustering, classification, and batch integration - **Supports multiple downstream tasks** including cell type annotation and next-token pretraining - **Is compatible with Hugging Face Transformers** for easy integration into existing pipelines ### Key Features ✅ **Efficient Sequence Processing**: SSM-based architecture with linear complexity ✅ **Cell Representation Learning**: Direct cell embedding without intermediate steps ✅ **Multi-task Support**: Classification, next-token pretraining, and embeddings in one model ✅ **Hugging Face Integration**: Standard `from_pretrained()` and `save_pretrained()` interface ✅ **Production Ready**: Pretrained checkpoints available on Hugging Face Hub --- ## Datasets The pretraining dataset and downstream datasets can be found in the official GeneMamba GitHub repository: https://github.com/MineSelf2016/GeneMamba --- ## Installation ### Option 1: Install from Source ```bash cd GeneMamba_HuggingFace pip install -e . ``` ### Option 2: Install from PyPI (coming soon) ```bash pip install genemamba-hf ``` ### Dependencies - Python >= 3.9 - PyTorch >= 2.0 - Transformers >= 4.40.0 - mamba-ssm >= 2.2.0 Install all dependencies: ```bash pip install -r requirements.txt ``` --- ## Quick Start ### Phase 1: Extract Cell Embeddings This is the **most common use case**. Extract single-cell embeddings for downstream analysis: ```python import torch import numpy as np from transformers import AutoTokenizer, AutoModel # Load pretrained model and tokenizer tokenizer = AutoTokenizer.from_pretrained( "mineself2016/GeneMamba", trust_remote_code=True ) model = AutoModel.from_pretrained( "mineself2016/GeneMamba", trust_remote_code=True ) # Prepare input: ranked gene sequences # Shape: (batch_size, seq_len) with gene Ensembl IDs as token IDs batch_size, seq_len = 8, 2048 input_ids = torch.randint(2, 25426, (batch_size, seq_len)) # Extract cell embedding outputs = model(input_ids) cell_embeddings = outputs.pooled_embedding # shape: (8, 512) print(f"Cell embeddings shape: {cell_embeddings.shape}") # Output: Cell embeddings shape: torch.Size([8, 512]) ``` #### Key Points - **Input format**: Ranked sequences of gene token IDs (genes sorted by expression descending) - **Recommended embedding**: Always use `outputs.pooled_embedding` for downstream tasks - **Pooling method**: Default is mean pooling over sequence (see `config.embedding_pooling`) - **Sequence length**: Maximum 2048; shorter sequences are auto-padded - **Token vocabulary**: Based on Ensembl Gene IDs (e.g., `ENSG00000000003`) #### Use Cases for Cell Embeddings - **Clustering**: KMeans, Leiden, etc. - **Visualization**: UMAP, t-SNE - **Classification**: Logistic regression with frozen embeddings - **Batch integration**: Evaluate with batch correction metrics - **Retrieval**: Find similar cells or genes --- ### Phase 2: Downstream Tasks Use GeneMamba for **cell type annotation** and other sequence classification tasks: ```python import torch from transformers import AutoModelForSequenceClassification, Trainer, TrainingArguments from torch.utils.data import Dataset # Load model with classification head model = AutoModelForSequenceClassification.from_pretrained( "mineself2016/GeneMamba", num_labels=10, # number of cell types trust_remote_code=True ) # Prepare dataset class GeneExpressionDataset(Dataset): def __init__(self, input_ids, labels): self.input_ids = input_ids self.labels = labels def __len__(self): return len(self.input_ids) def __getitem__(self, idx): return { "input_ids": self.input_ids[idx], "labels": self.labels[idx] } # Example data X_train = torch.randint(2, 25426, (1000, 2048)) y_train = torch.randint(0, 10, (1000,)) train_dataset = GeneExpressionDataset(X_train, y_train) # Fine-tune with Trainer trainer = Trainer( model=model, args=TrainingArguments( output_dir="./results", num_train_epochs=5, per_device_train_batch_size=32, learning_rate=2e-5, save_strategy="epoch", ), train_dataset=train_dataset, ) trainer.train() ``` #### Classification Variants The model also supports: - **Binary classification**: `num_labels=2` - **Multi-class**: `num_labels=N` - **Multi-label**: Use `BCEWithLogitsLoss` in custom training loop - **Regression**: Modify head (custom implementation needed) --- ### Phase 3: Train from Scratch Train a new GeneMamba model with **next-token prediction**. If a checkpoint exists, resume automatically; otherwise start from scratch. ```python import torch from pathlib import Path from transformers import AutoTokenizer, AutoConfig, AutoModelForMaskedLM, Trainer, TrainingArguments from transformers.trainer_utils import get_last_checkpoint tokenizer = AutoTokenizer.from_pretrained( "mineself2016/GeneMamba", trust_remote_code=True, ) print("vocab_size:", tokenizer.vocab_size) # 25426 print("unk/pad:", tokenizer.unk_token_id, tokenizer.pad_token_id) # 0, 1 print("cls/mask:", tokenizer.cls_token_id, tokenizer.mask_token_id) # None, None # Build model config (no local modeling file import required) config = AutoConfig.from_pretrained("mineself2016/GeneMamba", trust_remote_code=True) config.vocab_size = 25426 config.hidden_size = 512 config.num_hidden_layers = 24 config.max_position_embeddings = 2048 config.mamba_mode = "mean" # Resume if checkpoint exists output_dir = "./from_scratch_pretrain" checkpoint_dir = Path(output_dir) / "checkpoint-last" if checkpoint_dir.exists(): resume_from_checkpoint = str(checkpoint_dir) else: resume_from_checkpoint = get_last_checkpoint(output_dir) if resume_from_checkpoint is not None: model = AutoModelForMaskedLM.from_pretrained( resume_from_checkpoint, trust_remote_code=True, local_files_only=True, ) else: model = AutoModelForMaskedLM.from_config(config, trust_remote_code=True) class NextTokenTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False): input_ids = inputs["input_ids"] logits = model(input_ids=input_ids).logits shift_logits = logits[:, :-1, :].contiguous() shift_labels = input_ids[:, 1:].contiguous().to(shift_logits.device) loss = torch.nn.functional.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ) return loss trainer = NextTokenTrainer( model=model, args=TrainingArguments( output_dir=output_dir, num_train_epochs=3, per_device_train_batch_size=32, learning_rate=2e-5, ), train_dataset=train_dataset, ) trainer.train(resume_from_checkpoint=resume_from_checkpoint) ``` --- ## Model Variants We provide several pretrained checkpoint sizes: | Model Name | Layers | Hidden Size | Parameters | Download | |-----------|--------|------------|-----------|----------| | `GeneMamba-24l-512d` | 24 | 512 | ~170M | 🤗 Hub | | `GeneMamba-24l-768d` | 24 | 768 | ~380M | 🤗 Hub | | `GeneMamba-48l-512d` | 48 | 512 | ~340M | 🤗 Hub | | `GeneMamba-48l-768d` | 48 | 768 | ~750M | 🤗 Hub | All models share the same tokenizer (25,426 Ensembl Gene IDs + special tokens). --- ## Architecture ### Model Components ``` GeneMambaModel (Backbone) ├── Embedding Layer (vocab_size × hidden_size) ├── MambaMixer (Bidirectional SSM processing) │ ├── EncoderLayer 0 │ ├── EncoderLayer 1 │ ├── ... │ └── EncoderLayer N-1 ├── RMSNorm (Layer Normalization) └── Output: Pooled Embedding (batch_size × hidden_size) Task-Specific Heads: ├── GeneMambaForSequenceClassification │ └── Linear(hidden_size → num_labels) ├── GeneMambaForMaskedLM │ └── Linear(hidden_size → vocab_size) ``` ### Key Design Choices - **Bidirectional Mamba Block**: Bidirectional Mamba enables significant improvement in gene rank reconstruction task - **Pooling Strategy**: Bidirectional Mamba with multiple aggregation modes (mean/sum/concat/gate) - **Regularization**: Dropout on classification head - **Activation**: No explicit activation (Mamba uses internal gating) --- ## Important Notes ⚠️ ### Input Format **GeneMamba expects a very specific input format:** 1. Each cell is represented as a **ranked sequence** of genes 2. Genes should be **sorted by expression value in descending order** 3. Use **Ensembl Gene IDs** as tokens (e.g., `ENSG00000000003`) 4. Sequences are **padded/truncated to max_position_embeddings** (default 2048) **Example preparation:** ```python import numpy as np import scanpy as sc # Load scRNA-seq data adata = sc.read_h5ad("data.h5ad") # For each cell, rank genes by expression gene_ids = [] for cell_idx in range(adata.n_obs): expression = adata.X[cell_idx].toarray().flatten() ranked_indices = np.argsort(-expression) # Descending order ranked_gene_ids = [gene_id_mapping[idx] for idx in ranked_indices[:2048]] gene_ids.append(ranked_gene_ids) # Convert to token IDs input_ids = tokenizer(gene_ids, return_tensors="pt", padding=True)["input_ids"] ``` ### Limitations - **Gene vocabulary**: Only genes in Ensembl (25,426 total) can be directly tokenized - **Sequence order**: Expects ranked order; random order will degrade performance - **Batch size**: Larger batches (32-64) recommended for better convergence - **GPU memory**: Base model needs ~10GB for batch_size=32; larger variants need more --- ## Examples See the `examples/` directory for complete scripts: - `1_extract_embeddings.py` - Extract cell embeddings - `2_finetune_classification.py` - Cell type annotation - `3_pretrain_from_scratch.py` - Train from scratch (next-token + optional resume) --- ## Citation If you find GeneMamba is useful in your research, please cite: ```bibtex @article{qi2025genemamba, title={GeneMamba: An Efficient and Effective Foundation Model on Single Cell Data}, author={Qi, Cong and Fang, Hanzhang and Jiang, Siqi and Song, Xun and Hu, Tianxing and Zhi, Wei}, journal={arXiv preprint arXiv:2504.16956}, year={2026} } ``` --- ## Troubleshooting ### `trust_remote_code=True` Error This is expected for custom models. Either: 1. Set `trust_remote_code=True` (safe if loading from official repo) 2. Or use `sys.path.insert(0, '.')` if loading local code ### Old Cached Code / Shape Mismatch If you still see old loading errors after an update, force refresh files from Hub: ```python from transformers import AutoModel model = AutoModel.from_pretrained( "mineself2016/GeneMamba", trust_remote_code=True, force_download=True, ) ``` You can also clear local cache if needed: ```bash rm -rf ~/.cache/huggingface/hub/models--mineself2016--GeneMamba ``` ### Out of Memory (OOM) Reduce batch size: ```python args = TrainingArguments( per_device_train_batch_size=8, # Reduce from 32 ... ) ``` ### Tokenizer Not Found Make sure tokenizer files are in the same directory: ``` GeneMamba_repo/ ├── config.json ├── model.safetensors ├── tokenizer.json ← Required ├── tokenizer_config.json ← Required └── ... ``` --- **Last Updated**: March 2026