""" OMorpher — Object-oriented wrapper for OmniMorph diffusion-based deformation. Stores original high-res images and composes all intermediate deformations as deformation fields (DDFs), resampling only once at the end to avoid blurring. Independent of DeformDDPM at runtime; reimplements the diffusion logic using the network / STN / loss building blocks from Diffusion.*. """ import os import glob import math import random from typing import Optional, Union, List, Tuple, Dict import numpy as np import torch import torch.nn.functional as F from torch import nn import yaml import SimpleITK as sitk from skimage.transform import resize as sk_resize from Diffusion.networks import get_net, STN, DefRec_MutAttnNet from Diffusion.losses import Grad, MRSE, NCC EPS = 1e-8 class OMorpher: """High-level interface for OmniMorph deformation diffusion. All images are kept at their original resolution internally. Deformation fields are composed at model resolution and up-scaled on demand so that the original image is resampled at most *once*. """ # ------------------------------------------------------------------ # Construction # ------------------------------------------------------------------ def __init__( self, config: Union[str, dict], checkpoint_path: Optional[str] = None, device: Optional[str] = None, bert_model_path: Optional[str] = None, ): # ---- Config ---- if isinstance(config, str): with open(config, "r") as f: config = yaml.safe_load(f) self.config: dict = config self.net_name: str = config.get("net_name", "recmutattnnet") self.ndims: int = config.get("ndims", 3) self.img_size: int = config.get("img_size", 128) self.timesteps: int = config.get("timesteps", 80) self.v_scale: float = config.get("v_scale", 5e-5) self.noise_scale: float = config.get("noise_scale", 0.1) self.condition_type: str = config.get("condition_type", "none") self.num_input_chn: int = config.get("num_input_chn", 1) self.img_pad_mode: str = config.get("img_pad_mode", "zeros") self.ddf_pad_mode: str = config.get("ddf_pad_mode", "border") self.padding_mode: str = config.get("padding_mode", "border") self.resample_mode: str = config.get("resample_mode", "bilinear") self.batch_size: int = config.get("batchsize", 1) self.data_name: str = config.get("data_name", "all") self.clamp_range: list = config.get("clamp_range", [-400, 400]) self.inf_mode: bool = config.get("inf_mode", True) # ---- Device ---- if device is not None: self.device = torch.device(device) else: self.device = self._resolve_device(config.get("device", None)) # ---- BERT (lazy) ---- self.bert_model_path = bert_model_path or os.path.join( os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "External", "Models", "bert_large_uncased", ) self._bert_model = None self._bert_tokenizer = None # ---- Network ---- Net = get_net(self.net_name) self.network = Net( n_steps=self.timesteps, ndims=self.ndims, num_input_chn=self.num_input_chn, res=self.img_size, ) self.network.to(self.device) # ---- STN instances ---- self.ctl_ratio = 4 self.ctl_sz = self.img_size // self.ctl_ratio self.stn_full = STN( img_sz=self.img_size, ndims=self.ndims, padding_mode=self.padding_mode, device=self.device, ) self.stn_ctl = STN( img_sz=self.ctl_sz, ndims=self.ndims, padding_mode=self.ddf_pad_mode, device=self.device, ) self.img_stn = STN( img_sz=self.img_size, ndims=self.ndims, padding_mode=self.img_pad_mode, device=self.device, resample_mode=self.resample_mode if self.resample_mode != "bilinear" else None, ) self.msk_stn = STN( img_sz=self.img_size, ndims=self.ndims, padding_mode=self.img_pad_mode, device=self.device, resample_mode="nearest", ) # ---- Loss functions (for fine-tuning) ---- self._loss_grad = Grad(penalty=["l1"], ndims=self.ndims) self._loss_dist = MRSE(img_sz=self.img_size) self._loss_ang = NCC(img_sz=self.img_size) # ---- Load checkpoint ---- if checkpoint_path is not None: self._load_checkpoint(checkpoint_path) else: auto_path = self._auto_find_checkpoint() if auto_path is not None: self._load_checkpoint(auto_path) self.network.eval() # ---- State ---- self._init_img: Optional[torch.Tensor] = None # [B,1,S,S,S] model-res self._init_img_raw: Optional[torch.Tensor] = None # [B,1,D,H,W] full-res self._init_img_original_shape: Optional[tuple] = None self._init_ddf: Optional[torch.Tensor] = None # [B,ndims,S,S,S] self._cond_img: Optional[torch.Tensor] = None # [B,1,S,S,S] self._cond_txt: Optional[torch.Tensor] = None # [B,1024] self._predicted_ddf: Optional[torch.Tensor] = None # [B,ndims,S,S,S] self._intermediate_ddfs: List[Tuple[int, torch.Tensor]] = [] # ---- Fine-tuning state ---- self._optimizer: Optional[torch.optim.Optimizer] = None # ------------------------------------------------------------------ # Device resolution # ------------------------------------------------------------------ @staticmethod def _resolve_device(hint: Optional[str] = None) -> torch.device: if hint is not None: s = str(hint).lower() if s not in ("auto", ""): return torch.device(s) # XPU → CUDA → CPU try: import intel_extension_for_pytorch # noqa: F401 if torch.xpu.is_available(): return torch.device("xpu") except (ImportError, AttributeError): pass if torch.cuda.is_available(): return torch.device("cuda") return torch.device("cpu") # ------------------------------------------------------------------ # Checkpoint helpers # ------------------------------------------------------------------ def _auto_find_checkpoint(self) -> Optional[str]: pattern = os.path.join( os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "Models", f"{self.data_name}_{self.net_name}", "*.pth", ) files = sorted(glob.glob(pattern)) return files[-1] if files else None def _load_checkpoint(self, path: str): ckpt = torch.load(path, map_location="cpu") state_dict = ckpt.get("model_state_dict", ckpt) # Strip DDP 'module.' prefix and DeformDDPM wrapper keys cleaned = {} for k, v in state_dict.items(): k = k.replace("module.", "") if k.startswith("network."): k = k[len("network."):] cleaned[k] = v # Only load keys that exist in the network net_keys = set(self.network.state_dict().keys()) filtered = {k: v for k, v in cleaned.items() if k in net_keys} if filtered: self.network.load_state_dict(filtered, strict=False) # ------------------------------------------------------------------ # Public — Input setters # ------------------------------------------------------------------ def set_init_img( self, img, modality: Optional[str] = None, ) -> "OMorpher": """Set the initial image. Accepts numpy, torch, path, or (img, ddf) tuple.""" init_ddf = None if isinstance(img, (tuple, list)): img, init_ddf = img[0], img[1] model_tensor, fullres_tensor, orig_shape = self._standardize_img( img, modality=modality, keep_raw=True, ) self._init_img = model_tensor self._init_img_raw = fullres_tensor self._init_img_original_shape = orig_shape if init_ddf is not None: self._init_ddf = self._to_ddf_tensor(init_ddf) else: B = self._init_img.shape[0] S = self.img_size self._init_ddf = torch.zeros( [B, self.ndims] + [S] * self.ndims, dtype=torch.float32, device=self.device, ) return self def set_cond_img( self, img=None, modality: Optional[str] = None, ) -> "OMorpher": """Set the conditioning image. Default: Gaussian noise sigma=0.1.""" if img is None: B = self._init_img.shape[0] if self._init_img is not None else self.batch_size S = self.img_size self._cond_img = torch.randn( [B, 1] + [S] * self.ndims, dtype=torch.float32, device=self.device, ) * 0.1 else: tensor, _, _ = self._standardize_img(img, modality=modality, keep_raw=False) self._cond_img = tensor return self def set_cond_txt(self, txt=None) -> "OMorpher": """Set the text conditioning. Accepts string, numpy [1024], torch [1024], or None.""" self._cond_txt = self._standardize_txt(txt) return self def set_init_def(self, ddf=None) -> "OMorpher": """Set or regenerate the initial deformation field. If *ddf* is ``None``, a random DDF is generated using the forward diffusion parameters (useful for data augmentation). """ if ddf is None: if self._init_img is None: raise RuntimeError("set_init_img() must be called before set_init_def()") t_val = self.config.get("start_noise_step", self.timesteps // 2) t = torch.tensor([t_val], dtype=torch.long, device=self.device) _, _, random_ddf = self._get_random_ddf(self._init_img, t) self._init_ddf = random_ddf else: self._init_ddf = self._to_ddf_tensor(ddf) return self # ------------------------------------------------------------------ # Public — Core operations (inference) # ------------------------------------------------------------------ def predict( self, T: Optional[list] = None, proc_type: Optional[str] = None, t_save: Optional[list] = None, ) -> "OMorpher": """Run reverse diffusion and store predicted DDF. Returns ``self`` for chaining.""" if self._init_img is None: raise RuntimeError("set_init_img() must be called before predict()") # Defaults start_noise = self.config.get("start_noise_step", 0) if T is None: T = [start_noise, self.timesteps] if proc_type is None: proc_type = self.condition_type B = self._init_img.shape[0] S = self.img_size # Conditioning cond_img_src = self._cond_img if self._cond_img is not None else self._init_img.clone().detach() cond_img, mask, cond_ratio = self._proc_cond_img(cond_img_src, proc_type=proc_type) # Text embedding txt = self._cond_txt if txt is None: txt = torch.zeros([B, 1024], dtype=torch.float32, device=self.device) # Reshape text for network consumption if isinstance(self.network, DefRec_MutAttnNet): txt = txt.view(B, -1, *([1] * self.ndims)) # Initial state init_ddf_is_zero = (self._init_ddf is None) or torch.all(self._init_ddf == 0) if not init_ddf_is_zero: ddf_comp = self._init_ddf.clone() img_rec = self.img_stn(self._init_img, ddf_comp) elif T[0] is not None and T[0] > 0: t_start = torch.tensor(np.array([T[0]]), device=self.device) img_rec, _, ddf_comp = self._get_random_ddf(self._init_img, t_start) else: img_rec = self._init_img.clone() ddf_comp = torch.zeros( [B, self.ndims] + [S] * self.ndims, dtype=torch.float32, device=self.device, ) # Reverse diffusion loop self._intermediate_ddfs = [] rec_num = 2 # matches DeformDDPM.rec_num default if isinstance(self.network, DefRec_MutAttnNet): # DefRec network: pass full time list at once t_list = list(range(T[1] - 1, -1, -1)) with torch.no_grad(): pre_dvf = self.network( x=img_rec, y=cond_img, t=t_list, rec_num=rec_num, text=txt, ) ddf_comp = self.stn_full(ddf_comp, pre_dvf) + pre_dvf img_rec = self.img_stn(self._init_img.clone().detach(), ddf_comp) if t_save: self._intermediate_ddfs.append((0, ddf_comp.clone())) else: # Standard iterative recovery time_steps = range(T[1] - 1, -1, -1) for i in time_steps: t = torch.tensor(np.array([i]), device=self.device) with torch.no_grad(): pre_dvf = self.network( x=img_rec, y=cond_img, t=t, rec_num=rec_num, text=txt, ) ddf_comp = self.stn_full(ddf_comp, pre_dvf) + pre_dvf img_rec = self.img_stn(self._init_img.clone().detach(), ddf_comp) if t_save is not None and i in t_save: self._intermediate_ddfs.append((i, ddf_comp.clone())) self._predicted_ddf = ddf_comp return self def get_def( self, t_list: Optional[list] = None, ) -> Union[torch.Tensor, Dict[int, torch.Tensor]]: """Return the final predicted DDF, or intermediate DDFs for given timesteps.""" if t_list is None: if self._predicted_ddf is None: raise RuntimeError("predict() must be called before get_def()") return self._predicted_ddf out = {} for t, ddf in self._intermediate_ddfs: if t in t_list: out[t] = ddf return out def apply_def( self, img=None, ddf: Optional[torch.Tensor] = None, padding_mode: Optional[str] = None, resample_mode: Optional[str] = None, ) -> torch.Tensor: """Apply a DDF to an image. Auto-upscales DDF when sizes differ. Defaults: init image at full resolution, predicted DDF. """ if padding_mode is None: padding_mode = self.padding_mode if resample_mode is None: resample_mode = "bilinear" # Default DDF if ddf is None: if self._predicted_ddf is None: raise RuntimeError("predict() must be called before apply_def()") ddf = self._predicted_ddf # Default image: full-res init image tensor if img is None: if self._init_img_raw is not None: vol_tensor = self._init_img_raw else: vol_tensor = self._init_img else: vol_tensor = self._ensure_tensor(img) # Upscale DDF if sizes differ target_sz = list(vol_tensor.shape[2:]) ddf_sz = list(ddf.shape[2:]) if target_sz != ddf_sz: ddf = F.interpolate( ddf, size=target_sz, mode="bilinear" if self.ndims == 2 else "trilinear", align_corners=False, ) return self._apply_ddf(vol_tensor, ddf, padding_mode=padding_mode, resample_mode=resample_mode) # ------------------------------------------------------------------ # Public — Fine-tuning # ------------------------------------------------------------------ def finetune_setup( self, lr: float = 1e-4, optimizer_cls=None, ) -> "OMorpher": """Switch to training mode and create an optimizer.""" self.network.train() self.inf_mode = False if optimizer_cls is None: optimizer_cls = torch.optim.Adam self._optimizer = optimizer_cls(self.network.parameters(), lr=lr) return self def finetune_step( self, img_batch, cond_batch=None, text_batch=None, t=None, proc_type=None, ) -> dict: """Single training step. Returns loss dict.""" if self._optimizer is None: raise RuntimeError("finetune_setup() must be called first") img, _, _ = self._standardize_img(img_batch, keep_raw=False) cond = self._standardize_img(cond_batch, keep_raw=False)[0] if cond_batch is not None else img.clone() text = self._standardize_txt(text_batch) B = img.shape[0] if t is None: t = torch.randint(0, self.timesteps, (B,), device=self.device) else: t = torch.tensor(t, device=self.device) if not isinstance(t, torch.Tensor) else t.to(self.device) proc_type = proc_type or self.condition_type cond_img, mask, cond_ratio = self._proc_cond_img(cond, proc_type=proc_type) noisy_img, dvf_gt, _ = self._get_random_ddf(img, t) # Reshape text for network if isinstance(self.network, DefRec_MutAttnNet): if text is not None: text = text.view(B, -1, *([1] * self.ndims)) t_input = [t] else: t_input = t pre_dvf = self.network(x=noisy_img * mask, y=cond_img, t=t_input, rec_num=2, text=text) loss_grad = self._loss_grad(y_pred=pre_dvf, img=img) trm_pred = self.stn_full(pre_dvf, dvf_gt) loss_dist = self._loss_dist(pred=trm_pred, inv_lab=dvf_gt) loss_ang = self._loss_ang(pred=trm_pred, inv_lab=dvf_gt) loss_total = 2.0 * loss_ang + 1.0 * loss_dist + 16.0 * loss_grad self._optimizer.zero_grad() loss_total.backward() self._optimizer.step() return { "loss_total": loss_total.item(), "loss_grad": loss_grad.item(), "loss_dist": loss_dist.item(), "loss_ang": loss_ang.item(), } def finetune_save(self, path: str, epoch: int = 0): """Save checkpoint in the standard OmniMorph format.""" os.makedirs(os.path.dirname(path) or ".", exist_ok=True) torch.save( { "model_state_dict": self.network.state_dict(), "optimizer_state_dict": self._optimizer.state_dict() if self._optimizer else None, "epoch": epoch, }, path, ) def finetune_teardown(self) -> "OMorpher": """Switch back to eval mode.""" self.network.eval() self.inf_mode = True self._optimizer = None return self # ------------------------------------------------------------------ # Private — Diffusion logic # ------------------------------------------------------------------ def _get_ddf_scale( self, t: torch.Tensor, divide_num: int = 1, max_ddf_num: int = 200, ) -> Tuple[int, torch.Tensor, torch.Tensor]: """Timestep-dependent deformation magnitude. Mirrors DeformDDPM._get_ddf_scale().""" rec_num = 1 mul_num_ddf = torch.floor_divide(2 * torch.pow(t.float(), 1.3), 3 * divide_num).int() mul_num_dvf = torch.floor_divide(torch.pow(t.float(), 0.6), divide_num).int() mul_num_ddf = torch.clamp(mul_num_ddf, min=1, max=max_ddf_num) mul_num_dvf = torch.clamp(mul_num_dvf, min=0, max=max_ddf_num) return rec_num, mul_num_ddf, mul_num_dvf def _sample_random_uniform_multi_order( self, high=None, low=0.0, order_num=3, ) -> float: sample_value = low for _ in range(order_num): sample_value = np.random.uniform(low=sample_value, high=high) return sample_value def _multiscale_dvf_generate( self, v_scale: float, ctl_szs: list = None, rand_v_scale: bool = True, ) -> torch.Tensor: """Multi-scale Gaussian DVF at control-point sizes.""" if ctl_szs is None: ctl_szs = [4, 8, 16, 32, 64] dvf = 0 for ctl_sz in ctl_szs: _v = ( self._sample_random_uniform_multi_order(high=v_scale, low=1e-8, order_num=2) if rand_v_scale else v_scale ) if ctl_sz <= 2: _v = _v / 2 dvf_comp = torch.randn( [self.batch_size, self.ndims] + [ctl_sz] * self.ndims ) * _v dvf_comp = F.interpolate( dvf_comp * self.ctl_sz / ctl_sz, [self.ctl_sz] * self.ndims, align_corners=False, mode="bilinear" if self.ndims == 2 else "trilinear", ) dvf = dvf + dvf_comp return dvf def _random_ddf_generate( self, rec_num: int = 3, mul_num: list = None, noise_ratio: float = 0.08, select_num: int = 4, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compose DVFs to build a DDF. Mirrors DeformDDPM._random_ddf_generate().""" if mul_num is None: mul_num = [torch.tensor([5]), torch.tensor([5])] crop_rate = 2 # unsqueeze mul_num for broadcasting for _ in range(self.ndims + 1): mul_num = [torch.unsqueeze(n, -1) for n in mul_num] ctl_ddf_sz = [self.batch_size, self.ndims] + [self.ctl_sz] * self.ndims ddf = torch.zeros(ctl_ddf_sz) dddf = torch.zeros(ctl_ddf_sz) scale_num = min(8, int(math.log2(self.ctl_sz))) ctl_szs_all = [self.ctl_sz // (2 ** i) for i in range(scale_num)] for _i in range(rec_num): if len(ctl_szs_all) > select_num: ctl_szs = random.sample(ctl_szs_all, select_num) else: ctl_szs = ctl_szs_all dvf = self._multiscale_dvf_generate(self.v_scale, ctl_szs=ctl_szs).to(self.device) if noise_ratio == 0: dvf0 = dvf else: dvf0 = dvf + self.stn_ctl( self._multiscale_dvf_generate( self.v_scale * noise_ratio, ctl_szs=ctl_szs, rand_v_scale=False, ).to(self.device), dvf, ) for j in range(torch.max(mul_num[0]).item()): flag = [(n > j).int().to(self.device) for n in mul_num] ddf = dvf0 * flag[0] + self.stn_ctl(ddf, dvf0 * flag[0]) dddf = dvf * flag[1] + self.stn_ctl(dddf, dvf * flag[1]) # Upscale and center-crop interp_mode = "bilinear" if self.ndims == 2 else "trilinear" ddf = F.interpolate( ddf * self.img_size / self.ctl_sz, self.img_size * crop_rate, mode=interp_mode, ) dddf = F.interpolate( dddf * self.img_size / self.ctl_sz, self.img_size * crop_rate, mode=interp_mode, ) half = self.img_size // 2 three_half = self.img_size * 3 // 2 if self.ndims == 2: ddf = ddf[..., half:three_half, half:three_half] dddf = dddf[..., half:three_half, half:three_half] else: ddf = ddf[..., half:three_half, half:three_half, half:three_half] dddf = dddf[..., half:three_half, half:three_half, half:three_half] return ddf, dddf def _get_random_ddf( self, img: torch.Tensor, t: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Forward-diffuse: generate random DDF and warp image.""" rec_num, mul_num_ddf, mul_num_dvf = self._get_ddf_scale(t=t) ddf_forward, dvf_forward = self._random_ddf_generate( rec_num=rec_num, mul_num=[mul_num_ddf, mul_num_dvf], ) warped_img = self.img_stn(img, ddf_forward) return warped_img, dvf_forward, ddf_forward # ------------------------------------------------------------------ # Private — Conditioning processing # ------------------------------------------------------------------ def _proc_cond_img( self, img: torch.Tensor, proc_type: Optional[str] = None, noise_scale: float = 0.1, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Conditioning strategies. Mirrors DeformDDPM.proc_cond_img().""" proc_img = img.clone().detach() if proc_type is None: proc_type = random.choices( ["adding", "independ", "downsample", "slice", "none", "uncon"], weights=[1, 1, 1, 1, 1, 3], k=1, )[0] mask = torch.tensor(1, device=img.device) cond_ratio = torch.tensor(1.0, device=img.device) if proc_type in ["none", None, "", "None"]: return proc_img, mask, cond_ratio noise_type = random.choice(["gaussian", "uniform", "none"]) if proc_type == "uncon": noise_map = self._create_noise_map(img, noise_type=noise_type, noise_scale=noise_scale) return noise_map, torch.tensor(0, device=img.device), torch.tensor(0, device=img.device) noise_map = None if proc_type in ["adding", "independ", "slice"]: noise_map = self._create_noise_map(img, noise_type=noise_type, noise_scale=noise_scale) if proc_type == "adding": noise_ratio = np.random.uniform(0.0, 1.0) proc_img = proc_img * (1 - noise_ratio) + noise_map * noise_ratio cond_ratio = torch.tensor(1 - noise_ratio, device=img.device) elif proc_type == "independ": mask = self._create_noise_map(img, noise_type="binary") proc_img = img * mask cond_ratio = mask.float().mean() elif proc_type == "downsample": down_ratio = list(np.random.uniform(1.0 / 64, 1, [self.ndims])) down_img = F.interpolate( proc_img, scale_factor=down_ratio, mode="bilinear" if self.ndims == 2 else "trilinear", ) proc_img = F.interpolate( down_img, size=[self.img_size] * self.ndims, mode="bilinear" if self.ndims == 2 else "trilinear", align_corners=False, ) cond_ratio = torch.tensor(np.sqrt(np.prod(down_ratio)), device=img.device) elif proc_type == "slice": slice_num_max = random.randint(1, 64) slice_num_max = random.randint(1, slice_num_max) mask, sample_ratio = self._get_slice_mask(img, slice_num_range=[0, slice_num_max]) proc_img = img * mask cond_ratio = torch.tensor(sample_ratio, device=img.device) elif proc_type == "project": proj_img = torch.zeros_like(img) rand_bourn = np.random.randint(0, 2, size=[self.ndims]) proj_dim_num = np.sum(rand_bourn) for i, pflag in zip(range(2, 2 + self.ndims), rand_bourn): if pflag: proj_img += torch.mean(img, dim=i, keepdim=True) proc_img = proj_img / (proj_dim_num + EPS) cond_ratio = torch.tensor(proj_dim_num / (128 * self.ndims), device=img.device) return proc_img, mask, cond_ratio def _create_noise_map( self, img: torch.Tensor, noise_type: str = "gaussian", noise_scale: float = 0.1, ) -> torch.Tensor: if noise_type == "gaussian": return (torch.randn_like(img) * noise_scale).to(img.device) elif noise_type == "uniform": return (torch.rand_like(img) * noise_scale * 2 - noise_scale).to(img.device) elif noise_type == "binary": return torch.bernoulli(torch.rand_like(img)).to(img.device) return torch.zeros_like(img).to(img.device) def _get_slice_mask( self, img: torch.Tensor, slice_num_range: list = None, ) -> Tuple[torch.Tensor, float]: if slice_num_range is None: slice_num_range = [0, 32] slice_num_range[1] = min(slice_num_range[1], self.img_size) mask = torch.zeros_like(img) sample_ratio = 0.0 for i in range(self.ndims): if self.inf_mode: slice_num = 1 slice_idx = [self.img_size // 2] else: slice_num = random.randint(slice_num_range[0], slice_num_range[1]) slice_idx = random.sample(range(self.img_size), slice_num) transpose_list = [0, 1, 1 + self.ndims] + list(range(2, 1 + self.ndims)) for idx in slice_idx: mask[..., idx] = 1 mask = mask.permute(*transpose_list) sample_ratio += np.sqrt(slice_num / self.img_size) / self.ndims return mask, sample_ratio # ------------------------------------------------------------------ # Private — Standardization # ------------------------------------------------------------------ def _standardize_img( self, img, modality: Optional[str] = None, keep_raw: bool = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple]]: """Deterministic inference variant of the dataloader pipeline. Returns ``(model_tensor, fullres_tensor_or_None, orig_shape_or_None)``. * *model_tensor*: ``[B, C, S, S, S]`` at model resolution. * *fullres_tensor*: ``[B, C, D, H, W]`` at original padded resolution (only when *keep_raw=True*). * *orig_shape*: spatial dims of padded volume before resize. Accepts numpy arrays, torch tensors (any dimensionality), or a file path (loaded via SimpleITK). Torch tensors with >= 4 dims are treated as already-batched and are passed through with appropriate device/dtype conversion. """ fullres_tensor = None orig_shape = None # 1. Load from path if isinstance(img, str): sitk_img = sitk.ReadImage(img) vol = sitk.GetArrayFromImage(sitk_img) vol = self._reverse_axis_order(vol) elif isinstance(img, np.ndarray): vol = img.copy() elif isinstance(img, torch.Tensor): # If already a batched tensor [B,C,...], pass through if img.ndim >= 4: t = img.float().to(self.device) if keep_raw: fullres_tensor = t.clone() return t, fullres_tensor, None # 1-3D tensor — treat as spatial-only numpy vol = img.numpy() else: raise TypeError(f"Unsupported image type: {type(img)}") # 2. Extract 3D from 4D if vol.ndim == 4: vol = vol[:, :, :, 0] # 3. CT clamping if modality is not None and modality.upper() == "CT" and self.clamp_range is not None: vol = np.clip(vol, self.clamp_range[0], self.clamp_range[1]) # 4. Normalize [0, 1] vol = vol.astype(np.float64) vol = (vol - np.min(vol)) / (np.ptp(vol) + 1e-7) # 5. Center-pad to cube vol = self._center_pad_to_cube(vol) orig_shape = vol.shape[:3] # 6. Full-res tensor (before resize) if keep_raw: fullres_tensor = torch.tensor( vol[None, None, ...], dtype=torch.float32, device=self.device, ) # 7. Resize to model resolution target_sz = [self.img_size] * self.ndims vol_resized = sk_resize( vol, target_sz, anti_aliasing=True, preserve_range=True, ) # 8. Add batch + channel dims model_tensor = torch.tensor( vol_resized[None, None, ...], dtype=torch.float32, device=self.device, ) return model_tensor, fullres_tensor, orig_shape def _standardize_label( self, label, fill_value: float = -1, ) -> Tuple[torch.Tensor, torch.Tensor]: """Standardize a label volume for inference. Returns ``(model_tensor, fullres_tensor)``. * *model_tensor*: ``[1, C, S, S, S]`` at model resolution (nearest-neighbor resize, no anti-aliasing). * *fullres_tensor*: ``[1, C, D, H, W]`` at original padded resolution. If *label* is ``None``, returns *fill_value*-filled placeholders shaped to match the current init image (model-res and full-res). Accepts numpy arrays or torch tensors. Does NOT apply normalization or clamping (labels are discrete indices). """ # --- Placeholder for missing labels --- if label is None: model_sz = [self.img_size] * self.ndims model_t = torch.full( [1, 1] + model_sz, fill_value, dtype=torch.float32, device=self.device, ) if self._init_img_raw is not None: fullres_sz = list(self._init_img_raw.shape[2:]) else: fullres_sz = model_sz fullres_t = torch.full( [1, 1] + fullres_sz, fill_value, dtype=torch.float32, device=self.device, ) return model_t, fullres_t # --- Convert to numpy if needed --- if isinstance(label, torch.Tensor): if label.ndim >= 4: # Already batched tensor — pass through fullres_t = label.float().to(self.device) target_sz = [self.img_size] * self.ndims model_t = F.interpolate( fullres_t, size=target_sz, mode="nearest", ) return model_t, fullres_t lab = label.numpy() elif isinstance(label, np.ndarray): lab = label.copy() else: raise TypeError(f"Unsupported label type: {type(label)}") # --- Center-pad to cube --- lab = self._center_pad_to_cube(lab) # --- Channel dim: 3D→[C=1,...], 4D→channels-first [C,...] --- if lab.ndim == 3: lab = lab[None, :, :, :] # [1, D, H, W] elif lab.ndim > 3: lab = np.transpose(lab, (3, 0, 1, 2)) # [C, D, H, W] # --- Full-res tensor --- fullres_t = torch.tensor( lab[None, ...], dtype=torch.float32, device=self.device, ) # [1, C, D, H, W] # --- Resize to model resolution (nearest-neighbor) --- target_sz = [self.img_size] * self.ndims # Resize each channel separately to avoid resizing the channel dim channels = [] for c in range(lab.shape[0]): ch = sk_resize( lab[c], target_sz, anti_aliasing=False, preserve_range=True, order=0, ) channels.append(ch) lab_model = np.stack(channels, axis=0) # [C, S, S, S] model_t = torch.tensor( lab_model[None, ...], dtype=torch.float32, device=self.device, ) # [1, C, S, S, S] return model_t, fullres_t def _standardize_txt(self, txt) -> Optional[torch.Tensor]: """Convert text input to [B, 1024] tensor.""" if txt is None: return None if isinstance(txt, str): self._ensure_bert() from Dataloader.bert_helper import str2emb emb = str2emb( txt, max_words_num=100, embeder=self._bert_model, tokenizer=self._bert_tokenizer, reduce_method="mean", ) return emb.to(self.device) # [1, 1024] if isinstance(txt, np.ndarray): t = torch.tensor(txt, dtype=torch.float32, device=self.device) if t.ndim == 1: t = t.unsqueeze(0) return t if isinstance(txt, torch.Tensor): t = txt.float().to(self.device) if t.ndim == 1: t = t.unsqueeze(0) return t raise TypeError(f"Unsupported text type: {type(txt)}") def _ensure_bert(self): if self._bert_model is None: from Dataloader.bert_helper import get_frozen_embeder self._bert_model, self._bert_tokenizer = get_frozen_embeder(self.bert_model_path) # ------------------------------------------------------------------ # Private — Spatial utilities # ------------------------------------------------------------------ @staticmethod def _reverse_axis_order(arr: np.ndarray) -> np.ndarray: """SimpleITK → NumPy axis order.""" return np.ascontiguousarray(arr.transpose(tuple(range(arr.ndim)[::-1]))) @staticmethod def _center_pad_to_cube(volume: np.ndarray) -> np.ndarray: """Pad volume to a cube using the max dimension, with symmetric padding.""" max_dim = max(volume.shape[:3]) pad_width = [] for s in volume.shape[:3]: total_pad = max_dim - s pad_before = total_pad // 2 pad_after = total_pad - pad_before pad_width.append((pad_before, pad_after)) for _ in range(volume.ndim - 3): pad_width.append((0, 0)) return np.pad(volume, pad_width, mode="constant", constant_values=0) def _apply_ddf( self, volume_tensor: torch.Tensor, ddf: torch.Tensor, padding_mode: str = "border", resample_mode: str = "bilinear", ) -> torch.Tensor: """Apply DDF to volume tensor at any resolution via grid_sample.""" device = ddf.device ndims = self.ndims img_sz = list(volume_tensor.shape[2:]) max_sz = torch.reshape( torch.tensor(img_sz, dtype=torch.float32, device=device), [1, ndims] + [1] * ndims, ) ref_grid = torch.reshape( torch.stack( torch.meshgrid( [torch.arange(s, device=device, dtype=torch.float32) for s in img_sz], indexing="ij", ), 0, ), [1, ndims] + img_sz, ) img_shape = torch.reshape( torch.tensor( [(s - 1) / 2.0 for s in img_sz], dtype=torch.float32, device=device, ), [1] + [1] * ndims + [ndims], ) grid = torch.flip( (ddf * max_sz + ref_grid).permute( [0] + list(range(2, 2 + ndims)) + [1] ) / img_shape - 1, dims=[-1], ) return F.grid_sample( volume_tensor.to(device), grid.float(), mode=resample_mode, padding_mode=padding_mode, align_corners=True, ) def _ensure_tensor(self, img) -> torch.Tensor: """Convert numpy/torch input to a [B, C, ...] float tensor on device.""" if isinstance(img, np.ndarray): t = torch.tensor(img, dtype=torch.float32, device=self.device) elif isinstance(img, torch.Tensor): t = img.float().to(self.device) else: raise TypeError(f"Unsupported image type: {type(img)}") if t.ndim == self.ndims: # spatial only → [B=1, C=1, ...] t = t[None, None, ...] elif t.ndim == self.ndims + 1: # [C, ...] → [B=1, C, ...] t = t[None, ...] return t def _to_ddf_tensor(self, ddf) -> torch.Tensor: """Convert ddf input to proper tensor on device.""" if isinstance(ddf, np.ndarray): ddf = torch.tensor(ddf, dtype=torch.float32) ddf = ddf.float().to(self.device) if ddf.ndim == self.ndims + 1: ddf = ddf.unsqueeze(0) # Resize to model resolution if needed model_sz = [self.img_size] * self.ndims if list(ddf.shape[2:]) != model_sz: ddf = F.interpolate( ddf, size=model_sz, mode="bilinear" if self.ndims == 2 else "trilinear", align_corners=False, ) return ddf # ------------------------------------------------------------------ # Convenience / repr # ------------------------------------------------------------------ def __repr__(self) -> str: status_parts = [] if self._init_img is not None: status_parts.append(f"init_img={list(self._init_img.shape)}") if self._cond_img is not None: status_parts.append(f"cond_img={list(self._cond_img.shape)}") if self._predicted_ddf is not None: status_parts.append(f"predicted_ddf={list(self._predicted_ddf.shape)}") status = ", ".join(status_parts) if status_parts else "empty" return ( f"OMorpher(net={self.net_name}, ndims={self.ndims}, " f"img_size={self.img_size}, device={self.device}, {status})" )