cellfm-800m / README.md
krkawzq's picture
Update README.md
c7cd685 verified
# CellFM-800M
## Model Description
CellFM is a large-scale foundation model pre-trained on transcriptomics of 100 million human cells using a retention-based architecture (MAE Autobin).
- **Model Size**: 800M parameters
- **Pre-training Data**: 100M human cells
- **Architecture**: Retention-based Transformer (MAE Autobin)
- **Vocabulary**: 24,072 genes
- **Pre-training Task**: Masked Autoencoding (MAE)
## Model Details
- **Source**: [biomed-AI/CellFM](https://github.com/biomed-AI/CellFM)
- **Original Framework**: MindSpore
- **Converted to**: PyTorch (PerturbLab format)
- **License**: See original repository for details
## Architecture Specifications
- **Hidden Dimension**: 1536
- **Number of Layers**: 40
- **Number of Attention Heads**: 48
- **Dropout**: 0.1
- **Max Sequence Length**: 2048 genes
## Usage
### Load Model
```python
from perturblab.model.cellfm import CellFMModel
# Load pretrained model (automatically downloads if needed)
model = CellFMModel.from_pretrained('cellfm-800m')
# Or use short name
model = CellFMModel.from_pretrained('800m')
# Or from local path
model = CellFMModel.from_pretrained('./weights/cellfm-800m')
```
### Generate Cell Embeddings
```python
import scanpy as sc
# Load your data
adata = sc.read_h5ad('your_data.h5ad')
# Preprocess
adata = CellFMModel.prepare_data(adata)
# Get embeddings (use smaller batch size for 800M model)
embeddings = model.predict_embeddings(
adata,
batch_size=8, # Smaller batch size for larger model
return_cls_token=True,
)
# Access cell embeddings
cell_embeddings = embeddings['cell_embeddings'] # Shape: (n_cells, 1536)
```
### Fine-tune for Classification
```python
from perturblab.model.cellfm import CellFMModel, CellFMConfig
# Initialize model with classification head
config = CellFMConfig(
model_name='800M',
n_genes=24072,
enc_dims=1536,
enc_nlayers=40,
enc_num_heads=48,
num_cls=10, # Number of cell types
)
model = CellFMModel(config, for_finetuning=True)
# Load pretrained weights
model.load_weights('./weights/cellfm-800m/model.pt')
# Get dataloaders
train_loader = model.get_dataloader(train_data, batch_size=4)['train']
val_loader = model.get_dataloader(val_data, batch_size=4)['train']
# Train
model.train_model(
train_dataloader=train_loader,
val_dataloader=val_loader,
num_epochs=10,
learning_rate=1e-4,
)
```
### Perturbation Prediction
```python
from perturblab.model.cellfm import CellFMPerturbationModel
from perturblab.data import PerturbationData
# Load perturbation data
data = PerturbationData.from_anndata(adata)
data.split_data(train=0.7, val=0.15, test=0.15)
# Initialize model
model = CellFMPerturbationModel.from_pretrained('cellfm-800m')
# Initialize perturbation head from dataset
model.init_perturbation_head_from_dataset(data)
# Train (use smaller batch size)
model.train_model(data, epochs=20, batch_size=4)
# Predict
predictions = model.predict_perturbation(data, split='test')
# Evaluate
metrics = model.evaluate(data, split='test')
print(f"Pearson correlation: {metrics['pearson']:.4f}")
```
## Performance Notes
- **Memory Requirements**: ~3-4GB GPU memory for inference (batch_size=8)
- **Recommended Batch Size**: 4-8 for training, 8-16 for inference
- **Inference Speed**: ~2-3x slower than 80M model
- **Loading Time**: ~5-10 seconds
## Model Architecture
- **Encoder**: Retention-based Transformer (MAE Autobin)
- Auto-discretization embedding layer
- 40 retention layers with 48 attention heads each
- Hidden dimension: 1536
- Layer normalization and residual connections
- **Pre-training**: Masked Autoencoding (MAE)
- Masks 50% of genes
- Reconstructs masked gene expression
- **Output**: Gene-level embeddings + CLS token (1536-dimensional)
## Comparison with 80M Model
| Feature | 80M | 800M |
|---------|-----|------|
| Parameters | 80M | 800M |
| Hidden Dim | 1536 | 1536 |
| Layers | 2 | 40 |
| Heads | 48 | 48 |
| Genes | 27,855 | 24,072 |
| Memory (Inference) | ~1-2GB | ~3-4GB |
| Speed | Faster | Slower |
| Performance | Good | Better |
The 800M model provides significantly better representation quality due to its deeper architecture (40 layers vs 2 layers), at the cost of increased computational requirements.
## Files
- `config.json`: Model configuration
- `model.pt`: Model weights (PyTorch state dict, ~3.0GB)
- `README.md`: This file
- `.gitattributes`: Git LFS configuration
## Citation
If you use CellFM in your research, please cite:
```bibtex
@article{cellfm2024,
title={CellFM: A Large-Scale Foundation Model for Single-Cell Transcriptomics},
author={...},
journal={...},
year={2024}
}
```
## References
- Original Repository: https://github.com/biomed-AI/CellFM
- PyTorch Version: https://github.com/biomed-AI/CellFM-torch
- Paper: [Link to paper when available]
## Notes
- This model was converted from the original MindSpore checkpoint
- The gene vocabulary (24,072 genes) may differ from the 80M model (27,855 genes)
- For best results, ensure your data preprocessing matches the model's expected input format
- Use `CellFMModel.prepare_data()` to automatically preprocess your data