CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
Project Overview
OmniMorph is a medical image framework for generation, restoration, and registration using a conditional Deformation-Recovery Diffusion Model (DeformDDPM). It supports 2D and 3D multi-modal medical imaging (CT, MRI, PET) with text-conditioned generation via BERT embeddings.
Common Commands
# Training (single-mode diffusion)
CUDA_VISIBLE_DEVICES=0 python OM_train.py -C Config/config_om.yaml
# Training (dual-mode: diffusion + registration)
CUDA_VISIBLE_DEVICES=0,1 python OM_train_2modes.py -C Config/config_om.yaml
# Contrastive learning (text-image alignment)
python OM_contrastive.py -C Config/config_om_contrastive.yaml
# XPU testing with dummy data (no real dataset needed)
python OM_contrastive_xpu.py --dummy-samples 20
# Augmentation / inference with a trained model
python OM_aug.py -C Config/config_om.yaml
# Background training (production style)
nohup python -u OM_train_2modes.py -C Config/config_om.yaml > train_log.txt 2>&1 &
Architecture
Core Pipeline
Config YAML β DataLoader(s) β DeformDDPM(Network, STN) β Loss β Checkpoint
Diffusion Module (Diffusion/)
- diffuser.py β
DeformDDPM: forward/reverse diffusion over deformation vector fields (DVFs). Generates multi-scale DDFs via control points at ratios [4, 8, 16, 32, 64]. Key methods:diffuse(),recover(). - networks.py β Network architectures selected by
get_net(net_name):recresacnetβ Atrous convolution UNet (2D CMR)recmutattnnetβ Multi-head attention network (main 3D, channels [1,16,32,64,128,256])recmutattnnet_contrastiveβ Outputs 1024-dim image embeddings for contrastive trainingdefrecmutattnnetβ Deformable variant
- networks.py:
STNβ Spatial Transformer Network for differentiable image warping via DDFs. Composes deformations:comp_ddf = dvf + stn(ddf, dvf). - losses.py β
Grad(L1 + negative Jacobian determinant + range penalties),LNCC(local normalized cross-correlation),LMSE(labeled MSE),NCC,MRSE.
Training Modes
| Script | Purpose | DataLoader | Key Loss |
|---|---|---|---|
OM_train.py |
Single diffusion | OminiDataset_v1 |
Grad + MRSE + NCC |
OM_train_2modes.py |
Diffusion + registration | OMDataset_indiv + OMDataset_pair |
Above + LNCC + LMSE |
OM_train_3modes.py |
Extended dual-mode | Same as 2modes | Different loss weights |
OM_contrastive.py |
Text-image alignment | OMDataset_indiv |
Cosine similarity |
OM_reg.py |
Registration only | Paired data | Registration losses |
OM_train_uncon.py |
Unconditional generation | Generic | Standard |
All DDP-enabled training scripts use NCCL backend on localhost:12355.
DataLoader (Dataloader/)
- dataLoader.py β All dataset classes. Data comes from JSON mapping files in
nifty_mappings/that map NIfTI file paths to metadata (Modality, ROI, Size, Spacing_mm, BERT embeddings).OMDataset_indivβ returns[volume, embd](shape:[1,sz,sz,sz],[1024])OMDataset_pairβ returns[volume_A, volume_B, embd_A, embd_B]DummyOMDataset_indiv/DummyOMDataset_pairβ random tensors for XPU testing without data
- dataloader_utils.py β
get_sizeRange_dict()for ROI-based filtering, image thresholding, DICOM reading. - bert_helper.py / embding_gen.py β BERT text embedding generation.
- Filtering chain: min dimension β modality β ROI β label presence.
Config (Config/)
YAML files with keys: data_name, net_name, ndims (2 or 3), img_size, batchsize, timesteps (default 80), v_scale, lr, epoch, noise_scale, condition_type ('uncon', 'adding', 'project', etc.), augmentation params (start_noise_step, noise_step, aug_coe).
Augmentation (OM_aug.py, OM_aug_highres.py)
Loads a trained checkpoint and generates augmented samples. Controlled by start_noise_step (higher = less deformation), aug_coe (samples per input, typically 32-64). Outputs saved to Data/Aug_data/{dataset}/img|msk|ddf/.
Key Conventions
- Models saved as
{epoch:06d}_{data_name}_{net_name}.pthcontainingmodel_state_dict,optimizer_state_dict,epoch. - CT images clamped to [-400, 400] HU before normalization.
- SimpleITK axis order is reversed from NumPy (
reverse_axis_order()). - Mapping JSON files in
nifty_mappings/are Git LFS tracked (large files). utils.pyprovidesget_transformer()for random affine augmentations andget_random_deformed_mask()for blind masks.
Dependencies
PyTorch 1.12+ with CUDA, SimpleITK, nibabel, scikit-image, einops, pydicom, transformers (HuggingFace), swanlab (optional, for experiment tracking). See requirements.txt.