File size: 10,192 Bytes
2af0e94 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 | """
networks_opt.py — Optimized network components.
Subclasses RecMulModMutAttnNet and STN to eliminate per-call overhead:
1. OptSTN: register_buffer for ref_grid/max_sz — no .to(device) per call
2. OptRecMulModMutAttnNet: cached max_sz/img_sz tensors, ref_grid device —
eliminates ~80 NumPy→GPU transfers and ~32 tensor recreations per registration step
All optimizations are mathematically equivalent to the originals.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from Diffusion.networks import RecMulModMutAttnNet, STN
# ======================================================================
# Optimized STN
# ======================================================================
class OptSTN(STN):
"""STN with register_buffer for automatic device transfer.
Eliminates per-call .to(device) overhead in resample() and forward().
Buffers auto-transfer when module.to(device) is called.
"""
def __init__(self, ndims=2, img_sz=None, max_sz=None, device=None,
padding_mode="border", resample_mode=None):
# Skip parent __init__ to avoid creating plain tensor attributes
nn.Module.__init__(self)
self.ndims = ndims
self.img_sz = [img_sz] * ndims
self.device = device
self.padding_mode = padding_mode
self.resample_mode = resample_mode
# OPT: register_buffer — auto device transfer, no per-call .to()
max_sz_val = [img_sz] * ndims
max_sz_tensor = torch.Tensor(
np.reshape(np.array(max_sz_val), [1, self.ndims] + [1] * self.ndims)
)
self.register_buffer('max_sz', max_sz_tensor)
if self.img_sz is not None:
ref_grid = torch.reshape(
torch.stack(torch.meshgrid(
[torch.arange(end=s) for s in self.img_sz]
), 0),
[1, self.ndims] + self.img_sz
)
self.register_buffer('ref_grid', ref_grid)
# OPT: pre-compute the img_sz tensor used when forward() calls resample()
img_sz_for_resample = torch.reshape(
torch.tensor([(s - 1) / 2. for s in self.img_sz]),
[1] + [1] * self.ndims + [self.ndims]
)
self.register_buffer('_img_sz_for_resample', img_sz_for_resample)
# OPT: pre-compute constant permutation order
self._perm = [0] + list(range(2, 2 + self.ndims)) + [1]
def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
# OPT: no .to(device) — buffers auto-transfer with module.to()
ref = self.ref_grid if ref is None else ref
if img_sz is None:
img_sz_t = self.max_sz
else:
# Use pre-computed tensor for the common case (called from forward)
img_sz_t = self._img_sz_for_resample
resample_mode = 'bilinear' if self.resample_mode is None else self.resample_mode
grid = torch.flip(
(ddf * self.max_sz + ref).permute(self._perm) / img_sz_t - 1,
dims=[-1]
)
return F.grid_sample(vol, grid, mode=resample_mode,
padding_mode=padding_mode, align_corners=True)
def forward(self, x, ddf):
# OPT: no device check or ref_grid regeneration — buffers handle it
return self.resample(x, ddf=ddf, img_sz=self.img_sz,
padding_mode=self.padding_mode)
# ======================================================================
# Optimized RecMulModMutAttnNet
# ======================================================================
class OptRecMulModMutAttnNet(RecMulModMutAttnNet):
"""RecMulModMutAttnNet with cached tensors for resample/forward.
Eliminates per-call overhead:
- resample(): cached max_sz tensor (was: NumPy→Torch→GPU every call)
- forward(): cached img_sz tensor and ref_grid device placement
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Cache slots — populated on first forward
self._cached_input_key = None
self._cached_max_sz_tensor = None
self._cached_img_sz_tensor = None
# OPT: pre-compute constant permutation order
self._perm = [0] + list(range(2, 2 + self.dimension)) + [1]
def _ensure_cache(self, img_sz, device):
"""Populate cached tensors if input size or device changed."""
key = (tuple(img_sz), device)
if key == self._cached_input_key:
return
self._cached_input_key = key
max_sz_list = [img_sz[0]] * self.dimension
self.max_sz = max_sz_list
# OPT: create max_sz tensor ONCE, reuse across all resample() calls
self._cached_max_sz_tensor = torch.Tensor(
np.reshape(np.array(max_sz_list), [1, self.dimension] + [1] * self.dimension)
).to(device)
# OPT: create img_sz tensor ONCE per size change
self._cached_img_sz_tensor = torch.reshape(
torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=device),
[1] * (self.dimension + 1) + [self.dimension]
)
# OPT: ref_grid — only regenerate if size changed, only .to() if needed
if list(img_sz) != self.img_res:
self.ref_grid = torch.reshape(
torch.stack(torch.meshgrid(
[torch.arange(end=imsz) for imsz in img_sz]
), 0),
[1, self.dimension] + list(img_sz)
).to(device)
elif self.ref_grid.device != torch.device(device):
self.ref_grid = self.ref_grid.to(device)
def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
# OPT: use cached max_sz tensor instead of NumPy→Torch→GPU every call
ref = self.ref_grid if ref is None else ref
img_sz = self._cached_img_sz_tensor if img_sz is not None else self._cached_max_sz_tensor
grid = torch.flip(
(ddf * self._cached_max_sz_tensor + ref).permute(self._perm) / img_sz - 1,
dims=[-1]
)
return F.grid_sample(vol, grid, mode='bilinear',
padding_mode=padding_mode, align_corners=True)
def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2):
self.device = x.device
img_sz = x.size()[2:]
n = x.size()[0]
ts_emb_shape = [n, -1] + [1] * self.dimension
# OPT: cache tensors — only recreate if input size/device changes
self._ensure_cache(img_sz, self.device)
self.img_sz = self._cached_img_sz_tensor
img = x
t = self.time_embed(t)
if text is None:
text = self.text
text = text.to(self.device)
txt_shape = [1, -1] + [1] * self.dimension
else:
txt_shape = [n, -1] + [1] * self.dimension
for rec_id in range(rec_num):
if self.conditional_input:
tgt = y
enc_list = []
out = img
for i in range(self.hier_num):
out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
if self.conditional_input:
tgt = self.block_down_cond[i](tgt) + self.txt_layers[i](text).reshape(txt_shape)
out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
enc_list.append(out)
out = self.down_layers[i](out)
if self.conditional_input:
tgt = self.down_layers[i](tgt)
out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
if self.conditional_input:
out_shape = out.shape
tgt_shape = tgt.shape
out_flat = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1)
tgt_flat = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1)
out_attn, _ = self.attn_layer0(out_flat, tgt_flat, tgt_flat)
tgt_attn, _ = self.attn_layer1(tgt_flat, out_flat, out_flat)
out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape)
tgt_attn = tgt_attn.permute(1, 2, 0).contiguous().view(tgt_shape)
out = out + out_attn
tgt = tgt + tgt_attn
out = self.fuse(torch.cat([out, tgt], dim=1))
if self.conditional_input:
img_txt_feat = self.img2txt(out)
self.img_embd = self.global_maxpool(img_txt_feat).view(n, -1)
out_txt = self.txt_layers[-1](text).reshape(txt_shape) + img_txt_feat
out_txt = self.txt_proc(out_txt)
out_txt = self.txt2img(out_txt)
out = out + out_txt
for i in range(self.hier_num):
out = torch.cat((self.up_layers[i](out), enc_list[-i - 1]), dim=1)
out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))
out = self.conv_out(out) / 128
ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
if rec_id == 0:
ddf = ddf_one
else:
ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
return ddf
# ======================================================================
# Factory function
# ======================================================================
def get_net_opt(name):
"""Return optimized network class if available, else fall back to original."""
if name == "recmulmodmutattnnet":
return OptRecMulModMutAttnNet
# Fall back to original for other network types
from Diffusion.networks import get_net
return get_net(name)
|