| """
|
| losses_opt.py — Optimized loss functions.
|
|
|
| Inherits from Diffusion.losses and overrides LNCC and MSLNCC to use
|
| register_buffer for convolution kernels (auto device transfer, no
|
| per-call .to(device) overhead).
|
|
|
| All other loss classes (LMSE, NCC, MRSE, RMSE, Grad) are re-exported
|
| unchanged.
|
| """
|
|
|
| import numpy as np
|
| import torch
|
| import torch.nn.functional as F
|
|
|
|
|
| from Diffusion.losses import (
|
| LMSE,
|
| NCC,
|
| MRSE,
|
| RMSE,
|
| Grad,
|
| avg_std_skew_kurt,
|
| grad_std,
|
| avg_std,
|
| EPS,
|
| eps_scale,
|
| )
|
|
|
|
|
| class LNCC(torch.nn.Module):
|
| """
|
| Local (over window) normalized cross-correlation (LNCC).
|
| Optimized: kernels stored as registered buffers for automatic device transfer.
|
| """
|
|
|
| def __init__(self, win=None, num_ch=1, eps=1e-3, central=True, smooth=True):
|
| super(LNCC, self).__init__()
|
| self.scale = 2e0
|
| self.win = win
|
| self.eps = eps
|
| self.central = central
|
| self.ndims = 3
|
| self.strides = [1] * (self.ndims + 2)
|
| self.smooth = smooth
|
|
|
| if self.win is None:
|
| self.win = [11] * self.ndims
|
| self.padding = [(w - 1) // 2 for w in self.win]
|
|
|
| if smooth:
|
| self.tail = None
|
| kernels = self._build_kernel(std=0.5)
|
| self.register_buffer('kernels', kernels)
|
| self.register_buffer('sum_filt', self._build_kernel(std=0.0))
|
|
|
| def _build_kernel(self, std=0.0):
|
| if std == 0.0:
|
| return torch.ones([1, 1, *self.win]) / np.prod(self.win)
|
| else:
|
| self.tail = int(np.ceil(std)) * 2
|
| k = torch.exp(-0.5 * (torch.arange(-self.tail, self.tail + 1, dtype=torch.float32) ** 2) / std ** 2)
|
| kernel = k / torch.sum(k)
|
| kernel = kernel.view(-1, 1, 1) * kernel.view(1, -1, 1) * kernel.view(1, 1, -1)
|
| return kernel.unsqueeze(0).unsqueeze(0)
|
|
|
| def lncc(self, I, J, label=None):
|
|
|
|
|
| if self.smooth:
|
| I = torch.nn.functional.conv3d(I, self.kernels, stride=1, padding=self.tail)
|
| J = torch.nn.functional.conv3d(J, self.kernels, stride=1, padding=self.tail)
|
|
|
| I2 = I * I
|
| J2 = J * J
|
| IJ = I * J
|
|
|
| if self.central:
|
| I_sum = torch.nn.functional.conv3d(I, self.sum_filt, stride=1, padding=self.padding)
|
| J_sum = torch.nn.functional.conv3d(J, self.sum_filt, stride=1, padding=self.padding)
|
| I2_sum = torch.nn.functional.conv3d(I2, self.sum_filt, stride=1, padding=self.padding)
|
| J2_sum = torch.nn.functional.conv3d(J2, self.sum_filt, stride=1, padding=self.padding)
|
| IJ_sum = torch.nn.functional.conv3d(IJ, self.sum_filt, stride=1, padding=self.padding)
|
|
|
| cross = IJ_sum - (I_sum * J_sum)
|
| I_var = I2_sum - (I_sum * I_sum)
|
| J_var = J2_sum - (J_sum * J_sum)
|
| else:
|
| I2_sum = torch.nn.functional.conv3d(I2, self.sum_filt, stride=1, padding=self.padding)
|
| J2_sum = torch.nn.functional.conv3d(J2, self.sum_filt, stride=1, padding=self.padding)
|
| IJ_sum = torch.nn.functional.conv3d(IJ, self.sum_filt, stride=1, padding=self.padding)
|
|
|
| cross = IJ_sum
|
| I_var = I2_sum
|
| J_var = J2_sum
|
|
|
| cc = (cross * cross) / (I_var + self.eps) / (J_var + self.eps)
|
| if label is not None:
|
| label = label.float()
|
| cc = torch.sum(cc * label, dim=(2, 3, 4)) / (torch.sum(label, dim=(2, 3, 4)) + self.eps)
|
|
|
| return torch.mean(cc)
|
|
|
| def forward(self, I, J, label=None):
|
| return -self.lncc(I * self.scale, J * self.scale, label=label)
|
|
|
|
|
| class MSLNCC(LNCC):
|
| """
|
| Multi-Scale Local Normalized Cross-Correlation (MSLNCC).
|
| Optimized: inherits buffer-based kernels from LNCC.
|
| """
|
|
|
| def __init__(self, win=None, num_ch=1, eps=1e-3, central=True, smooth=False,
|
| scale_ratios=[1, 0.5, 0.25], scale_weights=[0.75, 0.5, 0.25]):
|
| super(MSLNCC, self).__init__(win=win, num_ch=num_ch, eps=eps,
|
| central=central, smooth=smooth)
|
| if win is None:
|
| win = [9] * self.ndims
|
| self.scale_ratios = scale_ratios
|
| self.scale_weights = scale_weights
|
|
|
| def _downsample(self, I, J, label, ratio):
|
| if ratio >= 1.0:
|
| return I, J, label
|
| factor = int(1.0 / ratio)
|
| I_down = F.avg_pool3d(I, kernel_size=factor, stride=factor)
|
| J_down = F.avg_pool3d(J, kernel_size=factor, stride=factor)
|
| label_down = None
|
| if label is not None:
|
| label_down = F.max_pool3d(label.float(), kernel_size=factor, stride=factor)
|
| return I_down, J_down, label_down
|
|
|
| def forward(self, I, J, label=None):
|
| total_loss = 0.0
|
| total_weight = 0.0
|
| for ratio, weight in zip(self.scale_ratios, self.scale_weights):
|
| I_s, J_s, label_s = self._downsample(I, J, label, ratio)
|
| total_loss += weight * self.lncc(I_s * self.scale, J_s * self.scale, label=label_s)
|
| total_weight += weight
|
| return -total_loss / total_weight
|
|
|