Omini3D / Diffusion /networks_opt.py
maxmo2009's picture
Sync from local: code + epoch-110 checkpoint, clean README
2af0e94 verified
"""
networks_opt.py — Optimized network components.
Subclasses RecMulModMutAttnNet and STN to eliminate per-call overhead:
1. OptSTN: register_buffer for ref_grid/max_sz — no .to(device) per call
2. OptRecMulModMutAttnNet: cached max_sz/img_sz tensors, ref_grid device —
eliminates ~80 NumPy→GPU transfers and ~32 tensor recreations per registration step
All optimizations are mathematically equivalent to the originals.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from Diffusion.networks import RecMulModMutAttnNet, STN
# ======================================================================
# Optimized STN
# ======================================================================
class OptSTN(STN):
"""STN with register_buffer for automatic device transfer.
Eliminates per-call .to(device) overhead in resample() and forward().
Buffers auto-transfer when module.to(device) is called.
"""
def __init__(self, ndims=2, img_sz=None, max_sz=None, device=None,
padding_mode="border", resample_mode=None):
# Skip parent __init__ to avoid creating plain tensor attributes
nn.Module.__init__(self)
self.ndims = ndims
self.img_sz = [img_sz] * ndims
self.device = device
self.padding_mode = padding_mode
self.resample_mode = resample_mode
# OPT: register_buffer — auto device transfer, no per-call .to()
max_sz_val = [img_sz] * ndims
max_sz_tensor = torch.Tensor(
np.reshape(np.array(max_sz_val), [1, self.ndims] + [1] * self.ndims)
)
self.register_buffer('max_sz', max_sz_tensor)
if self.img_sz is not None:
ref_grid = torch.reshape(
torch.stack(torch.meshgrid(
[torch.arange(end=s) for s in self.img_sz]
), 0),
[1, self.ndims] + self.img_sz
)
self.register_buffer('ref_grid', ref_grid)
# OPT: pre-compute the img_sz tensor used when forward() calls resample()
img_sz_for_resample = torch.reshape(
torch.tensor([(s - 1) / 2. for s in self.img_sz]),
[1] + [1] * self.ndims + [self.ndims]
)
self.register_buffer('_img_sz_for_resample', img_sz_for_resample)
# OPT: pre-compute constant permutation order
self._perm = [0] + list(range(2, 2 + self.ndims)) + [1]
def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
# OPT: no .to(device) — buffers auto-transfer with module.to()
ref = self.ref_grid if ref is None else ref
if img_sz is None:
img_sz_t = self.max_sz
else:
# Use pre-computed tensor for the common case (called from forward)
img_sz_t = self._img_sz_for_resample
resample_mode = 'bilinear' if self.resample_mode is None else self.resample_mode
grid = torch.flip(
(ddf * self.max_sz + ref).permute(self._perm) / img_sz_t - 1,
dims=[-1]
)
return F.grid_sample(vol, grid, mode=resample_mode,
padding_mode=padding_mode, align_corners=True)
def forward(self, x, ddf):
# OPT: no device check or ref_grid regeneration — buffers handle it
return self.resample(x, ddf=ddf, img_sz=self.img_sz,
padding_mode=self.padding_mode)
# ======================================================================
# Optimized RecMulModMutAttnNet
# ======================================================================
class OptRecMulModMutAttnNet(RecMulModMutAttnNet):
"""RecMulModMutAttnNet with cached tensors for resample/forward.
Eliminates per-call overhead:
- resample(): cached max_sz tensor (was: NumPy→Torch→GPU every call)
- forward(): cached img_sz tensor and ref_grid device placement
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Cache slots — populated on first forward
self._cached_input_key = None
self._cached_max_sz_tensor = None
self._cached_img_sz_tensor = None
# OPT: pre-compute constant permutation order
self._perm = [0] + list(range(2, 2 + self.dimension)) + [1]
def _ensure_cache(self, img_sz, device):
"""Populate cached tensors if input size or device changed."""
key = (tuple(img_sz), device)
if key == self._cached_input_key:
return
self._cached_input_key = key
max_sz_list = [img_sz[0]] * self.dimension
self.max_sz = max_sz_list
# OPT: create max_sz tensor ONCE, reuse across all resample() calls
self._cached_max_sz_tensor = torch.Tensor(
np.reshape(np.array(max_sz_list), [1, self.dimension] + [1] * self.dimension)
).to(device)
# OPT: create img_sz tensor ONCE per size change
self._cached_img_sz_tensor = torch.reshape(
torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=device),
[1] * (self.dimension + 1) + [self.dimension]
)
# OPT: ref_grid — only regenerate if size changed, only .to() if needed
if list(img_sz) != self.img_res:
self.ref_grid = torch.reshape(
torch.stack(torch.meshgrid(
[torch.arange(end=imsz) for imsz in img_sz]
), 0),
[1, self.dimension] + list(img_sz)
).to(device)
elif self.ref_grid.device != torch.device(device):
self.ref_grid = self.ref_grid.to(device)
def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
# OPT: use cached max_sz tensor instead of NumPy→Torch→GPU every call
ref = self.ref_grid if ref is None else ref
img_sz = self._cached_img_sz_tensor if img_sz is not None else self._cached_max_sz_tensor
grid = torch.flip(
(ddf * self._cached_max_sz_tensor + ref).permute(self._perm) / img_sz - 1,
dims=[-1]
)
return F.grid_sample(vol, grid, mode='bilinear',
padding_mode=padding_mode, align_corners=True)
def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2):
self.device = x.device
img_sz = x.size()[2:]
n = x.size()[0]
ts_emb_shape = [n, -1] + [1] * self.dimension
# OPT: cache tensors — only recreate if input size/device changes
self._ensure_cache(img_sz, self.device)
self.img_sz = self._cached_img_sz_tensor
img = x
t = self.time_embed(t)
if text is None:
text = self.text
text = text.to(self.device)
txt_shape = [1, -1] + [1] * self.dimension
else:
txt_shape = [n, -1] + [1] * self.dimension
for rec_id in range(rec_num):
if self.conditional_input:
tgt = y
enc_list = []
out = img
for i in range(self.hier_num):
out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
if self.conditional_input:
tgt = self.block_down_cond[i](tgt) + self.txt_layers[i](text).reshape(txt_shape)
out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
enc_list.append(out)
out = self.down_layers[i](out)
if self.conditional_input:
tgt = self.down_layers[i](tgt)
out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
if self.conditional_input:
out_shape = out.shape
tgt_shape = tgt.shape
out_flat = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1)
tgt_flat = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1)
out_attn, _ = self.attn_layer0(out_flat, tgt_flat, tgt_flat)
tgt_attn, _ = self.attn_layer1(tgt_flat, out_flat, out_flat)
out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape)
tgt_attn = tgt_attn.permute(1, 2, 0).contiguous().view(tgt_shape)
out = out + out_attn
tgt = tgt + tgt_attn
out = self.fuse(torch.cat([out, tgt], dim=1))
if self.conditional_input:
img_txt_feat = self.img2txt(out)
self.img_embd = self.global_maxpool(img_txt_feat).view(n, -1)
out_txt = self.txt_layers[-1](text).reshape(txt_shape) + img_txt_feat
out_txt = self.txt_proc(out_txt)
out_txt = self.txt2img(out_txt)
out = out + out_txt
for i in range(self.hier_num):
out = torch.cat((self.up_layers[i](out), enc_list[-i - 1]), dim=1)
out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))
out = self.conv_out(out) / 128
ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
if rec_id == 0:
ddf = ddf_one
else:
ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
return ddf
# ======================================================================
# Factory function
# ======================================================================
def get_net_opt(name):
"""Return optimized network class if available, else fall back to original."""
if name == "recmulmodmutattnnet":
return OptRecMulModMutAttnNet
# Fall back to original for other network types
from Diffusion.networks import get_net
return get_net(name)