Omini3D / Diffusion /networks.py
maxmo2009's picture
Sync from local: code + epoch-110 checkpoint, clean README
2af0e94 verified
from torch import nn
import torch
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint as grad_checkpoint
import numpy as np
import math
from Diffusion.safe_conv_transpose import SafeConvTranspose3d
class UpsampleConv(nn.Module):
"""Drop-in replacement for ConvTranspose3d/2d that avoids the XPU memory leak.
ConvTranspose3d backward leaks ~0.33 GiB/step on Intel XPU (oneDNN bug).
This uses F.interpolate (zero leak) + Conv (negligible leak) instead.
Also avoids checkerboard artifacts common with transposed convolutions.
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, ndims=3):
super().__init__()
self.scale_factor = stride
self.mode = 'trilinear' if ndims == 3 else 'bilinear'
Conv = getattr(nn, f'Conv{ndims}d')
self.conv = Conv(in_channels, out_channels, 3, 1, 1)
def forward(self, x):
x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=False)
return self.conv(x)
def get_net(name="recresnet"):
name = name.lower()
if name == "recresacnet":
net = RecResACNet
elif name == "recmutattnnet":
net = RecMutAttnNet
elif name == "recmutattnnet0":
net = RecMutAttnNet0
elif name == "recmutattnnet1":
net = RecMutAttnNet1
elif name == "defrecmutattnnet":
net = DefRec_MutAttnNet
elif name == "recmulmodmutattnnet":
net = RecMulModMutAttnNet
elif name == "om_net":
net = OM_net
else:
net = None
return net
def sinusoidal_embedding(n, d):
# Returns the standard positional embedding
embedding = torch.zeros(n, d)
wk = torch.tensor([1 / 10_000 ** (2 * j / d) for j in range(d)])
wk = wk.reshape((1, d))
t = torch.arange(n).reshape((n, 1))
embedding[:,::2] = torch.sin(t * wk[:,::2])
embedding[:,1::2] = torch.cos(t * wk[:,::2])
return embedding
class AtrousBlock(nn.Module):
def __init__(self, shape, in_c, out_c, kernel_size=3, stride=1, atrous_rates=[1,3], ndims=2, activation=None, normalize=True):
super(AtrousBlock, self).__init__()
# if 0 not in shape:
if normalize:
# print(shape)
# self.ln = nn.LayerNorm(shape) # jzheng 15/03/2024
norm=getattr(nn, 'InstanceNorm%dd' % ndims) # jzheng 15/03/2024
self.ln = norm(out_c,affine=True)
else:
self.ln = nn.Identity()
Conv=getattr(nn,'Conv%dd' % ndims)
if in_c!=out_c:
self.conv0 = Conv(in_c, out_c, kernel_size, 1, (kernel_size-1)//2*1) #if in_c!=out_c else None
else:
self.conv0 = None
self.convs = nn.ModuleList([
Conv(out_c, out_c, kernel_size, 1, (kernel_size-1)//2*ar, dilation=ar)
if ar>0 else Conv(out_c, out_c, 1, 1, 0)
for ar in atrous_rates
])
# self.conv1 = Conv(out_c, out_c, kernel_size, stride, padding)
# self.conv2 = Conv(out_c, out_c, kernel_size, stride, padding)
self.activation = nn.LeakyReLU(1e-6) if activation is None else activation
# self.activation = nn.ReLU() if activation is None else activation
# self.activation = nn.ReLU()
self.normalize = normalize
def forward(self, x):
if self.conv0 is not None:
x = self.conv0(x) #if self.conv0 is not None else x
x = self.ln(x) if self.normalize else x # jzheng 15/03/2024
out=nn.Identity()(x)
for conv in self.convs:
out = self.activation(out)
out = conv(out)
return self.activation(out+x)
# ==============================================
# Unconditional Network
# ==============================================
class RecResACNet(nn.Module):
def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0):
super(RecResACNet, self).__init__()
self.dimension = ndims
self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
# Sinusoidal embedding
self.time_embed = nn.Embedding(n_steps, time_emb_dim)
self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
self.time_embed.requires_grad_(False)
# First half
self.te1 = self._make_te(time_emb_dim, 1)
self.b1 = nn.Sequential(
AtrousBlock([num_input_chn] + [res] * ndims, num_input_chn, 10, ndims=ndims),
AtrousBlock([10] + [res] * ndims, 10, 10, ndims=ndims),
AtrousBlock([10] + [res] * ndims, 10, 10, ndims=ndims),
)
self.down1 = self.Conv(10, 10, 4, 2, 1)
self.te2 = self._make_te(time_emb_dim, 10)
self.b2 = nn.Sequential(
AtrousBlock([10] + [res // 2] * ndims, 10, 20, ndims=ndims),
AtrousBlock([20] + [res // 2] * ndims, 20, 20, ndims=ndims),
AtrousBlock([20] + [res // 2] * ndims, 20, 20, ndims=ndims)
)
self.down2 = self.Conv(20, 20, 4, 2, 1)
self.te3 = self._make_te(time_emb_dim, 20)
self.b3 = nn.Sequential(
AtrousBlock([20] + [res // 4] * ndims, 20, 40, ndims=ndims),
AtrousBlock([40] + [res // 4] * ndims, 40, 40, ndims=ndims),
AtrousBlock([40] + [res // 4] * ndims, 40, 40, ndims=ndims)
)
self.down3 = self.Conv(40, 40, 4, 2, 1)
# Bottleneck
self.te_mid = self._make_te(time_emb_dim, 40)
self.b_mid = nn.Sequential(
AtrousBlock([40] + [res // 8] * ndims, 40, 20, ndims=ndims),
AtrousBlock([20] + [res // 8] * ndims, 20, 20, ndims=ndims),
AtrousBlock([20] + [res // 8] * ndims, 20, 40, ndims=ndims)
)
# Second half
self.up1 = self.ConvT(40, 40, 4, 2, 1)
self.te4 = self._make_te(time_emb_dim, 80)
self.b4 = nn.Sequential(
AtrousBlock([80] + [res // 4] * ndims, 80, 40, ndims=ndims, normalize=False),
AtrousBlock([40] + [res // 4] * ndims, 40, 20, ndims=ndims, normalize=False),
AtrousBlock([20] + [res // 4] * ndims, 20, 20, ndims=ndims, normalize=False)
)
self.up2 = self.ConvT(20, 20, 4, 2, 1)
self.te5 = self._make_te(time_emb_dim, 40)
self.b5 = nn.Sequential(
AtrousBlock([40] + [res // 2] * ndims, 40, 20, ndims=ndims, normalize=False),
AtrousBlock([20] + [res // 2] * ndims, 20, 10, ndims=ndims, normalize=False),
AtrousBlock([10] + [res // 2] * ndims, 10, 10, ndims=ndims, normalize=False)
)
self.up3 = self.ConvT(10, 10, 4, 2, 1)
self.te_out = self._make_te(time_emb_dim, 20)
self.b_out = nn.Sequential(
AtrousBlock([20] + [res // 1] * ndims, 20, 10, ndims=ndims, normalize=False),
AtrousBlock([10] + [res // 1] * ndims, 10, 10, ndims=ndims, normalize=False),
AtrousBlock([10] + [res // 1] * ndims, 10, 10, ndims=ndims, normalize=False)
)
self.conv_out = self.Conv(10, ndims, 3, 1, 1)
def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
zip(sample_coords, max_sz)], 1)
def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
ref = self.ref_grid if ref is None else ref
img_sz = self.max_sz if img_sz is None else img_sz
# resample_mode = 'bicubic'
resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear'
# padding_mode = "border"
if True:
# return F.grid_sample(vol, torch.flip(torch.transpose(ddf * torch.Tensor(np.reshape(np.array(self.max_sz), [1, 1, 1, self.dimension])).cuda() + ref,[0, 2, 3, 1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,align_corners=True)
return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
[0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
align_corners=True)
def forward(self, x=None, t=None, y=None, rec_num=2, ndims=2):
#
self.device = x.device
# [h, w] = x.size()[2:]
img_sz = x.size()[2:]
n = x.size()[0]
self.max_sz = [img_sz[0]] * self.dimension
ts_emb_shape=[n,-1]+[1]*self.dimension
# [h,w]=img_sz
# self.img_sz = torch.reshape(torch.tensor([(h - 1) / 2., (w - 1) / 2.], device=self.device), [1, 1, 1, 2])
self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
# self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=h), torch.arange(end=w)]), 0),
# [1, 2, h, w]).to(self.device)
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
[1, self.dimension]+list(img_sz)).to(self.device)
img = x
# x is (N, 2, 28, 28) (image with positional embedding stacked on channel dimension)
t = self.time_embed(t)
for rec_id in range(rec_num):
out1 = self.b1(img + self.te1(t).reshape(ts_emb_shape)) # (N, 10, 28, 28)
out2 = self.b2(self.down1(out1) + self.te2(t).reshape(ts_emb_shape)) # (N, 20, 14, 14)
out3 = self.b3(self.down2(out2) + self.te3(t).reshape(ts_emb_shape)) # (N, 40, 7, 7)
out_mid = self.b_mid(self.down3(out3) * self.te_mid(t).reshape(ts_emb_shape)) # (N, 40, 3, 3)
out4 = torch.cat((out3, self.up1(out_mid)), dim=1) # (N, 80, 7, 7)
out4 = self.b4(out4 + self.te4(t).reshape(ts_emb_shape)) # (N, 20, 7, 7)
out5 = torch.cat((out2, self.up2(out4)), dim=1) # (N, 40, 14, 14)
out5 = self.b5(out5 + self.te5(t).reshape(ts_emb_shape)) # (N, 10, 14, 14)
out = torch.cat((out1, self.up3(out5)), dim=1) # (N, 20, 28, 28)
out = self.b_out(out + self.te_out(t).reshape(ts_emb_shape)) # (N, 1, 28, 28)
out = self.conv_out(out)
ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
if rec_id == 0:
ddf = ddf_one
else:
ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
return ddf
def _make_te(self, dim_in, dim_out):
# make time embedding
return nn.Sequential(
nn.Linear(dim_in, dim_out),
# nn.SiLU(),
nn.ReLU(),
nn.Linear(dim_out, dim_out)
)
# ==============================================
# Conditional Network
# ==============================================
class cross_attn(nn.Module):
def __init__(self, q, k, v, ndims=2):
self.q = q
self.k = k
self.v = v
self.ndims = ndims
self.Conv = getattr(nn, 'Conv%dd' % self.ndims)
self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.ndims)
self.softmax = nn.Softmax(dim=-1)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x, y):
q = self.q(x)
k = self.k(y)
v = self.v(y)
attn = self.softmax(torch.matmul(q, k.transpose(-2, -1)))
out = torch.matmul(attn, v)
return out
class DefRec_MutAttnNet(nn.Module):
def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4):
super(DefRec_MutAttnNet, self).__init__()
# self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64]
# self.feat_channels = [num_input_chn, 16, 32, 64, 128, 256]
self.feat_channels = [num_input_chn, 16, 32, 128, 256, 512]
self.conditional_input = conditional_input
self.num_heads = num_heads
self.text_feat_chn = text_feat_chn
self.dimension = ndims
self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
self.copy = nn.Identity()
# Sinusoidal embedding
self.time_embed = nn.Embedding(n_steps, time_emb_dim)
self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
self.time_embed.requires_grad_(False)
self.hier_num = len(self.feat_channels) - 1
self.down_layers = nn.ModuleList()
self.up_layers = nn.ModuleList()
self.ted_layers = nn.ModuleList()
self.teu_layers = nn.ModuleList()
self.block_down = nn.ModuleList()
self.block_up = nn.ModuleList()
if self.conditional_input:
self.block_down_cond = nn.ModuleList()
self.fuse_conv0 = nn.ModuleList()
# self.fuse_conv1 = nn.ModuleList()
self.attn_layer = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
Global_Maxpool = getattr(nn, 'AdaptiveMaxPool%dd' % self.dimension)
self.global_maxpool = Global_Maxpool(1)
self.img2txt = self.Conv(self.feat_channels[-1], self.text_feat_chn, 1, 1, 0)
self.txt_proc = AtrousBlock([self.text_feat_chn] + [1] * ndims, self.text_feat_chn, self.text_feat_chn, ndims=ndims, normalize=False, atrous_rates=[0, 0])
self.txt2img = self.Conv(self.text_feat_chn, self.feat_channels[-1], 1, 1, 0)
self.text = torch.zeros(1, self.text_feat_chn, *([1]*self.dimension))
self.img_res = [res]*self.dimension
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in self.img_res]), 0),
[1, self.dimension]+list(self.img_res))
for i in range(1, self.hier_num + 1):
j=-i
self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
self.block_down.append(nn.Sequential(
AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
))
if self.conditional_input:
self.block_down_cond.append(nn.Sequential(
AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
))
self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
# self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
if i==self.hier_num:
k=j
else:
k=j-1
self.block_up.append(nn.Sequential(
AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
))
# Bottleneck
self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
self.b_mid = nn.Sequential(
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
)
self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
zip(sample_coords, max_sz)], 1)
def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
ref = self.ref_grid if ref is None else ref
img_sz = self.max_sz if img_sz is None else img_sz
resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear'
return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
[0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
align_corners=True)
def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2):
self.device = x.device
img_sz = x.size()[2:]
n = x.size()[0]
self.max_sz = [img_sz[0]] * self.dimension
ts_emb_shape=[n,-1]+[1]*self.dimension
self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
if list(img_sz) != self.img_res:
# print ("Reinitialize the ref_grid to match the model's input image size.")
# print(img_sz, self.img_res)
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
[1, self.dimension]+list(img_sz))
self.ref_grid = self.ref_grid.to(self.device)
img = x
if self.conditional_input:
tgt = y
# encode the conditional input
tgt_down_list = []
for i in range(self.hier_num):
# out = self.block_down[i](out + self.ted_layers[i](t_emb).reshape(ts_emb_shape))
if self.conditional_input:
tgt = self.block_down_cond[i](tgt)
tgt_down_list.append(self.copy(tgt))
tgt = self.down_layers[i](tgt)
tgt_mid = self.copy(tgt)
tgt_shape = tgt_mid.shape
# out = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
tgt_mid = tgt_mid.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
t = [t0.to(self.device) for t0 in t]
t = [t0 for _ in range(rec_num) for t0 in t]
for rec_id,time in enumerate(t):
t_emb = self.time_embed(time)
# for rec_id in range(rec_num):
# if self.conditional_input:
# tgt = y
enc_list = []
out = img
for i in range(self.hier_num):
out = self.block_down[i](out + self.ted_layers[i](t_emb).reshape(ts_emb_shape))
if self.conditional_input:
# tgt = self.block_down_cond[i](tgt)
out = self.fuse_conv0[i](torch.cat([out, tgt_down_list[i]], axis=1))
# tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
enc_list.append(out)
out = self.down_layers[i](out)
# if self.conditional_input:
# tgt = self.down_layers[i](tgt)
out = self.b_mid(out + self.tmid(t_emb).reshape(ts_emb_shape))
if self.conditional_input:
# out += self.attn_layer(out, tgt, tgt)[0]
out_shape = out.shape
# tgt_shape = tgt.shape
# # out = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
# tgt = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
out_attn, _ = self.attn_layer(out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1), tgt_mid, tgt_mid)
out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape) # (H*W, N, C) -> (N, C, H, W)
out = out + out_attn
if self.conditional_input:
if text is None:
text = self.text
text = text.to(self.device)
out_txt = self.img2txt(out) + text
out_txt = self.txt_proc(out_txt)
out_txt = self.txt2img(out_txt)
out = out + out_txt
for i in range(self.hier_num):
out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
out = self.block_up[i](out + self.teu_layers[i](t_emb).reshape(ts_emb_shape))
out = self.conv_out(out)/128
ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
if rec_id == 0:
ddf = ddf_one
else:
ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
return ddf
def _make_te(self, dim_in, dim_out):
return nn.Sequential(
nn.Linear(dim_in, dim_out),
nn.ReLU(),
nn.Linear(dim_out, dim_out)
)
class RecMutAttnNet1(nn.Module):
def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4):
super(RecMutAttnNet1, self).__init__()
# self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64]
self.feat_channels = [num_input_chn, 16, 32, 64, 128, 256]
self.conditional_input = conditional_input
self.num_heads = num_heads
self.text_feat_chn = text_feat_chn
self.dimension = ndims
self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
# Sinusoidal embedding
self.time_embed = nn.Embedding(n_steps, time_emb_dim)
self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
self.time_embed.requires_grad_(False)
self.hier_num = len(self.feat_channels) - 1
self.down_layers = nn.ModuleList()
self.up_layers = nn.ModuleList()
self.ted_layers = nn.ModuleList()
self.teu_layers = nn.ModuleList()
self.block_down = nn.ModuleList()
if self.conditional_input:
self.block_down_cond = nn.ModuleList()
self.fuse_conv0 = nn.ModuleList()
self.fuse_conv1 = nn.ModuleList()
self.attn_layer = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
self.block_up = nn.ModuleList()
for i in range(1, self.hier_num + 1):
j=-i
self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
self.block_down.append(nn.Sequential(
AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
))
if self.conditional_input:
self.block_down_cond.append(nn.Sequential(
AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
))
self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
if i==self.hier_num:
k=j
else:
k=j-1
self.block_up.append(nn.Sequential(
AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
))
# Bottleneck
self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
self.b_mid = nn.Sequential(
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
)
self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
zip(sample_coords, max_sz)], 1)
def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
ref = self.ref_grid if ref is None else ref
img_sz = self.max_sz if img_sz is None else img_sz
resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear'
return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
[0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
align_corners=True)
def forward(self, x=None, y=None, t=None, rec_num=2, ndims=2):
self.device = x.device
img_sz = x.size()[2:]
n = x.size()[0]
self.max_sz = [img_sz[0]] * self.dimension
ts_emb_shape=[n,-1]+[1]*self.dimension
self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
[1, self.dimension]+list(img_sz)).to(self.device)
img = x
t = self.time_embed(t)
for rec_id in range(rec_num):
if self.conditional_input:
tgt = y
enc_list = []
out = img
for i in range(self.hier_num):
out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
if self.conditional_input:
tgt = self.block_down_cond[i](tgt)
out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
enc_list.append(out)
out = self.down_layers[i](out)
if self.conditional_input:
tgt = self.down_layers[i](tgt)
out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
if self.conditional_input:
# out += self.attn_layer(out, tgt, tgt)[0]
out_shape = out.shape
tgt_shape = tgt.shape
# out = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
tgt = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
out_attn, _ = self.attn_layer(out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1), tgt, tgt)
out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape) # (H*W, N, C) -> (N, C, H, W)
out = out + out_attn
for i in range(self.hier_num):
out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))
out = self.conv_out(out)/128
ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
if rec_id == 0:
ddf = ddf_one
else:
ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
return ddf
def _make_te(self, dim_in, dim_out):
return nn.Sequential(
nn.Linear(dim_in, dim_out),
nn.ReLU(),
nn.Linear(dim_out, dim_out)
)
class RecMutAttnNet(nn.Module):
def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4):
super(RecMutAttnNet, self).__init__()
# self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64]
self.feat_channels = [num_input_chn, 16, 32, 64, 128, 256]
self.conditional_input = conditional_input
self.num_heads = num_heads
self.text_feat_chn = text_feat_chn
self.dimension = ndims
self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
# Sinusoidal embedding
self.time_embed = nn.Embedding(n_steps, time_emb_dim)
self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
self.time_embed.requires_grad_(False)
self.hier_num = len(self.feat_channels) - 1
self.down_layers = nn.ModuleList()
self.up_layers = nn.ModuleList()
self.ted_layers = nn.ModuleList()
self.teu_layers = nn.ModuleList()
self.block_down = nn.ModuleList()
self.block_up = nn.ModuleList()
if self.conditional_input:
self.block_down_cond = nn.ModuleList()
self.fuse_conv0 = nn.ModuleList()
self.fuse_conv1 = nn.ModuleList()
self.attn_layer = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
Global_Maxpool = getattr(nn, 'AdaptiveMaxPool%dd' % self.dimension)
self.global_maxpool = Global_Maxpool(1)
self.img2txt = self.Conv(self.feat_channels[-1], self.text_feat_chn, 1, 1, 0)
self.txt_proc = AtrousBlock([self.text_feat_chn] + [1] * ndims, self.text_feat_chn, self.text_feat_chn, ndims=ndims, normalize=False, atrous_rates=[0, 0])
self.txt2img = self.Conv(self.text_feat_chn, self.feat_channels[-1], 1, 1, 0)
self.text = torch.zeros(1, self.text_feat_chn, *([1]*self.dimension))
self.img_res = [res]*self.dimension
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in self.img_res]), 0),
[1, self.dimension]+list(self.img_res))
for i in range(1, self.hier_num + 1):
j=-i
self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
self.block_down.append(nn.Sequential(
AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
))
if self.conditional_input:
self.block_down_cond.append(nn.Sequential(
AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
))
self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
if i==self.hier_num:
k=j
else:
k=j-1
self.block_up.append(nn.Sequential(
AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
))
# Bottleneck
self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
self.b_mid = nn.Sequential(
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
)
self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
zip(sample_coords, max_sz)], 1)
def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
ref = self.ref_grid if ref is None else ref
img_sz = self.max_sz if img_sz is None else img_sz
resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear'
return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
[0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
align_corners=True)
def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2):
self.device = x.device
img_sz = x.size()[2:]
n = x.size()[0]
self.max_sz = [img_sz[0]] * self.dimension
ts_emb_shape=[n,-1]+[1]*self.dimension
self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
if list(img_sz) != self.img_res:
# print ("Reinitialize the ref_grid to match the model's input image size.")
# print(img_sz, self.img_res)
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
[1, self.dimension]+list(img_sz))
self.ref_grid = self.ref_grid.to(self.device)
img = x
t = self.time_embed(t)
for rec_id in range(rec_num):
if self.conditional_input:
tgt = y
enc_list = []
out = img
for i in range(self.hier_num):
out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
if self.conditional_input:
tgt = self.block_down_cond[i](tgt)
out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
enc_list.append(out)
out = self.down_layers[i](out)
if self.conditional_input:
tgt = self.down_layers[i](tgt)
out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
if self.conditional_input:
# out += self.attn_layer(out, tgt, tgt)[0]
out_shape = out.shape
tgt_shape = tgt.shape
# out = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
tgt = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
out_attn, _ = self.attn_layer(out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1), tgt, tgt)
out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape) # (H*W, N, C) -> (N, C, H, W)
out = out + out_attn
if self.conditional_input:
if text is None:
text = self.text
text = text.to(self.device)
text = text.view(-1, self.text_feat_chn, *([1]*self.dimension))
out_txt = self.img2txt(out) + text
out_txt = self.txt_proc(out_txt)
out_txt = self.txt2img(out_txt)
out = out + out_txt
for i in range(self.hier_num):
out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))
out = self.conv_out(out)/128
ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
if rec_id == 0:
ddf = ddf_one
else:
ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
# print(torch.max(torch.abs(ddf)))
return ddf
def _make_te(self, dim_in, dim_out):
return nn.Sequential(
nn.Linear(dim_in, dim_out),
nn.ReLU(),
nn.Linear(dim_out, dim_out)
)
class RecMulModMutAttnNet(nn.Module):
def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4):
super(RecMulModMutAttnNet, self).__init__()
# self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64]
self.feat_channels = [num_input_chn, 16, 32, 64, 128, 256]
self.conditional_input = conditional_input
self.num_heads = num_heads
self.text_feat_chn = text_feat_chn
self.dimension = ndims
self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
# Sinusoidal embedding
self.time_embed = nn.Embedding(n_steps, time_emb_dim)
self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
self.time_embed.requires_grad_(False)
self.hier_num = len(self.feat_channels) - 1
self.down_layers = nn.ModuleList()
self.up_layers = nn.ModuleList()
self.ted_layers = nn.ModuleList()
self.teu_layers = nn.ModuleList()
self.block_down = nn.ModuleList()
self.block_up = nn.ModuleList()
if self.conditional_input:
# self.gate_img = nn.ModuleList()
self.txt_layers = nn.ModuleList()
self.block_down_cond = nn.ModuleList()
self.fuse_conv0 = nn.ModuleList()
self.fuse_conv1 = nn.ModuleList()
self.attn_layer0 = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
self.attn_layer1 = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
Global_Maxpool = getattr(nn, 'AdaptiveMaxPool%dd' % self.dimension)
self.global_maxpool = Global_Maxpool(1)
self.img2txt = self.Conv(self.feat_channels[-1], self.text_feat_chn, 1, 1, 0)
self.txt_proc = AtrousBlock([self.text_feat_chn] + [1] * ndims, self.text_feat_chn, self.text_feat_chn, ndims=ndims, normalize=False, atrous_rates=[0, 0])
self.txt2img = self.Conv(self.text_feat_chn, self.feat_channels[-1], 1, 1, 0)
# self.text = torch.zeros(1, self.text_feat_chn, *([1]*self.dimension))
self.text = torch.zeros(1, self.text_feat_chn)
self.img_res = [res]*self.dimension
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in self.img_res]), 0),
[1, self.dimension]+list(self.img_res))
for i in range(1, self.hier_num + 1):
j=-i
self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
self.block_down.append(nn.Sequential(
AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
))
if self.conditional_input:
# self.gate_img.append(nn.Sequential(
# nn.ConvNd(self.dimension, self.feat_channels[i], self.feat_channels[i], kernel_size=1, stride=1, padding=0),
# nn.Sigmoid()
# ))
self.txt_layers.append((self._make_te(self.text_feat_chn, self.feat_channels[i])))
self.block_down_cond.append(nn.Sequential(
AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
))
self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
if i==self.hier_num:
k=j
else:
k=j-1
self.block_up.append(nn.Sequential(
AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
))
# Bottleneck
self.txt_layers.append((self._make_te(self.text_feat_chn, self.text_feat_chn)))
self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
self.b_mid = nn.Sequential(
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
)
self.fuse = self.Conv(2*self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], 1, 1, 0)
self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
zip(sample_coords, max_sz)], 1)
def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
ref = self.ref_grid if ref is None else ref
img_sz = self.max_sz if img_sz is None else img_sz
resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear'
return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
[0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
align_corners=True)
def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2):
self.device = x.device
img_sz = x.size()[2:]
n = x.size()[0]
self.max_sz = [img_sz[0]] * self.dimension
ts_emb_shape=[n,-1]+[1]*self.dimension
self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
if list(img_sz) != self.img_res:
# print ("Reinitialize the ref_grid to match the model's input image size.")
# print(img_sz, self.img_res)
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
[1, self.dimension]+list(img_sz))
self.ref_grid = self.ref_grid.to(self.device)
img = x
t = self.time_embed(t)
if text is None:
text = self.text
# print(text.shape)
text = text.to(self.device)
txt_shape = [1,-1]+[1]*self.dimension
else:
txt_shape = [n,-1]+[1]*self.dimension
for rec_id in range(rec_num):
if self.conditional_input:
tgt = y
enc_list = []
out = img
for i in range(self.hier_num):
out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
if self.conditional_input:
tgt = self.block_down_cond[i](tgt) + self.txt_layers[i](text).reshape(txt_shape)
out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
enc_list.append(out)
out = self.down_layers[i](out)
if self.conditional_input:
tgt = self.down_layers[i](tgt)
out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
if self.conditional_input:
# out += self.attn_layer(out, tgt, tgt)[0]
out_shape = out.shape
tgt_shape = tgt.shape
out_flat = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
tgt_flat = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C)
out_attn, _ = self.attn_layer0(out_flat, tgt_flat, tgt_flat)
tgt_attn, _ = self.attn_layer1(tgt_flat, out_flat, out_flat)
out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape) # (H*W, N, C) -> (N, C, H, W)
tgt_attn = tgt_attn.permute(1, 2, 0).contiguous().view(tgt_shape) # (H*W, N, C) -> (N, C, H, W)
out = out + out_attn
tgt = tgt + tgt_attn
out = self.fuse(torch.cat([out, tgt], dim=1))
if self.conditional_input:
# text = text.view(-1, self.text_feat_chn, *([1]*self.dimension))
# out_txt = self.img2txt(out) + text.reshape(txt_shape)
img_txt_feat = self.img2txt(out)
self.img_embd = self.global_maxpool(img_txt_feat).view(n, -1) # [B, 1024]
out_txt = self.txt_layers[-1](text).reshape(txt_shape) + img_txt_feat
out_txt = self.txt_proc(out_txt)
out_txt = self.txt2img(out_txt)
out = out + out_txt
for i in range(self.hier_num):
out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))
out = self.conv_out(out)/128
ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
if rec_id == 0:
ddf = ddf_one
else:
ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
# print(torch.max(torch.abs(ddf)))
return ddf
def _make_te(self, dim_in, dim_out):
return nn.Sequential(
nn.Linear(dim_in, dim_out),
nn.ReLU(),
nn.Linear(dim_out, dim_out)
)
class OM_net(nn.Module):
"""
Extended RecMulModMutAttnNet with gated attention mechanisms:
1. Text Gate (bottleneck): sigmoid weight w_txt to interpolate between
text-enhanced features and raw image features. Learns to suppress
text branch when text embedding is zeros (no text provided).
2. Target Gate (each encoder level): per-voxel spatial gate using
residual AtrousBlock to identify condition vs. noise voxels in the
target/condition image path, weighting the fuse_conv1 output.
Supports gradient checkpointing via `use_checkpoint` flag to reduce
peak activation memory (trades compute for memory).
"""
def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0,
conditional_input=True, text_feat_chn=1024, num_heads=4,
use_conv_transpose=False):
super(OM_net, self).__init__()
self.use_checkpoint = False # Set True to enable gradient checkpointing
self.use_conv_transpose = use_conv_transpose
self.feat_channels = [num_input_chn, 12, 32, 64, 128, 512]
self.conditional_input = conditional_input
self.num_heads = num_heads
self.text_feat_chn = text_feat_chn
self.dimension = ndims
self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
# Sinusoidal embedding
self.time_embed = nn.Embedding(n_steps, time_emb_dim)
self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
self.time_embed.requires_grad_(False)
self.hier_num = len(self.feat_channels) - 1
self.down_layers = nn.ModuleList()
self.up_layers = nn.ModuleList()
self.ted_layers = nn.ModuleList()
self.teu_layers = nn.ModuleList()
self.block_down = nn.ModuleList()
self.block_up = nn.ModuleList()
if self.conditional_input:
self.txt_layers = nn.ModuleList()
self.block_down_cond = nn.ModuleList()
self.fuse_conv0 = nn.ModuleList()
self.fuse_conv1 = nn.ModuleList()
self.tgt_gate = nn.ModuleList() # Target gate per encoder level
self.attn_layer0 = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
self.attn_layer1 = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
Global_Maxpool = getattr(nn, 'AdaptiveMaxPool%dd' % self.dimension)
self.global_maxpool = Global_Maxpool(1)
self.img2txt = self.Conv(self.feat_channels[-1], self.text_feat_chn, 1, 1, 0)
self.txt_proc = AtrousBlock([self.text_feat_chn] + [1] * ndims, self.text_feat_chn, self.text_feat_chn, ndims=ndims, normalize=False, atrous_rates=[0, 0])
self.txt2img = self.Conv(self.text_feat_chn, self.feat_channels[-1], 1, 1, 0)
self.text = torch.zeros(1, self.text_feat_chn)
# Text Gate: text-only MLP → sigmoid weight (computed before rec loop)
self.text_gate = nn.Sequential(
nn.Linear(self.text_feat_chn, self.text_feat_chn // 4),
nn.ReLU(),
nn.Linear(self.text_feat_chn // 4, 1),
nn.Sigmoid()
)
self.img_res = [res]*self.dimension
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in self.img_res]), 0),
[1, self.dimension]+list(self.img_res))
for i in range(1, self.hier_num + 1):
j=-i
self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
self.up_layers.append(SafeConvTranspose3d(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
self.block_down.append(nn.Sequential(
AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
))
if self.conditional_input:
self.txt_layers.append((self._make_te(self.text_feat_chn, self.feat_channels[i])))
self.block_down_cond.append(nn.Sequential(
AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
))
self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
# Target Gate: residual AtrousBlock → 2-channel softmax (condition vs noise)
self.tgt_gate.append(nn.Sequential(
AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims,
self.feat_channels[i], self.feat_channels[i], ndims=ndims, atrous_rates=[1, 3]),
self.Conv(self.feat_channels[i], 2, 1, 1, 0)
))
if i==self.hier_num:
k=j
else:
k=j-1
self.block_up.append(nn.Sequential(
AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
))
# Bottleneck
self.txt_layers.append((self._make_te(self.text_feat_chn, self.text_feat_chn)))
self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
self.b_mid = nn.Sequential(
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
)
self.fuse = self.Conv(2*self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], 1, 1, 0)
self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
# Initialize target gates toward pass-through (condition confidence high)
self._init_tgt_gates()
def _init_tgt_gates(self):
"""Bias target gates so condition channel starts moderately high (~0.73).
Milder than [2,-2] to ensure both cond*tgt and (1-cond)*out halves of
fuse_conv1 input have enough signal for healthy early gradient flow."""
for gate_seq in self.tgt_gate:
final_conv = gate_seq[-1] # the Conv that outputs 2 channels
with torch.no_grad():
final_conv.bias.data[0] = 1.0 # condition channel → softmax ~0.73
final_conv.bias.data[1] = -1.0 # noise channel → softmax ~0.27
def _encoder_level(self, i, out, tgt, t, ts_emb_shape, text, txt_shape, w_txt):
"""Single encoder level — extracted for gradient checkpointing."""
out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
if self.conditional_input and tgt is not None:
tgt = self.block_down_cond[i](tgt) + w_txt * self.txt_layers[i](text).reshape(txt_shape)
gate_logits = self.tgt_gate[i](tgt)
cond_confidence = F.softmax(gate_logits, dim=1)[:, 0:1]
tgt = self.fuse_conv1[i](torch.cat([cond_confidence*tgt, (1-cond_confidence)*out], axis=1))
out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
return out, tgt
def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
zip(sample_coords, max_sz)], 1)
def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
ref = self.ref_grid if ref is None else ref
img_sz = self.max_sz if img_sz is None else img_sz
resample_mode = 'bilinear'
return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
[0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
align_corners=True)
def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2):
self.device = x.device
img_sz = x.size()[2:]
n = x.size()[0]
self.max_sz = [img_sz[0]] * self.dimension
ts_emb_shape=[n,-1]+[1]*self.dimension
self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
if list(img_sz) != self.img_res:
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
[1, self.dimension]+list(img_sz))
self.ref_grid = self.ref_grid.to(self.device)
img = x
t = self.time_embed(t)
if text is None:
text = self.text
text = text.to(self.device)
txt_shape = [1,-1]+[1]*self.dimension
else:
txt_shape = [n,-1]+[1]*self.dimension
# Text Gate: compute w_txt from text embedding alone before rec loop
txt_vec = text.view(text.size(0), -1) # [1, 1024] or [n, 1024]
if txt_vec.size(0) == 1 and n > 1:
txt_vec = txt_vec.expand(n, -1)
w_txt = self.text_gate(txt_vec) # [B, 1]
w_txt = w_txt.view([w_txt.size(0), 1] + [1] * self.dimension)
for rec_id in range(rec_num):
if self.conditional_input:
tgt = y
enc_list = []
out = img
for i in range(self.hier_num):
# Gradient checkpointing on early encoder levels (large feature maps)
# to reduce peak activation memory. Levels 0-2 have 128^3, 64^3, 32^3 maps.
if self.use_checkpoint and self.training and i < 3:
out, tgt = grad_checkpoint(
self._encoder_level, i, out, tgt if self.conditional_input else None,
t, ts_emb_shape, text, txt_shape, w_txt,
use_reentrant=False,
)
else:
out, tgt = self._encoder_level(
i, out, tgt if self.conditional_input else None,
t, ts_emb_shape, text, txt_shape, w_txt,
)
enc_list.append(out)
out = self.down_layers[i](out)
if self.conditional_input:
tgt = self.down_layers[i](tgt)
out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
if self.conditional_input:
out_shape = out.shape
tgt_shape = tgt.shape
out_flat = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1)
tgt_flat = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1)
out_attn, _ = self.attn_layer0(out_flat, tgt_flat, tgt_flat)
tgt_attn, _ = self.attn_layer1(tgt_flat, out_flat, out_flat)
out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape)
tgt_attn = tgt_attn.permute(1, 2, 0).contiguous().view(tgt_shape)
out = out + out_attn
tgt = tgt + tgt_attn
out = self.fuse(torch.cat([out, tgt], dim=1))
if self.conditional_input:
img_txt_feat = self.img2txt(out)
self.img_embd = self.global_maxpool(img_txt_feat).view(n, -1) # [B, 1024]
out_txt = self.txt_layers[-1](text).reshape(txt_shape) - img_txt_feat
out_txt = self.txt_proc(out_txt)
out_txt = self.txt2img(out_txt)
# Text Gate: w_txt precomputed from text embedding alone
out = (1 - w_txt) * out + w_txt * out_txt
for i in range(self.hier_num):
out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))
out = self.conv_out(out)/128
ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
if rec_id == 0:
ddf = ddf_one
else:
ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
return ddf
def _make_te(self, dim_in, dim_out):
return nn.Sequential(
nn.Linear(dim_in, dim_out),
nn.ReLU(),
nn.Linear(dim_out, dim_out)
)
# class RecMutAttnNet(nn.Module):
# def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True):
# super(RecMutAttnNet, self).__init__()
# self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64]
# self.conditional_input = conditional_input
# self.dimension = ndims
# self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
# self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
# # Sinusoidal embedding
# self.time_embed = nn.Embedding(n_steps, time_emb_dim)
# self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
# self.time_embed.requires_grad_(False)
# self.hier_num = len(self.feat_channels) - 1
# self.down_layers = nn.ModuleList()
# self.up_layers = nn.ModuleList()
# self.ted_layers = nn.ModuleList()
# self.teu_layers = nn.ModuleList()
# self.block_down = nn.ModuleList()
# if self.conditional_input:
# self.block_down_cond = nn.ModuleList()
# self.fuse_conv0 = nn.ModuleList()
# self.fuse_conv1 = nn.ModuleList()
# self.block_up = nn.ModuleList()
# for i in range(1, self.hier_num + 1):
# j=-i
# self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
# self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
# self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
# self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
# self.block_down.append(nn.Sequential(
# AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
# AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
# AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
# ))
# if self.conditional_input:
# self.block_down_cond.append(nn.Sequential(
# AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
# AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
# AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
# ))
# self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
# self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
# if i==self.hier_num:
# k=j
# else:
# k=j-1
# self.block_up.append(nn.Sequential(
# AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
# AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
# AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
# ))
# # Bottleneck
# self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
# self.b_mid = nn.Sequential(
# AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
# AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
# AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
# )
# self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
# def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
# sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
# return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
# zip(sample_coords, max_sz)], 1)
# def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
# ref = self.ref_grid if ref is None else ref
# img_sz = self.max_sz if img_sz is None else img_sz
# resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear'
# return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
# np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
# [0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
# align_corners=True)
# def forward(self, x=None, y=None, t=None, rec_num=2, ndims=2):
# self.device = x.device
# img_sz = x.size()[2:]
# n = x.size()[0]
# self.max_sz = [img_sz[0]] * self.dimension
# ts_emb_shape=[n,-1]+[1]*self.dimension
# self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
# self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
# [1, self.dimension]+list(img_sz)).to(self.device)
# img = x
# t = self.time_embed(t)
# for rec_id in range(rec_num):
# if self.conditional_input:
# tgt = y
# enc_list = []
# out = img
# for i in range(self.hier_num):
# out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
# if self.conditional_input:
# tgt = self.block_down_cond[i](tgt)
# out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
# tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
# enc_list.append(out)
# out = self.down_layers[i](out)
# if self.conditional_input:
# tgt = self.down_layers[i](tgt)
# out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
# if self.conditional_input:
# out = out + tgt
# for i in range(self.hier_num):
# out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
# out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))
# out = self.conv_out(out)/128
# ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
# if rec_id == 0:
# ddf = ddf_one
# else:
# ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
# img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
# return ddf
# def _make_te(self, dim_in, dim_out):
# return nn.Sequential(
# nn.Linear(dim_in, dim_out),
# nn.ReLU(),
# nn.Linear(dim_out, dim_out)
# )
# ==============================================
# Layers
# ==============================================
def ddf_multiplier(dvf,mul_num=10,stn=None):
ddf=dvf
for i in range(mul_num):
ddf = dvf + stn(ddf, dvf)
return ddf
def composite(ddfs,stn=None):
if stn is None:
stn = STN(device=ddfs[0].device,padding_mode="border")
comp_ddf=ddfs[0]
for i in range(1,len(ddfs)):
comp_ddf = ddfs[i] + stn(comp_ddf,ddfs[i])
return comp_ddf
class STN(nn.Module):
def __init__(self,ndims=2,img_sz=None,max_sz=None,device=None,padding_mode="border",resample_mode=None):
super(STN, self).__init__()
self.ndims=ndims
self.img_sz=[img_sz]*ndims
# self.img_sz=img_sz
self.device = device
self.padding_mode = padding_mode
# max_sz=[128]*self.ndims
max_sz=[img_sz]*self.ndims
# max_sz=img_sz
# max_sz=img_sz if max_sz is None else ([128,128] if img_sz is None else img_sz)
# self.max_sz=torch.Tensor(np.reshape(np.array(max_sz), [1, self.ndims, 1, 1])).to(self.device)
self.max_sz=torch.Tensor(np.reshape(np.array(max_sz), [1, self.ndims]+[1]*self.ndims)).to(self.device)
self.resample_mode=resample_mode
if self.img_sz is not None:
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=s) for s in self.img_sz]), 0),
[1, self.ndims] + self.img_sz).to(self.device)
return
def max_limit(self, sample_coords0, plus=0., minus=1.):
sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
# return tf.stack([tf.maximum(tf.minimum(x, sz - minus + plus), 0 + plus) for x, sz in zip(sample_coords, input_size0)],-1)
return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
zip(sample_coords, self.max_sz)], 1)
def boundary_limit(self, sample_coords0, plus=0., minus=1.):
sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
# return tf.stack([tf.maximum(tf.minimum(x, sz - minus + plus), 0 + plus) for x, sz in zip(sample_coords, input_size0)],-1)
return torch.cat([(torch.clamp(x * sz+ref, min=minus - 1 * sz + plus, max=1 * sz - minus + plus)-ref) / sz for x, sz,ref in
zip(sample_coords, self.max_sz, self.ref_grid)], 1)
def resample(self, vol, ddf, ref=None, img_sz=None,padding_mode = "zeros"):
# print(vol.device, ddf.device)
# print(self.device)
# print('===================')
device = ddf.device
ref = self.ref_grid if ref is None else ref
if img_sz is None:
img_sz = self.max_sz
else:
img_sz = torch.reshape(torch.tensor([(s - 1) / 2. for s in img_sz], device=device), [1]+[1]*self.ndims+[self.ndims])
# resample_mode = 'bicubic'
if self.resample_mode is None:
resample_mode = 'bilinear' # if self.ndims==2 else 'trilinear'
else:
resample_mode=self.resample_mode
# padding_mode = "border"
# print(ddf.shape, ref.shape)
return F.grid_sample(vol.to(device), torch.flip((ddf * self.max_sz.to(device) + ref.to(device)).permute(
[0] + list(range(2, 2 + self.ndims)) + [1]) / img_sz - 1, dims=[-1]), mode=resample_mode,
padding_mode=padding_mode,
align_corners=True)
def forward(self,x,ddf):
self.device = x.device if self.device is None else self.device
if self.img_sz is None:
self.img_sz = list(x.size()[2:]).to(self.device)
self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=s) for s in self.img_sz]), 0),[1, self.ndims]+self.img_sz).to(self.device)
resampled_x = self.resample(x, ddf=ddf, img_sz=self.img_sz, padding_mode=self.padding_mode)
return resampled_x
if __name__ == '__main__':
ndims = 3
res = 128
x = torch.rand([1, 1] + [res]*ndims)
t = torch.randint(0, 1000, (1,))
text = torch.rand([1, 1024] + [1]*ndims)
model = RecMutAttnNet(n_steps=1000, time_emb_dim=100, ndims=ndims, num_input_chn=1, res=res, conditional_input=True)
y = model(x, x, t, text=text)
print("Ouput shape", y.shape)
# Total parameters
total_params = sum(p.numel() for p in model.parameters())
# Trainable parameters only
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params}")
print(f"Trainable parameters: {trainable_params}")