File size: 5,404 Bytes
2af0e94 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | """
losses_opt.py — Optimized loss functions.
Inherits from Diffusion.losses and overrides LNCC and MSLNCC to use
register_buffer for convolution kernels (auto device transfer, no
per-call .to(device) overhead).
All other loss classes (LMSE, NCC, MRSE, RMSE, Grad) are re-exported
unchanged.
"""
import numpy as np
import torch
import torch.nn.functional as F
# Re-export unchanged classes
from Diffusion.losses import (
LMSE,
NCC,
MRSE,
RMSE,
Grad,
avg_std_skew_kurt,
grad_std,
avg_std,
EPS,
eps_scale,
)
class LNCC(torch.nn.Module):
"""
Local (over window) normalized cross-correlation (LNCC).
Optimized: kernels stored as registered buffers for automatic device transfer.
"""
def __init__(self, win=None, num_ch=1, eps=1e-3, central=True, smooth=True):
super(LNCC, self).__init__()
self.scale = 2e0
self.win = win
self.eps = eps
self.central = central
self.ndims = 3
self.strides = [1] * (self.ndims + 2)
self.smooth = smooth
if self.win is None:
self.win = [11] * self.ndims
self.padding = [(w - 1) // 2 for w in self.win]
if smooth:
self.tail = None # will be set in _build_kernel
kernels = self._build_kernel(std=0.5)
self.register_buffer('kernels', kernels) # OPT: auto device transfer
self.register_buffer('sum_filt', self._build_kernel(std=0.0)) # OPT: auto device transfer
def _build_kernel(self, std=0.0):
if std == 0.0:
return torch.ones([1, 1, *self.win]) / np.prod(self.win)
else:
self.tail = int(np.ceil(std)) * 2
k = torch.exp(-0.5 * (torch.arange(-self.tail, self.tail + 1, dtype=torch.float32) ** 2) / std ** 2)
kernel = k / torch.sum(k)
kernel = kernel.view(-1, 1, 1) * kernel.view(1, -1, 1) * kernel.view(1, 1, -1)
return kernel.unsqueeze(0).unsqueeze(0)
def lncc(self, I, J, label=None):
# OPT: no .to(I.device) needed — buffers auto-transfer with module.to()
if self.smooth:
I = torch.nn.functional.conv3d(I, self.kernels, stride=1, padding=self.tail)
J = torch.nn.functional.conv3d(J, self.kernels, stride=1, padding=self.tail)
I2 = I * I
J2 = J * J
IJ = I * J
if self.central:
I_sum = torch.nn.functional.conv3d(I, self.sum_filt, stride=1, padding=self.padding)
J_sum = torch.nn.functional.conv3d(J, self.sum_filt, stride=1, padding=self.padding)
I2_sum = torch.nn.functional.conv3d(I2, self.sum_filt, stride=1, padding=self.padding)
J2_sum = torch.nn.functional.conv3d(J2, self.sum_filt, stride=1, padding=self.padding)
IJ_sum = torch.nn.functional.conv3d(IJ, self.sum_filt, stride=1, padding=self.padding)
cross = IJ_sum - (I_sum * J_sum)
I_var = I2_sum - (I_sum * I_sum)
J_var = J2_sum - (J_sum * J_sum)
else:
I2_sum = torch.nn.functional.conv3d(I2, self.sum_filt, stride=1, padding=self.padding)
J2_sum = torch.nn.functional.conv3d(J2, self.sum_filt, stride=1, padding=self.padding)
IJ_sum = torch.nn.functional.conv3d(IJ, self.sum_filt, stride=1, padding=self.padding)
cross = IJ_sum
I_var = I2_sum
J_var = J2_sum
cc = (cross * cross) / (I_var + self.eps) / (J_var + self.eps)
if label is not None:
label = label.float()
cc = torch.sum(cc * label, dim=(2, 3, 4)) / (torch.sum(label, dim=(2, 3, 4)) + self.eps)
return torch.mean(cc)
def forward(self, I, J, label=None):
return -self.lncc(I * self.scale, J * self.scale, label=label)
class MSLNCC(LNCC):
"""
Multi-Scale Local Normalized Cross-Correlation (MSLNCC).
Optimized: inherits buffer-based kernels from LNCC.
"""
def __init__(self, win=None, num_ch=1, eps=1e-3, central=True, smooth=False,
scale_ratios=[1, 0.5, 0.25], scale_weights=[0.75, 0.5, 0.25]):
super(MSLNCC, self).__init__(win=win, num_ch=num_ch, eps=eps,
central=central, smooth=smooth)
if win is None:
win = [9] * self.ndims
self.scale_ratios = scale_ratios
self.scale_weights = scale_weights
def _downsample(self, I, J, label, ratio):
if ratio >= 1.0:
return I, J, label
factor = int(1.0 / ratio)
I_down = F.avg_pool3d(I, kernel_size=factor, stride=factor)
J_down = F.avg_pool3d(J, kernel_size=factor, stride=factor)
label_down = None
if label is not None:
label_down = F.max_pool3d(label.float(), kernel_size=factor, stride=factor)
return I_down, J_down, label_down
def forward(self, I, J, label=None):
total_loss = 0.0
total_weight = 0.0
for ratio, weight in zip(self.scale_ratios, self.scale_weights):
I_s, J_s, label_s = self._downsample(I, J, label, ratio)
total_loss += weight * self.lncc(I_s * self.scale, J_s * self.scale, label=label_s)
total_weight += weight
return -total_loss / total_weight
|