--- license: apache-2.0 tags: - cryo-em - flow-matching - 3d-density-maps - foundation-model - conditional-sampling --- # CryoFM2: A Generative Foundation Model for Cryo-EM Densities
[![Tech Report](https://img.shields.io/badge/Tech%20Report-bioRxiv-0066CC?logo=doi&logoColor=white)](https://doi.org/10.64898/2025.12.29.696802) [![GitHub](https://img.shields.io/badge/GitHub-cryofm-181717?logo=github&logoColor=white)](https://github.com/ByteDance-Seed/cryofm) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![Docs](https://img.shields.io/badge/Docs-cryofm-4CAF50?logo=read-the-docs&logoColor=white)](https://bytedance-seed.github.io/cryofm/docs/)
CryoFM2 Overview
## Overview **CryoFM2** is a flow-based generative foundation model for cryo-EM density maps. It is pretrained on curated EMDB half maps to learn general priors of high-quality cryo-EM densities and can be fine-tuned for downstream tasks. The model learns a continuous mapping from a simple Gaussian distribution to the complex distribution of cryo-EM densities, enabling stable generation and flexible adaptation. CryoFM2 can also act as a **Bayesian prior**, integrating naturally with task-specific likelihoods to support applications such as anisotropy-aware refinement, non-uniform reconstruction, and controlled density modification. ## Model Details CryoFM2 is pretrained on curated EMDB half maps to learn general priors of high-quality cryo-EM densities. The model can be fine-tuned for various downstream tasks such as density map enhancement and post-processing. **Pre-training Architecture:**
CryoFM2 architecture for pre-training.
**Fine-tuning Architecture (for EMhancer/EMReady style post-processing):**
CryoFM2 architecture for fine-tuning.
### Architecture - **Architecture Type**: 3D UNet - **Input Size**: 64×64×64 voxels - **Input Channels**: 2 for pre-trained model, 3 for fine-tuned model - **Output Channels**: 1 - **Down Blocks**: DownBlock3D, DownBlock3D, AttnDownBlock3D, AttnDownBlock3D - **Up Blocks**: AttnUpBlock3D, AttnUpBlock3D, UpBlock3D, UpBlock3D - **Block Output Channels**: (64, 128, 256, 512) - **Layers per Block**: 2 - **Attention Head Dimension**: 8 - **Normalization**: GroupNorm (32 groups) - **Activation**: SiLU - **Time Embedding**: Positional encoding ### Model Variants 1. **cryofm2-pretrain**: Unconditional pretrained model for general density map generation 2. **cryofm2-emhancer**: Fine-tuned model for density map enhancement (EMhancer style) 3. **cryofm2-emready**: Fine-tuned model for density map enhancement (EMReady style) ## Play with CryoFM2 ### Installation Before using CryoFM2, you need to set up the environment and install the package. Follow these steps to get started: ```bash # Clone the repository git clone https://github.com/ByteDance-Seed/cryofm.git cd cryofm # Create a new conda environment for CryoFM (recommended) conda create -n cryofm python=3.10 -y conda activate cryofm # Install CryoFM pip install . ``` ### Unconditional Generation (Explore Training Data Distribution) Generate samples from the pretrained model to explore the learned data distribution: **Pretrained Model:** ```python import torch from mmengine import Config from cryofm.core.utils.mrc_io import save_mrc from cryofm.core.utils.sampling_fm import sample_from_fm from cryofm.projects.cryofm2.lit_modules import CryoFM2Uncond # Update the path to your model directory model_dir = "path/to/cryofm-v2/cryofm2-pretrain" cfg = Config.fromfile(f"{model_dir}/config.yaml") lit_model = CryoFM2Uncond.load_from_safetensors(f"{model_dir}/model.safetensors", cfg=cfg) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") lit_model = lit_model.to(device) lit_model.eval() def v_xt_t(_xt, _t): return lit_model(_xt, _t) # Enable bfloat16 for faster inference if your GPU supports it with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16): out = sample_from_fm( v_xt_t, lit_model.noise_scheduler, method="euler", num_steps=200, num_samples=3, device=lit_model.device, side_shape=64 ) # Apply normalization if configured if hasattr(lit_model.cfg, "z_scale") and lit_model.cfg.z_scale.mean is not None: out = out * lit_model.cfg.z_scale.std + lit_model.cfg.z_scale.mean # Save generated samples for i in range(3): save_mrc(out[i].float().cpu().numpy(), f"sample-{i}.mrc", voxel_size=1.5) ``` **Fine-tuned Models (EMhancer/EMReady):** ```python import torch from mmengine import Config from cryofm.core.utils.mrc_io import save_mrc from cryofm.core.utils.sampling_fm import sample_from_fm from cryofm.projects.cryofm2.lit_modules import CryoFM2Cond # Choose style: "emhancer" or "emready" style = "emhancer" model_dir = f"path/to/cryofm-v2/cryofm2-{style}" cfg = Config.fromfile(f"{model_dir}/config.yaml") lit_model = CryoFM2Cond.load_from_safetensors(f"{model_dir}/model.safetensors", cfg=cfg) output_tag = 1 if style == "emhancer" else 0 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") lit_model = lit_model.to(device) lit_model.eval() def v_xt_t(_xt, _t): bs = _xt.shape[0] unconditional_generation_conds = { "input_cond": None, "output_cond": torch.tensor([output_tag] * bs).to(device), "vol_cond": None, # dimension should be [bs, d, h, w] } return lit_model(_xt, _t, generation_conds=unconditional_generation_conds) # Enable bfloat16 for faster inference if your GPU supports it with torch.no_grad(), torch.autocast("cuda", dtype=torch.bfloat16): out = sample_from_fm( v_xt_t, lit_model.noise_scheduler, method="euler", num_steps=200, num_samples=3, device=lit_model.device, side_shape=64 ) # Apply normalization if configured if hasattr(lit_model.cfg, "z_scale") and lit_model.cfg.z_scale.mean is not None: out = out * lit_model.cfg.z_scale.std + lit_model.cfg.z_scale.mean # Save generated samples for i in range(3): save_mrc(out[i].float().cpu().numpy(), f"{style}-sample-{i}.mrc", voxel_size=1.5) ``` ### Density Map Modification CryoFM2 supports various density map modification operations using the pretrained model as a Bayesian prior. Supported operators include: - **denoise**: Remove noise from density maps - **inpaint**: Fill missing regions (e.g., missing wedge) - **denoise inpaint**: Combined denoising and inpainting - **non-uniform weight**: Apply non-uniform weighting during reconstruction **Basic Usage:** ```bash python -m cryofm.projects.cryofm2.uncond_sampling \ -i1 half_map_1.mrc \ -i2 half_map_2.mrc \ -o ./output \ --model-dir path/to/cryofm-v2/cryofm2-pretrain \ --op denoise \ --norm-grad \ --use-lamb-w ``` **For inpainting tasks**, you need to provide a RELION starfile path: ```bash python -m cryofm.projects.cryofm2.uncond_sampling \ -i1 half_map_1.mrc \ -i2 half_map_2.mrc \ -o ./output \ --model-dir path/to/cryofm-v2/cryofm2-pretrain \ --op inpaint \ --data-starfile-path path/to/relion_data.star \ --norm-grad \ --use-lamb-w ``` ### Density Map Post-Processing CryoFM2 provides fine-tuned models for density map enhancement in different styles, similar to EMhancer and EMReady. #### EMhancer Style Enhancement ```bash python -m cryofm.projects.cryofm2.cond_sampling \ -i input_map.mrc \ -o ./output_emhancer \ --model-dir path/to/cryofm-v2/cryofm2-emhancer \ --output-tag 1 ``` #### EMReady Style Enhancement ```bash python -m cryofm.projects.cryofm2.cond_sampling \ -i input_map.mrc \ -o ./output_emready \ --model-dir path/to/cryofm-v2/cryofm2-emready \ --output-tag 0 \ --cfg-weight 0.5 ``` **Parameters:** - `-i`: Input density map file (MRC format) - `-o`: Output directory - `--model-dir`: Path to the model directory containing `config.yaml` and `model.safetensors` - `--output-tag`: Style tag (1 for EMhancer, 0 for EMReady) - `--cfg-weight`: Classifier-free guidance weight (optional, default varies by model) ## Performance Tips - **Multi-GPU Inference**: Use `accelerate launch` for faster inference on multiple GPUs: ```bash NCCL_DEBUG=ERROR accelerate launch --num_processes=${NUM_GPUS} --main_process_port=8881 \ python -m cryofm.projects.cryofm2.cond_sampling ... ``` - **Mixed Precision**: Use `--bf16` flag when available to reduce memory usage and speed up inference. - **Batch Processing**: Adjust batch size based on your GPU memory capacity. ## Limitations - Input size is fixed at 64×64×64 voxels - Model performance may vary depending on the input density map quality - Fine-tuned models are optimized for specific enhancement styles ## Ethical Considerations This model is intended for scientific research and structural biology applications. Users should: - Ensure proper attribution when using generated structures - Validate generated structures through experimental verification - Be aware of potential biases in the training data - Use the model responsibly and in accordance with scientific best practices ## Citation If you find CryoFM2 useful, please cite: ```bibtex @article{ Li2025.12.29.696802, author={Li, Yilai and Yuan, Jing and Zhou, Yi and Wang, Zhenghua and Chen, Suyi and Yang, Fengyu and Ling, Haibin and Kovalsky, Shahar Z and Zheng, Xiaoqing and Gu, Quanquan}, title={A Generative Foundation Model for Cryo-EM Densities}, elocation-id={2025.12.29.696802}, year={2025}, doi={10.64898/2025.12.29.696802}, publisher={Cold Spring Harbor Laboratory}, URL={https://www.biorxiv.org/content/early/2025/12/29/2025.12.29.696802}, eprint={https://www.biorxiv.org/content/early/2025/12/29/2025.12.29.696802.full.pdf}, journal={bioRxiv} } ``` ## License This model is released under the Apache 2.0 License. See the [LICENSE](https://github.com/ByteDance-Seed/cryofm/blob/main/LICENSE) file for details. ## Acknowledgments This work is developed by the ByteDance Seed Team. For more information, visit: - [Project Repository](https://github.com/ByteDance-Seed/cryofm) - [ByteDance Seed Team](https://seed.bytedance.com/)