mkalia commited on
Commit
e015760
·
verified ·
1 Parent(s): d4950f4

Upload layers.py

Browse files
Files changed (1) hide show
  1. layers.py +962 -0
layers.py ADDED
@@ -0,0 +1,962 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright Niantic 2019. Patent Pending. All rights reserved.
2
+ #
3
+ # This software is licensed under the terms of the Monodepth2 licence
4
+ # which allows for non-commercial use only, the full terms of which are made
5
+ # available in the LICENSE file.
6
+
7
+ from __future__ import absolute_import, division, print_function
8
+
9
+ import numpy as np
10
+ from scipy.spatial.transform import Rotation as R
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ # from torchmetrics.image.fid import FrechetInceptionDistance
16
+
17
+ # def silog(real1, fake1):
18
+ # # filter out invalid pixels
19
+ # real = real1.clone()
20
+ # fake = fake1.clone()
21
+ # N = (real>0).float().sum()
22
+ # mask1 = (real<=0)
23
+ # mask2 = (fake<=0)
24
+ # mask3 = mask1+mask2
25
+ # # mask = 1.0 - (mask3>0).float()
26
+ # mask = (mask3>0)
27
+ # fake[mask] = 1.
28
+ # real[mask] = 1.
29
+
30
+ # loss_ = torch.log(real)-torch.log(fake)
31
+ # loss = torch.sqrt((torch.sum( loss_ ** 2) / N ) - ((torch.sum(loss_)/N)**2))
32
+ # return loss
33
+
34
+
35
+
36
+ class SpatialTransformer(nn.Module):
37
+
38
+ def __init__(self, size, mode='bilinear'):
39
+ """
40
+ Instiantiate the block
41
+ :param size: size of input to the spatial transformer block
42
+ :param mode: method of interpolation for grid_sampler
43
+ """
44
+ super(SpatialTransformer, self).__init__()
45
+
46
+ # Create sampling grid
47
+ vectors = [torch.arange(0, s) for s in size]
48
+ grids = torch.meshgrid(vectors)
49
+ grid = torch.stack(grids) # y, x, z
50
+ grid = torch.unsqueeze(grid, 0) # add batch
51
+ grid = grid.type(torch.FloatTensor)
52
+ self.register_buffer('grid', grid)
53
+ self.mode = mode
54
+
55
+ def forward(self, src, flow):
56
+ """
57
+ Push the src and flow through the spatial transform block
58
+ :param src: the source image
59
+ :param flow: the output from the U-Net
60
+ """
61
+ new_locs = self.grid + flow
62
+ shape = flow.shape[2:]
63
+
64
+ # Need to normalize grid values to [-1, 1] for resampler
65
+ for i in range(len(shape)):
66
+ new_locs[:, i, ...] = 2*(new_locs[:, i, ...]/(shape[i]-1) - 0.5)
67
+
68
+ if len(shape) == 2:
69
+ new_locs = new_locs.permute(0, 2, 3, 1)
70
+ new_locs = new_locs[..., [1, 0]]
71
+ elif len(shape) == 3:
72
+ new_locs = new_locs.permute(0, 2, 3, 4, 1)
73
+ new_locs = new_locs[..., [2, 1, 0]]
74
+
75
+ return F.grid_sample(src, new_locs, mode=self.mode, padding_mode="border")
76
+
77
+
78
+
79
+ class optical_flow(nn.Module):
80
+
81
+ def __init__(self, size, batch_size, height, width, eps=1e-7):
82
+ super(optical_flow, self).__init__()
83
+
84
+ # Create sampling grid
85
+ vectors = [torch.arange(0, s) for s in size]
86
+ grids = torch.meshgrid(vectors)
87
+ grid = torch.stack(grids) # y, x, z
88
+ grid = torch.unsqueeze(grid, 0) # add batch
89
+ grid = grid.type(torch.FloatTensor)
90
+ self.register_buffer('grid', grid)
91
+
92
+ self.batch_size = batch_size
93
+ self.height = height
94
+ self.width = width
95
+ self.eps = eps
96
+
97
+ def forward(self, points, K, T):
98
+
99
+ P = torch.matmul(K, T)[:, :3, :]
100
+ cam_points = torch.matmul(P, points)
101
+ pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps)
102
+ pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width)
103
+ optical_flow = pix_coords[:, [1,0], ...] - self.grid
104
+
105
+ return optical_flow
106
+
107
+
108
+ def get_corresponding_map(data):
109
+ """
110
+ :param data: unnormalized coordinates Bx2xHxW
111
+ :return: Bx1xHxW
112
+ """
113
+ B, _, H, W = data.size()
114
+
115
+ # x = data[:, 0, :, :].view(B, -1).clamp(0, W - 1) # BxN (N=H*W)
116
+ # y = data[:, 1, :, :].view(B, -1).clamp(0, H - 1)
117
+
118
+ x = data[:, 0, :, :].view(B, -1) # BxN (N=H*W)
119
+ y = data[:, 1, :, :].view(B, -1)
120
+
121
+ # invalid = (x < 0) | (x > W - 1) | (y < 0) | (y > H - 1) # BxN
122
+ # invalid = invalid.repeat([1, 4])
123
+
124
+ x1 = torch.floor(x)
125
+ x_floor = x1.clamp(0, W - 1)
126
+ y1 = torch.floor(y)
127
+ y_floor = y1.clamp(0, H - 1)
128
+ x0 = x1 + 1
129
+ x_ceil = x0.clamp(0, W - 1)
130
+ y0 = y1 + 1
131
+ y_ceil = y0.clamp(0, H - 1)
132
+
133
+ x_ceil_out = x0 != x_ceil
134
+ y_ceil_out = y0 != y_ceil
135
+ x_floor_out = x1 != x_floor
136
+ y_floor_out = y1 != y_floor
137
+ invalid = torch.cat([x_ceil_out | y_ceil_out,
138
+ x_ceil_out | y_floor_out,
139
+ x_floor_out | y_ceil_out,
140
+ x_floor_out | y_floor_out], dim=1)
141
+
142
+ # encode coordinates, since the scatter function can only index along one axis
143
+ corresponding_map = torch.zeros(B, H * W).type_as(data)
144
+ indices = torch.cat([x_ceil + y_ceil * W,
145
+ x_ceil + y_floor * W,
146
+ x_floor + y_ceil * W,
147
+ x_floor + y_floor * W], 1).long() # BxN (N=4*H*W)
148
+ values = torch.cat([(1 - torch.abs(x - x_ceil)) * (1 - torch.abs(y - y_ceil)),
149
+ (1 - torch.abs(x - x_ceil)) * (1 - torch.abs(y - y_floor)),
150
+ (1 - torch.abs(x - x_floor)) * (1 - torch.abs(y - y_ceil)),
151
+ (1 - torch.abs(x - x_floor)) * (1 - torch.abs(y - y_floor))],
152
+ 1)
153
+ # values = torch.ones_like(values)
154
+
155
+ values[invalid] = 0
156
+
157
+ corresponding_map.scatter_add_(1, indices, values)
158
+ # decode coordinates
159
+ corresponding_map = corresponding_map.view(B, H, W)
160
+
161
+ return corresponding_map.unsqueeze(1)
162
+
163
+ class get_occu_mask_backward(nn.Module):
164
+
165
+ def __init__(self, size):
166
+ super(get_occu_mask_backward, self).__init__()
167
+
168
+ # Create sampling grid
169
+ vectors = [torch.arange(0, s) for s in size]
170
+ grids = torch.meshgrid(vectors)
171
+ grid = torch.stack(grids) # y, x, z
172
+ grid = torch.unsqueeze(grid, 0) # add batch
173
+ grid = grid.type(torch.FloatTensor)
174
+ self.register_buffer('grid', grid)
175
+
176
+ def forward(self, flow, th=0.95):
177
+
178
+ new_locs = self.grid + flow
179
+ new_locs = new_locs[:, [1,0], ...]
180
+ corr_map = get_corresponding_map(new_locs)
181
+ occu_map = corr_map
182
+ occu_mask = (occu_map > th).float()
183
+
184
+ return occu_mask, occu_map
185
+
186
+
187
+ class get_occu_mask_bidirection(nn.Module):
188
+
189
+ def __init__(self, size, mode='bilinear'):
190
+ super(get_occu_mask_bidirection, self).__init__()
191
+
192
+ # Create sampling grid
193
+ vectors = [torch.arange(0, s) for s in size]
194
+ grids = torch.meshgrid(vectors)
195
+ grid = torch.stack(grids) # y, x, z
196
+ grid = torch.unsqueeze(grid, 0) # add batch
197
+ grid = grid.type(torch.FloatTensor)
198
+ self.register_buffer('grid', grid)
199
+ self.mode = mode
200
+
201
+ def forward(self, flow12, flow21, scale=0.01, bias=0.5):
202
+
203
+ new_locs = self.grid + flow12
204
+ shape = flow12.shape[2:]
205
+
206
+ # Need to normalize grid values to [-1, 1] for resampler
207
+ for i in range(len(shape)):
208
+ new_locs[:, i, ...] = 2*(new_locs[:, i, ...]/(shape[i]-1) - 0.5)
209
+
210
+ if len(shape) == 2:
211
+ new_locs = new_locs.permute(0, 2, 3, 1)
212
+ new_locs = new_locs[..., [1, 0]]
213
+ elif len(shape) == 3:
214
+ new_locs = new_locs.permute(0, 2, 3, 4, 1)
215
+ new_locs = new_locs[..., [2, 1, 0]]
216
+
217
+ flow21_warped = F.grid_sample(flow21, new_locs, mode=self.mode, padding_mode="border")
218
+ flow12_diff = torch.abs(flow12 + flow21_warped)
219
+ # mag = (flow12 * flow12).sum(1, keepdim=True) + \
220
+ # (flow21_warped * flow21_warped).sum(1, keepdim=True)
221
+ # occ_thresh = scale * mag + bias
222
+ # occ_mask = (flow12_diff * flow12_diff).sum(1, keepdim=True) < occ_thresh
223
+
224
+ return flow12_diff
225
+
226
+ # functions
227
+ def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
228
+ """
229
+ Return the rotation matrices for one of the rotations about an axis
230
+ of which Euler angles describe, for each value of the angle given.
231
+
232
+ Args:
233
+ axis: Axis label "X" or "Y or "Z".
234
+ angle: any shape tensor of Euler angles in radians
235
+
236
+ Returns:
237
+ Rotation matrices as tensor of shape (..., 3, 3).
238
+ """
239
+
240
+ cos = torch.cos(angle)
241
+ sin = torch.sin(angle)
242
+ one = torch.ones_like(angle)
243
+ zero = torch.zeros_like(angle)
244
+
245
+ if axis == "X":
246
+ R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
247
+ elif axis == "Y":
248
+ R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
249
+ elif axis == "Z":
250
+ R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
251
+ else:
252
+ raise ValueError("letter must be either X, Y or Z.")
253
+
254
+ return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
255
+
256
+ def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor:
257
+ """
258
+ Convert rotations given as Euler angles in radians to rotation matrices.
259
+
260
+ Args:
261
+ euler_angles: Euler angles in radians as tensor of shape (..., 3).
262
+ convention: Convention string of three uppercase letters from
263
+ {"X", "Y", and "Z"}.
264
+
265
+ Returns:
266
+ Rotation matrices as tensor of shape (..., 3, 3).
267
+ """
268
+ if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
269
+ raise ValueError("Invalid input euler angles.")
270
+ if len(convention) != 3:
271
+ raise ValueError("Convention must have 3 letters.")
272
+ if convention[1] in (convention[0], convention[2]):
273
+ raise ValueError(f"Invalid convention {convention}.")
274
+ for letter in convention:
275
+ if letter not in ("X", "Y", "Z"):
276
+ raise ValueError(f"Invalid letter {letter} in convention string.")
277
+ matrices = [
278
+ _axis_angle_rotation(c, e)
279
+ for c, e in zip(convention, torch.unbind(euler_angles, -1))
280
+ ]
281
+ # return functools.reduce(torch.matmul, matrices)
282
+
283
+ rotation_matrices = torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])
284
+
285
+ rot = torch.zeros((rotation_matrices.shape[0], 4, 4)).to(device=rotation_matrices.device)
286
+
287
+ rot[:, :3, :3] = rotation_matrices.squeeze()
288
+
289
+ rot[:, 3, 3] = 1
290
+
291
+ return rot
292
+
293
+ def _angle_from_tan(
294
+ axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
295
+ ) -> torch.Tensor:
296
+ """
297
+ Extract the first or third Euler angle from the two members of
298
+ the matrix which are positive constant times its sine and cosine.
299
+
300
+ Args:
301
+ axis: Axis label "X" or "Y or "Z" for the angle we are finding.
302
+ other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
303
+ convention.
304
+ data: Rotation matrices as tensor of shape (..., 3, 3).
305
+ horizontal: Whether we are looking for the angle for the third axis,
306
+ which means the relevant entries are in the same row of the
307
+ rotation matrix. If not, they are in the same column.
308
+ tait_bryan: Whether the first and third axes in the convention differ.
309
+
310
+ Returns:
311
+ Euler Angles in radians for each matrix in data as a tensor
312
+ of shape (...).
313
+ """
314
+
315
+ i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
316
+ if horizontal:
317
+ i2, i1 = i1, i2
318
+ even = (axis + other_axis) in ["XY", "YZ", "ZX"]
319
+ if horizontal == even:
320
+ return torch.atan2(data[..., i1], data[..., i2])
321
+ if tait_bryan:
322
+ return torch.atan2(-data[..., i2], data[..., i1])
323
+ return torch.atan2(data[..., i2], -data[..., i1])
324
+
325
+ def matrix_2_euler_vector(matrix, convention = 'ZYX', roll = True):
326
+ # matrix = matrix_in.copy()
327
+ euler = (matrix_to_euler_angles(matrix[:, :3,:3], convention)) # to match with scipy euler = -euler and transpose of this
328
+
329
+ if roll:
330
+ euler[0] = 0.0
331
+ t = matrix[:, :3,3]
332
+
333
+ out = torch.cat([euler, t], dim = 0)
334
+
335
+ return out
336
+
337
+ def _index_from_letter(letter: str) -> int:
338
+ if letter == "X":
339
+ return 0
340
+ if letter == "Y":
341
+ return 1
342
+ if letter == "Z":
343
+ return 2
344
+ raise ValueError("letter must be either X, Y or Z.")
345
+
346
+ def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor:
347
+ """
348
+ Convert rotations given as rotation matrices to Euler angles in radians.
349
+
350
+ Args:
351
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
352
+ convention: Convention string of three uppercase letters.
353
+
354
+ Returns:
355
+ Euler angles in radians as tensor of shape (..., 3).
356
+ """
357
+ if len(convention) != 3:
358
+ raise ValueError("Convention must have 3 letters.")
359
+ if convention[1] in (convention[0], convention[2]):
360
+ raise ValueError(f"Invalid convention {convention}.")
361
+ for letter in convention:
362
+ if letter not in ("X", "Y", "Z"):
363
+ raise ValueError(f"Invalid letter {letter} in convention string.")
364
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
365
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
366
+ i0 = _index_from_letter(convention[0])
367
+ i2 = _index_from_letter(convention[2])
368
+ tait_bryan = i0 != i2
369
+ if tait_bryan:
370
+ central_angle = torch.asin(
371
+ matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
372
+ )
373
+ else:
374
+ central_angle = torch.acos(matrix[..., i0, i0])
375
+
376
+ o = (
377
+ _angle_from_tan(
378
+ convention[0], convention[1], matrix[..., i2], False, tait_bryan
379
+ ),
380
+ central_angle,
381
+ _angle_from_tan(
382
+ convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
383
+ ),
384
+ )
385
+ return torch.stack(o, -1)
386
+
387
+ def computeFID(real_images, fake_images, fid_criterion):
388
+ # metric = FrechetInceptionDistance(feature)
389
+ fid_criterion.update(real_images, real=True)
390
+ fid_criterion.update(fake_images, real=False)
391
+ return fid_criterion.compute()
392
+
393
+
394
+ class SLlog(nn.Module):
395
+ def __init__(self):
396
+ super(SLlog, self).__init__()
397
+
398
+ def forward(self, fake1, real1):
399
+ if not fake1.shape == real1.shape:
400
+ _,_,H,W = real1.shape
401
+ fake = F.upsample(fake, size=(H,W), mode='bilinear')
402
+
403
+ # filter out invalid pixels
404
+ real = real1.clone()
405
+ fake = fake1.clone()
406
+ N = (real>0).float().sum()
407
+ mask1 = (real<=0)
408
+ mask2 = (fake<=0)
409
+ mask3 = mask1+mask2
410
+ # mask = 1.0 - (mask3>0).float()
411
+ mask = (mask3>0)
412
+ fake[mask] = 1.
413
+ real[mask] = 1.
414
+
415
+ loss_ = torch.log(real)-torch.log(fake)
416
+ loss = torch.sqrt((torch.sum( loss_ ** 2) / N ) - ((torch.sum(loss_)/N)**2))
417
+ # loss = 100.* torch.sum( torch.abs(torch.log(real)-torch.log(fake)) ) / N
418
+ return loss
419
+
420
+ class RMSE_log(nn.Module):
421
+ def __init__(self, use_cuda):
422
+ super(RMSE_log, self).__init__()
423
+ self.eps = 1e-8
424
+ self.use_cuda = use_cuda
425
+
426
+ def forward(self, fake, real):
427
+ mask = real<1.
428
+ n,_,h,w = real.size()
429
+ fake = F.upsample(fake, size=(h,w), mode='bilinear')
430
+ fake += self.eps
431
+
432
+ N = len(real[mask])
433
+ loss = torch.sqrt( torch.sum( torch.abs(torch.log(real[mask])-torch.log(fake[mask])) ** 2 ) / N )
434
+ return loss
435
+
436
+ def depth_to_disp(depth, min_disp=0.00001, max_disp = 1.000001):
437
+ """Convert network's sigmoid output into depth prediction
438
+ The formula for this conversion is given in the 'additional considerations'
439
+ section of the paper.
440
+ """
441
+ min_depth = 1 / max_disp
442
+ max_depth = 1 / min_disp
443
+ scaled_depth = min_depth + (max_depth - min_depth) * depth
444
+ disp = 1 / scaled_depth
445
+ return scaled_depth, disp
446
+
447
+ def disp_to_depth(disp, min_depth, max_depth):
448
+ """Convert network's sigmoid output into depth prediction
449
+ The formula for this conversion is given in the 'additional considerations'
450
+ section of the paper.
451
+ """
452
+ min_disp = 1 / max_depth
453
+ max_disp = 1 / min_depth
454
+ scaled_disp = min_disp + (max_disp - min_disp) * disp
455
+ depth = 1 / scaled_disp
456
+ return scaled_disp, depth
457
+
458
+ def disp_to_depth_no_scaling(disp):
459
+ """Convert network's sigmoid output into depth prediction
460
+ The formula for this conversion is given in the 'additional considerations'
461
+ section of the paper.
462
+ """
463
+ depth = 1 / (disp + 1e-7)
464
+ return depth
465
+
466
+
467
+ def transformation_from_parameters(axisangle, translation, invert=False):
468
+ """Convert the network's (axisangle, translation) output into a 4x4 matrix
469
+ """
470
+
471
+ R = rot_from_axisangle(axisangle)
472
+ t = translation.clone()
473
+
474
+ if invert:
475
+ R = R.transpose(1, 2) # uncomment beore running
476
+ t *= -1
477
+
478
+ T = get_translation_matrix(t)
479
+
480
+ if invert:
481
+ M = torch.matmul(R, T)
482
+ else:
483
+ M = torch.matmul(T, R)
484
+
485
+ return M
486
+
487
+ def transformation_from_parameters_euler(euler, translation, invert=False):
488
+ """Convert the network's (axisangle, translation) output into a 4x4 matrix
489
+ """
490
+ # R = torch.transpose(euler_angles_to_matrix(euler, 'ZYX'), 0, 1).permute(1, 0, 2) # to match with scipy euler = -euler and transpose of this
491
+ R = euler_angles_to_matrix(euler, 'ZYX') # to match with scipy euler = -euler and transpose of this
492
+ t = translation.clone()
493
+
494
+ if invert:
495
+ R = R.transpose(1, 2)
496
+ t *= -1
497
+
498
+ T = get_translation_matrix(t)
499
+
500
+ if invert:
501
+ M = torch.matmul(R, T)
502
+ else:
503
+ M = torch.matmul(T, R)
504
+
505
+ return M
506
+
507
+ def get_translation_matrix(translation_vector):
508
+ """Convert a translation vector into a 4x4 transformation matrix
509
+ """
510
+ T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device)
511
+
512
+ t = translation_vector.contiguous().view(-1, 3, 1)
513
+
514
+ T[:, 0, 0] = 1
515
+ T[:, 1, 1] = 1
516
+ T[:, 2, 2] = 1
517
+ T[:, 3, 3] = 1
518
+ T[:, :3, 3, None] = t
519
+
520
+ return T
521
+
522
+
523
+
524
+ def rot_from_euler(vec):
525
+
526
+ rot = R.from_euler('zyx', vec, degrees=True)
527
+ return
528
+
529
+ def rot_from_axisangle(vec):
530
+ """Convert an axisangle rotation into a 4x4 transformation matrix
531
+ (adapted from https://github.com/Wallacoloo/printipi)
532
+ Input 'vec' has to be Bx1x3
533
+ """
534
+ angle = torch.norm(vec, 2, 2, True)
535
+ axis = vec / (angle + 1e-7)
536
+
537
+ ca = torch.cos(angle)
538
+ sa = torch.sin(angle)
539
+ C = 1 - ca
540
+
541
+ x = axis[..., 0].unsqueeze(1)
542
+ y = axis[..., 1].unsqueeze(1)
543
+ z = axis[..., 2].unsqueeze(1)
544
+
545
+ xs = x * sa
546
+ ys = y * sa
547
+ zs = z * sa
548
+ xC = x * C
549
+ yC = y * C
550
+ zC = z * C
551
+ xyC = x * yC
552
+ yzC = y * zC
553
+ zxC = z * xC
554
+
555
+ rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device)
556
+
557
+ rot[:, 0, 0] = torch.squeeze(x * xC + ca)
558
+ rot[:, 0, 1] = torch.squeeze(xyC - zs)
559
+ rot[:, 0, 2] = torch.squeeze(zxC + ys)
560
+ rot[:, 1, 0] = torch.squeeze(xyC + zs)
561
+ rot[:, 1, 1] = torch.squeeze(y * yC + ca)
562
+ rot[:, 1, 2] = torch.squeeze(yzC - xs)
563
+ rot[:, 2, 0] = torch.squeeze(zxC - ys)
564
+ rot[:, 2, 1] = torch.squeeze(yzC + xs)
565
+ rot[:, 2, 2] = torch.squeeze(z * zC + ca)
566
+ rot[:, 3, 3] = 1
567
+
568
+ return rot
569
+
570
+
571
+ class ConvBlock(nn.Module):
572
+ """Layer to perform a convolution followed by ELU
573
+ """
574
+ def __init__(self, in_channels, out_channels):
575
+ super(ConvBlock, self).__init__()
576
+
577
+ self.conv = Conv3x3(in_channels, out_channels)
578
+ self.nonlin = nn.ELU(inplace=True)
579
+
580
+ def forward(self, x):
581
+ out = self.conv(x)
582
+ out = self.nonlin(out)
583
+ return out
584
+
585
+ def batchNorm(num_ch_dec):
586
+ return nn.BatchNorm2d(num_ch_dec)
587
+
588
+ class Conv3x3(nn.Module):
589
+ """Layer to pad and convolve input
590
+ """
591
+ def __init__(self, in_channels, out_channels, use_refl=True):
592
+ super(Conv3x3, self).__init__()
593
+
594
+ if use_refl:
595
+ self.pad = nn.ReflectionPad2d(1)
596
+ else:
597
+ self.pad = nn.ZeroPad2d(1)
598
+ self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3)
599
+
600
+ def forward(self, x):
601
+ out = self.pad(x)
602
+ out = self.conv(out)
603
+ return out
604
+
605
+
606
+ class BackprojectDepth(nn.Module):
607
+ """Layer to transform a depth image into a point cloud
608
+ """
609
+ def __init__(self, batch_size, height, width):
610
+ super(BackprojectDepth, self).__init__()
611
+
612
+ self.batch_size = batch_size
613
+ self.height = height
614
+ self.width = width
615
+
616
+ meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy')
617
+ self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32)
618
+ self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords),
619
+ requires_grad=False)
620
+
621
+ self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width),
622
+ requires_grad=False)
623
+
624
+ self.pix_coords = torch.unsqueeze(torch.stack(
625
+ [self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0)
626
+ self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1)
627
+ self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1),
628
+ requires_grad=False)
629
+
630
+ def forward(self, depth, inv_K):
631
+ cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords)
632
+ cam_points = depth.view(self.batch_size, 1, -1) * cam_points
633
+ cam_points = torch.cat([cam_points, self.ones], 1)
634
+
635
+ return cam_points
636
+
637
+
638
+ class Project3D(nn.Module):
639
+ """Layer which projects 3D points into a camera with intrinsics K and at position T
640
+ """
641
+ def __init__(self, batch_size, height, width, eps=1e-7):
642
+ super(Project3D, self).__init__()
643
+
644
+ self.batch_size = batch_size
645
+ self.height = height
646
+ self.width = width
647
+ self.eps = eps
648
+
649
+ def forward(self, points, K, T):
650
+ P = torch.matmul(K, T)[:, :3, :]
651
+
652
+ cam_points = torch.matmul(P, points)
653
+
654
+ pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps)
655
+ pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width)
656
+ pix_coords = pix_coords.permute(0, 2, 3, 1)
657
+ pix_coords[..., 0] /= self.width - 1
658
+ pix_coords[..., 1] /= self.height - 1
659
+ pix_coords = (pix_coords - 0.5) * 2
660
+ return pix_coords
661
+
662
+
663
+ def upsample(x):
664
+ """Upsample input tensor by a factor of 2
665
+ """
666
+ return F.interpolate(x, scale_factor=2, mode="nearest")
667
+
668
+ class deconv(nn.Module):
669
+ """Layer to perform a convolution followed by ELU
670
+ """
671
+ def __init__(self, ch_in, ch_out):
672
+ super(deconv, self).__init__()
673
+
674
+ self.deconvlayer = nn.ConvTranspose2d(ch_in, ch_out, 3, stride=2, padding=1)
675
+
676
+ def forward(self, x):
677
+ out = self.deconvlayer(x)
678
+ return out
679
+
680
+
681
+ def get_smooth_loss_gauss_mask(disp, img, gauss_mask):
682
+ """Computes the smoothness loss for a disparity image
683
+ The color image is used for edge-aware smoothness
684
+ """
685
+ grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:])
686
+ grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :])
687
+
688
+ # weighted mean
689
+ # grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:])*gauss_mask[:, :, :, :-1], 1, keepdim=True)
690
+ # grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :])*gauss_mask[:, :, :-1, :], 1, keepdim=True)
691
+
692
+ grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True)
693
+ grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True)
694
+
695
+ grad_disp_x *= torch.exp(-grad_img_x)
696
+ grad_disp_y *= torch.exp(-grad_img_y)
697
+
698
+
699
+ # take weighted mean
700
+ grad_disp_x*=gauss_mask[:, :, :, :-1]
701
+ grad_disp_y*=gauss_mask[:, :, :-1, :]
702
+
703
+ return grad_disp_x.mean() + grad_disp_y.mean()
704
+
705
+ def get_smooth_loss(disp, img):
706
+ """Computes the smoothness loss for a disparity image
707
+ The color image is used for edge-aware smoothness
708
+ """
709
+ grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:])
710
+ grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :])
711
+
712
+ grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True)
713
+ grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True)
714
+
715
+ grad_disp_x *= torch.exp(-grad_img_x)
716
+ grad_disp_y *= torch.exp(-grad_img_y)
717
+
718
+ return grad_disp_x.mean() + grad_disp_y.mean()
719
+
720
+
721
+ class SSIM(nn.Module):
722
+ """Layer to compute the SSIM loss between a pair of images
723
+ """
724
+ def __init__(self):
725
+ super(SSIM, self).__init__()
726
+ self.mu_x_pool = nn.AvgPool2d(3, 1)
727
+ self.mu_y_pool = nn.AvgPool2d(3, 1)
728
+ self.sig_x_pool = nn.AvgPool2d(3, 1)
729
+ self.sig_y_pool = nn.AvgPool2d(3, 1)
730
+ self.sig_xy_pool = nn.AvgPool2d(3, 1)
731
+
732
+ self.refl = nn.ReflectionPad2d(1)
733
+
734
+ self.C1 = 0.01 ** 2
735
+ self.C2 = 0.03 ** 2
736
+
737
+ def forward(self, x, y):
738
+ x = self.refl(x)
739
+ y = self.refl(y)
740
+
741
+ mu_x = self.mu_x_pool(x)
742
+ mu_y = self.mu_y_pool(y)
743
+
744
+ sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2
745
+ sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2
746
+ sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y
747
+
748
+ SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2)
749
+ SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2)
750
+
751
+ return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1)
752
+
753
+
754
+ def compute_depth_errors(gt, pred):
755
+ """Computation of error metrics between predicted and ground truth depths
756
+ """
757
+ thresh = torch.max((gt / pred), (pred / gt))
758
+ a1 = (thresh < 1.25 ).float().mean()
759
+ a2 = (thresh < 1.25 ** 2).float().mean()
760
+ a3 = (thresh < 1.25 ** 3).float().mean()
761
+
762
+ rmse = (gt - pred) ** 2
763
+ rmse = torch.sqrt(rmse.mean())
764
+
765
+ rmse_log = (torch.log(gt) - torch.log(pred)) ** 2
766
+ rmse_log = torch.sqrt(rmse_log.mean())
767
+
768
+ abs_rel = torch.mean(torch.abs(gt - pred) / gt)
769
+
770
+ sq_rel = torch.mean((gt - pred) ** 2 / gt)
771
+
772
+ return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3
773
+
774
+
775
+ """ Parts of the U-Net model """
776
+ class InstanceNormDoubleConv(nn.Module):
777
+ """(convolution => [BN] => ReLU) * 2"""
778
+
779
+ def __init__(self, in_channels, out_channels, mid_channels=None):
780
+ super().__init__()
781
+ if not mid_channels:
782
+ mid_channels = out_channels
783
+ self.double_conv = nn.Sequential(
784
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
785
+ nn.InstanceNorm2d(mid_channels, affine = True),
786
+ nn.ReLU(inplace=True),
787
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
788
+ nn.BatchNorm2d(out_channels),
789
+ nn.ReLU(inplace=True)
790
+ )
791
+
792
+ def forward(self, x):
793
+ return self.double_conv(x)
794
+
795
+ class DoubleConv(nn.Module):
796
+ """(convolution => [BN] => ReLU) * 2"""
797
+
798
+ def __init__(self, in_channels, out_channels, mid_channels=None):
799
+ super().__init__()
800
+ if not mid_channels:
801
+ mid_channels = out_channels
802
+ self.double_conv = nn.Sequential(
803
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
804
+ nn.BatchNorm2d(mid_channels),
805
+ nn.ReLU(inplace=True),
806
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
807
+ nn.BatchNorm2d(out_channels),
808
+ nn.ReLU(inplace=True)
809
+ )
810
+
811
+ def forward(self, x):
812
+ return self.double_conv(x)
813
+
814
+ class DoubleConvIN(nn.Module):
815
+ """(convolution => [BN] => ReLU) * 2"""
816
+
817
+ def __init__(self, in_channels, out_channels, mid_channels=None):
818
+ super().__init__()
819
+ if not mid_channels:
820
+ mid_channels = out_channels
821
+ self.double_conv = nn.Sequential(
822
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
823
+ nn.InstanceNorm2d(mid_channels,affine = True).to('cuda'),
824
+ nn.ReLU(inplace=True),
825
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
826
+ nn.InstanceNorm2d(out_channels,affine = True).to('cuda'),
827
+ nn.ReLU(inplace=True))
828
+
829
+ def forward(self, x):
830
+ return self.double_conv(x)
831
+
832
+ class Down(nn.Module):
833
+ """Downscaling with maxpool then double conv"""
834
+
835
+ def __init__(self, in_channels, out_channels):
836
+ super().__init__()
837
+ self.maxpool_conv = nn.Sequential(
838
+ nn.MaxPool2d(2),
839
+ DoubleConv(in_channels, out_channels)
840
+ )
841
+
842
+ def forward(self, x):
843
+ return self.maxpool_conv(x)
844
+
845
+ class DownIN(nn.Module):
846
+ """Downscaling with maxpool then double conv"""
847
+
848
+ def __init__(self, in_channels, out_channels):
849
+ super().__init__()
850
+ self.maxpool_conv = nn.Sequential(
851
+ nn.MaxPool2d(2),
852
+ DoubleConvIN(in_channels, out_channels)
853
+ )
854
+
855
+ def forward(self, x):
856
+ return self.maxpool_conv(x)
857
+
858
+ class Up(nn.Module):
859
+ """Upscaling then double conv"""
860
+
861
+ def __init__(self, in_channels, out_channels, bilinear=True):
862
+ super().__init__()
863
+
864
+ # if bilinear, use the normal convolutions to reduce the number of channels
865
+ if bilinear:
866
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
867
+ self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
868
+ else:
869
+ self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
870
+ self.conv = DoubleConv(in_channels, out_channels)
871
+
872
+ def forward(self, x1, x2):
873
+ x1 = self.up(x1)
874
+ # input is CHW
875
+ diffY = x2.size()[2] - x1.size()[2]
876
+ diffX = x2.size()[3] - x1.size()[3]
877
+
878
+ x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
879
+ diffY // 2, diffY - diffY // 2])
880
+ # if you have padding issues, see
881
+ # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
882
+ # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
883
+ x = torch.cat([x2, x1], dim=1)
884
+ return self.conv(x)
885
+
886
+
887
+ class OutConv(nn.Module):
888
+ def __init__(self, in_channels, out_channels):
889
+ super(OutConv, self).__init__()
890
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
891
+
892
+ def forward(self, x):
893
+ return self.conv(x)
894
+
895
+
896
+ class UpIN(nn.Module):
897
+ """Upscaling then double conv"""
898
+
899
+ def __init__(self, in_channels, out_channels, bilinear=True):
900
+ super().__init__()
901
+
902
+ # if bilinear, use the normal convolutions to reduce the number of channels
903
+ if bilinear:
904
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
905
+ self.conv = DoubleConvIN(in_channels, out_channels, in_channels // 2)
906
+ else:
907
+ self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
908
+ self.conv = DoubleConvIN(in_channels, out_channels)
909
+
910
+ def forward(self, x1, x2):
911
+ x1 = self.up(x1)
912
+ # input is CHW
913
+ diffY = x2.size()[2] - x1.size()[2]
914
+ diffX = x2.size()[3] - x1.size()[3]
915
+
916
+ x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
917
+ diffY // 2, diffY - diffY // 2])
918
+ # if you have padding issues, see
919
+ # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
920
+ # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
921
+ x = torch.cat([x2, x1], dim=1)
922
+ return self.conv(x)
923
+
924
+
925
+
926
+ # def gaussian_fn(M, std):
927
+ # n = torch.arange(0, M) - (M - 1.0) / 2.0
928
+ # sig2 = 2 * std * std
929
+ # w = torch.exp(-n ** 2 / sig2)
930
+ # return w
931
+
932
+ # def gkern(kernlen=256, std=128):
933
+ # """Returns a 2D Gaussian kernel array."""
934
+ # gkern1d = gaussian_fn(kernlen, std=std)
935
+ # gkern2d = torch.outer(gkern1d, gkern1d)
936
+ # return gkern2d
937
+
938
+ # A = np.random.rand(256*256).reshape([256,256])
939
+ # A = torch.from_numpy(A)
940
+ # guassian_filter = gkern(256, std=32)
941
+
942
+
943
+ # class GaussianLayer(nn.Module):
944
+ # def __init__(self):
945
+ # super(GaussianLayer, self).__init__()
946
+ # self.seq = nn.Sequential(
947
+ # nn.ReflectionPad2d(10),
948
+ # nn.Conv2d(3, 3, 21, stride=1, padding=0, bias=None, groups=3)
949
+ # )
950
+
951
+ # self.weights_init()
952
+
953
+ # def forward(self, x):
954
+ # return self.seq(x)
955
+
956
+ # def weights_init(self):
957
+ # n= np.zeros((21,21))
958
+ # n[10,10] = 1
959
+ # k = scipy.ndimage.gaussian_filter(n,sigma=3)
960
+ # for name, f in self.named_parameters():
961
+ # f.data.copy_(torch.from_numpy(k))
962
+