| """
|
| networks_opt.py — Optimized network components.
|
|
|
| Subclasses RecMulModMutAttnNet and STN to eliminate per-call overhead:
|
| 1. OptSTN: register_buffer for ref_grid/max_sz — no .to(device) per call
|
| 2. OptRecMulModMutAttnNet: cached max_sz/img_sz tensors, ref_grid device —
|
| eliminates ~80 NumPy→GPU transfers and ~32 tensor recreations per registration step
|
|
|
| All optimizations are mathematically equivalent to the originals.
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| import numpy as np
|
|
|
| from Diffusion.networks import RecMulModMutAttnNet, STN
|
|
|
|
|
|
|
|
|
|
|
|
|
| class OptSTN(STN):
|
| """STN with register_buffer for automatic device transfer.
|
|
|
| Eliminates per-call .to(device) overhead in resample() and forward().
|
| Buffers auto-transfer when module.to(device) is called.
|
| """
|
|
|
| def __init__(self, ndims=2, img_sz=None, max_sz=None, device=None,
|
| padding_mode="border", resample_mode=None):
|
|
|
| nn.Module.__init__(self)
|
| self.ndims = ndims
|
| self.img_sz = [img_sz] * ndims
|
| self.device = device
|
| self.padding_mode = padding_mode
|
| self.resample_mode = resample_mode
|
|
|
|
|
| max_sz_val = [img_sz] * ndims
|
| max_sz_tensor = torch.Tensor(
|
| np.reshape(np.array(max_sz_val), [1, self.ndims] + [1] * self.ndims)
|
| )
|
| self.register_buffer('max_sz', max_sz_tensor)
|
|
|
| if self.img_sz is not None:
|
| ref_grid = torch.reshape(
|
| torch.stack(torch.meshgrid(
|
| [torch.arange(end=s) for s in self.img_sz]
|
| ), 0),
|
| [1, self.ndims] + self.img_sz
|
| )
|
| self.register_buffer('ref_grid', ref_grid)
|
|
|
|
|
| img_sz_for_resample = torch.reshape(
|
| torch.tensor([(s - 1) / 2. for s in self.img_sz]),
|
| [1] + [1] * self.ndims + [self.ndims]
|
| )
|
| self.register_buffer('_img_sz_for_resample', img_sz_for_resample)
|
|
|
|
|
| self._perm = [0] + list(range(2, 2 + self.ndims)) + [1]
|
|
|
| def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
|
|
|
| ref = self.ref_grid if ref is None else ref
|
|
|
| if img_sz is None:
|
| img_sz_t = self.max_sz
|
| else:
|
|
|
| img_sz_t = self._img_sz_for_resample
|
|
|
| resample_mode = 'bilinear' if self.resample_mode is None else self.resample_mode
|
|
|
| grid = torch.flip(
|
| (ddf * self.max_sz + ref).permute(self._perm) / img_sz_t - 1,
|
| dims=[-1]
|
| )
|
| return F.grid_sample(vol, grid, mode=resample_mode,
|
| padding_mode=padding_mode, align_corners=True)
|
|
|
| def forward(self, x, ddf):
|
|
|
| return self.resample(x, ddf=ddf, img_sz=self.img_sz,
|
| padding_mode=self.padding_mode)
|
|
|
|
|
|
|
|
|
|
|
|
|
| class OptRecMulModMutAttnNet(RecMulModMutAttnNet):
|
| """RecMulModMutAttnNet with cached tensors for resample/forward.
|
|
|
| Eliminates per-call overhead:
|
| - resample(): cached max_sz tensor (was: NumPy→Torch→GPU every call)
|
| - forward(): cached img_sz tensor and ref_grid device placement
|
| """
|
|
|
| def __init__(self, *args, **kwargs):
|
| super().__init__(*args, **kwargs)
|
|
|
| self._cached_input_key = None
|
| self._cached_max_sz_tensor = None
|
| self._cached_img_sz_tensor = None
|
|
|
| self._perm = [0] + list(range(2, 2 + self.dimension)) + [1]
|
|
|
| def _ensure_cache(self, img_sz, device):
|
| """Populate cached tensors if input size or device changed."""
|
| key = (tuple(img_sz), device)
|
| if key == self._cached_input_key:
|
| return
|
| self._cached_input_key = key
|
| max_sz_list = [img_sz[0]] * self.dimension
|
| self.max_sz = max_sz_list
|
|
|
|
|
| self._cached_max_sz_tensor = torch.Tensor(
|
| np.reshape(np.array(max_sz_list), [1, self.dimension] + [1] * self.dimension)
|
| ).to(device)
|
|
|
|
|
| self._cached_img_sz_tensor = torch.reshape(
|
| torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=device),
|
| [1] * (self.dimension + 1) + [self.dimension]
|
| )
|
|
|
|
|
| if list(img_sz) != self.img_res:
|
| self.ref_grid = torch.reshape(
|
| torch.stack(torch.meshgrid(
|
| [torch.arange(end=imsz) for imsz in img_sz]
|
| ), 0),
|
| [1, self.dimension] + list(img_sz)
|
| ).to(device)
|
| elif self.ref_grid.device != torch.device(device):
|
| self.ref_grid = self.ref_grid.to(device)
|
|
|
| def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
|
|
|
| ref = self.ref_grid if ref is None else ref
|
| img_sz = self._cached_img_sz_tensor if img_sz is not None else self._cached_max_sz_tensor
|
|
|
| grid = torch.flip(
|
| (ddf * self._cached_max_sz_tensor + ref).permute(self._perm) / img_sz - 1,
|
| dims=[-1]
|
| )
|
| return F.grid_sample(vol, grid, mode='bilinear',
|
| padding_mode=padding_mode, align_corners=True)
|
|
|
| def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2):
|
| self.device = x.device
|
| img_sz = x.size()[2:]
|
| n = x.size()[0]
|
| ts_emb_shape = [n, -1] + [1] * self.dimension
|
|
|
|
|
| self._ensure_cache(img_sz, self.device)
|
| self.img_sz = self._cached_img_sz_tensor
|
|
|
| img = x
|
| t = self.time_embed(t)
|
| if text is None:
|
| text = self.text
|
| text = text.to(self.device)
|
| txt_shape = [1, -1] + [1] * self.dimension
|
| else:
|
| txt_shape = [n, -1] + [1] * self.dimension
|
|
|
| for rec_id in range(rec_num):
|
| if self.conditional_input:
|
| tgt = y
|
| enc_list = []
|
| out = img
|
| for i in range(self.hier_num):
|
| out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
|
| if self.conditional_input:
|
| tgt = self.block_down_cond[i](tgt) + self.txt_layers[i](text).reshape(txt_shape)
|
| out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
|
| tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
|
| enc_list.append(out)
|
| out = self.down_layers[i](out)
|
| if self.conditional_input:
|
| tgt = self.down_layers[i](tgt)
|
|
|
| out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
|
| if self.conditional_input:
|
| out_shape = out.shape
|
| tgt_shape = tgt.shape
|
| out_flat = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1)
|
| tgt_flat = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1)
|
| out_attn, _ = self.attn_layer0(out_flat, tgt_flat, tgt_flat)
|
| tgt_attn, _ = self.attn_layer1(tgt_flat, out_flat, out_flat)
|
| out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape)
|
| tgt_attn = tgt_attn.permute(1, 2, 0).contiguous().view(tgt_shape)
|
| out = out + out_attn
|
| tgt = tgt + tgt_attn
|
| out = self.fuse(torch.cat([out, tgt], dim=1))
|
|
|
| if self.conditional_input:
|
| img_txt_feat = self.img2txt(out)
|
| self.img_embd = self.global_maxpool(img_txt_feat).view(n, -1)
|
| out_txt = self.txt_layers[-1](text).reshape(txt_shape) + img_txt_feat
|
| out_txt = self.txt_proc(out_txt)
|
| out_txt = self.txt2img(out_txt)
|
| out = out + out_txt
|
|
|
| for i in range(self.hier_num):
|
| out = torch.cat((self.up_layers[i](out), enc_list[-i - 1]), dim=1)
|
| out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))
|
|
|
| out = self.conv_out(out) / 128
|
|
|
| ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
|
| if rec_id == 0:
|
| ddf = ddf_one
|
| else:
|
| ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
|
| img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
|
|
|
| return ddf
|
|
|
|
|
|
|
|
|
|
|
|
|
| def get_net_opt(name):
|
| """Return optimized network class if available, else fall back to original."""
|
| if name == "recmulmodmutattnnet":
|
| return OptRecMulModMutAttnNet
|
|
|
| from Diffusion.networks import get_net
|
| return get_net(name)
|
|
|