import os, sys import numpy as np from PIL import Image import itertools import glob import random import torch import torchvision import torchvision.transforms as transforms import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.nn.functional import relu as RLU registration_method = 'Additive_Recurence' #{'Rawblock', 'matching_points', 'Additive_Recurence', 'Multiplicative_Recurence'} #'recurrent_matrix', imposed_point = 0 Arch = 'ResNet' Fix_Torch_Wrap = False BW_Position = False dim = 128 dim0 =224 crop_ratio = dim/dim0 class Identity(nn.Module): def __init__(self): super(Identity, self).__init__() def forward(self, x): return x class Build_IRmodel_Resnet(nn.Module): def __init__(self, resnet_model, registration_method = 'Additive_Recurence', BW_Position=False): super(Build_IRmodel_Resnet, self).__init__() self.resnet_model = resnet_model self.BW_Position = BW_Position self.N_parameters = 6 self.registration_method = registration_method self.fc1 =nn.Linear(6, 64) self.fc2 =nn.Linear(64, 128*3) self.fc3 =nn.Linear(512, self.N_parameters) def forward(self, input_X_batch): source = input_X_batch['source'] target = input_X_batch['target'] if 'Recurence' in self.registration_method: M_i = input_X_batch['M_i'].view(-1, 6) M_rep = F.relu(self.fc1(M_i)) M_rep = F.relu(self.fc2(M_rep)).view(-1,3,1,128) concatenated_input = torch.cat((source,target,M_rep), dim=2) else: concatenated_input = torch.cat((source,target), dim=2) resnet_output = self.resnet_model(concatenated_input) predicted_line = self.fc3(resnet_output) if 'Recurence' in self.registration_method: predicted_part_mtrx = predicted_line.view(-1, 2, 3) Prd_Affine_mtrx = predicted_part_mtrx + input_X_batch['M_i'] predction = {'predicted_part_mtrx':predicted_part_mtrx, 'Affine_mtrx': Prd_Affine_mtrx} else: Prd_Affine_mtrx = predicted_line.view(-1, 2, 3) predction = {'Affine_mtrx': Prd_Affine_mtrx} return predction def pil_to_numpy(im): im.load() # Unpack data e = Image._getencoder(im.mode, "raw", im.mode) e.setimage(im.im) # NumPy buffer for the result shape, typestr = Image._conv_type_shape(im) data = np.empty(shape, dtype=np.dtype(typestr)) mem = data.data.cast("B", (data.data.nbytes,)) bufsize, s, offset = 65536, 0, 0 while not s: l, s, d = e.encode(bufsize) mem[offset:offset + len(d)] = d offset += len(d) if s < 0: raise RuntimeError("encoder error %d in tobytes" % s) return data def load_image_pil_accelerated(image_path, dim=128): image = Image.open(image_path).convert("RGB") array = pil_to_numpy(image) tensor = torch.from_numpy(np.rollaxis(array,2,0)/255).to(torch.float32) tensor = torchvision.transforms.Resize((dim,dim))(tensor) return tensor def preprocess_image(image_path, dim = 128): img = load_image_pil_accelerated(image_path, dim) return img.unsqueeze(0) ''' def load_image_from_url(image_path, dim = 128): img = Image.open(image_path).convert("RGB") img = img.resize((dim, dim)) return img def preprocess_image(image_path, dim = 128): img = load_img(image_path, target_size=(dim, dim)) img = img_to_array(img) img = np.expand_dims(img, axis=0) return img def create_model(dim = 128): # configure unet input shape (concatenation of moving and fixed images) volshape = (dim,dim,3) unet_input_features = 2*volshape[:-1] inshape = (*volshape[:-1],unet_input_features) nb_conv_per_level=1 enc_nf = [dim, dim, dim, dim] dec_nf = [dim, dim, dim, dim, dim, int(dim/2)] nb_upsample_skips = 0 nb_dec_convs = len(enc_nf) final_convs = dec_nf[nb_dec_convs:] dec_nf = dec_nf[:nb_dec_convs] nb_levels = int(nb_dec_convs / nb_conv_per_level) + 1 source = tf.keras.Input(shape=volshape, name='source_input') target = tf.keras.Input(shape=volshape, name='target_input') inputs = [source, target] unet_input = concatenate(inputs, name='input_concat') #Define lyers ndims = len(unet_input.get_shape()) - 2 MaxPooling = getattr(tf.keras.layers, 'MaxPooling%dD' % ndims) Conv = getattr(tf.keras.layers, 'Conv%dD' % ndims) UpSampling = getattr(tf.keras.layers, 'UpSampling%dD' % ndims) # Encoder enc_layers = [] lyr = unet_input for level in range(nb_levels - 1): for conv in range(nb_conv_per_level): nfeat = enc_nf[level * nb_conv_per_level + conv] lyr = Conv(nfeat, kernel_size=3, padding='same', strides=1,activation = LeakyReLU(0.2), kernel_initializer = 'he_normal')(lyr) enc_layers.append(lyr) lyr = MaxPooling(2)(lyr) # Decoder for level in range(nb_levels - 1): real_level = nb_levels - level - 2 for conv in range(nb_conv_per_level): nfeat = dec_nf[level * nb_conv_per_level + conv] lyr = Conv(nfeat, kernel_size=3, padding='same', strides=1,activation = LeakyReLU(0.2), kernel_initializer = 'he_normal')(lyr) # upsample if level < (nb_levels - 1 - nb_upsample_skips): upsampled = UpSampling(size=(2,) * ndims)(lyr) lyr = concatenate([upsampled, enc_layers.pop()]) # Final convolution for num, nfeat in enumerate(final_convs): lyr = Conv(nfeat, kernel_size=3, padding='same', strides=1,activation = LeakyReLU(0.2), kernel_initializer = 'he_normal')(lyr) unet = tf.keras.models.Model(inputs=inputs, outputs=lyr) # transform the results into a flow field. disp_tensor = Conv(ndims, kernel_size=3, padding='same', name='disp')(unet.output) # using keras, we can easily form new models via tensor pointers def_model = tf.keras.models.Model(inputs, disp_tensor) # build transformer layer spatial_transformer = SpatialTransformer() # warp the moving image with the transformer moved_image_tensor = spatial_transformer([source, disp_tensor]) outputs = [moved_image_tensor, disp_tensor] vxm_model = tf.keras.models.Model(inputs=inputs, outputs=outputs) return vxm_model '''