File size: 5,157 Bytes
d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 d6a8cfb c7cd685 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 | # CellFM-800M
## Model Description
CellFM is a large-scale foundation model pre-trained on transcriptomics of 100 million human cells using a retention-based architecture (MAE Autobin).
- **Model Size**: 800M parameters
- **Pre-training Data**: 100M human cells
- **Architecture**: Retention-based Transformer (MAE Autobin)
- **Vocabulary**: 24,072 genes
- **Pre-training Task**: Masked Autoencoding (MAE)
## Model Details
- **Source**: [biomed-AI/CellFM](https://github.com/biomed-AI/CellFM)
- **Original Framework**: MindSpore
- **Converted to**: PyTorch (PerturbLab format)
- **License**: See original repository for details
## Architecture Specifications
- **Hidden Dimension**: 1536
- **Number of Layers**: 40
- **Number of Attention Heads**: 48
- **Dropout**: 0.1
- **Max Sequence Length**: 2048 genes
## Usage
### Load Model
```python
from perturblab.model.cellfm import CellFMModel
# Load pretrained model (automatically downloads if needed)
model = CellFMModel.from_pretrained('cellfm-800m')
# Or use short name
model = CellFMModel.from_pretrained('800m')
# Or from local path
model = CellFMModel.from_pretrained('./weights/cellfm-800m')
```
### Generate Cell Embeddings
```python
import scanpy as sc
# Load your data
adata = sc.read_h5ad('your_data.h5ad')
# Preprocess
adata = CellFMModel.prepare_data(adata)
# Get embeddings (use smaller batch size for 800M model)
embeddings = model.predict_embeddings(
adata,
batch_size=8, # Smaller batch size for larger model
return_cls_token=True,
)
# Access cell embeddings
cell_embeddings = embeddings['cell_embeddings'] # Shape: (n_cells, 1536)
```
### Fine-tune for Classification
```python
from perturblab.model.cellfm import CellFMModel, CellFMConfig
# Initialize model with classification head
config = CellFMConfig(
model_name='800M',
n_genes=24072,
enc_dims=1536,
enc_nlayers=40,
enc_num_heads=48,
num_cls=10, # Number of cell types
)
model = CellFMModel(config, for_finetuning=True)
# Load pretrained weights
model.load_weights('./weights/cellfm-800m/model.pt')
# Get dataloaders
train_loader = model.get_dataloader(train_data, batch_size=4)['train']
val_loader = model.get_dataloader(val_data, batch_size=4)['train']
# Train
model.train_model(
train_dataloader=train_loader,
val_dataloader=val_loader,
num_epochs=10,
learning_rate=1e-4,
)
```
### Perturbation Prediction
```python
from perturblab.model.cellfm import CellFMPerturbationModel
from perturblab.data import PerturbationData
# Load perturbation data
data = PerturbationData.from_anndata(adata)
data.split_data(train=0.7, val=0.15, test=0.15)
# Initialize model
model = CellFMPerturbationModel.from_pretrained('cellfm-800m')
# Initialize perturbation head from dataset
model.init_perturbation_head_from_dataset(data)
# Train (use smaller batch size)
model.train_model(data, epochs=20, batch_size=4)
# Predict
predictions = model.predict_perturbation(data, split='test')
# Evaluate
metrics = model.evaluate(data, split='test')
print(f"Pearson correlation: {metrics['pearson']:.4f}")
```
## Performance Notes
- **Memory Requirements**: ~3-4GB GPU memory for inference (batch_size=8)
- **Recommended Batch Size**: 4-8 for training, 8-16 for inference
- **Inference Speed**: ~2-3x slower than 80M model
- **Loading Time**: ~5-10 seconds
## Model Architecture
- **Encoder**: Retention-based Transformer (MAE Autobin)
- Auto-discretization embedding layer
- 40 retention layers with 48 attention heads each
- Hidden dimension: 1536
- Layer normalization and residual connections
- **Pre-training**: Masked Autoencoding (MAE)
- Masks 50% of genes
- Reconstructs masked gene expression
- **Output**: Gene-level embeddings + CLS token (1536-dimensional)
## Comparison with 80M Model
| Feature | 80M | 800M |
|---------|-----|------|
| Parameters | 80M | 800M |
| Hidden Dim | 1536 | 1536 |
| Layers | 2 | 40 |
| Heads | 48 | 48 |
| Genes | 27,855 | 24,072 |
| Memory (Inference) | ~1-2GB | ~3-4GB |
| Speed | Faster | Slower |
| Performance | Good | Better |
The 800M model provides significantly better representation quality due to its deeper architecture (40 layers vs 2 layers), at the cost of increased computational requirements.
## Files
- `config.json`: Model configuration
- `model.pt`: Model weights (PyTorch state dict, ~3.0GB)
- `README.md`: This file
- `.gitattributes`: Git LFS configuration
## Citation
If you use CellFM in your research, please cite:
```bibtex
@article{cellfm2024,
title={CellFM: A Large-Scale Foundation Model for Single-Cell Transcriptomics},
author={...},
journal={...},
year={2024}
}
```
## References
- Original Repository: https://github.com/biomed-AI/CellFM
- PyTorch Version: https://github.com/biomed-AI/CellFM-torch
- Paper: [Link to paper when available]
## Notes
- This model was converted from the original MindSpore checkpoint
- The gene vocabulary (24,072 genes) may differ from the 80M model (27,855 genes)
- For best results, ensure your data preprocessing matches the model's expected input format
- Use `CellFMModel.prepare_data()` to automatically preprocess your data
|