Amould commited on
Commit
69d1aad
·
verified ·
1 Parent(s): a1b04b8

Update codes.py

Browse files
Files changed (1) hide show
  1. codes.py +171 -130
codes.py CHANGED
@@ -1,66 +1,173 @@
1
  import os, sys
2
  import numpy as np
3
- from PIL import Image
4
- import itertools
5
- import glob
6
  import random
 
 
 
 
7
  import torch
8
- import torchvision
9
- import torchvision.transforms as transforms
10
- import torch.nn as nn
11
- import torch.optim as optim
12
- import torch.nn.functional as F
13
- from torch.nn.functional import relu as RLU
14
 
15
- registration_method = 'Additive_Recurence' #{'Rawblock', 'matching_points', 'Additive_Recurence', 'Multiplicative_Recurence'} #'recurrent_matrix',
16
- imposed_point = 0
17
- Arch = 'ResNet'
18
- Fix_Torch_Wrap = False
19
- BW_Position = False
20
- dim = 128
21
- dim0 =224
22
- crop_ratio = dim/dim0
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
 
 
26
 
27
- class Identity(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def __init__(self):
29
- super(Identity, self).__init__()
30
- def forward(self, x):
31
- return x
32
-
33
- class Build_IRmodel_Resnet(nn.Module):
34
- def __init__(self, resnet_model, registration_method = 'Additive_Recurence', BW_Position=False):
35
- super(Build_IRmodel_Resnet, self).__init__()
36
- self.resnet_model = resnet_model
37
- self.BW_Position = BW_Position
38
- self.N_parameters = 6
39
- self.registration_method = registration_method
40
- self.fc1 =nn.Linear(6, 64)
41
- self.fc2 =nn.Linear(64, 128*3)
42
- self.fc3 =nn.Linear(512, self.N_parameters)
43
  def forward(self, input_X_batch):
44
- source = input_X_batch['source']
45
- target = input_X_batch['target']
46
- if 'Recurence' in self.registration_method:
47
- M_i = input_X_batch['M_i'].view(-1, 6)
48
- M_rep = F.relu(self.fc1(M_i))
49
- M_rep = F.relu(self.fc2(M_rep)).view(-1,3,1,128)
50
- concatenated_input = torch.cat((source,target,M_rep), dim=2)
51
- else:
52
- concatenated_input = torch.cat((source,target), dim=2)
53
- resnet_output = self.resnet_model(concatenated_input)
54
- predicted_line = self.fc3(resnet_output)
55
- if 'Recurence' in self.registration_method:
56
- predicted_part_mtrx = predicted_line.view(-1, 2, 3)
57
- Prd_Affine_mtrx = predicted_part_mtrx + input_X_batch['M_i']
58
- predction = {'predicted_part_mtrx':predicted_part_mtrx,
59
- 'Affine_mtrx': Prd_Affine_mtrx}
60
- else:
61
- Prd_Affine_mtrx = predicted_line.view(-1, 2, 3)
62
- predction = {'Affine_mtrx': Prd_Affine_mtrx}
63
- return predction
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  def pil_to_numpy(im):
66
  im.load()
@@ -80,93 +187,27 @@ def pil_to_numpy(im):
80
  raise RuntimeError("encoder error %d in tobytes" % s)
81
  return data
82
 
83
- def load_image_pil_accelerated(image_path, dim=128):
84
- image = Image.open(image_path).convert("RGB")
85
- array = pil_to_numpy(image)
86
- tensor = torch.from_numpy(np.rollaxis(array,2,0)/255).to(torch.float32)
87
- tensor = torchvision.transforms.Resize((dim,dim))(tensor)
88
- return tensor
89
 
90
 
91
- def workaround_matrix(Affine_mtrx0, acc = 2):
92
- # To find the equivalent torch-compatible matrix from a correct matrix set acc=2 #This will be needed for transforming an image
93
- # To find the correct Affine matrix from Torch compatible matrix set acc=0.5
94
- Affine_mtrx_adj = inv_AM(Affine_mtrx0)
95
- Affine_mtrx_adj[:,:,2]*=acc
96
- return Affine_mtrx_adj
97
-
98
- def inv_AM(Affine_mtrx):
99
- AM3 = mtrx3(Affine_mtrx)
100
- AM_inv = torch.linalg.inv(AM3)
101
- return AM_inv[:,0:2,:]
102
-
103
- def mtrx3(Affine_mtrx):
104
- mtrx_shape = Affine_mtrx.shape
105
- if len(mtrx_shape)==3:
106
- N_Mbatches = mtrx_shape[0]
107
- AM3 = torch.zeros( [N_Mbatches,3,3])#.to(device)
108
- AM3[:,0:2,:] = Affine_mtrx
109
- AM3[:,2,2] = 1
110
- elif len(mtrx_shape)==2:
111
- N_Mbatches = 1
112
- AM3 = torch.zeros([3,3])#.to(device)
113
- AM3[0:2,:] = Affine_mtrx
114
- AM3[2,2] = 1
115
- return AM3
116
-
117
- def standarize_point(d, dim=128, flip = False):
118
- if flip:
119
- d = -d
120
- return d/dim - 0.5
121
-
122
- def destandarize_point(d, dim=128, flip = False):
123
- if flip:
124
- d = -d
125
- return dim*(d + 0.5)
126
 
127
- def generate_standard_elips(N_samples = 100, a= 1,b = 1):
128
- radius = 0.25
129
- center = 0
130
- N_samples1 = int(N_samples/2 - 1)
131
- N_samples2 = N_samples - N_samples1
132
- x1 = torch.distributions.uniform.Uniform(center-radius,center + radius).sample([N_samples1])
133
- x1_ordered = torch.sort(x1).values
134
- y1 = center + b*torch.sqrt(radius**2 - ((x1_ordered-center)/a)**2)
135
- x2 = torch.distributions.uniform.Uniform(center-radius,center + radius).sample([N_samples2])
136
- x2_ordered = torch.sort(x2, descending=True).values
137
- y2 = center - b*torch.sqrt(radius**2 - ((x2_ordered-center)/a)**2)
138
- x = torch.cat([x1_ordered, x2_ordered])
139
- y = torch.cat([y1, y2])
140
- return x, y
141
 
142
- def transform_standard_points(Affine_mat, x,y):
143
- XY = torch.ones([3,x.shape[0]])
144
- XY[0,:]= x
145
- XY[1,:]= y
146
- XYt = torch.matmul(Affine_mat.to('cpu').detach(),XY)
147
- xt0 = XYt[0]
148
- yt0 = XYt[1]
149
- return xt0, yt0
150
-
151
- def wrap_points(img, x_source, y_source, l=1, DIM =dim):
152
- for i in range(len(y_source)):
153
- x0 = x_source[i].int()
154
- y0 = y_source[i].int()
155
- if (x0<DIM) and (x0>0) and (y0<DIM) and (y0>0):
156
- img[:,:,y0-l:y0+l,x0-l:x0+l] = 0
157
- return img
158
 
159
 
160
- def wrap_imge_cropped(Affine_mtrx, source_img, dim1=224, dim2=128):
161
- source_img224 = torch.nn.ZeroPad2d(int((dim1-dim2)/2))(source_img)
162
- grd = torch.nn.functional.affine_grid(Affine_mtrx, size=source_img224.shape,align_corners=False)
163
- wrapped_img = torch.nn.functional.grid_sample(source_img224, grid=grd,
164
- mode='bilinear', padding_mode='zeros', align_corners=False)
165
- wrapped_img = torchvision.transforms.CenterCrop((dim2, dim2))(wrapped_img)
166
- return wrapped_img
167
 
168
 
169
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  def preprocess_image(image_path, dim = 128):
171
  img = torch.zeros([1,3,dim,dim])
172
  img[0] = load_image_pil_accelerated(image_path, dim)
 
1
  import os, sys
2
  import numpy as np
 
 
 
3
  import random
4
+ import time
5
+ import json
6
+ import matplotlib.pyplot as plt
7
+ import tqdm
8
  import torch
 
 
 
 
 
 
9
 
 
 
 
 
 
 
 
 
10
 
11
 
12
+ def generate_standard_elips(N_samples = 10, a= 1,b = 1):
13
+ radius = 0.5
14
+ center = 0
15
+ N_samples1 = int(N_samples/2 - 1)
16
+ N_samples2 = N_samples - N_samples1
17
+ x1 = np.random.uniform((center-radius)*a,(center+radius)*a, size = N_samples1)
18
+ x1_ordered = np.sort(x1)
19
+ y1 = center + b*np.sqrt(radius**2 - ((x1_ordered-center)/a)**2)
20
+ x2 = np.random.uniform((center-radius)*a,(center+radius)*a, size = N_samples - N_samples1)
21
+ x2_ordered = -np.sort(-x2) #the minus sign to sort descindingly
22
+ y2 = center - b*np.sqrt(radius**2 - ((x2_ordered-center)/a)**2)
23
+ x = np.concatenate([x1_ordered, x2_ordered], axis=0)
24
+ y = np.concatenate([y1, y2], axis=0)
25
+ return x, y
26
 
27
+ def destandarize_point(d, dim=128):
28
+ return dim*(d + 0.5)
29
 
30
+ def To_pointcloud(x,y,z=0):
31
+ N_points = x.shape[0]
32
+ point_cloud = np.zeros([N_points,3])
33
+ point_cloud[:,0] = x
34
+ point_cloud[:,1] = y
35
+ if not z==0:
36
+ point_cloud[:,2] = z
37
+ return point_cloud
38
+
39
+ def To_xyz(point_cloud):
40
+ x = point_cloud[:,0]
41
+ y = point_cloud[:,1]
42
+ z = point_cloud[:,2]
43
+ return x,y,z
44
+
45
+
46
+ def random_rigid_transformation(dim=2):
47
+ #dim = 4
48
+ rotation_x = 0
49
+ rotation_y = 0
50
+ rotation_z = random.uniform(0, 2)*np.pi
51
+ translation_x = random.uniform(-1, 1)*dim
52
+ translation_y = random.uniform(-1, 1)*dim
53
+ translation_z = 0
54
+ reflection_x = random.sample([-1,1],1)[0]
55
+ reflection_y = random.sample([-1,1],1)[0]
56
+ reflection_z = 1
57
+ Rotx = np.array([[1,0,0],
58
+ [0,np.cos(rotation_x),-np.sin(rotation_x)],
59
+ [0,np.sin(rotation_x),np.cos(rotation_x)]])
60
+ Roty = np.array([[np.cos(rotation_y),0,np.sin(rotation_y)],
61
+ [0,1,0],
62
+ [-np.sin(rotation_y),0,np.cos(rotation_y)]])
63
+ Rotz = np.array([[np.cos(rotation_z),-np.sin(rotation_z),0],
64
+ [np.sin(rotation_z),np.cos(rotation_z),0],
65
+ [0,0,1]])
66
+ Rotation = np.matmul(Rotz, np.matmul(Roty,Rotx))
67
+ Reflection = np.array([[reflection_x,0,0],[0,reflection_y,0],[0,0,reflection_z]])
68
+ Translation = np.array([translation_x,translation_y,translation_z])
69
+ RefRotation = np.matmul(Reflection,Rotation)
70
+ return RefRotation, Translation
71
+
72
+
73
+
74
+ def rigid_2Dtransformation(prdction):
75
+ #prediction = [rotationz, reflectionx, reflectiony, translationx, translationy]
76
+ N_examples = prdction['rotation'].shape[0]
77
+ Translation = prdction['translation']
78
+ Reflection = torch.zeros([N_examples,3,3])#
79
+ Reflection[:,0,0] = prdction['reflection'][:,0]
80
+ Reflection[:,1,1] = prdction['reflection'][:,1]
81
+ Reflection[:,2,2] = 1.
82
+ rotation_z = prdction['rotation'][:,2]
83
+ Rotation = torch.zeros([N_examples,3,3])#np.repeat(np.eye(3)[None,:,:],N_examples, axis=0))
84
+ Rotation[:,0,0] = torch.cos(rotation_z)
85
+ Rotation[:,1,1] = torch.cos(rotation_z)
86
+ Rotation[:,0,1] = -torch.sin(rotation_z)
87
+ Rotation[:,2,2] = 1.
88
+ Rotation[:,1,0] = torch.sin(rotation_z)
89
+ RefRotation = torch.matmul(Reflection,Rotation)
90
+ return RefRotation, Translation
91
+
92
+ def batch_hausdorff_prcnt_distance(batch_point_cloud1, point_cloud2, percentile = 0.95):
93
+ assert point_cloud2.shape[0]==3
94
+ assert batch_point_cloud1.shape[1]==3
95
+ distances = torch.norm(batch_point_cloud1[:, :, None,:] - point_cloud2[None, :, :,None], dim=1)
96
+ dists1 = torch.min(distances, dim=1).values
97
+ dists2 = torch.min(distances, dim=2).values
98
+ # Calculate the 95th percentile distance
99
+ percentile_95 = torch.quantile(torch.cat([dists1, dists2],axis=1), percentile, interpolation='linear', dim=1)
100
+ return percentile_95
101
+
102
+ def HDloss(prd, pointcloud_source_norm_torch,pointcloud_target_norm_torch, percentile = 0.95):
103
+ A, b = rigid_2Dtransformation(prd)
104
+ point_cloud_wrapped = torch.matmul(A, pointcloud_source_norm_torch.T) + b[:,:,None]
105
+ loss = batch_hausdorff_prcnt_distance(point_cloud_wrapped, pointcloud_target_norm_torch.T, percentile)
106
+ return loss
107
+
108
+ def Mean_HDloss(prd, pointcloud_source_norm_torch, pointcloud_target_norm_torch, percentile = 0.95):
109
+ loss = HDloss(prd, pointcloud_source_norm_torch, pointcloud_target_norm_torch, percentile = 0.95)
110
+ return torch.mean(loss)
111
+
112
+
113
+ def wrap_pointcloud(record, pointcloud_source):
114
+ #normalize first
115
+ PC1_mean = np.mean(pointcloud_source, axis=0)
116
+ pointcloud_source_norm = pointcloud_source - PC1_mean
117
+ pointcloud_source_norm_torch = torch.tensor(pointcloud_source_norm, requires_grad=False).to(torch.float32)
118
+ # find Tx
119
+ A, b = rigid_2Dtransformation(record)
120
+ point_cloud_wrapped = torch.matmul(A, pointcloud_source_norm_torch.T) + b[:,:,None]
121
+ return point_cloud_wrapped
122
+
123
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
124
+
125
+ def move_dict2device(dictionary,device):
126
+ for key in list(dictionary.keys()):
127
+ dictionary[key] = dictionary[key].to(device)
128
+ return dictionary
129
+
130
+ eps = 0.000001948
131
+ class Optimization_model(torch.nn.Module):
132
  def __init__(self):
133
+ super(Optimization_model, self).__init__()
134
+ self.alpha = torch.nn.Parameter(torch.tensor(0.5, requires_grad=True))
135
+ self.rotation = torch.nn.Parameter(torch.tensor([0.0, 0.0, 0.0], requires_grad=True))
136
+ self.translation = torch.nn.Parameter(torch.tensor([0.01,-0.01, 0.0], requires_grad=True))
137
+ self.reflection = torch.nn.Parameter(torch.sign(torch.tensor([0.01,-0.01, 1], requires_grad=True)))
138
+ #self.rigid = torch.nn.Parameter(torch.tensor([0.0,1.0,1.0,0.1,0.1], requires_grad=True))
 
 
 
 
 
 
 
 
139
  def forward(self, input_X_batch):
140
+ predicted_rotation = self.alpha*self.rotation + (1-self.alpha)*input_X_batch['rotation']
141
+ predicted_translation = self.alpha*self.translation + (1-self.alpha)*input_X_batch['translation']
142
+ predicted_reflection = torch.sign(self.alpha*self.reflection +
143
+ (1-self.alpha)*input_X_batch['reflection']+eps)
144
+ return {'rotation':predicted_rotation,
145
+ 'translation':predicted_translation,
146
+ 'reflection':predicted_reflection}
147
+
148
+ class Dataset(torch.utils.data.Dataset):
149
+ def __init__(self, dataset_size, N_dim = 2):
150
+ self.dataset_size = dataset_size
151
+ self.N_dim = 2
152
+ def __len__(self):
153
+ return int(self.dataset_size)
154
+ def __getitem__(self, index):
155
+ rotation = np.pi*(-1 + 2*torch.rand([3]))
156
+ translation = -0.1 + 0.2*torch.rand([3])
157
+ reflection = torch.sign(torch.rand([3]) - 0.5)
158
+ if self.N_dim == 2:
159
+ rotation[0:2]=0
160
+ translation[2]=0
161
+ reflection[2]=1
162
+ random_solution = {'rotation':rotation,
163
+ 'translation':translation,
164
+ 'reflection':reflection}
165
+ return random_solution
166
+
167
+
168
+
169
+
170
+
171
 
172
  def pil_to_numpy(im):
173
  im.load()
 
187
  raise RuntimeError("encoder error %d in tobytes" % s)
188
  return data
189
 
 
 
 
 
 
 
190
 
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
 
 
 
 
 
 
 
 
196
 
197
 
198
 
199
+
200
+
201
+
202
+
203
+ def load_image_pil_accelerated(image_path, dim=128):
204
+ image = Image.open(image_path).convert("RGB")
205
+ array = pil_to_numpy(image)
206
+ tensor = torch.from_numpy(np.rollaxis(array,2,0)/255).to(torch.float32)
207
+ tensor = torchvision.transforms.Resize((dim,dim))(tensor)
208
+ return tensor
209
+
210
+
211
  def preprocess_image(image_path, dim = 128):
212
  img = torch.zeros([1,3,dim,dim])
213
  img[0] = load_image_pil_accelerated(image_path, dim)