| """
|
| OMorpher — Object-oriented wrapper for OmniMorph diffusion-based deformation.
|
|
|
| Stores original high-res images and composes all intermediate deformations as
|
| deformation fields (DDFs), resampling only once at the end to avoid blurring.
|
| Independent of DeformDDPM at runtime; reimplements the diffusion logic using
|
| the network / STN / loss building blocks from Diffusion.*.
|
| """
|
|
|
| import os
|
| import glob
|
| import math
|
| import random
|
| from typing import Optional, Union, List, Tuple, Dict
|
|
|
| import numpy as np
|
| import torch
|
| import torch.nn.functional as F
|
| from torch import nn
|
|
|
| import yaml
|
| import SimpleITK as sitk
|
| from skimage.transform import resize as sk_resize
|
|
|
| from Diffusion.networks import get_net, STN, DefRec_MutAttnNet
|
| from Diffusion.losses import Grad, MRSE, NCC
|
|
|
| EPS = 1e-8
|
|
|
|
|
| class OMorpher:
|
| """High-level interface for OmniMorph deformation diffusion.
|
|
|
| All images are kept at their original resolution internally. Deformation
|
| fields are composed at model resolution and up-scaled on demand so that the
|
| original image is resampled at most *once*.
|
| """
|
|
|
|
|
|
|
|
|
|
|
| def __init__(
|
| self,
|
| config: Union[str, dict],
|
| checkpoint_path: Optional[str] = None,
|
| device: Optional[str] = None,
|
| bert_model_path: Optional[str] = None,
|
| ):
|
|
|
| if isinstance(config, str):
|
| with open(config, "r") as f:
|
| config = yaml.safe_load(f)
|
| self.config: dict = config
|
|
|
| self.net_name: str = config.get("net_name", "recmutattnnet")
|
| self.ndims: int = config.get("ndims", 3)
|
| self.img_size: int = config.get("img_size", 128)
|
| self.timesteps: int = config.get("timesteps", 80)
|
| self.v_scale: float = config.get("v_scale", 5e-5)
|
| self.noise_scale: float = config.get("noise_scale", 0.1)
|
| self.condition_type: str = config.get("condition_type", "none")
|
| self.num_input_chn: int = config.get("num_input_chn", 1)
|
| self.img_pad_mode: str = config.get("img_pad_mode", "zeros")
|
| self.ddf_pad_mode: str = config.get("ddf_pad_mode", "border")
|
| self.padding_mode: str = config.get("padding_mode", "border")
|
| self.resample_mode: str = config.get("resample_mode", "bilinear")
|
| self.batch_size: int = config.get("batchsize", 1)
|
| self.data_name: str = config.get("data_name", "all")
|
| self.clamp_range: list = config.get("clamp_range", [-400, 400])
|
| self.inf_mode: bool = config.get("inf_mode", True)
|
|
|
|
|
| if device is not None:
|
| self.device = torch.device(device)
|
| else:
|
| self.device = self._resolve_device(config.get("device", None))
|
|
|
|
|
| self.bert_model_path = bert_model_path or os.path.join(
|
| os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
| "External", "Models", "bert_large_uncased",
|
| )
|
| self._bert_model = None
|
| self._bert_tokenizer = None
|
|
|
|
|
| Net = get_net(self.net_name)
|
| self.network = Net(
|
| n_steps=self.timesteps,
|
| ndims=self.ndims,
|
| num_input_chn=self.num_input_chn,
|
| res=self.img_size,
|
| )
|
| self.network.to(self.device)
|
|
|
|
|
| self.ctl_ratio = 4
|
| self.ctl_sz = self.img_size // self.ctl_ratio
|
|
|
| self.stn_full = STN(
|
| img_sz=self.img_size,
|
| ndims=self.ndims,
|
| padding_mode=self.padding_mode,
|
| device=self.device,
|
| )
|
| self.stn_ctl = STN(
|
| img_sz=self.ctl_sz,
|
| ndims=self.ndims,
|
| padding_mode=self.ddf_pad_mode,
|
| device=self.device,
|
| )
|
| self.img_stn = STN(
|
| img_sz=self.img_size,
|
| ndims=self.ndims,
|
| padding_mode=self.img_pad_mode,
|
| device=self.device,
|
| resample_mode=self.resample_mode if self.resample_mode != "bilinear" else None,
|
| )
|
| self.msk_stn = STN(
|
| img_sz=self.img_size,
|
| ndims=self.ndims,
|
| padding_mode=self.img_pad_mode,
|
| device=self.device,
|
| resample_mode="nearest",
|
| )
|
|
|
|
|
| self._loss_grad = Grad(penalty=["l1"], ndims=self.ndims)
|
| self._loss_dist = MRSE(img_sz=self.img_size)
|
| self._loss_ang = NCC(img_sz=self.img_size)
|
|
|
|
|
| if checkpoint_path is not None:
|
| self._load_checkpoint(checkpoint_path)
|
| else:
|
| auto_path = self._auto_find_checkpoint()
|
| if auto_path is not None:
|
| self._load_checkpoint(auto_path)
|
|
|
| self.network.eval()
|
|
|
|
|
| self._init_img: Optional[torch.Tensor] = None
|
| self._init_img_raw: Optional[torch.Tensor] = None
|
| self._init_img_original_shape: Optional[tuple] = None
|
| self._init_ddf: Optional[torch.Tensor] = None
|
| self._cond_img: Optional[torch.Tensor] = None
|
| self._cond_txt: Optional[torch.Tensor] = None
|
| self._predicted_ddf: Optional[torch.Tensor] = None
|
| self._intermediate_ddfs: List[Tuple[int, torch.Tensor]] = []
|
|
|
|
|
| self._optimizer: Optional[torch.optim.Optimizer] = None
|
|
|
|
|
|
|
|
|
|
|
| @staticmethod
|
| def _resolve_device(hint: Optional[str] = None) -> torch.device:
|
| if hint is not None:
|
| s = str(hint).lower()
|
| if s not in ("auto", ""):
|
| return torch.device(s)
|
|
|
| try:
|
| import intel_extension_for_pytorch
|
| if torch.xpu.is_available():
|
| return torch.device("xpu")
|
| except (ImportError, AttributeError):
|
| pass
|
| if torch.cuda.is_available():
|
| return torch.device("cuda")
|
| return torch.device("cpu")
|
|
|
|
|
|
|
|
|
|
|
| def _auto_find_checkpoint(self) -> Optional[str]:
|
| pattern = os.path.join(
|
| os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
| "Models",
|
| f"{self.data_name}_{self.net_name}",
|
| "*.pth",
|
| )
|
| files = sorted(glob.glob(pattern))
|
| return files[-1] if files else None
|
|
|
| def _load_checkpoint(self, path: str):
|
| ckpt = torch.load(path, map_location="cpu")
|
| state_dict = ckpt.get("model_state_dict", ckpt)
|
|
|
| cleaned = {}
|
| for k, v in state_dict.items():
|
| k = k.replace("module.", "")
|
| if k.startswith("network."):
|
| k = k[len("network."):]
|
| cleaned[k] = v
|
|
|
| net_keys = set(self.network.state_dict().keys())
|
| filtered = {k: v for k, v in cleaned.items() if k in net_keys}
|
| if filtered:
|
| self.network.load_state_dict(filtered, strict=False)
|
|
|
|
|
|
|
|
|
|
|
| def set_init_img(
|
| self,
|
| img,
|
| modality: Optional[str] = None,
|
| ) -> "OMorpher":
|
| """Set the initial image. Accepts numpy, torch, path, or (img, ddf) tuple."""
|
| init_ddf = None
|
| if isinstance(img, (tuple, list)):
|
| img, init_ddf = img[0], img[1]
|
|
|
| model_tensor, fullres_tensor, orig_shape = self._standardize_img(
|
| img, modality=modality, keep_raw=True,
|
| )
|
| self._init_img = model_tensor
|
| self._init_img_raw = fullres_tensor
|
| self._init_img_original_shape = orig_shape
|
|
|
| if init_ddf is not None:
|
| self._init_ddf = self._to_ddf_tensor(init_ddf)
|
| else:
|
| B = self._init_img.shape[0]
|
| S = self.img_size
|
| self._init_ddf = torch.zeros(
|
| [B, self.ndims] + [S] * self.ndims,
|
| dtype=torch.float32, device=self.device,
|
| )
|
| return self
|
|
|
| def set_cond_img(
|
| self,
|
| img=None,
|
| modality: Optional[str] = None,
|
| ) -> "OMorpher":
|
| """Set the conditioning image. Default: Gaussian noise sigma=0.1."""
|
| if img is None:
|
| B = self._init_img.shape[0] if self._init_img is not None else self.batch_size
|
| S = self.img_size
|
| self._cond_img = torch.randn(
|
| [B, 1] + [S] * self.ndims,
|
| dtype=torch.float32, device=self.device,
|
| ) * 0.1
|
| else:
|
| tensor, _, _ = self._standardize_img(img, modality=modality, keep_raw=False)
|
| self._cond_img = tensor
|
| return self
|
|
|
| def set_cond_txt(self, txt=None) -> "OMorpher":
|
| """Set the text conditioning. Accepts string, numpy [1024], torch [1024], or None."""
|
| self._cond_txt = self._standardize_txt(txt)
|
| return self
|
|
|
| def set_init_def(self, ddf=None) -> "OMorpher":
|
| """Set or regenerate the initial deformation field.
|
|
|
| If *ddf* is ``None``, a random DDF is generated using the forward
|
| diffusion parameters (useful for data augmentation).
|
| """
|
| if ddf is None:
|
| if self._init_img is None:
|
| raise RuntimeError("set_init_img() must be called before set_init_def()")
|
| t_val = self.config.get("start_noise_step", self.timesteps // 2)
|
| t = torch.tensor([t_val], dtype=torch.long, device=self.device)
|
| _, _, random_ddf = self._get_random_ddf(self._init_img, t)
|
| self._init_ddf = random_ddf
|
| else:
|
| self._init_ddf = self._to_ddf_tensor(ddf)
|
| return self
|
|
|
|
|
|
|
|
|
|
|
| def predict(
|
| self,
|
| T: Optional[list] = None,
|
| proc_type: Optional[str] = None,
|
| t_save: Optional[list] = None,
|
| ) -> "OMorpher":
|
| """Run reverse diffusion and store predicted DDF. Returns ``self`` for chaining."""
|
| if self._init_img is None:
|
| raise RuntimeError("set_init_img() must be called before predict()")
|
|
|
|
|
| start_noise = self.config.get("start_noise_step", 0)
|
| if T is None:
|
| T = [start_noise, self.timesteps]
|
| if proc_type is None:
|
| proc_type = self.condition_type
|
|
|
| B = self._init_img.shape[0]
|
| S = self.img_size
|
|
|
|
|
| cond_img_src = self._cond_img if self._cond_img is not None else self._init_img.clone().detach()
|
| cond_img, mask, cond_ratio = self._proc_cond_img(cond_img_src, proc_type=proc_type)
|
|
|
|
|
| txt = self._cond_txt
|
| if txt is None:
|
| txt = torch.zeros([B, 1024], dtype=torch.float32, device=self.device)
|
|
|
|
|
| if isinstance(self.network, DefRec_MutAttnNet):
|
| txt = txt.view(B, -1, *([1] * self.ndims))
|
|
|
|
|
| init_ddf_is_zero = (self._init_ddf is None) or torch.all(self._init_ddf == 0)
|
|
|
| if not init_ddf_is_zero:
|
| ddf_comp = self._init_ddf.clone()
|
| img_rec = self.img_stn(self._init_img, ddf_comp)
|
| elif T[0] is not None and T[0] > 0:
|
| t_start = torch.tensor(np.array([T[0]]), device=self.device)
|
| img_rec, _, ddf_comp = self._get_random_ddf(self._init_img, t_start)
|
| else:
|
| img_rec = self._init_img.clone()
|
| ddf_comp = torch.zeros(
|
| [B, self.ndims] + [S] * self.ndims,
|
| dtype=torch.float32, device=self.device,
|
| )
|
|
|
|
|
| self._intermediate_ddfs = []
|
|
|
| rec_num = 2
|
|
|
| if isinstance(self.network, DefRec_MutAttnNet):
|
|
|
| t_list = list(range(T[1] - 1, -1, -1))
|
| with torch.no_grad():
|
| pre_dvf = self.network(
|
| x=img_rec, y=cond_img, t=t_list, rec_num=rec_num, text=txt,
|
| )
|
| ddf_comp = self.stn_full(ddf_comp, pre_dvf) + pre_dvf
|
| img_rec = self.img_stn(self._init_img.clone().detach(), ddf_comp)
|
| if t_save:
|
| self._intermediate_ddfs.append((0, ddf_comp.clone()))
|
| else:
|
|
|
| time_steps = range(T[1] - 1, -1, -1)
|
| for i in time_steps:
|
| t = torch.tensor(np.array([i]), device=self.device)
|
| with torch.no_grad():
|
| pre_dvf = self.network(
|
| x=img_rec, y=cond_img, t=t, rec_num=rec_num, text=txt,
|
| )
|
| ddf_comp = self.stn_full(ddf_comp, pre_dvf) + pre_dvf
|
| img_rec = self.img_stn(self._init_img.clone().detach(), ddf_comp)
|
| if t_save is not None and i in t_save:
|
| self._intermediate_ddfs.append((i, ddf_comp.clone()))
|
|
|
| self._predicted_ddf = ddf_comp
|
| return self
|
|
|
| def get_def(
|
| self,
|
| t_list: Optional[list] = None,
|
| ) -> Union[torch.Tensor, Dict[int, torch.Tensor]]:
|
| """Return the final predicted DDF, or intermediate DDFs for given timesteps."""
|
| if t_list is None:
|
| if self._predicted_ddf is None:
|
| raise RuntimeError("predict() must be called before get_def()")
|
| return self._predicted_ddf
|
| out = {}
|
| for t, ddf in self._intermediate_ddfs:
|
| if t in t_list:
|
| out[t] = ddf
|
| return out
|
|
|
| def apply_def(
|
| self,
|
| img=None,
|
| ddf: Optional[torch.Tensor] = None,
|
| padding_mode: Optional[str] = None,
|
| resample_mode: Optional[str] = None,
|
| ) -> torch.Tensor:
|
| """Apply a DDF to an image. Auto-upscales DDF when sizes differ.
|
|
|
| Defaults: init image at full resolution, predicted DDF.
|
| """
|
| if padding_mode is None:
|
| padding_mode = self.padding_mode
|
| if resample_mode is None:
|
| resample_mode = "bilinear"
|
|
|
|
|
| if ddf is None:
|
| if self._predicted_ddf is None:
|
| raise RuntimeError("predict() must be called before apply_def()")
|
| ddf = self._predicted_ddf
|
|
|
|
|
| if img is None:
|
| if self._init_img_raw is not None:
|
| vol_tensor = self._init_img_raw
|
| else:
|
| vol_tensor = self._init_img
|
| else:
|
| vol_tensor = self._ensure_tensor(img)
|
|
|
|
|
| target_sz = list(vol_tensor.shape[2:])
|
| ddf_sz = list(ddf.shape[2:])
|
| if target_sz != ddf_sz:
|
| ddf = F.interpolate(
|
| ddf, size=target_sz,
|
| mode="bilinear" if self.ndims == 2 else "trilinear",
|
| align_corners=False,
|
| )
|
|
|
| return self._apply_ddf(vol_tensor, ddf, padding_mode=padding_mode, resample_mode=resample_mode)
|
|
|
|
|
|
|
|
|
|
|
| def finetune_setup(
|
| self,
|
| lr: float = 1e-4,
|
| optimizer_cls=None,
|
| ) -> "OMorpher":
|
| """Switch to training mode and create an optimizer."""
|
| self.network.train()
|
| self.inf_mode = False
|
| if optimizer_cls is None:
|
| optimizer_cls = torch.optim.Adam
|
| self._optimizer = optimizer_cls(self.network.parameters(), lr=lr)
|
| return self
|
|
|
| def finetune_step(
|
| self,
|
| img_batch,
|
| cond_batch=None,
|
| text_batch=None,
|
| t=None,
|
| proc_type=None,
|
| ) -> dict:
|
| """Single training step. Returns loss dict."""
|
| if self._optimizer is None:
|
| raise RuntimeError("finetune_setup() must be called first")
|
|
|
| img, _, _ = self._standardize_img(img_batch, keep_raw=False)
|
| cond = self._standardize_img(cond_batch, keep_raw=False)[0] if cond_batch is not None else img.clone()
|
| text = self._standardize_txt(text_batch)
|
|
|
| B = img.shape[0]
|
| if t is None:
|
| t = torch.randint(0, self.timesteps, (B,), device=self.device)
|
| else:
|
| t = torch.tensor(t, device=self.device) if not isinstance(t, torch.Tensor) else t.to(self.device)
|
|
|
| proc_type = proc_type or self.condition_type
|
| cond_img, mask, cond_ratio = self._proc_cond_img(cond, proc_type=proc_type)
|
| noisy_img, dvf_gt, _ = self._get_random_ddf(img, t)
|
|
|
|
|
| if isinstance(self.network, DefRec_MutAttnNet):
|
| if text is not None:
|
| text = text.view(B, -1, *([1] * self.ndims))
|
| t_input = [t]
|
| else:
|
| t_input = t
|
|
|
| pre_dvf = self.network(x=noisy_img * mask, y=cond_img, t=t_input, rec_num=2, text=text)
|
|
|
| loss_grad = self._loss_grad(y_pred=pre_dvf, img=img)
|
| trm_pred = self.stn_full(pre_dvf, dvf_gt)
|
| loss_dist = self._loss_dist(pred=trm_pred, inv_lab=dvf_gt)
|
| loss_ang = self._loss_ang(pred=trm_pred, inv_lab=dvf_gt)
|
| loss_total = 2.0 * loss_ang + 1.0 * loss_dist + 16.0 * loss_grad
|
|
|
| self._optimizer.zero_grad()
|
| loss_total.backward()
|
| self._optimizer.step()
|
|
|
| return {
|
| "loss_total": loss_total.item(),
|
| "loss_grad": loss_grad.item(),
|
| "loss_dist": loss_dist.item(),
|
| "loss_ang": loss_ang.item(),
|
| }
|
|
|
| def finetune_save(self, path: str, epoch: int = 0):
|
| """Save checkpoint in the standard OmniMorph format."""
|
| os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
|
| torch.save(
|
| {
|
| "model_state_dict": self.network.state_dict(),
|
| "optimizer_state_dict": self._optimizer.state_dict() if self._optimizer else None,
|
| "epoch": epoch,
|
| },
|
| path,
|
| )
|
|
|
| def finetune_teardown(self) -> "OMorpher":
|
| """Switch back to eval mode."""
|
| self.network.eval()
|
| self.inf_mode = True
|
| self._optimizer = None
|
| return self
|
|
|
|
|
|
|
|
|
|
|
| def _get_ddf_scale(
|
| self, t: torch.Tensor, divide_num: int = 1, max_ddf_num: int = 200,
|
| ) -> Tuple[int, torch.Tensor, torch.Tensor]:
|
| """Timestep-dependent deformation magnitude. Mirrors DeformDDPM._get_ddf_scale()."""
|
| rec_num = 1
|
| mul_num_ddf = torch.floor_divide(2 * torch.pow(t.float(), 1.3), 3 * divide_num).int()
|
| mul_num_dvf = torch.floor_divide(torch.pow(t.float(), 0.6), divide_num).int()
|
| mul_num_ddf = torch.clamp(mul_num_ddf, min=1, max=max_ddf_num)
|
| mul_num_dvf = torch.clamp(mul_num_dvf, min=0, max=max_ddf_num)
|
| return rec_num, mul_num_ddf, mul_num_dvf
|
|
|
| def _sample_random_uniform_multi_order(
|
| self, high=None, low=0.0, order_num=3,
|
| ) -> float:
|
| sample_value = low
|
| for _ in range(order_num):
|
| sample_value = np.random.uniform(low=sample_value, high=high)
|
| return sample_value
|
|
|
| def _multiscale_dvf_generate(
|
| self, v_scale: float, ctl_szs: list = None, rand_v_scale: bool = True,
|
| ) -> torch.Tensor:
|
| """Multi-scale Gaussian DVF at control-point sizes."""
|
| if ctl_szs is None:
|
| ctl_szs = [4, 8, 16, 32, 64]
|
| dvf = 0
|
| for ctl_sz in ctl_szs:
|
| _v = (
|
| self._sample_random_uniform_multi_order(high=v_scale, low=1e-8, order_num=2)
|
| if rand_v_scale
|
| else v_scale
|
| )
|
| if ctl_sz <= 2:
|
| _v = _v / 2
|
| dvf_comp = torch.randn(
|
| [self.batch_size, self.ndims] + [ctl_sz] * self.ndims
|
| ) * _v
|
| dvf_comp = F.interpolate(
|
| dvf_comp * self.ctl_sz / ctl_sz,
|
| [self.ctl_sz] * self.ndims,
|
| align_corners=False,
|
| mode="bilinear" if self.ndims == 2 else "trilinear",
|
| )
|
| dvf = dvf + dvf_comp
|
| return dvf
|
|
|
| def _random_ddf_generate(
|
| self,
|
| rec_num: int = 3,
|
| mul_num: list = None,
|
| noise_ratio: float = 0.08,
|
| select_num: int = 4,
|
| ) -> Tuple[torch.Tensor, torch.Tensor]:
|
| """Compose DVFs to build a DDF. Mirrors DeformDDPM._random_ddf_generate()."""
|
| if mul_num is None:
|
| mul_num = [torch.tensor([5]), torch.tensor([5])]
|
|
|
| crop_rate = 2
|
|
|
| for _ in range(self.ndims + 1):
|
| mul_num = [torch.unsqueeze(n, -1) for n in mul_num]
|
|
|
| ctl_ddf_sz = [self.batch_size, self.ndims] + [self.ctl_sz] * self.ndims
|
| ddf = torch.zeros(ctl_ddf_sz)
|
| dddf = torch.zeros(ctl_ddf_sz)
|
| scale_num = min(8, int(math.log2(self.ctl_sz)))
|
| ctl_szs_all = [self.ctl_sz // (2 ** i) for i in range(scale_num)]
|
|
|
| for _i in range(rec_num):
|
| if len(ctl_szs_all) > select_num:
|
| ctl_szs = random.sample(ctl_szs_all, select_num)
|
| else:
|
| ctl_szs = ctl_szs_all
|
| dvf = self._multiscale_dvf_generate(self.v_scale, ctl_szs=ctl_szs).to(self.device)
|
| if noise_ratio == 0:
|
| dvf0 = dvf
|
| else:
|
| dvf0 = dvf + self.stn_ctl(
|
| self._multiscale_dvf_generate(
|
| self.v_scale * noise_ratio, ctl_szs=ctl_szs, rand_v_scale=False,
|
| ).to(self.device),
|
| dvf,
|
| )
|
| for j in range(torch.max(mul_num[0]).item()):
|
| flag = [(n > j).int().to(self.device) for n in mul_num]
|
| ddf = dvf0 * flag[0] + self.stn_ctl(ddf, dvf0 * flag[0])
|
| dddf = dvf * flag[1] + self.stn_ctl(dddf, dvf * flag[1])
|
|
|
|
|
| interp_mode = "bilinear" if self.ndims == 2 else "trilinear"
|
| ddf = F.interpolate(
|
| ddf * self.img_size / self.ctl_sz,
|
| self.img_size * crop_rate,
|
| mode=interp_mode,
|
| )
|
| dddf = F.interpolate(
|
| dddf * self.img_size / self.ctl_sz,
|
| self.img_size * crop_rate,
|
| mode=interp_mode,
|
| )
|
| half = self.img_size // 2
|
| three_half = self.img_size * 3 // 2
|
| if self.ndims == 2:
|
| ddf = ddf[..., half:three_half, half:three_half]
|
| dddf = dddf[..., half:three_half, half:three_half]
|
| else:
|
| ddf = ddf[..., half:three_half, half:three_half, half:three_half]
|
| dddf = dddf[..., half:three_half, half:three_half, half:three_half]
|
| return ddf, dddf
|
|
|
| def _get_random_ddf(
|
| self, img: torch.Tensor, t: torch.Tensor,
|
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| """Forward-diffuse: generate random DDF and warp image."""
|
| rec_num, mul_num_ddf, mul_num_dvf = self._get_ddf_scale(t=t)
|
| ddf_forward, dvf_forward = self._random_ddf_generate(
|
| rec_num=rec_num, mul_num=[mul_num_ddf, mul_num_dvf],
|
| )
|
| warped_img = self.img_stn(img, ddf_forward)
|
| return warped_img, dvf_forward, ddf_forward
|
|
|
|
|
|
|
|
|
|
|
| def _proc_cond_img(
|
| self,
|
| img: torch.Tensor,
|
| proc_type: Optional[str] = None,
|
| noise_scale: float = 0.1,
|
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| """Conditioning strategies. Mirrors DeformDDPM.proc_cond_img()."""
|
| proc_img = img.clone().detach()
|
| if proc_type is None:
|
| proc_type = random.choices(
|
| ["adding", "independ", "downsample", "slice", "none", "uncon"],
|
| weights=[1, 1, 1, 1, 1, 3],
|
| k=1,
|
| )[0]
|
|
|
| mask = torch.tensor(1, device=img.device)
|
| cond_ratio = torch.tensor(1.0, device=img.device)
|
|
|
| if proc_type in ["none", None, "", "None"]:
|
| return proc_img, mask, cond_ratio
|
|
|
| noise_type = random.choice(["gaussian", "uniform", "none"])
|
|
|
| if proc_type == "uncon":
|
| noise_map = self._create_noise_map(img, noise_type=noise_type, noise_scale=noise_scale)
|
| return noise_map, torch.tensor(0, device=img.device), torch.tensor(0, device=img.device)
|
|
|
| noise_map = None
|
| if proc_type in ["adding", "independ", "slice"]:
|
| noise_map = self._create_noise_map(img, noise_type=noise_type, noise_scale=noise_scale)
|
|
|
| if proc_type == "adding":
|
| noise_ratio = np.random.uniform(0.0, 1.0)
|
| proc_img = proc_img * (1 - noise_ratio) + noise_map * noise_ratio
|
| cond_ratio = torch.tensor(1 - noise_ratio, device=img.device)
|
| elif proc_type == "independ":
|
| mask = self._create_noise_map(img, noise_type="binary")
|
| proc_img = img * mask
|
| cond_ratio = mask.float().mean()
|
| elif proc_type == "downsample":
|
| down_ratio = list(np.random.uniform(1.0 / 64, 1, [self.ndims]))
|
| down_img = F.interpolate(
|
| proc_img, scale_factor=down_ratio,
|
| mode="bilinear" if self.ndims == 2 else "trilinear",
|
| )
|
| proc_img = F.interpolate(
|
| down_img, size=[self.img_size] * self.ndims,
|
| mode="bilinear" if self.ndims == 2 else "trilinear",
|
| align_corners=False,
|
| )
|
| cond_ratio = torch.tensor(np.sqrt(np.prod(down_ratio)), device=img.device)
|
| elif proc_type == "slice":
|
| slice_num_max = random.randint(1, 64)
|
| slice_num_max = random.randint(1, slice_num_max)
|
| mask, sample_ratio = self._get_slice_mask(img, slice_num_range=[0, slice_num_max])
|
| proc_img = img * mask
|
| cond_ratio = torch.tensor(sample_ratio, device=img.device)
|
| elif proc_type == "project":
|
| proj_img = torch.zeros_like(img)
|
| rand_bourn = np.random.randint(0, 2, size=[self.ndims])
|
| proj_dim_num = np.sum(rand_bourn)
|
| for i, pflag in zip(range(2, 2 + self.ndims), rand_bourn):
|
| if pflag:
|
| proj_img += torch.mean(img, dim=i, keepdim=True)
|
| proc_img = proj_img / (proj_dim_num + EPS)
|
| cond_ratio = torch.tensor(proj_dim_num / (128 * self.ndims), device=img.device)
|
|
|
| return proc_img, mask, cond_ratio
|
|
|
| def _create_noise_map(
|
| self,
|
| img: torch.Tensor,
|
| noise_type: str = "gaussian",
|
| noise_scale: float = 0.1,
|
| ) -> torch.Tensor:
|
| if noise_type == "gaussian":
|
| return (torch.randn_like(img) * noise_scale).to(img.device)
|
| elif noise_type == "uniform":
|
| return (torch.rand_like(img) * noise_scale * 2 - noise_scale).to(img.device)
|
| elif noise_type == "binary":
|
| return torch.bernoulli(torch.rand_like(img)).to(img.device)
|
| return torch.zeros_like(img).to(img.device)
|
|
|
| def _get_slice_mask(
|
| self,
|
| img: torch.Tensor,
|
| slice_num_range: list = None,
|
| ) -> Tuple[torch.Tensor, float]:
|
| if slice_num_range is None:
|
| slice_num_range = [0, 32]
|
| slice_num_range[1] = min(slice_num_range[1], self.img_size)
|
| mask = torch.zeros_like(img)
|
| sample_ratio = 0.0
|
| for i in range(self.ndims):
|
| if self.inf_mode:
|
| slice_num = 1
|
| slice_idx = [self.img_size // 2]
|
| else:
|
| slice_num = random.randint(slice_num_range[0], slice_num_range[1])
|
| slice_idx = random.sample(range(self.img_size), slice_num)
|
| transpose_list = [0, 1, 1 + self.ndims] + list(range(2, 1 + self.ndims))
|
| for idx in slice_idx:
|
| mask[..., idx] = 1
|
| mask = mask.permute(*transpose_list)
|
| sample_ratio += np.sqrt(slice_num / self.img_size) / self.ndims
|
| return mask, sample_ratio
|
|
|
|
|
|
|
|
|
|
|
| def _standardize_img(
|
| self,
|
| img,
|
| modality: Optional[str] = None,
|
| keep_raw: bool = False,
|
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple]]:
|
| """Deterministic inference variant of the dataloader pipeline.
|
|
|
| Returns ``(model_tensor, fullres_tensor_or_None, orig_shape_or_None)``.
|
|
|
| * *model_tensor*: ``[B, C, S, S, S]`` at model resolution.
|
| * *fullres_tensor*: ``[B, C, D, H, W]`` at original padded resolution
|
| (only when *keep_raw=True*).
|
| * *orig_shape*: spatial dims of padded volume before resize.
|
|
|
| Accepts numpy arrays, torch tensors (any dimensionality), or a
|
| file path (loaded via SimpleITK). Torch tensors with >= 4 dims
|
| are treated as already-batched and are passed through with
|
| appropriate device/dtype conversion.
|
| """
|
| fullres_tensor = None
|
| orig_shape = None
|
|
|
|
|
| if isinstance(img, str):
|
| sitk_img = sitk.ReadImage(img)
|
| vol = sitk.GetArrayFromImage(sitk_img)
|
| vol = self._reverse_axis_order(vol)
|
| elif isinstance(img, np.ndarray):
|
| vol = img.copy()
|
| elif isinstance(img, torch.Tensor):
|
|
|
| if img.ndim >= 4:
|
| t = img.float().to(self.device)
|
| if keep_raw:
|
| fullres_tensor = t.clone()
|
| return t, fullres_tensor, None
|
|
|
| vol = img.numpy()
|
| else:
|
| raise TypeError(f"Unsupported image type: {type(img)}")
|
|
|
|
|
| if vol.ndim == 4:
|
| vol = vol[:, :, :, 0]
|
|
|
|
|
| if modality is not None and modality.upper() == "CT" and self.clamp_range is not None:
|
| vol = np.clip(vol, self.clamp_range[0], self.clamp_range[1])
|
|
|
|
|
| vol = vol.astype(np.float64)
|
| vol = (vol - np.min(vol)) / (np.ptp(vol) + 1e-7)
|
|
|
|
|
| vol = self._center_pad_to_cube(vol)
|
| orig_shape = vol.shape[:3]
|
|
|
|
|
| if keep_raw:
|
| fullres_tensor = torch.tensor(
|
| vol[None, None, ...], dtype=torch.float32, device=self.device,
|
| )
|
|
|
|
|
| target_sz = [self.img_size] * self.ndims
|
| vol_resized = sk_resize(
|
| vol, target_sz, anti_aliasing=True, preserve_range=True,
|
| )
|
|
|
|
|
| model_tensor = torch.tensor(
|
| vol_resized[None, None, ...], dtype=torch.float32, device=self.device,
|
| )
|
| return model_tensor, fullres_tensor, orig_shape
|
|
|
| def _standardize_label(
|
| self,
|
| label,
|
| fill_value: float = -1,
|
| ) -> Tuple[torch.Tensor, torch.Tensor]:
|
| """Standardize a label volume for inference.
|
|
|
| Returns ``(model_tensor, fullres_tensor)``.
|
|
|
| * *model_tensor*: ``[1, C, S, S, S]`` at model resolution
|
| (nearest-neighbor resize, no anti-aliasing).
|
| * *fullres_tensor*: ``[1, C, D, H, W]`` at original padded resolution.
|
|
|
| If *label* is ``None``, returns *fill_value*-filled placeholders
|
| shaped to match the current init image (model-res and full-res).
|
|
|
| Accepts numpy arrays or torch tensors. Does NOT apply
|
| normalization or clamping (labels are discrete indices).
|
| """
|
|
|
| if label is None:
|
| model_sz = [self.img_size] * self.ndims
|
| model_t = torch.full(
|
| [1, 1] + model_sz, fill_value,
|
| dtype=torch.float32, device=self.device,
|
| )
|
| if self._init_img_raw is not None:
|
| fullres_sz = list(self._init_img_raw.shape[2:])
|
| else:
|
| fullres_sz = model_sz
|
| fullres_t = torch.full(
|
| [1, 1] + fullres_sz, fill_value,
|
| dtype=torch.float32, device=self.device,
|
| )
|
| return model_t, fullres_t
|
|
|
|
|
| if isinstance(label, torch.Tensor):
|
| if label.ndim >= 4:
|
|
|
| fullres_t = label.float().to(self.device)
|
| target_sz = [self.img_size] * self.ndims
|
| model_t = F.interpolate(
|
| fullres_t, size=target_sz, mode="nearest",
|
| )
|
| return model_t, fullres_t
|
| lab = label.numpy()
|
| elif isinstance(label, np.ndarray):
|
| lab = label.copy()
|
| else:
|
| raise TypeError(f"Unsupported label type: {type(label)}")
|
|
|
|
|
| lab = self._center_pad_to_cube(lab)
|
|
|
|
|
| if lab.ndim == 3:
|
| lab = lab[None, :, :, :]
|
| elif lab.ndim > 3:
|
| lab = np.transpose(lab, (3, 0, 1, 2))
|
|
|
|
|
| fullres_t = torch.tensor(
|
| lab[None, ...], dtype=torch.float32, device=self.device,
|
| )
|
|
|
|
|
| target_sz = [self.img_size] * self.ndims
|
|
|
| channels = []
|
| for c in range(lab.shape[0]):
|
| ch = sk_resize(
|
| lab[c], target_sz,
|
| anti_aliasing=False, preserve_range=True, order=0,
|
| )
|
| channels.append(ch)
|
| lab_model = np.stack(channels, axis=0)
|
| model_t = torch.tensor(
|
| lab_model[None, ...], dtype=torch.float32, device=self.device,
|
| )
|
|
|
| return model_t, fullres_t
|
|
|
| def _standardize_txt(self, txt) -> Optional[torch.Tensor]:
|
| """Convert text input to [B, 1024] tensor."""
|
| if txt is None:
|
| return None
|
| if isinstance(txt, str):
|
| self._ensure_bert()
|
| from Dataloader.bert_helper import str2emb
|
| emb = str2emb(
|
| txt, max_words_num=100,
|
| embeder=self._bert_model, tokenizer=self._bert_tokenizer,
|
| reduce_method="mean",
|
| )
|
| return emb.to(self.device)
|
| if isinstance(txt, np.ndarray):
|
| t = torch.tensor(txt, dtype=torch.float32, device=self.device)
|
| if t.ndim == 1:
|
| t = t.unsqueeze(0)
|
| return t
|
| if isinstance(txt, torch.Tensor):
|
| t = txt.float().to(self.device)
|
| if t.ndim == 1:
|
| t = t.unsqueeze(0)
|
| return t
|
| raise TypeError(f"Unsupported text type: {type(txt)}")
|
|
|
| def _ensure_bert(self):
|
| if self._bert_model is None:
|
| from Dataloader.bert_helper import get_frozen_embeder
|
| self._bert_model, self._bert_tokenizer = get_frozen_embeder(self.bert_model_path)
|
|
|
|
|
|
|
|
|
|
|
| @staticmethod
|
| def _reverse_axis_order(arr: np.ndarray) -> np.ndarray:
|
| """SimpleITK → NumPy axis order."""
|
| return np.ascontiguousarray(arr.transpose(tuple(range(arr.ndim)[::-1])))
|
|
|
| @staticmethod
|
| def _center_pad_to_cube(volume: np.ndarray) -> np.ndarray:
|
| """Pad volume to a cube using the max dimension, with symmetric padding."""
|
| max_dim = max(volume.shape[:3])
|
| pad_width = []
|
| for s in volume.shape[:3]:
|
| total_pad = max_dim - s
|
| pad_before = total_pad // 2
|
| pad_after = total_pad - pad_before
|
| pad_width.append((pad_before, pad_after))
|
| for _ in range(volume.ndim - 3):
|
| pad_width.append((0, 0))
|
| return np.pad(volume, pad_width, mode="constant", constant_values=0)
|
|
|
| def _apply_ddf(
|
| self,
|
| volume_tensor: torch.Tensor,
|
| ddf: torch.Tensor,
|
| padding_mode: str = "border",
|
| resample_mode: str = "bilinear",
|
| ) -> torch.Tensor:
|
| """Apply DDF to volume tensor at any resolution via grid_sample."""
|
| device = ddf.device
|
| ndims = self.ndims
|
| img_sz = list(volume_tensor.shape[2:])
|
| max_sz = torch.reshape(
|
| torch.tensor(img_sz, dtype=torch.float32, device=device),
|
| [1, ndims] + [1] * ndims,
|
| )
|
| ref_grid = torch.reshape(
|
| torch.stack(
|
| torch.meshgrid(
|
| [torch.arange(s, device=device, dtype=torch.float32) for s in img_sz],
|
| indexing="ij",
|
| ),
|
| 0,
|
| ),
|
| [1, ndims] + img_sz,
|
| )
|
| img_shape = torch.reshape(
|
| torch.tensor(
|
| [(s - 1) / 2.0 for s in img_sz], dtype=torch.float32, device=device,
|
| ),
|
| [1] + [1] * ndims + [ndims],
|
| )
|
| grid = torch.flip(
|
| (ddf * max_sz + ref_grid).permute(
|
| [0] + list(range(2, 2 + ndims)) + [1]
|
| )
|
| / img_shape
|
| - 1,
|
| dims=[-1],
|
| )
|
| return F.grid_sample(
|
| volume_tensor.to(device),
|
| grid.float(),
|
| mode=resample_mode,
|
| padding_mode=padding_mode,
|
| align_corners=True,
|
| )
|
|
|
| def _ensure_tensor(self, img) -> torch.Tensor:
|
| """Convert numpy/torch input to a [B, C, ...] float tensor on device."""
|
| if isinstance(img, np.ndarray):
|
| t = torch.tensor(img, dtype=torch.float32, device=self.device)
|
| elif isinstance(img, torch.Tensor):
|
| t = img.float().to(self.device)
|
| else:
|
| raise TypeError(f"Unsupported image type: {type(img)}")
|
| if t.ndim == self.ndims:
|
| t = t[None, None, ...]
|
| elif t.ndim == self.ndims + 1:
|
| t = t[None, ...]
|
| return t
|
|
|
| def _to_ddf_tensor(self, ddf) -> torch.Tensor:
|
| """Convert ddf input to proper tensor on device."""
|
| if isinstance(ddf, np.ndarray):
|
| ddf = torch.tensor(ddf, dtype=torch.float32)
|
| ddf = ddf.float().to(self.device)
|
| if ddf.ndim == self.ndims + 1:
|
| ddf = ddf.unsqueeze(0)
|
|
|
| model_sz = [self.img_size] * self.ndims
|
| if list(ddf.shape[2:]) != model_sz:
|
| ddf = F.interpolate(
|
| ddf, size=model_sz,
|
| mode="bilinear" if self.ndims == 2 else "trilinear",
|
| align_corners=False,
|
| )
|
| return ddf
|
|
|
|
|
|
|
|
|
|
|
| def __repr__(self) -> str:
|
| status_parts = []
|
| if self._init_img is not None:
|
| status_parts.append(f"init_img={list(self._init_img.shape)}")
|
| if self._cond_img is not None:
|
| status_parts.append(f"cond_img={list(self._cond_img.shape)}")
|
| if self._predicted_ddf is not None:
|
| status_parts.append(f"predicted_ddf={list(self._predicted_ddf.shape)}")
|
| status = ", ".join(status_parts) if status_parts else "empty"
|
| return (
|
| f"OMorpher(net={self.net_name}, ndims={self.ndims}, "
|
| f"img_size={self.img_size}, device={self.device}, {status})"
|
| )
|
|
|