File size: 5,157 Bytes
d6a8cfb
 
c7cd685
d6a8cfb
c7cd685
d6a8cfb
c7cd685
 
 
 
 
d6a8cfb
c7cd685
d6a8cfb
c7cd685
 
 
 
 
 
 
 
 
 
d6a8cfb
c7cd685
d6a8cfb
c7cd685
d6a8cfb
c7cd685
d6a8cfb
 
 
 
c7cd685
d6a8cfb
 
c7cd685
 
 
 
d6a8cfb
 
 
c7cd685
d6a8cfb
 
 
 
c7cd685
d6a8cfb
 
c7cd685
d6a8cfb
 
c7cd685
d6a8cfb
 
c7cd685
d6a8cfb
 
 
c7cd685
 
d6a8cfb
 
c7cd685
d6a8cfb
 
 
 
c7cd685
d6a8cfb
 
c7cd685
 
 
 
 
d6a8cfb
 
 
c7cd685
d6a8cfb
 
c7cd685
 
 
 
 
 
 
 
 
 
 
d6a8cfb
 
c7cd685
d6a8cfb
 
 
 
 
c7cd685
d6a8cfb
 
 
c7cd685
d6a8cfb
 
c7cd685
d6a8cfb
 
c7cd685
 
d6a8cfb
c7cd685
d6a8cfb
c7cd685
 
 
 
d6a8cfb
 
c7cd685
d6a8cfb
c7cd685
 
 
 
d6a8cfb
c7cd685
d6a8cfb
c7cd685
 
 
 
 
 
 
 
 
d6a8cfb
c7cd685
d6a8cfb
c7cd685
 
 
 
 
 
 
 
 
 
d6a8cfb
c7cd685
d6a8cfb
c7cd685
d6a8cfb
c7cd685
 
 
 
d6a8cfb
c7cd685
d6a8cfb
c7cd685
d6a8cfb
 
 
 
 
 
 
 
 
 
c7cd685
d6a8cfb
c7cd685
 
 
d6a8cfb
c7cd685
d6a8cfb
c7cd685
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
# CellFM-800M

## Model Description

CellFM is a large-scale foundation model pre-trained on transcriptomics of 100 million human cells using a retention-based architecture (MAE Autobin).

- **Model Size**: 800M parameters
- **Pre-training Data**: 100M human cells
- **Architecture**: Retention-based Transformer (MAE Autobin)
- **Vocabulary**: 24,072 genes
- **Pre-training Task**: Masked Autoencoding (MAE)

## Model Details

- **Source**: [biomed-AI/CellFM](https://github.com/biomed-AI/CellFM)
- **Original Framework**: MindSpore
- **Converted to**: PyTorch (PerturbLab format)
- **License**: See original repository for details

## Architecture Specifications

- **Hidden Dimension**: 1536
- **Number of Layers**: 40
- **Number of Attention Heads**: 48
- **Dropout**: 0.1
- **Max Sequence Length**: 2048 genes

## Usage

### Load Model

```python
from perturblab.model.cellfm import CellFMModel

# Load pretrained model (automatically downloads if needed)
model = CellFMModel.from_pretrained('cellfm-800m')

# Or use short name
model = CellFMModel.from_pretrained('800m')

# Or from local path
model = CellFMModel.from_pretrained('./weights/cellfm-800m')
```

### Generate Cell Embeddings

```python
import scanpy as sc

# Load your data
adata = sc.read_h5ad('your_data.h5ad')

# Preprocess
adata = CellFMModel.prepare_data(adata)

# Get embeddings (use smaller batch size for 800M model)
embeddings = model.predict_embeddings(
    adata,
    batch_size=8,  # Smaller batch size for larger model
    return_cls_token=True,
)

# Access cell embeddings
cell_embeddings = embeddings['cell_embeddings']  # Shape: (n_cells, 1536)
```

### Fine-tune for Classification

```python
from perturblab.model.cellfm import CellFMModel, CellFMConfig

# Initialize model with classification head
config = CellFMConfig(
    model_name='800M',
    n_genes=24072,
    enc_dims=1536,
    enc_nlayers=40,
    enc_num_heads=48,
    num_cls=10,  # Number of cell types
)
model = CellFMModel(config, for_finetuning=True)

# Load pretrained weights
model.load_weights('./weights/cellfm-800m/model.pt')

# Get dataloaders
train_loader = model.get_dataloader(train_data, batch_size=4)['train']
val_loader = model.get_dataloader(val_data, batch_size=4)['train']

# Train
model.train_model(
    train_dataloader=train_loader,
    val_dataloader=val_loader,
    num_epochs=10,
    learning_rate=1e-4,
)
```

### Perturbation Prediction

```python
from perturblab.model.cellfm import CellFMPerturbationModel
from perturblab.data import PerturbationData

# Load perturbation data
data = PerturbationData.from_anndata(adata)
data.split_data(train=0.7, val=0.15, test=0.15)

# Initialize model
model = CellFMPerturbationModel.from_pretrained('cellfm-800m')

# Initialize perturbation head from dataset
model.init_perturbation_head_from_dataset(data)

# Train (use smaller batch size)
model.train_model(data, epochs=20, batch_size=4)

# Predict
predictions = model.predict_perturbation(data, split='test')

# Evaluate
metrics = model.evaluate(data, split='test')
print(f"Pearson correlation: {metrics['pearson']:.4f}")
```

## Performance Notes

- **Memory Requirements**: ~3-4GB GPU memory for inference (batch_size=8)
- **Recommended Batch Size**: 4-8 for training, 8-16 for inference
- **Inference Speed**: ~2-3x slower than 80M model
- **Loading Time**: ~5-10 seconds

## Model Architecture

- **Encoder**: Retention-based Transformer (MAE Autobin)
  - Auto-discretization embedding layer
  - 40 retention layers with 48 attention heads each
  - Hidden dimension: 1536
  - Layer normalization and residual connections
- **Pre-training**: Masked Autoencoding (MAE)
  - Masks 50% of genes
  - Reconstructs masked gene expression
- **Output**: Gene-level embeddings + CLS token (1536-dimensional)

## Comparison with 80M Model

| Feature | 80M | 800M |
|---------|-----|------|
| Parameters | 80M | 800M |
| Hidden Dim | 1536 | 1536 |
| Layers | 2 | 40 |
| Heads | 48 | 48 |
| Genes | 27,855 | 24,072 |
| Memory (Inference) | ~1-2GB | ~3-4GB |
| Speed | Faster | Slower |
| Performance | Good | Better |

The 800M model provides significantly better representation quality due to its deeper architecture (40 layers vs 2 layers), at the cost of increased computational requirements.

## Files

- `config.json`: Model configuration
- `model.pt`: Model weights (PyTorch state dict, ~3.0GB)
- `README.md`: This file
- `.gitattributes`: Git LFS configuration

## Citation

If you use CellFM in your research, please cite:

```bibtex
@article{cellfm2024,
  title={CellFM: A Large-Scale Foundation Model for Single-Cell Transcriptomics},
  author={...},
  journal={...},
  year={2024}
}
```

## References

- Original Repository: https://github.com/biomed-AI/CellFM
- PyTorch Version: https://github.com/biomed-AI/CellFM-torch
- Paper: [Link to paper when available]

## Notes

- This model was converted from the original MindSpore checkpoint
- The gene vocabulary (24,072 genes) may differ from the 80M model (27,855 genes)
- For best results, ensure your data preprocessing matches the model's expected input format
- Use `CellFMModel.prepare_data()` to automatically preprocess your data