File size: 4,962 Bytes
be5d479
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# CLAUDE.md

This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.

## Project Overview

OmniMorph is a medical image framework for generation, restoration, and registration using a conditional Deformation-Recovery Diffusion Model (DeformDDPM). It supports 2D and 3D multi-modal medical imaging (CT, MRI, PET) with text-conditioned generation via BERT embeddings.

## Common Commands

```bash

# Training (single-mode diffusion)

CUDA_VISIBLE_DEVICES=0 python OM_train.py -C Config/config_om.yaml



# Training (dual-mode: diffusion + registration)

CUDA_VISIBLE_DEVICES=0,1 python OM_train_2modes.py -C Config/config_om.yaml



# Contrastive learning (text-image alignment)

python OM_contrastive.py -C Config/config_om_contrastive.yaml



# XPU testing with dummy data (no real dataset needed)

python OM_contrastive_xpu.py --dummy-samples 20



# Augmentation / inference with a trained model

python OM_aug.py -C Config/config_om.yaml



# Background training (production style)

nohup python -u OM_train_2modes.py -C Config/config_om.yaml > train_log.txt 2>&1 &

```

## Architecture

### Core Pipeline

```

Config YAML β†’ DataLoader(s) β†’ DeformDDPM(Network, STN) β†’ Loss β†’ Checkpoint

```

### Diffusion Module (`Diffusion/`)

- **diffuser.py** β€” `DeformDDPM`: forward/reverse diffusion over deformation vector fields (DVFs). Generates multi-scale DDFs via control points at ratios [4, 8, 16, 32, 64]. Key methods: `diffuse()`, `recover()`.
- **networks.py** β€” Network architectures selected by `get_net(net_name)`:
  - `recresacnet` β€” Atrous convolution UNet (2D CMR)
  - `recmutattnnet` β€” Multi-head attention network (main 3D, channels [1,16,32,64,128,256])
  - `recmutattnnet_contrastive` β€” Outputs 1024-dim image embeddings for contrastive training
  - `defrecmutattnnet` β€” Deformable variant
- **networks.py: `STN`** β€” Spatial Transformer Network for differentiable image warping via DDFs. Composes deformations: `comp_ddf = dvf + stn(ddf, dvf)`.
- **losses.py** β€” `Grad` (L1 + negative Jacobian determinant + range penalties), `LNCC` (local normalized cross-correlation), `LMSE` (labeled MSE), `NCC`, `MRSE`.

### Training Modes

| Script | Purpose | DataLoader | Key Loss |
|--------|---------|------------|----------|
| `OM_train.py` | Single diffusion | `OminiDataset_v1` | Grad + MRSE + NCC |
| `OM_train_2modes.py` | Diffusion + registration | `OMDataset_indiv` + `OMDataset_pair` | Above + LNCC + LMSE |
| `OM_train_3modes.py` | Extended dual-mode | Same as 2modes | Different loss weights |
| `OM_contrastive.py` | Text-image alignment | `OMDataset_indiv` | Cosine similarity |
| `OM_reg.py` | Registration only | Paired data | Registration losses |
| `OM_train_uncon.py` | Unconditional generation | Generic | Standard |

All DDP-enabled training scripts use NCCL backend on `localhost:12355`.

### DataLoader (`Dataloader/`)

- **dataLoader.py** β€” All dataset classes. Data comes from JSON mapping files in `nifty_mappings/` that map NIfTI file paths to metadata (Modality, ROI, Size, Spacing_mm, BERT embeddings).

  - `OMDataset_indiv` β†’ returns `[volume, embd]` (shape: `[1,sz,sz,sz]`, `[1024]`)
  - `OMDataset_pair` β†’ returns `[volume_A, volume_B, embd_A, embd_B]`
  - `DummyOMDataset_indiv` / `DummyOMDataset_pair` β†’ random tensors for XPU testing without data
- **dataloader_utils.py** β€” `get_sizeRange_dict()` for ROI-based filtering, image thresholding, DICOM reading.

- **bert_helper.py** / **embding_gen.py** β€” BERT text embedding generation.

- Filtering chain: min dimension β†’ modality β†’ ROI β†’ label presence.



### Config (`Config/`)



YAML files with keys: `data_name`, `net_name`, `ndims` (2 or 3), `img_size`, `batchsize`, `timesteps` (default 80), `v_scale`, `lr`, `epoch`, `noise_scale`, `condition_type` (`'uncon'`, `'adding'`, `'project'`, etc.), augmentation params (`start_noise_step`, `noise_step`, `aug_coe`).



### Augmentation (`OM_aug.py`, `OM_aug_highres.py`)



Loads a trained checkpoint and generates augmented samples. Controlled by `start_noise_step` (higher = less deformation), `aug_coe` (samples per input, typically 32-64). Outputs saved to `Data/Aug_data/{dataset}/img|msk|ddf/`.



## Key Conventions



- Models saved as `{epoch:06d}_{data_name}_{net_name}.pth` containing `model_state_dict`, `optimizer_state_dict`, `epoch`.

- CT images clamped to [-400, 400] HU before normalization.

- SimpleITK axis order is reversed from NumPy (`reverse_axis_order()`).

- Mapping JSON files in `nifty_mappings/` are Git LFS tracked (large files).

- `utils.py` provides `get_transformer()` for random affine augmentations and `get_random_deformed_mask()` for blind masks.



## Dependencies



PyTorch 1.12+ with CUDA, SimpleITK, nibabel, scikit-image, einops, pydicom, transformers (HuggingFace), swanlab (optional, for experiment tracking). See `requirements.txt`.