File size: 4,961 Bytes
2af0e94 a3690e7 2af0e94 a3690e7 2af0e94 a3690e7 2af0e94 a3690e7 2af0e94 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | ---
license: mit
tags:
- medical-imaging
- registration
- diffusion
- 3d
- image-generation
- image-restoration
- pytorch
library_name: pytorch
---
# OmniMorph
**Deform All-in-One Framework for Medical Image Generation, Restoration and Registration based on a conditional Deformation-Recovery Diffusion Model (DeformDDPM).**
OmniMorph is a unified framework for 2D/3D multi-modal medical imaging (CT, MRI, PET) supporting:
- **Generation** β text-conditioned image synthesis via BERT embeddings.
- **Restoration** β recover anatomically plausible images from degraded inputs.
- **Registration** β paired / unpaired / flexible-resolution registration via diffused deformation vector fields.
## Repository Contents
| Path | Description |
|---|---|
| `OM_train*.py` | Training entrypoints (single-/2-/3-mode variants, CUDA + Intel XPU) |
| `OM_aug*.py`, `OM_reg*.py`, `OM_contrastive*.py` | Inference / augmentation / registration / contrastive scripts |
| `Diffusion/` | DeformDDPM core: `diffuser.py`, networks, losses, spatial utils |
| `OMorpher/` | Higher-level model wrapper |
| `Dataloader/` | Multi-modality dataloaders + dataset mappings (16 datasets) |
| `Config/` | YAML training/inference configs |
| `Scripts/` | Auxiliary scripts (registration, evaluation) |
| `tests/` | Pytest suite for `OMorpher` and loss functions |
| `bash_*.sh`, `*.slurm` | SLURM submission scripts (CUDA + Intel XPU/Dawn) |
| `Models/all_om_net/000110_all_om_net.pth` | Trained checkpoint β production multi-modal `recmulmodmutattnnet` (epoch 110, ~3.0 GB) |
| `Models/all_recmulmodmutattnnet/000010_all_recmulmodmutattnnet.pth` | Earlier `recmulmodmutattnnet` run (epoch 10, ~906 MB) |
> **Note** Only the final checkpoint of each training run is shipped β intermediate epochs and the `bert_large_uncased` weights are not bundled. Download `bert-large-uncased` from the official Hugging Face repo if you need the contrastive text encoder.
## Setup
```bash
git clone https://huggingface.co/DRDMsig/Omini3D
cd Omini3D
pip install -r requirements.txt
```
For Intel XPU / Dawn cluster, install the matching `intel-extension-for-pytorch` build before installing the rest of the requirements.
## Quick Start
### Training
```bash
# Single-mode diffusion
CUDA_VISIBLE_DEVICES=0 python OM_train.py -C Config/config_om.yaml
# Dual mode (diffusion + registration)
CUDA_VISIBLE_DEVICES=0,1 python OM_train_2modes.py -C Config/config_om.yaml
# Triple mode (diffusion + contrastive + registration)
CUDA_VISIBLE_DEVICES=0,1 python OM_train_3modes.py -C Config/config_om.yaml
# Intel XPU (single node)
sbatch bash_train_single_node.sh
```
### Inference
```bash
# Augmentation / restoration with a trained model
python OM_aug.py -C Config/config_om.yaml
# Paired registration
python OM_reg.py -C Config/config_om.yaml
# Flexible-resolution registration
python OM_reg_flexres.py -C Config/config_om.yaml
```
### Loading the checkpoint
```python
import torch
from Diffusion.networks import get_net
# Production network (multi-modal recmulmodmutattnnet)
net = get_net("recmulmodmutattnnet")
# Production checkpoint (epoch 110)
ckpt_path = "Models/all_om_net/000110_all_om_net.pth"
# Or earlier run: "Models/all_recmulmodmutattnnet/000010_all_recmulmodmutattnnet.pth"
state = torch.load(ckpt_path, map_location="cpu")
net.load_state_dict(state["model"] if "model" in state else state)
net.eval()
```
## Architecture
```
Config YAML β DataLoader(s) β DeformDDPM(Network, STN) β Loss β Checkpoint
```
- **`DeformDDPM`** (`Diffusion/diffuser.py`) β forward/reverse diffusion over deformation vector fields (DVFs); multi-scale DDFs at control-point ratios `[4, 8, 16, 32, 64]`.
- **Networks** (`Diffusion/networks.py`) β selectable via `get_net(name)`:
- `recmulmodmutattnnet` β current production multi-modal multi-head-attention net (used by `000110_all_om_net.pth`)
- `recmutattnnet`, `recmutattnnet_contrastive`, `recresacnet`, `defrecmutattnnet`
- **`STN`** β Spatial Transformer for differentiable warping; composes deformations as `comp_ddf = dvf + stn(ddf, dvf)`.
- **Losses** (`Diffusion/losses.py`, `losses_ncc0.py`) β `Grad`, `LNCC`, `LMSE`, `NCC`, `MRSE`, `RMSE`.
## Datasets Supported
`Dataloader/nifty_mappings/` contains pre-computed mappings for 16 public medical-imaging datasets, including:
AbdomenAtlas, AbdomenCT-1k, BraTS 2019/2020/2021, MSD, OASIS-1/2, OAI-ZIB, MnMs, Kaggle OSIC, TotalSegmentator (CT+MRI), PSMA-FDG-PET-CT-Lesion, CIA.
The dataset files themselves are **not** included; obtain them from their respective sources and update the mapping paths.
## Citation
```bibtex
@article{omnimorph,
title = {OmniMorph: Deform All-in-One Framework for Medical Image Generation,
Restoration and Registration via Conditional Deformation-Recovery
Diffusion Models},
author = {Zheng, J. and Mo, M. and others},
year = {2025}
}
```
## License
MIT β see `LICENSE`.
|