eugeneyuan commited on
Commit
d066174
·
verified ·
1 Parent(s): 5942acb

Add CryoFM model weights and configurations

Browse files

- Add CryoFM-S and CryoFM-L model variants
- Include model configs and safetensors checkpoints
- Add README with model description and usage examples

.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/cryofm_archs.jpg filter=lfs diff=lfs merge=lfs -text
37
+ assets/cryofm.gif filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,165 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - cryo-em
5
+ - flow-matching
6
+ - 3d-density-maps
7
+ - foundation-model
8
+ ---
9
+
10
+ # CryoFM: Flow-based Foundation Model for Cryo-EM Density Maps
11
+
12
+ <div align="center">
13
+
14
+ [![arXiv](https://img.shields.io/badge/arXiv-2410.08631-B31B1B?logo=arxiv&logoColor=white)](https://arxiv.org/abs/2410.08631)
15
+ [![GitHub](https://img.shields.io/badge/GitHub-cryofm-181717?logo=github&logoColor=white)](https://github.com/ByteDance-Seed/cryofm)
16
+ [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
17
+
18
+ </div>
19
+
20
+ <div align="center">
21
+ <img src="./assets/cryofm.gif" alt="CryoFM Demo" style="max-width: 100%; height: auto; width: 800px;"/>
22
+ </div>
23
+
24
+ ## Model Description
25
+
26
+ CryoFM1 is a flow-based foundation model for 3D cryo-electron microscopy (cryo-EM) density maps. The model employs a Hierarchical Diffusion Transformer (HDiT) architecture, specifically designed to learn deep priors of 3D cryo-EM densities. CryoFM1 supports various downstream tasks including density map denoising, anisotropy noise correction, missing wedge inpainting, and *ab initio* modeling.
27
+
28
+ ### Key Features
29
+
30
+ - **Flow Matching Framework**: Uses flow matching for efficient and stable training
31
+ - **HDiT Architecture**: Hierarchical Diffusion Transformer with local and global attention mechanisms
32
+ - **Two Model Variants**: CryoFM-S (64³) and CryoFM-L (128³) for different resolution needs
33
+ - **Downstream Task Support**: Denoising, anisotropy noise correction, missing wedge restoration, and more
34
+
35
+ ## Model Details
36
+
37
+ CryoFM1 employs a Hierarchical Diffusion Transformer (HDiT) architecture that combines local neighborhood attention with global attention mechanisms. This design enables the model to effectively capture both fine-grained local structures and long-range dependencies in 3D cryo-EM density maps. The architecture processes 3D volumes through a hierarchical patch-based approach, progressively building representations at multiple scales.
38
+
39
+ <div align="center">
40
+ <img src="./assets/cryofm_archs.jpg" alt="CryoFM Architecture" style="max-width: 100%; height: auto; width: 600px;"/>
41
+ </div>
42
+
43
+ The model is available in two variants optimized for different resolution requirements. The following table summarizes the key architectural and training parameters for each variant:
44
+
45
+ | Parameter | CRYOFM-S | CRYOFM-L |
46
+ |-----------|----------|----------|
47
+ | **Parameters** | 335.18 M | 308.54 M |
48
+ | **GFLOP/forward** | 395.87 | 427.26 |
49
+ | **Training Steps** | 150k | 300k |
50
+ | **Batch Size** | 128 | 128 |
51
+ | **Precision** | bf16 | bf16 |
52
+ | **Training Hardware** | 8×A100 | 8×A100 |
53
+ | **Patchifying** | 4 | 4 |
54
+ | **Levels (Local + Global Attention)** | 1 + 1 | 2 + 1 |
55
+ | **Depth** | [4, 8] | [2, 2, 12] |
56
+ | **Widths** | [768, 1536] | [320, 640, 1280] |
57
+ | **Attention Heads (Width / Head Dim)** | [12, 24] | [5, 10, 20] |
58
+ | **Attention Head Dim** | 64 | 64 |
59
+ | **Neighborhood Kernel Size** | 7 | 7 |
60
+
61
+ ## Quick Start
62
+
63
+ ### Unconditional Generation
64
+
65
+ CryoFM1 provides two model variants for different resolution needs:
66
+ - **CryoFM-S**: Generates 64×64×64 voxel density maps at 1.5 Å/pixel resolution
67
+ - **CryoFM-L**: Generates 128×128×128 voxel density maps at 3.0 Å/pixel resolution
68
+
69
+
70
+ ```python
71
+ import torch
72
+ from mmengine import Config
73
+ from cryofm.core.utils.mrc_io import save_mrc
74
+ from cryofm.projects.cryofm1.lit_modules import CryoFM1
75
+ from cryofm.core.utils.sampling_fm import sample_from_fm
76
+
77
+ # Choose model variant: "cryofm-s" or "cryofm-l"
78
+ model_variant = "cryofm-s" # or "cryofm-l"
79
+ model_config = {
80
+ "cryofm-s": {
81
+ "config_path": "cryofm-v1/cryofm-s/config.yaml",
82
+ "model_path": "cryofm-v1/cryofm-s/model.safetensors",
83
+ "side_shape": 64,
84
+ "apix": 1.5
85
+ },
86
+ "cryofm-l": {
87
+ "config_path": "cryofm-v1/cryofm-l/config.yaml",
88
+ "model_path": "cryofm-v1/cryofm-l/model.safetensors",
89
+ "side_shape": 128,
90
+ "apix": 3.0
91
+ }
92
+ }
93
+
94
+ # Load configuration and model
95
+ cfg = Config.fromfile(model_config[model_variant]["config_path"])
96
+ lit_model = CryoFM1.load_from_safetensors(
97
+ model_config[model_variant]["model_path"],
98
+ cfg=cfg
99
+ )
100
+
101
+ # Set up device and model
102
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
103
+ lit_model = lit_model.to(device)
104
+ lit_model.eval()
105
+
106
+ # Define vector field function for flow matching
107
+ def v_xt_t(_xt, _t):
108
+ return lit_model(_xt, _t)
109
+
110
+ # Generate samples
111
+ # Note: Enable bfloat16 if your GPU supports it for better performance
112
+ with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
113
+ out = sample_from_fm(
114
+ v_xt_t,
115
+ lit_model.noise_scheduler,
116
+ method="euler",
117
+ num_steps=200,
118
+ num_samples=3,
119
+ device=device,
120
+ side_shape=model_config[model_variant]["side_shape"]
121
+ )
122
+ # Apply z-scaling normalization if configured
123
+ if hasattr(lit_model.cfg, "z_scale") and lit_model.cfg.z_scale.mean is not None:
124
+ out = out * lit_model.cfg.z_scale.std + lit_model.cfg.z_scale.mean
125
+
126
+ # Save generated density maps
127
+ for i in range(3):
128
+ save_mrc(
129
+ out[i].float().cpu().numpy(),
130
+ f"sample-{i}.mrc",
131
+ apix=model_config[model_variant]["apix"] # Angstroms per pixel
132
+ )
133
+ ```
134
+
135
+ ### Ethical Considerations
136
+
137
+ This model is intended for scientific research and structural biology applications. Users should:
138
+ - Ensure proper attribution when using generated structures
139
+ - Validate generated structures through experimental verification
140
+ - Be aware of potential biases in the training data
141
+
142
+ ## Citation
143
+
144
+ If you use CryoFM1 in your research, please cite:
145
+
146
+ ```bibtex
147
+ @inproceedings{
148
+ zhou2025cryofm,
149
+ title={Cryo{FM}: A Flow-based Foundation Model for Cryo-{EM} Densities},
150
+ author={Yi Zhou and Yilai Li and Jing Yuan and Quanquan Gu},
151
+ booktitle={The Thirteenth International Conference on Learning Representations},
152
+ year={2025},
153
+ url={https://openreview.net/forum?id=T4sMzjy7fO}
154
+ }
155
+ ```
156
+
157
+ ## License
158
+
159
+ This model is released under the Apache 2.0 License. See the [LICENSE](https://github.com/ByteDance-Seed/cryofm/blob/main/LICENSE) file for details.
160
+
161
+ ## Acknowledgments
162
+
163
+ This work is developed by the ByteDance Seed Team. For more information, visit:
164
+ - [Project Repository](https://github.com/ByteDance-Seed/cryofm)
165
+ - [ByteDance Seed Team](https://seed.bytedance.com/)
assets/cryofm.gif ADDED

Git LFS Details

  • SHA256: dbccb7fd7a941ad09f3154b666b8e3ad83334f8d5c0f11fa59eb5700f684d828
  • Pointer size: 132 Bytes
  • Size of remote file: 2.95 MB
assets/cryofm_archs.jpg ADDED

Git LFS Details

  • SHA256: fef5d88f6988a5a0ffa9f44073147d569fdd79efa9f960b0d27a28da22926502
  • Pointer size: 131 Bytes
  • Size of remote file: 528 kB
cryofm-l/config.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ckpt_path: null
2
+ ddpm:
3
+ prediction_type: v_prediction
4
+ exp_name: 128-hdit_fm_scale_bf16
5
+ hdit_model:
6
+ depths:
7
+ - 2
8
+ - 2
9
+ - 12
10
+ input_channels: 1
11
+ input_size:
12
+ - 128
13
+ - 128
14
+ - 128
15
+ patch_size:
16
+ - 4
17
+ - 4
18
+ - 4
19
+ self_attns:
20
+ - d_head: 64
21
+ kernel_size: 7
22
+ type: neighborhood
23
+ - d_head: 64
24
+ kernel_size: 7
25
+ type: neighborhood
26
+ - d_head: 64
27
+ type: global
28
+ type: image_transformer_v2
29
+ widths:
30
+ - 320
31
+ - 640
32
+ - 1280
33
+ keep_last_k: null
34
+ model_type: hdit
35
+ num_val_samples: 3
36
+ optimizer:
37
+ lr: 0.0001
38
+ warmup: 2000
39
+ patch_size: 128
40
+ process: fm
41
+ seed: 42
42
+ work_dir: work_dirs/128-hdit_fm_scale_bf16_00
43
+ z_crop: null
44
+ z_scale:
45
+ mean: 0.04
46
+ std: 0.09
cryofm-l/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:818ea9a9e53b21f4d07cef941ceaf99dff226f117b9678cbe63bc24937bc85eb
3
+ size 1234168600
cryofm-s/config.yaml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ckpt_path: null
2
+ ddpm:
3
+ prediction_type: v_prediction
4
+ exp_name: 64-hdit_fm_scale_bf16
5
+ hdit_model:
6
+ depths:
7
+ - 4
8
+ - 8
9
+ input_channels: 1
10
+ input_size:
11
+ - 64
12
+ - 64
13
+ - 64
14
+ patch_size:
15
+ - 4
16
+ - 4
17
+ - 4
18
+ self_attns:
19
+ - d_head: 64
20
+ kernel_size: 7
21
+ type: neighborhood
22
+ - d_head: 64
23
+ type: global
24
+ type: image_transformer_v2
25
+ widths:
26
+ - 768
27
+ - 1536
28
+ keep_last_k: null
29
+ mode: train
30
+ model_type: hdit
31
+ num_val_samples: 3
32
+ optimizer:
33
+ lr: 0.0001
34
+ warmup: 2000
35
+ patch_size: 64
36
+ process: fm
37
+ seed: 42
38
+ work_dir: work_dirs/64-hdit_fm_scale_bf16_00
39
+ z_crop: null
40
+ z_scale:
41
+ mean: 0.04
42
+ std: 0.09
cryofm-s/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39b8430620c0a2fad85158412cf22c6e62f5034e21e39801219964141ff5e313
3
+ size 1340760716