Omini3D / CLAUDE.md
Yuanhan Mo
Add dummy datasets for XPU testing, XPU contrastive training script, and CLAUDE.md
be5d479

CLAUDE.md

This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.

Project Overview

OmniMorph is a medical image framework for generation, restoration, and registration using a conditional Deformation-Recovery Diffusion Model (DeformDDPM). It supports 2D and 3D multi-modal medical imaging (CT, MRI, PET) with text-conditioned generation via BERT embeddings.

Common Commands

# Training (single-mode diffusion)
CUDA_VISIBLE_DEVICES=0 python OM_train.py -C Config/config_om.yaml

# Training (dual-mode: diffusion + registration)
CUDA_VISIBLE_DEVICES=0,1 python OM_train_2modes.py -C Config/config_om.yaml

# Contrastive learning (text-image alignment)
python OM_contrastive.py -C Config/config_om_contrastive.yaml

# XPU testing with dummy data (no real dataset needed)
python OM_contrastive_xpu.py --dummy-samples 20

# Augmentation / inference with a trained model
python OM_aug.py -C Config/config_om.yaml

# Background training (production style)
nohup python -u OM_train_2modes.py -C Config/config_om.yaml > train_log.txt 2>&1 &

Architecture

Core Pipeline

Config YAML β†’ DataLoader(s) β†’ DeformDDPM(Network, STN) β†’ Loss β†’ Checkpoint

Diffusion Module (Diffusion/)

  • diffuser.py β€” DeformDDPM: forward/reverse diffusion over deformation vector fields (DVFs). Generates multi-scale DDFs via control points at ratios [4, 8, 16, 32, 64]. Key methods: diffuse(), recover().
  • networks.py β€” Network architectures selected by get_net(net_name):
    • recresacnet β€” Atrous convolution UNet (2D CMR)
    • recmutattnnet β€” Multi-head attention network (main 3D, channels [1,16,32,64,128,256])
    • recmutattnnet_contrastive β€” Outputs 1024-dim image embeddings for contrastive training
    • defrecmutattnnet β€” Deformable variant
  • networks.py: STN β€” Spatial Transformer Network for differentiable image warping via DDFs. Composes deformations: comp_ddf = dvf + stn(ddf, dvf).
  • losses.py β€” Grad (L1 + negative Jacobian determinant + range penalties), LNCC (local normalized cross-correlation), LMSE (labeled MSE), NCC, MRSE.

Training Modes

Script Purpose DataLoader Key Loss
OM_train.py Single diffusion OminiDataset_v1 Grad + MRSE + NCC
OM_train_2modes.py Diffusion + registration OMDataset_indiv + OMDataset_pair Above + LNCC + LMSE
OM_train_3modes.py Extended dual-mode Same as 2modes Different loss weights
OM_contrastive.py Text-image alignment OMDataset_indiv Cosine similarity
OM_reg.py Registration only Paired data Registration losses
OM_train_uncon.py Unconditional generation Generic Standard

All DDP-enabled training scripts use NCCL backend on localhost:12355.

DataLoader (Dataloader/)

  • dataLoader.py β€” All dataset classes. Data comes from JSON mapping files in nifty_mappings/ that map NIfTI file paths to metadata (Modality, ROI, Size, Spacing_mm, BERT embeddings).
    • OMDataset_indiv β†’ returns [volume, embd] (shape: [1,sz,sz,sz], [1024])
    • OMDataset_pair β†’ returns [volume_A, volume_B, embd_A, embd_B]
    • DummyOMDataset_indiv / DummyOMDataset_pair β†’ random tensors for XPU testing without data
  • dataloader_utils.py β€” get_sizeRange_dict() for ROI-based filtering, image thresholding, DICOM reading.
  • bert_helper.py / embding_gen.py β€” BERT text embedding generation.
  • Filtering chain: min dimension β†’ modality β†’ ROI β†’ label presence.

Config (Config/)

YAML files with keys: data_name, net_name, ndims (2 or 3), img_size, batchsize, timesteps (default 80), v_scale, lr, epoch, noise_scale, condition_type ('uncon', 'adding', 'project', etc.), augmentation params (start_noise_step, noise_step, aug_coe).

Augmentation (OM_aug.py, OM_aug_highres.py)

Loads a trained checkpoint and generates augmented samples. Controlled by start_noise_step (higher = less deformation), aug_coe (samples per input, typically 32-64). Outputs saved to Data/Aug_data/{dataset}/img|msk|ddf/.

Key Conventions

  • Models saved as {epoch:06d}_{data_name}_{net_name}.pth containing model_state_dict, optimizer_state_dict, epoch.
  • CT images clamped to [-400, 400] HU before normalization.
  • SimpleITK axis order is reversed from NumPy (reverse_axis_order()).
  • Mapping JSON files in nifty_mappings/ are Git LFS tracked (large files).
  • utils.py provides get_transformer() for random affine augmentations and get_random_deformed_mask() for blind masks.

Dependencies

PyTorch 1.12+ with CUDA, SimpleITK, nibabel, scikit-image, einops, pydicom, transformers (HuggingFace), swanlab (optional, for experiment tracking). See requirements.txt.