| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import scipy.ndimage as ndimage |
| |
|
| | def local_scan_zero_ones(locality, x, h_scan=False): |
| | |
| | local_flat = locality.squeeze().cpu().numpy() |
| | |
| | labeled_ones, num_ones = ndimage.label(local_flat == 1) |
| |
|
| | |
| | indices_zeros = torch.tensor(local_flat) |
| | indices_ones = torch.tensor(labeled_ones) |
| | |
| | components_zeros = [] |
| | components_ones = [] |
| |
|
| | if h_scan: |
| | x.transpose_(-1, -2) |
| | indices_zeros.transpose_(-1, -2) |
| | indices_ones.transpose_(-1, -2) |
| |
|
| | |
| | |
| | |
| | |
| | mask = (indices_zeros == 0) |
| | components_zeros.append(x[:,mask]) |
| |
|
| | for i in range(1, num_ones + 1): |
| | mask = (indices_ones == i) |
| | components_ones.append(x[:,mask]) |
| |
|
| | |
| | flattened_zeros = torch.cat(components_zeros, dim=-1) |
| | flattened_ones = torch.cat(components_ones, dim=-1) |
| |
|
| | return flattened_zeros, flattened_ones, flattened_zeros.shape[-1], indices_zeros == 0, indices_ones, num_ones |
| |
|
| | def reverse_local_scan_zero_ones(indices_zeros, indices_ones, num_ones, flattened_zeros, flattened_ones, h_scan=False): |
| | C, H, W = flattened_zeros.shape[0], indices_ones.shape[-2], indices_ones.shape[-1] |
| | local_restored = torch.zeros((C, H, W)).float().cuda(flattened_zeros.get_device()) |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | mask = indices_zeros |
| | local_restored[:, mask] = flattened_zeros |
| |
|
| | |
| | start_idx = 0 |
| | for i in range(1, num_ones + 1): |
| | mask = (indices_ones == i) |
| | local_restored[:, mask] = flattened_ones[:, start_idx:start_idx + mask.sum()] |
| | start_idx += mask.sum() |
| |
|
| | if h_scan: |
| | local_restored.transpose_(-1, -2) |
| | |
| | return local_restored |
| |
|
| |
|
| | def merge_lists(list1, list2): |
| | list1, list2 = list1.unsqueeze(-1), list2.unsqueeze(-1) |
| | merged_list = torch.concat([list1, list2], -1) |
| | return merged_list |
| |
|
| | class Scan_FB_S(torch.autograd.Function): |
| | @staticmethod |
| | def forward(ctx, x: torch.Tensor): |
| | B, C, L = x.shape |
| | ctx.shape = (B, C // 2, L) |
| | x1, x2 = torch.split(x, C // 2, 1) |
| | xs1, xs2 = x1.new_empty((B, 2, C // 2, L)), x2.new_empty((B, 2, C // 2, L)) |
| |
|
| | xs1[:, 0] = x1 |
| | xs1[:, 1] = x1.flip(-1) |
| | xs2[:, 0] = x2 |
| | xs2[:, 1] = x2.flip(-1) |
| | xs = merge_lists(xs1, xs2).reshape(B, 2, C // 2, L * 2) |
| | return xs |
| |
|
| | @staticmethod |
| | def backward(ctx, ys: torch.Tensor): |
| | B, C, L = ctx.shape |
| | ys = ys.view(B, 2, C, L, 2) |
| | ys1, ys2 = ys[..., 0], ys[..., 1] |
| | y1 = ys1[:, 0, :, :] + ys1[:, 1, :, :].flip(-1) |
| | y2 = ys2[:, 0, :, :] + ys2[:, 1, :, :].flip(-1) |
| | y = torch.concat([y1, y2], 1) |
| | return y.view(B, C * 2, L).contiguous() |
| |
|
| |
|
| | class Merge_FB_S(torch.autograd.Function): |
| | @staticmethod |
| | def forward(ctx, ys: torch.Tensor): |
| | B, K, C, L = ys.shape |
| | ctx.shape = (B, K, C, L) |
| | ys = ys.view(B, K, C, -1, 2) |
| | ys1, ys2 = ys[..., 0], ys[..., 1] |
| | y1 = ys1[:, 0, :, :] + ys1[:, 1, :, :].flip(-1) |
| | y2 = ys2[:, 0, :, :] + ys2[:, 1, :, :].flip(-1) |
| | y = torch.concat([y1, y2], 1) |
| | return y.contiguous() |
| |
|
| | @staticmethod |
| | def backward(ctx, x: torch.Tensor): |
| | B, K, C, L = ctx.shape |
| | x1, x2 = torch.split(x, C, 1) |
| | xs1, xs2 = x1.new_empty((B, K, C, L // 2)), x2.new_empty((B, K, C, L // 2)) |
| | xs1[:, 0] = x1 |
| | xs1[:, 1] = x1.flip(-1) |
| | xs2[:, 0] = x2 |
| | xs2[:, 1] = x2.flip(-1) |
| | xs = merge_lists(xs1, xs2).reshape(B, K, C, L) |
| | return xs |
| |
|
| | class CrossScanS(torch.autograd.Function): |
| | @staticmethod |
| | def forward(ctx, x: torch.Tensor): |
| | B, C, H, W = x.shape |
| | ctx.shape = (B, C // 2, H, W) |
| | x1, x2 = torch.split(x, x.shape[1] // 2, 1) |
| | xs1, xs2 = x1.new_empty((B, 4, C // 2, H * W)), x2.new_empty((B, 4, C // 2, H * W)) |
| | xs1[:, 0] = x1.flatten(2, 3) |
| | xs1[:, 1] = x1.transpose(dim0=2, dim1=3).flatten(2, 3) |
| | xs1[:, 2:4] = torch.flip(xs1[:, 0:2], dims=[-1]) |
| | xs2[:, 0] = x2.flatten(2, 3) |
| | xs2[:, 1] = x2.transpose(dim0=2, dim1=3).flatten(2, 3) |
| | xs2[:, 2:4] = torch.flip(xs2[:, 0:2], dims=[-1]) |
| | xs = merge_lists(xs1, xs2).reshape(B, 4, C // 2, H * W * 2) |
| | return xs |
| | |
| | @staticmethod |
| | def backward(ctx, ys: torch.Tensor): |
| | |
| | B, C, H, W = ctx.shape |
| | L = H * W |
| | ys = ys.view(B, 4, C, L, 2) |
| | ys1, ys2 = ys[..., 0], ys[..., 1] |
| | ys1 = ys1[:, 0:2] + ys1[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L) |
| | ys2 = ys2[:, 0:2] + ys2[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L) |
| | y1 = ys1[:, 0] + ys1[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L) |
| | y2 = ys2[:, 0] + ys2[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L) |
| | y = torch.concat([y1, y2], 1) |
| | return y.view(B, -1, H, W) |
| |
|
| |
|
| | class CrossMergeS(torch.autograd.Function): |
| | @staticmethod |
| | def forward(ctx, ys: torch.Tensor): |
| | B, K, D, H, W = ys.shape |
| | W = W // 2 |
| | ctx.shape = (H, W) |
| | ys = ys.view(B, K, D, -1, 2) |
| | ys1, ys2 = ys[..., 0], ys[..., 1] |
| | ys1 = ys1[:, 0:2] + ys1[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) |
| | ys2 = ys2[:, 0:2] + ys2[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) |
| | y1 = ys1[:, 0] + ys1[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1) |
| | y2 = ys2[:, 0] + ys2[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1) |
| | y = torch.concat([y1, y2], 1) |
| | return y |
| | |
| | @staticmethod |
| | def backward(ctx, x: torch.Tensor): |
| | B, D, L = x.shape |
| | |
| | H, W = ctx.shape |
| | B, C, L = x.shape |
| | C = C // 2 |
| | x1, x2 = torch.split(x, x.shape[1] // 2, 1) |
| | xs1, xs2 = x1.new_empty((B, 4, C, L)), x2.new_empty((B, 4, C, L)) |
| | xs1[:, 0] = x1 |
| | xs1[:, 1] = x1.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3) |
| | xs1[:, 2:4] = torch.flip(xs1[:, 0:2], dims=[-1]) |
| | xs2[:, 0] = x2 |
| | xs2[:, 1] = x2.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3) |
| | xs2[:, 2:4] = torch.flip(xs2[:, 0:2], dims=[-1]) |
| | xs = merge_lists(xs1, xs2).reshape(B, 4, C, H, W * 2) |
| | return xs, None, None |