File size: 4,961 Bytes
2af0e94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3690e7
 
2af0e94
a3690e7
2af0e94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3690e7
2af0e94
a3690e7
 
 
 
 
 
2af0e94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
---
license: mit
tags:
  - medical-imaging
  - registration
  - diffusion
  - 3d
  - image-generation
  - image-restoration
  - pytorch
library_name: pytorch
---

# OmniMorph

**Deform All-in-One Framework for Medical Image Generation, Restoration and Registration based on a conditional Deformation-Recovery Diffusion Model (DeformDDPM).**

OmniMorph is a unified framework for 2D/3D multi-modal medical imaging (CT, MRI, PET) supporting:

- **Generation** β€” text-conditioned image synthesis via BERT embeddings.
- **Restoration** β€” recover anatomically plausible images from degraded inputs.
- **Registration** β€” paired / unpaired / flexible-resolution registration via diffused deformation vector fields.

## Repository Contents

| Path | Description |
|---|---|
| `OM_train*.py` | Training entrypoints (single-/2-/3-mode variants, CUDA + Intel XPU) |
| `OM_aug*.py`, `OM_reg*.py`, `OM_contrastive*.py` | Inference / augmentation / registration / contrastive scripts |
| `Diffusion/` | DeformDDPM core: `diffuser.py`, networks, losses, spatial utils |
| `OMorpher/` | Higher-level model wrapper |
| `Dataloader/` | Multi-modality dataloaders + dataset mappings (16 datasets) |
| `Config/` | YAML training/inference configs |
| `Scripts/` | Auxiliary scripts (registration, evaluation) |
| `tests/` | Pytest suite for `OMorpher` and loss functions |
| `bash_*.sh`, `*.slurm` | SLURM submission scripts (CUDA + Intel XPU/Dawn) |
| `Models/all_om_net/000110_all_om_net.pth` | Trained checkpoint β€” production multi-modal `recmulmodmutattnnet` (epoch 110, ~3.0 GB) |
| `Models/all_recmulmodmutattnnet/000010_all_recmulmodmutattnnet.pth` | Earlier `recmulmodmutattnnet` run (epoch 10, ~906 MB) |

> **Note** Only the final checkpoint of each training run is shipped β€” intermediate epochs and the `bert_large_uncased` weights are not bundled. Download `bert-large-uncased` from the official Hugging Face repo if you need the contrastive text encoder.

## Setup

```bash
git clone https://huggingface.co/DRDMsig/Omini3D
cd Omini3D
pip install -r requirements.txt
```

For Intel XPU / Dawn cluster, install the matching `intel-extension-for-pytorch` build before installing the rest of the requirements.

## Quick Start

### Training

```bash
# Single-mode diffusion
CUDA_VISIBLE_DEVICES=0 python OM_train.py -C Config/config_om.yaml

# Dual mode (diffusion + registration)
CUDA_VISIBLE_DEVICES=0,1 python OM_train_2modes.py -C Config/config_om.yaml

# Triple mode (diffusion + contrastive + registration)
CUDA_VISIBLE_DEVICES=0,1 python OM_train_3modes.py -C Config/config_om.yaml

# Intel XPU (single node)
sbatch bash_train_single_node.sh
```

### Inference

```bash
# Augmentation / restoration with a trained model
python OM_aug.py -C Config/config_om.yaml

# Paired registration
python OM_reg.py -C Config/config_om.yaml

# Flexible-resolution registration
python OM_reg_flexres.py -C Config/config_om.yaml
```

### Loading the checkpoint

```python
import torch
from Diffusion.networks import get_net

# Production network (multi-modal recmulmodmutattnnet)
net = get_net("recmulmodmutattnnet")

# Production checkpoint (epoch 110)
ckpt_path = "Models/all_om_net/000110_all_om_net.pth"
# Or earlier run: "Models/all_recmulmodmutattnnet/000010_all_recmulmodmutattnnet.pth"

state = torch.load(ckpt_path, map_location="cpu")
net.load_state_dict(state["model"] if "model" in state else state)
net.eval()
```

## Architecture

```
Config YAML β†’ DataLoader(s) β†’ DeformDDPM(Network, STN) β†’ Loss β†’ Checkpoint
```

- **`DeformDDPM`** (`Diffusion/diffuser.py`) β€” forward/reverse diffusion over deformation vector fields (DVFs); multi-scale DDFs at control-point ratios `[4, 8, 16, 32, 64]`.
- **Networks** (`Diffusion/networks.py`) β€” selectable via `get_net(name)`:
  - `recmulmodmutattnnet` β€” current production multi-modal multi-head-attention net (used by `000110_all_om_net.pth`)
  - `recmutattnnet`, `recmutattnnet_contrastive`, `recresacnet`, `defrecmutattnnet`
- **`STN`** β€” Spatial Transformer for differentiable warping; composes deformations as `comp_ddf = dvf + stn(ddf, dvf)`.
- **Losses** (`Diffusion/losses.py`, `losses_ncc0.py`) β€” `Grad`, `LNCC`, `LMSE`, `NCC`, `MRSE`, `RMSE`.

## Datasets Supported

`Dataloader/nifty_mappings/` contains pre-computed mappings for 16 public medical-imaging datasets, including:
AbdomenAtlas, AbdomenCT-1k, BraTS 2019/2020/2021, MSD, OASIS-1/2, OAI-ZIB, MnMs, Kaggle OSIC, TotalSegmentator (CT+MRI), PSMA-FDG-PET-CT-Lesion, CIA.

The dataset files themselves are **not** included; obtain them from their respective sources and update the mapping paths.

## Citation

```bibtex
@article{omnimorph,
  title  = {OmniMorph: Deform All-in-One Framework for Medical Image Generation,
            Restoration and Registration via Conditional Deformation-Recovery
            Diffusion Models},
  author = {Zheng, J. and Mo, M. and others},
  year   = {2025}
}
```

## License

MIT β€” see `LICENSE`.