File size: 7,135 Bytes
226675b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import scipy.ndimage as ndimage
def local_scan_zero_ones(locality, x, h_scan=False):
# 第一步:将 `local` 展平以便识别连通区域
local_flat = locality.squeeze().cpu().numpy()
# labeled_zeros, num_zeros = ndimage.label(local_flat == 0) # 标记 0 的连通区域
labeled_ones, num_ones = ndimage.label(local_flat == 1) # 标记 1 的连通区域
# 第二步:提取连通区域的索引
indices_zeros = torch.tensor(local_flat)
indices_ones = torch.tensor(labeled_ones)
# 第三步:为每个连通区域创建掩码
components_zeros = []
components_ones = []
if h_scan:
x.transpose_(-1, -2)
indices_zeros.transpose_(-1, -2)
indices_ones.transpose_(-1, -2)
# for i in range(1, num_zeros + 1):
# mask = (indices_zeros == i)
# components_zeros.append(x[:,mask]) # 使用掩码从 y 中提取值
mask = (indices_zeros == 0)
components_zeros.append(x[:,mask])
for i in range(1, num_ones + 1):
mask = (indices_ones == i)
components_ones.append(x[:,mask]) # 使用掩码从 y 中提取值
# 第四步:将这些区域平铺(即按题目要求扫描)
flattened_zeros = torch.cat(components_zeros, dim=-1) # 将所有 0 区域合并
flattened_ones = torch.cat(components_ones, dim=-1) # 将所有 1 区域合并
return flattened_zeros, flattened_ones, flattened_zeros.shape[-1], indices_zeros == 0, indices_ones, num_ones
def reverse_local_scan_zero_ones(indices_zeros, indices_ones, num_ones, flattened_zeros, flattened_ones, h_scan=False):
C, H, W = flattened_zeros.shape[0], indices_ones.shape[-2], indices_ones.shape[-1]
local_restored = torch.zeros((C, H, W)).float().cuda(flattened_zeros.get_device()) # 创建一个与原始矩阵形状相同的零矩阵
# 填充 0 区域
# start_idx = 0
# for i in range(1, num_zeros + 1):
# mask = (indices_zeros == i)
# local_restored[:, mask] = flattened_zeros[:, start_idx:start_idx + mask.sum()]
# start_idx += mask.sum()
mask = indices_zeros
local_restored[:, mask] = flattened_zeros
# 填充 1 区域
start_idx = 0
for i in range(1, num_ones + 1):
mask = (indices_ones == i)
local_restored[:, mask] = flattened_ones[:, start_idx:start_idx + mask.sum()]
start_idx += mask.sum()
if h_scan:
local_restored.transpose_(-1, -2)
return local_restored
def merge_lists(list1, list2):
list1, list2 = list1.unsqueeze(-1), list2.unsqueeze(-1)
merged_list = torch.concat([list1, list2], -1)
return merged_list
class Scan_FB_S(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor):
B, C, L = x.shape
ctx.shape = (B, C // 2, L)
x1, x2 = torch.split(x, C // 2, 1)
xs1, xs2 = x1.new_empty((B, 2, C // 2, L)), x2.new_empty((B, 2, C // 2, L))
xs1[:, 0] = x1
xs1[:, 1] = x1.flip(-1)
xs2[:, 0] = x2
xs2[:, 1] = x2.flip(-1)
xs = merge_lists(xs1, xs2).reshape(B, 2, C // 2, L * 2)
return xs
@staticmethod
def backward(ctx, ys: torch.Tensor):
B, C, L = ctx.shape
ys = ys.view(B, 2, C, L, 2)
ys1, ys2 = ys[..., 0], ys[..., 1]
y1 = ys1[:, 0, :, :] + ys1[:, 1, :, :].flip(-1)
y2 = ys2[:, 0, :, :] + ys2[:, 1, :, :].flip(-1)
y = torch.concat([y1, y2], 1)
return y.view(B, C * 2, L).contiguous()
class Merge_FB_S(torch.autograd.Function):
@staticmethod
def forward(ctx, ys: torch.Tensor):
B, K, C, L = ys.shape
ctx.shape = (B, K, C, L)
ys = ys.view(B, K, C, -1, 2)
ys1, ys2 = ys[..., 0], ys[..., 1]
y1 = ys1[:, 0, :, :] + ys1[:, 1, :, :].flip(-1)
y2 = ys2[:, 0, :, :] + ys2[:, 1, :, :].flip(-1)
y = torch.concat([y1, y2], 1)
return y.contiguous()
@staticmethod
def backward(ctx, x: torch.Tensor):
B, K, C, L = ctx.shape
x1, x2 = torch.split(x, C, 1)
xs1, xs2 = x1.new_empty((B, K, C, L // 2)), x2.new_empty((B, K, C, L // 2))
xs1[:, 0] = x1
xs1[:, 1] = x1.flip(-1)
xs2[:, 0] = x2
xs2[:, 1] = x2.flip(-1)
xs = merge_lists(xs1, xs2).reshape(B, K, C, L)
return xs
class CrossScanS(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor):
B, C, H, W = x.shape
ctx.shape = (B, C // 2, H, W)
x1, x2 = torch.split(x, x.shape[1] // 2, 1)
xs1, xs2 = x1.new_empty((B, 4, C // 2, H * W)), x2.new_empty((B, 4, C // 2, H * W))
xs1[:, 0] = x1.flatten(2, 3)
xs1[:, 1] = x1.transpose(dim0=2, dim1=3).flatten(2, 3)
xs1[:, 2:4] = torch.flip(xs1[:, 0:2], dims=[-1])
xs2[:, 0] = x2.flatten(2, 3)
xs2[:, 1] = x2.transpose(dim0=2, dim1=3).flatten(2, 3)
xs2[:, 2:4] = torch.flip(xs2[:, 0:2], dims=[-1])
xs = merge_lists(xs1, xs2).reshape(B, 4, C // 2, H * W * 2)
return xs
@staticmethod
def backward(ctx, ys: torch.Tensor):
# out: (b, k, d, l)
B, C, H, W = ctx.shape
L = H * W
ys = ys.view(B, 4, C, L, 2)
ys1, ys2 = ys[..., 0], ys[..., 1]
ys1 = ys1[:, 0:2] + ys1[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L)
ys2 = ys2[:, 0:2] + ys2[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L)
y1 = ys1[:, 0] + ys1[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
y2 = ys2[:, 0] + ys2[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
y = torch.concat([y1, y2], 1)
return y.view(B, -1, H, W)
class CrossMergeS(torch.autograd.Function):
@staticmethod
def forward(ctx, ys: torch.Tensor):
B, K, D, H, W = ys.shape
W = W // 2
ctx.shape = (H, W)
ys = ys.view(B, K, D, -1, 2)
ys1, ys2 = ys[..., 0], ys[..., 1]
ys1 = ys1[:, 0:2] + ys1[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
ys2 = ys2[:, 0:2] + ys2[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
y1 = ys1[:, 0] + ys1[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
y2 = ys2[:, 0] + ys2[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
y = torch.concat([y1, y2], 1)
return y
@staticmethod
def backward(ctx, x: torch.Tensor):
B, D, L = x.shape
# out: (b, k, d, l)
H, W = ctx.shape
B, C, L = x.shape
C = C // 2
x1, x2 = torch.split(x, x.shape[1] // 2, 1)
xs1, xs2 = x1.new_empty((B, 4, C, L)), x2.new_empty((B, 4, C, L))
xs1[:, 0] = x1
xs1[:, 1] = x1.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3)
xs1[:, 2:4] = torch.flip(xs1[:, 0:2], dims=[-1])
xs2[:, 0] = x2
xs2[:, 1] = x2.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3)
xs2[:, 2:4] = torch.flip(xs2[:, 0:2], dims=[-1])
xs = merge_lists(xs1, xs2).reshape(B, 4, C, H, W * 2)
return xs, None, None |