|
|
--- |
|
|
license: apache-2.0 |
|
|
tags: |
|
|
- cryo-em |
|
|
- flow-matching |
|
|
- 3d-density-maps |
|
|
- foundation-model |
|
|
--- |
|
|
|
|
|
# CryoFM: Flow-based Foundation Model for Cryo-EM Density Maps |
|
|
|
|
|
<div align="center"> |
|
|
|
|
|
[](https://arxiv.org/abs/2410.08631) |
|
|
[](https://github.com/ByteDance-Seed/cryofm) |
|
|
[](https://opensource.org/licenses/Apache-2.0) |
|
|
[](https://bytedance-seed.github.io/cryofm/docs/) |
|
|
|
|
|
</div> |
|
|
|
|
|
<div align="center"> |
|
|
<img src="./assets/cryofm.gif" alt="CryoFM Demo" style="max-width: 100%; height: auto; width: 800px;"/> |
|
|
</div> |
|
|
|
|
|
## Model Description |
|
|
|
|
|
CryoFM1 is a flow-based foundation model for 3D cryo-electron microscopy (cryo-EM) density maps. The model employs a Hierarchical Diffusion Transformer (HDiT) architecture, specifically designed to learn deep priors of 3D cryo-EM densities. CryoFM1 supports various downstream tasks including density map denoising, anisotropy noise correction, missing wedge inpainting, and *ab initio* modeling. |
|
|
|
|
|
### Key Features |
|
|
|
|
|
- **Flow Matching Framework**: Uses flow matching for efficient and stable training |
|
|
- **HDiT Architecture**: Hierarchical Diffusion Transformer with local and global attention mechanisms |
|
|
- **Two Model Variants**: CryoFM-S (64³) and CryoFM-L (128³) for different resolution needs |
|
|
- **Downstream Task Support**: Denoising, anisotropy noise correction, missing wedge restoration, and more |
|
|
|
|
|
## Model Details |
|
|
|
|
|
CryoFM1 employs a Hierarchical Diffusion Transformer (HDiT) architecture that combines local neighborhood attention with global attention mechanisms. This design enables the model to effectively capture both fine-grained local structures and long-range dependencies in 3D cryo-EM density maps. The architecture processes 3D volumes through a hierarchical patch-based approach, progressively building representations at multiple scales. |
|
|
|
|
|
<div align="center"> |
|
|
<img src="./assets/cryofm_archs.jpg" alt="CryoFM Architecture" style="max-width: 100%; height: auto; width: 600px;"/> |
|
|
</div> |
|
|
|
|
|
The model is available in two variants optimized for different resolution requirements. The following table summarizes the key architectural and training parameters for each variant: |
|
|
|
|
|
| Parameter | CRYOFM-S | CRYOFM-L | |
|
|
|-----------|----------|----------| |
|
|
| **Parameters** | 335.18 M | 308.54 M | |
|
|
| **GFLOP/forward** | 395.87 | 427.26 | |
|
|
| **Training Steps** | 150k | 300k | |
|
|
| **Batch Size** | 128 | 128 | |
|
|
| **Precision** | bf16 | bf16 | |
|
|
| **Training Hardware** | 8×A100 | 8×A100 | |
|
|
| **Patchifying** | 4 | 4 | |
|
|
| **Levels (Local + Global Attention)** | 1 + 1 | 2 + 1 | |
|
|
| **Depth** | [4, 8] | [2, 2, 12] | |
|
|
| **Widths** | [768, 1536] | [320, 640, 1280] | |
|
|
| **Attention Heads (Width / Head Dim)** | [12, 24] | [5, 10, 20] | |
|
|
| **Attention Head Dim** | 64 | 64 | |
|
|
| **Neighborhood Kernel Size** | 7 | 7 | |
|
|
|
|
|
## Quick Start |
|
|
|
|
|
### Installation |
|
|
|
|
|
Before using CryoFM1, ensure you have: |
|
|
|
|
|
#### 1. Install CryoFM with compatible dependencies |
|
|
|
|
|
CryoFM1 uses the HDiT model architecture, which depends on the `natten` package. Different versions of `natten` have varying requirements for PyTorch and CUDA versions. For a reproducible installation, follow these steps: |
|
|
|
|
|
```bash |
|
|
# natten 0.17.5 uses type union syntax, you must use python >=3.10 |
|
|
conda create -n cryofm python=3.10 -y |
|
|
conda activate cryofm |
|
|
|
|
|
# Install PyTorch 2.5.1 with CUDA 12.4 support |
|
|
pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu124 |
|
|
|
|
|
# Install natten 0.17.5 compatible with PyTorch 2.5.0 and CUDA 12.4 |
|
|
pip install natten==0.17.5+torch250cu124 -f https://whl.natten.org |
|
|
|
|
|
# Clone and install CryoFM |
|
|
git clone https://github.com/ByteDance-Seed/cryofm |
|
|
cd cryofm |
|
|
pip install . |
|
|
``` |
|
|
|
|
|
#### 2. Download model checkpoints and configuration files |
|
|
|
|
|
Download the CryoFM1 model weights and configuration files from the [Hugging Face repository](https://huggingface.co/ByteDance-Seed/cryofm-v1). |
|
|
|
|
|
|
|
|
### Unconditional Generation |
|
|
|
|
|
CryoFM1 provides two model variants for different resolution needs: |
|
|
- **CryoFM-S**: Generates 64×64×64 voxel density maps at 1.5 Å/pixel resolution |
|
|
- **CryoFM-L**: Generates 128×128×128 voxel density maps at 3.0 Å/pixel resolution |
|
|
|
|
|
|
|
|
```python |
|
|
import torch |
|
|
from mmengine import Config |
|
|
from cryofm.core.utils.mrc_io import save_mrc |
|
|
from cryofm.projects.cryofm1.lit_modules import CryoFM1 |
|
|
from cryofm.core.utils.sampling_fm import sample_from_fm |
|
|
|
|
|
# Choose model variant: "cryofm-s" or "cryofm-l" |
|
|
model_variant = "cryofm-s" # or "cryofm-l" |
|
|
model_config = { |
|
|
"cryofm-s": { |
|
|
"config_path": "cryofm-v1/cryofm-s/config.yaml", |
|
|
"model_path": "cryofm-v1/cryofm-s/model.safetensors", |
|
|
"side_shape": 64, |
|
|
"apix": 1.5 |
|
|
}, |
|
|
"cryofm-l": { |
|
|
"config_path": "cryofm-v1/cryofm-l/config.yaml", |
|
|
"model_path": "cryofm-v1/cryofm-l/model.safetensors", |
|
|
"side_shape": 128, |
|
|
"apix": 3.0 |
|
|
} |
|
|
} |
|
|
|
|
|
# Load configuration and model |
|
|
cfg = Config.fromfile(model_config[model_variant]["config_path"]) |
|
|
lit_model = CryoFM1.load_from_safetensors( |
|
|
model_config[model_variant]["model_path"], |
|
|
cfg=cfg |
|
|
) |
|
|
|
|
|
# Set up device and model |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
lit_model = lit_model.to(device) |
|
|
lit_model.eval() |
|
|
|
|
|
# Define vector field function for flow matching |
|
|
def v_xt_t(_xt, _t): |
|
|
return lit_model(_xt, _t) |
|
|
|
|
|
# Generate samples |
|
|
# Note: Enable bfloat16 if your GPU supports it for better performance |
|
|
with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16): |
|
|
out = sample_from_fm( |
|
|
v_xt_t, |
|
|
lit_model.noise_scheduler, |
|
|
method="euler", |
|
|
num_steps=200, |
|
|
num_samples=3, |
|
|
device=device, |
|
|
side_shape=model_config[model_variant]["side_shape"] |
|
|
) |
|
|
# Apply z-scaling normalization if configured |
|
|
if hasattr(lit_model.cfg, "z_scale") and lit_model.cfg.z_scale.mean is not None: |
|
|
out = out * lit_model.cfg.z_scale.std + lit_model.cfg.z_scale.mean |
|
|
|
|
|
# Save generated density maps |
|
|
for i in range(3): |
|
|
save_mrc( |
|
|
out[i].float().cpu().numpy(), |
|
|
f"sample-{i}.mrc", |
|
|
apix=model_config[model_variant]["apix"] # Angstroms per pixel |
|
|
) |
|
|
``` |
|
|
|
|
|
### Downstream Tasks |
|
|
|
|
|
CryoFM1 demonstrates various downstream tasks including density map denoising, anisotropy noise correction, and missing wedge restoration. For detailed instructions on how to run these tasks, please refer to the [Downstream Tasks documentation](https://bytedance-seed.github.io/cryofm/docs/model-guides/cryofm1/downstream-tasks.html). |
|
|
|
|
|
|
|
|
## Ethical Considerations |
|
|
|
|
|
This model is intended for scientific research and structural biology applications. Users should: |
|
|
- Ensure proper attribution when using generated structures |
|
|
- Validate generated structures through experimental verification |
|
|
- Be aware of potential biases in the training data |
|
|
|
|
|
## Citation |
|
|
|
|
|
If you use CryoFM1 in your research, please cite: |
|
|
|
|
|
```bibtex |
|
|
@inproceedings{ |
|
|
zhou2025cryofm, |
|
|
title={Cryo{FM}: A Flow-based Foundation Model for Cryo-{EM} Densities}, |
|
|
author={Yi Zhou and Yilai Li and Jing Yuan and Quanquan Gu}, |
|
|
booktitle={The Thirteenth International Conference on Learning Representations}, |
|
|
year={2025}, |
|
|
url={https://openreview.net/forum?id=T4sMzjy7fO} |
|
|
} |
|
|
``` |
|
|
|
|
|
## License |
|
|
|
|
|
This model is released under the Apache 2.0 License. See the [LICENSE](https://github.com/ByteDance-Seed/cryofm/blob/main/LICENSE) file for details. |
|
|
|
|
|
## Acknowledgments |
|
|
|
|
|
This work is developed by the ByteDance Seed Team. For more information, visit: |
|
|
- [Project Repository](https://github.com/ByteDance-Seed/cryofm) |
|
|
- [ByteDance Seed Team](https://seed.bytedance.com/) |
|
|
|