YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

CellFM-80M

Model Description

CellFM is a large-scale foundation model pre-trained on transcriptomics of 100 million human cells using a retention-based architecture (MAE Autobin).

  • Model Size: 80M
  • Pre-training Data: 100M human cells
  • Architecture: Retention-based Transformer (MAE Autobin)
  • Vocabulary: 27,855 genes
  • Pre-training Task: Masked Autoencoding (MAE)

Model Details

  • Source: biomed-AI/CellFM
  • Original Framework: MindSpore
  • Converted to: PyTorch (PerturbLab format)
  • License: See original repository for details

Usage

Load Model

from perturblab.model.cellfm import CellFMModel

# Load pretrained model
model = CellFMModel.from_pretrained('cellfm-80m')

# Or from local path
model = CellFMModel.from_pretrained('./weights/cellfm-80m')

Generate Cell Embeddings

import scanpy as sc
from perturblab.data import PerturbationData

# Load your data
adata = sc.read_h5ad('your_data.h5ad')

# Preprocess
adata = CellFMModel.prepare_data(adata)

# Get embeddings
embeddings = model.predict_embeddings(
    adata,
    batch_size=32,
    return_cls_token=True,
)

# Access cell embeddings
cell_embeddings = embeddings['cell_embeddings']  # Shape: (n_cells, enc_dims)

Fine-tune for Classification

from perturblab.model.cellfm import CellFMModel, CellFMConfig

# Initialize model with classification head
config = CellFMConfig(
    model_name='80M',
    num_cls=10,  # Number of cell types
)
model = CellFMModel(config, for_finetuning=True)

# Load pretrained weights
model.load_weights('./weights/cellfm-80m/model.pt')

# Fine-tune on your labeled data
# ... (training code)

Perturbation Prediction

from perturblab.model.cellfm import CellFMPerturbationModel
from perturblab.data import PerturbationData

# Load perturbation data
data = PerturbationData.from_anndata(adata)
data.split_data(train=0.7, val=0.15, test=0.15)

# Initialize model
model = CellFMPerturbationModel.from_pretrained('cellfm-80m')

# Initialize perturbation head
model.init_perturbation_head_from_dataset(data)

# Train
model.train_model(data, epochs=20)

# Predict
predictions = model.predict_perturbation(data, split='test')

Model Architecture

  • Encoder: Retention-based Transformer (MAE Autobin)
    • Auto-discretization embedding layer
    • Multi-head retention mechanism
    • Layer normalization and residual connections
  • Pre-training: Masked Autoencoding (MAE)
    • Masks 50% of genes
    • Reconstructs masked gene expression
  • Output: Gene-level embeddings + CLS token

Files

  • config.json: Model configuration
  • model.pt: Model weights (PyTorch state dict)
  • README.md: This file

Citation

If you use CellFM in your research, please cite:

@article{cellfm2024,
  title={CellFM: A Large-Scale Foundation Model for Single-Cell Transcriptomics},
  author={...},
  journal={...},
  year={2024}
}

References

Downloads last month
25
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support