Yuanhan Mo commited on
Commit
be5d479
·
1 Parent(s): e612cdb

Add dummy datasets for XPU testing, XPU contrastive training script, and CLAUDE.md

Browse files

- DummyOMDataset_indiv/pair in dataLoader.py for testing without real data
- OM_contrastive_xpu.py with XPU/CUDA/CPU auto-detection
- CLAUDE.md for codebase guidance

Files changed (3) hide show
  1. CLAUDE.md +91 -0
  2. Dataloader/dataLoader.py +41 -0
  3. OM_contrastive_xpu.py +71 -0
CLAUDE.md ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CLAUDE.md
2
+
3
+ This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4
+
5
+ ## Project Overview
6
+
7
+ OmniMorph is a medical image framework for generation, restoration, and registration using a conditional Deformation-Recovery Diffusion Model (DeformDDPM). It supports 2D and 3D multi-modal medical imaging (CT, MRI, PET) with text-conditioned generation via BERT embeddings.
8
+
9
+ ## Common Commands
10
+
11
+ ```bash
12
+ # Training (single-mode diffusion)
13
+ CUDA_VISIBLE_DEVICES=0 python OM_train.py -C Config/config_om.yaml
14
+
15
+ # Training (dual-mode: diffusion + registration)
16
+ CUDA_VISIBLE_DEVICES=0,1 python OM_train_2modes.py -C Config/config_om.yaml
17
+
18
+ # Contrastive learning (text-image alignment)
19
+ python OM_contrastive.py -C Config/config_om_contrastive.yaml
20
+
21
+ # XPU testing with dummy data (no real dataset needed)
22
+ python OM_contrastive_xpu.py --dummy-samples 20
23
+
24
+ # Augmentation / inference with a trained model
25
+ python OM_aug.py -C Config/config_om.yaml
26
+
27
+ # Background training (production style)
28
+ nohup python -u OM_train_2modes.py -C Config/config_om.yaml > train_log.txt 2>&1 &
29
+ ```
30
+
31
+ ## Architecture
32
+
33
+ ### Core Pipeline
34
+
35
+ ```
36
+ Config YAML → DataLoader(s) → DeformDDPM(Network, STN) → Loss → Checkpoint
37
+ ```
38
+
39
+ ### Diffusion Module (`Diffusion/`)
40
+
41
+ - **diffuser.py** — `DeformDDPM`: forward/reverse diffusion over deformation vector fields (DVFs). Generates multi-scale DDFs via control points at ratios [4, 8, 16, 32, 64]. Key methods: `diffuse()`, `recover()`.
42
+ - **networks.py** — Network architectures selected by `get_net(net_name)`:
43
+ - `recresacnet` — Atrous convolution UNet (2D CMR)
44
+ - `recmutattnnet` — Multi-head attention network (main 3D, channels [1,16,32,64,128,256])
45
+ - `recmutattnnet_contrastive` — Outputs 1024-dim image embeddings for contrastive training
46
+ - `defrecmutattnnet` — Deformable variant
47
+ - **networks.py: `STN`** — Spatial Transformer Network for differentiable image warping via DDFs. Composes deformations: `comp_ddf = dvf + stn(ddf, dvf)`.
48
+ - **losses.py** — `Grad` (L1 + negative Jacobian determinant + range penalties), `LNCC` (local normalized cross-correlation), `LMSE` (labeled MSE), `NCC`, `MRSE`.
49
+
50
+ ### Training Modes
51
+
52
+ | Script | Purpose | DataLoader | Key Loss |
53
+ |--------|---------|------------|----------|
54
+ | `OM_train.py` | Single diffusion | `OminiDataset_v1` | Grad + MRSE + NCC |
55
+ | `OM_train_2modes.py` | Diffusion + registration | `OMDataset_indiv` + `OMDataset_pair` | Above + LNCC + LMSE |
56
+ | `OM_train_3modes.py` | Extended dual-mode | Same as 2modes | Different loss weights |
57
+ | `OM_contrastive.py` | Text-image alignment | `OMDataset_indiv` | Cosine similarity |
58
+ | `OM_reg.py` | Registration only | Paired data | Registration losses |
59
+ | `OM_train_uncon.py` | Unconditional generation | Generic | Standard |
60
+
61
+ All DDP-enabled training scripts use NCCL backend on `localhost:12355`.
62
+
63
+ ### DataLoader (`Dataloader/`)
64
+
65
+ - **dataLoader.py** — All dataset classes. Data comes from JSON mapping files in `nifty_mappings/` that map NIfTI file paths to metadata (Modality, ROI, Size, Spacing_mm, BERT embeddings).
66
+ - `OMDataset_indiv` → returns `[volume, embd]` (shape: `[1,sz,sz,sz]`, `[1024]`)
67
+ - `OMDataset_pair` → returns `[volume_A, volume_B, embd_A, embd_B]`
68
+ - `DummyOMDataset_indiv` / `DummyOMDataset_pair` → random tensors for XPU testing without data
69
+ - **dataloader_utils.py** — `get_sizeRange_dict()` for ROI-based filtering, image thresholding, DICOM reading.
70
+ - **bert_helper.py** / **embding_gen.py** — BERT text embedding generation.
71
+ - Filtering chain: min dimension → modality → ROI → label presence.
72
+
73
+ ### Config (`Config/`)
74
+
75
+ YAML files with keys: `data_name`, `net_name`, `ndims` (2 or 3), `img_size`, `batchsize`, `timesteps` (default 80), `v_scale`, `lr`, `epoch`, `noise_scale`, `condition_type` (`'uncon'`, `'adding'`, `'project'`, etc.), augmentation params (`start_noise_step`, `noise_step`, `aug_coe`).
76
+
77
+ ### Augmentation (`OM_aug.py`, `OM_aug_highres.py`)
78
+
79
+ Loads a trained checkpoint and generates augmented samples. Controlled by `start_noise_step` (higher = less deformation), `aug_coe` (samples per input, typically 32-64). Outputs saved to `Data/Aug_data/{dataset}/img|msk|ddf/`.
80
+
81
+ ## Key Conventions
82
+
83
+ - Models saved as `{epoch:06d}_{data_name}_{net_name}.pth` containing `model_state_dict`, `optimizer_state_dict`, `epoch`.
84
+ - CT images clamped to [-400, 400] HU before normalization.
85
+ - SimpleITK axis order is reversed from NumPy (`reverse_axis_order()`).
86
+ - Mapping JSON files in `nifty_mappings/` are Git LFS tracked (large files).
87
+ - `utils.py` provides `get_transformer()` for random affine augmentations and `get_random_deformed_mask()` for blind masks.
88
+
89
+ ## Dependencies
90
+
91
+ PyTorch 1.12+ with CUDA, SimpleITK, nibabel, scikit-image, einops, pydicom, transformers (HuggingFace), swanlab (optional, for experiment tracking). See `requirements.txt`.
Dataloader/dataLoader.py CHANGED
@@ -74,6 +74,47 @@ def sample_random_uniform_multi_order(high=1., low=0., order_num=2, type='high')
74
  sample_value = np.random.uniform(low, high=sample_value)
