import os, sys import numpy as np from PIL import Image import itertools import glob import random import torch import torchvision import torchvision.transforms as transforms import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.nn.functional import relu as RLU registration_method = 'Additive_Recurence' #{'Rawblock', 'matching_points', 'Additive_Recurence', 'Multiplicative_Recurence'} #'recurrent_matrix', imposed_point = 0 Arch = 'ResNet' Fix_Torch_Wrap = False BW_Position = False dim = 128 dim0 =224 crop_ratio = dim/dim0 class Identity(nn.Module): def __init__(self): super(Identity, self).__init__() def forward(self, x): return x class Build_IRmodel_Resnet(nn.Module): def __init__(self, resnet_model, registration_method = 'Additive_Recurence', BW_Position=False): super(Build_IRmodel_Resnet, self).__init__() self.resnet_model = resnet_model self.BW_Position = BW_Position self.N_parameters = 6 self.registration_method = registration_method self.fc1 =nn.Linear(6, 64) self.fc2 =nn.Linear(64, 128*3) self.fc3 =nn.Linear(512, self.N_parameters) def forward(self, input_X_batch): source = input_X_batch['source'] target = input_X_batch['target'] if 'Recurence' in self.registration_method: M_i = input_X_batch['M_i'].view(-1, 6) M_rep = F.relu(self.fc1(M_i)) M_rep = F.relu(self.fc2(M_rep)).view(-1,3,1,128) concatenated_input = torch.cat((source,target,M_rep), dim=2) else: concatenated_input = torch.cat((source,target), dim=2) resnet_output = self.resnet_model(concatenated_input) predicted_line = self.fc3(resnet_output) if 'Recurence' in self.registration_method: predicted_part_mtrx = predicted_line.view(-1, 2, 3) Prd_Affine_mtrx = predicted_part_mtrx + input_X_batch['M_i'] predction = {'predicted_part_mtrx':predicted_part_mtrx, 'Affine_mtrx': Prd_Affine_mtrx} else: Prd_Affine_mtrx = predicted_line.view(-1, 2, 3) predction = {'Affine_mtrx': Prd_Affine_mtrx} return predction def pil_to_numpy(im): im.load() # Unpack data e = Image._getencoder(im.mode, "raw", im.mode) e.setimage(im.im) # NumPy buffer for the result shape, typestr = Image._conv_type_shape(im) data = np.empty(shape, dtype=np.dtype(typestr)) mem = data.data.cast("B", (data.data.nbytes,)) bufsize, s, offset = 65536, 0, 0 while not s: l, s, d = e.encode(bufsize) mem[offset:offset + len(d)] = d offset += len(d) if s < 0: raise RuntimeError("encoder error %d in tobytes" % s) return data def load_image_pil_accelerated(image_path, dim=128): image = Image.open(image_path).convert("RGB") array = pil_to_numpy(image) tensor = torch.from_numpy(np.rollaxis(array,2,0)/255).to(torch.float32) tensor = torchvision.transforms.Resize((dim,dim))(tensor) return tensor def preprocess_image(image_path, dim = 128): img = torch.zeros([1,3,dim,dim]) img[0] = load_image_pil_accelerated(image_path, dim) return img