YAML Metadata
Warning:
empty or missing yaml metadata in repo card
(https://huggingface.co/docs/hub/model-cards#model-card-metadata)
CellFM-80M
Model Description
CellFM is a large-scale foundation model pre-trained on transcriptomics of 100 million human cells using a retention-based architecture (MAE Autobin).
- Model Size: 80M
- Pre-training Data: 100M human cells
- Architecture: Retention-based Transformer (MAE Autobin)
- Vocabulary: 27,855 genes
- Pre-training Task: Masked Autoencoding (MAE)
Model Details
- Source: biomed-AI/CellFM
- Original Framework: MindSpore
- Converted to: PyTorch (PerturbLab format)
- License: See original repository for details
Usage
Load Model
from perturblab.model.cellfm import CellFMModel
# Load pretrained model
model = CellFMModel.from_pretrained('cellfm-80m')
# Or from local path
model = CellFMModel.from_pretrained('./weights/cellfm-80m')
Generate Cell Embeddings
import scanpy as sc
from perturblab.data import PerturbationData
# Load your data
adata = sc.read_h5ad('your_data.h5ad')
# Preprocess
adata = CellFMModel.prepare_data(adata)
# Get embeddings
embeddings = model.predict_embeddings(
adata,
batch_size=32,
return_cls_token=True,
)
# Access cell embeddings
cell_embeddings = embeddings['cell_embeddings'] # Shape: (n_cells, enc_dims)
Fine-tune for Classification
from perturblab.model.cellfm import CellFMModel, CellFMConfig
# Initialize model with classification head
config = CellFMConfig(
model_name='80M',
num_cls=10, # Number of cell types
)
model = CellFMModel(config, for_finetuning=True)
# Load pretrained weights
model.load_weights('./weights/cellfm-80m/model.pt')
# Fine-tune on your labeled data
# ... (training code)
Perturbation Prediction
from perturblab.model.cellfm import CellFMPerturbationModel
from perturblab.data import PerturbationData
# Load perturbation data
data = PerturbationData.from_anndata(adata)
data.split_data(train=0.7, val=0.15, test=0.15)
# Initialize model
model = CellFMPerturbationModel.from_pretrained('cellfm-80m')
# Initialize perturbation head
model.init_perturbation_head_from_dataset(data)
# Train
model.train_model(data, epochs=20)
# Predict
predictions = model.predict_perturbation(data, split='test')
Model Architecture
- Encoder: Retention-based Transformer (MAE Autobin)
- Auto-discretization embedding layer
- Multi-head retention mechanism
- Layer normalization and residual connections
- Pre-training: Masked Autoencoding (MAE)
- Masks 50% of genes
- Reconstructs masked gene expression
- Output: Gene-level embeddings + CLS token
Files
config.json: Model configurationmodel.pt: Model weights (PyTorch state dict)README.md: This file
Citation
If you use CellFM in your research, please cite:
@article{cellfm2024,
title={CellFM: A Large-Scale Foundation Model for Single-Cell Transcriptomics},
author={...},
journal={...},
year={2024}
}
References
- Original Repository: https://github.com/biomed-AI/CellFM
- PyTorch Version: https://github.com/biomed-AI/CellFM-torch
- Paper: [Link to paper when available]
- Downloads last month
- 25
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support