# CellFM-80M ## Model Description CellFM is a large-scale foundation model pre-trained on transcriptomics of 100 million human cells using a retention-based architecture (MAE Autobin). - **Model Size**: 80M - **Pre-training Data**: 100M human cells - **Architecture**: Retention-based Transformer (MAE Autobin) - **Vocabulary**: 27,855 genes - **Pre-training Task**: Masked Autoencoding (MAE) ## Model Details - **Source**: [biomed-AI/CellFM](https://github.com/biomed-AI/CellFM) - **Original Framework**: MindSpore - **Converted to**: PyTorch (PerturbLab format) - **License**: See original repository for details ## Usage ### Load Model ```python from perturblab.model.cellfm import CellFMModel # Load pretrained model model = CellFMModel.from_pretrained('cellfm-80m') # Or from local path model = CellFMModel.from_pretrained('./weights/cellfm-80m') ``` ### Generate Cell Embeddings ```python import scanpy as sc from perturblab.data import PerturbationData # Load your data adata = sc.read_h5ad('your_data.h5ad') # Preprocess adata = CellFMModel.prepare_data(adata) # Get embeddings embeddings = model.predict_embeddings( adata, batch_size=32, return_cls_token=True, ) # Access cell embeddings cell_embeddings = embeddings['cell_embeddings'] # Shape: (n_cells, enc_dims) ``` ### Fine-tune for Classification ```python from perturblab.model.cellfm import CellFMModel, CellFMConfig # Initialize model with classification head config = CellFMConfig( model_name='80M', num_cls=10, # Number of cell types ) model = CellFMModel(config, for_finetuning=True) # Load pretrained weights model.load_weights('./weights/cellfm-80m/model.pt') # Fine-tune on your labeled data # ... (training code) ``` ### Perturbation Prediction ```python from perturblab.model.cellfm import CellFMPerturbationModel from perturblab.data import PerturbationData # Load perturbation data data = PerturbationData.from_anndata(adata) data.split_data(train=0.7, val=0.15, test=0.15) # Initialize model model = CellFMPerturbationModel.from_pretrained('cellfm-80m') # Initialize perturbation head model.init_perturbation_head_from_dataset(data) # Train model.train_model(data, epochs=20) # Predict predictions = model.predict_perturbation(data, split='test') ``` ## Model Architecture - **Encoder**: Retention-based Transformer (MAE Autobin) - Auto-discretization embedding layer - Multi-head retention mechanism - Layer normalization and residual connections - **Pre-training**: Masked Autoencoding (MAE) - Masks 50% of genes - Reconstructs masked gene expression - **Output**: Gene-level embeddings + CLS token ## Files - `config.json`: Model configuration - `model.pt`: Model weights (PyTorch state dict) - `README.md`: This file ## Citation If you use CellFM in your research, please cite: ```bibtex @article{cellfm2024, title={CellFM: A Large-Scale Foundation Model for Single-Cell Transcriptomics}, author={...}, journal={...}, year={2024} } ``` ## References - Original Repository: https://github.com/biomed-AI/CellFM - PyTorch Version: https://github.com/biomed-AI/CellFM-torch - Paper: [Link to paper when available]