75
  return sample_value
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  class OminiDataset(object):
78
  """Base class for OmniMorph datasets."""
79
  def init(self, out_sz, transform, clamp_range, min_crop_ratio, ROIs, modality,reverse_axis_order ,min_dim,mapping_files):
 
74
  sample_value = np.random.uniform(low, high=sample_value)
75
  return sample_value
76
 
77
+ class DummyOMDataset_indiv(Dataset):
78
+ """Dummy dataset that generates random 3D volumes and embeddings for XPU testing."""
79
+ def __init__(self, out_sz=128, num_samples=100, embd_dim=1024, transform=None):
80
+ self.out_sz = out_sz
81
+ self.num_samples = num_samples
82
+ self.embd_dim = embd_dim
83
+ self.transform = transform
84
+
85
+ def __len__(self):
86
+ return self.num_samples
87
+
88
+ def __getitem__(self, idx):
89
+ volume = np.random.rand(1, self.out_sz, self.out_sz, self.out_sz).astype(np.float64)
90
+ embd = np.random.randn(self.embd_dim).astype(np.float32)
91
+ if self.transform is not None:
92
+ volume = self.transform(volume)
93
+ return volume, embd
94
+
95
+
96
+ class DummyOMDataset_pair(Dataset):
97
+ """Dummy dataset that generates random paired 3D volumes and embeddings for XPU testing."""
98
+ def __init__(self, out_sz=128, num_samples=100, embd_dim=1024, transform=None):
99
+ self.out_sz = out_sz
100
+ self.num_samples = num_samples
101
+ self.embd_dim = embd_dim
102
+ self.transform = transform
103
+
104
+ def __len__(self):
105
+ return self.num_samples
106
+
107
+ def __getitem__(self, idx):
108
+ volume_A = np.random.rand(1, self.out_sz, self.out_sz, self.out_sz).astype(np.float64)
109
+ volume_B = np.random.rand(1, self.out_sz, self.out_sz, self.out_sz).astype(np.float64)
110
+ embd_A = np.random.randn(self.embd_dim).astype(np.float32)
111
+ embd_B = np.random.randn(self.embd_dim).astype(np.float32)
112
+ if self.transform is not None:
113
+ volume_A = self.transform(volume_A)
114
+ volume_B = self.transform(volume_B)
115
+ return [volume_A, volume_B, embd_A, embd_B]
116
+
117
+
118
  class OminiDataset(object):
