| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | orderings = [ |
| | [0, 1, 3, 4, 5], |
| | [1, 2, 0, 4, 5], |
| | [2, 3, 1, 4, 5], |
| | [3, 0, 2, 4, 5], |
| | [4, 1, 3, 2, 0], |
| | [5, 1, 3, 0, 2], |
| | ] |
| | rotations = [ |
| | [0, 0, 0, 0, 0], |
| | [0, 0, 0,-1, 1], |
| | [0, 0, 0, 2, 2], |
| | [0, 0, 0, 1,-1], |
| | [0, 1,-1, 2, 0], |
| | [0,-1, 1, 0, 2] |
| | ] |
| |
|
| | def _take_right(face, rot): |
| | if rot == 0: |
| | return face[:, :, 0] |
| | elif rot == 1: |
| | return face[:, 0, :].flip(1) |
| | elif rot == 2: |
| | return face[:, :, -1].flip(1) |
| | elif rot == -1: |
| | return face[:, -1, :] |
| |
|
| | def _take_left(face, rot): |
| | if rot == 0: |
| | return face[:, :, -1] |
| | elif rot == 1: |
| | return face[:, -1, :].flip(1) |
| | elif rot == 2: |
| | return face[:, :, 0].flip(1) |
| | elif rot == -1: |
| | return face[:, 0, :] |
| |
|
| | def _take_top(face, rot): |
| | if rot == 0: |
| | return face[:, -1, :] |
| | elif rot == 1: |
| | return face[:, :, 0] |
| | elif rot == 2: |
| | return face[:, 0, :].flip(1) |
| | elif rot == -1: |
| | return face[:, :, -1].flip(1) |
| |
|
| | def _take_bottom(face, rot): |
| | if rot == 0: |
| | return face[:, 0, :] |
| | elif rot == 1: |
| | return face[:, :, -1] |
| | elif rot == 2: |
| | return face[:, -1, :].flip(1) |
| | elif rot == -1: |
| | return face[:, :, 0].flip(1) |
| |
|
| | def valid_pad_conv_fn(x, one_side_pad=False): |
| | if one_side_pad: |
| | x = x[:, :, :-1, :-1] |
| | assert x.ndim == 4 and x.shape[0] == 6 |
| | _, C, H, W = x.shape |
| | y = x.new_empty(6, C, H+2, W+2) |
| | y[..., 1:-1, 1:-1] = x |
| |
|
| | for i in range(6): |
| | r_idx, l_idx, t_idx, b_idx = orderings[i][1:5] |
| | r_rot, l_rot, t_rot, b_rot = rotations[i][1:5] |
| |
|
| | r_edge = _take_right (x[r_idx], r_rot) |
| | l_edge = _take_left (x[l_idx], l_rot) |
| | t_edge = _take_top (x[t_idx], t_rot) |
| | b_edge = _take_bottom(x[b_idx], b_rot) |
| |
|
| | y[i, :, 1:-1, 0 ] = l_edge |
| | y[i, :, 1:-1, -1 ] = r_edge |
| | y[i, :, 0, 1:-1] = t_edge |
| | y[i, :, -1, 1:-1] = b_edge |
| |
|
| | y[i, :, 0, 0 ] = 0.5*(y[i, :, 0, 1] + y[i, :, 1, 0]) |
| | y[i, :, 0, -1 ] = 0.5*(y[i, :, 0, -2] + y[i, :, 1, -1]) |
| | y[i, :, -1, 0 ] = 0.5*(y[i, :, -2, 0] + y[i, :, -1, 1]) |
| | y[i, :, -1,-1 ] = 0.5*(y[i, :, -2, -1] + y[i, :, -1, -2]) |
| |
|
| | if one_side_pad: |
| | return y[:, :, 1:, 1:] |
| |
|
| | return y |
| |
|
| |
|
| | class PaddedConv2d(nn.Conv2d): |
| | def __init__(self, *args, pad_fn=None, one_side_pad=False, **kwargs): |
| | kwargs = dict(kwargs) |
| | kwargs["padding"] = 0 |
| | super().__init__(*args, **kwargs) |
| | self.pad_fn = pad_fn |
| | self.one_side_pad = one_side_pad |
| |
|
| | def forward(self, x): |
| | x = self.pad_fn(x, one_side_pad=self.one_side_pad) |
| | return F.conv2d( |
| | x, self.weight, self.bias, |
| | stride=self.stride, padding=0, |
| | dilation=self.dilation, groups=self.groups |
| | ) |
| |
|
| | @classmethod |
| | def from_existing(cls, conv: nn.Conv2d, pad_fn, one_side_pad=False): |
| | new = cls( |
| | conv.in_channels, conv.out_channels, conv.kernel_size, |
| | stride=conv.stride, padding=0, dilation=conv.dilation, |
| | groups=conv.groups, bias=(conv.bias is not None), |
| | padding_mode="zeros", pad_fn=pad_fn, one_side_pad=one_side_pad |
| | ) |
| | new.weight = conv.weight |
| | if conv.bias is not None: |
| | new.bias = conv.bias |
| | return new |
| |
|
| |
|
| |
|