Yuanhan Mo commited on
Commit ·
be5d479
1
Parent(s): e612cdb
Add dummy datasets for XPU testing, XPU contrastive training script, and CLAUDE.md
Browse files- DummyOMDataset_indiv/pair in dataLoader.py for testing without real data
- OM_contrastive_xpu.py with XPU/CUDA/CPU auto-detection
- CLAUDE.md for codebase guidance
- CLAUDE.md +91 -0
- Dataloader/dataLoader.py +41 -0
- OM_contrastive_xpu.py +71 -0
CLAUDE.md
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CLAUDE.md
|
| 2 |
+
|
| 3 |
+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
| 4 |
+
|
| 5 |
+
## Project Overview
|
| 6 |
+
|
| 7 |
+
OmniMorph is a medical image framework for generation, restoration, and registration using a conditional Deformation-Recovery Diffusion Model (DeformDDPM). It supports 2D and 3D multi-modal medical imaging (CT, MRI, PET) with text-conditioned generation via BERT embeddings.
|
| 8 |
+
|
| 9 |
+
## Common Commands
|
| 10 |
+
|
| 11 |
+
```bash
|
| 12 |
+
# Training (single-mode diffusion)
|
| 13 |
+
CUDA_VISIBLE_DEVICES=0 python OM_train.py -C Config/config_om.yaml
|
| 14 |
+
|
| 15 |
+
# Training (dual-mode: diffusion + registration)
|
| 16 |
+
CUDA_VISIBLE_DEVICES=0,1 python OM_train_2modes.py -C Config/config_om.yaml
|
| 17 |
+
|
| 18 |
+
# Contrastive learning (text-image alignment)
|
| 19 |
+
python OM_contrastive.py -C Config/config_om_contrastive.yaml
|
| 20 |
+
|
| 21 |
+
# XPU testing with dummy data (no real dataset needed)
|
| 22 |
+
python OM_contrastive_xpu.py --dummy-samples 20
|
| 23 |
+
|
| 24 |
+
# Augmentation / inference with a trained model
|
| 25 |
+
python OM_aug.py -C Config/config_om.yaml
|
| 26 |
+
|
| 27 |
+
# Background training (production style)
|
| 28 |
+
nohup python -u OM_train_2modes.py -C Config/config_om.yaml > train_log.txt 2>&1 &
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
## Architecture
|
| 32 |
+
|
| 33 |
+
### Core Pipeline
|
| 34 |
+
|
| 35 |
+
```
|
| 36 |
+
Config YAML → DataLoader(s) → DeformDDPM(Network, STN) → Loss → Checkpoint
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
### Diffusion Module (`Diffusion/`)
|
| 40 |
+
|
| 41 |
+
- **diffuser.py** — `DeformDDPM`: forward/reverse diffusion over deformation vector fields (DVFs). Generates multi-scale DDFs via control points at ratios [4, 8, 16, 32, 64]. Key methods: `diffuse()`, `recover()`.
|
| 42 |
+
- **networks.py** — Network architectures selected by `get_net(net_name)`:
|
| 43 |
+
- `recresacnet` — Atrous convolution UNet (2D CMR)
|
| 44 |
+
- `recmutattnnet` — Multi-head attention network (main 3D, channels [1,16,32,64,128,256])
|
| 45 |
+
- `recmutattnnet_contrastive` — Outputs 1024-dim image embeddings for contrastive training
|
| 46 |
+
- `defrecmutattnnet` — Deformable variant
|
| 47 |
+
- **networks.py: `STN`** — Spatial Transformer Network for differentiable image warping via DDFs. Composes deformations: `comp_ddf = dvf + stn(ddf, dvf)`.
|
| 48 |
+
- **losses.py** — `Grad` (L1 + negative Jacobian determinant + range penalties), `LNCC` (local normalized cross-correlation), `LMSE` (labeled MSE), `NCC`, `MRSE`.
|
| 49 |
+
|
| 50 |
+
### Training Modes
|
| 51 |
+
|
| 52 |
+
| Script | Purpose | DataLoader | Key Loss |
|
| 53 |
+
|--------|---------|------------|----------|
|
| 54 |
+
| `OM_train.py` | Single diffusion | `OminiDataset_v1` | Grad + MRSE + NCC |
|
| 55 |
+
| `OM_train_2modes.py` | Diffusion + registration | `OMDataset_indiv` + `OMDataset_pair` | Above + LNCC + LMSE |
|
| 56 |
+
| `OM_train_3modes.py` | Extended dual-mode | Same as 2modes | Different loss weights |
|
| 57 |
+
| `OM_contrastive.py` | Text-image alignment | `OMDataset_indiv` | Cosine similarity |
|
| 58 |
+
| `OM_reg.py` | Registration only | Paired data | Registration losses |
|
| 59 |
+
| `OM_train_uncon.py` | Unconditional generation | Generic | Standard |
|
| 60 |
+
|
| 61 |
+
All DDP-enabled training scripts use NCCL backend on `localhost:12355`.
|
| 62 |
+
|
| 63 |
+
### DataLoader (`Dataloader/`)
|
| 64 |
+
|
| 65 |
+
- **dataLoader.py** — All dataset classes. Data comes from JSON mapping files in `nifty_mappings/` that map NIfTI file paths to metadata (Modality, ROI, Size, Spacing_mm, BERT embeddings).
|
| 66 |
+
- `OMDataset_indiv` → returns `[volume, embd]` (shape: `[1,sz,sz,sz]`, `[1024]`)
|
| 67 |
+
- `OMDataset_pair` → returns `[volume_A, volume_B, embd_A, embd_B]`
|
| 68 |
+
- `DummyOMDataset_indiv` / `DummyOMDataset_pair` → random tensors for XPU testing without data
|
| 69 |
+
- **dataloader_utils.py** — `get_sizeRange_dict()` for ROI-based filtering, image thresholding, DICOM reading.
|
| 70 |
+
- **bert_helper.py** / **embding_gen.py** — BERT text embedding generation.
|
| 71 |
+
- Filtering chain: min dimension → modality → ROI → label presence.
|
| 72 |
+
|
| 73 |
+
### Config (`Config/`)
|
| 74 |
+
|
| 75 |
+
YAML files with keys: `data_name`, `net_name`, `ndims` (2 or 3), `img_size`, `batchsize`, `timesteps` (default 80), `v_scale`, `lr`, `epoch`, `noise_scale`, `condition_type` (`'uncon'`, `'adding'`, `'project'`, etc.), augmentation params (`start_noise_step`, `noise_step`, `aug_coe`).
|
| 76 |
+
|
| 77 |
+
### Augmentation (`OM_aug.py`, `OM_aug_highres.py`)
|
| 78 |
+
|
| 79 |
+
Loads a trained checkpoint and generates augmented samples. Controlled by `start_noise_step` (higher = less deformation), `aug_coe` (samples per input, typically 32-64). Outputs saved to `Data/Aug_data/{dataset}/img|msk|ddf/`.
|
| 80 |
+
|
| 81 |
+
## Key Conventions
|
| 82 |
+
|
| 83 |
+
- Models saved as `{epoch:06d}_{data_name}_{net_name}.pth` containing `model_state_dict`, `optimizer_state_dict`, `epoch`.
|
| 84 |
+
- CT images clamped to [-400, 400] HU before normalization.
|
| 85 |
+
- SimpleITK axis order is reversed from NumPy (`reverse_axis_order()`).
|
| 86 |
+
- Mapping JSON files in `nifty_mappings/` are Git LFS tracked (large files).
|
| 87 |
+
- `utils.py` provides `get_transformer()` for random affine augmentations and `get_random_deformed_mask()` for blind masks.
|
| 88 |
+
|
| 89 |
+
## Dependencies
|
| 90 |
+
|
| 91 |
+
PyTorch 1.12+ with CUDA, SimpleITK, nibabel, scikit-image, einops, pydicom, transformers (HuggingFace), swanlab (optional, for experiment tracking). See `requirements.txt`.
|
Dataloader/dataLoader.py
CHANGED
|
@@ -74,6 +74,47 @@ def sample_random_uniform_multi_order(high=1., low=0., order_num=2, type='high')
|
|
| 74 |
sample_value = np.random.uniform(low, high=sample_value)
|
| 75 |
return sample_value
|
| 76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
class OminiDataset(object):
|
| 78 |
"""Base class for OmniMorph datasets."""
|
| 79 |
def init(self, out_sz, transform, clamp_range, min_crop_ratio, ROIs, modality,reverse_axis_order ,min_dim,mapping_files):
|
|
|
|
| 74 |
sample_value = np.random.uniform(low, high=sample_value)
|
| 75 |
return sample_value
|
| 76 |
|
| 77 |
+
class DummyOMDataset_indiv(Dataset):
|
| 78 |
+
"""Dummy dataset that generates random 3D volumes and embeddings for XPU testing."""
|
| 79 |
+
def __init__(self, out_sz=128, num_samples=100, embd_dim=1024, transform=None):
|
| 80 |
+
self.out_sz = out_sz
|
| 81 |
+
self.num_samples = num_samples
|
| 82 |
+
self.embd_dim = embd_dim
|
| 83 |
+
self.transform = transform
|
| 84 |
+
|
| 85 |
+
def __len__(self):
|
| 86 |
+
return self.num_samples
|
| 87 |
+
|
| 88 |
+
def __getitem__(self, idx):
|
| 89 |
+
volume = np.random.rand(1, self.out_sz, self.out_sz, self.out_sz).astype(np.float64)
|
| 90 |
+
embd = np.random.randn(self.embd_dim).astype(np.float32)
|
| 91 |
+
if self.transform is not None:
|
| 92 |
+
volume = self.transform(volume)
|
| 93 |
+
return volume, embd
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class DummyOMDataset_pair(Dataset):
|
| 97 |
+
"""Dummy dataset that generates random paired 3D volumes and embeddings for XPU testing."""
|
| 98 |
+
def __init__(self, out_sz=128, num_samples=100, embd_dim=1024, transform=None):
|
| 99 |
+
self.out_sz = out_sz
|
| 100 |
+
self.num_samples = num_samples
|
| 101 |
+
self.embd_dim = embd_dim
|
| 102 |
+
self.transform = transform
|
| 103 |
+
|
| 104 |
+
def __len__(self):
|
| 105 |
+
return self.num_samples
|
| 106 |
+
|
| 107 |
+
def __getitem__(self, idx):
|
| 108 |
+
volume_A = np.random.rand(1, self.out_sz, self.out_sz, self.out_sz).astype(np.float64)
|
| 109 |
+
volume_B = np.random.rand(1, self.out_sz, self.out_sz, self.out_sz).astype(np.float64)
|
| 110 |
+
embd_A = np.random.randn(self.embd_dim).astype(np.float32)
|
| 111 |
+
embd_B = np.random.randn(self.embd_dim).astype(np.float32)
|
| 112 |
+
if self.transform is not None:
|
| 113 |
+
volume_A = self.transform(volume_A)
|
| 114 |
+
volume_B = self.transform(volume_B)
|
| 115 |
+
return [volume_A, volume_B, embd_A, embd_B]
|
| 116 |
+
|
| 117 |
+
|
| 118 |
class OminiDataset(object):
|
| 119 |
"""Base class for OmniMorph datasets."""
|
| 120 |
def init(self, out_sz, transform, clamp_range, min_crop_ratio, ROIs, modality,reverse_axis_order ,min_dim,mapping_files):
|
OM_contrastive_xpu.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
from torch.optim import Adam
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
from Diffusion.networks import get_net
|
| 6 |
+
from Dataloader.dataLoader import DummyOMDataset_indiv
|
| 7 |
+
import argparse
|
| 8 |
+
import yaml
|
| 9 |
+
import os
|
| 10 |
+
import time
|
| 11 |
+
|
| 12 |
+
parser = argparse.ArgumentParser()
|
| 13 |
+
parser.add_argument("--config", "-C", type=str, default="Config/config_om_contrastive.yaml")
|
| 14 |
+
parser.add_argument("--dummy-samples", type=int, default=100, help="Number of dummy samples")
|
| 15 |
+
args = parser.parse_args()
|
| 16 |
+
|
| 17 |
+
with open(args.config, 'r') as file:
|
| 18 |
+
hyp = yaml.safe_load(file)
|
| 19 |
+
|
| 20 |
+
# Setup device: prefer XPU, fallback to CUDA, then CPU
|
| 21 |
+
if hasattr(torch, 'xpu') and torch.xpu.is_available():
|
| 22 |
+
device = torch.device('xpu')
|
| 23 |
+
print(f"Using XPU device: {torch.xpu.get_device_name(0)}")
|
| 24 |
+
elif torch.cuda.is_available():
|
| 25 |
+
device = torch.device(hyp['device'])
|
| 26 |
+
print(f"Using CUDA device")
|
| 27 |
+
else:
|
| 28 |
+
device = torch.device('cpu')
|
| 29 |
+
print(f"Using CPU device")
|
| 30 |
+
|
| 31 |
+
data_name = hyp['data_name']
|
| 32 |
+
net_name = hyp['net_name']
|
| 33 |
+
ndims = hyp['ndims']
|
| 34 |
+
img_size = hyp['img_size']
|
| 35 |
+
model_save_path = os.path.join('Models', f'{data_name}_{net_name}/')
|
| 36 |
+
os.makedirs(model_save_path, exist_ok=True)
|
| 37 |
+
|
| 38 |
+
# Model
|
| 39 |
+
Net = get_net(net_name)
|
| 40 |
+
model = Net(n_steps=hyp['timesteps'], ndims=ndims, num_input_chn=hyp['num_input_chn'], res=img_size).to(device)
|
| 41 |
+
optimizer = Adam(model.parameters(), lr=hyp['lr'])
|
| 42 |
+
|
| 43 |
+
# Data - dummy dataset for XPU testing
|
| 44 |
+
dataset = DummyOMDataset_indiv(out_sz=img_size, num_samples=args.dummy_samples)
|
| 45 |
+
train_loader = DataLoader(dataset, batch_size=hyp['batchsize'], shuffle=True, drop_last=True)
|
| 46 |
+
|
| 47 |
+
# Training
|
| 48 |
+
print(f'Start training on {device} with {len(dataset)} dummy samples...')
|
| 49 |
+
for epoch in range(hyp['epoch']):
|
| 50 |
+
epoch_loss = 0.0
|
| 51 |
+
|
| 52 |
+
for i, (volume, embd) in enumerate(train_loader):
|
| 53 |
+
t0 = time.time()
|
| 54 |
+
volume = volume.float().to(device)
|
| 55 |
+
embd = embd.to(device) # [B, 1024] GT text embedding
|
| 56 |
+
t = torch.randint(0, hyp['timesteps'], (volume.shape[0],)).to(device)
|
| 57 |
+
|
| 58 |
+
_, img_embd = model(x=volume, y=volume, t=t) # img_embd: [B, 1024]
|
| 59 |
+
|
| 60 |
+
# Cosine similarity loss: align img_embd with GT text embedding
|
| 61 |
+
loss = 1 - F.cosine_similarity(img_embd, embd, dim=-1).mean()
|
| 62 |
+
|
| 63 |
+
optimizer.zero_grad()
|
| 64 |
+
loss.backward()
|
| 65 |
+
optimizer.step()
|
| 66 |
+
epoch_loss += loss.item()
|
| 67 |
+
t1 = time.time()
|
| 68 |
+
dt = t1 - t0
|
| 69 |
+
print(f" Batch {i:04d} | Loss: {loss.item():.6f} | Time: {dt:.2f}s")
|
| 70 |
+
avg_loss = epoch_loss / max(len(train_loader), 1)
|
| 71 |
+
print(f"Epoch {epoch:04d} | Avg Loss: {avg_loss:.6f}")
|