""" networks_opt.py — Optimized network components. Subclasses RecMulModMutAttnNet and STN to eliminate per-call overhead: 1. OptSTN: register_buffer for ref_grid/max_sz — no .to(device) per call 2. OptRecMulModMutAttnNet: cached max_sz/img_sz tensors, ref_grid device — eliminates ~80 NumPy→GPU transfers and ~32 tensor recreations per registration step All optimizations are mathematically equivalent to the originals. """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from Diffusion.networks import RecMulModMutAttnNet, STN # ====================================================================== # Optimized STN # ====================================================================== class OptSTN(STN): """STN with register_buffer for automatic device transfer. Eliminates per-call .to(device) overhead in resample() and forward(). Buffers auto-transfer when module.to(device) is called. """ def __init__(self, ndims=2, img_sz=None, max_sz=None, device=None, padding_mode="border", resample_mode=None): # Skip parent __init__ to avoid creating plain tensor attributes nn.Module.__init__(self) self.ndims = ndims self.img_sz = [img_sz] * ndims self.device = device self.padding_mode = padding_mode self.resample_mode = resample_mode # OPT: register_buffer — auto device transfer, no per-call .to() max_sz_val = [img_sz] * ndims max_sz_tensor = torch.Tensor( np.reshape(np.array(max_sz_val), [1, self.ndims] + [1] * self.ndims) ) self.register_buffer('max_sz', max_sz_tensor) if self.img_sz is not None: ref_grid = torch.reshape( torch.stack(torch.meshgrid( [torch.arange(end=s) for s in self.img_sz] ), 0), [1, self.ndims] + self.img_sz ) self.register_buffer('ref_grid', ref_grid) # OPT: pre-compute the img_sz tensor used when forward() calls resample() img_sz_for_resample = torch.reshape( torch.tensor([(s - 1) / 2. for s in self.img_sz]), [1] + [1] * self.ndims + [self.ndims] ) self.register_buffer('_img_sz_for_resample', img_sz_for_resample) # OPT: pre-compute constant permutation order self._perm = [0] + list(range(2, 2 + self.ndims)) + [1] def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"): # OPT: no .to(device) — buffers auto-transfer with module.to() ref = self.ref_grid if ref is None else ref if img_sz is None: img_sz_t = self.max_sz else: # Use pre-computed tensor for the common case (called from forward) img_sz_t = self._img_sz_for_resample resample_mode = 'bilinear' if self.resample_mode is None else self.resample_mode grid = torch.flip( (ddf * self.max_sz + ref).permute(self._perm) / img_sz_t - 1, dims=[-1] ) return F.grid_sample(vol, grid, mode=resample_mode, padding_mode=padding_mode, align_corners=True) def forward(self, x, ddf): # OPT: no device check or ref_grid regeneration — buffers handle it return self.resample(x, ddf=ddf, img_sz=self.img_sz, padding_mode=self.padding_mode) # ====================================================================== # Optimized RecMulModMutAttnNet # ====================================================================== class OptRecMulModMutAttnNet(RecMulModMutAttnNet): """RecMulModMutAttnNet with cached tensors for resample/forward. Eliminates per-call overhead: - resample(): cached max_sz tensor (was: NumPy→Torch→GPU every call) - forward(): cached img_sz tensor and ref_grid device placement """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Cache slots — populated on first forward self._cached_input_key = None self._cached_max_sz_tensor = None self._cached_img_sz_tensor = None # OPT: pre-compute constant permutation order self._perm = [0] + list(range(2, 2 + self.dimension)) + [1] def _ensure_cache(self, img_sz, device): """Populate cached tensors if input size or device changed.""" key = (tuple(img_sz), device) if key == self._cached_input_key: return self._cached_input_key = key max_sz_list = [img_sz[0]] * self.dimension self.max_sz = max_sz_list # OPT: create max_sz tensor ONCE, reuse across all resample() calls self._cached_max_sz_tensor = torch.Tensor( np.reshape(np.array(max_sz_list), [1, self.dimension] + [1] * self.dimension) ).to(device) # OPT: create img_sz tensor ONCE per size change self._cached_img_sz_tensor = torch.reshape( torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=device), [1] * (self.dimension + 1) + [self.dimension] ) # OPT: ref_grid — only regenerate if size changed, only .to() if needed if list(img_sz) != self.img_res: self.ref_grid = torch.reshape( torch.stack(torch.meshgrid( [torch.arange(end=imsz) for imsz in img_sz] ), 0), [1, self.dimension] + list(img_sz) ).to(device) elif self.ref_grid.device != torch.device(device): self.ref_grid = self.ref_grid.to(device) def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"): # OPT: use cached max_sz tensor instead of NumPy→Torch→GPU every call ref = self.ref_grid if ref is None else ref img_sz = self._cached_img_sz_tensor if img_sz is not None else self._cached_max_sz_tensor grid = torch.flip( (ddf * self._cached_max_sz_tensor + ref).permute(self._perm) / img_sz - 1, dims=[-1] ) return F.grid_sample(vol, grid, mode='bilinear', padding_mode=padding_mode, align_corners=True) def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2): self.device = x.device img_sz = x.size()[2:] n = x.size()[0] ts_emb_shape = [n, -1] + [1] * self.dimension # OPT: cache tensors — only recreate if input size/device changes self._ensure_cache(img_sz, self.device) self.img_sz = self._cached_img_sz_tensor img = x t = self.time_embed(t) if text is None: text = self.text text = text.to(self.device) txt_shape = [1, -1] + [1] * self.dimension else: txt_shape = [n, -1] + [1] * self.dimension for rec_id in range(rec_num): if self.conditional_input: tgt = y enc_list = [] out = img for i in range(self.hier_num): out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape)) if self.conditional_input: tgt = self.block_down_cond[i](tgt) + self.txt_layers[i](text).reshape(txt_shape) out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1)) tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1)) enc_list.append(out) out = self.down_layers[i](out) if self.conditional_input: tgt = self.down_layers[i](tgt) out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape)) if self.conditional_input: out_shape = out.shape tgt_shape = tgt.shape out_flat = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) tgt_flat = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) out_attn, _ = self.attn_layer0(out_flat, tgt_flat, tgt_flat) tgt_attn, _ = self.attn_layer1(tgt_flat, out_flat, out_flat) out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape) tgt_attn = tgt_attn.permute(1, 2, 0).contiguous().view(tgt_shape) out = out + out_attn tgt = tgt + tgt_attn out = self.fuse(torch.cat([out, tgt], dim=1)) if self.conditional_input: img_txt_feat = self.img2txt(out) self.img_embd = self.global_maxpool(img_txt_feat).view(n, -1) out_txt = self.txt_layers[-1](text).reshape(txt_shape) + img_txt_feat out_txt = self.txt_proc(out_txt) out_txt = self.txt2img(out_txt) out = out + out_txt for i in range(self.hier_num): out = torch.cat((self.up_layers[i](out), enc_list[-i - 1]), dim=1) out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape)) out = self.conv_out(out) / 128 ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz) if rec_id == 0: ddf = ddf_one else: ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border") img = self.resample(x, ddf=ddf, img_sz=self.img_sz) return ddf # ====================================================================== # Factory function # ====================================================================== def get_net_opt(name): """Return optimized network class if available, else fall back to original.""" if name == "recmulmodmutattnnet": return OptRecMulModMutAttnNet # Fall back to original for other network types from Diffusion.networks import get_net return get_net(name)