| from torch import nn
|
| import torch
|
| import torch.nn.functional as F
|
| from torch.utils.checkpoint import checkpoint as grad_checkpoint
|
| import numpy as np
|
| import math
|
| from Diffusion.safe_conv_transpose import SafeConvTranspose3d
|
|
|
| class UpsampleConv(nn.Module):
|
| """Drop-in replacement for ConvTranspose3d/2d that avoids the XPU memory leak.
|
| ConvTranspose3d backward leaks ~0.33 GiB/step on Intel XPU (oneDNN bug).
|
| This uses F.interpolate (zero leak) + Conv (negligible leak) instead.
|
| Also avoids checkerboard artifacts common with transposed convolutions.
|
| """
|
| def __init__(self, in_channels, out_channels, kernel_size, stride, padding, ndims=3):
|
| super().__init__()
|
| self.scale_factor = stride
|
| self.mode = 'trilinear' if ndims == 3 else 'bilinear'
|
| Conv = getattr(nn, f'Conv{ndims}d')
|
| self.conv = Conv(in_channels, out_channels, 3, 1, 1)
|
|
|
| def forward(self, x):
|
| x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=False)
|
| return self.conv(x)
|
|
|
|
|
| def get_net(name="recresnet"):
|
| name = name.lower()
|
| if name == "recresacnet":
|
| net = RecResACNet
|
| elif name == "recmutattnnet":
|
| net = RecMutAttnNet
|
| elif name == "recmutattnnet0":
|
| net = RecMutAttnNet0
|
| elif name == "recmutattnnet1":
|
| net = RecMutAttnNet1
|
| elif name == "defrecmutattnnet":
|
| net = DefRec_MutAttnNet
|
| elif name == "recmulmodmutattnnet":
|
| net = RecMulModMutAttnNet
|
| elif name == "om_net":
|
| net = OM_net
|
| else:
|
| net = None
|
| return net
|
|
|
|
|
|
|
| def sinusoidal_embedding(n, d):
|
|
|
| embedding = torch.zeros(n, d)
|
| wk = torch.tensor([1 / 10_000 ** (2 * j / d) for j in range(d)])
|
| wk = wk.reshape((1, d))
|
| t = torch.arange(n).reshape((n, 1))
|
| embedding[:,::2] = torch.sin(t * wk[:,::2])
|
| embedding[:,1::2] = torch.cos(t * wk[:,::2])
|
| return embedding
|
|
|
| class AtrousBlock(nn.Module):
|
| def __init__(self, shape, in_c, out_c, kernel_size=3, stride=1, atrous_rates=[1,3], ndims=2, activation=None, normalize=True):
|
| super(AtrousBlock, self).__init__()
|
|
|
| if normalize:
|
|
|
|
|
| norm=getattr(nn, 'InstanceNorm%dd' % ndims)
|
| self.ln = norm(out_c,affine=True)
|
| else:
|
| self.ln = nn.Identity()
|
| Conv=getattr(nn,'Conv%dd' % ndims)
|
| if in_c!=out_c:
|
| self.conv0 = Conv(in_c, out_c, kernel_size, 1, (kernel_size-1)//2*1)
|
| else:
|
| self.conv0 = None
|
| self.convs = nn.ModuleList([
|
| Conv(out_c, out_c, kernel_size, 1, (kernel_size-1)//2*ar, dilation=ar)
|
| if ar>0 else Conv(out_c, out_c, 1, 1, 0)
|
| for ar in atrous_rates
|
| ])
|
|
|
|
|
| self.activation = nn.LeakyReLU(1e-6) if activation is None else activation
|
|
|
|
|
| self.normalize = normalize
|
|
|
| def forward(self, x):
|
| if self.conv0 is not None:
|
| x = self.conv0(x)
|
| x = self.ln(x) if self.normalize else x
|
| out=nn.Identity()(x)
|
| for conv in self.convs:
|
| out = self.activation(out)
|
| out = conv(out)
|
| return self.activation(out+x)
|
|
|
|
|
|
|
|
|
|
|
| class RecResACNet(nn.Module):
|
| def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0):
|
| super(RecResACNet, self).__init__()
|
|
|
| self.dimension = ndims
|
| self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
|
| self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
|
|
|
|
|
| self.time_embed = nn.Embedding(n_steps, time_emb_dim)
|
| self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
|
| self.time_embed.requires_grad_(False)
|
|
|
|
|
| self.te1 = self._make_te(time_emb_dim, 1)
|
| self.b1 = nn.Sequential(
|
| AtrousBlock([num_input_chn] + [res] * ndims, num_input_chn, 10, ndims=ndims),
|
| AtrousBlock([10] + [res] * ndims, 10, 10, ndims=ndims),
|
| AtrousBlock([10] + [res] * ndims, 10, 10, ndims=ndims),
|
|
|
| )
|
| self.down1 = self.Conv(10, 10, 4, 2, 1)
|
|
|
| self.te2 = self._make_te(time_emb_dim, 10)
|
| self.b2 = nn.Sequential(
|
| AtrousBlock([10] + [res // 2] * ndims, 10, 20, ndims=ndims),
|
| AtrousBlock([20] + [res // 2] * ndims, 20, 20, ndims=ndims),
|
| AtrousBlock([20] + [res // 2] * ndims, 20, 20, ndims=ndims)
|
| )
|
| self.down2 = self.Conv(20, 20, 4, 2, 1)
|
|
|
| self.te3 = self._make_te(time_emb_dim, 20)
|
| self.b3 = nn.Sequential(
|
| AtrousBlock([20] + [res // 4] * ndims, 20, 40, ndims=ndims),
|
| AtrousBlock([40] + [res // 4] * ndims, 40, 40, ndims=ndims),
|
| AtrousBlock([40] + [res // 4] * ndims, 40, 40, ndims=ndims)
|
| )
|
| self.down3 = self.Conv(40, 40, 4, 2, 1)
|
|
|
|
|
| self.te_mid = self._make_te(time_emb_dim, 40)
|
| self.b_mid = nn.Sequential(
|
| AtrousBlock([40] + [res // 8] * ndims, 40, 20, ndims=ndims),
|
| AtrousBlock([20] + [res // 8] * ndims, 20, 20, ndims=ndims),
|
| AtrousBlock([20] + [res // 8] * ndims, 20, 40, ndims=ndims)
|
| )
|
|
|
|
|
| self.up1 = self.ConvT(40, 40, 4, 2, 1)
|
|
|
| self.te4 = self._make_te(time_emb_dim, 80)
|
| self.b4 = nn.Sequential(
|
| AtrousBlock([80] + [res // 4] * ndims, 80, 40, ndims=ndims, normalize=False),
|
| AtrousBlock([40] + [res // 4] * ndims, 40, 20, ndims=ndims, normalize=False),
|
| AtrousBlock([20] + [res // 4] * ndims, 20, 20, ndims=ndims, normalize=False)
|
| )
|
|
|
| self.up2 = self.ConvT(20, 20, 4, 2, 1)
|
| self.te5 = self._make_te(time_emb_dim, 40)
|
| self.b5 = nn.Sequential(
|
| AtrousBlock([40] + [res // 2] * ndims, 40, 20, ndims=ndims, normalize=False),
|
| AtrousBlock([20] + [res // 2] * ndims, 20, 10, ndims=ndims, normalize=False),
|
| AtrousBlock([10] + [res // 2] * ndims, 10, 10, ndims=ndims, normalize=False)
|
| )
|
|
|
| self.up3 = self.ConvT(10, 10, 4, 2, 1)
|
| self.te_out = self._make_te(time_emb_dim, 20)
|
| self.b_out = nn.Sequential(
|
| AtrousBlock([20] + [res // 1] * ndims, 20, 10, ndims=ndims, normalize=False),
|
| AtrousBlock([10] + [res // 1] * ndims, 10, 10, ndims=ndims, normalize=False),
|
| AtrousBlock([10] + [res // 1] * ndims, 10, 10, ndims=ndims, normalize=False)
|
| )
|
|
|
| self.conv_out = self.Conv(10, ndims, 3, 1, 1)
|
|
|
| def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
|
| sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
|
| return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
|
| zip(sample_coords, max_sz)], 1)
|
|
|
| def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
|
| ref = self.ref_grid if ref is None else ref
|
| img_sz = self.max_sz if img_sz is None else img_sz
|
|
|
| resample_mode = 'bilinear'
|
|
|
|
|
| if True:
|
|
|
| return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
|
| np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
|
| [0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
|
| align_corners=True)
|
|
|
| def forward(self, x=None, t=None, y=None, rec_num=2, ndims=2):
|
|
|
| self.device = x.device
|
|
|
| img_sz = x.size()[2:]
|
| n = x.size()[0]
|
| self.max_sz = [img_sz[0]] * self.dimension
|
| ts_emb_shape=[n,-1]+[1]*self.dimension
|
|
|
|
|
| self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
|
|
|
|
|
| self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
|
| [1, self.dimension]+list(img_sz)).to(self.device)
|
| img = x
|
|
|
|
|
| t = self.time_embed(t)
|
|
|
| for rec_id in range(rec_num):
|
| out1 = self.b1(img + self.te1(t).reshape(ts_emb_shape))
|
| out2 = self.b2(self.down1(out1) + self.te2(t).reshape(ts_emb_shape))
|
| out3 = self.b3(self.down2(out2) + self.te3(t).reshape(ts_emb_shape))
|
|
|
| out_mid = self.b_mid(self.down3(out3) * self.te_mid(t).reshape(ts_emb_shape))
|
|
|
| out4 = torch.cat((out3, self.up1(out_mid)), dim=1)
|
| out4 = self.b4(out4 + self.te4(t).reshape(ts_emb_shape))
|
|
|
| out5 = torch.cat((out2, self.up2(out4)), dim=1)
|
| out5 = self.b5(out5 + self.te5(t).reshape(ts_emb_shape))
|
|
|
| out = torch.cat((out1, self.up3(out5)), dim=1)
|
| out = self.b_out(out + self.te_out(t).reshape(ts_emb_shape))
|
|
|
| out = self.conv_out(out)
|
|
|
| ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
|
| if rec_id == 0:
|
| ddf = ddf_one
|
| else:
|
| ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
|
| img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
|
|
|
| return ddf
|
|
|
| def _make_te(self, dim_in, dim_out):
|
|
|
|
|
| return nn.Sequential(
|
| nn.Linear(dim_in, dim_out),
|
|
|
| nn.ReLU(),
|
| nn.Linear(dim_out, dim_out)
|
| )
|
|
|
|
|
|
|
|
|
|
|
| class cross_attn(nn.Module):
|
| def __init__(self, q, k, v, ndims=2):
|
| self.q = q
|
| self.k = k
|
| self.v = v
|
| self.ndims = ndims
|
| self.Conv = getattr(nn, 'Conv%dd' % self.ndims)
|
| self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.ndims)
|
| self.softmax = nn.Softmax(dim=-1)
|
| self.gamma = nn.Parameter(torch.zeros(1))
|
|
|
| def forward(self, x, y):
|
| q = self.q(x)
|
| k = self.k(y)
|
| v = self.v(y)
|
| attn = self.softmax(torch.matmul(q, k.transpose(-2, -1)))
|
| out = torch.matmul(attn, v)
|
| return out
|
|
|
| class DefRec_MutAttnNet(nn.Module):
|
| def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4):
|
| super(DefRec_MutAttnNet, self).__init__()
|
|
|
|
|
|
|
| self.feat_channels = [num_input_chn, 16, 32, 128, 256, 512]
|
| self.conditional_input = conditional_input
|
| self.num_heads = num_heads
|
| self.text_feat_chn = text_feat_chn
|
|
|
| self.dimension = ndims
|
| self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
|
| self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
|
| self.copy = nn.Identity()
|
|
|
| self.time_embed = nn.Embedding(n_steps, time_emb_dim)
|
| self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
|
| self.time_embed.requires_grad_(False)
|
| self.hier_num = len(self.feat_channels) - 1
|
| self.down_layers = nn.ModuleList()
|
| self.up_layers = nn.ModuleList()
|
| self.ted_layers = nn.ModuleList()
|
| self.teu_layers = nn.ModuleList()
|
| self.block_down = nn.ModuleList()
|
| self.block_up = nn.ModuleList()
|
| if self.conditional_input:
|
| self.block_down_cond = nn.ModuleList()
|
| self.fuse_conv0 = nn.ModuleList()
|
|
|
| self.attn_layer = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
|
| Global_Maxpool = getattr(nn, 'AdaptiveMaxPool%dd' % self.dimension)
|
| self.global_maxpool = Global_Maxpool(1)
|
| self.img2txt = self.Conv(self.feat_channels[-1], self.text_feat_chn, 1, 1, 0)
|
| self.txt_proc = AtrousBlock([self.text_feat_chn] + [1] * ndims, self.text_feat_chn, self.text_feat_chn, ndims=ndims, normalize=False, atrous_rates=[0, 0])
|
| self.txt2img = self.Conv(self.text_feat_chn, self.feat_channels[-1], 1, 1, 0)
|
| self.text = torch.zeros(1, self.text_feat_chn, *([1]*self.dimension))
|
| self.img_res = [res]*self.dimension
|
| self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in self.img_res]), 0),
|
| [1, self.dimension]+list(self.img_res))
|
|
|
| for i in range(1, self.hier_num + 1):
|
| j=-i
|
| self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
|
| self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
|
| self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
|
| self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
|
| self.block_down.append(nn.Sequential(
|
| AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
| AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| ))
|
| if self.conditional_input:
|
| self.block_down_cond.append(nn.Sequential(
|
| AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
| AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| ))
|
| self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
|
|
|
| if i==self.hier_num:
|
| k=j
|
| else:
|
| k=j-1
|
| self.block_up.append(nn.Sequential(
|
| AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
|
| AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
|
| AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
|
| ))
|
|
|
|
|
| self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
|
| self.b_mid = nn.Sequential(
|
| AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
|
| )
|
|
|
| self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
|
|
|
| def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
|
| sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
|
| return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
|
| zip(sample_coords, max_sz)], 1)
|
|
|
| def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
|
| ref = self.ref_grid if ref is None else ref
|
| img_sz = self.max_sz if img_sz is None else img_sz
|
| resample_mode = 'bilinear'
|
|
|
| return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
|
| np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
|
| [0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
|
| align_corners=True)
|
|
|
| def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2):
|
| self.device = x.device
|
| img_sz = x.size()[2:]
|
| n = x.size()[0]
|
| self.max_sz = [img_sz[0]] * self.dimension
|
| ts_emb_shape=[n,-1]+[1]*self.dimension
|
|
|
| self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
|
| if list(img_sz) != self.img_res:
|
|
|
|
|
| self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
|
| [1, self.dimension]+list(img_sz))
|
| self.ref_grid = self.ref_grid.to(self.device)
|
|
|
| img = x
|
| if self.conditional_input:
|
| tgt = y
|
|
|
| tgt_down_list = []
|
| for i in range(self.hier_num):
|
|
|
| if self.conditional_input:
|
| tgt = self.block_down_cond[i](tgt)
|
| tgt_down_list.append(self.copy(tgt))
|
| tgt = self.down_layers[i](tgt)
|
| tgt_mid = self.copy(tgt)
|
| tgt_shape = tgt_mid.shape
|
|
|
| tgt_mid = tgt_mid.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1)
|
|
|
| t = [t0.to(self.device) for t0 in t]
|
| t = [t0 for _ in range(rec_num) for t0 in t]
|
| for rec_id,time in enumerate(t):
|
| t_emb = self.time_embed(time)
|
|
|
|
|
|
|
|
|
| enc_list = []
|
| out = img
|
| for i in range(self.hier_num):
|
| out = self.block_down[i](out + self.ted_layers[i](t_emb).reshape(ts_emb_shape))
|
| if self.conditional_input:
|
|
|
| out = self.fuse_conv0[i](torch.cat([out, tgt_down_list[i]], axis=1))
|
|
|
| enc_list.append(out)
|
| out = self.down_layers[i](out)
|
|
|
|
|
|
|
|
|
| out = self.b_mid(out + self.tmid(t_emb).reshape(ts_emb_shape))
|
| if self.conditional_input:
|
|
|
| out_shape = out.shape
|
|
|
|
|
|
|
| out_attn, _ = self.attn_layer(out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1), tgt_mid, tgt_mid)
|
| out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape)
|
| out = out + out_attn
|
|
|
| if self.conditional_input:
|
| if text is None:
|
| text = self.text
|
| text = text.to(self.device)
|
| out_txt = self.img2txt(out) + text
|
| out_txt = self.txt_proc(out_txt)
|
| out_txt = self.txt2img(out_txt)
|
| out = out + out_txt
|
|
|
| for i in range(self.hier_num):
|
| out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
|
| out = self.block_up[i](out + self.teu_layers[i](t_emb).reshape(ts_emb_shape))
|
|
|
| out = self.conv_out(out)/128
|
|
|
| ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
|
| if rec_id == 0:
|
| ddf = ddf_one
|
| else:
|
| ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
|
| img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
|
|
|
| return ddf
|
|
|
| def _make_te(self, dim_in, dim_out):
|
| return nn.Sequential(
|
| nn.Linear(dim_in, dim_out),
|
| nn.ReLU(),
|
| nn.Linear(dim_out, dim_out)
|
| )
|
|
|
|
|
| class RecMutAttnNet1(nn.Module):
|
| def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4):
|
| super(RecMutAttnNet1, self).__init__()
|
|
|
|
|
| self.feat_channels = [num_input_chn, 16, 32, 64, 128, 256]
|
| self.conditional_input = conditional_input
|
| self.num_heads = num_heads
|
| self.text_feat_chn = text_feat_chn
|
|
|
| self.dimension = ndims
|
| self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
|
| self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
|
|
|
|
|
| self.time_embed = nn.Embedding(n_steps, time_emb_dim)
|
| self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
|
| self.time_embed.requires_grad_(False)
|
| self.hier_num = len(self.feat_channels) - 1
|
| self.down_layers = nn.ModuleList()
|
| self.up_layers = nn.ModuleList()
|
| self.ted_layers = nn.ModuleList()
|
| self.teu_layers = nn.ModuleList()
|
| self.block_down = nn.ModuleList()
|
| if self.conditional_input:
|
| self.block_down_cond = nn.ModuleList()
|
| self.fuse_conv0 = nn.ModuleList()
|
| self.fuse_conv1 = nn.ModuleList()
|
| self.attn_layer = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
|
|
|
| self.block_up = nn.ModuleList()
|
|
|
| for i in range(1, self.hier_num + 1):
|
| j=-i
|
| self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
|
| self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
|
| self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
|
| self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
|
| self.block_down.append(nn.Sequential(
|
| AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
| AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| ))
|
| if self.conditional_input:
|
| self.block_down_cond.append(nn.Sequential(
|
| AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
| AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| ))
|
| self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
|
| self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
|
| if i==self.hier_num:
|
| k=j
|
| else:
|
| k=j-1
|
| self.block_up.append(nn.Sequential(
|
| AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
|
| AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
|
| AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
|
| ))
|
|
|
|
|
| self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
|
| self.b_mid = nn.Sequential(
|
| AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
|
| )
|
|
|
| self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
|
|
|
| def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
|
| sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
|
| return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
|
| zip(sample_coords, max_sz)], 1)
|
|
|
| def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
|
| ref = self.ref_grid if ref is None else ref
|
| img_sz = self.max_sz if img_sz is None else img_sz
|
| resample_mode = 'bilinear'
|
|
|
| return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
|
| np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
|
| [0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
|
| align_corners=True)
|
|
|
| def forward(self, x=None, y=None, t=None, rec_num=2, ndims=2):
|
| self.device = x.device
|
| img_sz = x.size()[2:]
|
| n = x.size()[0]
|
| self.max_sz = [img_sz[0]] * self.dimension
|
| ts_emb_shape=[n,-1]+[1]*self.dimension
|
| self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
|
| self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
|
| [1, self.dimension]+list(img_sz)).to(self.device)
|
| img = x
|
| t = self.time_embed(t)
|
|
|
| for rec_id in range(rec_num):
|
| if self.conditional_input:
|
| tgt = y
|
| enc_list = []
|
| out = img
|
| for i in range(self.hier_num):
|
| out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
|
| if self.conditional_input:
|
| tgt = self.block_down_cond[i](tgt)
|
| out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
|
| tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
|
| enc_list.append(out)
|
| out = self.down_layers[i](out)
|
| if self.conditional_input:
|
| tgt = self.down_layers[i](tgt)
|
|
|
| out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
|
| if self.conditional_input:
|
|
|
| out_shape = out.shape
|
| tgt_shape = tgt.shape
|
|
|
| tgt = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1)
|
| out_attn, _ = self.attn_layer(out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1), tgt, tgt)
|
| out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape)
|
| out = out + out_attn
|
|
|
| for i in range(self.hier_num):
|
| out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
|
| out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))
|
|
|
| out = self.conv_out(out)/128
|
|
|
| ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
|
| if rec_id == 0:
|
| ddf = ddf_one
|
| else:
|
| ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
|
| img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
|
|
|
| return ddf
|
|
|
| def _make_te(self, dim_in, dim_out):
|
| return nn.Sequential(
|
| nn.Linear(dim_in, dim_out),
|
| nn.ReLU(),
|
| nn.Linear(dim_out, dim_out)
|
| )
|
|
|
| class RecMutAttnNet(nn.Module):
|
| def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4):
|
| super(RecMutAttnNet, self).__init__()
|
|
|
|
|
| self.feat_channels = [num_input_chn, 16, 32, 64, 128, 256]
|
| self.conditional_input = conditional_input
|
| self.num_heads = num_heads
|
| self.text_feat_chn = text_feat_chn
|
|
|
| self.dimension = ndims
|
| self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
|
| self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
|
|
|
|
|
| self.time_embed = nn.Embedding(n_steps, time_emb_dim)
|
| self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
|
| self.time_embed.requires_grad_(False)
|
| self.hier_num = len(self.feat_channels) - 1
|
| self.down_layers = nn.ModuleList()
|
| self.up_layers = nn.ModuleList()
|
| self.ted_layers = nn.ModuleList()
|
| self.teu_layers = nn.ModuleList()
|
| self.block_down = nn.ModuleList()
|
| self.block_up = nn.ModuleList()
|
| if self.conditional_input:
|
| self.block_down_cond = nn.ModuleList()
|
| self.fuse_conv0 = nn.ModuleList()
|
| self.fuse_conv1 = nn.ModuleList()
|
| self.attn_layer = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
|
| Global_Maxpool = getattr(nn, 'AdaptiveMaxPool%dd' % self.dimension)
|
| self.global_maxpool = Global_Maxpool(1)
|
| self.img2txt = self.Conv(self.feat_channels[-1], self.text_feat_chn, 1, 1, 0)
|
| self.txt_proc = AtrousBlock([self.text_feat_chn] + [1] * ndims, self.text_feat_chn, self.text_feat_chn, ndims=ndims, normalize=False, atrous_rates=[0, 0])
|
| self.txt2img = self.Conv(self.text_feat_chn, self.feat_channels[-1], 1, 1, 0)
|
| self.text = torch.zeros(1, self.text_feat_chn, *([1]*self.dimension))
|
| self.img_res = [res]*self.dimension
|
| self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in self.img_res]), 0),
|
| [1, self.dimension]+list(self.img_res))
|
|
|
| for i in range(1, self.hier_num + 1):
|
| j=-i
|
| self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
|
| self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
|
| self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
|
| self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
|
| self.block_down.append(nn.Sequential(
|
| AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
| AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| ))
|
| if self.conditional_input:
|
| self.block_down_cond.append(nn.Sequential(
|
| AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
| AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| ))
|
| self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
|
| self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
|
| if i==self.hier_num:
|
| k=j
|
| else:
|
| k=j-1
|
| self.block_up.append(nn.Sequential(
|
| AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
|
| AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
|
| AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
|
| ))
|
|
|
|
|
| self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
|
| self.b_mid = nn.Sequential(
|
| AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
|
| )
|
|
|
| self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
|
|
|
| def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
|
| sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
|
| return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
|
| zip(sample_coords, max_sz)], 1)
|
|
|
| def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
|
| ref = self.ref_grid if ref is None else ref
|
| img_sz = self.max_sz if img_sz is None else img_sz
|
| resample_mode = 'bilinear'
|
|
|
| return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
|
| np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
|
| [0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
|
| align_corners=True)
|
|
|
| def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2):
|
| self.device = x.device
|
| img_sz = x.size()[2:]
|
| n = x.size()[0]
|
| self.max_sz = [img_sz[0]] * self.dimension
|
| ts_emb_shape=[n,-1]+[1]*self.dimension
|
|
|
| self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
|
| if list(img_sz) != self.img_res:
|
|
|
|
|
| self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
|
| [1, self.dimension]+list(img_sz))
|
| self.ref_grid = self.ref_grid.to(self.device)
|
|
|
| img = x
|
| t = self.time_embed(t)
|
|
|
| for rec_id in range(rec_num):
|
| if self.conditional_input:
|
| tgt = y
|
| enc_list = []
|
| out = img
|
| for i in range(self.hier_num):
|
| out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
|
| if self.conditional_input:
|
| tgt = self.block_down_cond[i](tgt)
|
| out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
|
| tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
|
| enc_list.append(out)
|
| out = self.down_layers[i](out)
|
| if self.conditional_input:
|
| tgt = self.down_layers[i](tgt)
|
|
|
|
|
| out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
|
| if self.conditional_input:
|
|
|
| out_shape = out.shape
|
| tgt_shape = tgt.shape
|
|
|
| tgt = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1)
|
| out_attn, _ = self.attn_layer(out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1), tgt, tgt)
|
| out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape)
|
| out = out + out_attn
|
|
|
| if self.conditional_input:
|
| if text is None:
|
| text = self.text
|
| text = text.to(self.device)
|
| text = text.view(-1, self.text_feat_chn, *([1]*self.dimension))
|
| out_txt = self.img2txt(out) + text
|
| out_txt = self.txt_proc(out_txt)
|
| out_txt = self.txt2img(out_txt)
|
| out = out + out_txt
|
|
|
| for i in range(self.hier_num):
|
| out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
|
| out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))
|
|
|
| out = self.conv_out(out)/128
|
|
|
| ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
|
| if rec_id == 0:
|
| ddf = ddf_one
|
| else:
|
| ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
|
| img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
|
|
|
|
|
|
|
| return ddf
|
|
|
| def _make_te(self, dim_in, dim_out):
|
| return nn.Sequential(
|
| nn.Linear(dim_in, dim_out),
|
| nn.ReLU(),
|
| nn.Linear(dim_out, dim_out)
|
| )
|
|
|
| class RecMulModMutAttnNet(nn.Module):
|
| def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4):
|
| super(RecMulModMutAttnNet, self).__init__()
|
|
|
|
|
| self.feat_channels = [num_input_chn, 16, 32, 64, 128, 256]
|
| self.conditional_input = conditional_input
|
| self.num_heads = num_heads
|
| self.text_feat_chn = text_feat_chn
|
|
|
| self.dimension = ndims
|
| self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
|
| self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
|
|
|
|
|
| self.time_embed = nn.Embedding(n_steps, time_emb_dim)
|
| self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
|
| self.time_embed.requires_grad_(False)
|
| self.hier_num = len(self.feat_channels) - 1
|
| self.down_layers = nn.ModuleList()
|
| self.up_layers = nn.ModuleList()
|
| self.ted_layers = nn.ModuleList()
|
| self.teu_layers = nn.ModuleList()
|
| self.block_down = nn.ModuleList()
|
| self.block_up = nn.ModuleList()
|
| if self.conditional_input:
|
|
|
| self.txt_layers = nn.ModuleList()
|
| self.block_down_cond = nn.ModuleList()
|
| self.fuse_conv0 = nn.ModuleList()
|
| self.fuse_conv1 = nn.ModuleList()
|
| self.attn_layer0 = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
|
| self.attn_layer1 = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
|
| Global_Maxpool = getattr(nn, 'AdaptiveMaxPool%dd' % self.dimension)
|
| self.global_maxpool = Global_Maxpool(1)
|
| self.img2txt = self.Conv(self.feat_channels[-1], self.text_feat_chn, 1, 1, 0)
|
| self.txt_proc = AtrousBlock([self.text_feat_chn] + [1] * ndims, self.text_feat_chn, self.text_feat_chn, ndims=ndims, normalize=False, atrous_rates=[0, 0])
|
| self.txt2img = self.Conv(self.text_feat_chn, self.feat_channels[-1], 1, 1, 0)
|
|
|
| self.text = torch.zeros(1, self.text_feat_chn)
|
|
|
| self.img_res = [res]*self.dimension
|
| self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in self.img_res]), 0),
|
| [1, self.dimension]+list(self.img_res))
|
|
|
| for i in range(1, self.hier_num + 1):
|
| j=-i
|
| self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
|
| self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
|
| self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
|
| self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
|
| self.block_down.append(nn.Sequential(
|
| AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
| AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| ))
|
| if self.conditional_input:
|
|
|
|
|
|
|
|
|
| self.txt_layers.append((self._make_te(self.text_feat_chn, self.feat_channels[i])))
|
| self.block_down_cond.append(nn.Sequential(
|
| AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
| AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| ))
|
| self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
|
| self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
|
| if i==self.hier_num:
|
| k=j
|
| else:
|
| k=j-1
|
| self.block_up.append(nn.Sequential(
|
| AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
|
| AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
|
| AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
|
| ))
|
|
|
|
|
| self.txt_layers.append((self._make_te(self.text_feat_chn, self.text_feat_chn)))
|
| self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
|
| self.b_mid = nn.Sequential(
|
| AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
|
| )
|
| self.fuse = self.Conv(2*self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], 1, 1, 0)
|
|
|
| self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
|
|
|
| def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
|
| sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
|
| return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
|
| zip(sample_coords, max_sz)], 1)
|
|
|
| def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
|
| ref = self.ref_grid if ref is None else ref
|
| img_sz = self.max_sz if img_sz is None else img_sz
|
| resample_mode = 'bilinear'
|
|
|
| return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
|
| np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
|
| [0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
|
| align_corners=True)
|
|
|
| def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2):
|
| self.device = x.device
|
| img_sz = x.size()[2:]
|
| n = x.size()[0]
|
| self.max_sz = [img_sz[0]] * self.dimension
|
| ts_emb_shape=[n,-1]+[1]*self.dimension
|
|
|
|
|
| self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
|
| if list(img_sz) != self.img_res:
|
|
|
|
|
| self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
|
| [1, self.dimension]+list(img_sz))
|
| self.ref_grid = self.ref_grid.to(self.device)
|
|
|
| img = x
|
| t = self.time_embed(t)
|
| if text is None:
|
| text = self.text
|
|
|
| text = text.to(self.device)
|
| txt_shape = [1,-1]+[1]*self.dimension
|
| else:
|
| txt_shape = [n,-1]+[1]*self.dimension
|
|
|
| for rec_id in range(rec_num):
|
| if self.conditional_input:
|
| tgt = y
|
| enc_list = []
|
| out = img
|
| for i in range(self.hier_num):
|
| out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
|
| if self.conditional_input:
|
| tgt = self.block_down_cond[i](tgt) + self.txt_layers[i](text).reshape(txt_shape)
|
| out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
|
| tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1))
|
| enc_list.append(out)
|
| out = self.down_layers[i](out)
|
| if self.conditional_input:
|
| tgt = self.down_layers[i](tgt)
|
|
|
|
|
| out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
|
| if self.conditional_input:
|
|
|
| out_shape = out.shape
|
| tgt_shape = tgt.shape
|
| out_flat = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1)
|
| tgt_flat = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1)
|
| out_attn, _ = self.attn_layer0(out_flat, tgt_flat, tgt_flat)
|
| tgt_attn, _ = self.attn_layer1(tgt_flat, out_flat, out_flat)
|
| out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape)
|
| tgt_attn = tgt_attn.permute(1, 2, 0).contiguous().view(tgt_shape)
|
| out = out + out_attn
|
| tgt = tgt + tgt_attn
|
| out = self.fuse(torch.cat([out, tgt], dim=1))
|
|
|
| if self.conditional_input:
|
|
|
|
|
|
|
|
|
| img_txt_feat = self.img2txt(out)
|
| self.img_embd = self.global_maxpool(img_txt_feat).view(n, -1)
|
| out_txt = self.txt_layers[-1](text).reshape(txt_shape) + img_txt_feat
|
| out_txt = self.txt_proc(out_txt)
|
| out_txt = self.txt2img(out_txt)
|
| out = out + out_txt
|
|
|
| for i in range(self.hier_num):
|
| out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
|
| out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))
|
|
|
| out = self.conv_out(out)/128
|
|
|
| ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
|
| if rec_id == 0:
|
| ddf = ddf_one
|
| else:
|
| ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
|
| img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
|
|
|
|
|
|
|
| return ddf
|
|
|
| def _make_te(self, dim_in, dim_out):
|
| return nn.Sequential(
|
| nn.Linear(dim_in, dim_out),
|
| nn.ReLU(),
|
| nn.Linear(dim_out, dim_out)
|
| )
|
|
|
|
|
| class OM_net(nn.Module):
|
| """
|
| Extended RecMulModMutAttnNet with gated attention mechanisms:
|
| 1. Text Gate (bottleneck): sigmoid weight w_txt to interpolate between
|
| text-enhanced features and raw image features. Learns to suppress
|
| text branch when text embedding is zeros (no text provided).
|
| 2. Target Gate (each encoder level): per-voxel spatial gate using
|
| residual AtrousBlock to identify condition vs. noise voxels in the
|
| target/condition image path, weighting the fuse_conv1 output.
|
|
|
| Supports gradient checkpointing via `use_checkpoint` flag to reduce
|
| peak activation memory (trades compute for memory).
|
| """
|
| def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0,
|
| conditional_input=True, text_feat_chn=1024, num_heads=4,
|
| use_conv_transpose=False):
|
| super(OM_net, self).__init__()
|
| self.use_checkpoint = False
|
| self.use_conv_transpose = use_conv_transpose
|
|
|
| self.feat_channels = [num_input_chn, 12, 32, 64, 128, 512]
|
| self.conditional_input = conditional_input
|
| self.num_heads = num_heads
|
| self.text_feat_chn = text_feat_chn
|
|
|
| self.dimension = ndims
|
| self.Conv = getattr(nn, 'Conv%dd' % self.dimension)
|
| self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension)
|
|
|
|
|
| self.time_embed = nn.Embedding(n_steps, time_emb_dim)
|
| self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
|
| self.time_embed.requires_grad_(False)
|
| self.hier_num = len(self.feat_channels) - 1
|
| self.down_layers = nn.ModuleList()
|
| self.up_layers = nn.ModuleList()
|
| self.ted_layers = nn.ModuleList()
|
| self.teu_layers = nn.ModuleList()
|
| self.block_down = nn.ModuleList()
|
| self.block_up = nn.ModuleList()
|
| if self.conditional_input:
|
| self.txt_layers = nn.ModuleList()
|
| self.block_down_cond = nn.ModuleList()
|
| self.fuse_conv0 = nn.ModuleList()
|
| self.fuse_conv1 = nn.ModuleList()
|
| self.tgt_gate = nn.ModuleList()
|
| self.attn_layer0 = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
|
| self.attn_layer1 = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads)
|
| Global_Maxpool = getattr(nn, 'AdaptiveMaxPool%dd' % self.dimension)
|
| self.global_maxpool = Global_Maxpool(1)
|
| self.img2txt = self.Conv(self.feat_channels[-1], self.text_feat_chn, 1, 1, 0)
|
| self.txt_proc = AtrousBlock([self.text_feat_chn] + [1] * ndims, self.text_feat_chn, self.text_feat_chn, ndims=ndims, normalize=False, atrous_rates=[0, 0])
|
| self.txt2img = self.Conv(self.text_feat_chn, self.feat_channels[-1], 1, 1, 0)
|
| self.text = torch.zeros(1, self.text_feat_chn)
|
|
|
|
|
| self.text_gate = nn.Sequential(
|
| nn.Linear(self.text_feat_chn, self.text_feat_chn // 4),
|
| nn.ReLU(),
|
| nn.Linear(self.text_feat_chn // 4, 1),
|
| nn.Sigmoid()
|
| )
|
|
|
| self.img_res = [res]*self.dimension
|
| self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in self.img_res]), 0),
|
| [1, self.dimension]+list(self.img_res))
|
|
|
| for i in range(1, self.hier_num + 1):
|
| j=-i
|
| self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1))
|
| self.up_layers.append(SafeConvTranspose3d(self.feat_channels[j], self.feat_channels[j], 4, 2, 1))
|
| self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1]))
|
| self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j]))
|
| self.block_down.append(nn.Sequential(
|
| AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
| AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| ))
|
| if self.conditional_input:
|
| self.txt_layers.append((self._make_te(self.text_feat_chn, self.feat_channels[i])))
|
| self.block_down_cond.append(nn.Sequential(
|
| AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims),
|
| AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims),
|
| AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims)
|
| ))
|
| self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
|
| self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0))
|
|
|
| self.tgt_gate.append(nn.Sequential(
|
| AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims,
|
| self.feat_channels[i], self.feat_channels[i], ndims=ndims, atrous_rates=[1, 3]),
|
| self.Conv(self.feat_channels[i], 2, 1, 1, 0)
|
| ))
|
| if i==self.hier_num:
|
| k=j
|
| else:
|
| k=j-1
|
| self.block_up.append(nn.Sequential(
|
| AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
|
| AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False),
|
| AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False)
|
| ))
|
|
|
|
|
| self.txt_layers.append((self._make_te(self.text_feat_chn, self.text_feat_chn)))
|
| self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1])
|
| self.b_mid = nn.Sequential(
|
| AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims),
|
| AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims)
|
| )
|
| self.fuse = self.Conv(2*self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], 1, 1, 0)
|
|
|
| self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1)
|
|
|
|
|
| self._init_tgt_gates()
|
|
|
| def _init_tgt_gates(self):
|
| """Bias target gates so condition channel starts moderately high (~0.73).
|
| Milder than [2,-2] to ensure both cond*tgt and (1-cond)*out halves of
|
| fuse_conv1 input have enough signal for healthy early gradient flow."""
|
| for gate_seq in self.tgt_gate:
|
| final_conv = gate_seq[-1]
|
| with torch.no_grad():
|
| final_conv.bias.data[0] = 1.0
|
| final_conv.bias.data[1] = -1.0
|
|
|
| def _encoder_level(self, i, out, tgt, t, ts_emb_shape, text, txt_shape, w_txt):
|
| """Single encoder level — extracted for gradient checkpointing."""
|
| out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape))
|
| if self.conditional_input and tgt is not None:
|
| tgt = self.block_down_cond[i](tgt) + w_txt * self.txt_layers[i](text).reshape(txt_shape)
|
| gate_logits = self.tgt_gate[i](tgt)
|
| cond_confidence = F.softmax(gate_logits, dim=1)[:, 0:1]
|
| tgt = self.fuse_conv1[i](torch.cat([cond_confidence*tgt, (1-cond_confidence)*out], axis=1))
|
| out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1))
|
| return out, tgt
|
|
|
| def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.):
|
| sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
|
| return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
|
| zip(sample_coords, max_sz)], 1)
|
|
|
| def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"):
|
| ref = self.ref_grid if ref is None else ref
|
| img_sz = self.max_sz if img_sz is None else img_sz
|
| resample_mode = 'bilinear'
|
|
|
| return F.grid_sample(vol, torch.flip((ddf * torch.Tensor(
|
| np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute(
|
| [0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,
|
| align_corners=True)
|
|
|
| def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2):
|
| self.device = x.device
|
| img_sz = x.size()[2:]
|
| n = x.size()[0]
|
| self.max_sz = [img_sz[0]] * self.dimension
|
| ts_emb_shape=[n,-1]+[1]*self.dimension
|
|
|
| self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension])
|
| if list(img_sz) != self.img_res:
|
| self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0),
|
| [1, self.dimension]+list(img_sz))
|
| self.ref_grid = self.ref_grid.to(self.device)
|
|
|
| img = x
|
| t = self.time_embed(t)
|
| if text is None:
|
| text = self.text
|
| text = text.to(self.device)
|
| txt_shape = [1,-1]+[1]*self.dimension
|
| else:
|
| txt_shape = [n,-1]+[1]*self.dimension
|
|
|
|
|
| txt_vec = text.view(text.size(0), -1)
|
| if txt_vec.size(0) == 1 and n > 1:
|
| txt_vec = txt_vec.expand(n, -1)
|
| w_txt = self.text_gate(txt_vec)
|
| w_txt = w_txt.view([w_txt.size(0), 1] + [1] * self.dimension)
|
|
|
| for rec_id in range(rec_num):
|
| if self.conditional_input:
|
| tgt = y
|
| enc_list = []
|
| out = img
|
| for i in range(self.hier_num):
|
|
|
|
|
| if self.use_checkpoint and self.training and i < 3:
|
| out, tgt = grad_checkpoint(
|
| self._encoder_level, i, out, tgt if self.conditional_input else None,
|
| t, ts_emb_shape, text, txt_shape, w_txt,
|
| use_reentrant=False,
|
| )
|
| else:
|
| out, tgt = self._encoder_level(
|
| i, out, tgt if self.conditional_input else None,
|
| t, ts_emb_shape, text, txt_shape, w_txt,
|
| )
|
| enc_list.append(out)
|
| out = self.down_layers[i](out)
|
| if self.conditional_input:
|
| tgt = self.down_layers[i](tgt)
|
|
|
| out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape))
|
| if self.conditional_input:
|
| out_shape = out.shape
|
| tgt_shape = tgt.shape
|
| out_flat = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1)
|
| tgt_flat = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1)
|
| out_attn, _ = self.attn_layer0(out_flat, tgt_flat, tgt_flat)
|
| tgt_attn, _ = self.attn_layer1(tgt_flat, out_flat, out_flat)
|
| out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape)
|
| tgt_attn = tgt_attn.permute(1, 2, 0).contiguous().view(tgt_shape)
|
| out = out + out_attn
|
| tgt = tgt + tgt_attn
|
| out = self.fuse(torch.cat([out, tgt], dim=1))
|
|
|
| if self.conditional_input:
|
| img_txt_feat = self.img2txt(out)
|
| self.img_embd = self.global_maxpool(img_txt_feat).view(n, -1)
|
| out_txt = self.txt_layers[-1](text).reshape(txt_shape) - img_txt_feat
|
| out_txt = self.txt_proc(out_txt)
|
| out_txt = self.txt2img(out_txt)
|
|
|
|
|
| out = (1 - w_txt) * out + w_txt * out_txt
|
|
|
| for i in range(self.hier_num):
|
| out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1)
|
| out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape))
|
|
|
| out = self.conv_out(out)/128
|
|
|
| ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz)
|
| if rec_id == 0:
|
| ddf = ddf_one
|
| else:
|
| ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border")
|
| img = self.resample(x, ddf=ddf, img_sz=self.img_sz)
|
|
|
| return ddf
|
|
|
| def _make_te(self, dim_in, dim_out):
|
| return nn.Sequential(
|
| nn.Linear(dim_in, dim_out),
|
| nn.ReLU(),
|
| nn.Linear(dim_out, dim_out)
|
| )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def ddf_multiplier(dvf,mul_num=10,stn=None):
|
| ddf=dvf
|
| for i in range(mul_num):
|
| ddf = dvf + stn(ddf, dvf)
|
| return ddf
|
|
|
|
|
| def composite(ddfs,stn=None):
|
| if stn is None:
|
| stn = STN(device=ddfs[0].device,padding_mode="border")
|
| comp_ddf=ddfs[0]
|
| for i in range(1,len(ddfs)):
|
| comp_ddf = ddfs[i] + stn(comp_ddf,ddfs[i])
|
| return comp_ddf
|
|
|
|
|
|
|
| class STN(nn.Module):
|
| def __init__(self,ndims=2,img_sz=None,max_sz=None,device=None,padding_mode="border",resample_mode=None):
|
| super(STN, self).__init__()
|
| self.ndims=ndims
|
| self.img_sz=[img_sz]*ndims
|
|
|
| self.device = device
|
| self.padding_mode = padding_mode
|
|
|
| max_sz=[img_sz]*self.ndims
|
|
|
|
|
|
|
| self.max_sz=torch.Tensor(np.reshape(np.array(max_sz), [1, self.ndims]+[1]*self.ndims)).to(self.device)
|
| self.resample_mode=resample_mode
|
| if self.img_sz is not None:
|
| self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=s) for s in self.img_sz]), 0),
|
| [1, self.ndims] + self.img_sz).to(self.device)
|
| return
|
| def max_limit(self, sample_coords0, plus=0., minus=1.):
|
| sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
|
|
|
| return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in
|
| zip(sample_coords, self.max_sz)], 1)
|
|
|
| def boundary_limit(self, sample_coords0, plus=0., minus=1.):
|
|
|
| sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1)
|
|
|
| return torch.cat([(torch.clamp(x * sz+ref, min=minus - 1 * sz + plus, max=1 * sz - minus + plus)-ref) / sz for x, sz,ref in
|
| zip(sample_coords, self.max_sz, self.ref_grid)], 1)
|
|
|
| def resample(self, vol, ddf, ref=None, img_sz=None,padding_mode = "zeros"):
|
|
|
|
|
|
|
| device = ddf.device
|
|
|
| ref = self.ref_grid if ref is None else ref
|
| if img_sz is None:
|
| img_sz = self.max_sz
|
| else:
|
| img_sz = torch.reshape(torch.tensor([(s - 1) / 2. for s in img_sz], device=device), [1]+[1]*self.ndims+[self.ndims])
|
|
|
| if self.resample_mode is None:
|
| resample_mode = 'bilinear'
|
| else:
|
| resample_mode=self.resample_mode
|
|
|
|
|
| return F.grid_sample(vol.to(device), torch.flip((ddf * self.max_sz.to(device) + ref.to(device)).permute(
|
| [0] + list(range(2, 2 + self.ndims)) + [1]) / img_sz - 1, dims=[-1]), mode=resample_mode,
|
| padding_mode=padding_mode,
|
| align_corners=True)
|
|
|
| def forward(self,x,ddf):
|
| self.device = x.device if self.device is None else self.device
|
| if self.img_sz is None:
|
| self.img_sz = list(x.size()[2:]).to(self.device)
|
| self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=s) for s in self.img_sz]), 0),[1, self.ndims]+self.img_sz).to(self.device)
|
| resampled_x = self.resample(x, ddf=ddf, img_sz=self.img_sz, padding_mode=self.padding_mode)
|
| return resampled_x
|
|
|
|
|
| if __name__ == '__main__':
|
| ndims = 3
|
| res = 128
|
| x = torch.rand([1, 1] + [res]*ndims)
|
| t = torch.randint(0, 1000, (1,))
|
| text = torch.rand([1, 1024] + [1]*ndims)
|
| model = RecMutAttnNet(n_steps=1000, time_emb_dim=100, ndims=ndims, num_input_chn=1, res=res, conditional_input=True)
|
| y = model(x, x, t, text=text)
|
| print("Ouput shape", y.shape)
|
|
|
|
|
| total_params = sum(p.numel() for p in model.parameters())
|
|
|
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
| print(f"Total parameters: {total_params}")
|
| print(f"Trainable parameters: {trainable_params}") |