krkawzq commited on
Commit
c7cd685
·
verified ·
1 Parent(s): afb48ec

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +106 -81
README.md CHANGED
@@ -1,147 +1,172 @@
1
  # CellFM-800M
2
 
3
- ## 模型描述
4
 
5
- CellFM 是一个在 1 亿人类单细胞转录组数据上预训练的大规模基础模型。
6
 
7
- - **参数量**: 800M (8 亿)
8
- - **预训练数据**: 100M 人类细胞
9
- - **架构**: Retention-based Transformer (MAE Autobin)
10
- - **基因词汇表**: 27,855 个基因
11
- - **预训练任务**: 掩码自编码 (Masked Autoencoding)
12
 
13
- ## 模型规格
14
 
15
- - **隐藏维度**: 2048
16
- - **层数**: 8
17
- - **注意力头数**: 64
 
 
 
 
 
 
 
18
  - **Dropout**: 0.1
19
- - **最大序列长度**: 2048 个基因
20
 
21
- ## 使用方法
22
 
23
- ### 加载模型
24
 
25
  ```python
26
  from perturblab.model.cellfm import CellFMModel
27
 
28
- # 加载预训练模型
29
  model = CellFMModel.from_pretrained('cellfm-800m')
30
 
31
- # 或从本地路径加载
 
 
 
32
  model = CellFMModel.from_pretrained('./weights/cellfm-800m')
33
  ```
34
 
35
- ### 生成细胞嵌入
36
 
37
  ```python
38
  import scanpy as sc
39
 
40
- # 加载数据
41
  adata = sc.read_h5ad('your_data.h5ad')
42
 
43
- # 预处理
44
  adata = CellFMModel.prepare_data(adata)
45
 
46
- # 获取嵌入
47
  embeddings = model.predict_embeddings(
48
  adata,
49
- batch_size=8, # 800M 模型使用较小的批次大小
50
  return_cls_token=True,
51
  )
52
 
53
- # 访问细胞嵌入
54
- cell_embeddings = embeddings['cell_embeddings'] # Shape: (n_cells, 2048)
55
  ```
56
 
57
- ### 微调分类任务
58
 
59
  ```python
60
  from perturblab.model.cellfm import CellFMModel, CellFMConfig
61
 
62
- # 初始化带分类头的模型
63
  config = CellFMConfig(
64
  model_name='800M',
65
- n_genes=27855,
66
- enc_dims=2048,
67
- enc_nlayers=8,
68
- enc_num_heads=64,
69
- num_cls=10, # 细胞类型数量
70
  )
71
  model = CellFMModel(config, for_finetuning=True)
72
 
73
- # 加载预训练权重
74
  model.load_weights('./weights/cellfm-800m/model.pt')
75
 
76
- # 在标注数据上微调
77
- # ... (训练代码)
 
 
 
 
 
 
 
 
 
78
  ```
79
 
80
- ### 扰动预测
81
 
82
  ```python
83
  from perturblab.model.cellfm import CellFMPerturbationModel
84
  from perturblab.data import PerturbationData
85
 
86
- # 加载扰动数据
87
  data = PerturbationData.from_anndata(adata)
88
  data.split_data(train=0.7, val=0.15, test=0.15)
89
 
90
- # 初始化模型
91
  model = CellFMPerturbationModel.from_pretrained('cellfm-800m')
92
 
93
- # 初始化扰动头
94
  model.init_perturbation_head_from_dataset(data)
95
 
96
- # 训练
97
- model.train_model(data, epochs=20, batch_size=8)
98
 
99
- # 预测
100
  predictions = model.predict_perturbation(data, split='test')
 
 
 
 
101
  ```
102
 
103
- ## 性能说明
104
 
105
- - **内存需求**: 推理需要约 3-4GB GPU 显存 (batch_size=16)
106
- - **推荐批次大小**: 推理 8-16,训练 4-8
107
- - **推理速度**: 80M 模型慢约 2-3
108
- - **加载时间**: 5-10
109
 
110
- ## 模型架构
111
 
112
- - **编码器**: Retention-based Transformer (MAE Autobin)
113
- - 自动离散化嵌入层
114
- - 8 retention 层,每层 64 个注意力头
115
- - 隐藏维度: 2048
116
- - 层归一化和残差连接
117
- - **预训练**: 掩码自编码 (MAE)
118
- - 掩码 50% 的基因
119
- - 重建被掩码的基因表达
120
- - **输出**: 基因级嵌入 + CLS token (2048 维)
121
 
