# CellFM-800M ## Model Description CellFM is a large-scale foundation model pre-trained on transcriptomics of 100 million human cells using a retention-based architecture (MAE Autobin). - **Model Size**: 800M parameters - **Pre-training Data**: 100M human cells - **Architecture**: Retention-based Transformer (MAE Autobin) - **Vocabulary**: 24,072 genes - **Pre-training Task**: Masked Autoencoding (MAE) ## Model Details - **Source**: [biomed-AI/CellFM](https://github.com/biomed-AI/CellFM) - **Original Framework**: MindSpore - **Converted to**: PyTorch (PerturbLab format) - **License**: See original repository for details ## Architecture Specifications - **Hidden Dimension**: 1536 - **Number of Layers**: 40 - **Number of Attention Heads**: 48 - **Dropout**: 0.1 - **Max Sequence Length**: 2048 genes ## Usage ### Load Model ```python from perturblab.model.cellfm import CellFMModel # Load pretrained model (automatically downloads if needed) model = CellFMModel.from_pretrained('cellfm-800m') # Or use short name model = CellFMModel.from_pretrained('800m') # Or from local path model = CellFMModel.from_pretrained('./weights/cellfm-800m') ``` ### Generate Cell Embeddings ```python import scanpy as sc # Load your data adata = sc.read_h5ad('your_data.h5ad') # Preprocess adata = CellFMModel.prepare_data(adata) # Get embeddings (use smaller batch size for 800M model) embeddings = model.predict_embeddings( adata, batch_size=8, # Smaller batch size for larger model return_cls_token=True, ) # Access cell embeddings cell_embeddings = embeddings['cell_embeddings'] # Shape: (n_cells, 1536) ``` ### Fine-tune for Classification ```python from perturblab.model.cellfm import CellFMModel, CellFMConfig # Initialize model with classification head config = CellFMConfig( model_name='800M', n_genes=24072, enc_dims=1536, enc_nlayers=40, enc_num_heads=48, num_cls=10, # Number of cell types ) model = CellFMModel(config, for_finetuning=True) # Load pretrained weights model.load_weights('./weights/cellfm-800m/model.pt') # Get dataloaders train_loader = model.get_dataloader(train_data, batch_size=4)['train'] val_loader = model.get_dataloader(val_data, batch_size=4)['train'] # Train model.train_model( train_dataloader=train_loader, val_dataloader=val_loader, num_epochs=10, learning_rate=1e-4, ) ``` ### Perturbation Prediction ```python from perturblab.model.cellfm import CellFMPerturbationModel from perturblab.data import PerturbationData # Load perturbation data data = PerturbationData.from_anndata(adata) data.split_data(train=0.7, val=0.15, test=0.15) # Initialize model model = CellFMPerturbationModel.from_pretrained('cellfm-800m') # Initialize perturbation head from dataset model.init_perturbation_head_from_dataset(data) # Train (use smaller batch size) model.train_model(data, epochs=20, batch_size=4) # Predict predictions = model.predict_perturbation(data, split='test') # Evaluate metrics = model.evaluate(data, split='test') print(f"Pearson correlation: {metrics['pearson']:.4f}") ``` ## Performance Notes - **Memory Requirements**: ~3-4GB GPU memory for inference (batch_size=8) - **Recommended Batch Size**: 4-8 for training, 8-16 for inference - **Inference Speed**: ~2-3x slower than 80M model - **Loading Time**: ~5-10 seconds ## Model Architecture - **Encoder**: Retention-based Transformer (MAE Autobin) - Auto-discretization embedding layer - 40 retention layers with 48 attention heads each - Hidden dimension: 1536 - Layer normalization and residual connections - **Pre-training**: Masked Autoencoding (MAE) - Masks 50% of genes - Reconstructs masked gene expression - **Output**: Gene-level embeddings + CLS token (1536-dimensional) ## Comparison with 80M Model | Feature | 80M | 800M | |---------|-----|------| | Parameters | 80M | 800M | | Hidden Dim | 1536 | 1536 | | Layers | 2 | 40 | | Heads | 48 | 48 | | Genes | 27,855 | 24,072 | | Memory (Inference) | ~1-2GB | ~3-4GB | | Speed | Faster | Slower | | Performance | Good | Better | The 800M model provides significantly better representation quality due to its deeper architecture (40 layers vs 2 layers), at the cost of increased computational requirements. ## Files - `config.json`: Model configuration - `model.pt`: Model weights (PyTorch state dict, ~3.0GB) - `README.md`: This file - `.gitattributes`: Git LFS configuration ## Citation If you use CellFM in your research, please cite: ```bibtex @article{cellfm2024, title={CellFM: A Large-Scale Foundation Model for Single-Cell Transcriptomics}, author={...}, journal={...}, year={2024} } ``` ## References - Original Repository: https://github.com/biomed-AI/CellFM - PyTorch Version: https://github.com/biomed-AI/CellFM-torch - Paper: [Link to paper when available] ## Notes - This model was converted from the original MindSpore checkpoint - The gene vocabulary (24,072 genes) may differ from the 80M model (27,855 genes) - For best results, ensure your data preprocessing matches the model's expected input format - Use `CellFMModel.prepare_data()` to automatically preprocess your data