File size: 7,135 Bytes
226675b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185

import torch
import torch.nn as nn
import torch.nn.functional as F
import scipy.ndimage as ndimage

def local_scan_zero_ones(locality, x, h_scan=False):
    # 第一步:将 `local` 展平以便识别连通区域
    local_flat = locality.squeeze().cpu().numpy() 
    # labeled_zeros, num_zeros = ndimage.label(local_flat == 0)  # 标记 0 的连通区域
    labeled_ones, num_ones = ndimage.label(local_flat == 1)  # 标记 1 的连通区域

    # 第二步:提取连通区域的索引
    indices_zeros = torch.tensor(local_flat)
    indices_ones = torch.tensor(labeled_ones)
    # 第三步:为每个连通区域创建掩码
    components_zeros = []
    components_ones = []

    if h_scan:
        x.transpose_(-1, -2)
        indices_zeros.transpose_(-1, -2)
        indices_ones.transpose_(-1, -2)

    # for i in range(1, num_zeros + 1):
    #     mask = (indices_zeros == i)
    #     components_zeros.append(x[:,mask])  # 使用掩码从 y 中提取值
    
    mask = (indices_zeros == 0)
    components_zeros.append(x[:,mask])

    for i in range(1, num_ones + 1):
        mask = (indices_ones == i)
        components_ones.append(x[:,mask])  # 使用掩码从 y 中提取值

    # 第四步:将这些区域平铺(即按题目要求扫描)
    flattened_zeros = torch.cat(components_zeros, dim=-1) # 将所有 0 区域合并
    flattened_ones = torch.cat(components_ones, dim=-1)  # 将所有 1 区域合并

    return flattened_zeros, flattened_ones, flattened_zeros.shape[-1], indices_zeros == 0, indices_ones, num_ones

def reverse_local_scan_zero_ones(indices_zeros, indices_ones, num_ones, flattened_zeros, flattened_ones, h_scan=False):
    C, H, W = flattened_zeros.shape[0], indices_ones.shape[-2], indices_ones.shape[-1]
    local_restored = torch.zeros((C, H, W)).float().cuda(flattened_zeros.get_device()) # 创建一个与原始矩阵形状相同的零矩阵
    # 填充 0 区域
    # start_idx = 0
    # for i in range(1, num_zeros + 1):
    #     mask = (indices_zeros == i)
    #     local_restored[:, mask] = flattened_zeros[:, start_idx:start_idx + mask.sum()]
    #     start_idx += mask.sum()

    mask = indices_zeros
    local_restored[:, mask] = flattened_zeros

    # 填充 1 区域
    start_idx = 0
    for i in range(1, num_ones + 1):
        mask = (indices_ones == i)
        local_restored[:, mask] = flattened_ones[:, start_idx:start_idx + mask.sum()]
        start_idx += mask.sum()

    if h_scan:
        local_restored.transpose_(-1, -2)
   
    return local_restored


def merge_lists(list1, list2):
    list1, list2 = list1.unsqueeze(-1), list2.unsqueeze(-1)
    merged_list = torch.concat([list1, list2], -1)
    return merged_list