122
- ## 80M 模型对比
123
 
124
- | 特性 | 80M | 800M |
125
- |------|-----|------|
126
- | 参数量 | 80M | 800M |
127
- | 隐藏维度 | 1536 | 2048 |
128
- | 层数 | 2 | 8 |
129
- | 注意力头 | 48 | 64 |
130
- | 内存 (推理) | ~1-2GB | ~3-4GB |
131
- | 速度 | 更快 | 较慢 |
132
- | 性能 | 良好 | 更好 |
 
133
 
134
- 800M 模型提供更好的表示质量,但需要更多计算资源。
135
 
136
- ## 文件说明
137
 
138
- - `config.json`: 模型配置文件
139
- - `model.pt`: PyTorch 格式的模型权重 (~3.2GB)
140
- - `README.md`: 本说明文档
 
141
 
142
- ## 引用
143
 
144
- 如果您在研究中使用 CellFM,请引用:
145
 
146
  ```bibtex
147
  @article{cellfm2024,
@@ -152,15 +177,15 @@ predictions = model.predict_perturbation(data, split='test')
152
  }
153
  ```
154
 
155
- ## 参考资料
156
 
157
- - 原始仓库: https://github.com/biomed-AI/CellFM
158
- - PyTorch 版本: https://github.com/biomed-AI/CellFM-torch
159
- - 论文: [待发布]
160
 
161
- ## 转换信息
162
 
163
- - **原始格式**: MindSpore checkpoint
164
- - **转换后格式**: PyTorch state_dict
165
- - **转换日期**: 2025-12-23
166
- - **转换工具**: PerturbLab conversion script
 
1
  # CellFM-800M
2
 
3
+ ## Model Description
4
 
5
+ CellFM is a large-scale foundation model pre-trained on transcriptomics of 100 million human cells using a retention-based architecture (MAE Autobin).
6
 
7
+ - **Model Size**: 800M parameters
8
+ - **Pre-training Data**: 100M human cells
9
+ - **Architecture**: Retention-based Transformer (MAE Autobin)
10
+ - **Vocabulary**: 24,072 genes
11
+ - **Pre-training Task**: Masked Autoencoding (MAE)
12
 
13
+ ## Model Details
14
 
