| import torch |
| import warnings |
|
|
| WITH_TRITON = True |
| |
| try: |
| import triton |
| import triton.language as tl |
| except: |
| WITH_TRITON = False |
| warnings.warn("Triton not installed, fall back to pytorch implements.") |
|
|
| |
| if WITH_TRITON: |
| try: |
| from functools import cached_property |
| except: |
| warnings.warn("if you are using py37, add this line to functools.py: " |
| "cached_property = lambda func: property(lru_cache()(func))") |
|
|
| |
| def cross_scan_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0): |
| if in_channel_first: |
| B, C, H, W = x.shape |
| if scans == 0: |
| y = x.new_empty((B, 4, C, H * W)) |
| y[:, 0, :, :] = x.flatten(2, 3) |
| y[:, 1, :, :] = x.transpose(dim0=2, dim1=3).flatten(2, 3) |
| y[:, 2:4, :, :] = torch.flip(y[:, 0:2, :, :], dims=[-1]) |
| elif scans == 1: |
| y = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1) |
| elif scans == 2: |
| y = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1) |
| y = torch.cat([y, y.flip(dims=[-1])], dim=1) |
| else: |
| B, H, W, C = x.shape |
| if scans == 0: |
| y = x.new_empty((B, H * W, 4, C)) |
| y[:, :, 0, :] = x.flatten(1, 2) |
| y[:, :, 1, :] = x.transpose(dim0=1, dim1=2).flatten(1, 2) |
| y[:, :, 2:4, :] = torch.flip(y[:, :, 0:2, :], dims=[1]) |
| elif scans == 1: |
| y = x.view(B, H * W, 1, C).repeat(1, 1, 4, 1) |
| elif scans == 2: |
| y = x.view(B, H * W, 1, C).repeat(1, 1, 2, 1) |
| y = torch.cat([y, y.flip(dims=[1])], dim=2) |
|
|
| if in_channel_first and (not out_channel_first): |
| y = y.permute(0, 3, 1, 2).contiguous() |
| elif (not in_channel_first) and out_channel_first: |
| y = y.permute(0, 2, 3, 1).contiguous() |
|
|
| return y |
|
|
|
|
| def cross_merge_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0): |
| if out_channel_first: |
| B, K, D, H, W = y.shape |
| y = y.view(B, K, D, -1) |
| if scans == 0: |
| y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) |
| y = y[:, 0] + y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1) |
| elif scans == 1: |
| y = y.sum(1) |
| elif scans == 2: |
| y = y[:, 0:2] + y[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) |
| y = y.sum(1) |
| else: |
| B, H, W, K, D = y.shape |
| y = y.view(B, -1, K, D) |
| if scans == 0: |
| y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D) |
| y = y[:, :, 0] + y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).contiguous().view(B, -1, D) |
| elif scans == 1: |
| y = y.sum(2) |
| elif scans == 2: |
| y = y[:, :, 0:2] + y[:, :, 2:4].flip(dims=[1]).view(B, -1, 2, D) |
| y = y.sum(2) |
|
|
| if in_channel_first and (not out_channel_first): |
| y = y.permute(0, 2, 1).contiguous() |
| elif (not in_channel_first) and out_channel_first: |
| y = y.permute(0, 2, 1).contiguous() |
| |
| return y |
|
|
|
|
| def cross_scan1b1_fwd(x: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0): |
| if in_channel_first: |
| B, _, C, H, W = x.shape |
| if scans == 0: |
| y = torch.stack([ |
| x[:, 0].flatten(2, 3), |
| x[:, 1].transpose(dim0=2, dim1=3).flatten(2, 3), |
| torch.flip(x[:, 2].flatten(2, 3), dims=[-1]), |
| torch.flip(x[:, 3].transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]), |
| ], dim=1) |
| elif scans == 1: |
| y = x.flatten(2, 3) |
| elif scans == 2: |
| y = torch.stack([ |
| x[:, 0].flatten(2, 3), |
| x[:, 1].flatten(2, 3), |
| torch.flip(x[:, 2].flatten(2, 3), dims=[-1]), |
| torch.flip(x[:, 3].flatten(2, 3), dims=[-1]), |
| ], dim=1) |
| else: |
| B, H, W, _, C = x.shape |
| if scans == 0: |
| y = torch.stack([ |
| x[:, :, :, 0].flatten(1, 2), |
| x[:, :, :, 1].transpose(dim0=1, dim1=2).flatten(1, 2), |
| torch.flip(x[:, :, :, 2].flatten(1, 2), dims=[1]), |
| torch.flip(x[:, :, :, 3].transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]), |
| ], dim=2) |
| elif scans == 1: |
| y = x.flatten(1, 2) |
| elif scans == 2: |
| y = torch.stack([ |
| x[:, 0].flatten(1, 2), |
| x[:, 1].flatten(1, 2), |
| torch.flip(x[:, 2].flatten(1, 2), dims=[-1]), |
| torch.flip(x[:, 3].flatten(1, 2), dims=[-1]), |
| ], dim=2) |
|
|
| if in_channel_first and (not out_channel_first): |
| y = y.permute(0, 3, 1, 2).contiguous() |
| elif (not in_channel_first) and out_channel_first: |
| y = y.permute(0, 2, 3, 1).contiguous() |
|
|
| return y |
|
|
|
|
| def cross_merge1b1_fwd(y: torch.Tensor, in_channel_first=True, out_channel_first=True, scans=0): |
| if out_channel_first: |
| B, K, D, H, W = y.shape |
| y = y.view(B, K, D, -1) |
| if scans == 0: |
| y = torch.stack([ |
| y[:, 0], |
| y[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3), |
| torch.flip(y[:, 2], dims=[-1]), |
| torch.flip(y[:, 3].view(B, -1, W, H).transpose(dim0=2, dim1=3).flatten(2, 3), dims=[-1]), |
| ], dim=1) |
| elif scans == 1: |
| y = y |
| elif scans == 2: |
| y = torch.stack([ |
| y[:, 0], |
| y[:, 1], |
| torch.flip(y[:, 2], dims=[-1]), |
| torch.flip(y[:, 3], dims=[-1]), |
| ], dim=1) |
| else: |
| B, H, W, _, D = y.shape |
| y = y.view(B, -1, K, D) |
| if scans == 0: |
| y = torch.stack([ |
| y[:, :, 0], |
| y[:, :, 1].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2), |
| torch.flip(y[:, :, 2], dims=[1]), |
| torch.flip(y[:, :, 3].view(B, W, H, -1).transpose(dim0=1, dim1=2).flatten(1, 2), dims=[1]), |
| ], dim=2) |
| elif scans == 1: |
| y = y |
| elif scans == 2: |
| y = torch.stack([ |
| y[:, :, 0], |
| y[:, :, 1], |
| torch.flip(y[:, :, 2], dims=[1]), |
| torch.flip(y[:, :, 3], dims=[1]), |
| ], dim=2) |
|
|
| if out_channel_first and (not in_channel_first): |
| y = y.permute(0, 3, 1, 2).contiguous() |
| elif (not out_channel_first) and in_channel_first: |
| y = y.permute(0, 2, 3, 1).contiguous() |
|
|
| return y |
|
|
|
|
| class CrossScanF(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0): |
| |
| |
| ctx.in_channel_first = in_channel_first |
| ctx.out_channel_first = out_channel_first |
| ctx.one_by_one = one_by_one |
| ctx.scans = scans |
|
|
| if one_by_one: |
| B, K, C, H, W = x.shape |
| if not in_channel_first: |
| B, H, W, K, C = x.shape |
| else: |
| B, C, H, W = x.shape |
| if not in_channel_first: |
| B, H, W, C = x.shape |
| ctx.shape = (B, C, H, W) |
|
|
| _fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd |
| y = _fn(x, in_channel_first, out_channel_first, scans) |
|
|
| return y |
| |
| @staticmethod |
| def backward(ctx, ys: torch.Tensor): |
| |
| in_channel_first = ctx.in_channel_first |
| out_channel_first = ctx.out_channel_first |
| one_by_one = ctx.one_by_one |
| scans = ctx.scans |
| B, C, H, W = ctx.shape |
|
|
| ys = ys.view(B, -1, C, H, W) if out_channel_first else ys.view(B, H, W, -1, C) |
| _fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd |
| y = _fn(ys, in_channel_first, out_channel_first, scans) |
| |
| if one_by_one: |
| y = y.view(B, 4, -1, H, W) if in_channel_first else y.view(B, H, W, 4, -1) |
| else: |
| y = y.view(B, -1, H, W) if in_channel_first else y.view(B, H, W, -1) |
|
|
| return y, None, None, None, None |
|
|
|
|
| class CrossMergeF(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, ys: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0): |
| |
| |
| ctx.in_channel_first = in_channel_first |
| ctx.out_channel_first = out_channel_first |
| ctx.one_by_one = one_by_one |
| ctx.scans = scans |
|
|
| B, K, C, H, W = ys.shape |
| if not out_channel_first: |
| B, H, W, K, C = ys.shape |
| ctx.shape = (B, C, H, W) |
| |
| _fn = cross_merge1b1_fwd if one_by_one else cross_merge_fwd |
| y = _fn(ys, in_channel_first, out_channel_first, scans) |
|
|
| return y |
| |
| @staticmethod |
| def backward(ctx, x: torch.Tensor): |
| |
| |
| in_channel_first = ctx.in_channel_first |
| out_channel_first = ctx.out_channel_first |
| one_by_one = ctx.one_by_one |
| scans = ctx.scans |
| B, C, H, W = ctx.shape |
| |
| if not one_by_one: |
| if in_channel_first: |
| x = x.view(B, C, H, W) |
| else: |
| x = x.view(B, H, W, C) |
| else: |
| if in_channel_first: |
| x = x.view(B, 4, C, H, W) |
| else: |
| x = x.view(B, H, W, 4, C) |
| |
| _fn = cross_scan1b1_fwd if one_by_one else cross_scan_fwd |
| x = _fn(x, in_channel_first, out_channel_first, scans) |
| x = x.view(B, 4, C, H, W) if out_channel_first else x.view(B, H, W, 4, C) |
| |
| return x, None, None, None, None |
|
|
|
|
| |
|
|
| @triton.jit |
| def triton_cross_scan_flex( |
| x, |
| y, |
| x_layout: tl.constexpr, |
| y_layout: tl.constexpr, |
| operation: tl.constexpr, |
| onebyone: tl.constexpr, |
| scans: tl.constexpr, |
| BC: tl.constexpr, |
| BH: tl.constexpr, |
| BW: tl.constexpr, |
| DC: tl.constexpr, |
| DH: tl.constexpr, |
| DW: tl.constexpr, |
| NH: tl.constexpr, |
| NW: tl.constexpr, |
| ): |
| |
| |
| |
| |
| |
|
|
| i_hw, i_c, i_b = tl.program_id(0), tl.program_id(1), tl.program_id(2) |
| i_h, i_w = (i_hw // NW), (i_hw % NW) |
| _mask_h = (i_h * BH + tl.arange(0, BH)) < DH |
| _mask_w = (i_w * BW + tl.arange(0, BW)) < DW |
| _mask_hw = _mask_h[:, None] & _mask_w[None, :] |
| _for_C = min(DC - i_c * BC, BC) |
|
|
| HWRoute0 = i_h * BH * DW + tl.arange(0, BH)[:, None] * DW + i_w * BW + tl.arange(0, BW)[None, :] |
| HWRoute1 = i_w * BW * DH + tl.arange(0, BW)[None, :] * DH + i_h * BH + tl.arange(0, BH)[:, None] |
| HWRoute2 = (NH - i_h - 1) * BH * DW + (BH - 1 - tl.arange(0, BH)[:, None]) * DW + (NW - i_w - 1) * BW + (BW - 1 - tl.arange(0, BW)[None, :]) + (DH - NH * BH) * DW + (DW - NW * BW) |
| HWRoute3 = (NW - i_w - 1) * BW * DH + (BW - 1 - tl.arange(0, BW)[None, :]) * DH + (NH - i_h - 1) * BH + (BH - 1 - tl.arange(0, BH)[:, None]) + (DH - NH * BH) + (DW - NW * BW) * DH |
|
|
| if scans == 1: |
| HWRoute1 = HWRoute0 |
| HWRoute2 = HWRoute0 |
| HWRoute3 = HWRoute0 |
| elif scans == 2: |
| HWRoute1 = HWRoute0 |
| HWRoute3 = HWRoute2 |
|
|
| _tmp1 = DC * DH * DW |
|
|
| y_ptr_base = y + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if y_layout == 0 else i_c * BC) |
| if y_layout == 0: |
| p_y1 = y_ptr_base + HWRoute0 |
| p_y2 = y_ptr_base + _tmp1 + HWRoute1 |
| p_y3 = y_ptr_base + 2 * _tmp1 + HWRoute2 |
| p_y4 = y_ptr_base + 3 * _tmp1 + HWRoute3 |
| else: |
| p_y1 = y_ptr_base + HWRoute0 * 4 * DC |
| p_y2 = y_ptr_base + DC + HWRoute1 * 4 * DC |
| p_y3 = y_ptr_base + 2 * DC + HWRoute2 * 4 * DC |
| p_y4 = y_ptr_base + 3 * DC + HWRoute3 * 4 * DC |
| |
| if onebyone == 0: |
| x_ptr_base = x + i_b * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC) |
| if x_layout == 0: |
| p_x = x_ptr_base + HWRoute0 |
| else: |
| p_x = x_ptr_base + HWRoute0 * DC |
|
|
| if operation == 0: |
| for idxc in range(_for_C): |
| _idx_x = idxc * DH * DW if x_layout == 0 else idxc |
| _idx_y = idxc * DH * DW if y_layout == 0 else idxc |
| _x = tl.load(p_x + _idx_x, mask=_mask_hw) |
| tl.store(p_y1 + _idx_y, _x, mask=_mask_hw) |
| tl.store(p_y2 + _idx_y, _x, mask=_mask_hw) |
| tl.store(p_y3 + _idx_y, _x, mask=_mask_hw) |
| tl.store(p_y4 + _idx_y, _x, mask=_mask_hw) |
| elif operation == 1: |
| for idxc in range(_for_C): |
| _idx_x = idxc * DH * DW if x_layout == 0 else idxc |
| _idx_y = idxc * DH * DW if y_layout == 0 else idxc |
| _y1 = tl.load(p_y1 + _idx_y, mask=_mask_hw) |
| _y2 = tl.load(p_y2 + _idx_y, mask=_mask_hw) |
| _y3 = tl.load(p_y3 + _idx_y, mask=_mask_hw) |
| _y4 = tl.load(p_y4 + _idx_y, mask=_mask_hw) |
| tl.store(p_x + _idx_x, _y1 + _y2 + _y3 + _y4, mask=_mask_hw) |
|
|
| else: |
| x_ptr_base = x + i_b * 4 * _tmp1 + (i_c * BC * DH * DW if x_layout == 0 else i_c * BC) |
| if x_layout == 0: |
| p_x1 = x_ptr_base + HWRoute0 |
| p_x2 = p_x1 + _tmp1 |
| p_x3 = p_x2 + _tmp1 |
| p_x4 = p_x3 + _tmp1 |
| else: |
| p_x1 = x_ptr_base + HWRoute0 * 4 * DC |
| p_x2 = p_x1 + DC |
| p_x3 = p_x2 + DC |
| p_x4 = p_x3 + DC |
| |
| if operation == 0: |
| for idxc in range(_for_C): |
| _idx_x = idxc * DH * DW if x_layout == 0 else idxc |
| _idx_y = idxc * DH * DW if y_layout == 0 else idxc |
| tl.store(p_y1 + _idx_y, tl.load(p_x1 + _idx_x, mask=_mask_hw), mask=_mask_hw) |
| tl.store(p_y2 + _idx_y, tl.load(p_x2 + _idx_x, mask=_mask_hw), mask=_mask_hw) |
| tl.store(p_y3 + _idx_y, tl.load(p_x3 + _idx_x, mask=_mask_hw), mask=_mask_hw) |
| tl.store(p_y4 + _idx_y, tl.load(p_x4 + _idx_x, mask=_mask_hw), mask=_mask_hw) |
| else: |
| for idxc in range(_for_C): |
| _idx_x = idxc * DH * DW if x_layout == 0 else idxc |
| _idx_y = idxc * DH * DW if y_layout == 0 else idxc |
| tl.store(p_x1 + _idx_x, tl.load(p_y1 + _idx_y), mask=_mask_hw) |
| tl.store(p_x2 + _idx_x, tl.load(p_y2 + _idx_y), mask=_mask_hw) |
| tl.store(p_x3 + _idx_x, tl.load(p_y3 + _idx_y), mask=_mask_hw) |
| tl.store(p_x4 + _idx_x, tl.load(p_y4 + _idx_y), mask=_mask_hw) |
|
|
|
|
| class CrossScanTritonF(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0): |
| if one_by_one: |
| if in_channel_first: |
| B, _, C, H, W = x.shape |
| else: |
| B, H, W, _, C = x.shape |
| else: |
| if in_channel_first: |
| B, C, H, W = x.shape |
| else: |
| B, H, W, C = x.shape |
| B, C, H, W = int(B), int(C), int(H), int(W) |
| BC, BH, BW = 1, 32, 32 |
| NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC) |
| |
| ctx.in_channel_first = in_channel_first |
| ctx.out_channel_first = out_channel_first |
| ctx.one_by_one = one_by_one |
| ctx.scans = scans |
| ctx.shape = (B, C, H, W) |
| ctx.triton_shape = (BC, BH, BW, NC, NH, NW) |
|
|
| y = x.new_empty((B, 4, C, H * W)) if out_channel_first else x.new_empty((B, H * W, 4, C)) |
| triton_cross_scan_flex[(NH * NW, NC, B)]( |
| x.contiguous(), y, |
| (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans, |
| BC, BH, BW, C, H, W, NH, NW |
| ) |
| return y |
| |
| @staticmethod |
| def backward(ctx, y: torch.Tensor): |
| in_channel_first = ctx.in_channel_first |
| out_channel_first = ctx.out_channel_first |
| one_by_one = ctx.one_by_one |
| scans = ctx.scans |
| B, C, H, W = ctx.shape |
| BC, BH, BW, NC, NH, NW = ctx.triton_shape |
| if one_by_one: |
| x = y.new_empty((B, 4, C, H, W)) if in_channel_first else y.new_empty((B, H, W, 4, C)) |
| else: |
| x = y.new_empty((B, C, H, W)) if in_channel_first else y.new_empty((B, H, W, C)) |
| |
| triton_cross_scan_flex[(NH * NW, NC, B)]( |
| x, y.contiguous(), |
| (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans, |
| BC, BH, BW, C, H, W, NH, NW |
| ) |
| return x, None, None, None, None |
|
|
|
|
| class CrossMergeTritonF(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0): |
| if out_channel_first: |
| B, _, C, H, W = y.shape |
| else: |
| B, H, W, _, C = y.shape |
| B, C, H, W = int(B), int(C), int(H), int(W) |
| BC, BH, BW = 1, 32, 32 |
| NH, NW, NC = triton.cdiv(H, BH), triton.cdiv(W, BW), triton.cdiv(C, BC) |
| ctx.in_channel_first = in_channel_first |
| ctx.out_channel_first = out_channel_first |
| ctx.one_by_one = one_by_one |
| ctx.scans = scans |
| ctx.shape = (B, C, H, W) |
| ctx.triton_shape = (BC, BH, BW, NC, NH, NW) |
| if one_by_one: |
| x = y.new_empty((B, 4, C, H * W)) if in_channel_first else y.new_empty((B, H * W, 4, C)) |
| else: |
| x = y.new_empty((B, C, H * W)) if in_channel_first else y.new_empty((B, H * W, C)) |
| triton_cross_scan_flex[(NH * NW, NC, B)]( |
| x, y.contiguous(), |
| (0 if in_channel_first else 1), (0 if out_channel_first else 1), 1, (0 if not one_by_one else 1), scans, |
| BC, BH, BW, C, H, W, NH, NW |
| ) |
| return x |
| |
| @staticmethod |
| def backward(ctx, x: torch.Tensor): |
| in_channel_first = ctx.in_channel_first |
| out_channel_first = ctx.out_channel_first |
| one_by_one = ctx.one_by_one |
| scans = ctx.scans |
| B, C, H, W = ctx.shape |
| BC, BH, BW, NC, NH, NW = ctx.triton_shape |
| y = x.new_empty((B, 4, C, H, W)) if out_channel_first else x.new_empty((B, H, W, 4, C)) |
| triton_cross_scan_flex[(NH * NW, NC, B)]( |
| x.contiguous(), y, |
| (0 if in_channel_first else 1), (0 if out_channel_first else 1), 0, (0 if not one_by_one else 1), scans, |
| BC, BH, BW, C, H, W, NH, NW |
| ) |
| return y, None, None, None, None, None |
|
|
|
|
| |
| def cross_scan_fn(x: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False): |
| |
| |
| |
| CSF = CrossScanTritonF if WITH_TRITON and x.is_cuda and (not force_torch) else CrossScanF |
| return CSF.apply(x, in_channel_first, out_channel_first, one_by_one, scans) |
|
|
|
|
| |
| def cross_merge_fn(y: torch.Tensor, in_channel_first=True, out_channel_first=True, one_by_one=False, scans=0, force_torch=False): |
| |
| |
| |
| CMF = CrossMergeTritonF if WITH_TRITON and y.is_cuda and (not force_torch) else CrossMergeF |
| return CMF.apply(y, in_channel_first, out_channel_first, one_by_one, scans) |
|
|
|
|
| |
|
|
| class CHECK: |
| def check_csm_triton(): |
| B, C, H, W = 2, 192, 56, 57 |
| dtype=torch.float16 |
| dtype=torch.float32 |
| x = torch.randn((B, C, H, W), dtype=dtype, device=torch.device("cuda")).requires_grad_(True) |
| y = torch.randn((B, 4, C, H, W), dtype=dtype, device=torch.device("cuda")).requires_grad_(True) |
| x1 = x.clone().detach().requires_grad_(True) |
| y1 = y.clone().detach().requires_grad_(True) |
|
|
| def cross_scan(x: torch.Tensor): |
| B, C, H, W = x.shape |
| L = H * W |
| xs = torch.stack([ |
| x.view(B, C, L), |
| torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, C, L), |
| torch.flip(x.contiguous().view(B, C, L), dims=[-1]), |
| torch.flip(torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, C, L), dims=[-1]), |
| ], dim=1).view(B, 4, C, L) |
| return xs |
| |
| def cross_merge(out_y: torch.Tensor): |
| B, K, D, H, W = out_y.shape |
| L = H * W |
| out_y = out_y.view(B, K, D, L) |
| inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) |
| wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) |
| invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) |
| y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y |
| return y |
|
|
| def cross_scan_1b1(x: torch.Tensor): |
| B, K, C, H, W = x.shape |
| L = H * W |
| xs = torch.stack([ |
| x[:, 0].view(B, C, L), |
| torch.transpose(x[:, 1], dim0=2, dim1=3).contiguous().view(B, C, L), |
| torch.flip(x[:, 2].contiguous().view(B, C, L), dims=[-1]), |
| torch.flip(torch.transpose(x[:, 3], dim0=2, dim1=3).contiguous().view(B, C, L), dims=[-1]), |
| ], dim=1).view(B, 4, C, L) |
| return xs |
| |
| def unidi_scan(x): |
| B, C, H, W = x.shape |
| x = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1) |
| return x |
| |
| def unidi_merge(ys): |
| B, K, C, H, W = ys.shape |
| return ys.view(B, 4, -1, H * W).sum(1) |
|
|
| def bidi_scan(x): |
| B, C, H, W = x.shape |
| x = x.view(B, 1, C, H * W).repeat(1, 2, 1, 1) |
| x = torch.cat([x, x.flip(dims=[-1])], dim=1) |
| return x |
| |
| def bidi_merge(ys): |
| B, K, D, H, W = ys.shape |
| ys = ys.view(B, K, D, -1) |
| ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) |
| return ys.contiguous().sum(1) |
|
|
| if True: |
| res0 = triton.testing.do_bench(lambda :cross_scan(x)) |
| res1 = triton.testing.do_bench(lambda :cross_scan_fn(x, True, True, False)) |
| |
| res3 = triton.testing.do_bench(lambda :cross_merge(y)) |
| res4 = triton.testing.do_bench(lambda :cross_merge_fn(y, True, True, False)) |
| |
| |
| print(res0, res1, res3, res4) |
| res0 = triton.testing.do_bench(lambda :cross_scan(x).sum().backward()) |
| res1 = triton.testing.do_bench(lambda :cross_scan_fn(x, True, True, False).sum().backward()) |
| |
| res3 = triton.testing.do_bench(lambda :cross_merge(y).sum().backward()) |
| res4 = triton.testing.do_bench(lambda :cross_merge_fn(y, True, True, False).sum().backward()) |
| |
| |
| print(res0, res1, res3, res4) |
|
|
| print("test cross scan") |
| for (cs0, cm0, cs1, cm1) in [ |
| |
| (cross_scan, cross_merge, cross_scan_fn, cross_merge_fn), |
| (unidi_scan, unidi_merge, lambda x: cross_scan_fn(x, scans=1), lambda x: cross_merge_fn(x, scans=1)), |
| (bidi_scan, bidi_merge, lambda x: cross_scan_fn(x, scans=2), lambda x: cross_merge_fn(x, scans=2)), |
| |
| |
| (cross_scan, cross_merge, lambda x: cross_scan_fn(x.permute(0, 2, 3, 1), in_channel_first=False), lambda x: cross_merge_fn(x, in_channel_first=False).permute(0, 2, 1)), |
| (cross_scan, cross_merge, lambda x: cross_scan_fn(x, out_channel_first=False).permute(0, 2, 3, 1), lambda x: cross_merge_fn(x.permute(0, 3, 4, 1, 2), out_channel_first=False)), |
| (cross_scan, cross_merge, lambda x: cross_scan_fn(x.permute(0, 2, 3, 1), in_channel_first=False, out_channel_first=False).permute(0, 2, 3, 1), lambda x: cross_merge_fn(x.permute(0, 3, 4, 1, 2), in_channel_first=False, out_channel_first=False).permute(0, 2, 1)), |
| |
| |
| |
| |
| |
| ]: |
| x.grad, x1.grad, y.grad, y1.grad = None, None, None, None |
| o0 = cs0(x) |
| o1 = cs1(x1) |
| o0.backward(y.view(B, 4, C, H * W)) |
| o1.backward(y.view(B, 4, C, H * W)) |
| print((o0 - o1).abs().max()) |
| print((x.grad - x1.grad).abs().max()) |
| o0 = cm0(y) |
| o1 = cm1(y1) |
| o0.backward(x.view(B, C, H * W)) |
| o1.backward(x.view(B, C, H * W)) |
| print((o0 - o1).abs().max()) |
| print((y.grad - y1.grad).abs().max()) |
| x.grad, x1.grad, y.grad, y1.grad = None, None, None, None |
| print("===============", flush=True) |
|
|
| print("test cross scan one by one") |
| for (cs0, cs1) in [ |
| (cross_scan_1b1, lambda x: cross_scan_fn(x, one_by_one=True)), |
| |
| ]: |
| o0 = cs0(y) |
| o1 = cs1(y1) |
| o0.backward(y.view(B, 4, C, H * W)) |
| o1.backward(y.view(B, 4, C, H * W)) |
| print((o0 - o1).abs().max()) |
| print((y.grad - y1.grad).abs().max()) |
| x.grad, x1.grad, y.grad, y1.grad = None, None, None, None |
| print("===============", flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| CHECK.check_csm_triton() |
|
|
|
|
|
|
|
|
|
|