| # CellFM-80M | |
| ## Model Description | |
| CellFM is a large-scale foundation model pre-trained on transcriptomics of 100 million human cells using a retention-based architecture (MAE Autobin). | |
| - **Model Size**: 80M | |
| - **Pre-training Data**: 100M human cells | |
| - **Architecture**: Retention-based Transformer (MAE Autobin) | |
| - **Vocabulary**: 27,855 genes | |
| - **Pre-training Task**: Masked Autoencoding (MAE) | |
| ## Model Details | |
| - **Source**: [biomed-AI/CellFM](https://github.com/biomed-AI/CellFM) | |
| - **Original Framework**: MindSpore | |
| - **Converted to**: PyTorch (PerturbLab format) | |
| - **License**: See original repository for details | |
| ## Usage | |
| ### Load Model | |
| ```python | |
| from perturblab.model.cellfm import CellFMModel | |
| # Load pretrained model | |
| model = CellFMModel.from_pretrained('cellfm-80m') | |
| # Or from local path | |
| model = CellFMModel.from_pretrained('./weights/cellfm-80m') | |
| ``` | |
| ### Generate Cell Embeddings | |
| ```python | |
| import scanpy as sc | |
| from perturblab.data import PerturbationData | |
| # Load your data | |
| adata = sc.read_h5ad('your_data.h5ad') | |
| # Preprocess | |
| adata = CellFMModel.prepare_data(adata) | |
| # Get embeddings | |
| embeddings = model.predict_embeddings( | |
| adata, | |
| batch_size=32, | |
| return_cls_token=True, | |
| ) | |
| # Access cell embeddings | |
| cell_embeddings = embeddings['cell_embeddings'] # Shape: (n_cells, enc_dims) | |
| ``` | |
| ### Fine-tune for Classification | |
| ```python | |
| from perturblab.model.cellfm import CellFMModel, CellFMConfig | |
| # Initialize model with classification head | |
| config = CellFMConfig( | |
| model_name='80M', | |
| num_cls=10, # Number of cell types | |
| ) | |
| model = CellFMModel(config, for_finetuning=True) | |
| # Load pretrained weights | |
| model.load_weights('./weights/cellfm-80m/model.pt') | |
| # Fine-tune on your labeled data | |
| # ... (training code) | |
| ``` | |
| ### Perturbation Prediction | |
| ```python | |
| from perturblab.model.cellfm import CellFMPerturbationModel | |
| from perturblab.data import PerturbationData | |
| # Load perturbation data | |
| data = PerturbationData.from_anndata(adata) | |
| data.split_data(train=0.7, val=0.15, test=0.15) | |
| # Initialize model | |
| model = CellFMPerturbationModel.from_pretrained('cellfm-80m') | |
| # Initialize perturbation head | |
| model.init_perturbation_head_from_dataset(data) | |
| # Train | |
| model.train_model(data, epochs=20) | |
| # Predict | |
| predictions = model.predict_perturbation(data, split='test') | |
| ``` | |
| ## Model Architecture | |
| - **Encoder**: Retention-based Transformer (MAE Autobin) | |
| - Auto-discretization embedding layer | |
| - Multi-head retention mechanism | |
| - Layer normalization and residual connections | |
| - **Pre-training**: Masked Autoencoding (MAE) | |
| - Masks 50% of genes | |
| - Reconstructs masked gene expression | |
| - **Output**: Gene-level embeddings + CLS token | |
| ## Files | |
| - `config.json`: Model configuration | |
| - `model.pt`: Model weights (PyTorch state dict) | |
| - `README.md`: This file | |
| ## Citation | |
| If you use CellFM in your research, please cite: | |
| ```bibtex | |
| @article{cellfm2024, | |
| title={CellFM: A Large-Scale Foundation Model for Single-Cell Transcriptomics}, | |
| author={...}, | |
| journal={...}, | |
| year={2024} | |
| } | |
| ``` | |
| ## References | |
| - Original Repository: https://github.com/biomed-AI/CellFM | |
| - PyTorch Version: https://github.com/biomed-AI/CellFM-torch | |
| - Paper: [Link to paper when available] | |