| --- |
| license: mit |
| tags: |
| - medical-imaging |
| - registration |
| - diffusion |
| - 3d |
| - image-generation |
| - image-restoration |
| - pytorch |
| library_name: pytorch |
| --- |
| |
| # OmniMorph |
|
|
| **Deform All-in-One Framework for Medical Image Generation, Restoration and Registration based on a conditional Deformation-Recovery Diffusion Model (DeformDDPM).** |
|
|
| OmniMorph is a unified framework for 2D/3D multi-modal medical imaging (CT, MRI, PET) supporting: |
|
|
| - **Generation** β text-conditioned image synthesis via BERT embeddings. |
| - **Restoration** β recover anatomically plausible images from degraded inputs. |
| - **Registration** β paired / unpaired / flexible-resolution registration via diffused deformation vector fields. |
|
|
| ## Repository Contents |
|
|
| | Path | Description | |
| |---|---| |
| | `OM_train*.py` | Training entrypoints (single-/2-/3-mode variants, CUDA + Intel XPU) | |
| | `OM_aug*.py`, `OM_reg*.py`, `OM_contrastive*.py` | Inference / augmentation / registration / contrastive scripts | |
| | `Diffusion/` | DeformDDPM core: `diffuser.py`, networks, losses, spatial utils | |
| | `OMorpher/` | Higher-level model wrapper | |
| | `Dataloader/` | Multi-modality dataloaders + dataset mappings (16 datasets) | |
| | `Config/` | YAML training/inference configs | |
| | `Scripts/` | Auxiliary scripts (registration, evaluation) | |
| | `tests/` | Pytest suite for `OMorpher` and loss functions | |
| | `bash_*.sh`, `*.slurm` | SLURM submission scripts (CUDA + Intel XPU/Dawn) | |
| | `Models/all_om_net/000110_all_om_net.pth` | Trained checkpoint β production multi-modal `recmulmodmutattnnet` (epoch 110, ~3.0 GB) | |
| | `Models/all_recmulmodmutattnnet/000010_all_recmulmodmutattnnet.pth` | Earlier `recmulmodmutattnnet` run (epoch 10, ~906 MB) | |
|
|
| > **Note** Only the final checkpoint of each training run is shipped β intermediate epochs and the `bert_large_uncased` weights are not bundled. Download `bert-large-uncased` from the official Hugging Face repo if you need the contrastive text encoder. |
|
|
| ## Setup |
|
|
| ```bash |
| git clone https://huggingface.co/DRDMsig/Omini3D |
| cd Omini3D |
| pip install -r requirements.txt |
| ``` |
|
|
| For Intel XPU / Dawn cluster, install the matching `intel-extension-for-pytorch` build before installing the rest of the requirements. |
|
|
| ## Quick Start |
|
|
| ### Training |
|
|
| ```bash |
| # Single-mode diffusion |
| CUDA_VISIBLE_DEVICES=0 python OM_train.py -C Config/config_om.yaml |
| |
| # Dual mode (diffusion + registration) |
| CUDA_VISIBLE_DEVICES=0,1 python OM_train_2modes.py -C Config/config_om.yaml |
| |
| # Triple mode (diffusion + contrastive + registration) |
| CUDA_VISIBLE_DEVICES=0,1 python OM_train_3modes.py -C Config/config_om.yaml |
| |
| # Intel XPU (single node) |
| sbatch bash_train_single_node.sh |
| ``` |
|
|
| ### Inference |
|
|
| ```bash |
| # Augmentation / restoration with a trained model |
| python OM_aug.py -C Config/config_om.yaml |
| |
| # Paired registration |
| python OM_reg.py -C Config/config_om.yaml |
| |
| # Flexible-resolution registration |
| python OM_reg_flexres.py -C Config/config_om.yaml |
| ``` |
|
|
| ### Loading the checkpoint |
|
|
| ```python |
| import torch |
| from Diffusion.networks import get_net |
| |
| # Production network (multi-modal recmulmodmutattnnet) |
| net = get_net("recmulmodmutattnnet") |
| |
| # Production checkpoint (epoch 110) |
| ckpt_path = "Models/all_om_net/000110_all_om_net.pth" |
| # Or earlier run: "Models/all_recmulmodmutattnnet/000010_all_recmulmodmutattnnet.pth" |
| |
| state = torch.load(ckpt_path, map_location="cpu") |
| net.load_state_dict(state["model"] if "model" in state else state) |
| net.eval() |
| ``` |
|
|
| ## Architecture |
|
|
| ``` |
| Config YAML β DataLoader(s) β DeformDDPM(Network, STN) β Loss β Checkpoint |
| ``` |
|
|
| - **`DeformDDPM`** (`Diffusion/diffuser.py`) β forward/reverse diffusion over deformation vector fields (DVFs); multi-scale DDFs at control-point ratios `[4, 8, 16, 32, 64]`. |
| - **Networks** (`Diffusion/networks.py`) β selectable via `get_net(name)`: |
| - `recmulmodmutattnnet` β current production multi-modal multi-head-attention net (used by `000110_all_om_net.pth`) |
| - `recmutattnnet`, `recmutattnnet_contrastive`, `recresacnet`, `defrecmutattnnet` |
| - **`STN`** β Spatial Transformer for differentiable warping; composes deformations as `comp_ddf = dvf + stn(ddf, dvf)`. |
| - **Losses** (`Diffusion/losses.py`, `losses_ncc0.py`) β `Grad`, `LNCC`, `LMSE`, `NCC`, `MRSE`, `RMSE`. |
|
|
| ## Datasets Supported |
|
|
| `Dataloader/nifty_mappings/` contains pre-computed mappings for 16 public medical-imaging datasets, including: |
| AbdomenAtlas, AbdomenCT-1k, BraTS 2019/2020/2021, MSD, OASIS-1/2, OAI-ZIB, MnMs, Kaggle OSIC, TotalSegmentator (CT+MRI), PSMA-FDG-PET-CT-Lesion, CIA. |
|
|
| The dataset files themselves are **not** included; obtain them from their respective sources and update the mapping paths. |
|
|
| ## Citation |
|
|
| ```bibtex |
| @article{omnimorph, |
| title = {OmniMorph: Deform All-in-One Framework for Medical Image Generation, |
| Restoration and Registration via Conditional Deformation-Recovery |
| Diffusion Models}, |
| author = {Zheng, J. and Mo, M. and others}, |
| year = {2025} |
| } |
| ``` |
|
|
| ## License |
|
|
| MIT β see `LICENSE`. |
|
|