class Scan_FB_S(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor):
        B, C, L = x.shape
        ctx.shape = (B, C // 2, L)
        x1, x2 = torch.split(x, C // 2, 1)
        xs1, xs2 = x1.new_empty((B, 2, C // 2, L)), x2.new_empty((B, 2, C // 2, L))

        xs1[:, 0] = x1
        xs1[:, 1] = x1.flip(-1)
        xs2[:, 0] = x2
        xs2[:, 1] = x2.flip(-1)
        xs = merge_lists(xs1, xs2).reshape(B, 2, C // 2, L * 2)
        return xs

    @staticmethod
    def backward(ctx, ys: torch.Tensor):
        B, C, L = ctx.shape
        ys = ys.view(B, 2, C, L, 2)
        ys1, ys2 = ys[..., 0], ys[..., 1]
        y1 = ys1[:, 0, :, :] + ys1[:, 1, :, :].flip(-1)
        y2 = ys2[:, 0, :, :] + ys2[:, 1, :, :].flip(-1)
        y = torch.concat([y1, y2], 1)
        return y.view(B, C * 2, L).contiguous()


class Merge_FB_S(torch.autograd.Function):
    @staticmethod
    def forward(ctx, ys: torch.Tensor):
        B, K, C, L = ys.shape
        ctx.shape = (B, K, C, L)
        ys = ys.view(B, K, C, -1, 2)
        ys1, ys2 = ys[..., 0], ys[..., 1]
        y1 = ys1[:, 0, :, :] + ys1[:, 1, :, :].flip(-1)
        y2 = ys2[:, 0, :, :] + ys2[:, 1, :, :].flip(-1)
        y = torch.concat([y1, y2], 1)
        return y.contiguous()

    @staticmethod
    def backward(ctx, x: torch.Tensor):
        B, K, C, L = ctx.shape
        x1, x2 = torch.split(x, C, 1)
        xs1, xs2 = x1.new_empty((B, K, C, L // 2)), x2.new_empty((B, K, C, L // 2))
        xs1[:, 0] = x1
        xs1[:, 1] = x1.flip(-1)
        xs2[:, 0] = x2
        xs2[:, 1] = x2.flip(-1)
        xs = merge_lists(xs1, xs2).reshape(B, K, C, L)
        return xs

class CrossScanS(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor):
        B, C, H, W = x.shape
        ctx.shape = (B, C // 2, H, W)
        x1, x2 = torch.split(x, x.shape[1] // 2, 1)
        xs1, xs2 = x1.new_empty((B, 4, C // 2, H * W)), x2.new_empty((B, 4, C // 2, H * W))
        xs1[:, 0] = x1.flatten(2, 3)
        xs1[:, 1] = x1.transpose(dim0=2, dim1=3).flatten(2, 3)
        xs1[:, 2:4] = torch.flip(xs1[:, 0:2], dims=[-1])
        xs2[:, 0] = x2.flatten(2, 3)
        xs2[:, 1] = x2.transpose(dim0=2, dim1=3).flatten(2, 3)
        xs2[:, 2:4] = torch.flip(xs2[:, 0:2], dims=[-1])
        xs = merge_lists(xs1, xs2).reshape(B, 4, C // 2, H * W * 2)
        return xs
    
    @staticmethod
    def backward(ctx, ys: torch.Tensor):
        # out: (b, k, d, l)
        B, C, H, W = ctx.shape
        L = H * W
        ys = ys.view(B, 4, C, L, 2)
        ys1, ys2 = ys[..., 0], ys[..., 1]
        ys1 = ys1[:, 0:2] + ys1[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L)
        ys2 = ys2[:, 0:2] + ys2[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L)
        y1 = ys1[:, 0] + ys1[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
        y2 = ys2[:, 0] + ys2[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L)
        y = torch.concat([y1, y2], 1)
        return y.view(B, -1, H, W)


class CrossMergeS(torch.autograd.Function):
    @staticmethod
    def forward(ctx, ys: torch.Tensor):
        B, K, D, H, W = ys.shape
        W = W // 2
        ctx.shape = (H, W)
        ys = ys.view(B, K, D, -1, 2)
        ys1, ys2 = ys[..., 0], ys[..., 1]
        ys1 = ys1[:, 0:2] + ys1[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
        ys2 = ys2[:, 0:2] + ys2[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1)
        y1 = ys1[:, 0] + ys1[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
        y2 = ys2[:, 0] + ys2[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1)
        y = torch.concat([y1, y2], 1)
        return y
    
    @staticmethod
    def backward(ctx, x: torch.Tensor):
        B, D, L = x.shape
        # out: (b, k, d, l)
        H, W = ctx.shape
        B, C, L = x.shape
        C = C // 2
        x1, x2 = torch.split(x, x.shape[1] // 2, 1)
        xs1, xs2 = x1.new_empty((B, 4, C, L)), x2.new_empty((B, 4, C, L))
        xs1[:, 0] = x1
        xs1[:, 1] = x1.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3)
        xs1[:, 2:4] = torch.flip(xs1[:, 0:2], dims=[-1])
        xs2[:, 0] = x2
        xs2[:, 1] = x2.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3)
        xs2[:, 2:4] = torch.flip(xs2[:, 0:2], dims=[-1])
        xs = merge_lists(xs1, xs2).reshape(B, 4, C, H, W * 2)
        return xs, None, None