119
  """Base class for OmniMorph datasets."""
120
  def init(self, out_sz, transform, clamp_range, min_crop_ratio, ROIs, modality,reverse_axis_order ,min_dim,mapping_files):
OM_contrastive_xpu.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch.optim import Adam
4
+ from torch.utils.data import DataLoader
5
+ from Diffusion.networks import get_net
6
+ from Dataloader.dataLoader import DummyOMDataset_indiv
7
+ import argparse
8
+ import yaml
9
+ import os
10
+ import time
11
+
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument("--config", "-C", type=str, default="Config/config_om_contrastive.yaml")
14
+ parser.add_argument("--dummy-samples", type=int, default=100, help="Number of dummy samples")
15
+ args = parser.parse_args()
16
+
17
+ with open(args.config, 'r') as file:
18
+ hyp = yaml.safe_load(file)
19
+
20
+ # Setup device: prefer XPU, fallback to CUDA, then CPU
21
+ if hasattr(torch, 'xpu') and torch.xpu.is_available():
22
+ device = torch.device('xpu')
23
+ print(f"Using XPU device: {torch.xpu.get_device_name(0)}")
24
+ elif torch.cuda.is_available():
25
+ device = torch.device(hyp['device'])
26
+ print(f"Using CUDA device")
27
+ else:
28
+ device = torch.device('cpu')
29
+ print(f"Using CPU device")
30
+
31
+ data_name = hyp['data_name']
32
+ net_name = hyp['net_name']
33
+ ndims = hyp['ndims']
34
+ img_size = hyp['img_size']
35
+ model_save_path = os.path.join('Models', f'{data_name}_{net_name}/')
36
+ os.makedirs(model_save_path, exist_ok=True)
37
+
38
+ # Model
39
+ Net = get_net(net_name)
40
+ model = Net(n_steps=hyp['timesteps'], ndims=ndims, num_input_chn=hyp['num_input_chn'], res=img_size).to(device)
41
+ optimizer = Adam(model.parameters(), lr=hyp['lr'])
42
+
43
+ # Data - dummy dataset for XPU testing
44
+ dataset = DummyOMDataset_indiv(out_sz=img_size, num_samples=args.dummy_samples)
45
+ train_loader = DataLoader(dataset, batch_size=hyp['batchsize'], shuffle=True, drop_last=True)
46
+
47
+ # Training
48
+ print(f'Start training on {device} with {len(dataset)} dummy samples...')
49
+ for epoch in range(hyp['epoch']):
50
+ epoch_loss = 0.0
51
+
52
+ for i, (volume, embd) in enumerate(train_loader):
53
+ t0 = time.time()
54
+ volume = volume.float().to(device)
55
+ embd = embd.to(device) # [B, 1024] GT text embedding
56
+ t = torch.randint(0, hyp['timesteps'], (volume.shape[0],)).to(device)
57
+
58
+ _, img_embd = model(x=volume, y=volume, t=t) # img_embd: [B, 1024]
59
+
60
+ # Cosine similarity loss: align img_embd with GT text embedding
61
+ loss = 1 - F.cosine_similarity(img_embd, embd, dim=-1).mean()
62
+
63
+ optimizer.zero_grad()
64
+ loss.backward()
65
+ optimizer.step()
66
+ epoch_loss += loss.item()
67
+ t1 = time.time()
68
+ dt = t1 - t0
69
+ print(f" Batch {i:04d} | Loss: {loss.item():.6f} | Time: {dt:.2f}s")
70
+ avg_loss = epoch_loss / max(len(train_loader), 1)
71
+ print(f"Epoch {epoch:04d} | Avg Loss: {avg_loss:.6f}")