File size: 3,488 Bytes
05d33a4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | import torch.nn as nn
import torch.nn.functional as F
orderings = [
[0, 1, 3, 4, 5],
[1, 2, 0, 4, 5],
[2, 3, 1, 4, 5],
[3, 0, 2, 4, 5],
[4, 1, 3, 2, 0],
[5, 1, 3, 0, 2],
]
rotations = [
[0, 0, 0, 0, 0],
[0, 0, 0,-1, 1],
[0, 0, 0, 2, 2],
[0, 0, 0, 1,-1],
[0, 1,-1, 2, 0],
[0,-1, 1, 0, 2]
]
def _take_right(face, rot):
if rot == 0:
return face[:, :, 0]
elif rot == 1:
return face[:, 0, :].flip(1)
elif rot == 2:
return face[:, :, -1].flip(1)
elif rot == -1:
return face[:, -1, :]
def _take_left(face, rot):
if rot == 0:
return face[:, :, -1]
elif rot == 1:
return face[:, -1, :].flip(1)
elif rot == 2:
return face[:, :, 0].flip(1)
elif rot == -1:
return face[:, 0, :]
def _take_top(face, rot):
if rot == 0:
return face[:, -1, :]
elif rot == 1:
return face[:, :, 0]
elif rot == 2:
return face[:, 0, :].flip(1)
elif rot == -1:
return face[:, :, -1].flip(1)
def _take_bottom(face, rot):
if rot == 0:
return face[:, 0, :]
elif rot == 1:
return face[:, :, -1]
elif rot == 2:
return face[:, -1, :].flip(1)
elif rot == -1:
return face[:, :, 0].flip(1)
def valid_pad_conv_fn(x, one_side_pad=False):
if one_side_pad:
x = x[:, :, :-1, :-1]
assert x.ndim == 4 and x.shape[0] == 6
_, C, H, W = x.shape
y = x.new_empty(6, C, H+2, W+2)
y[..., 1:-1, 1:-1] = x
for i in range(6):
r_idx, l_idx, t_idx, b_idx = orderings[i][1:5]
r_rot, l_rot, t_rot, b_rot = rotations[i][1:5]
r_edge = _take_right (x[r_idx], r_rot)
l_edge = _take_left (x[l_idx], l_rot)
t_edge = _take_top (x[t_idx], t_rot)
b_edge = _take_bottom(x[b_idx], b_rot)
y[i, :, 1:-1, 0 ] = l_edge
y[i, :, 1:-1, -1 ] = r_edge
y[i, :, 0, 1:-1] = t_edge
y[i, :, -1, 1:-1] = b_edge
y[i, :, 0, 0 ] = 0.5*(y[i, :, 0, 1] + y[i, :, 1, 0])
y[i, :, 0, -1 ] = 0.5*(y[i, :, 0, -2] + y[i, :, 1, -1])
y[i, :, -1, 0 ] = 0.5*(y[i, :, -2, 0] + y[i, :, -1, 1])
y[i, :, -1,-1 ] = 0.5*(y[i, :, -2, -1] + y[i, :, -1, -2])
if one_side_pad:
return y[:, :, 1:, 1:]
return y
class PaddedConv2d(nn.Conv2d):
def __init__(self, *args, pad_fn=None, one_side_pad=False, **kwargs):
kwargs = dict(kwargs)
kwargs["padding"] = 0
super().__init__(*args, **kwargs)
self.pad_fn = pad_fn
self.one_side_pad = one_side_pad
def forward(self, x):
x = self.pad_fn(x, one_side_pad=self.one_side_pad)
return F.conv2d(
x, self.weight, self.bias,
stride=self.stride, padding=0,
dilation=self.dilation, groups=self.groups
)
@classmethod
def from_existing(cls, conv: nn.Conv2d, pad_fn, one_side_pad=False):
new = cls(
conv.in_channels, conv.out_channels, conv.kernel_size,
stride=conv.stride, padding=0, dilation=conv.dilation,
groups=conv.groups, bias=(conv.bias is not None),
padding_mode="zeros", pad_fn=pad_fn, one_side_pad=one_side_pad
)
new.weight = conv.weight
if conv.bias is not None:
new.bias = conv.bias
return new
|