Spaces:
Runtime error
Runtime error
| import os, sys | |
| import numpy as np | |
| from PIL import Image | |
| import itertools | |
| import glob | |
| import random | |
| import torch | |
| import torchvision | |
| import torchvision.transforms as transforms | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import torch.nn.functional as F | |
| from torch.nn.functional import relu as RLU | |
| registration_method = 'Additive_Recurence' #{'Rawblock', 'matching_points', 'Additive_Recurence', 'Multiplicative_Recurence'} #'recurrent_matrix', | |
| imposed_point = 0 | |
| Arch = 'ResNet' | |
| Fix_Torch_Wrap = False | |
| BW_Position = False | |
| dim = 128 | |
| dim0 =224 | |
| crop_ratio = dim/dim0 | |
| class Identity(nn.Module): | |
| def __init__(self): | |
| super(Identity, self).__init__() | |
| def forward(self, x): | |
| return x | |
| class Build_IRmodel_Resnet(nn.Module): | |
| def __init__(self, resnet_model, registration_method = 'Additive_Recurence', BW_Position=False): | |
| super(Build_IRmodel_Resnet, self).__init__() | |
| self.resnet_model = resnet_model | |
| self.BW_Position = BW_Position | |
| self.N_parameters = 6 | |
| self.registration_method = registration_method | |
| self.fc1 =nn.Linear(6, 64) | |
| self.fc2 =nn.Linear(64, 128*3) | |
| self.fc3 =nn.Linear(512, self.N_parameters) | |
| def forward(self, input_X_batch): | |
| source = input_X_batch['source'] | |
| target = input_X_batch['target'] | |
| if 'Recurence' in self.registration_method: | |
| M_i = input_X_batch['M_i'].view(-1, 6) | |
| M_rep = F.relu(self.fc1(M_i)) | |
| M_rep = F.relu(self.fc2(M_rep)).view(-1,3,1,128) | |
| concatenated_input = torch.cat((source,target,M_rep), dim=2) | |
| else: | |
| concatenated_input = torch.cat((source,target), dim=2) | |
| resnet_output = self.resnet_model(concatenated_input) | |
| predicted_line = self.fc3(resnet_output) | |
| if 'Recurence' in self.registration_method: | |
| predicted_part_mtrx = predicted_line.view(-1, 2, 3) | |
| Prd_Affine_mtrx = predicted_part_mtrx + input_X_batch['M_i'] | |
| predction = {'predicted_part_mtrx':predicted_part_mtrx, | |
| 'Affine_mtrx': Prd_Affine_mtrx} | |
| else: | |
| Prd_Affine_mtrx = predicted_line.view(-1, 2, 3) | |
| predction = {'Affine_mtrx': Prd_Affine_mtrx} | |
| return predction | |
| def pil_to_numpy(im): | |
| im.load() | |
| # Unpack data | |
| e = Image._getencoder(im.mode, "raw", im.mode) | |
| e.setimage(im.im) | |
| # NumPy buffer for the result | |
| shape, typestr = Image._conv_type_shape(im) | |
| data = np.empty(shape, dtype=np.dtype(typestr)) | |
| mem = data.data.cast("B", (data.data.nbytes,)) | |
| bufsize, s, offset = 65536, 0, 0 | |
| while not s: | |
| l, s, d = e.encode(bufsize) | |
| mem[offset:offset + len(d)] = d | |
| offset += len(d) | |
| if s < 0: | |
| raise RuntimeError("encoder error %d in tobytes" % s) | |
| return data | |
| def load_image_pil_accelerated(image_path, dim=128): | |
| image = Image.open(image_path).convert("RGB") | |
| array = pil_to_numpy(image) | |
| tensor = torch.from_numpy(np.rollaxis(array,2,0)/255).to(torch.float32) | |
| tensor = torchvision.transforms.Resize((dim,dim))(tensor) | |
| return tensor | |
| def preprocess_image(image_path, dim = 128): | |
| img = torch.zeros([1,3,dim,dim]) | |
| img[0] = load_image_pil_accelerated(image_path, dim) | |
| return img | |