Add CryoFM model weights and configurations
Browse files- Add CryoFM-S and CryoFM-L model variants
- Include model configs and safetensors checkpoints
- Add README with model description and usage examples
- .gitattributes +2 -0
- README.md +165 -3
- assets/cryofm.gif +3 -0
- assets/cryofm_archs.jpg +3 -0
- cryofm-l/config.yaml +46 -0
- cryofm-l/model.safetensors +3 -0
- cryofm-s/config.yaml +42 -0
- cryofm-s/model.safetensors +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/cryofm_archs.jpg filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/cryofm.gif filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,3 +1,165 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: apache-2.0
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
tags:
|
| 4 |
+
- cryo-em
|
| 5 |
+
- flow-matching
|
| 6 |
+
- 3d-density-maps
|
| 7 |
+
- foundation-model
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# CryoFM: Flow-based Foundation Model for Cryo-EM Density Maps
|
| 11 |
+
|
| 12 |
+
<div align="center">
|
| 13 |
+
|
| 14 |
+
[](https://arxiv.org/abs/2410.08631)
|
| 15 |
+
[](https://github.com/ByteDance-Seed/cryofm)
|
| 16 |
+
[](https://opensource.org/licenses/Apache-2.0)
|
| 17 |
+
|
| 18 |
+
</div>
|
| 19 |
+
|
| 20 |
+
<div align="center">
|
| 21 |
+
<img src="./assets/cryofm.gif" alt="CryoFM Demo" style="max-width: 100%; height: auto; width: 800px;"/>
|
| 22 |
+
</div>
|
| 23 |
+
|
| 24 |
+
## Model Description
|
| 25 |
+
|
| 26 |
+
CryoFM1 is a flow-based foundation model for 3D cryo-electron microscopy (cryo-EM) density maps. The model employs a Hierarchical Diffusion Transformer (HDiT) architecture, specifically designed to learn deep priors of 3D cryo-EM densities. CryoFM1 supports various downstream tasks including density map denoising, anisotropy noise correction, missing wedge inpainting, and *ab initio* modeling.
|
| 27 |
+
|
| 28 |
+
### Key Features
|
| 29 |
+
|
| 30 |
+
- **Flow Matching Framework**: Uses flow matching for efficient and stable training
|
| 31 |
+
- **HDiT Architecture**: Hierarchical Diffusion Transformer with local and global attention mechanisms
|
| 32 |
+
- **Two Model Variants**: CryoFM-S (64³) and CryoFM-L (128³) for different resolution needs
|
| 33 |
+
- **Downstream Task Support**: Denoising, anisotropy noise correction, missing wedge restoration, and more
|
| 34 |
+
|
| 35 |
+
## Model Details
|
| 36 |
+
|
| 37 |
+
CryoFM1 employs a Hierarchical Diffusion Transformer (HDiT) architecture that combines local neighborhood attention with global attention mechanisms. This design enables the model to effectively capture both fine-grained local structures and long-range dependencies in 3D cryo-EM density maps. The architecture processes 3D volumes through a hierarchical patch-based approach, progressively building representations at multiple scales.
|
| 38 |
+
|
| 39 |
+
<div align="center">
|
| 40 |
+
<img src="./assets/cryofm_archs.jpg" alt="CryoFM Architecture" style="max-width: 100%; height: auto; width: 600px;"/>
|
| 41 |
+
</div>
|
| 42 |
+
|
| 43 |
+
The model is available in two variants optimized for different resolution requirements. The following table summarizes the key architectural and training parameters for each variant:
|
| 44 |
+
|
| 45 |
+
| Parameter | CRYOFM-S | CRYOFM-L |
|
| 46 |
+
|-----------|----------|----------|
|
| 47 |
+
| **Parameters** | 335.18 M | 308.54 M |
|
| 48 |
+
| **GFLOP/forward** | 395.87 | 427.26 |
|
| 49 |
+
| **Training Steps** | 150k | 300k |
|
| 50 |
+
| **Batch Size** | 128 | 128 |
|
| 51 |
+
| **Precision** | bf16 | bf16 |
|
| 52 |
+
| **Training Hardware** | 8×A100 | 8×A100 |
|
| 53 |
+
| **Patchifying** | 4 | 4 |
|
| 54 |
+
| **Levels (Local + Global Attention)** | 1 + 1 | 2 + 1 |
|
| 55 |
+
| **Depth** | [4, 8] | [2, 2, 12] |
|
| 56 |
+
| **Widths** | [768, 1536] | [320, 640, 1280] |
|
| 57 |
+
| **Attention Heads (Width / Head Dim)** | [12, 24] | [5, 10, 20] |
|
| 58 |
+
| **Attention Head Dim** | 64 | 64 |
|
| 59 |
+
| **Neighborhood Kernel Size** | 7 | 7 |
|
| 60 |
+
|
| 61 |
+
## Quick Start
|
| 62 |
+
|
| 63 |
+
### Unconditional Generation
|
| 64 |
+
|
| 65 |
+
CryoFM1 provides two model variants for different resolution needs:
|
| 66 |
+
- **CryoFM-S**: Generates 64×64×64 voxel density maps at 1.5 Å/pixel resolution
|
| 67 |
+
- **CryoFM-L**: Generates 128×128×128 voxel density maps at 3.0 Å/pixel resolution
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
```python
|
| 71 |
+
import torch
|
| 72 |
+
from mmengine import Config
|
| 73 |
+
from cryofm.core.utils.mrc_io import save_mrc
|
| 74 |
+
from cryofm.projects.cryofm1.lit_modules import CryoFM1
|
| 75 |
+
from cryofm.core.utils.sampling_fm import sample_from_fm
|
| 76 |
+
|
| 77 |
+
# Choose model variant: "cryofm-s" or "cryofm-l"
|
| 78 |
+
model_variant = "cryofm-s" # or "cryofm-l"
|
| 79 |
+
model_config = {
|
| 80 |
+
"cryofm-s": {
|
| 81 |
+
"config_path": "cryofm-v1/cryofm-s/config.yaml",
|
| 82 |
+
"model_path": "cryofm-v1/cryofm-s/model.safetensors",
|
| 83 |
+
"side_shape": 64,
|
| 84 |
+
"apix": 1.5
|
| 85 |
+
},
|
| 86 |
+
"cryofm-l": {
|
| 87 |
+
"config_path": "cryofm-v1/cryofm-l/config.yaml",
|
| 88 |
+
"model_path": "cryofm-v1/cryofm-l/model.safetensors",
|
| 89 |
+
"side_shape": 128,
|
| 90 |
+
"apix": 3.0
|
| 91 |
+
}
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
# Load configuration and model
|
| 95 |
+
cfg = Config.fromfile(model_config[model_variant]["config_path"])
|
| 96 |
+
lit_model = CryoFM1.load_from_safetensors(
|
| 97 |
+
model_config[model_variant]["model_path"],
|
| 98 |
+
cfg=cfg
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# Set up device and model
|
| 102 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 103 |
+
lit_model = lit_model.to(device)
|
| 104 |
+
lit_model.eval()
|
| 105 |
+
|
| 106 |
+
# Define vector field function for flow matching
|
| 107 |
+
def v_xt_t(_xt, _t):
|
| 108 |
+
return lit_model(_xt, _t)
|
| 109 |
+
|
| 110 |
+
# Generate samples
|
| 111 |
+
# Note: Enable bfloat16 if your GPU supports it for better performance
|
| 112 |
+
with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
|
| 113 |
+
out = sample_from_fm(
|
| 114 |
+
v_xt_t,
|
| 115 |
+
lit_model.noise_scheduler,
|
| 116 |
+
method="euler",
|
| 117 |
+
num_steps=200,
|
| 118 |
+
num_samples=3,
|
| 119 |
+
device=device,
|
| 120 |
+
side_shape=model_config[model_variant]["side_shape"]
|
| 121 |
+
)
|
| 122 |
+
# Apply z-scaling normalization if configured
|
| 123 |
+
if hasattr(lit_model.cfg, "z_scale") and lit_model.cfg.z_scale.mean is not None:
|
| 124 |
+
out = out * lit_model.cfg.z_scale.std + lit_model.cfg.z_scale.mean
|
| 125 |
+
|
| 126 |
+
# Save generated density maps
|
| 127 |
+
for i in range(3):
|
| 128 |
+
save_mrc(
|
| 129 |
+
out[i].float().cpu().numpy(),
|
| 130 |
+
f"sample-{i}.mrc",
|
| 131 |
+
apix=model_config[model_variant]["apix"] # Angstroms per pixel
|
| 132 |
+
)
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
### Ethical Considerations
|
| 136 |
+
|
| 137 |
+
This model is intended for scientific research and structural biology applications. Users should:
|
| 138 |
+
- Ensure proper attribution when using generated structures
|
| 139 |
+
- Validate generated structures through experimental verification
|
| 140 |
+
- Be aware of potential biases in the training data
|
| 141 |
+
|
| 142 |
+
## Citation
|
| 143 |
+
|
| 144 |
+
If you use CryoFM1 in your research, please cite:
|
| 145 |
+
|
| 146 |
+
```bibtex
|
| 147 |
+
@inproceedings{
|
| 148 |
+
zhou2025cryofm,
|
| 149 |
+
title={Cryo{FM}: A Flow-based Foundation Model for Cryo-{EM} Densities},
|
| 150 |
+
author={Yi Zhou and Yilai Li and Jing Yuan and Quanquan Gu},
|
| 151 |
+
booktitle={The Thirteenth International Conference on Learning Representations},
|
| 152 |
+
year={2025},
|
| 153 |
+
url={https://openreview.net/forum?id=T4sMzjy7fO}
|
| 154 |
+
}
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
## License
|
| 158 |
+
|
| 159 |
+
This model is released under the Apache 2.0 License. See the [LICENSE](https://github.com/ByteDance-Seed/cryofm/blob/main/LICENSE) file for details.
|
| 160 |
+
|
| 161 |
+
## Acknowledgments
|
| 162 |
+
|
| 163 |
+
This work is developed by the ByteDance Seed Team. For more information, visit:
|
| 164 |
+
- [Project Repository](https://github.com/ByteDance-Seed/cryofm)
|
| 165 |
+
- [ByteDance Seed Team](https://seed.bytedance.com/)
|
assets/cryofm.gif
ADDED
|
Git LFS Details
|
assets/cryofm_archs.jpg
ADDED
|
Git LFS Details
|
cryofm-l/config.yaml
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ckpt_path: null
|
| 2 |
+
ddpm:
|
| 3 |
+
prediction_type: v_prediction
|
| 4 |
+
exp_name: 128-hdit_fm_scale_bf16
|
| 5 |
+
hdit_model:
|
| 6 |
+
depths:
|
| 7 |
+
- 2
|
| 8 |
+
- 2
|
| 9 |
+
- 12
|
| 10 |
+
input_channels: 1
|
| 11 |
+
input_size:
|
| 12 |
+
- 128
|
| 13 |
+
- 128
|
| 14 |
+
- 128
|
| 15 |
+
patch_size:
|
| 16 |
+
- 4
|
| 17 |
+
- 4
|
| 18 |
+
- 4
|
| 19 |
+
self_attns:
|
| 20 |
+
- d_head: 64
|
| 21 |
+
kernel_size: 7
|
| 22 |
+
type: neighborhood
|
| 23 |
+
- d_head: 64
|
| 24 |
+
kernel_size: 7
|
| 25 |
+
type: neighborhood
|
| 26 |
+
- d_head: 64
|
| 27 |
+
type: global
|
| 28 |
+
type: image_transformer_v2
|
| 29 |
+
widths:
|
| 30 |
+
- 320
|
| 31 |
+
- 640
|
| 32 |
+
- 1280
|
| 33 |
+
keep_last_k: null
|
| 34 |
+
model_type: hdit
|
| 35 |
+
num_val_samples: 3
|
| 36 |
+
optimizer:
|
| 37 |
+
lr: 0.0001
|
| 38 |
+
warmup: 2000
|
| 39 |
+
patch_size: 128
|
| 40 |
+
process: fm
|
| 41 |
+
seed: 42
|
| 42 |
+
work_dir: work_dirs/128-hdit_fm_scale_bf16_00
|
| 43 |
+
z_crop: null
|
| 44 |
+
z_scale:
|
| 45 |
+
mean: 0.04
|
| 46 |
+
std: 0.09
|
cryofm-l/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:818ea9a9e53b21f4d07cef941ceaf99dff226f117b9678cbe63bc24937bc85eb
|
| 3 |
+
size 1234168600
|
cryofm-s/config.yaml
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ckpt_path: null
|
| 2 |
+
ddpm:
|
| 3 |
+
prediction_type: v_prediction
|
| 4 |
+
exp_name: 64-hdit_fm_scale_bf16
|
| 5 |
+
hdit_model:
|
| 6 |
+
depths:
|
| 7 |
+
- 4
|
| 8 |
+
- 8
|
| 9 |
+
input_channels: 1
|
| 10 |
+
input_size:
|
| 11 |
+
- 64
|
| 12 |
+
- 64
|
| 13 |
+
- 64
|
| 14 |
+
patch_size:
|
| 15 |
+
- 4
|
| 16 |
+
- 4
|
| 17 |
+
- 4
|
| 18 |
+
self_attns:
|
| 19 |
+
- d_head: 64
|
| 20 |
+
kernel_size: 7
|
| 21 |
+
type: neighborhood
|
| 22 |
+
- d_head: 64
|
| 23 |
+
type: global
|
| 24 |
+
type: image_transformer_v2
|
| 25 |
+
widths:
|
| 26 |
+
- 768
|
| 27 |
+
- 1536
|
| 28 |
+
keep_last_k: null
|
| 29 |
+
mode: train
|
| 30 |
+
model_type: hdit
|
| 31 |
+
num_val_samples: 3
|
| 32 |
+
optimizer:
|
| 33 |
+
lr: 0.0001
|
| 34 |
+
warmup: 2000
|
| 35 |
+
patch_size: 64
|
| 36 |
+
process: fm
|
| 37 |
+
seed: 42
|
| 38 |
+
work_dir: work_dirs/64-hdit_fm_scale_bf16_00
|
| 39 |
+
z_crop: null
|
| 40 |
+
z_scale:
|
| 41 |
+
mean: 0.04
|
| 42 |
+
std: 0.09
|
cryofm-s/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:39b8430620c0a2fad85158412cf22c6e62f5034e21e39801219964141ff5e313
|
| 3 |
+
size 1340760716
|