""" losses_opt.py — Optimized loss functions. Inherits from Diffusion.losses and overrides LNCC and MSLNCC to use register_buffer for convolution kernels (auto device transfer, no per-call .to(device) overhead). All other loss classes (LMSE, NCC, MRSE, RMSE, Grad) are re-exported unchanged. """ import numpy as np import torch import torch.nn.functional as F # Re-export unchanged classes from Diffusion.losses import ( LMSE, NCC, MRSE, RMSE, Grad, avg_std_skew_kurt, grad_std, avg_std, EPS, eps_scale, ) class LNCC(torch.nn.Module): """ Local (over window) normalized cross-correlation (LNCC). Optimized: kernels stored as registered buffers for automatic device transfer. """ def __init__(self, win=None, num_ch=1, eps=1e-3, central=True, smooth=True): super(LNCC, self).__init__() self.scale = 2e0 self.win = win self.eps = eps self.central = central self.ndims = 3 self.strides = [1] * (self.ndims + 2) self.smooth = smooth if self.win is None: self.win = [11] * self.ndims self.padding = [(w - 1) // 2 for w in self.win] if smooth: self.tail = None # will be set in _build_kernel kernels = self._build_kernel(std=0.5) self.register_buffer('kernels', kernels) # OPT: auto device transfer self.register_buffer('sum_filt', self._build_kernel(std=0.0)) # OPT: auto device transfer def _build_kernel(self, std=0.0): if std == 0.0: return torch.ones([1, 1, *self.win]) / np.prod(self.win) else: self.tail = int(np.ceil(std)) * 2 k = torch.exp(-0.5 * (torch.arange(-self.tail, self.tail + 1, dtype=torch.float32) ** 2) / std ** 2) kernel = k / torch.sum(k) kernel = kernel.view(-1, 1, 1) * kernel.view(1, -1, 1) * kernel.view(1, 1, -1) return kernel.unsqueeze(0).unsqueeze(0) def lncc(self, I, J, label=None): # OPT: no .to(I.device) needed — buffers auto-transfer with module.to() if self.smooth: I = torch.nn.functional.conv3d(I, self.kernels, stride=1, padding=self.tail) J = torch.nn.functional.conv3d(J, self.kernels, stride=1, padding=self.tail) I2 = I * I J2 = J * J IJ = I * J if self.central: I_sum = torch.nn.functional.conv3d(I, self.sum_filt, stride=1, padding=self.padding) J_sum = torch.nn.functional.conv3d(J, self.sum_filt, stride=1, padding=self.padding) I2_sum = torch.nn.functional.conv3d(I2, self.sum_filt, stride=1, padding=self.padding) J2_sum = torch.nn.functional.conv3d(J2, self.sum_filt, stride=1, padding=self.padding) IJ_sum = torch.nn.functional.conv3d(IJ, self.sum_filt, stride=1, padding=self.padding) cross = IJ_sum - (I_sum * J_sum) I_var = I2_sum - (I_sum * I_sum) J_var = J2_sum - (J_sum * J_sum) else: I2_sum = torch.nn.functional.conv3d(I2, self.sum_filt, stride=1, padding=self.padding) J2_sum = torch.nn.functional.conv3d(J2, self.sum_filt, stride=1, padding=self.padding) IJ_sum = torch.nn.functional.conv3d(IJ, self.sum_filt, stride=1, padding=self.padding) cross = IJ_sum I_var = I2_sum J_var = J2_sum cc = (cross * cross) / (I_var + self.eps) / (J_var + self.eps) if label is not None: label = label.float() cc = torch.sum(cc * label, dim=(2, 3, 4)) / (torch.sum(label, dim=(2, 3, 4)) + self.eps) return torch.mean(cc) def forward(self, I, J, label=None): return -self.lncc(I * self.scale, J * self.scale, label=label) class MSLNCC(LNCC): """ Multi-Scale Local Normalized Cross-Correlation (MSLNCC). Optimized: inherits buffer-based kernels from LNCC. """ def __init__(self, win=None, num_ch=1, eps=1e-3, central=True, smooth=False, scale_ratios=[1, 0.5, 0.25], scale_weights=[0.75, 0.5, 0.25]): super(MSLNCC, self).__init__(win=win, num_ch=num_ch, eps=eps, central=central, smooth=smooth) if win is None: win = [9] * self.ndims self.scale_ratios = scale_ratios self.scale_weights = scale_weights def _downsample(self, I, J, label, ratio): if ratio >= 1.0: return I, J, label factor = int(1.0 / ratio) I_down = F.avg_pool3d(I, kernel_size=factor, stride=factor) J_down = F.avg_pool3d(J, kernel_size=factor, stride=factor) label_down = None if label is not None: label_down = F.max_pool3d(label.float(), kernel_size=factor, stride=factor) return I_down, J_down, label_down def forward(self, I, J, label=None): total_loss = 0.0 total_weight = 0.0 for ratio, weight in zip(self.scale_ratios, self.scale_weights): I_s, J_s, label_s = self._downsample(I, J, label, ratio) total_loss += weight * self.lncc(I_s * self.scale, J_s * self.scale, label=label_s) total_weight += weight return -total_loss / total_weight