Omini3D / OMorpher /omorpher.py
maxmo2009's picture
Sync from local: code + epoch-110 checkpoint, clean README
2af0e94 verified
"""
OMorpher — Object-oriented wrapper for OmniMorph diffusion-based deformation.
Stores original high-res images and composes all intermediate deformations as
deformation fields (DDFs), resampling only once at the end to avoid blurring.
Independent of DeformDDPM at runtime; reimplements the diffusion logic using
the network / STN / loss building blocks from Diffusion.*.
"""
import os
import glob
import math
import random
from typing import Optional, Union, List, Tuple, Dict
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
import yaml
import SimpleITK as sitk
from skimage.transform import resize as sk_resize
from Diffusion.networks import get_net, STN, DefRec_MutAttnNet
from Diffusion.losses import Grad, MRSE, NCC
EPS = 1e-8
class OMorpher:
"""High-level interface for OmniMorph deformation diffusion.
All images are kept at their original resolution internally. Deformation
fields are composed at model resolution and up-scaled on demand so that the
original image is resampled at most *once*.
"""
# ------------------------------------------------------------------
# Construction
# ------------------------------------------------------------------
def __init__(
self,
config: Union[str, dict],
checkpoint_path: Optional[str] = None,
device: Optional[str] = None,
bert_model_path: Optional[str] = None,
):
# ---- Config ----
if isinstance(config, str):
with open(config, "r") as f:
config = yaml.safe_load(f)
self.config: dict = config
self.net_name: str = config.get("net_name", "recmutattnnet")
self.ndims: int = config.get("ndims", 3)
self.img_size: int = config.get("img_size", 128)
self.timesteps: int = config.get("timesteps", 80)
self.v_scale: float = config.get("v_scale", 5e-5)
self.noise_scale: float = config.get("noise_scale", 0.1)
self.condition_type: str = config.get("condition_type", "none")
self.num_input_chn: int = config.get("num_input_chn", 1)
self.img_pad_mode: str = config.get("img_pad_mode", "zeros")
self.ddf_pad_mode: str = config.get("ddf_pad_mode", "border")
self.padding_mode: str = config.get("padding_mode", "border")
self.resample_mode: str = config.get("resample_mode", "bilinear")
self.batch_size: int = config.get("batchsize", 1)
self.data_name: str = config.get("data_name", "all")
self.clamp_range: list = config.get("clamp_range", [-400, 400])
self.inf_mode: bool = config.get("inf_mode", True)
# ---- Device ----
if device is not None:
self.device = torch.device(device)
else:
self.device = self._resolve_device(config.get("device", None))
# ---- BERT (lazy) ----
self.bert_model_path = bert_model_path or os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"External", "Models", "bert_large_uncased",
)
self._bert_model = None
self._bert_tokenizer = None
# ---- Network ----
Net = get_net(self.net_name)
self.network = Net(
n_steps=self.timesteps,
ndims=self.ndims,
num_input_chn=self.num_input_chn,
res=self.img_size,
)
self.network.to(self.device)
# ---- STN instances ----
self.ctl_ratio = 4
self.ctl_sz = self.img_size // self.ctl_ratio
self.stn_full = STN(
img_sz=self.img_size,
ndims=self.ndims,
padding_mode=self.padding_mode,
device=self.device,
)
self.stn_ctl = STN(
img_sz=self.ctl_sz,
ndims=self.ndims,
padding_mode=self.ddf_pad_mode,
device=self.device,
)
self.img_stn = STN(
img_sz=self.img_size,
ndims=self.ndims,
padding_mode=self.img_pad_mode,
device=self.device,
resample_mode=self.resample_mode if self.resample_mode != "bilinear" else None,
)
self.msk_stn = STN(
img_sz=self.img_size,
ndims=self.ndims,
padding_mode=self.img_pad_mode,
device=self.device,
resample_mode="nearest",
)
# ---- Loss functions (for fine-tuning) ----
self._loss_grad = Grad(penalty=["l1"], ndims=self.ndims)
self._loss_dist = MRSE(img_sz=self.img_size)
self._loss_ang = NCC(img_sz=self.img_size)
# ---- Load checkpoint ----
if checkpoint_path is not None:
self._load_checkpoint(checkpoint_path)
else:
auto_path = self._auto_find_checkpoint()
if auto_path is not None:
self._load_checkpoint(auto_path)
self.network.eval()
# ---- State ----
self._init_img: Optional[torch.Tensor] = None # [B,1,S,S,S] model-res
self._init_img_raw: Optional[torch.Tensor] = None # [B,1,D,H,W] full-res
self._init_img_original_shape: Optional[tuple] = None
self._init_ddf: Optional[torch.Tensor] = None # [B,ndims,S,S,S]
self._cond_img: Optional[torch.Tensor] = None # [B,1,S,S,S]
self._cond_txt: Optional[torch.Tensor] = None # [B,1024]
self._predicted_ddf: Optional[torch.Tensor] = None # [B,ndims,S,S,S]
self._intermediate_ddfs: List[Tuple[int, torch.Tensor]] = []
# ---- Fine-tuning state ----
self._optimizer: Optional[torch.optim.Optimizer] = None
# ------------------------------------------------------------------
# Device resolution
# ------------------------------------------------------------------
@staticmethod
def _resolve_device(hint: Optional[str] = None) -> torch.device:
if hint is not None:
s = str(hint).lower()
if s not in ("auto", ""):
return torch.device(s)
# XPU → CUDA → CPU
try:
import intel_extension_for_pytorch # noqa: F401
if torch.xpu.is_available():
return torch.device("xpu")
except (ImportError, AttributeError):
pass
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
# ------------------------------------------------------------------
# Checkpoint helpers
# ------------------------------------------------------------------
def _auto_find_checkpoint(self) -> Optional[str]:
pattern = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"Models",
f"{self.data_name}_{self.net_name}",
"*.pth",
)
files = sorted(glob.glob(pattern))
return files[-1] if files else None
def _load_checkpoint(self, path: str):
ckpt = torch.load(path, map_location="cpu")
state_dict = ckpt.get("model_state_dict", ckpt)
# Strip DDP 'module.' prefix and DeformDDPM wrapper keys
cleaned = {}
for k, v in state_dict.items():
k = k.replace("module.", "")
if k.startswith("network."):
k = k[len("network."):]
cleaned[k] = v
# Only load keys that exist in the network
net_keys = set(self.network.state_dict().keys())
filtered = {k: v for k, v in cleaned.items() if k in net_keys}
if filtered:
self.network.load_state_dict(filtered, strict=False)
# ------------------------------------------------------------------
# Public — Input setters
# ------------------------------------------------------------------
def set_init_img(
self,
img,
modality: Optional[str] = None,
) -> "OMorpher":
"""Set the initial image. Accepts numpy, torch, path, or (img, ddf) tuple."""
init_ddf = None
if isinstance(img, (tuple, list)):
img, init_ddf = img[0], img[1]
model_tensor, fullres_tensor, orig_shape = self._standardize_img(
img, modality=modality, keep_raw=True,
)
self._init_img = model_tensor
self._init_img_raw = fullres_tensor
self._init_img_original_shape = orig_shape
if init_ddf is not None:
self._init_ddf = self._to_ddf_tensor(init_ddf)
else:
B = self._init_img.shape[0]
S = self.img_size
self._init_ddf = torch.zeros(
[B, self.ndims] + [S] * self.ndims,
dtype=torch.float32, device=self.device,
)
return self
def set_cond_img(
self,
img=None,
modality: Optional[str] = None,
) -> "OMorpher":
"""Set the conditioning image. Default: Gaussian noise sigma=0.1."""
if img is None:
B = self._init_img.shape[0] if self._init_img is not None else self.batch_size
S = self.img_size
self._cond_img = torch.randn(
[B, 1] + [S] * self.ndims,
dtype=torch.float32, device=self.device,
) * 0.1
else:
tensor, _, _ = self._standardize_img(img, modality=modality, keep_raw=False)
self._cond_img = tensor
return self
def set_cond_txt(self, txt=None) -> "OMorpher":
"""Set the text conditioning. Accepts string, numpy [1024], torch [1024], or None."""
self._cond_txt = self._standardize_txt(txt)
return self
def set_init_def(self, ddf=None) -> "OMorpher":
"""Set or regenerate the initial deformation field.
If *ddf* is ``None``, a random DDF is generated using the forward
diffusion parameters (useful for data augmentation).
"""
if ddf is None:
if self._init_img is None:
raise RuntimeError("set_init_img() must be called before set_init_def()")
t_val = self.config.get("start_noise_step", self.timesteps // 2)
t = torch.tensor([t_val], dtype=torch.long, device=self.device)
_, _, random_ddf = self._get_random_ddf(self._init_img, t)
self._init_ddf = random_ddf
else:
self._init_ddf = self._to_ddf_tensor(ddf)
return self
# ------------------------------------------------------------------
# Public — Core operations (inference)
# ------------------------------------------------------------------
def predict(
self,
T: Optional[list] = None,
proc_type: Optional[str] = None,
t_save: Optional[list] = None,
) -> "OMorpher":
"""Run reverse diffusion and store predicted DDF. Returns ``self`` for chaining."""
if self._init_img is None:
raise RuntimeError("set_init_img() must be called before predict()")
# Defaults
start_noise = self.config.get("start_noise_step", 0)
if T is None:
T = [start_noise, self.timesteps]
if proc_type is None:
proc_type = self.condition_type
B = self._init_img.shape[0]
S = self.img_size
# Conditioning
cond_img_src = self._cond_img if self._cond_img is not None else self._init_img.clone().detach()
cond_img, mask, cond_ratio = self._proc_cond_img(cond_img_src, proc_type=proc_type)
# Text embedding
txt = self._cond_txt
if txt is None:
txt = torch.zeros([B, 1024], dtype=torch.float32, device=self.device)
# Reshape text for network consumption
if isinstance(self.network, DefRec_MutAttnNet):
txt = txt.view(B, -1, *([1] * self.ndims))
# Initial state
init_ddf_is_zero = (self._init_ddf is None) or torch.all(self._init_ddf == 0)
if not init_ddf_is_zero:
ddf_comp = self._init_ddf.clone()
img_rec = self.img_stn(self._init_img, ddf_comp)
elif T[0] is not None and T[0] > 0:
t_start = torch.tensor(np.array([T[0]]), device=self.device)
img_rec, _, ddf_comp = self._get_random_ddf(self._init_img, t_start)
else:
img_rec = self._init_img.clone()
ddf_comp = torch.zeros(
[B, self.ndims] + [S] * self.ndims,
dtype=torch.float32, device=self.device,
)
# Reverse diffusion loop
self._intermediate_ddfs = []
rec_num = 2 # matches DeformDDPM.rec_num default
if isinstance(self.network, DefRec_MutAttnNet):
# DefRec network: pass full time list at once
t_list = list(range(T[1] - 1, -1, -1))
with torch.no_grad():
pre_dvf = self.network(
x=img_rec, y=cond_img, t=t_list, rec_num=rec_num, text=txt,
)
ddf_comp = self.stn_full(ddf_comp, pre_dvf) + pre_dvf
img_rec = self.img_stn(self._init_img.clone().detach(), ddf_comp)
if t_save:
self._intermediate_ddfs.append((0, ddf_comp.clone()))
else:
# Standard iterative recovery
time_steps = range(T[1] - 1, -1, -1)
for i in time_steps:
t = torch.tensor(np.array([i]), device=self.device)
with torch.no_grad():
pre_dvf = self.network(
x=img_rec, y=cond_img, t=t, rec_num=rec_num, text=txt,
)
ddf_comp = self.stn_full(ddf_comp, pre_dvf) + pre_dvf
img_rec = self.img_stn(self._init_img.clone().detach(), ddf_comp)
if t_save is not None and i in t_save:
self._intermediate_ddfs.append((i, ddf_comp.clone()))
self._predicted_ddf = ddf_comp
return self
def get_def(
self,
t_list: Optional[list] = None,
) -> Union[torch.Tensor, Dict[int, torch.Tensor]]:
"""Return the final predicted DDF, or intermediate DDFs for given timesteps."""
if t_list is None:
if self._predicted_ddf is None:
raise RuntimeError("predict() must be called before get_def()")
return self._predicted_ddf
out = {}
for t, ddf in self._intermediate_ddfs:
if t in t_list:
out[t] = ddf
return out
def apply_def(
self,
img=None,
ddf: Optional[torch.Tensor] = None,
padding_mode: Optional[str] = None,
resample_mode: Optional[str] = None,
) -> torch.Tensor:
"""Apply a DDF to an image. Auto-upscales DDF when sizes differ.
Defaults: init image at full resolution, predicted DDF.
"""
if padding_mode is None:
padding_mode = self.padding_mode
if resample_mode is None:
resample_mode = "bilinear"
# Default DDF
if ddf is None:
if self._predicted_ddf is None:
raise RuntimeError("predict() must be called before apply_def()")
ddf = self._predicted_ddf
# Default image: full-res init image tensor
if img is None:
if self._init_img_raw is not None:
vol_tensor = self._init_img_raw
else:
vol_tensor = self._init_img
else:
vol_tensor = self._ensure_tensor(img)
# Upscale DDF if sizes differ
target_sz = list(vol_tensor.shape[2:])
ddf_sz = list(ddf.shape[2:])
if target_sz != ddf_sz:
ddf = F.interpolate(
ddf, size=target_sz,
mode="bilinear" if self.ndims == 2 else "trilinear",
align_corners=False,
)
return self._apply_ddf(vol_tensor, ddf, padding_mode=padding_mode, resample_mode=resample_mode)
# ------------------------------------------------------------------
# Public — Fine-tuning
# ------------------------------------------------------------------
def finetune_setup(
self,
lr: float = 1e-4,
optimizer_cls=None,
) -> "OMorpher":
"""Switch to training mode and create an optimizer."""
self.network.train()
self.inf_mode = False
if optimizer_cls is None:
optimizer_cls = torch.optim.Adam
self._optimizer = optimizer_cls(self.network.parameters(), lr=lr)
return self
def finetune_step(
self,
img_batch,
cond_batch=None,
text_batch=None,
t=None,
proc_type=None,
) -> dict:
"""Single training step. Returns loss dict."""
if self._optimizer is None:
raise RuntimeError("finetune_setup() must be called first")
img, _, _ = self._standardize_img(img_batch, keep_raw=False)
cond = self._standardize_img(cond_batch, keep_raw=False)[0] if cond_batch is not None else img.clone()
text = self._standardize_txt(text_batch)
B = img.shape[0]
if t is None:
t = torch.randint(0, self.timesteps, (B,), device=self.device)
else:
t = torch.tensor(t, device=self.device) if not isinstance(t, torch.Tensor) else t.to(self.device)
proc_type = proc_type or self.condition_type
cond_img, mask, cond_ratio = self._proc_cond_img(cond, proc_type=proc_type)
noisy_img, dvf_gt, _ = self._get_random_ddf(img, t)
# Reshape text for network
if isinstance(self.network, DefRec_MutAttnNet):
if text is not None:
text = text.view(B, -1, *([1] * self.ndims))
t_input = [t]
else:
t_input = t
pre_dvf = self.network(x=noisy_img * mask, y=cond_img, t=t_input, rec_num=2, text=text)
loss_grad = self._loss_grad(y_pred=pre_dvf, img=img)
trm_pred = self.stn_full(pre_dvf, dvf_gt)
loss_dist = self._loss_dist(pred=trm_pred, inv_lab=dvf_gt)
loss_ang = self._loss_ang(pred=trm_pred, inv_lab=dvf_gt)
loss_total = 2.0 * loss_ang + 1.0 * loss_dist + 16.0 * loss_grad
self._optimizer.zero_grad()
loss_total.backward()
self._optimizer.step()
return {
"loss_total": loss_total.item(),
"loss_grad": loss_grad.item(),
"loss_dist": loss_dist.item(),
"loss_ang": loss_ang.item(),
}
def finetune_save(self, path: str, epoch: int = 0):
"""Save checkpoint in the standard OmniMorph format."""
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
torch.save(
{
"model_state_dict": self.network.state_dict(),
"optimizer_state_dict": self._optimizer.state_dict() if self._optimizer else None,
"epoch": epoch,
},
path,
)
def finetune_teardown(self) -> "OMorpher":
"""Switch back to eval mode."""
self.network.eval()
self.inf_mode = True
self._optimizer = None
return self
# ------------------------------------------------------------------
# Private — Diffusion logic
# ------------------------------------------------------------------
def _get_ddf_scale(
self, t: torch.Tensor, divide_num: int = 1, max_ddf_num: int = 200,
) -> Tuple[int, torch.Tensor, torch.Tensor]:
"""Timestep-dependent deformation magnitude. Mirrors DeformDDPM._get_ddf_scale()."""
rec_num = 1
mul_num_ddf = torch.floor_divide(2 * torch.pow(t.float(), 1.3), 3 * divide_num).int()
mul_num_dvf = torch.floor_divide(torch.pow(t.float(), 0.6), divide_num).int()
mul_num_ddf = torch.clamp(mul_num_ddf, min=1, max=max_ddf_num)
mul_num_dvf = torch.clamp(mul_num_dvf, min=0, max=max_ddf_num)
return rec_num, mul_num_ddf, mul_num_dvf
def _sample_random_uniform_multi_order(
self, high=None, low=0.0, order_num=3,
) -> float:
sample_value = low
for _ in range(order_num):
sample_value = np.random.uniform(low=sample_value, high=high)
return sample_value
def _multiscale_dvf_generate(
self, v_scale: float, ctl_szs: list = None, rand_v_scale: bool = True,
) -> torch.Tensor:
"""Multi-scale Gaussian DVF at control-point sizes."""
if ctl_szs is None:
ctl_szs = [4, 8, 16, 32, 64]
dvf = 0
for ctl_sz in ctl_szs:
_v = (
self._sample_random_uniform_multi_order(high=v_scale, low=1e-8, order_num=2)
if rand_v_scale
else v_scale
)
if ctl_sz <= 2:
_v = _v / 2
dvf_comp = torch.randn(
[self.batch_size, self.ndims] + [ctl_sz] * self.ndims
) * _v
dvf_comp = F.interpolate(
dvf_comp * self.ctl_sz / ctl_sz,
[self.ctl_sz] * self.ndims,
align_corners=False,
mode="bilinear" if self.ndims == 2 else "trilinear",
)
dvf = dvf + dvf_comp
return dvf
def _random_ddf_generate(
self,
rec_num: int = 3,
mul_num: list = None,
noise_ratio: float = 0.08,
select_num: int = 4,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compose DVFs to build a DDF. Mirrors DeformDDPM._random_ddf_generate()."""
if mul_num is None:
mul_num = [torch.tensor([5]), torch.tensor([5])]
crop_rate = 2
# unsqueeze mul_num for broadcasting
for _ in range(self.ndims + 1):
mul_num = [torch.unsqueeze(n, -1) for n in mul_num]
ctl_ddf_sz = [self.batch_size, self.ndims] + [self.ctl_sz] * self.ndims
ddf = torch.zeros(ctl_ddf_sz)
dddf = torch.zeros(ctl_ddf_sz)
scale_num = min(8, int(math.log2(self.ctl_sz)))
ctl_szs_all = [self.ctl_sz // (2 ** i) for i in range(scale_num)]
for _i in range(rec_num):
if len(ctl_szs_all) > select_num:
ctl_szs = random.sample(ctl_szs_all, select_num)
else:
ctl_szs = ctl_szs_all
dvf = self._multiscale_dvf_generate(self.v_scale, ctl_szs=ctl_szs).to(self.device)
if noise_ratio == 0:
dvf0 = dvf
else:
dvf0 = dvf + self.stn_ctl(
self._multiscale_dvf_generate(
self.v_scale * noise_ratio, ctl_szs=ctl_szs, rand_v_scale=False,
).to(self.device),
dvf,
)
for j in range(torch.max(mul_num[0]).item()):
flag = [(n > j).int().to(self.device) for n in mul_num]
ddf = dvf0 * flag[0] + self.stn_ctl(ddf, dvf0 * flag[0])
dddf = dvf * flag[1] + self.stn_ctl(dddf, dvf * flag[1])
# Upscale and center-crop
interp_mode = "bilinear" if self.ndims == 2 else "trilinear"
ddf = F.interpolate(
ddf * self.img_size / self.ctl_sz,
self.img_size * crop_rate,
mode=interp_mode,
)
dddf = F.interpolate(
dddf * self.img_size / self.ctl_sz,
self.img_size * crop_rate,
mode=interp_mode,
)
half = self.img_size // 2
three_half = self.img_size * 3 // 2
if self.ndims == 2:
ddf = ddf[..., half:three_half, half:three_half]
dddf = dddf[..., half:three_half, half:three_half]
else:
ddf = ddf[..., half:three_half, half:three_half, half:three_half]
dddf = dddf[..., half:three_half, half:three_half, half:three_half]
return ddf, dddf
def _get_random_ddf(
self, img: torch.Tensor, t: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Forward-diffuse: generate random DDF and warp image."""
rec_num, mul_num_ddf, mul_num_dvf = self._get_ddf_scale(t=t)
ddf_forward, dvf_forward = self._random_ddf_generate(
rec_num=rec_num, mul_num=[mul_num_ddf, mul_num_dvf],
)
warped_img = self.img_stn(img, ddf_forward)
return warped_img, dvf_forward, ddf_forward
# ------------------------------------------------------------------
# Private — Conditioning processing
# ------------------------------------------------------------------
def _proc_cond_img(
self,
img: torch.Tensor,
proc_type: Optional[str] = None,
noise_scale: float = 0.1,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Conditioning strategies. Mirrors DeformDDPM.proc_cond_img()."""
proc_img = img.clone().detach()
if proc_type is None:
proc_type = random.choices(
["adding", "independ", "downsample", "slice", "none", "uncon"],
weights=[1, 1, 1, 1, 1, 3],
k=1,
)[0]
mask = torch.tensor(1, device=img.device)
cond_ratio = torch.tensor(1.0, device=img.device)
if proc_type in ["none", None, "", "None"]:
return proc_img, mask, cond_ratio
noise_type = random.choice(["gaussian", "uniform", "none"])
if proc_type == "uncon":
noise_map = self._create_noise_map(img, noise_type=noise_type, noise_scale=noise_scale)
return noise_map, torch.tensor(0, device=img.device), torch.tensor(0, device=img.device)
noise_map = None
if proc_type in ["adding", "independ", "slice"]:
noise_map = self._create_noise_map(img, noise_type=noise_type, noise_scale=noise_scale)
if proc_type == "adding":
noise_ratio = np.random.uniform(0.0, 1.0)
proc_img = proc_img * (1 - noise_ratio) + noise_map * noise_ratio
cond_ratio = torch.tensor(1 - noise_ratio, device=img.device)
elif proc_type == "independ":
mask = self._create_noise_map(img, noise_type="binary")
proc_img = img * mask
cond_ratio = mask.float().mean()
elif proc_type == "downsample":
down_ratio = list(np.random.uniform(1.0 / 64, 1, [self.ndims]))
down_img = F.interpolate(
proc_img, scale_factor=down_ratio,
mode="bilinear" if self.ndims == 2 else "trilinear",
)
proc_img = F.interpolate(
down_img, size=[self.img_size] * self.ndims,
mode="bilinear" if self.ndims == 2 else "trilinear",
align_corners=False,
)
cond_ratio = torch.tensor(np.sqrt(np.prod(down_ratio)), device=img.device)
elif proc_type == "slice":
slice_num_max = random.randint(1, 64)
slice_num_max = random.randint(1, slice_num_max)
mask, sample_ratio = self._get_slice_mask(img, slice_num_range=[0, slice_num_max])
proc_img = img * mask
cond_ratio = torch.tensor(sample_ratio, device=img.device)
elif proc_type == "project":
proj_img = torch.zeros_like(img)
rand_bourn = np.random.randint(0, 2, size=[self.ndims])
proj_dim_num = np.sum(rand_bourn)
for i, pflag in zip(range(2, 2 + self.ndims), rand_bourn):
if pflag:
proj_img += torch.mean(img, dim=i, keepdim=True)
proc_img = proj_img / (proj_dim_num + EPS)
cond_ratio = torch.tensor(proj_dim_num / (128 * self.ndims), device=img.device)
return proc_img, mask, cond_ratio
def _create_noise_map(
self,
img: torch.Tensor,
noise_type: str = "gaussian",
noise_scale: float = 0.1,
) -> torch.Tensor:
if noise_type == "gaussian":
return (torch.randn_like(img) * noise_scale).to(img.device)
elif noise_type == "uniform":
return (torch.rand_like(img) * noise_scale * 2 - noise_scale).to(img.device)
elif noise_type == "binary":
return torch.bernoulli(torch.rand_like(img)).to(img.device)
return torch.zeros_like(img).to(img.device)
def _get_slice_mask(
self,
img: torch.Tensor,
slice_num_range: list = None,
) -> Tuple[torch.Tensor, float]:
if slice_num_range is None:
slice_num_range = [0, 32]
slice_num_range[1] = min(slice_num_range[1], self.img_size)
mask = torch.zeros_like(img)
sample_ratio = 0.0
for i in range(self.ndims):
if self.inf_mode:
slice_num = 1
slice_idx = [self.img_size // 2]
else:
slice_num = random.randint(slice_num_range[0], slice_num_range[1])
slice_idx = random.sample(range(self.img_size), slice_num)
transpose_list = [0, 1, 1 + self.ndims] + list(range(2, 1 + self.ndims))
for idx in slice_idx:
mask[..., idx] = 1
mask = mask.permute(*transpose_list)
sample_ratio += np.sqrt(slice_num / self.img_size) / self.ndims
return mask, sample_ratio
# ------------------------------------------------------------------
# Private — Standardization
# ------------------------------------------------------------------
def _standardize_img(
self,
img,
modality: Optional[str] = None,
keep_raw: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple]]:
"""Deterministic inference variant of the dataloader pipeline.
Returns ``(model_tensor, fullres_tensor_or_None, orig_shape_or_None)``.
* *model_tensor*: ``[B, C, S, S, S]`` at model resolution.
* *fullres_tensor*: ``[B, C, D, H, W]`` at original padded resolution
(only when *keep_raw=True*).
* *orig_shape*: spatial dims of padded volume before resize.
Accepts numpy arrays, torch tensors (any dimensionality), or a
file path (loaded via SimpleITK). Torch tensors with >= 4 dims
are treated as already-batched and are passed through with
appropriate device/dtype conversion.
"""
fullres_tensor = None
orig_shape = None
# 1. Load from path
if isinstance(img, str):
sitk_img = sitk.ReadImage(img)
vol = sitk.GetArrayFromImage(sitk_img)
vol = self._reverse_axis_order(vol)
elif isinstance(img, np.ndarray):
vol = img.copy()
elif isinstance(img, torch.Tensor):
# If already a batched tensor [B,C,...], pass through
if img.ndim >= 4:
t = img.float().to(self.device)
if keep_raw:
fullres_tensor = t.clone()
return t, fullres_tensor, None
# 1-3D tensor — treat as spatial-only numpy
vol = img.numpy()
else:
raise TypeError(f"Unsupported image type: {type(img)}")
# 2. Extract 3D from 4D
if vol.ndim == 4:
vol = vol[:, :, :, 0]
# 3. CT clamping
if modality is not None and modality.upper() == "CT" and self.clamp_range is not None:
vol = np.clip(vol, self.clamp_range[0], self.clamp_range[1])
# 4. Normalize [0, 1]
vol = vol.astype(np.float64)
vol = (vol - np.min(vol)) / (np.ptp(vol) + 1e-7)
# 5. Center-pad to cube
vol = self._center_pad_to_cube(vol)
orig_shape = vol.shape[:3]
# 6. Full-res tensor (before resize)
if keep_raw:
fullres_tensor = torch.tensor(
vol[None, None, ...], dtype=torch.float32, device=self.device,
)
# 7. Resize to model resolution
target_sz = [self.img_size] * self.ndims
vol_resized = sk_resize(
vol, target_sz, anti_aliasing=True, preserve_range=True,
)
# 8. Add batch + channel dims
model_tensor = torch.tensor(
vol_resized[None, None, ...], dtype=torch.float32, device=self.device,
)
return model_tensor, fullres_tensor, orig_shape
def _standardize_label(
self,
label,
fill_value: float = -1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Standardize a label volume for inference.
Returns ``(model_tensor, fullres_tensor)``.
* *model_tensor*: ``[1, C, S, S, S]`` at model resolution
(nearest-neighbor resize, no anti-aliasing).
* *fullres_tensor*: ``[1, C, D, H, W]`` at original padded resolution.
If *label* is ``None``, returns *fill_value*-filled placeholders
shaped to match the current init image (model-res and full-res).
Accepts numpy arrays or torch tensors. Does NOT apply
normalization or clamping (labels are discrete indices).
"""
# --- Placeholder for missing labels ---
if label is None:
model_sz = [self.img_size] * self.ndims
model_t = torch.full(
[1, 1] + model_sz, fill_value,
dtype=torch.float32, device=self.device,
)
if self._init_img_raw is not None:
fullres_sz = list(self._init_img_raw.shape[2:])
else:
fullres_sz = model_sz
fullres_t = torch.full(
[1, 1] + fullres_sz, fill_value,
dtype=torch.float32, device=self.device,
)
return model_t, fullres_t
# --- Convert to numpy if needed ---
if isinstance(label, torch.Tensor):
if label.ndim >= 4:
# Already batched tensor — pass through
fullres_t = label.float().to(self.device)
target_sz = [self.img_size] * self.ndims
model_t = F.interpolate(
fullres_t, size=target_sz, mode="nearest",
)
return model_t, fullres_t
lab = label.numpy()
elif isinstance(label, np.ndarray):
lab = label.copy()
else:
raise TypeError(f"Unsupported label type: {type(label)}")
# --- Center-pad to cube ---
lab = self._center_pad_to_cube(lab)
# --- Channel dim: 3D→[C=1,...], 4D→channels-first [C,...] ---
if lab.ndim == 3:
lab = lab[None, :, :, :] # [1, D, H, W]
elif lab.ndim > 3:
lab = np.transpose(lab, (3, 0, 1, 2)) # [C, D, H, W]
# --- Full-res tensor ---
fullres_t = torch.tensor(
lab[None, ...], dtype=torch.float32, device=self.device,
) # [1, C, D, H, W]
# --- Resize to model resolution (nearest-neighbor) ---
target_sz = [self.img_size] * self.ndims
# Resize each channel separately to avoid resizing the channel dim
channels = []
for c in range(lab.shape[0]):
ch = sk_resize(
lab[c], target_sz,
anti_aliasing=False, preserve_range=True, order=0,
)
channels.append(ch)
lab_model = np.stack(channels, axis=0) # [C, S, S, S]
model_t = torch.tensor(
lab_model[None, ...], dtype=torch.float32, device=self.device,
) # [1, C, S, S, S]
return model_t, fullres_t
def _standardize_txt(self, txt) -> Optional[torch.Tensor]:
"""Convert text input to [B, 1024] tensor."""
if txt is None:
return None
if isinstance(txt, str):
self._ensure_bert()
from Dataloader.bert_helper import str2emb
emb = str2emb(
txt, max_words_num=100,
embeder=self._bert_model, tokenizer=self._bert_tokenizer,
reduce_method="mean",
)
return emb.to(self.device) # [1, 1024]
if isinstance(txt, np.ndarray):
t = torch.tensor(txt, dtype=torch.float32, device=self.device)
if t.ndim == 1:
t = t.unsqueeze(0)
return t
if isinstance(txt, torch.Tensor):
t = txt.float().to(self.device)
if t.ndim == 1:
t = t.unsqueeze(0)
return t
raise TypeError(f"Unsupported text type: {type(txt)}")
def _ensure_bert(self):
if self._bert_model is None:
from Dataloader.bert_helper import get_frozen_embeder
self._bert_model, self._bert_tokenizer = get_frozen_embeder(self.bert_model_path)
# ------------------------------------------------------------------
# Private — Spatial utilities
# ------------------------------------------------------------------
@staticmethod
def _reverse_axis_order(arr: np.ndarray) -> np.ndarray:
"""SimpleITK → NumPy axis order."""
return np.ascontiguousarray(arr.transpose(tuple(range(arr.ndim)[::-1])))
@staticmethod
def _center_pad_to_cube(volume: np.ndarray) -> np.ndarray:
"""Pad volume to a cube using the max dimension, with symmetric padding."""
max_dim = max(volume.shape[:3])
pad_width = []
for s in volume.shape[:3]:
total_pad = max_dim - s
pad_before = total_pad // 2
pad_after = total_pad - pad_before
pad_width.append((pad_before, pad_after))
for _ in range(volume.ndim - 3):
pad_width.append((0, 0))
return np.pad(volume, pad_width, mode="constant", constant_values=0)
def _apply_ddf(
self,
volume_tensor: torch.Tensor,
ddf: torch.Tensor,
padding_mode: str = "border",
resample_mode: str = "bilinear",
) -> torch.Tensor:
"""Apply DDF to volume tensor at any resolution via grid_sample."""
device = ddf.device
ndims = self.ndims
img_sz = list(volume_tensor.shape[2:])
max_sz = torch.reshape(
torch.tensor(img_sz, dtype=torch.float32, device=device),
[1, ndims] + [1] * ndims,
)
ref_grid = torch.reshape(
torch.stack(
torch.meshgrid(
[torch.arange(s, device=device, dtype=torch.float32) for s in img_sz],
indexing="ij",
),
0,
),
[1, ndims] + img_sz,
)
img_shape = torch.reshape(
torch.tensor(
[(s - 1) / 2.0 for s in img_sz], dtype=torch.float32, device=device,
),
[1] + [1] * ndims + [ndims],
)
grid = torch.flip(
(ddf * max_sz + ref_grid).permute(
[0] + list(range(2, 2 + ndims)) + [1]
)
/ img_shape
- 1,
dims=[-1],
)
return F.grid_sample(
volume_tensor.to(device),
grid.float(),
mode=resample_mode,
padding_mode=padding_mode,
align_corners=True,
)
def _ensure_tensor(self, img) -> torch.Tensor:
"""Convert numpy/torch input to a [B, C, ...] float tensor on device."""
if isinstance(img, np.ndarray):
t = torch.tensor(img, dtype=torch.float32, device=self.device)
elif isinstance(img, torch.Tensor):
t = img.float().to(self.device)
else:
raise TypeError(f"Unsupported image type: {type(img)}")
if t.ndim == self.ndims: # spatial only → [B=1, C=1, ...]
t = t[None, None, ...]
elif t.ndim == self.ndims + 1: # [C, ...] → [B=1, C, ...]
t = t[None, ...]
return t
def _to_ddf_tensor(self, ddf) -> torch.Tensor:
"""Convert ddf input to proper tensor on device."""
if isinstance(ddf, np.ndarray):
ddf = torch.tensor(ddf, dtype=torch.float32)
ddf = ddf.float().to(self.device)
if ddf.ndim == self.ndims + 1:
ddf = ddf.unsqueeze(0)
# Resize to model resolution if needed
model_sz = [self.img_size] * self.ndims
if list(ddf.shape[2:]) != model_sz:
ddf = F.interpolate(
ddf, size=model_sz,
mode="bilinear" if self.ndims == 2 else "trilinear",
align_corners=False,
)
return ddf
# ------------------------------------------------------------------
# Convenience / repr
# ------------------------------------------------------------------
def __repr__(self) -> str:
status_parts = []
if self._init_img is not None:
status_parts.append(f"init_img={list(self._init_img.shape)}")
if self._cond_img is not None:
status_parts.append(f"cond_img={list(self._cond_img.shape)}")
if self._predicted_ddf is not None:
status_parts.append(f"predicted_ddf={list(self._predicted_ddf.shape)}")
status = ", ".join(status_parts) if status_parts else "empty"
return (
f"OMorpher(net={self.net_name}, ndims={self.ndims}, "
f"img_size={self.img_size}, device={self.device}, {status})"
)