Spaces:
Sleeping
Sleeping
| import os, sys | |
| import numpy as np | |
| from PIL import Image | |
| import itertools | |
| import glob | |
| import random | |
| import torch | |
| import torchvision | |
| import torchvision.transforms as transforms | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| import torch.nn.functional as F | |
| from torch.nn.functional import relu as RLU | |
| registration_method = 'Additive_Recurence' #{'Rawblock', 'matching_points', 'Additive_Recurence', 'Multiplicative_Recurence'} #'recurrent_matrix', | |
| imposed_point = 0 | |
| Arch = 'ResNet' | |
| Fix_Torch_Wrap = False | |
| BW_Position = False | |
| dim = 128 | |
| dim0 =224 | |
| crop_ratio = dim/dim0 | |
| class Identity(nn.Module): | |
| def __init__(self): | |
| super(Identity, self).__init__() | |
| def forward(self, x): | |
| return x | |
| class Build_IRmodel_Resnet(nn.Module): | |
| def __init__(self, resnet_model, registration_method = 'Additive_Recurence', BW_Position=False): | |
| super(Build_IRmodel_Resnet, self).__init__() | |
| self.resnet_model = resnet_model | |
| self.BW_Position = BW_Position | |
| self.N_parameters = 6 | |
| self.registration_method = registration_method | |
| self.fc1 =nn.Linear(6, 64) | |
| self.fc2 =nn.Linear(64, 128*3) | |
| self.fc3 =nn.Linear(512, self.N_parameters) | |
| def forward(self, input_X_batch): | |
| source = input_X_batch['source'] | |
| target = input_X_batch['target'] | |
| if 'Recurence' in self.registration_method: | |
| M_i = input_X_batch['M_i'].view(-1, 6) | |
| M_rep = F.relu(self.fc1(M_i)) | |
| M_rep = F.relu(self.fc2(M_rep)).view(-1,3,1,128) | |
| concatenated_input = torch.cat((source,target,M_rep), dim=2) | |
| else: | |
| concatenated_input = torch.cat((source,target), dim=2) | |
| resnet_output = self.resnet_model(concatenated_input) | |
| predicted_line = self.fc3(resnet_output) | |
| if 'Recurence' in self.registration_method: | |
| predicted_part_mtrx = predicted_line.view(-1, 2, 3) | |
| Prd_Affine_mtrx = predicted_part_mtrx + input_X_batch['M_i'] | |
| predction = {'predicted_part_mtrx':predicted_part_mtrx, | |
| 'Affine_mtrx': Prd_Affine_mtrx} | |
| else: | |
| Prd_Affine_mtrx = predicted_line.view(-1, 2, 3) | |
| predction = {'Affine_mtrx': Prd_Affine_mtrx} | |
| return predction | |
| def pil_to_numpy(im): | |
| im.load() | |
| # Unpack data | |
| e = Image._getencoder(im.mode, "raw", im.mode) | |
| e.setimage(im.im) | |
| # NumPy buffer for the result | |
| shape, typestr = Image._conv_type_shape(im) | |
| data = np.empty(shape, dtype=np.dtype(typestr)) | |
| mem = data.data.cast("B", (data.data.nbytes,)) | |
| bufsize, s, offset = 65536, 0, 0 | |
| while not s: | |
| l, s, d = e.encode(bufsize) | |
| mem[offset:offset + len(d)] = d | |
| offset += len(d) | |
| if s < 0: | |
| raise RuntimeError("encoder error %d in tobytes" % s) | |
| return data | |
| def load_image_pil_accelerated(image_path, dim=128): | |
| image = Image.open(image_path).convert("RGB") | |
| array = pil_to_numpy(image) | |
| tensor = torch.from_numpy(np.rollaxis(array,2,0)/255).to(torch.float32) | |
| tensor = torchvision.transforms.Resize((dim,dim))(tensor) | |
| return tensor | |
| def workaround_matrix(Affine_mtrx0, acc = 2): | |
| # To find the equivalent torch-compatible matrix from a correct matrix set acc=2 #This will be needed for transforming an image | |
| # To find the correct Affine matrix from Torch compatible matrix set acc=0.5 | |
| Affine_mtrx_adj = inv_AM(Affine_mtrx0) | |
| Affine_mtrx_adj[:,:,2]*=acc | |
| return Affine_mtrx_adj | |
| def inv_AM(Affine_mtrx): | |
| AM3 = mtrx3(Affine_mtrx) | |
| AM_inv = torch.linalg.inv(AM3) | |
| return AM_inv[:,0:2,:] | |
| def mtrx3(Affine_mtrx): | |
| mtrx_shape = Affine_mtrx.shape | |
| if len(mtrx_shape)==3: | |
| N_Mbatches = mtrx_shape[0] | |
| AM3 = torch.zeros( [N_Mbatches,3,3])#.to(device) | |
| AM3[:,0:2,:] = Affine_mtrx | |
| AM3[:,2,2] = 1 | |
| elif len(mtrx_shape)==2: | |
| N_Mbatches = 1 | |
| AM3 = torch.zeros([3,3])#.to(device) | |
| AM3[0:2,:] = Affine_mtrx | |
| AM3[2,2] = 1 | |
| return AM3 | |
| def standarize_point(d, dim=128, flip = False): | |
| if flip: | |
| d = -d | |
| return d/dim - 0.5 | |
| def destandarize_point(d, dim=128, flip = False): | |
| if flip: | |
| d = -d | |
| return dim*(d + 0.5) | |
| def generate_standard_elips(N_samples = 100, a= 1,b = 1): | |
| radius = 0.25 | |
| center = 0 | |
| N_samples1 = int(N_samples/2 - 1) | |
| N_samples2 = N_samples - N_samples1 | |
| x1 = torch.distributions.uniform.Uniform(center-radius,center + radius).sample([N_samples1]) | |
| x1_ordered = torch.sort(x1).values | |
| y1 = center + b*torch.sqrt(radius**2 - ((x1_ordered-center)/a)**2) | |
| x2 = torch.distributions.uniform.Uniform(center-radius,center + radius).sample([N_samples2]) | |
| x2_ordered = torch.sort(x2, descending=True).values | |
| y2 = center - b*torch.sqrt(radius**2 - ((x2_ordered-center)/a)**2) | |
| x = torch.cat([x1_ordered, x2_ordered]) | |
| y = torch.cat([y1, y2]) | |
| return x, y | |
| def transform_standard_points(Affine_mat, x,y): | |
| XY = torch.ones([3,x.shape[0]]) | |
| XY[0,:]= x | |
| XY[1,:]= y | |
| XYt = torch.matmul(Affine_mat.to('cpu').detach(),XY) | |
| xt0 = XYt[0] | |
| yt0 = XYt[1] | |
| return xt0, yt0 | |
| def wrap_points(img, x_source, y_source, l=1, DIM =dim): | |
| for i in range(len(y_source)): | |
| x0 = x_source[i].int() | |
| y0 = y_source[i].int() | |
| if (x0<DIM) and (x0>0) and (y0<DIM) and (y0>0): | |
| img[:,:,y0-l:y0+l,x0-l:x0+l] = 0 | |
| return img | |
| def wrap_imge_cropped(Affine_mtrx, source_img, dim1=224, dim2=128): | |
| source_img224 = torch.nn.ZeroPad2d(int((dim1-dim2)/2))(source_img) | |
| grd = torch.nn.functional.affine_grid(Affine_mtrx, size=source_img224.shape,align_corners=False) | |
| wrapped_img = torch.nn.functional.grid_sample(source_img224, grid=grd, | |
| mode='bilinear', padding_mode='zeros', align_corners=False) | |
| wrapped_img = torchvision.transforms.CenterCrop((dim2, dim2))(wrapped_img) | |
| return wrapped_img | |
| def preprocess_image(image_path, dim = 128): | |
| img = torch.zeros([1,3,dim,dim]) | |
| img[0] = load_image_pil_accelerated(image_path, dim) | |
| return img | |