import os, sys import numpy as np from PIL import Image import itertools import glob import random import torch import torchvision import torchvision.transforms as transforms import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.nn.functional import relu as RLU registration_method = 'Additive_Recurence' #{'Rawblock', 'matching_points', 'Additive_Recurence', 'Multiplicative_Recurence'} #'recurrent_matrix', imposed_point = 0 Arch = 'ResNet' Fix_Torch_Wrap = False BW_Position = False dim = 128 dim0 =224 crop_ratio = dim/dim0 class Identity(nn.Module): def __init__(self): super(Identity, self).__init__() def forward(self, x): return x class Build_IRmodel_Resnet(nn.Module): def __init__(self, resnet_model, registration_method = 'Additive_Recurence', BW_Position=False): super(Build_IRmodel_Resnet, self).__init__() self.resnet_model = resnet_model self.BW_Position = BW_Position self.N_parameters = 6 self.registration_method = registration_method self.fc1 =nn.Linear(6, 64) self.fc2 =nn.Linear(64, 128*3) self.fc3 =nn.Linear(512, self.N_parameters) def forward(self, input_X_batch): source = input_X_batch['source'] target = input_X_batch['target'] if 'Recurence' in self.registration_method: M_i = input_X_batch['M_i'].view(-1, 6) M_rep = F.relu(self.fc1(M_i)) M_rep = F.relu(self.fc2(M_rep)).view(-1,3,1,128) concatenated_input = torch.cat((source,target,M_rep), dim=2) else: concatenated_input = torch.cat((source,target), dim=2) resnet_output = self.resnet_model(concatenated_input) predicted_line = self.fc3(resnet_output) if 'Recurence' in self.registration_method: predicted_part_mtrx = predicted_line.view(-1, 2, 3) Prd_Affine_mtrx = predicted_part_mtrx + input_X_batch['M_i'] predction = {'predicted_part_mtrx':predicted_part_mtrx, 'Affine_mtrx': Prd_Affine_mtrx} else: Prd_Affine_mtrx = predicted_line.view(-1, 2, 3) predction = {'Affine_mtrx': Prd_Affine_mtrx} return predction def pil_to_numpy(im): im.load() # Unpack data e = Image._getencoder(im.mode, "raw", im.mode) e.setimage(im.im) # NumPy buffer for the result shape, typestr = Image._conv_type_shape(im) data = np.empty(shape, dtype=np.dtype(typestr)) mem = data.data.cast("B", (data.data.nbytes,)) bufsize, s, offset = 65536, 0, 0 while not s: l, s, d = e.encode(bufsize) mem[offset:offset + len(d)] = d offset += len(d) if s < 0: raise RuntimeError("encoder error %d in tobytes" % s) return data def load_image_pil_accelerated(image_path, dim=128): image = Image.open(image_path).convert("RGB") array = pil_to_numpy(image) tensor = torch.from_numpy(np.rollaxis(array,2,0)/255).to(torch.float32) tensor = torchvision.transforms.Resize((dim,dim))(tensor) return tensor def workaround_matrix(Affine_mtrx0, acc = 2): # To find the equivalent torch-compatible matrix from a correct matrix set acc=2 #This will be needed for transforming an image # To find the correct Affine matrix from Torch compatible matrix set acc=0.5 Affine_mtrx_adj = inv_AM(Affine_mtrx0) Affine_mtrx_adj[:,:,2]*=acc return Affine_mtrx_adj def inv_AM(Affine_mtrx): AM3 = mtrx3(Affine_mtrx) AM_inv = torch.linalg.inv(AM3) return AM_inv[:,0:2,:] def mtrx3(Affine_mtrx): mtrx_shape = Affine_mtrx.shape if len(mtrx_shape)==3: N_Mbatches = mtrx_shape[0] AM3 = torch.zeros( [N_Mbatches,3,3])#.to(device) AM3[:,0:2,:] = Affine_mtrx AM3[:,2,2] = 1 elif len(mtrx_shape)==2: N_Mbatches = 1 AM3 = torch.zeros([3,3])#.to(device) AM3[0:2,:] = Affine_mtrx AM3[2,2] = 1 return AM3 def standarize_point(d, dim=128, flip = False): if flip: d = -d return d/dim - 0.5 def destandarize_point(d, dim=128, flip = False): if flip: d = -d return dim*(d + 0.5) def generate_standard_elips(N_samples = 100, a= 1,b = 1): radius = 0.25 center = 0 N_samples1 = int(N_samples/2 - 1) N_samples2 = N_samples - N_samples1 x1 = torch.distributions.uniform.Uniform(center-radius,center + radius).sample([N_samples1]) x1_ordered = torch.sort(x1).values y1 = center + b*torch.sqrt(radius**2 - ((x1_ordered-center)/a)**2) x2 = torch.distributions.uniform.Uniform(center-radius,center + radius).sample([N_samples2]) x2_ordered = torch.sort(x2, descending=True).values y2 = center - b*torch.sqrt(radius**2 - ((x2_ordered-center)/a)**2) x = torch.cat([x1_ordered, x2_ordered]) y = torch.cat([y1, y2]) return x, y def transform_standard_points(Affine_mat, x,y): XY = torch.ones([3,x.shape[0]]) XY[0,:]= x XY[1,:]= y XYt = torch.matmul(Affine_mat.to('cpu').detach(),XY) xt0 = XYt[0] yt0 = XYt[1] return xt0, yt0 def wrap_points(img, x_source, y_source, l=1, DIM =dim): for i in range(len(y_source)): x0 = x_source[i].int() y0 = y_source[i].int() if (x00) and (y00): img[:,:,y0-l:y0+l,x0-l:x0+l] = 0 return img def wrap_imge_cropped(Affine_mtrx, source_img, dim1=224, dim2=128): source_img224 = torch.nn.ZeroPad2d(int((dim1-dim2)/2))(source_img) grd = torch.nn.functional.affine_grid(Affine_mtrx, size=source_img224.shape,align_corners=False) wrapped_img = torch.nn.functional.grid_sample(source_img224, grid=grd, mode='bilinear', padding_mode='zeros', align_corners=False) wrapped_img = torchvision.transforms.CenterCrop((dim2, dim2))(wrapped_img) return wrapped_img def preprocess_image(image_path, dim = 128): img = torch.zeros([1,3,dim,dim]) img[0] = load_image_pil_accelerated(image_path, dim) return img