Spaces:
Runtime error
Runtime error
| import os, sys | |
| import numpy as np | |
| from PIL import Image | |
| import itertools | |
| import glob | |
| import random | |
| import torch | |
| import torchvision | |
| import torchvision.transforms as transforms | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import torch.nn.functional as F | |
| from torch.nn.functional import relu as RLU | |
| registration_method = 'Additive_Recurence' #{'Rawblock', 'matching_points', 'Additive_Recurence', 'Multiplicative_Recurence'} #'recurrent_matrix', | |
| imposed_point = 0 | |
| Arch = 'ResNet' | |
| Fix_Torch_Wrap = False | |
| BW_Position = False | |
| dim = 128 | |
| dim0 =224 | |
| crop_ratio = dim/dim0 | |
| class Identity(nn.Module): | |
| def __init__(self): | |
| super(Identity, self).__init__() | |
| def forward(self, x): | |
| return x | |
| class Build_IRmodel_Resnet(nn.Module): | |
| def __init__(self, resnet_model, registration_method = 'Additive_Recurence', BW_Position=False): | |
| super(Build_IRmodel_Resnet, self).__init__() | |
| self.resnet_model = resnet_model | |
| self.BW_Position = BW_Position | |
| self.N_parameters = 6 | |
| self.registration_method = registration_method | |
| self.fc1 =nn.Linear(6, 64) | |
| self.fc2 =nn.Linear(64, 128*3) | |
| self.fc3 =nn.Linear(512, self.N_parameters) | |
| def forward(self, input_X_batch): | |
| source = input_X_batch['source'] | |
| target = input_X_batch['target'] | |
| if 'Recurence' in self.registration_method: | |
| M_i = input_X_batch['M_i'].view(-1, 6) | |
| M_rep = F.relu(self.fc1(M_i)) | |
| M_rep = F.relu(self.fc2(M_rep)).view(-1,3,1,128) | |
| concatenated_input = torch.cat((source,target,M_rep), dim=2) | |
| else: | |
| concatenated_input = torch.cat((source,target), dim=2) | |
| resnet_output = self.resnet_model(concatenated_input) | |
| predicted_line = self.fc3(resnet_output) | |
| if 'Recurence' in self.registration_method: | |
| predicted_part_mtrx = predicted_line.view(-1, 2, 3) | |
| Prd_Affine_mtrx = predicted_part_mtrx + input_X_batch['M_i'] | |
| predction = {'predicted_part_mtrx':predicted_part_mtrx, | |
| 'Affine_mtrx': Prd_Affine_mtrx} | |
| else: | |
| Prd_Affine_mtrx = predicted_line.view(-1, 2, 3) | |
| predction = {'Affine_mtrx': Prd_Affine_mtrx} | |
| return predction | |
| def pil_to_numpy(im): | |
| im.load() | |
| # Unpack data | |
| e = Image._getencoder(im.mode, "raw", im.mode) | |
| e.setimage(im.im) | |
| # NumPy buffer for the result | |
| shape, typestr = Image._conv_type_shape(im) | |
| data = np.empty(shape, dtype=np.dtype(typestr)) | |
| mem = data.data.cast("B", (data.data.nbytes,)) | |
| bufsize, s, offset = 65536, 0, 0 | |
| while not s: | |
| l, s, d = e.encode(bufsize) | |
| mem[offset:offset + len(d)] = d | |
| offset += len(d) | |
| if s < 0: | |
| raise RuntimeError("encoder error %d in tobytes" % s) | |
| return data | |
| def load_image_pil_accelerated(image_path, dim=128): | |
| image = Image.open(image_path).convert("RGB") | |
| array = pil_to_numpy(image) | |
| tensor = torch.from_numpy(np.rollaxis(array,2,0)/255).to(torch.float32) | |
| tensor = torchvision.transforms.Resize((dim,dim))(tensor) | |
| return tensor | |
| def preprocess_image(image_path, dim = 128): | |
| img = load_image_pil_accelerated(image_path, dim) | |
| return img.unsqueeze(0) | |
| ''' | |
| def load_image_from_url(image_path, dim = 128): | |
| img = Image.open(image_path).convert("RGB") | |
| img = img.resize((dim, dim)) | |
| return img | |
| def preprocess_image(image_path, dim = 128): | |
| img = load_img(image_path, target_size=(dim, dim)) | |
| img = img_to_array(img) | |
| img = np.expand_dims(img, axis=0) | |
| return img | |
| def create_model(dim = 128): | |
| # configure unet input shape (concatenation of moving and fixed images) | |
| volshape = (dim,dim,3) | |
| unet_input_features = 2*volshape[:-1] | |
| inshape = (*volshape[:-1],unet_input_features) | |
| nb_conv_per_level=1 | |
| enc_nf = [dim, dim, dim, dim] | |
| dec_nf = [dim, dim, dim, dim, dim, int(dim/2)] | |
| nb_upsample_skips = 0 | |
| nb_dec_convs = len(enc_nf) | |
| final_convs = dec_nf[nb_dec_convs:] | |
| dec_nf = dec_nf[:nb_dec_convs] | |
| nb_levels = int(nb_dec_convs / nb_conv_per_level) + 1 | |
| source = tf.keras.Input(shape=volshape, name='source_input') | |
| target = tf.keras.Input(shape=volshape, name='target_input') | |
| inputs = [source, target] | |
| unet_input = concatenate(inputs, name='input_concat') | |
| #Define lyers | |
| ndims = len(unet_input.get_shape()) - 2 | |
| MaxPooling = getattr(tf.keras.layers, 'MaxPooling%dD' % ndims) | |
| Conv = getattr(tf.keras.layers, 'Conv%dD' % ndims) | |
| UpSampling = getattr(tf.keras.layers, 'UpSampling%dD' % ndims) | |
| # Encoder | |
| enc_layers = [] | |
| lyr = unet_input | |
| for level in range(nb_levels - 1): | |
| for conv in range(nb_conv_per_level): | |
| nfeat = enc_nf[level * nb_conv_per_level + conv] | |
| lyr = Conv(nfeat, kernel_size=3, padding='same', strides=1,activation = LeakyReLU(0.2), kernel_initializer = 'he_normal')(lyr) | |
| enc_layers.append(lyr) | |
| lyr = MaxPooling(2)(lyr) | |
| # Decoder | |
| for level in range(nb_levels - 1): | |
| real_level = nb_levels - level - 2 | |
| for conv in range(nb_conv_per_level): | |
| nfeat = dec_nf[level * nb_conv_per_level + conv] | |
| lyr = Conv(nfeat, kernel_size=3, padding='same', strides=1,activation = LeakyReLU(0.2), kernel_initializer = 'he_normal')(lyr) | |
| # upsample | |
| if level < (nb_levels - 1 - nb_upsample_skips): | |
| upsampled = UpSampling(size=(2,) * ndims)(lyr) | |
| lyr = concatenate([upsampled, enc_layers.pop()]) | |
| # Final convolution | |
| for num, nfeat in enumerate(final_convs): | |
| lyr = Conv(nfeat, kernel_size=3, padding='same', strides=1,activation = LeakyReLU(0.2), kernel_initializer = 'he_normal')(lyr) | |
| unet = tf.keras.models.Model(inputs=inputs, outputs=lyr) | |
| # transform the results into a flow field. | |
| disp_tensor = Conv(ndims, kernel_size=3, padding='same', name='disp')(unet.output) | |
| # using keras, we can easily form new models via tensor pointers | |
| def_model = tf.keras.models.Model(inputs, disp_tensor) | |
| # build transformer layer | |
| spatial_transformer = SpatialTransformer() | |
| # warp the moving image with the transformer | |
| moved_image_tensor = spatial_transformer([source, disp_tensor]) | |
| outputs = [moved_image_tensor, disp_tensor] | |
| vxm_model = tf.keras.models.Model(inputs=inputs, outputs=outputs) | |
| return vxm_model | |
| ''' | |