File size: 7,890 Bytes
d066174
 
fcbab52
 
d066174
 
 
 
 
 
 
 
 
 
 
 
 
 
cf2db33
d066174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf2db33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d066174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf2db33
 
 
 
 
 
d066174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
---
license: apache-2.0
pipeline_tag: other
library_name: cryofm
tags:
- cryo-em
- flow-matching
- 3d-density-maps
- foundation-model
---

# CryoFM: Flow-based Foundation Model for Cryo-EM Density Maps

<div align="center">

[![arXiv](https://img.shields.io/badge/arXiv-2410.08631-B31B1B?logo=arxiv&logoColor=white)](https://arxiv.org/abs/2410.08631)
[![GitHub](https://img.shields.io/badge/GitHub-cryofm-181717?logo=github&logoColor=white)](https://github.com/ByteDance-Seed/cryofm)
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
[![Docs](https://img.shields.io/badge/Docs-cryofm-4CAF50?logo=read-the-docs&logoColor=white)](https://bytedance-seed.github.io/cryofm/docs/)

</div>

<div align="center">
  <img src="./assets/cryofm.gif" alt="CryoFM Demo" style="max-width: 100%; height: auto; width: 800px;"/>
</div>

## Model Description

CryoFM1 is a flow-based foundation model for 3D cryo-electron microscopy (cryo-EM) density maps. The model employs a Hierarchical Diffusion Transformer (HDiT) architecture, specifically designed to learn deep priors of 3D cryo-EM densities. CryoFM1 supports various downstream tasks including density map denoising, anisotropy noise correction, missing wedge inpainting, and *ab initio* modeling.

### Key Features

- **Flow Matching Framework**: Uses flow matching for efficient and stable training
- **HDiT Architecture**: Hierarchical Diffusion Transformer with local and global attention mechanisms
- **Two Model Variants**: CryoFM-S (64³) and CryoFM-L (128³) for different resolution needs
- **Downstream Task Support**: Denoising, anisotropy noise correction, missing wedge restoration, and more

## Model Details

CryoFM1 employs a Hierarchical Diffusion Transformer (HDiT) architecture that combines local neighborhood attention with global attention mechanisms. This design enables the model to effectively capture both fine-grained local structures and long-range dependencies in 3D cryo-EM density maps. The architecture processes 3D volumes through a hierarchical patch-based approach, progressively building representations at multiple scales.

<div align="center">
  <img src="./assets/cryofm_archs.jpg" alt="CryoFM Architecture" style="max-width: 100%; height: auto; width: 600px;"/>
</div>

The model is available in two variants optimized for different resolution requirements. The following table summarizes the key architectural and training parameters for each variant:

| Parameter | CRYOFM-S | CRYOFM-L |
|-----------|----------|----------|
| **Parameters** | 335.18 M | 308.54 M |
| **GFLOP/forward** | 395.87 | 427.26 |
| **Training Steps** | 150k | 300k |
| **Batch Size** | 128 | 128 |
| **Precision** | bf16 | bf16 |
| **Training Hardware** | 8×A100 | 8×A100 |
| **Patchifying** | 4 | 4 |
| **Levels (Local + Global Attention)** | 1 + 1 | 2 + 1 |
| **Depth** | [4, 8] | [2, 2, 12] |
| **Widths** | [768, 1536] | [320, 640, 1280] |
| **Attention Heads (Width / Head Dim)** | [12, 24] | [5, 10, 20] |
| **Attention Head Dim** | 64 | 64 |
| **Neighborhood Kernel Size** | 7 | 7 |

## Quick Start

### Installation

Before using CryoFM1, ensure you have:

#### 1. Install CryoFM with compatible dependencies

CryoFM1 uses the HDiT model architecture, which depends on the `natten` package. Different versions of `natten` have varying requirements for PyTorch and CUDA versions. For a reproducible installation, follow these steps:

```bash
# natten 0.17.5 uses type union syntax, you must use python >=3.10
conda create -n cryofm python=3.10 -y
conda activate cryofm

# Install PyTorch 2.5.1 with CUDA 12.4 support
pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu124

# Install natten 0.17.5 compatible with PyTorch 2.5.0 and CUDA 12.4
pip install natten==0.17.5+torch250cu124 -f https://whl.natten.org

# Clone and install CryoFM
git clone https://github.com/ByteDance-Seed/cryofm
cd cryofm
pip install .
```

#### 2. Download model checkpoints and configuration files

Download the CryoFM1 model weights and configuration files from the [Hugging Face repository](https://huggingface.co/ByteDance-Seed/cryofm-v1).


### Unconditional Generation

CryoFM1 provides two model variants for different resolution needs:
- **CryoFM-S**: Generates 64×64×64 voxel density maps at 1.5 Å/pixel resolution
- **CryoFM-L**: Generates 128×128×128 voxel density maps at 3.0 Å/pixel resolution


```python
import torch
from mmengine import Config
from cryofm.core.utils.mrc_io import save_mrc
from cryofm.projects.cryofm1.lit_modules import CryoFM1
from cryofm.core.utils.sampling_fm import sample_from_fm

# Choose model variant: "cryofm-s" or "cryofm-l"
model_variant = "cryofm-s"  # or "cryofm-l"
model_config = {
    "cryofm-s": {
        "config_path": "cryofm-v1/cryofm-s/config.yaml",
        "model_path": "cryofm-v1/cryofm-s/model.safetensors",
        "side_shape": 64,
        "apix": 1.5
    },
    "cryofm-l": {
        "config_path": "cryofm-v1/cryofm-l/config.yaml",
        "model_path": "cryofm-v1/cryofm-l/model.safetensors",
        "side_shape": 128,
        "apix": 3.0
    }
}

# Load configuration and model
cfg = Config.fromfile(model_config[model_variant]["config_path"])
lit_model = CryoFM1.load_from_safetensors(
    model_config[model_variant]["model_path"], 
    cfg=cfg
)

# Set up device and model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lit_model = lit_model.to(device)
lit_model.eval()

# Define vector field function for flow matching
def v_xt_t(_xt, _t):
    return lit_model(_xt, _t)

# Generate samples
# Note: Enable bfloat16 if your GPU supports it for better performance
with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16):
    out = sample_from_fm(
        v_xt_t, 
        lit_model.noise_scheduler, 
        method="euler", 
        num_steps=200, 
        num_samples=3, 
        device=device, 
        side_shape=model_config[model_variant]["side_shape"]
    )
    # Apply z-scaling normalization if configured
    if hasattr(lit_model.cfg, "z_scale") and lit_model.cfg.z_scale.mean is not None:
        out = out * lit_model.cfg.z_scale.std + lit_model.cfg.z_scale.mean

# Save generated density maps
for i in range(3):
    save_mrc(
        out[i].float().cpu().numpy(), 
        f"sample-{i}.mrc", 
        apix=model_config[model_variant]["apix"]  # Angstroms per pixel
    )
```

### Downstream Tasks

CryoFM1 demonstrates various downstream tasks including density map denoising, anisotropy noise correction, and missing wedge restoration. For detailed instructions on how to run these tasks, please refer to the [Downstream Tasks documentation](https://bytedance-seed.github.io/cryofm/docs/model-guides/cryofm1/downstream-tasks.html).


## Ethical Considerations

This model is intended for scientific research and structural biology applications. Users should:
- Ensure proper attribution when using generated structures
- Validate generated structures through experimental verification
- Be aware of potential biases in the training data

## Citation

If you use CryoFM1 in your research, please cite:

```bibtex
@inproceedings{
  zhou2025cryofm,
  title={Cryo{FM}: A Flow-based Foundation Model for Cryo-{EM} Densities},
  author={Yi Zhou and Yilai Li and Jing Yuan and Quanquan Gu},
  booktitle={The Thirteenth International Conference on Learning Representations},
  year={2025},
  url={https://openreview.net/forum?id=T4sMzjy7fO}
}
```

## License

This model is released under the Apache 2.0 License. See the [LICENSE](https://github.com/ByteDance-Seed/cryofm/blob/main/LICENSE) file for details.

## Acknowledgments

This work is developed by the ByteDance Seed Team. For more information, visit:
- [Project Repository](https://github.com/ByteDance-Seed/cryofm)
- [ByteDance Seed Team](https://seed.bytedance.com/)