File size: 7,885 Bytes
fffea84
 
 
69d1aad
 
 
 
fffea84
 
 
 
69d1aad
 
 
 
 
 
 
 
 
 
 
 
 
 
fffea84
69d1aad
 
fffea84
69d1aad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fffea84
69d1aad
 
 
 
 
 
fffea84
69d1aad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fffea84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01b1eda
 
 
 
 
 
 
69d1aad
 
 
 
 
 
 
 
 
 
 
 
fffea84
2a6b806
 
fffea84
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import os, sys
import numpy as np
import random
import time
import json
import matplotlib.pyplot as plt
import tqdm
import torch



def generate_standard_elips(N_samples = 10, a= 1,b = 1):
    radius = 0.5
    center = 0
    N_samples1 = int(N_samples/2 - 1)
    N_samples2 = N_samples - N_samples1
    x1 = np.random.uniform((center-radius)*a,(center+radius)*a, size = N_samples1)
    x1_ordered = np.sort(x1)
    y1 = center + b*np.sqrt(radius**2 - ((x1_ordered-center)/a)**2)
    x2 =  np.random.uniform((center-radius)*a,(center+radius)*a, size = N_samples - N_samples1)
    x2_ordered = -np.sort(-x2) #the minus sign to sort descindingly
    y2 = center - b*np.sqrt(radius**2 - ((x2_ordered-center)/a)**2)
    x = np.concatenate([x1_ordered, x2_ordered], axis=0)
    y = np.concatenate([y1, y2], axis=0)
    return x, y

def destandarize_point(d, dim=128):
    return dim*(d + 0.5)

def To_pointcloud(x,y,z=0):
    N_points = x.shape[0]
    point_cloud = np.zeros([N_points,3])
    point_cloud[:,0] = x
    point_cloud[:,1] = y
    if not z==0:
        point_cloud[:,2] = z
    return point_cloud

def To_xyz(point_cloud):
    x = point_cloud[:,0]
    y = point_cloud[:,1]
    z = point_cloud[:,2]
    return x,y,z


def random_rigid_transformation(dim=2):
    #dim = 4
    rotation_x = 0
    rotation_y = 0
    rotation_z = random.uniform(0, 2)*np.pi
    translation_x = random.uniform(-1, 1)*dim
    translation_y = random.uniform(-1, 1)*dim
    translation_z = 0
    reflection_x = random.sample([-1,1],1)[0]
    reflection_y = random.sample([-1,1],1)[0]
    reflection_z = 1
    Rotx = np.array([[1,0,0],
                     [0,np.cos(rotation_x),-np.sin(rotation_x)],
                     [0,np.sin(rotation_x),np.cos(rotation_x)]])
    Roty = np.array([[np.cos(rotation_y),0,np.sin(rotation_y)],
                     [0,1,0],
                     [-np.sin(rotation_y),0,np.cos(rotation_y)]])
    Rotz = np.array([[np.cos(rotation_z),-np.sin(rotation_z),0],
                     [np.sin(rotation_z),np.cos(rotation_z),0],
                     [0,0,1]])
    Rotation = np.matmul(Rotz, np.matmul(Roty,Rotx))
    Reflection = np.array([[reflection_x,0,0],[0,reflection_y,0],[0,0,reflection_z]])
    Translation = np.array([translation_x,translation_y,translation_z])
    RefRotation = np.matmul(Reflection,Rotation)
    return RefRotation, Translation



def rigid_2Dtransformation(prdction):
    #prediction = [rotationz, reflectionx, reflectiony, translationx, translationy]
    N_examples = prdction['rotation'].shape[0]
    Translation = prdction['translation']
    Reflection = torch.zeros([N_examples,3,3])#
    Reflection[:,0,0] = prdction['reflection'][:,0]
    Reflection[:,1,1] = prdction['reflection'][:,1]
    Reflection[:,2,2] = 1.
    rotation_z = prdction['rotation'][:,2]
    Rotation = torch.zeros([N_examples,3,3])#np.repeat(np.eye(3)[None,:,:],N_examples, axis=0))
    Rotation[:,0,0] = torch.cos(rotation_z)
    Rotation[:,1,1] = torch.cos(rotation_z)
    Rotation[:,0,1] = -torch.sin(rotation_z)
    Rotation[:,2,2] = 1.
    Rotation[:,1,0] = torch.sin(rotation_z)
    RefRotation = torch.matmul(Reflection,Rotation)
    return RefRotation, Translation

def batch_hausdorff_prcnt_distance(batch_point_cloud1, point_cloud2, percentile = 0.95):
    assert point_cloud2.shape[0]==3
    assert batch_point_cloud1.shape[1]==3
    distances = torch.norm(batch_point_cloud1[:, :, None,:] - point_cloud2[None, :, :,None], dim=1)
    dists1 = torch.min(distances, dim=1).values
    dists2 = torch.min(distances, dim=2).values
    # Calculate the 95th percentile distance
    percentile_95 = torch.quantile(torch.cat([dists1, dists2],axis=1), percentile, interpolation='linear', dim=1)
    return percentile_95

