cellfm-80m / README.md
krkawzq's picture
Upload README.md
76be15a verified
# CellFM-80M
## Model Description
CellFM is a large-scale foundation model pre-trained on transcriptomics of 100 million human cells using a retention-based architecture (MAE Autobin).
- **Model Size**: 80M
- **Pre-training Data**: 100M human cells
- **Architecture**: Retention-based Transformer (MAE Autobin)
- **Vocabulary**: 27,855 genes
- **Pre-training Task**: Masked Autoencoding (MAE)
## Model Details
- **Source**: [biomed-AI/CellFM](https://github.com/biomed-AI/CellFM)
- **Original Framework**: MindSpore
- **Converted to**: PyTorch (PerturbLab format)
- **License**: See original repository for details
## Usage
### Load Model
```python
from perturblab.model.cellfm import CellFMModel
# Load pretrained model
model = CellFMModel.from_pretrained('cellfm-80m')
# Or from local path
model = CellFMModel.from_pretrained('./weights/cellfm-80m')
```
### Generate Cell Embeddings
```python
import scanpy as sc
from perturblab.data import PerturbationData
# Load your data
adata = sc.read_h5ad('your_data.h5ad')
# Preprocess
adata = CellFMModel.prepare_data(adata)
# Get embeddings
embeddings = model.predict_embeddings(
adata,
batch_size=32,
return_cls_token=True,
)
# Access cell embeddings
cell_embeddings = embeddings['cell_embeddings'] # Shape: (n_cells, enc_dims)
```
### Fine-tune for Classification
```python
from perturblab.model.cellfm import CellFMModel, CellFMConfig
# Initialize model with classification head
config = CellFMConfig(
model_name='80M',
num_cls=10, # Number of cell types
)
model = CellFMModel(config, for_finetuning=True)
# Load pretrained weights
model.load_weights('./weights/cellfm-80m/model.pt')
# Fine-tune on your labeled data
# ... (training code)
```
### Perturbation Prediction
```python
from perturblab.model.cellfm import CellFMPerturbationModel
from perturblab.data import PerturbationData
# Load perturbation data
data = PerturbationData.from_anndata(adata)
data.split_data(train=0.7, val=0.15, test=0.15)
# Initialize model
model = CellFMPerturbationModel.from_pretrained('cellfm-80m')
# Initialize perturbation head
model.init_perturbation_head_from_dataset(data)
# Train
model.train_model(data, epochs=20)
# Predict
predictions = model.predict_perturbation(data, split='test')
```
## Model Architecture
- **Encoder**: Retention-based Transformer (MAE Autobin)
- Auto-discretization embedding layer
- Multi-head retention mechanism
- Layer normalization and residual connections
- **Pre-training**: Masked Autoencoding (MAE)
- Masks 50% of genes
- Reconstructs masked gene expression
- **Output**: Gene-level embeddings + CLS token
## Files
- `config.json`: Model configuration
- `model.pt`: Model weights (PyTorch state dict)
- `README.md`: This file
## Citation
If you use CellFM in your research, please cite:
```bibtex
@article{cellfm2024,
title={CellFM: A Large-Scale Foundation Model for Single-Cell Transcriptomics},
author={...},
journal={...},
year={2024}
}
```
## References
- Original Repository: https://github.com/biomed-AI/CellFM
- PyTorch Version: https://github.com/biomed-AI/CellFM-torch
- Paper: [Link to paper when available]