File size: 5,404 Bytes
2af0e94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""

losses_opt.py — Optimized loss functions.



Inherits from Diffusion.losses and overrides LNCC and MSLNCC to use

register_buffer for convolution kernels (auto device transfer, no

per-call .to(device) overhead).



All other loss classes (LMSE, NCC, MRSE, RMSE, Grad) are re-exported

unchanged.

"""

import numpy as np
import torch
import torch.nn.functional as F

# Re-export unchanged classes
from Diffusion.losses import (
    LMSE,
    NCC,
    MRSE,
    RMSE,
    Grad,
    avg_std_skew_kurt,
    grad_std,
    avg_std,
    EPS,
    eps_scale,
)


class LNCC(torch.nn.Module):
    """

    Local (over window) normalized cross-correlation (LNCC).

    Optimized: kernels stored as registered buffers for automatic device transfer.

    """

    def __init__(self, win=None, num_ch=1, eps=1e-3, central=True, smooth=True):
        super(LNCC, self).__init__()
        self.scale = 2e0
        self.win = win
        self.eps = eps
        self.central = central
        self.ndims = 3
        self.strides = [1] * (self.ndims + 2)
        self.smooth = smooth

        if self.win is None:
            self.win = [11] * self.ndims
        self.padding = [(w - 1) // 2 for w in self.win]

        if smooth:
            self.tail = None  # will be set in _build_kernel
            kernels = self._build_kernel(std=0.5)
            self.register_buffer('kernels', kernels)  # OPT: auto device transfer
        self.register_buffer('sum_filt', self._build_kernel(std=0.0))  # OPT: auto device transfer

    def _build_kernel(self, std=0.0):
        if std == 0.0:
            return torch.ones([1, 1, *self.win]) / np.prod(self.win)
        else:
            self.tail = int(np.ceil(std)) * 2
            k = torch.exp(-0.5 * (torch.arange(-self.tail, self.tail + 1, dtype=torch.float32) ** 2) / std ** 2)
            kernel = k / torch.sum(k)
            kernel = kernel.view(-1, 1, 1) * kernel.view(1, -1, 1) * kernel.view(1, 1, -1)
            return kernel.unsqueeze(0).unsqueeze(0)

    def lncc(self, I, J, label=None):
        # OPT: no .to(I.device) needed — buffers auto-transfer with module.to()

        if self.smooth:
            I = torch.nn.functional.conv3d(I, self.kernels, stride=1, padding=self.tail)
            J = torch.nn.functional.conv3d(J, self.kernels, stride=1, padding=self.tail)

        I2 = I * I
        J2 = J * J
        IJ = I * J

        if self.central:
            I_sum = torch.nn.functional.conv3d(I, self.sum_filt, stride=1, padding=self.padding)
            J_sum = torch.nn.functional.conv3d(J, self.sum_filt, stride=1, padding=self.padding)
            I2_sum = torch.nn.functional.conv3d(I2, self.sum_filt, stride=1, padding=self.padding)
            J2_sum = torch.nn.functional.conv3d(J2, self.sum_filt, stride=1, padding=self.padding)
            IJ_sum = torch.nn.functional.conv3d(IJ, self.sum_filt, stride=1, padding=self.padding)

            cross = IJ_sum - (I_sum * J_sum)
            I_var = I2_sum - (I_sum * I_sum)
            J_var = J2_sum - (J_sum * J_sum)
        else:
            I2_sum = torch.nn.functional.conv3d(I2, self.sum_filt, stride=1, padding=self.padding)
            J2_sum = torch.nn.functional.conv3d(J2, self.sum_filt, stride=1, padding=self.padding)
            IJ_sum = torch.nn.functional.conv3d(IJ, self.sum_filt, stride=1, padding=self.padding)

            cross = IJ_sum
            I_var = I2_sum
            J_var = J2_sum

        cc = (cross * cross) / (I_var + self.eps) / (J_var + self.eps)
        if label is not None:
            label = label.float()
            cc = torch.sum(cc * label, dim=(2, 3, 4)) / (torch.sum(label, dim=(2, 3, 4)) + self.eps)

        return torch.mean(cc)

    def forward(self, I, J, label=None):
        return -self.lncc(I * self.scale, J * self.scale, label=label)


class MSLNCC(LNCC):
    """

    Multi-Scale Local Normalized Cross-Correlation (MSLNCC).

    Optimized: inherits buffer-based kernels from LNCC.

    """

    def __init__(self, win=None, num_ch=1, eps=1e-3, central=True, smooth=False,

                 scale_ratios=[1, 0.5, 0.25], scale_weights=[0.75, 0.5, 0.25]):
        super(MSLNCC, self).__init__(win=win, num_ch=num_ch, eps=eps,
                                     central=central, smooth=smooth)
        if win is None:
            win = [9] * self.ndims
        self.scale_ratios = scale_ratios
        self.scale_weights = scale_weights

    def _downsample(self, I, J, label, ratio):
        if ratio >= 1.0:
            return I, J, label
        factor = int(1.0 / ratio)
        I_down = F.avg_pool3d(I, kernel_size=factor, stride=factor)
        J_down = F.avg_pool3d(J, kernel_size=factor, stride=factor)
        label_down = None
        if label is not None:
            label_down = F.max_pool3d(label.float(), kernel_size=factor, stride=factor)
        return I_down, J_down, label_down

    def forward(self, I, J, label=None):
        total_loss = 0.0
        total_weight = 0.0
        for ratio, weight in zip(self.scale_ratios, self.scale_weights):
            I_s, J_s, label_s = self._downsample(I, J, label, ratio)
            total_loss += weight * self.lncc(I_s * self.scale, J_s * self.scale, label=label_s)
            total_weight += weight
        return -total_loss / total_weight