15
+ - **Source**: [biomed-AI/CellFM](https://github.com/biomed-AI/CellFM)
16
+ - **Original Framework**: MindSpore
17
+ - **Converted to**: PyTorch (PerturbLab format)
18
+ - **License**: See original repository for details
19
+
20
+ ## Architecture Specifications
21
+
22
+ - **Hidden Dimension**: 1536
23
+ - **Number of Layers**: 40
24
+ - **Number of Attention Heads**: 48
25
  - **Dropout**: 0.1
26
+ - **Max Sequence Length**: 2048 genes
27
 
28
+ ## Usage
29
 
30
+ ### Load Model
31
 
32
  ```python
33
  from perturblab.model.cellfm import CellFMModel
34
 
35
+ # Load pretrained model (automatically downloads if needed)
36
  model = CellFMModel.from_pretrained('cellfm-800m')
37
 
38
+ # Or use short name
39
+ model = CellFMModel.from_pretrained('800m')
40
+
41
+ # Or from local path
42
  model = CellFMModel.from_pretrained('./weights/cellfm-800m')
43
  ```
44
 
45
+ ### Generate Cell Embeddings
46
 
47
  ```python
48
  import scanpy as sc
49
 
50
+ # Load your data
51
  adata = sc.read_h5ad('your_data.h5ad')
52
 
53
+ # Preprocess
54
  adata = CellFMModel.prepare_data(adata)
55
 
56
+ # Get embeddings (use smaller batch size for 800M model)
57
  embeddings = model.predict_embeddings(
58
  adata,
59
+ batch_size=8, # Smaller batch size for larger model
60
  return_cls_token=True,
61
  )
62
 
63
+ # Access cell embeddings
64
+ cell_embeddings = embeddings['cell_embeddings'] # Shape: (n_cells, 1536)
65
  ```
66
 
67
+ ### Fine-tune for Classification
68
 
69
  ```python
70
  from perturblab.model.cellfm import CellFMModel, CellFMConfig
71
 
72
+ # Initialize model with classification head
73
  config = CellFMConfig(
74
  model_name='800M',
75
+ n_genes=24072,
76
+ enc_dims=1536,
77
+ enc_nlayers=40,
78
+ enc_num_heads=48,
79
+ num_cls=10, # Number of cell types
80
  )
81
  model = CellFMModel(config, for_finetuning=True)
82
 
83
+ # Load pretrained weights
84
  model.load_weights('./weights/cellfm-800m/model.pt')
85
 
86
+ # Get dataloaders
87
+ train_loader = model.get_dataloader(train_data, batch_size=4)['train']
88
+ val_loader = model.get_dataloader(val_data, batch_size=4)['train']
89
+
90
+ # Train
91
+ model.train_model(
92
+ train_dataloader=train_loader,
93
+ val_dataloader=val_loader,
94
+ num_epochs=10,
95
+ learning_rate=1e-4,
96
+ )
97
  ```
98
 
99
+ ### Perturbation Prediction
100
 
101
  ```python
102
  from perturblab.model.cellfm import CellFMPerturbationModel
103
  from perturblab.data import PerturbationData
104
 
105
+ # Load perturbation data
106
  data = PerturbationData.from_anndata(adata)
107
  data.split_data(train=0.7, val=0.15, test=0.15)
108
 
109
+ # Initialize model
110
  model = CellFMPerturbationModel.from_pretrained('cellfm-800m')
111
 
112
+ # Initialize perturbation head from dataset
113
  model.init_perturbation_head_from_dataset(data)
114
 
115
+ # Train (use smaller batch size)
116
+ model.train_model(data, epochs=20, batch_size=4)
117
 
118
+ # Predict
119
  predictions = model.predict_perturbation(data, split='test')
120
+
121
+ # Evaluate
122
+ metrics = model.evaluate(data, split='test')
123
+ print(f"Pearson correlation: {metrics['pearson']:.4f}")
124
  ```
125
 
126
+ ## Performance Notes
127
 
128
+ - **Memory Requirements**: ~3-4GB GPU memory for inference (batch_size=8)
129
+ - **Recommended Batch Size**: 4-8 for training, 8-16 for inference
130
+ - **Inference Speed**: ~2-3x slower than 80M model
131
+ - **Loading Time**: ~5-10 seconds
132
 
133
+ ## Model Architecture
134
 
135
+ - **Encoder**: Retention-based Transformer (MAE Autobin)
136
+ - Auto-discretization embedding layer
137
+ - 40 retention layers with 48 attention heads each
138
+ - Hidden dimension: 1536
139
+ - Layer normalization and residual connections
140
+ - **Pre-training**: Masked Autoencoding (MAE)
141
+ - Masks 50% of genes
142
+ - Reconstructs masked gene expression
143
+ - **Output**: Gene-level embeddings + CLS token (1536-dimensional)
144
 
145
+ ## Comparison with 80M Model
146
 
147
+ | Feature | 80M | 800M |
148
+ |---------|-----|------|
149
+ | Parameters | 80M | 800M |
150
+ | Hidden Dim | 1536 | 1536 |
151
+ | Layers | 2 | 40 |
152
+ | Heads | 48 | 48 |
153
+ | Genes | 27,855 | 24,072 |
154
+ | Memory (Inference) | ~1-2GB | ~3-4GB |
155
+ | Speed | Faster | Slower |
156
+ | Performance | Good | Better |
157
 
158
+ The 800M model provides significantly better representation quality due to its deeper architecture (40 layers vs 2 layers), at the cost of increased computational requirements.
159
 
160
+ ## Files
161
 
162
+ - `config.json`: Model configuration
163
+ - `model.pt`: Model weights (PyTorch state dict, ~3.0GB)
164
+ - `README.md`: This file
165
+ - `.gitattributes`: Git LFS configuration
166
 
167
+ ## Citation
168
 
169
+ If you use CellFM in your research, please cite:
170
 
171
  ```bibtex
172
  @article{cellfm2024,
 
177
  }
178
  ```
179
 
180
+ ## References
181
 
182
+ - Original Repository: https://github.com/biomed-AI/CellFM
183
+ - PyTorch Version: https://github.com/biomed-AI/CellFM-torch
184
+ - Paper: [Link to paper when available]
185
 
186
+ ## Notes
187
 
188
+ - This model was converted from the original MindSpore checkpoint
189
+ - The gene vocabulary (24,072 genes) may differ from the 80M model (27,855 genes)
190
+ - For best results, ensure your data preprocessing matches the model's expected input format
191
+ - Use `CellFMModel.prepare_data()` to automatically preprocess your data