| import torch |
| import torch.nn as nn |
| import numpy as np |
|
|
|
|
| class AbstractPermuter(nn.Module): |
| def __init__(self, *args, **kwargs): |
| super().__init__() |
| def forward(self, x, reverse=False): |
| raise NotImplementedError |
|
|
|
|
| class Identity(AbstractPermuter): |
| def __init__(self): |
| super().__init__() |
|
|
| def forward(self, x, reverse=False): |
| return x |
|
|
|
|
| class Subsample(AbstractPermuter): |
| def __init__(self, H, W): |
| super().__init__() |
| C = 1 |
| indices = np.arange(H*W).reshape(C,H,W) |
| while min(H, W) > 1: |
| indices = indices.reshape(C,H//2,2,W//2,2) |
| indices = indices.transpose(0,2,4,1,3) |
| indices = indices.reshape(C*4,H//2, W//2) |
| H = H//2 |
| W = W//2 |
| C = C*4 |
| assert H == W == 1 |
| idx = torch.tensor(indices.ravel()) |
| self.register_buffer('forward_shuffle_idx', |
| nn.Parameter(idx, requires_grad=False)) |
| self.register_buffer('backward_shuffle_idx', |
| nn.Parameter(torch.argsort(idx), requires_grad=False)) |
|
|
| def forward(self, x, reverse=False): |
| if not reverse: |
| return x[:, self.forward_shuffle_idx] |
| else: |
| return x[:, self.backward_shuffle_idx] |
|
|
|
|
| def mortonify(i, j): |
| """(i,j) index to linear morton code""" |
| i = np.uint64(i) |
| j = np.uint64(j) |
|
|
| z = np.uint(0) |
|
|
| for pos in range(32): |
| z = (z | |
| ((j & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos)) | |
| ((i & (np.uint64(1) << np.uint64(pos))) << np.uint64(pos+1)) |
| ) |
| return z |
|
|
|
|
| class ZCurve(AbstractPermuter): |
| def __init__(self, H, W): |
| super().__init__() |
| reverseidx = [np.int64(mortonify(i,j)) for i in range(H) for j in range(W)] |
| idx = np.argsort(reverseidx) |
| idx = torch.tensor(idx) |
| reverseidx = torch.tensor(reverseidx) |
| self.register_buffer('forward_shuffle_idx', |
| idx) |
| self.register_buffer('backward_shuffle_idx', |
| reverseidx) |
|
|
| def forward(self, x, reverse=False): |
| if not reverse: |
| return x[:, self.forward_shuffle_idx] |
| else: |
| return x[:, self.backward_shuffle_idx] |
|
|
|
|
| class SpiralOut(AbstractPermuter): |
| def __init__(self, H, W): |
| super().__init__() |
| assert H == W |
| size = W |
| indices = np.arange(size*size).reshape(size,size) |
|
|
| i0 = size//2 |
| j0 = size//2-1 |
|
|
| i = i0 |
| j = j0 |
|
|
| idx = [indices[i0, j0]] |
| step_mult = 0 |
| for c in range(1, size//2+1): |
| step_mult += 1 |
| |
| for k in range(step_mult): |
| i = i - 1 |
| j = j |
| idx.append(indices[i, j]) |
|
|
| |
| for k in range(step_mult): |
| i = i |
| j = j + 1 |
| idx.append(indices[i, j]) |
|
|
| step_mult += 1 |
| if c < size//2: |
| |
| for k in range(step_mult): |
| i = i + 1 |
| j = j |
| idx.append(indices[i, j]) |
|
|
| |
| for k in range(step_mult): |
| i = i |
| j = j - 1 |
| idx.append(indices[i, j]) |
| else: |
| |
| for k in range(step_mult-1): |
| i = i + 1 |
| idx.append(indices[i, j]) |
|
|
| assert len(idx) == size*size |
| idx = torch.tensor(idx) |
| self.register_buffer('forward_shuffle_idx', idx) |
| self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) |
|
|
| def forward(self, x, reverse=False): |
| if not reverse: |
| return x[:, self.forward_shuffle_idx] |
| else: |
| return x[:, self.backward_shuffle_idx] |
|
|
|
|
| class SpiralIn(AbstractPermuter): |
| def __init__(self, H, W): |
| super().__init__() |
| assert H == W |
| size = W |
| indices = np.arange(size*size).reshape(size,size) |
|
|
| i0 = size//2 |
| j0 = size//2-1 |
|
|
| i = i0 |
| j = j0 |
|
|
| idx = [indices[i0, j0]] |
| step_mult = 0 |
| for c in range(1, size//2+1): |
| step_mult += 1 |
| |
| for k in range(step_mult): |
| i = i - 1 |
| j = j |
| idx.append(indices[i, j]) |
|
|
| |
| for k in range(step_mult): |
| i = i |
| j = j + 1 |
| idx.append(indices[i, j]) |
|
|
| step_mult += 1 |
| if c < size//2: |
| |
| for k in range(step_mult): |
| i = i + 1 |
| j = j |
| idx.append(indices[i, j]) |
|
|
| |
| for k in range(step_mult): |
| i = i |
| j = j - 1 |
| idx.append(indices[i, j]) |
| else: |
| |
| for k in range(step_mult-1): |
| i = i + 1 |
| idx.append(indices[i, j]) |
|
|
| assert len(idx) == size*size |
| idx = idx[::-1] |
| idx = torch.tensor(idx) |
| self.register_buffer('forward_shuffle_idx', idx) |
| self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) |
|
|
| def forward(self, x, reverse=False): |
| if not reverse: |
| return x[:, self.forward_shuffle_idx] |
| else: |
| return x[:, self.backward_shuffle_idx] |
|
|
|
|
| class Random(nn.Module): |
| def __init__(self, H, W): |
| super().__init__() |
| indices = np.random.RandomState(1).permutation(H*W) |
| idx = torch.tensor(indices.ravel()) |
| self.register_buffer('forward_shuffle_idx', idx) |
| self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) |
|
|
| def forward(self, x, reverse=False): |
| if not reverse: |
| return x[:, self.forward_shuffle_idx] |
| else: |
| return x[:, self.backward_shuffle_idx] |
|
|
|
|
| class AlternateParsing(AbstractPermuter): |
| def __init__(self, H, W): |
| super().__init__() |
| indices = np.arange(W*H).reshape(H,W) |
| for i in range(1, H, 2): |
| indices[i, :] = indices[i, ::-1] |
| idx = indices.flatten() |
| assert len(idx) == H*W |
| idx = torch.tensor(idx) |
| self.register_buffer('forward_shuffle_idx', idx) |
| self.register_buffer('backward_shuffle_idx', torch.argsort(idx)) |
|
|
| def forward(self, x, reverse=False): |
| if not reverse: |
| return x[:, self.forward_shuffle_idx] |
| else: |
| return x[:, self.backward_shuffle_idx] |
|
|
|
|
| if __name__ == "__main__": |
| p0 = AlternateParsing(16, 16) |
| print(p0.forward_shuffle_idx) |
| print(p0.backward_shuffle_idx) |
|
|
| x = torch.randint(0, 768, size=(11, 256)) |
| y = p0(x) |
| xre = p0(y, reverse=True) |
| assert torch.equal(x, xre) |
|
|
| p1 = SpiralOut(2, 2) |
| print(p1.forward_shuffle_idx) |
| print(p1.backward_shuffle_idx) |
|
|