def HDloss(prd, pointcloud_source_norm_torch,pointcloud_target_norm_torch, percentile = 0.95):
    A, b = rigid_2Dtransformation(prd)
    point_cloud_wrapped = torch.matmul(A, pointcloud_source_norm_torch.T) + b[:,:,None]
    loss = batch_hausdorff_prcnt_distance(point_cloud_wrapped, pointcloud_target_norm_torch.T, percentile)
    return loss

def Mean_HDloss(prd, pointcloud_source_norm_torch, pointcloud_target_norm_torch, percentile = 0.95):
    loss = HDloss(prd, pointcloud_source_norm_torch, pointcloud_target_norm_torch, percentile = 0.95)
    return torch.mean(loss)


def wrap_pointcloud(record, pointcloud_source):
    #normalize first
    PC1_mean = np.mean(pointcloud_source, axis=0)
    pointcloud_source_norm = pointcloud_source - PC1_mean
    pointcloud_source_norm_torch = torch.tensor(pointcloud_source_norm, requires_grad=False).to(torch.float32)
    # find Tx
    A, b = rigid_2Dtransformation(record)
    point_cloud_wrapped = torch.matmul(A, pointcloud_source_norm_torch.T) + b[:,:,None]
    return point_cloud_wrapped

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

def move_dict2device(dictionary,device):
    for key in list(dictionary.keys()):
        dictionary[key] = dictionary[key].to(device)
    return dictionary

eps = 0.000001948
class Optimization_model(torch.nn.Module):
    def __init__(self):
        super(Optimization_model, self).__init__()
        self.alpha = torch.nn.Parameter(torch.tensor(0.5, requires_grad=True))
        self.rotation = torch.nn.Parameter(torch.tensor([0.0, 0.0, 0.0], requires_grad=True))
        self.translation = torch.nn.Parameter(torch.tensor([0.01,-0.01, 0.0], requires_grad=True))
        self.reflection = torch.nn.Parameter(torch.sign(torch.tensor([0.01,-0.01, 1], requires_grad=True)))
        #self.rigid = torch.nn.Parameter(torch.tensor([0.0,1.0,1.0,0.1,0.1], requires_grad=True))
    def forward(self, input_X_batch):
        predicted_rotation = self.alpha*self.rotation + (1-self.alpha)*input_X_batch['rotation']
        predicted_translation = self.alpha*self.translation + (1-self.alpha)*input_X_batch['translation']
        predicted_reflection = torch.sign(self.alpha*self.reflection +
                                          (1-self.alpha)*input_X_batch['reflection']+eps)
        return {'rotation':predicted_rotation,
                'translation':predicted_translation,
                'reflection':predicted_reflection}

class Dataset(torch.utils.data.Dataset):
  def __init__(self, dataset_size, N_dim = 2):
        self.dataset_size = dataset_size
        self.N_dim = 2
  def __len__(self):
        return int(self.dataset_size)
  def __getitem__(self, index):
        rotation = np.pi*(-1 + 2*torch.rand([3]))
        translation = -0.1 + 0.2*torch.rand([3])
        reflection  = torch.sign(torch.rand([3]) - 0.5)
        if self.N_dim == 2:
            rotation[0:2]=0
            translation[2]=0
            reflection[2]=1
        random_solution = {'rotation':rotation,
                'translation':translation,
                'reflection':reflection}
        return random_solution






def pil_to_numpy(im):
    im.load()
    # Unpack data
    e = Image._getencoder(im.mode, "raw", im.mode)
    e.setimage(im.im)
    # NumPy buffer for the result
    shape, typestr = Image._conv_type_shape(im)
    data = np.empty(shape, dtype=np.dtype(typestr))
    mem = data.data.cast("B", (data.data.nbytes,))
    bufsize, s, offset = 65536, 0, 0
    while not s:
        l, s, d = e.encode(bufsize)
        mem[offset:offset + len(d)] = d
        offset += len(d)
    if s < 0:
        raise RuntimeError("encoder error %d in tobytes" % s)
    return data














def load_image_pil_accelerated(image_path, dim=128):
    image = Image.open(image_path).convert("RGB")
    array = pil_to_numpy(image)
    tensor = torch.from_numpy(np.rollaxis(array,2,0)/255).to(torch.float32)
    tensor = torchvision.transforms.Resize((dim,dim))(tensor)
    return tensor


def preprocess_image(image_path, dim = 128):
    img = torch.zeros([1,3,dim,dim])
    img[0] = load_image_pil_accelerated(image_path, dim)
    return img