from torch import nn import torch import torch.nn.functional as F from torch.utils.checkpoint import checkpoint as grad_checkpoint import numpy as np import math from Diffusion.safe_conv_transpose import SafeConvTranspose3d class UpsampleConv(nn.Module): """Drop-in replacement for ConvTranspose3d/2d that avoids the XPU memory leak. ConvTranspose3d backward leaks ~0.33 GiB/step on Intel XPU (oneDNN bug). This uses F.interpolate (zero leak) + Conv (negligible leak) instead. Also avoids checkerboard artifacts common with transposed convolutions. """ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, ndims=3): super().__init__() self.scale_factor = stride self.mode = 'trilinear' if ndims == 3 else 'bilinear' Conv = getattr(nn, f'Conv{ndims}d') self.conv = Conv(in_channels, out_channels, 3, 1, 1) def forward(self, x): x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=False) return self.conv(x) def get_net(name="recresnet"): name = name.lower() if name == "recresacnet": net = RecResACNet elif name == "recmutattnnet": net = RecMutAttnNet elif name == "recmutattnnet0": net = RecMutAttnNet0 elif name == "recmutattnnet1": net = RecMutAttnNet1 elif name == "defrecmutattnnet": net = DefRec_MutAttnNet elif name == "recmulmodmutattnnet": net = RecMulModMutAttnNet elif name == "om_net": net = OM_net else: net = None return net def sinusoidal_embedding(n, d): # Returns the standard positional embedding embedding = torch.zeros(n, d) wk = torch.tensor([1 / 10_000 ** (2 * j / d) for j in range(d)]) wk = wk.reshape((1, d)) t = torch.arange(n).reshape((n, 1)) embedding[:,::2] = torch.sin(t * wk[:,::2]) embedding[:,1::2] = torch.cos(t * wk[:,::2]) return embedding class AtrousBlock(nn.Module): def __init__(self, shape, in_c, out_c, kernel_size=3, stride=1, atrous_rates=[1,3], ndims=2, activation=None, normalize=True): super(AtrousBlock, self).__init__() # if 0 not in shape: if normalize: # print(shape) # self.ln = nn.LayerNorm(shape) # jzheng 15/03/2024 norm=getattr(nn, 'InstanceNorm%dd' % ndims) # jzheng 15/03/2024 self.ln = norm(out_c,affine=True) else: self.ln = nn.Identity() Conv=getattr(nn,'Conv%dd' % ndims) if in_c!=out_c: self.conv0 = Conv(in_c, out_c, kernel_size, 1, (kernel_size-1)//2*1) #if in_c!=out_c else None else: self.conv0 = None self.convs = nn.ModuleList([ Conv(out_c, out_c, kernel_size, 1, (kernel_size-1)//2*ar, dilation=ar) if ar>0 else Conv(out_c, out_c, 1, 1, 0) for ar in atrous_rates ]) # self.conv1 = Conv(out_c, out_c, kernel_size, stride, padding) # self.conv2 = Conv(out_c, out_c, kernel_size, stride, padding) self.activation = nn.LeakyReLU(1e-6) if activation is None else activation # self.activation = nn.ReLU() if activation is None else activation # self.activation = nn.ReLU() self.normalize = normalize def forward(self, x): if self.conv0 is not None: x = self.conv0(x) #if self.conv0 is not None else x x = self.ln(x) if self.normalize else x # jzheng 15/03/2024 out=nn.Identity()(x) for conv in self.convs: out = self.activation(out) out = conv(out) return self.activation(out+x) # ============================================== # Unconditional Network # ============================================== class RecResACNet(nn.Module): def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0): super(RecResACNet, self).__init__() self.dimension = ndims self.Conv = getattr(nn, 'Conv%dd' % self.dimension) self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension) # Sinusoidal embedding self.time_embed = nn.Embedding(n_steps, time_emb_dim) self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim) self.time_embed.requires_grad_(False) # First half self.te1 = self._make_te(time_emb_dim, 1) self.b1 = nn.Sequential( AtrousBlock([num_input_chn] + [res] * ndims, num_input_chn, 10, ndims=ndims), AtrousBlock([10] + [res] * ndims, 10, 10, ndims=ndims), AtrousBlock([10] + [res] * ndims, 10, 10, ndims=ndims), ) self.down1 = self.Conv(10, 10, 4, 2, 1) self.te2 = self._make_te(time_emb_dim, 10) self.b2 = nn.Sequential( AtrousBlock([10] + [res // 2] * ndims, 10, 20, ndims=ndims), AtrousBlock([20] + [res // 2] * ndims, 20, 20, ndims=ndims), AtrousBlock([20] + [res // 2] * ndims, 20, 20, ndims=ndims) ) self.down2 = self.Conv(20, 20, 4, 2, 1) self.te3 = self._make_te(time_emb_dim, 20) self.b3 = nn.Sequential( AtrousBlock([20] + [res // 4] * ndims, 20, 40, ndims=ndims), AtrousBlock([40] + [res // 4] * ndims, 40, 40, ndims=ndims), AtrousBlock([40] + [res // 4] * ndims, 40, 40, ndims=ndims) ) self.down3 = self.Conv(40, 40, 4, 2, 1) # Bottleneck self.te_mid = self._make_te(time_emb_dim, 40) self.b_mid = nn.Sequential( AtrousBlock([40] + [res // 8] * ndims, 40, 20, ndims=ndims), AtrousBlock([20] + [res // 8] * ndims, 20, 20, ndims=ndims), AtrousBlock([20] + [res // 8] * ndims, 20, 40, ndims=ndims) ) # Second half self.up1 = self.ConvT(40, 40, 4, 2, 1) self.te4 = self._make_te(time_emb_dim, 80) self.b4 = nn.Sequential( AtrousBlock([80] + [res // 4] * ndims, 80, 40, ndims=ndims, normalize=False), AtrousBlock([40] + [res // 4] * ndims, 40, 20, ndims=ndims, normalize=False), AtrousBlock([20] + [res // 4] * ndims, 20, 20, ndims=ndims, normalize=False) ) self.up2 = self.ConvT(20, 20, 4, 2, 1) self.te5 = self._make_te(time_emb_dim, 40) self.b5 = nn.Sequential( AtrousBlock([40] + [res // 2] * ndims, 40, 20, ndims=ndims, normalize=False), AtrousBlock([20] + [res // 2] * ndims, 20, 10, ndims=ndims, normalize=False), AtrousBlock([10] + [res // 2] * ndims, 10, 10, ndims=ndims, normalize=False) ) self.up3 = self.ConvT(10, 10, 4, 2, 1) self.te_out = self._make_te(time_emb_dim, 20) self.b_out = nn.Sequential( AtrousBlock([20] + [res // 1] * ndims, 20, 10, ndims=ndims, normalize=False), AtrousBlock([10] + [res // 1] * ndims, 10, 10, ndims=ndims, normalize=False), AtrousBlock([10] + [res // 1] * ndims, 10, 10, ndims=ndims, normalize=False) ) self.conv_out = self.Conv(10, ndims, 3, 1, 1) def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.): sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1) return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in zip(sample_coords, max_sz)], 1) def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"): ref = self.ref_grid if ref is None else ref img_sz = self.max_sz if img_sz is None else img_sz # resample_mode = 'bicubic' resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear' # padding_mode = "border" if True: # return F.grid_sample(vol, torch.flip(torch.transpose(ddf * torch.Tensor(np.reshape(np.array(self.max_sz), [1, 1, 1, self.dimension])).cuda() + ref,[0, 2, 3, 1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode,align_corners=True) return F.grid_sample(vol, torch.flip((ddf * torch.Tensor( np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute( [0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode, align_corners=True) def forward(self, x=None, t=None, y=None, rec_num=2, ndims=2): # self.device = x.device # [h, w] = x.size()[2:] img_sz = x.size()[2:] n = x.size()[0] self.max_sz = [img_sz[0]] * self.dimension ts_emb_shape=[n,-1]+[1]*self.dimension # [h,w]=img_sz # self.img_sz = torch.reshape(torch.tensor([(h - 1) / 2., (w - 1) / 2.], device=self.device), [1, 1, 1, 2]) self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension]) # self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=h), torch.arange(end=w)]), 0), # [1, 2, h, w]).to(self.device) self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0), [1, self.dimension]+list(img_sz)).to(self.device) img = x # x is (N, 2, 28, 28) (image with positional embedding stacked on channel dimension) t = self.time_embed(t) for rec_id in range(rec_num): out1 = self.b1(img + self.te1(t).reshape(ts_emb_shape)) # (N, 10, 28, 28) out2 = self.b2(self.down1(out1) + self.te2(t).reshape(ts_emb_shape)) # (N, 20, 14, 14) out3 = self.b3(self.down2(out2) + self.te3(t).reshape(ts_emb_shape)) # (N, 40, 7, 7) out_mid = self.b_mid(self.down3(out3) * self.te_mid(t).reshape(ts_emb_shape)) # (N, 40, 3, 3) out4 = torch.cat((out3, self.up1(out_mid)), dim=1) # (N, 80, 7, 7) out4 = self.b4(out4 + self.te4(t).reshape(ts_emb_shape)) # (N, 20, 7, 7) out5 = torch.cat((out2, self.up2(out4)), dim=1) # (N, 40, 14, 14) out5 = self.b5(out5 + self.te5(t).reshape(ts_emb_shape)) # (N, 10, 14, 14) out = torch.cat((out1, self.up3(out5)), dim=1) # (N, 20, 28, 28) out = self.b_out(out + self.te_out(t).reshape(ts_emb_shape)) # (N, 1, 28, 28) out = self.conv_out(out) ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz) if rec_id == 0: ddf = ddf_one else: ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border") img = self.resample(x, ddf=ddf, img_sz=self.img_sz) return ddf def _make_te(self, dim_in, dim_out): # make time embedding return nn.Sequential( nn.Linear(dim_in, dim_out), # nn.SiLU(), nn.ReLU(), nn.Linear(dim_out, dim_out) ) # ============================================== # Conditional Network # ============================================== class cross_attn(nn.Module): def __init__(self, q, k, v, ndims=2): self.q = q self.k = k self.v = v self.ndims = ndims self.Conv = getattr(nn, 'Conv%dd' % self.ndims) self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.ndims) self.softmax = nn.Softmax(dim=-1) self.gamma = nn.Parameter(torch.zeros(1)) def forward(self, x, y): q = self.q(x) k = self.k(y) v = self.v(y) attn = self.softmax(torch.matmul(q, k.transpose(-2, -1))) out = torch.matmul(attn, v) return out class DefRec_MutAttnNet(nn.Module): def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4): super(DefRec_MutAttnNet, self).__init__() # self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64] # self.feat_channels = [num_input_chn, 16, 32, 64, 128, 256] self.feat_channels = [num_input_chn, 16, 32, 128, 256, 512] self.conditional_input = conditional_input self.num_heads = num_heads self.text_feat_chn = text_feat_chn self.dimension = ndims self.Conv = getattr(nn, 'Conv%dd' % self.dimension) self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension) self.copy = nn.Identity() # Sinusoidal embedding self.time_embed = nn.Embedding(n_steps, time_emb_dim) self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim) self.time_embed.requires_grad_(False) self.hier_num = len(self.feat_channels) - 1 self.down_layers = nn.ModuleList() self.up_layers = nn.ModuleList() self.ted_layers = nn.ModuleList() self.teu_layers = nn.ModuleList() self.block_down = nn.ModuleList() self.block_up = nn.ModuleList() if self.conditional_input: self.block_down_cond = nn.ModuleList() self.fuse_conv0 = nn.ModuleList() # self.fuse_conv1 = nn.ModuleList() self.attn_layer = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads) Global_Maxpool = getattr(nn, 'AdaptiveMaxPool%dd' % self.dimension) self.global_maxpool = Global_Maxpool(1) self.img2txt = self.Conv(self.feat_channels[-1], self.text_feat_chn, 1, 1, 0) self.txt_proc = AtrousBlock([self.text_feat_chn] + [1] * ndims, self.text_feat_chn, self.text_feat_chn, ndims=ndims, normalize=False, atrous_rates=[0, 0]) self.txt2img = self.Conv(self.text_feat_chn, self.feat_channels[-1], 1, 1, 0) self.text = torch.zeros(1, self.text_feat_chn, *([1]*self.dimension)) self.img_res = [res]*self.dimension self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in self.img_res]), 0), [1, self.dimension]+list(self.img_res)) for i in range(1, self.hier_num + 1): j=-i self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1)) self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1)) self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1])) self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j])) self.block_down.append(nn.Sequential( AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims), AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims), AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims) )) if self.conditional_input: self.block_down_cond.append(nn.Sequential( AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims), AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims), AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims) )) self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0)) # self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0)) if i==self.hier_num: k=j else: k=j-1 self.block_up.append(nn.Sequential( AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False), AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False), AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False) )) # Bottleneck self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1]) self.b_mid = nn.Sequential( AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims), AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims), AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims) ) self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1) def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.): sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1) return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in zip(sample_coords, max_sz)], 1) def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"): ref = self.ref_grid if ref is None else ref img_sz = self.max_sz if img_sz is None else img_sz resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear' return F.grid_sample(vol, torch.flip((ddf * torch.Tensor( np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute( [0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode, align_corners=True) def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2): self.device = x.device img_sz = x.size()[2:] n = x.size()[0] self.max_sz = [img_sz[0]] * self.dimension ts_emb_shape=[n,-1]+[1]*self.dimension self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension]) if list(img_sz) != self.img_res: # print ("Reinitialize the ref_grid to match the model's input image size.") # print(img_sz, self.img_res) self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0), [1, self.dimension]+list(img_sz)) self.ref_grid = self.ref_grid.to(self.device) img = x if self.conditional_input: tgt = y # encode the conditional input tgt_down_list = [] for i in range(self.hier_num): # out = self.block_down[i](out + self.ted_layers[i](t_emb).reshape(ts_emb_shape)) if self.conditional_input: tgt = self.block_down_cond[i](tgt) tgt_down_list.append(self.copy(tgt)) tgt = self.down_layers[i](tgt) tgt_mid = self.copy(tgt) tgt_shape = tgt_mid.shape # out = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C) tgt_mid = tgt_mid.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C) t = [t0.to(self.device) for t0 in t] t = [t0 for _ in range(rec_num) for t0 in t] for rec_id,time in enumerate(t): t_emb = self.time_embed(time) # for rec_id in range(rec_num): # if self.conditional_input: # tgt = y enc_list = [] out = img for i in range(self.hier_num): out = self.block_down[i](out + self.ted_layers[i](t_emb).reshape(ts_emb_shape)) if self.conditional_input: # tgt = self.block_down_cond[i](tgt) out = self.fuse_conv0[i](torch.cat([out, tgt_down_list[i]], axis=1)) # tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1)) enc_list.append(out) out = self.down_layers[i](out) # if self.conditional_input: # tgt = self.down_layers[i](tgt) out = self.b_mid(out + self.tmid(t_emb).reshape(ts_emb_shape)) if self.conditional_input: # out += self.attn_layer(out, tgt, tgt)[0] out_shape = out.shape # tgt_shape = tgt.shape # # out = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C) # tgt = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C) out_attn, _ = self.attn_layer(out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1), tgt_mid, tgt_mid) out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape) # (H*W, N, C) -> (N, C, H, W) out = out + out_attn if self.conditional_input: if text is None: text = self.text text = text.to(self.device) out_txt = self.img2txt(out) + text out_txt = self.txt_proc(out_txt) out_txt = self.txt2img(out_txt) out = out + out_txt for i in range(self.hier_num): out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1) out = self.block_up[i](out + self.teu_layers[i](t_emb).reshape(ts_emb_shape)) out = self.conv_out(out)/128 ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz) if rec_id == 0: ddf = ddf_one else: ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border") img = self.resample(x, ddf=ddf, img_sz=self.img_sz) return ddf def _make_te(self, dim_in, dim_out): return nn.Sequential( nn.Linear(dim_in, dim_out), nn.ReLU(), nn.Linear(dim_out, dim_out) ) class RecMutAttnNet1(nn.Module): def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4): super(RecMutAttnNet1, self).__init__() # self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64] self.feat_channels = [num_input_chn, 16, 32, 64, 128, 256] self.conditional_input = conditional_input self.num_heads = num_heads self.text_feat_chn = text_feat_chn self.dimension = ndims self.Conv = getattr(nn, 'Conv%dd' % self.dimension) self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension) # Sinusoidal embedding self.time_embed = nn.Embedding(n_steps, time_emb_dim) self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim) self.time_embed.requires_grad_(False) self.hier_num = len(self.feat_channels) - 1 self.down_layers = nn.ModuleList() self.up_layers = nn.ModuleList() self.ted_layers = nn.ModuleList() self.teu_layers = nn.ModuleList() self.block_down = nn.ModuleList() if self.conditional_input: self.block_down_cond = nn.ModuleList() self.fuse_conv0 = nn.ModuleList() self.fuse_conv1 = nn.ModuleList() self.attn_layer = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads) self.block_up = nn.ModuleList() for i in range(1, self.hier_num + 1): j=-i self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1)) self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1)) self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1])) self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j])) self.block_down.append(nn.Sequential( AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims), AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims), AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims) )) if self.conditional_input: self.block_down_cond.append(nn.Sequential( AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims), AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims), AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims) )) self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0)) self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0)) if i==self.hier_num: k=j else: k=j-1 self.block_up.append(nn.Sequential( AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False), AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False), AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False) )) # Bottleneck self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1]) self.b_mid = nn.Sequential( AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims), AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims), AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims) ) self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1) def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.): sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1) return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in zip(sample_coords, max_sz)], 1) def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"): ref = self.ref_grid if ref is None else ref img_sz = self.max_sz if img_sz is None else img_sz resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear' return F.grid_sample(vol, torch.flip((ddf * torch.Tensor( np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute( [0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode, align_corners=True) def forward(self, x=None, y=None, t=None, rec_num=2, ndims=2): self.device = x.device img_sz = x.size()[2:] n = x.size()[0] self.max_sz = [img_sz[0]] * self.dimension ts_emb_shape=[n,-1]+[1]*self.dimension self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension]) self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0), [1, self.dimension]+list(img_sz)).to(self.device) img = x t = self.time_embed(t) for rec_id in range(rec_num): if self.conditional_input: tgt = y enc_list = [] out = img for i in range(self.hier_num): out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape)) if self.conditional_input: tgt = self.block_down_cond[i](tgt) out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1)) tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1)) enc_list.append(out) out = self.down_layers[i](out) if self.conditional_input: tgt = self.down_layers[i](tgt) out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape)) if self.conditional_input: # out += self.attn_layer(out, tgt, tgt)[0] out_shape = out.shape tgt_shape = tgt.shape # out = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C) tgt = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C) out_attn, _ = self.attn_layer(out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1), tgt, tgt) out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape) # (H*W, N, C) -> (N, C, H, W) out = out + out_attn for i in range(self.hier_num): out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1) out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape)) out = self.conv_out(out)/128 ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz) if rec_id == 0: ddf = ddf_one else: ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border") img = self.resample(x, ddf=ddf, img_sz=self.img_sz) return ddf def _make_te(self, dim_in, dim_out): return nn.Sequential( nn.Linear(dim_in, dim_out), nn.ReLU(), nn.Linear(dim_out, dim_out) ) class RecMutAttnNet(nn.Module): def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4): super(RecMutAttnNet, self).__init__() # self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64] self.feat_channels = [num_input_chn, 16, 32, 64, 128, 256] self.conditional_input = conditional_input self.num_heads = num_heads self.text_feat_chn = text_feat_chn self.dimension = ndims self.Conv = getattr(nn, 'Conv%dd' % self.dimension) self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension) # Sinusoidal embedding self.time_embed = nn.Embedding(n_steps, time_emb_dim) self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim) self.time_embed.requires_grad_(False) self.hier_num = len(self.feat_channels) - 1 self.down_layers = nn.ModuleList() self.up_layers = nn.ModuleList() self.ted_layers = nn.ModuleList() self.teu_layers = nn.ModuleList() self.block_down = nn.ModuleList() self.block_up = nn.ModuleList() if self.conditional_input: self.block_down_cond = nn.ModuleList() self.fuse_conv0 = nn.ModuleList() self.fuse_conv1 = nn.ModuleList() self.attn_layer = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads) Global_Maxpool = getattr(nn, 'AdaptiveMaxPool%dd' % self.dimension) self.global_maxpool = Global_Maxpool(1) self.img2txt = self.Conv(self.feat_channels[-1], self.text_feat_chn, 1, 1, 0) self.txt_proc = AtrousBlock([self.text_feat_chn] + [1] * ndims, self.text_feat_chn, self.text_feat_chn, ndims=ndims, normalize=False, atrous_rates=[0, 0]) self.txt2img = self.Conv(self.text_feat_chn, self.feat_channels[-1], 1, 1, 0) self.text = torch.zeros(1, self.text_feat_chn, *([1]*self.dimension)) self.img_res = [res]*self.dimension self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in self.img_res]), 0), [1, self.dimension]+list(self.img_res)) for i in range(1, self.hier_num + 1): j=-i self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1)) self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1)) self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1])) self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j])) self.block_down.append(nn.Sequential( AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims), AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims), AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims) )) if self.conditional_input: self.block_down_cond.append(nn.Sequential( AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims), AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims), AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims) )) self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0)) self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0)) if i==self.hier_num: k=j else: k=j-1 self.block_up.append(nn.Sequential( AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False), AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False), AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False) )) # Bottleneck self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1]) self.b_mid = nn.Sequential( AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims), AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims), AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims) ) self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1) def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.): sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1) return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in zip(sample_coords, max_sz)], 1) def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"): ref = self.ref_grid if ref is None else ref img_sz = self.max_sz if img_sz is None else img_sz resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear' return F.grid_sample(vol, torch.flip((ddf * torch.Tensor( np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute( [0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode, align_corners=True) def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2): self.device = x.device img_sz = x.size()[2:] n = x.size()[0] self.max_sz = [img_sz[0]] * self.dimension ts_emb_shape=[n,-1]+[1]*self.dimension self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension]) if list(img_sz) != self.img_res: # print ("Reinitialize the ref_grid to match the model's input image size.") # print(img_sz, self.img_res) self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0), [1, self.dimension]+list(img_sz)) self.ref_grid = self.ref_grid.to(self.device) img = x t = self.time_embed(t) for rec_id in range(rec_num): if self.conditional_input: tgt = y enc_list = [] out = img for i in range(self.hier_num): out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape)) if self.conditional_input: tgt = self.block_down_cond[i](tgt) out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1)) tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1)) enc_list.append(out) out = self.down_layers[i](out) if self.conditional_input: tgt = self.down_layers[i](tgt) out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape)) if self.conditional_input: # out += self.attn_layer(out, tgt, tgt)[0] out_shape = out.shape tgt_shape = tgt.shape # out = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C) tgt = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C) out_attn, _ = self.attn_layer(out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1), tgt, tgt) out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape) # (H*W, N, C) -> (N, C, H, W) out = out + out_attn if self.conditional_input: if text is None: text = self.text text = text.to(self.device) text = text.view(-1, self.text_feat_chn, *([1]*self.dimension)) out_txt = self.img2txt(out) + text out_txt = self.txt_proc(out_txt) out_txt = self.txt2img(out_txt) out = out + out_txt for i in range(self.hier_num): out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1) out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape)) out = self.conv_out(out)/128 ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz) if rec_id == 0: ddf = ddf_one else: ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border") img = self.resample(x, ddf=ddf, img_sz=self.img_sz) # print(torch.max(torch.abs(ddf))) return ddf def _make_te(self, dim_in, dim_out): return nn.Sequential( nn.Linear(dim_in, dim_out), nn.ReLU(), nn.Linear(dim_out, dim_out) ) class RecMulModMutAttnNet(nn.Module): def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True,text_feat_chn=1024, num_heads=4): super(RecMulModMutAttnNet, self).__init__() # self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64] self.feat_channels = [num_input_chn, 16, 32, 64, 128, 256] self.conditional_input = conditional_input self.num_heads = num_heads self.text_feat_chn = text_feat_chn self.dimension = ndims self.Conv = getattr(nn, 'Conv%dd' % self.dimension) self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension) # Sinusoidal embedding self.time_embed = nn.Embedding(n_steps, time_emb_dim) self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim) self.time_embed.requires_grad_(False) self.hier_num = len(self.feat_channels) - 1 self.down_layers = nn.ModuleList() self.up_layers = nn.ModuleList() self.ted_layers = nn.ModuleList() self.teu_layers = nn.ModuleList() self.block_down = nn.ModuleList() self.block_up = nn.ModuleList() if self.conditional_input: # self.gate_img = nn.ModuleList() self.txt_layers = nn.ModuleList() self.block_down_cond = nn.ModuleList() self.fuse_conv0 = nn.ModuleList() self.fuse_conv1 = nn.ModuleList() self.attn_layer0 = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads) self.attn_layer1 = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads) Global_Maxpool = getattr(nn, 'AdaptiveMaxPool%dd' % self.dimension) self.global_maxpool = Global_Maxpool(1) self.img2txt = self.Conv(self.feat_channels[-1], self.text_feat_chn, 1, 1, 0) self.txt_proc = AtrousBlock([self.text_feat_chn] + [1] * ndims, self.text_feat_chn, self.text_feat_chn, ndims=ndims, normalize=False, atrous_rates=[0, 0]) self.txt2img = self.Conv(self.text_feat_chn, self.feat_channels[-1], 1, 1, 0) # self.text = torch.zeros(1, self.text_feat_chn, *([1]*self.dimension)) self.text = torch.zeros(1, self.text_feat_chn) self.img_res = [res]*self.dimension self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in self.img_res]), 0), [1, self.dimension]+list(self.img_res)) for i in range(1, self.hier_num + 1): j=-i self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1)) self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1)) self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1])) self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j])) self.block_down.append(nn.Sequential( AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims), AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims), AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims) )) if self.conditional_input: # self.gate_img.append(nn.Sequential( # nn.ConvNd(self.dimension, self.feat_channels[i], self.feat_channels[i], kernel_size=1, stride=1, padding=0), # nn.Sigmoid() # )) self.txt_layers.append((self._make_te(self.text_feat_chn, self.feat_channels[i]))) self.block_down_cond.append(nn.Sequential( AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims), AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims), AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims) )) self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0)) self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0)) if i==self.hier_num: k=j else: k=j-1 self.block_up.append(nn.Sequential( AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False), AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False), AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False) )) # Bottleneck self.txt_layers.append((self._make_te(self.text_feat_chn, self.text_feat_chn))) self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1]) self.b_mid = nn.Sequential( AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims), AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims), AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims) ) self.fuse = self.Conv(2*self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], 1, 1, 0) self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1) def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.): sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1) return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in zip(sample_coords, max_sz)], 1) def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"): ref = self.ref_grid if ref is None else ref img_sz = self.max_sz if img_sz is None else img_sz resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear' return F.grid_sample(vol, torch.flip((ddf * torch.Tensor( np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute( [0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode, align_corners=True) def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2): self.device = x.device img_sz = x.size()[2:] n = x.size()[0] self.max_sz = [img_sz[0]] * self.dimension ts_emb_shape=[n,-1]+[1]*self.dimension self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension]) if list(img_sz) != self.img_res: # print ("Reinitialize the ref_grid to match the model's input image size.") # print(img_sz, self.img_res) self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0), [1, self.dimension]+list(img_sz)) self.ref_grid = self.ref_grid.to(self.device) img = x t = self.time_embed(t) if text is None: text = self.text # print(text.shape) text = text.to(self.device) txt_shape = [1,-1]+[1]*self.dimension else: txt_shape = [n,-1]+[1]*self.dimension for rec_id in range(rec_num): if self.conditional_input: tgt = y enc_list = [] out = img for i in range(self.hier_num): out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape)) if self.conditional_input: tgt = self.block_down_cond[i](tgt) + self.txt_layers[i](text).reshape(txt_shape) out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1)) tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1)) enc_list.append(out) out = self.down_layers[i](out) if self.conditional_input: tgt = self.down_layers[i](tgt) out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape)) if self.conditional_input: # out += self.attn_layer(out, tgt, tgt)[0] out_shape = out.shape tgt_shape = tgt.shape out_flat = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C) tgt_flat = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) # (N, C, H, W) -> (H*W, N, C) out_attn, _ = self.attn_layer0(out_flat, tgt_flat, tgt_flat) tgt_attn, _ = self.attn_layer1(tgt_flat, out_flat, out_flat) out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape) # (H*W, N, C) -> (N, C, H, W) tgt_attn = tgt_attn.permute(1, 2, 0).contiguous().view(tgt_shape) # (H*W, N, C) -> (N, C, H, W) out = out + out_attn tgt = tgt + tgt_attn out = self.fuse(torch.cat([out, tgt], dim=1)) if self.conditional_input: # text = text.view(-1, self.text_feat_chn, *([1]*self.dimension)) # out_txt = self.img2txt(out) + text.reshape(txt_shape) img_txt_feat = self.img2txt(out) self.img_embd = self.global_maxpool(img_txt_feat).view(n, -1) # [B, 1024] out_txt = self.txt_layers[-1](text).reshape(txt_shape) + img_txt_feat out_txt = self.txt_proc(out_txt) out_txt = self.txt2img(out_txt) out = out + out_txt for i in range(self.hier_num): out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1) out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape)) out = self.conv_out(out)/128 ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz) if rec_id == 0: ddf = ddf_one else: ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border") img = self.resample(x, ddf=ddf, img_sz=self.img_sz) # print(torch.max(torch.abs(ddf))) return ddf def _make_te(self, dim_in, dim_out): return nn.Sequential( nn.Linear(dim_in, dim_out), nn.ReLU(), nn.Linear(dim_out, dim_out) ) class OM_net(nn.Module): """ Extended RecMulModMutAttnNet with gated attention mechanisms: 1. Text Gate (bottleneck): sigmoid weight w_txt to interpolate between text-enhanced features and raw image features. Learns to suppress text branch when text embedding is zeros (no text provided). 2. Target Gate (each encoder level): per-voxel spatial gate using residual AtrousBlock to identify condition vs. noise voxels in the target/condition image path, weighting the fuse_conv1 output. Supports gradient checkpointing via `use_checkpoint` flag to reduce peak activation memory (trades compute for memory). """ def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True, text_feat_chn=1024, num_heads=4, use_conv_transpose=False): super(OM_net, self).__init__() self.use_checkpoint = False # Set True to enable gradient checkpointing self.use_conv_transpose = use_conv_transpose self.feat_channels = [num_input_chn, 12, 32, 64, 128, 512] self.conditional_input = conditional_input self.num_heads = num_heads self.text_feat_chn = text_feat_chn self.dimension = ndims self.Conv = getattr(nn, 'Conv%dd' % self.dimension) self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension) # Sinusoidal embedding self.time_embed = nn.Embedding(n_steps, time_emb_dim) self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim) self.time_embed.requires_grad_(False) self.hier_num = len(self.feat_channels) - 1 self.down_layers = nn.ModuleList() self.up_layers = nn.ModuleList() self.ted_layers = nn.ModuleList() self.teu_layers = nn.ModuleList() self.block_down = nn.ModuleList() self.block_up = nn.ModuleList() if self.conditional_input: self.txt_layers = nn.ModuleList() self.block_down_cond = nn.ModuleList() self.fuse_conv0 = nn.ModuleList() self.fuse_conv1 = nn.ModuleList() self.tgt_gate = nn.ModuleList() # Target gate per encoder level self.attn_layer0 = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads) self.attn_layer1 = nn.MultiheadAttention(self.feat_channels[-1], self.num_heads) Global_Maxpool = getattr(nn, 'AdaptiveMaxPool%dd' % self.dimension) self.global_maxpool = Global_Maxpool(1) self.img2txt = self.Conv(self.feat_channels[-1], self.text_feat_chn, 1, 1, 0) self.txt_proc = AtrousBlock([self.text_feat_chn] + [1] * ndims, self.text_feat_chn, self.text_feat_chn, ndims=ndims, normalize=False, atrous_rates=[0, 0]) self.txt2img = self.Conv(self.text_feat_chn, self.feat_channels[-1], 1, 1, 0) self.text = torch.zeros(1, self.text_feat_chn) # Text Gate: text-only MLP → sigmoid weight (computed before rec loop) self.text_gate = nn.Sequential( nn.Linear(self.text_feat_chn, self.text_feat_chn // 4), nn.ReLU(), nn.Linear(self.text_feat_chn // 4, 1), nn.Sigmoid() ) self.img_res = [res]*self.dimension self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in self.img_res]), 0), [1, self.dimension]+list(self.img_res)) for i in range(1, self.hier_num + 1): j=-i self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1)) self.up_layers.append(SafeConvTranspose3d(self.feat_channels[j], self.feat_channels[j], 4, 2, 1)) self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1])) self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j])) self.block_down.append(nn.Sequential( AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims), AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims), AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims) )) if self.conditional_input: self.txt_layers.append((self._make_te(self.text_feat_chn, self.feat_channels[i]))) self.block_down_cond.append(nn.Sequential( AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims), AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims), AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims) )) self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0)) self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0)) # Target Gate: residual AtrousBlock → 2-channel softmax (condition vs noise) self.tgt_gate.append(nn.Sequential( AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims, atrous_rates=[1, 3]), self.Conv(self.feat_channels[i], 2, 1, 1, 0) )) if i==self.hier_num: k=j else: k=j-1 self.block_up.append(nn.Sequential( AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False), AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False), AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False) )) # Bottleneck self.txt_layers.append((self._make_te(self.text_feat_chn, self.text_feat_chn))) self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1]) self.b_mid = nn.Sequential( AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims), AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims), AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims) ) self.fuse = self.Conv(2*self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], 1, 1, 0) self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1) # Initialize target gates toward pass-through (condition confidence high) self._init_tgt_gates() def _init_tgt_gates(self): """Bias target gates so condition channel starts moderately high (~0.73). Milder than [2,-2] to ensure both cond*tgt and (1-cond)*out halves of fuse_conv1 input have enough signal for healthy early gradient flow.""" for gate_seq in self.tgt_gate: final_conv = gate_seq[-1] # the Conv that outputs 2 channels with torch.no_grad(): final_conv.bias.data[0] = 1.0 # condition channel → softmax ~0.73 final_conv.bias.data[1] = -1.0 # noise channel → softmax ~0.27 def _encoder_level(self, i, out, tgt, t, ts_emb_shape, text, txt_shape, w_txt): """Single encoder level — extracted for gradient checkpointing.""" out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape)) if self.conditional_input and tgt is not None: tgt = self.block_down_cond[i](tgt) + w_txt * self.txt_layers[i](text).reshape(txt_shape) gate_logits = self.tgt_gate[i](tgt) cond_confidence = F.softmax(gate_logits, dim=1)[:, 0:1] tgt = self.fuse_conv1[i](torch.cat([cond_confidence*tgt, (1-cond_confidence)*out], axis=1)) out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1)) return out, tgt def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.): sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1) return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in zip(sample_coords, max_sz)], 1) def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"): ref = self.ref_grid if ref is None else ref img_sz = self.max_sz if img_sz is None else img_sz resample_mode = 'bilinear' return F.grid_sample(vol, torch.flip((ddf * torch.Tensor( np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute( [0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode, align_corners=True) def forward(self, x=None, y=None, t=None, text=None, rec_num=2, ndims=2): self.device = x.device img_sz = x.size()[2:] n = x.size()[0] self.max_sz = [img_sz[0]] * self.dimension ts_emb_shape=[n,-1]+[1]*self.dimension self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension]) if list(img_sz) != self.img_res: self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0), [1, self.dimension]+list(img_sz)) self.ref_grid = self.ref_grid.to(self.device) img = x t = self.time_embed(t) if text is None: text = self.text text = text.to(self.device) txt_shape = [1,-1]+[1]*self.dimension else: txt_shape = [n,-1]+[1]*self.dimension # Text Gate: compute w_txt from text embedding alone before rec loop txt_vec = text.view(text.size(0), -1) # [1, 1024] or [n, 1024] if txt_vec.size(0) == 1 and n > 1: txt_vec = txt_vec.expand(n, -1) w_txt = self.text_gate(txt_vec) # [B, 1] w_txt = w_txt.view([w_txt.size(0), 1] + [1] * self.dimension) for rec_id in range(rec_num): if self.conditional_input: tgt = y enc_list = [] out = img for i in range(self.hier_num): # Gradient checkpointing on early encoder levels (large feature maps) # to reduce peak activation memory. Levels 0-2 have 128^3, 64^3, 32^3 maps. if self.use_checkpoint and self.training and i < 3: out, tgt = grad_checkpoint( self._encoder_level, i, out, tgt if self.conditional_input else None, t, ts_emb_shape, text, txt_shape, w_txt, use_reentrant=False, ) else: out, tgt = self._encoder_level( i, out, tgt if self.conditional_input else None, t, ts_emb_shape, text, txt_shape, w_txt, ) enc_list.append(out) out = self.down_layers[i](out) if self.conditional_input: tgt = self.down_layers[i](tgt) out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape)) if self.conditional_input: out_shape = out.shape tgt_shape = tgt.shape out_flat = out.view(out_shape[0], out_shape[1], -1).permute(2, 0, 1) tgt_flat = tgt.view(tgt_shape[0], tgt_shape[1], -1).permute(2, 0, 1) out_attn, _ = self.attn_layer0(out_flat, tgt_flat, tgt_flat) tgt_attn, _ = self.attn_layer1(tgt_flat, out_flat, out_flat) out_attn = out_attn.permute(1, 2, 0).contiguous().view(out_shape) tgt_attn = tgt_attn.permute(1, 2, 0).contiguous().view(tgt_shape) out = out + out_attn tgt = tgt + tgt_attn out = self.fuse(torch.cat([out, tgt], dim=1)) if self.conditional_input: img_txt_feat = self.img2txt(out) self.img_embd = self.global_maxpool(img_txt_feat).view(n, -1) # [B, 1024] out_txt = self.txt_layers[-1](text).reshape(txt_shape) - img_txt_feat out_txt = self.txt_proc(out_txt) out_txt = self.txt2img(out_txt) # Text Gate: w_txt precomputed from text embedding alone out = (1 - w_txt) * out + w_txt * out_txt for i in range(self.hier_num): out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1) out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape)) out = self.conv_out(out)/128 ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz) if rec_id == 0: ddf = ddf_one else: ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border") img = self.resample(x, ddf=ddf, img_sz=self.img_sz) return ddf def _make_te(self, dim_in, dim_out): return nn.Sequential( nn.Linear(dim_in, dim_out), nn.ReLU(), nn.Linear(dim_out, dim_out) ) # class RecMutAttnNet(nn.Module): # def __init__(self, n_steps=1000, time_emb_dim=100, ndims=2, num_input_chn=1, res=0, conditional_input=True): # super(RecMutAttnNet, self).__init__() # self.feat_channels = [num_input_chn, 8, 16, 32, 32, 64] # self.conditional_input = conditional_input # self.dimension = ndims # self.Conv = getattr(nn, 'Conv%dd' % self.dimension) # self.ConvT = getattr(nn, 'ConvTranspose%dd' % self.dimension) # # Sinusoidal embedding # self.time_embed = nn.Embedding(n_steps, time_emb_dim) # self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim) # self.time_embed.requires_grad_(False) # self.hier_num = len(self.feat_channels) - 1 # self.down_layers = nn.ModuleList() # self.up_layers = nn.ModuleList() # self.ted_layers = nn.ModuleList() # self.teu_layers = nn.ModuleList() # self.block_down = nn.ModuleList() # if self.conditional_input: # self.block_down_cond = nn.ModuleList() # self.fuse_conv0 = nn.ModuleList() # self.fuse_conv1 = nn.ModuleList() # self.block_up = nn.ModuleList() # for i in range(1, self.hier_num + 1): # j=-i # self.down_layers.append(self.Conv(self.feat_channels[i], self.feat_channels[i], 4, 2, 1)) # self.up_layers.append(self.ConvT(self.feat_channels[j], self.feat_channels[j], 4, 2, 1)) # self.ted_layers.append(self._make_te(time_emb_dim, self.feat_channels[i-1])) # self.teu_layers.append(self._make_te(time_emb_dim, 2*self.feat_channels[j])) # self.block_down.append(nn.Sequential( # AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims), # AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims), # AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims) # )) # if self.conditional_input: # self.block_down_cond.append(nn.Sequential( # AtrousBlock([self.feat_channels[i-1]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i-1], self.feat_channels[i], ndims=ndims), # AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims), # AtrousBlock([self.feat_channels[i]] + [res // (2 ** (i-1))] * ndims, self.feat_channels[i], self.feat_channels[i], ndims=ndims) # )) # self.fuse_conv0.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0)) # self.fuse_conv1.append(self.Conv(2*self.feat_channels[i], self.feat_channels[i], 1, 1, 0)) # if i==self.hier_num: # k=j # else: # k=j-1 # self.block_up.append(nn.Sequential( # AtrousBlock([2*self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, 2*self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False), # AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[j], ndims=ndims, normalize=False), # AtrousBlock([self.feat_channels[j]] + [res // (2 ** (self.hier_num-i-1))] * ndims, self.feat_channels[j], self.feat_channels[k], ndims=ndims, normalize=False) # )) # # Bottleneck # self.tmid = self._make_te(time_emb_dim, self.feat_channels[-1]) # self.b_mid = nn.Sequential( # AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims), # AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims), # AtrousBlock([self.feat_channels[self.hier_num]] + [res // (2**self.hier_num)] * ndims, self.feat_channels[self.hier_num], self.feat_channels[self.hier_num], ndims=ndims) # ) # self.conv_out = self.Conv(self.feat_channels[1], ndims, 3, 1, 1) # def boundary_limit(self, sample_coords0, max_sz, plus=0., minus=1.): # sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1) # return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in # zip(sample_coords, max_sz)], 1) # def resample(self, vol, ddf, ref=None, img_sz=None, padding_mode="zeros"): # ref = self.ref_grid if ref is None else ref # img_sz = self.max_sz if img_sz is None else img_sz # resample_mode = 'bilinear' # if self.dimension==2 else 'trilinear' # return F.grid_sample(vol, torch.flip((ddf * torch.Tensor( # np.reshape(np.array(self.max_sz), [1, self.dimension]+[1]*self.dimension)).to(self.device) + ref).permute( # [0]+list(range(2,2+self.dimension))+[1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode, # align_corners=True) # def forward(self, x=None, y=None, t=None, rec_num=2, ndims=2): # self.device = x.device # img_sz = x.size()[2:] # n = x.size()[0] # self.max_sz = [img_sz[0]] * self.dimension # ts_emb_shape=[n,-1]+[1]*self.dimension # self.img_sz = torch.reshape(torch.tensor([(imsz - 1) / 2 for imsz in img_sz], device=self.device), [1]*(self.dimension+1)+[self.dimension]) # self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=imsz) for imsz in img_sz]), 0), # [1, self.dimension]+list(img_sz)).to(self.device) # img = x # t = self.time_embed(t) # for rec_id in range(rec_num): # if self.conditional_input: # tgt = y # enc_list = [] # out = img # for i in range(self.hier_num): # out = self.block_down[i](out + self.ted_layers[i](t).reshape(ts_emb_shape)) # if self.conditional_input: # tgt = self.block_down_cond[i](tgt) # out = self.fuse_conv0[i](torch.cat([out, tgt], axis=1)) # tgt = self.fuse_conv1[i](torch.cat([tgt, out], axis=1)) # enc_list.append(out) # out = self.down_layers[i](out) # if self.conditional_input: # tgt = self.down_layers[i](tgt) # out = self.b_mid(out + self.tmid(t).reshape(ts_emb_shape)) # if self.conditional_input: # out = out + tgt # for i in range(self.hier_num): # out = torch.cat((self.up_layers[i](out),enc_list[-i-1]), dim=1) # out = self.block_up[i](out + self.teu_layers[i](t).reshape(ts_emb_shape)) # out = self.conv_out(out)/128 # ddf_one = self.boundary_limit(out, max_sz=1 * self.max_sz) # if rec_id == 0: # ddf = ddf_one # else: # ddf = ddf_one + self.resample(ddf, ddf=ddf_one, img_sz=self.img_sz, padding_mode="border") # img = self.resample(x, ddf=ddf, img_sz=self.img_sz) # return ddf # def _make_te(self, dim_in, dim_out): # return nn.Sequential( # nn.Linear(dim_in, dim_out), # nn.ReLU(), # nn.Linear(dim_out, dim_out) # ) # ============================================== # Layers # ============================================== def ddf_multiplier(dvf,mul_num=10,stn=None): ddf=dvf for i in range(mul_num): ddf = dvf + stn(ddf, dvf) return ddf def composite(ddfs,stn=None): if stn is None: stn = STN(device=ddfs[0].device,padding_mode="border") comp_ddf=ddfs[0] for i in range(1,len(ddfs)): comp_ddf = ddfs[i] + stn(comp_ddf,ddfs[i]) return comp_ddf class STN(nn.Module): def __init__(self,ndims=2,img_sz=None,max_sz=None,device=None,padding_mode="border",resample_mode=None): super(STN, self).__init__() self.ndims=ndims self.img_sz=[img_sz]*ndims # self.img_sz=img_sz self.device = device self.padding_mode = padding_mode # max_sz=[128]*self.ndims max_sz=[img_sz]*self.ndims # max_sz=img_sz # max_sz=img_sz if max_sz is None else ([128,128] if img_sz is None else img_sz) # self.max_sz=torch.Tensor(np.reshape(np.array(max_sz), [1, self.ndims, 1, 1])).to(self.device) self.max_sz=torch.Tensor(np.reshape(np.array(max_sz), [1, self.ndims]+[1]*self.ndims)).to(self.device) self.resample_mode=resample_mode if self.img_sz is not None: self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=s) for s in self.img_sz]), 0), [1, self.ndims] + self.img_sz).to(self.device) return def max_limit(self, sample_coords0, plus=0., minus=1.): sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1) # return tf.stack([tf.maximum(tf.minimum(x, sz - minus + plus), 0 + plus) for x, sz in zip(sample_coords, input_size0)],-1) return torch.cat([torch.clamp(x * sz, min=minus - 1 * sz + plus, max=1 * sz - minus + plus) / sz for x, sz in zip(sample_coords, self.max_sz)], 1) def boundary_limit(self, sample_coords0, plus=0., minus=1.): sample_coords = torch.split(sample_coords0, split_size_or_sections=1, dim=1) # return tf.stack([tf.maximum(tf.minimum(x, sz - minus + plus), 0 + plus) for x, sz in zip(sample_coords, input_size0)],-1) return torch.cat([(torch.clamp(x * sz+ref, min=minus - 1 * sz + plus, max=1 * sz - minus + plus)-ref) / sz for x, sz,ref in zip(sample_coords, self.max_sz, self.ref_grid)], 1) def resample(self, vol, ddf, ref=None, img_sz=None,padding_mode = "zeros"): # print(vol.device, ddf.device) # print(self.device) # print('===================') device = ddf.device ref = self.ref_grid if ref is None else ref if img_sz is None: img_sz = self.max_sz else: img_sz = torch.reshape(torch.tensor([(s - 1) / 2. for s in img_sz], device=device), [1]+[1]*self.ndims+[self.ndims]) # resample_mode = 'bicubic' if self.resample_mode is None: resample_mode = 'bilinear' # if self.ndims==2 else 'trilinear' else: resample_mode=self.resample_mode # padding_mode = "border" # print(ddf.shape, ref.shape) return F.grid_sample(vol.to(device), torch.flip((ddf * self.max_sz.to(device) + ref.to(device)).permute( [0] + list(range(2, 2 + self.ndims)) + [1]) / img_sz - 1, dims=[-1]), mode=resample_mode, padding_mode=padding_mode, align_corners=True) def forward(self,x,ddf): self.device = x.device if self.device is None else self.device if self.img_sz is None: self.img_sz = list(x.size()[2:]).to(self.device) self.ref_grid = torch.reshape(torch.stack(torch.meshgrid([torch.arange(end=s) for s in self.img_sz]), 0),[1, self.ndims]+self.img_sz).to(self.device) resampled_x = self.resample(x, ddf=ddf, img_sz=self.img_sz, padding_mode=self.padding_mode) return resampled_x if __name__ == '__main__': ndims = 3 res = 128 x = torch.rand([1, 1] + [res]*ndims) t = torch.randint(0, 1000, (1,)) text = torch.rand([1, 1024] + [1]*ndims) model = RecMutAttnNet(n_steps=1000, time_emb_dim=100, ndims=ndims, num_input_chn=1, res=res, conditional_input=True) y = model(x, x, t, text=text) print("Ouput shape", y.shape) # Total parameters total_params = sum(p.numel() for p in model.parameters()) # Trainable parameters only trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Total parameters: {total_params}") print(f"Trainable parameters: {trainable_params}")