| # CellFM-800M | |
| ## Model Description | |
| CellFM is a large-scale foundation model pre-trained on transcriptomics of 100 million human cells using a retention-based architecture (MAE Autobin). | |
| - **Model Size**: 800M parameters | |
| - **Pre-training Data**: 100M human cells | |
| - **Architecture**: Retention-based Transformer (MAE Autobin) | |
| - **Vocabulary**: 24,072 genes | |
| - **Pre-training Task**: Masked Autoencoding (MAE) | |
| ## Model Details | |
| - **Source**: [biomed-AI/CellFM](https://github.com/biomed-AI/CellFM) | |
| - **Original Framework**: MindSpore | |
| - **Converted to**: PyTorch (PerturbLab format) | |
| - **License**: See original repository for details | |
| ## Architecture Specifications | |
| - **Hidden Dimension**: 1536 | |
| - **Number of Layers**: 40 | |
| - **Number of Attention Heads**: 48 | |
| - **Dropout**: 0.1 | |
| - **Max Sequence Length**: 2048 genes | |
| ## Usage | |
| ### Load Model | |
| ```python | |
| from perturblab.model.cellfm import CellFMModel | |
| # Load pretrained model (automatically downloads if needed) | |
| model = CellFMModel.from_pretrained('cellfm-800m') | |
| # Or use short name | |
| model = CellFMModel.from_pretrained('800m') | |
| # Or from local path | |
| model = CellFMModel.from_pretrained('./weights/cellfm-800m') | |
| ``` | |
| ### Generate Cell Embeddings | |
| ```python | |
| import scanpy as sc | |
| # Load your data | |
| adata = sc.read_h5ad('your_data.h5ad') | |
| # Preprocess | |
| adata = CellFMModel.prepare_data(adata) | |
| # Get embeddings (use smaller batch size for 800M model) | |
| embeddings = model.predict_embeddings( | |
| adata, | |
| batch_size=8, # Smaller batch size for larger model | |
| return_cls_token=True, | |
| ) | |
| # Access cell embeddings | |
| cell_embeddings = embeddings['cell_embeddings'] # Shape: (n_cells, 1536) | |
| ``` | |
| ### Fine-tune for Classification | |
| ```python | |
| from perturblab.model.cellfm import CellFMModel, CellFMConfig | |
| # Initialize model with classification head | |
| config = CellFMConfig( | |
| model_name='800M', | |
| n_genes=24072, | |
| enc_dims=1536, | |
| enc_nlayers=40, | |
| enc_num_heads=48, | |
| num_cls=10, # Number of cell types | |
| ) | |
| model = CellFMModel(config, for_finetuning=True) | |
| # Load pretrained weights | |
| model.load_weights('./weights/cellfm-800m/model.pt') | |
| # Get dataloaders | |
| train_loader = model.get_dataloader(train_data, batch_size=4)['train'] | |
| val_loader = model.get_dataloader(val_data, batch_size=4)['train'] | |
| # Train | |
| model.train_model( | |
| train_dataloader=train_loader, | |
| val_dataloader=val_loader, | |
| num_epochs=10, | |
| learning_rate=1e-4, | |
| ) | |
| ``` | |
| ### Perturbation Prediction | |
| ```python | |
| from perturblab.model.cellfm import CellFMPerturbationModel | |
| from perturblab.data import PerturbationData | |
| # Load perturbation data | |
| data = PerturbationData.from_anndata(adata) | |
| data.split_data(train=0.7, val=0.15, test=0.15) | |
| # Initialize model | |
| model = CellFMPerturbationModel.from_pretrained('cellfm-800m') | |
| # Initialize perturbation head from dataset | |
| model.init_perturbation_head_from_dataset(data) | |
| # Train (use smaller batch size) | |
| model.train_model(data, epochs=20, batch_size=4) | |
| # Predict | |
| predictions = model.predict_perturbation(data, split='test') | |
| # Evaluate | |
| metrics = model.evaluate(data, split='test') | |
| print(f"Pearson correlation: {metrics['pearson']:.4f}") | |
| ``` | |
| ## Performance Notes | |
| - **Memory Requirements**: ~3-4GB GPU memory for inference (batch_size=8) | |
| - **Recommended Batch Size**: 4-8 for training, 8-16 for inference | |
| - **Inference Speed**: ~2-3x slower than 80M model | |
| - **Loading Time**: ~5-10 seconds | |
| ## Model Architecture | |
| - **Encoder**: Retention-based Transformer (MAE Autobin) | |
| - Auto-discretization embedding layer | |
| - 40 retention layers with 48 attention heads each | |
| - Hidden dimension: 1536 | |
| - Layer normalization and residual connections | |
| - **Pre-training**: Masked Autoencoding (MAE) | |
| - Masks 50% of genes | |
| - Reconstructs masked gene expression | |
| - **Output**: Gene-level embeddings + CLS token (1536-dimensional) | |
| ## Comparison with 80M Model | |
| | Feature | 80M | 800M | | |
| |---------|-----|------| | |
| | Parameters | 80M | 800M | | |
| | Hidden Dim | 1536 | 1536 | | |
| | Layers | 2 | 40 | | |
| | Heads | 48 | 48 | | |
| | Genes | 27,855 | 24,072 | | |
| | Memory (Inference) | ~1-2GB | ~3-4GB | | |
| | Speed | Faster | Slower | | |
| | Performance | Good | Better | | |
| The 800M model provides significantly better representation quality due to its deeper architecture (40 layers vs 2 layers), at the cost of increased computational requirements. | |
| ## Files | |
| - `config.json`: Model configuration | |
| - `model.pt`: Model weights (PyTorch state dict, ~3.0GB) | |
| - `README.md`: This file | |
| - `.gitattributes`: Git LFS configuration | |
| ## Citation | |
| If you use CellFM in your research, please cite: | |
| ```bibtex | |
| @article{cellfm2024, | |
| title={CellFM: A Large-Scale Foundation Model for Single-Cell Transcriptomics}, | |
| author={...}, | |
| journal={...}, | |
| year={2024} | |
| } | |
| ``` | |
| ## References | |
| - Original Repository: https://github.com/biomed-AI/CellFM | |
| - PyTorch Version: https://github.com/biomed-AI/CellFM-torch | |
| - Paper: [Link to paper when available] | |
| ## Notes | |
| - This model was converted from the original MindSpore checkpoint | |
| - The gene vocabulary (24,072 genes) may differ from the 80M model (27,855 genes) | |
| - For best results, ensure your data preprocessing matches the model's expected input format | |
| - Use `CellFMModel.prepare_data()` to automatically preprocess your data | |