File size: 10,192 Bytes
2af0e94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
"""

networks_opt.py — Optimized network components.



Subclasses RecMulModMutAttnNet and STN to eliminate per-call overhead:

  1. OptSTN: register_buffer for ref_grid/max_sz — no .to(device) per call

  2. OptRecMulModMutAttnNet: cached max_sz/img_sz tensors, ref_grid device —

     eliminates ~80 NumPy→GPU transfers and ~32 tensor recreations per registration step



All optimizations are mathematically equivalent to the originals.

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from Diffusion.networks import RecMulModMutAttnNet, STN


# ======================================================================
# Optimized STN
# ======================================================================

class OptSTN(STN):
    """STN with register_buffer for automatic device transfer.



    Eliminates per-call .to(device) overhead in resample() and forward().

    Buffers auto-transfer when module.to(device) is called.

    """

    def __init__(self, ndims=2, img_sz=None, max_sz=None, device=None,

                 padding_mode="border", resample_mode=None):
        # Skip parent __init__ to avoid creating plain tensor attributes
        nn.Module.__init__(self)
        self.ndims = ndims
        self.img_sz = [img_sz] * ndims
        self.device = device
        self.padding_mode = padding_mode
        self.resample_mode = resample_mode

        # OPT: register_buffer — auto device transfer, no per-call .to()
        max_sz_val = [img_sz] * ndims
        max_sz_tensor = torch.Tensor(
            np.reshape(np.array(max_sz_val), [1, self.ndims] + [1] * self.ndims)
        )
        self.register_buffer('max_sz', max_sz_tensor)

        if self.img_sz is not None:
            ref_grid = torch.reshape(
                torch.stack(torch.meshgrid(
                    [torch.arange(end=s) for s in self.img_sz]
                ), 0),
                [1, self.ndims] + self.img_sz
            )
            self.register_buffer('ref_grid', ref_grid)

            # OPT: pre-compute the img_sz tensor used when forward() calls resample()
            img_sz_for_resample = torch.reshape(
                torch.tensor([(s - 1) / 2. for s in self.img_sz]),
                [1] + [1] * self.ndims + [self.ndims]
            )
            self.register_buffer('_img_sz_for_resample', img_sz_for_resample)

        # OPT: pre-compute constant permutation order
        self._perm = [0] + list(range(2, 2 + self.ndims)) + [1]

    def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
        # OPT: no .to(device) — buffers auto-transfer with module.to()
        ref = self.ref_grid if ref is None else ref

        if img_sz is None:
            img_sz_t = self.max_sz
        else:
            # Use pre-computed tensor for the common case (called from forward)
            img_sz_t = self._img_sz_for_resample

        resample_mode = 'bilinear' if self.resample_mode is None else self.resample_mode

        grid = torch.flip(
            (ddf * self.max_sz + ref).permute(self._perm) / img_sz_t - 1,
            dims=[-1]
        )
        return F.grid_sample(vol, grid, mode=resample_mode,
                             padding_mode=padding_mode, align_corners=True)

    def forward(self, x, ddf):
        # OPT: no device check or ref_grid regeneration — buffers handle it
        return self.resample(x, ddf=ddf, img_sz=self.img_sz,
                             padding_mode=self.padding_mode)


# ======================================================================
# Optimized RecMulModMutAttnNet
# ======================================================================

class OptRecMulModMutAttnNet(RecMulModMutAttnNet):
    """RecMulModMutAttnNet with cached tensors for resample/forward.



    Eliminates per-call overhead:

      - resample(): cached max_sz tensor (was: NumPy→Torch→GPU every call)

      - forward(): cached img_sz tensor and ref_grid device placement

    """

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # Cache slots — populated on first forward
        self._cached_input_key = None
        self._cached_max_sz_tensor = None
        self._cached_img_sz_tensor = None
        # OPT: pre-compute constant permutation order
        self._perm = [0] + list(range(2, 2 + self.dimension)) + [1]

    def _ensure_cache(self, img_sz, device):
        """Populate cached tensors if input size or device changed."""
        key = (tuple(img_sz), device)
        if key == self._cached_input_key:
            return
        self._cached_input_key = key
        max_sz_list = [img_sz[0]] * self.dimension
        self.max_sz = max_sz_list

        # OPT: create max_sz tensor ONCE, reuse across all resample() calls
        self._cached_max_sz_tensor = torch.Tensor(
            np.reshape(np.array(max_sz_list), [1, self.dimension] + [1] * self.dimension)
        ).to(device)

        # OPT: create img_sz tensor ONCE per size change
        self._cached_img_sz_tensor = torch.reshape(
            torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=device),
            [1] * (self.dimension + 1) + [self.dimension]
        )

        # OPT: ref_grid — only regenerate if size changed, only .to() if needed
        if list(img_sz) != self.img_res:
            self.ref_grid = torch.reshape(
                torch.stack(torch.meshgrid(
                    [torch.arange(end=imsz) for imsz in img_sz]
                ), 0),
                [1, self.dimension] + list(img_sz)
            ).to(device)
        elif self.ref_grid.device != torch.device(device):
            self.ref_grid = self.ref_grid.to(device)

    def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
        # OPT: use cached max_sz tensor instead of NumPy→Torch→GPU every call
        ref = self.ref_grid if ref is None else ref
        img_sz = self._cached_img_sz_tensor if img_sz is not None else self._cached_max_sz_tensor

        grid = torch.flip(
            (ddf * self._cached_max_sz_tensor + ref).permute(self._perm) / img_sz - 1,
            dims=[-1]
        )
        return F.grid_sample(vol, grid, mode='bilinear',
                             padding_mode=padding_mode, align_corners=True)

    def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2):
        self.device = x.device
        img_sz = x.size()[2:]
        n = x.size()[0]
        ts_emb_shape = [n, -1] + [1] * self.dimension

        # OPT: cache tensors — only recreate if input size/device changes
        self._ensure_cache(img_sz, self.device)
        self.img_sz = self._cached_img_sz_tensor

        img = x
        t = self.time_embed(t)
        if text is None:
            text = self.text
            text = text.to(self.device)
            txt_shape = [1, -1] + [1] * self.dimension
        else:
            txt_shape = [n, -1] + [1] * self.dimension

        for rec_id in range(rec_num):
            if self.conditional_input:
                tgt = y
            enc_list = []
            out = img
            for i in range(self.hier_num):
                out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
                if self.conditional_input:
                    tgt = self.block_down_cond[i](tgt) + self.txt_layers[i](text).reshape(txt_shape)
                    out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
                    tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
                enc_list.append(out)
                out = self.down_layers[i](out)
                if self.conditional_input:
                    tgt = self.down_layers[i](tgt)

            out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
            if self.conditional_input:
                out_shape = out.shape
                tgt_shape = tgt.shape
                out_flat = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1)
                tgt_flat = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1)
                out_attn, _ = self.attn_layer0(out_flat, tgt_flat, tgt_flat)
                tgt_attn, _ = self.attn_layer1(tgt_flat, out_flat, out_flat)
                out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape)
                tgt_attn = tgt_attn.permute(1, 2, 0).contiguous().view(tgt_shape)
                out = out + out_attn
                tgt = tgt + tgt_attn
                out = self.fuse(torch.cat([out, tgt], dim=1))

            if self.conditional_input:
                img_txt_feat = self.img2txt(out)
                self.img_embd = self.global_maxpool(img_txt_feat).view(n, -1)
                out_txt = self.txt_layers[-1](text).reshape(txt_shape) + img_txt_feat
                out_txt = self.txt_proc(out_txt)
                out_txt = self.txt2img(out_txt)
                out = out + out_txt

            for i in range(self.hier_num):
                out = torch.cat((self.up_layers[i](out), enc_list[-i - 1]), dim=1)
                out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))

            out = self.conv_out(out) / 128

            ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
            if rec_id == 0:
                ddf = ddf_one
            else:
                ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
            img = self.resample(x, ddf=ddf, img_sz=self.img_sz)

        return ddf


# ======================================================================
# Factory function
# ======================================================================

def get_net_opt(name):
    """Return optimized network class if available, else fall back to original."""
    if name == "recmulmodmutattnnet":
        return OptRecMulModMutAttnNet
    # Fall back to original for other network types
    from Diffusion.networks import get_net
    return get_net(name)