|
|
--- |
|
|
license: apache-2.0 |
|
|
tags: |
|
|
- cryo-em |
|
|
- flow-matching |
|
|
- 3d-density-maps |
|
|
- foundation-model |
|
|
- conditional-sampling |
|
|
--- |
|
|
|
|
|
# CryoFM2: A Generative Foundation Model for Cryo-EM Densities |
|
|
|
|
|
<div align="center"> |
|
|
|
|
|
[](https://doi.org/10.64898/2025.12.29.696802) |
|
|
[](https://github.com/ByteDance-Seed/cryofm) |
|
|
[](https://opensource.org/licenses/Apache-2.0) |
|
|
[](https://bytedance-seed.github.io/cryofm/docs/) |
|
|
|
|
|
</div> |
|
|
|
|
|
<div align="center"> |
|
|
<img src="./assets/cryofm2_overview.jpg" alt="CryoFM2 Overview" style="max-width: 100%; height: auto; width: 800px;"/> |
|
|
</div> |
|
|
|
|
|
## Overview |
|
|
|
|
|
**CryoFM2** is a flow-based generative foundation model for cryo-EM density maps. |
|
|
It is pretrained on curated EMDB half maps to learn general priors of high-quality cryo-EM densities and can be fine-tuned for downstream tasks. |
|
|
|
|
|
The model learns a continuous mapping from a simple Gaussian distribution to the complex distribution of cryo-EM densities, enabling stable generation and flexible adaptation. CryoFM2 can also act as a **Bayesian prior**, integrating naturally with task-specific likelihoods to support applications such as anisotropy-aware refinement, non-uniform reconstruction, and controlled density modification. |
|
|
|
|
|
## Model Details |
|
|
|
|
|
CryoFM2 is pretrained on curated EMDB half maps to learn general priors of high-quality cryo-EM densities. The model can be fine-tuned for various downstream tasks such as density map enhancement and post-processing. |
|
|
|
|
|
**Pre-training Architecture:** |
|
|
|
|
|
<div align="center"> |
|
|
<img src="./assets/cryofm2_arch-pretrain.jpg" alt="CryoFM2 architecture for pre-training." style="max-width: 100%; height: auto; width: 800px;"/> |
|
|
</div> |
|
|
|
|
|
**Fine-tuning Architecture (for EMhancer/EMReady style post-processing):** |
|
|
|
|
|
<div align="center"> |
|
|
<img src="./assets/cryofm2_arch-finetune.jpg" alt="CryoFM2 architecture for fine-tuning." style="max-width: 100%; height: auto; width: 800px;"/> |
|
|
</div> |
|
|
|
|
|
### Architecture |
|
|
- **Architecture Type**: 3D UNet |
|
|
- **Input Size**: 64×64×64 voxels |
|
|
- **Input Channels**: 2 for pre-trained model, 3 for fine-tuned model |
|
|
- **Output Channels**: 1 |
|
|
- **Down Blocks**: DownBlock3D, DownBlock3D, AttnDownBlock3D, AttnDownBlock3D |
|
|
- **Up Blocks**: AttnUpBlock3D, AttnUpBlock3D, UpBlock3D, UpBlock3D |
|
|
- **Block Output Channels**: (64, 128, 256, 512) |
|
|
- **Layers per Block**: 2 |
|
|
- **Attention Head Dimension**: 8 |
|
|
- **Normalization**: GroupNorm (32 groups) |
|
|
- **Activation**: SiLU |
|
|
- **Time Embedding**: Positional encoding |
|
|
|
|
|
### Model Variants |
|
|
|
|
|
1. **cryofm2-pretrain**: Unconditional pretrained model for general density map generation |
|
|
2. **cryofm2-emhancer**: Fine-tuned model for density map enhancement (EMhancer style) |
|
|
3. **cryofm2-emready**: Fine-tuned model for density map enhancement (EMReady style) |
|
|
|
|
|
## Play with CryoFM2 |
|
|
|
|
|
### Installation |
|
|
|
|
|
Before using CryoFM2, you need to set up the environment and install the package. Follow these steps to get started: |
|
|
|
|
|
```bash |
|
|
# Clone the repository |
|
|
git clone https://github.com/ByteDance-Seed/cryofm.git |
|
|
cd cryofm |
|
|
|
|
|
# Create a new conda environment for CryoFM (recommended) |
|
|
conda create -n cryofm python=3.10 -y |
|
|
conda activate cryofm |
|
|
|
|
|
# Install CryoFM |
|
|
pip install . |
|
|
``` |
|
|
|
|
|
### Unconditional Generation (Explore Training Data Distribution) |
|
|
|
|
|
Generate samples from the pretrained model to explore the learned data distribution: |
|
|
|
|
|
**Pretrained Model:** |
|
|
```python |
|
|
import torch |
|
|
from mmengine import Config |
|
|
|
|
|
from cryofm.core.utils.mrc_io import save_mrc |
|
|
from cryofm.core.utils.sampling_fm import sample_from_fm |
|
|
from cryofm.projects.cryofm2.lit_modules import CryoFM2Uncond |
|
|
|
|
|
# Update the path to your model directory |
|
|
model_dir = "path/to/cryofm-v2/cryofm2-pretrain" |
|
|
cfg = Config.fromfile(f"{model_dir}/config.yaml") |
|
|
lit_model = CryoFM2Uncond.load_from_safetensors(f"{model_dir}/model.safetensors", cfg=cfg) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
lit_model = lit_model.to(device) |
|
|
lit_model.eval() |
|
|
def v_xt_t(_xt, _t): |
|
|
return lit_model(_xt, _t) |
|
|
|
|
|
# Enable bfloat16 for faster inference if your GPU supports it |
|
|
with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16): |
|
|
out = sample_from_fm( |
|
|
v_xt_t, |
|
|
lit_model.noise_scheduler, |
|
|
method="euler", |
|
|
num_steps=200, |
|
|
num_samples=3, |
|
|
device=lit_model.device, |
|
|
side_shape=64 |
|
|
) |
|
|
# Apply normalization if configured |
|
|
if hasattr(lit_model.cfg, "z_scale") and lit_model.cfg.z_scale.mean is not None: |
|
|
out = out * lit_model.cfg.z_scale.std + lit_model.cfg.z_scale.mean |
|
|
|
|
|
# Save generated samples |
|
|
for i in range(3): |
|
|
save_mrc(out[i].float().cpu().numpy(), f"sample-{i}.mrc", voxel_size=1.5) |
|
|
``` |
|
|
|
|
|
**Fine-tuned Models (EMhancer/EMReady):** |
|
|
```python |
|
|
import torch |
|
|
from mmengine import Config |
|
|
|
|
|
from cryofm.core.utils.mrc_io import save_mrc |
|
|
from cryofm.core.utils.sampling_fm import sample_from_fm |
|
|
from cryofm.projects.cryofm2.lit_modules import CryoFM2Cond |
|
|
|
|
|
# Choose style: "emhancer" or "emready" |
|
|
style = "emhancer" |
|
|
model_dir = f"path/to/cryofm-v2/cryofm2-{style}" |
|
|
cfg = Config.fromfile(f"{model_dir}/config.yaml") |
|
|
lit_model = CryoFM2Cond.load_from_safetensors(f"{model_dir}/model.safetensors", cfg=cfg) |
|
|
output_tag = 1 if style == "emhancer" else 0 |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
lit_model = lit_model.to(device) |
|
|
lit_model.eval() |
|
|
def v_xt_t(_xt, _t): |
|
|
bs = _xt.shape[0] |
|
|
unconditional_generation_conds = { |
|
|
"input_cond": None, |
|
|
"output_cond": torch.tensor([output_tag] * bs).to(device), |
|
|
"vol_cond": None, # dimension should be [bs, d, h, w] |
|
|
} |
|
|
return lit_model(_xt, _t, generation_conds=unconditional_generation_conds) |
|
|
|
|
|
# Enable bfloat16 for faster inference if your GPU supports it |
|
|
with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16): |
|
|
out = sample_from_fm( |
|
|
v_xt_t, |
|
|
lit_model.noise_scheduler, |
|
|
method="euler", |
|
|
num_steps=200, |
|
|
num_samples=3, |
|
|
device=lit_model.device, |
|
|
side_shape=64 |
|
|
) |
|
|
# Apply normalization if configured |
|
|
if hasattr(lit_model.cfg, "z_scale") and lit_model.cfg.z_scale.mean is not None: |
|
|
out = out * lit_model.cfg.z_scale.std + lit_model.cfg.z_scale.mean |
|
|
|
|
|
# Save generated samples |
|
|
for i in range(3): |
|
|
save_mrc(out[i].float().cpu().numpy(), f"{style}-sample-{i}.mrc", voxel_size=1.5) |
|
|
``` |
|
|
|
|
|
### Density Map Modification |
|
|
|
|
|
CryoFM2 supports various density map modification operations using the pretrained model as a Bayesian prior. Supported operators include: |
|
|
|
|
|
- **denoise**: Remove noise from density maps |
|
|
- **inpaint**: Fill missing regions (e.g., missing wedge) |
|
|
- **denoise inpaint**: Combined denoising and inpainting |
|
|
- **non-uniform weight**: Apply non-uniform weighting during reconstruction |
|
|
|
|
|
**Basic Usage:** |
|
|
|
|
|
```bash |
|
|
python -m cryofm.projects.cryofm2.uncond_sampling \ |
|
|
-i1 half_map_1.mrc \ |
|
|
-i2 half_map_2.mrc \ |
|
|
-o ./output \ |
|
|
--model-dir path/to/cryofm-v2/cryofm2-pretrain \ |
|
|
--op denoise \ |
|
|
--norm-grad \ |
|
|
--use-lamb-w |
|
|
``` |
|
|
|
|
|
**For inpainting tasks**, you need to provide a RELION starfile path: |
|
|
|
|
|
```bash |
|
|
python -m cryofm.projects.cryofm2.uncond_sampling \ |
|
|
-i1 half_map_1.mrc \ |
|
|
-i2 half_map_2.mrc \ |
|
|
-o ./output \ |
|
|
--model-dir path/to/cryofm-v2/cryofm2-pretrain \ |
|
|
--op inpaint \ |
|
|
--data-starfile-path path/to/relion_data.star \ |
|
|
--norm-grad \ |
|
|
--use-lamb-w |
|
|
``` |
|
|
|
|
|
### Density Map Post-Processing |
|
|
|
|
|
CryoFM2 provides fine-tuned models for density map enhancement in different styles, similar to EMhancer and EMReady. |
|
|
|
|
|
#### EMhancer Style Enhancement |
|
|
|
|
|
```bash |
|
|
python -m cryofm.projects.cryofm2.cond_sampling \ |
|
|
-i input_map.mrc \ |
|
|
-o ./output_emhancer \ |
|
|
--model-dir path/to/cryofm-v2/cryofm2-emhancer \ |
|
|
--output-tag 1 |
|
|
``` |
|
|
|
|
|
#### EMReady Style Enhancement |
|
|
|
|
|
```bash |
|
|
python -m cryofm.projects.cryofm2.cond_sampling \ |
|
|
-i input_map.mrc \ |
|
|
-o ./output_emready \ |
|
|
--model-dir path/to/cryofm-v2/cryofm2-emready \ |
|
|
--output-tag 0 \ |
|
|
--cfg-weight 0.5 |
|
|
``` |
|
|
|
|
|
**Parameters:** |
|
|
- `-i`: Input density map file (MRC format) |
|
|
- `-o`: Output directory |
|
|
- `--model-dir`: Path to the model directory containing `config.yaml` and `model.safetensors` |
|
|
- `--output-tag`: Style tag (1 for EMhancer, 0 for EMReady) |
|
|
- `--cfg-weight`: Classifier-free guidance weight (optional, default varies by model) |
|
|
|
|
|
|
|
|
## Performance Tips |
|
|
|
|
|
- **Multi-GPU Inference**: Use `accelerate launch` for faster inference on multiple GPUs: |
|
|
```bash |
|
|
NCCL_DEBUG=ERROR accelerate launch --num_processes=${NUM_GPUS} --main_process_port=8881 \ |
|
|
python -m cryofm.projects.cryofm2.cond_sampling ... |
|
|
``` |
|
|
- **Mixed Precision**: Use `--bf16` flag when available to reduce memory usage and speed up inference. |
|
|
- **Batch Processing**: Adjust batch size based on your GPU memory capacity. |
|
|
|
|
|
## Limitations |
|
|
|
|
|
- Input size is fixed at 64×64×64 voxels |
|
|
- Model performance may vary depending on the input density map quality |
|
|
- Fine-tuned models are optimized for specific enhancement styles |
|
|
|
|
|
## Ethical Considerations |
|
|
|
|
|
This model is intended for scientific research and structural biology applications. Users should: |
|
|
- Ensure proper attribution when using generated structures |
|
|
- Validate generated structures through experimental verification |
|
|
- Be aware of potential biases in the training data |
|
|
- Use the model responsibly and in accordance with scientific best practices |
|
|
|
|
|
## Citation |
|
|
|
|
|
If you find CryoFM2 useful, please cite: |
|
|
|
|
|
```bibtex |
|
|
@article{ |
|
|
Li2025.12.29.696802, |
|
|
author={Li, Yilai and Yuan, Jing and Zhou, Yi and Wang, Zhenghua and Chen, Suyi and Yang, Fengyu and Ling, Haibin and Kovalsky, Shahar Z and Zheng, Xiaoqing and Gu, Quanquan}, |
|
|
title={A Generative Foundation Model for Cryo-EM Densities}, |
|
|
elocation-id={2025.12.29.696802}, |
|
|
year={2025}, |
|
|
doi={10.64898/2025.12.29.696802}, |
|
|
publisher={Cold Spring Harbor Laboratory}, |
|
|
URL={https://www.biorxiv.org/content/early/2025/12/29/2025.12.29.696802}, |
|
|
eprint={https://www.biorxiv.org/content/early/2025/12/29/2025.12.29.696802.full.pdf}, |
|
|
journal={bioRxiv} |
|
|
} |
|
|
``` |
|
|
|
|
|
## License |
|
|
|
|
|
This model is released under the Apache 2.0 License. See the [LICENSE](https://github.com/ByteDance-Seed/cryofm/blob/main/LICENSE) file for details. |
|
|
|
|
|
## Acknowledgments |
|
|
|
|
|
This work is developed by the ByteDance Seed Team. For more information, visit: |
|
|
- [Project Repository](https://github.com/ByteDance-Seed/cryofm) |
|
|
- [ByteDance Seed Team](https://seed.bytedance.com/) |
|
|
|
|
|
|