Omini3D / Diffusion /losses_opt.py
maxmo2009's picture
Sync from local: code + epoch-110 checkpoint, clean README
2af0e94 verified
"""
losses_opt.py — Optimized loss functions.
Inherits from Diffusion.losses and overrides LNCC and MSLNCC to use
register_buffer for convolution kernels (auto device transfer, no
per-call .to(device) overhead).
All other loss classes (LMSE, NCC, MRSE, RMSE, Grad) are re-exported
unchanged.
"""
import numpy as np
import torch
import torch.nn.functional as F
# Re-export unchanged classes
from Diffusion.losses import (
LMSE,
NCC,
MRSE,
RMSE,
Grad,
avg_std_skew_kurt,
grad_std,
avg_std,
EPS,
eps_scale,
)
class LNCC(torch.nn.Module):
"""
Local (over window) normalized cross-correlation (LNCC).
Optimized: kernels stored as registered buffers for automatic device transfer.
"""
def __init__(self, win=None, num_ch=1, eps=1e-3, central=True, smooth=True):
super(LNCC, self).__init__()
self.scale = 2e0
self.win = win
self.eps = eps
self.central = central
self.ndims = 3
self.strides = [1] * (self.ndims + 2)
self.smooth = smooth
if self.win is None:
self.win = [11] * self.ndims
self.padding = [(w - 1) // 2 for w in self.win]
if smooth:
self.tail = None # will be set in _build_kernel
kernels = self._build_kernel(std=0.5)
self.register_buffer('kernels', kernels) # OPT: auto device transfer
self.register_buffer('sum_filt', self._build_kernel(std=0.0)) # OPT: auto device transfer
def _build_kernel(self, std=0.0):
if std == 0.0:
return torch.ones([1, 1, *self.win]) / np.prod(self.win)
else:
self.tail = int(np.ceil(std)) * 2
k = torch.exp(-0.5 * (torch.arange(-self.tail, self.tail + 1, dtype=torch.float32) ** 2) / std ** 2)
kernel = k / torch.sum(k)
kernel = kernel.view(-1, 1, 1) * kernel.view(1, -1, 1) * kernel.view(1, 1, -1)
return kernel.unsqueeze(0).unsqueeze(0)
def lncc(self, I, J, label=None):
# OPT: no .to(I.device) needed — buffers auto-transfer with module.to()
if self.smooth:
I = torch.nn.functional.conv3d(I, self.kernels, stride=1, padding=self.tail)
J = torch.nn.functional.conv3d(J, self.kernels, stride=1, padding=self.tail)
I2 = I * I
J2 = J * J
IJ = I * J
if self.central:
I_sum = torch.nn.functional.conv3d(I, self.sum_filt, stride=1, padding=self.padding)
J_sum = torch.nn.functional.conv3d(J, self.sum_filt, stride=1, padding=self.padding)
I2_sum = torch.nn.functional.conv3d(I2, self.sum_filt, stride=1, padding=self.padding)
J2_sum = torch.nn.functional.conv3d(J2, self.sum_filt, stride=1, padding=self.padding)
IJ_sum = torch.nn.functional.conv3d(IJ, self.sum_filt, stride=1, padding=self.padding)
cross = IJ_sum - (I_sum * J_sum)
I_var = I2_sum - (I_sum * I_sum)
J_var = J2_sum - (J_sum * J_sum)
else:
I2_sum = torch.nn.functional.conv3d(I2, self.sum_filt, stride=1, padding=self.padding)
J2_sum = torch.nn.functional.conv3d(J2, self.sum_filt, stride=1, padding=self.padding)
IJ_sum = torch.nn.functional.conv3d(IJ, self.sum_filt, stride=1, padding=self.padding)
cross = IJ_sum
I_var = I2_sum
J_var = J2_sum
cc = (cross * cross) / (I_var + self.eps) / (J_var + self.eps)
if label is not None:
label = label.float()
cc = torch.sum(cc * label, dim=(2, 3, 4)) / (torch.sum(label, dim=(2, 3, 4)) + self.eps)
return torch.mean(cc)
def forward(self, I, J, label=None):
return -self.lncc(I * self.scale, J * self.scale, label=label)
class MSLNCC(LNCC):
"""
Multi-Scale Local Normalized Cross-Correlation (MSLNCC).
Optimized: inherits buffer-based kernels from LNCC.
"""
def __init__(self, win=None, num_ch=1, eps=1e-3, central=True, smooth=False,
scale_ratios=[1, 0.5, 0.25], scale_weights=[0.75, 0.5, 0.25]):
super(MSLNCC, self).__init__(win=win, num_ch=num_ch, eps=eps,
central=central, smooth=smooth)
if win is None:
win = [9] * self.ndims
self.scale_ratios = scale_ratios
self.scale_weights = scale_weights
def _downsample(self, I, J, label, ratio):
if ratio >= 1.0:
return I, J, label
factor = int(1.0 / ratio)
I_down = F.avg_pool3d(I, kernel_size=factor, stride=factor)
J_down = F.avg_pool3d(J, kernel_size=factor, stride=factor)
label_down = None
if label is not None:
label_down = F.max_pool3d(label.float(), kernel_size=factor, stride=factor)
return I_down, J_down, label_down
def forward(self, I, J, label=None):
total_loss = 0.0
total_weight = 0.0
for ratio, weight in zip(self.scale_ratios, self.scale_weights):
I_s, J_s, label_s = self._downsample(I, J, label, ratio)
total_loss += weight * self.lncc(I_s * self.scale, J_s * self.scale, label=label_s)
total_weight += weight
return -total_loss / total_weight