|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import scipy.ndimage as ndimage |
|
|
|
|
|
def local_scan_zero_ones(locality, x, h_scan=False): |
|
|
|
|
|
local_flat = locality.squeeze().cpu().numpy() |
|
|
|
|
|
labeled_ones, num_ones = ndimage.label(local_flat == 1) |
|
|
|
|
|
|
|
|
indices_zeros = torch.tensor(local_flat) |
|
|
indices_ones = torch.tensor(labeled_ones) |
|
|
|
|
|
components_zeros = [] |
|
|
components_ones = [] |
|
|
|
|
|
if h_scan: |
|
|
x.transpose_(-1, -2) |
|
|
indices_zeros.transpose_(-1, -2) |
|
|
indices_ones.transpose_(-1, -2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask = (indices_zeros == 0) |
|
|
components_zeros.append(x[:,mask]) |
|
|
|
|
|
for i in range(1, num_ones + 1): |
|
|
mask = (indices_ones == i) |
|
|
components_ones.append(x[:,mask]) |
|
|
|
|
|
|
|
|
flattened_zeros = torch.cat(components_zeros, dim=-1) |
|
|
flattened_ones = torch.cat(components_ones, dim=-1) |
|
|
|
|
|
return flattened_zeros, flattened_ones, flattened_zeros.shape[-1], indices_zeros == 0, indices_ones, num_ones |
|
|
|
|
|
def reverse_local_scan_zero_ones(indices_zeros, indices_ones, num_ones, flattened_zeros, flattened_ones, h_scan=False): |
|
|
C, H, W = flattened_zeros.shape[0], indices_ones.shape[-2], indices_ones.shape[-1] |
|
|
local_restored = torch.zeros((C, H, W)).float().cuda(flattened_zeros.get_device()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask = indices_zeros |
|
|
local_restored[:, mask] = flattened_zeros |
|
|
|
|
|
|
|
|
start_idx = 0 |
|
|
for i in range(1, num_ones + 1): |
|
|
mask = (indices_ones == i) |
|
|
local_restored[:, mask] = flattened_ones[:, start_idx:start_idx + mask.sum()] |
|
|
start_idx += mask.sum() |
|
|
|
|
|
if h_scan: |
|
|
local_restored.transpose_(-1, -2) |
|
|
|
|
|
return local_restored |
|
|
|
|
|
|
|
|
def merge_lists(list1, list2): |
|
|
list1, list2 = list1.unsqueeze(-1), list2.unsqueeze(-1) |
|
|
merged_list = torch.concat([list1, list2], -1) |
|
|
return merged_list |
|
|
|
|
|
class Scan_FB_S(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def forward(ctx, x: torch.Tensor): |
|
|
B, C, L = x.shape |
|
|
ctx.shape = (B, C // 2, L) |
|
|
x1, x2 = torch.split(x, C // 2, 1) |
|
|
xs1, xs2 = x1.new_empty((B, 2, C // 2, L)), x2.new_empty((B, 2, C // 2, L)) |
|
|
|
|
|
xs1[:, 0] = x1 |
|
|
xs1[:, 1] = x1.flip(-1) |
|
|
xs2[:, 0] = x2 |
|
|
xs2[:, 1] = x2.flip(-1) |
|
|
xs = merge_lists(xs1, xs2).reshape(B, 2, C // 2, L * 2) |
|
|
return xs |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, ys: torch.Tensor): |
|
|
B, C, L = ctx.shape |
|
|
ys = ys.view(B, 2, C, L, 2) |
|
|
ys1, ys2 = ys[..., 0], ys[..., 1] |
|
|
y1 = ys1[:, 0, :, :] + ys1[:, 1, :, :].flip(-1) |
|
|
y2 = ys2[:, 0, :, :] + ys2[:, 1, :, :].flip(-1) |
|
|
y = torch.concat([y1, y2], 1) |
|
|
return y.view(B, C * 2, L).contiguous() |
|
|
|
|
|
|
|
|
class Merge_FB_S(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def forward(ctx, ys: torch.Tensor): |
|
|
B, K, C, L = ys.shape |
|
|
ctx.shape = (B, K, C, L) |
|
|
ys = ys.view(B, K, C, -1, 2) |
|
|
ys1, ys2 = ys[..., 0], ys[..., 1] |
|
|
y1 = ys1[:, 0, :, :] + ys1[:, 1, :, :].flip(-1) |
|
|
y2 = ys2[:, 0, :, :] + ys2[:, 1, :, :].flip(-1) |
|
|
y = torch.concat([y1, y2], 1) |
|
|
return y.contiguous() |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, x: torch.Tensor): |
|
|
B, K, C, L = ctx.shape |
|
|
x1, x2 = torch.split(x, C, 1) |
|
|
xs1, xs2 = x1.new_empty((B, K, C, L // 2)), x2.new_empty((B, K, C, L // 2)) |
|
|
xs1[:, 0] = x1 |
|
|
xs1[:, 1] = x1.flip(-1) |
|
|
xs2[:, 0] = x2 |
|
|
xs2[:, 1] = x2.flip(-1) |
|
|
xs = merge_lists(xs1, xs2).reshape(B, K, C, L) |
|
|
return xs |
|
|
|
|
|
class CrossScanS(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def forward(ctx, x: torch.Tensor): |
|
|
B, C, H, W = x.shape |
|
|
ctx.shape = (B, C // 2, H, W) |
|
|
x1, x2 = torch.split(x, x.shape[1] // 2, 1) |
|
|
xs1, xs2 = x1.new_empty((B, 4, C // 2, H * W)), x2.new_empty((B, 4, C // 2, H * W)) |
|
|
xs1[:, 0] = x1.flatten(2, 3) |
|
|
xs1[:, 1] = x1.transpose(dim0=2, dim1=3).flatten(2, 3) |
|
|
xs1[:, 2:4] = torch.flip(xs1[:, 0:2], dims=[-1]) |
|
|
xs2[:, 0] = x2.flatten(2, 3) |
|
|
xs2[:, 1] = x2.transpose(dim0=2, dim1=3).flatten(2, 3) |
|
|
xs2[:, 2:4] = torch.flip(xs2[:, 0:2], dims=[-1]) |
|
|
xs = merge_lists(xs1, xs2).reshape(B, 4, C // 2, H * W * 2) |
|
|
return xs |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, ys: torch.Tensor): |
|
|
|
|
|
B, C, H, W = ctx.shape |
|
|
L = H * W |
|
|
ys = ys.view(B, 4, C, L, 2) |
|
|
ys1, ys2 = ys[..., 0], ys[..., 1] |
|
|
ys1 = ys1[:, 0:2] + ys1[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L) |
|
|
ys2 = ys2[:, 0:2] + ys2[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L) |
|
|
y1 = ys1[:, 0] + ys1[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L) |
|
|
y2 = ys2[:, 0] + ys2[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L) |
|
|
y = torch.concat([y1, y2], 1) |
|
|
return y.view(B, -1, H, W) |
|
|
|
|
|
|
|
|
class CrossMergeS(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def forward(ctx, ys: torch.Tensor): |
|
|
B, K, D, H, W = ys.shape |
|
|
W = W // 2 |
|
|
ctx.shape = (H, W) |
|
|
ys = ys.view(B, K, D, -1, 2) |
|
|
ys1, ys2 = ys[..., 0], ys[..., 1] |
|
|
ys1 = ys1[:, 0:2] + ys1[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) |
|
|
ys2 = ys2[:, 0:2] + ys2[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) |
|
|
y1 = ys1[:, 0] + ys1[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1) |
|
|
y2 = ys2[:, 0] + ys2[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1) |
|
|
y = torch.concat([y1, y2], 1) |
|
|
return y |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, x: torch.Tensor): |
|
|
B, D, L = x.shape |
|
|
|
|
|
H, W = ctx.shape |
|
|
B, C, L = x.shape |
|
|
C = C // 2 |
|
|
x1, x2 = torch.split(x, x.shape[1] // 2, 1) |
|
|
xs1, xs2 = x1.new_empty((B, 4, C, L)), x2.new_empty((B, 4, C, L)) |
|
|
xs1[:, 0] = x1 |
|
|
xs1[:, 1] = x1.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3) |
|
|
xs1[:, 2:4] = torch.flip(xs1[:, 0:2], dims=[-1]) |
|
|
xs2[:, 0] = x2 |
|
|
xs2[:, 1] = x2.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3) |
|
|
xs2[:, 2:4] = torch.flip(xs2[:, 0:2], dims=[-1]) |
|
|
xs = merge_lists(xs1, xs2).reshape(B, 4, C, H, W * 2) |
|
|
return xs, None, None |