--- license: mit tags: - medical-imaging - registration - diffusion - 3d - image-generation - image-restoration - pytorch library_name: pytorch --- # OmniMorph **Deform All-in-One Framework for Medical Image Generation, Restoration and Registration based on a conditional Deformation-Recovery Diffusion Model (DeformDDPM).** OmniMorph is a unified framework for 2D/3D multi-modal medical imaging (CT, MRI, PET) supporting: - **Generation** — text-conditioned image synthesis via BERT embeddings. - **Restoration** — recover anatomically plausible images from degraded inputs. - **Registration** — paired / unpaired / flexible-resolution registration via diffused deformation vector fields. ## Repository Contents | Path | Description | |---|---| | `OM_train*.py` | Training entrypoints (single-/2-/3-mode variants, CUDA + Intel XPU) | | `OM_aug*.py`, `OM_reg*.py`, `OM_contrastive*.py` | Inference / augmentation / registration / contrastive scripts | | `Diffusion/` | DeformDDPM core: `diffuser.py`, networks, losses, spatial utils | | `OMorpher/` | Higher-level model wrapper | | `Dataloader/` | Multi-modality dataloaders + dataset mappings (16 datasets) | | `Config/` | YAML training/inference configs | | `Scripts/` | Auxiliary scripts (registration, evaluation) | | `tests/` | Pytest suite for `OMorpher` and loss functions | | `bash_*.sh`, `*.slurm` | SLURM submission scripts (CUDA + Intel XPU/Dawn) | | `Models/all_om_net/000110_all_om_net.pth` | Trained checkpoint — production multi-modal `recmulmodmutattnnet` (epoch 110, ~3.0 GB) | | `Models/all_recmulmodmutattnnet/000010_all_recmulmodmutattnnet.pth` | Earlier `recmulmodmutattnnet` run (epoch 10, ~906 MB) | > **Note** Only the final checkpoint of each training run is shipped — intermediate epochs and the `bert_large_uncased` weights are not bundled. Download `bert-large-uncased` from the official Hugging Face repo if you need the contrastive text encoder. ## Setup ```bash git clone https://huggingface.co/DRDMsig/Omini3D cd Omini3D pip install -r requirements.txt ``` For Intel XPU / Dawn cluster, install the matching `intel-extension-for-pytorch` build before installing the rest of the requirements. ## Quick Start ### Training ```bash # Single-mode diffusion CUDA_VISIBLE_DEVICES=0 python OM_train.py -C Config/config_om.yaml # Dual mode (diffusion + registration) CUDA_VISIBLE_DEVICES=0,1 python OM_train_2modes.py -C Config/config_om.yaml # Triple mode (diffusion + contrastive + registration) CUDA_VISIBLE_DEVICES=0,1 python OM_train_3modes.py -C Config/config_om.yaml # Intel XPU (single node) sbatch bash_train_single_node.sh ``` ### Inference ```bash # Augmentation / restoration with a trained model python OM_aug.py -C Config/config_om.yaml # Paired registration python OM_reg.py -C Config/config_om.yaml # Flexible-resolution registration python OM_reg_flexres.py -C Config/config_om.yaml ``` ### Loading the checkpoint ```python import torch from Diffusion.networks import get_net # Production network (multi-modal recmulmodmutattnnet) net = get_net("recmulmodmutattnnet") # Production checkpoint (epoch 110) ckpt_path = "Models/all_om_net/000110_all_om_net.pth" # Or earlier run: "Models/all_recmulmodmutattnnet/000010_all_recmulmodmutattnnet.pth" state = torch.load(ckpt_path, map_location="cpu") net.load_state_dict(state["model"] if "model" in state else state) net.eval() ``` ## Architecture ``` Config YAML → DataLoader(s) → DeformDDPM(Network, STN) → Loss → Checkpoint ``` - **`DeformDDPM`** (`Diffusion/diffuser.py`) — forward/reverse diffusion over deformation vector fields (DVFs); multi-scale DDFs at control-point ratios `[4, 8, 16, 32, 64]`. - **Networks** (`Diffusion/networks.py`) — selectable via `get_net(name)`: - `recmulmodmutattnnet` — current production multi-modal multi-head-attention net (used by `000110_all_om_net.pth`) - `recmutattnnet`, `recmutattnnet_contrastive`, `recresacnet`, `defrecmutattnnet` - **`STN`** — Spatial Transformer for differentiable warping; composes deformations as `comp_ddf = dvf + stn(ddf, dvf)`. - **Losses** (`Diffusion/losses.py`, `losses_ncc0.py`) — `Grad`, `LNCC`, `LMSE`, `NCC`, `MRSE`, `RMSE`. ## Datasets Supported `Dataloader/nifty_mappings/` contains pre-computed mappings for 16 public medical-imaging datasets, including: AbdomenAtlas, AbdomenCT-1k, BraTS 2019/2020/2021, MSD, OASIS-1/2, OAI-ZIB, MnMs, Kaggle OSIC, TotalSegmentator (CT+MRI), PSMA-FDG-PET-CT-Lesion, CIA. The dataset files themselves are **not** included; obtain them from their respective sources and update the mapping paths. ## Citation ```bibtex @article{omnimorph, title = {OmniMorph: Deform All-in-One Framework for Medical Image Generation, Restoration and Registration via Conditional Deformation-Recovery Diffusion Models}, author = {Zheng, J. and Mo, M. and others}, year = {2025} } ``` ## License MIT — see `LICENSE`.