Omini3D / README.md
maxmo2009's picture
Update README to include both shipped checkpoints
a3690e7 verified
---
license: mit
tags:
- medical-imaging
- registration
- diffusion
- 3d
- image-generation
- image-restoration
- pytorch
library_name: pytorch
---
# OmniMorph
**Deform All-in-One Framework for Medical Image Generation, Restoration and Registration based on a conditional Deformation-Recovery Diffusion Model (DeformDDPM).**
OmniMorph is a unified framework for 2D/3D multi-modal medical imaging (CT, MRI, PET) supporting:
- **Generation** β€” text-conditioned image synthesis via BERT embeddings.
- **Restoration** β€” recover anatomically plausible images from degraded inputs.
- **Registration** β€” paired / unpaired / flexible-resolution registration via diffused deformation vector fields.
## Repository Contents
| Path | Description |
|---|---|
| `OM_train*.py` | Training entrypoints (single-/2-/3-mode variants, CUDA + Intel XPU) |
| `OM_aug*.py`, `OM_reg*.py`, `OM_contrastive*.py` | Inference / augmentation / registration / contrastive scripts |
| `Diffusion/` | DeformDDPM core: `diffuser.py`, networks, losses, spatial utils |
| `OMorpher/` | Higher-level model wrapper |
| `Dataloader/` | Multi-modality dataloaders + dataset mappings (16 datasets) |
| `Config/` | YAML training/inference configs |
| `Scripts/` | Auxiliary scripts (registration, evaluation) |
| `tests/` | Pytest suite for `OMorpher` and loss functions |
| `bash_*.sh`, `*.slurm` | SLURM submission scripts (CUDA + Intel XPU/Dawn) |
| `Models/all_om_net/000110_all_om_net.pth` | Trained checkpoint β€” production multi-modal `recmulmodmutattnnet` (epoch 110, ~3.0 GB) |
| `Models/all_recmulmodmutattnnet/000010_all_recmulmodmutattnnet.pth` | Earlier `recmulmodmutattnnet` run (epoch 10, ~906 MB) |
> **Note** Only the final checkpoint of each training run is shipped β€” intermediate epochs and the `bert_large_uncased` weights are not bundled. Download `bert-large-uncased` from the official Hugging Face repo if you need the contrastive text encoder.
## Setup
```bash
git clone https://huggingface.co/DRDMsig/Omini3D
cd Omini3D
pip install -r requirements.txt
```
For Intel XPU / Dawn cluster, install the matching `intel-extension-for-pytorch` build before installing the rest of the requirements.
## Quick Start
### Training
```bash
# Single-mode diffusion
CUDA_VISIBLE_DEVICES=0 python OM_train.py -C Config/config_om.yaml
# Dual mode (diffusion + registration)
CUDA_VISIBLE_DEVICES=0,1 python OM_train_2modes.py -C Config/config_om.yaml
# Triple mode (diffusion + contrastive + registration)
CUDA_VISIBLE_DEVICES=0,1 python OM_train_3modes.py -C Config/config_om.yaml
# Intel XPU (single node)
sbatch bash_train_single_node.sh
```
### Inference
```bash
# Augmentation / restoration with a trained model
python OM_aug.py -C Config/config_om.yaml
# Paired registration
python OM_reg.py -C Config/config_om.yaml
# Flexible-resolution registration
python OM_reg_flexres.py -C Config/config_om.yaml
```
### Loading the checkpoint
```python
import torch
from Diffusion.networks import get_net
# Production network (multi-modal recmulmodmutattnnet)
net = get_net("recmulmodmutattnnet")
# Production checkpoint (epoch 110)
ckpt_path = "Models/all_om_net/000110_all_om_net.pth"
# Or earlier run: "Models/all_recmulmodmutattnnet/000010_all_recmulmodmutattnnet.pth"
state = torch.load(ckpt_path, map_location="cpu")
net.load_state_dict(state["model"] if "model" in state else state)
net.eval()
```
## Architecture
```
Config YAML β†’ DataLoader(s) β†’ DeformDDPM(Network, STN) β†’ Loss β†’ Checkpoint
```
- **`DeformDDPM`** (`Diffusion/diffuser.py`) β€” forward/reverse diffusion over deformation vector fields (DVFs); multi-scale DDFs at control-point ratios `[4, 8, 16, 32, 64]`.
- **Networks** (`Diffusion/networks.py`) β€” selectable via `get_net(name)`:
- `recmulmodmutattnnet` β€” current production multi-modal multi-head-attention net (used by `000110_all_om_net.pth`)
- `recmutattnnet`, `recmutattnnet_contrastive`, `recresacnet`, `defrecmutattnnet`
- **`STN`** β€” Spatial Transformer for differentiable warping; composes deformations as `comp_ddf = dvf + stn(ddf, dvf)`.
- **Losses** (`Diffusion/losses.py`, `losses_ncc0.py`) β€” `Grad`, `LNCC`, `LMSE`, `NCC`, `MRSE`, `RMSE`.
## Datasets Supported
`Dataloader/nifty_mappings/` contains pre-computed mappings for 16 public medical-imaging datasets, including:
AbdomenAtlas, AbdomenCT-1k, BraTS 2019/2020/2021, MSD, OASIS-1/2, OAI-ZIB, MnMs, Kaggle OSIC, TotalSegmentator (CT+MRI), PSMA-FDG-PET-CT-Lesion, CIA.
The dataset files themselves are **not** included; obtain them from their respective sources and update the mapping paths.
## Citation
```bibtex
@article{omnimorph,
title = {OmniMorph: Deform All-in-One Framework for Medical Image Generation,
Restoration and Registration via Conditional Deformation-Recovery
Diffusion Models},
author = {Zheng, J. and Mo, M. and others},
year = {2025}
}
```
## License
MIT β€” see `LICENSE`.