import os import time import math import copy from functools import partial from typing import Optional, Callable, Any from collections import OrderedDict import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from einops import rearrange, repeat from timm.models.layers import DropPath, trunc_normal_ from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count, parameter_count DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})" # import selective scan ============================== try: import selective_scan_cuda_oflex except Exception as e: ... # print(f"WARNING: can not import selective_scan_cuda_oflex.", flush=True) # print(e, flush=True) try: import selective_scan_cuda_core except Exception as e: ... # print(f"WARNING: can not import selective_scan_cuda_core.", flush=True) # print(e, flush=True) try: import selective_scan_cuda except Exception as e: ... # print(f"WARNING: can not import selective_scan_cuda.", flush=True) # print(e, flush=True) # fvcore flops ======================================= def flops_selective_scan_fn(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_complex=False): """ u: r(B D L) delta: r(B D L) A: r(D N) B: r(B N L) C: r(B N L) D: r(D) z: r(B D L) delta_bias: r(D), fp32 ignores: [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] """ assert not with_complex # https://github.com/state-spaces/mamba/issues/110 flops = 9 * B * L * D * N if with_D: flops += B * D * L if with_Z: flops += B * D * L return flops # this is only for selective_scan_ref... def flops_selective_scan_ref(B=1, L=256, D=768, N=16, with_D=True, with_Z=False, with_Group=True, with_complex=False): """ u: r(B D L) delta: r(B D L) A: r(D N) B: r(B N L) C: r(B N L) D: r(D) z: r(B D L) delta_bias: r(D), fp32 ignores: [.float(), +, .softplus, .shape, new_zeros, repeat, stack, to(dtype), silu] """ import numpy as np # fvcore.nn.jit_handles def get_flops_einsum(input_shapes, equation): np_arrs = [np.zeros(s) for s in input_shapes] optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1] for line in optim.split("\n"): if "optimized flop" in line.lower(): # divided by 2 because we count MAC (multiply-add counted as one flop) flop = float(np.floor(float(line.split(":")[-1]) / 2)) return flop assert not with_complex flops = 0 # below code flops = 0 flops += get_flops_einsum([[B, D, L], [D, N]], "bdl,dn->bdln") if with_Group: flops += get_flops_einsum([[B, D, L], [B, N, L], [B, D, L]], "bdl,bnl,bdl->bdln") else: flops += get_flops_einsum([[B, D, L], [B, D, N, L], [B, D, L]], "bdl,bdnl,bdl->bdln") in_for_flops = B * D * N if with_Group: in_for_flops += get_flops_einsum([[B, D, N], [B, D, N]], "bdn,bdn->bd") else: in_for_flops += get_flops_einsum([[B, D, N], [B, N]], "bdn,bn->bd") flops += L * in_for_flops if with_D: flops += B * D * L if with_Z: flops += B * D * L return flops def print_jit_input_names(inputs): print("input params: ", end=" ", flush=True) try: for i in range(10): print(inputs[i].debugName(), end=" ", flush=True) except Exception as e: pass print("", flush=True) # cross selective scan =============================== class SelectiveScanMamba(torch.autograd.Function): # comment all checks if inside cross_selective_scan @staticmethod @torch.cuda.amp.custom_fwd def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1, oflex=True): # assert nrows in [1, 2, 3, 4], f"{nrows}" # 8+ is too slow to compile # assert u.shape[1] % (B.shape[1] * nrows) == 0, f"{nrows}, {u.shape}, {B.shape}" ctx.delta_softplus = delta_softplus # all in float # if u.stride(-1) != 1: # u = u.contiguous() # if delta.stride(-1) != 1: # delta = delta.contiguous() # if D is not None and D.stride(-1) != 1: # D = D.contiguous() # if B.stride(-1) != 1: # B = B.contiguous() # if C.stride(-1) != 1: # C = C.contiguous() # if B.dim() == 3: # B = B.unsqueeze(dim=1) # ctx.squeeze_B = True # if C.dim() == 3: # C = C.unsqueeze(dim=1) # ctx.squeeze_C = True out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, None, delta_bias, delta_softplus) ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) return out @staticmethod @torch.cuda.amp.custom_bwd def backward(ctx, dout, *args): u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors if dout.stride(-1) != 1: dout = dout.contiguous() du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd( u, delta, A, B, C, D, None, delta_bias, dout, x, None, None, ctx.delta_softplus, False ) # dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB # dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None) class SelectiveScanCore(torch.autograd.Function): # comment all checks if inside cross_selective_scan @staticmethod @torch.cuda.amp.custom_fwd def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1, oflex=True): ctx.delta_softplus = delta_softplus out, x, *rest = selective_scan_cuda_core.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1) ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) return out @staticmethod @torch.cuda.amp.custom_bwd def backward(ctx, dout, *args): u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors if dout.stride(-1) != 1: dout = dout.contiguous() du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_core.bwd( u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1 ) return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None) class SelectiveScanOflex(torch.autograd.Function): # comment all checks if inside cross_selective_scan @staticmethod @torch.cuda.amp.custom_fwd def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1, oflex=True): ctx.delta_softplus = delta_softplus out, x, *rest = selective_scan_cuda_oflex.fwd(u, delta, A, B, C, D, delta_bias, delta_softplus, 1, oflex) ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) return out @staticmethod @torch.cuda.amp.custom_bwd def backward(ctx, dout, *args): u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors if dout.stride(-1) != 1: dout = dout.contiguous() du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda_oflex.bwd( u, delta, A, B, C, D, delta_bias, dout, x, ctx.delta_softplus, 1 ) return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None) class SelectiveScanFake(torch.autograd.Function): # comment all checks if inside cross_selective_scan @staticmethod @torch.cuda.amp.custom_fwd def forward(ctx, u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=False, nrows=1, backnrows=1, oflex=True): ctx.delta_softplus = delta_softplus ctx.backnrows = backnrows x = delta out = u ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x) return out @staticmethod @torch.cuda.amp.custom_bwd def backward(ctx, dout, *args): u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors if dout.stride(-1) != 1: dout = dout.contiguous() du, ddelta, dA, dB, dC, dD, ddelta_bias = u * 0, delta * 0, A * 0, B * 0, C * 0, C * 0, (D * 0 if D else None), (delta_bias * 0 if delta_bias else None) return (du, ddelta, dA, dB, dC, dD, ddelta_bias, None, None, None, None) # ============= def antidiagonal_gather(tensor): # 取出矩阵所有反斜向的元素并拼接 B, C, H, W = tensor.size() shift = torch.arange(H, device=tensor.device).unsqueeze(1) # 创建一个列向量[H, 1] index = (torch.arange(W, device=tensor.device) - shift) % W # 利用广播创建索引矩阵[H, W] # 扩展索引以适应B和C维度 expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1) # 使用gather进行索引选择 return tensor.gather(3, expanded_index).transpose(-1,-2).reshape(B, C, H*W) def diagonal_gather(tensor): # 取出矩阵所有反斜向的元素并拼接 B, C, H, W = tensor.size() shift = torch.arange(H, device=tensor.device).unsqueeze(1) # 创建一个列向量[H, 1] index = (shift + torch.arange(W, device=tensor.device)) % W # 利用广播创建索引矩阵[H, W] # 扩展索引以适应B和C维度 expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1) # 使用gather进行索引选择 return tensor.gather(3, expanded_index).transpose(-1,-2).reshape(B, C, H*W) def diagonal_scatter(tensor_flat, original_shape): # 把斜向元素拼接起来的一维向量还原为最初的矩阵形式 B, C, H, W = original_shape shift = torch.arange(H, device=tensor_flat.device).unsqueeze(1) # 创建一个列向量[H, 1] index = (shift + torch.arange(W, device=tensor_flat.device)) % W # 利用广播创建索引矩阵[H, W] # 扩展索引以适应B和C维度 expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1) # 创建一个空的张量来存储反向散布的结果 result_tensor = torch.zeros(B, C, H, W, device=tensor_flat.device, dtype=tensor_flat.dtype) # 将平铺的张量重新变形为[B, C, H, W],考虑到需要使用transpose将H和W调换 tensor_reshaped = tensor_flat.reshape(B, C, W, H).transpose(-1, -2) # 使用scatter_根据expanded_index将元素放回原位 result_tensor.scatter_(3, expanded_index, tensor_reshaped) return result_tensor def antidiagonal_scatter(tensor_flat, original_shape): # 把反斜向元素拼接起来的一维向量还原为最初的矩阵形式 B, C, H, W = original_shape shift = torch.arange(H, device=tensor_flat.device).unsqueeze(1) # 创建一个列向量[H, 1] index = (torch.arange(W, device=tensor_flat.device) - shift) % W # 利用广播创建索引矩阵[H, W] expanded_index = index.unsqueeze(0).unsqueeze(0).expand(B, C, -1, -1) # 初始化一个与原始张量形状相同、元素全为0的张量 result_tensor = torch.zeros(B, C, H, W, device=tensor_flat.device, dtype=tensor_flat.dtype) # 将平铺的张量重新变形为[B, C, W, H],因为操作是沿最后一个维度收集的,需要调整形状并交换维度 tensor_reshaped = tensor_flat.reshape(B, C, W, H).transpose(-1, -2) # 使用scatter_将元素根据索引放回原位 result_tensor.scatter_(3, expanded_index, tensor_reshaped) return result_tensor class CrossScan(torch.autograd.Function): # ZSJ 这里是把图像按照特定方向展平的地方,改变扫描方向可以在这里修改 @staticmethod def forward(ctx, x: torch.Tensor): B, C, H, W = x.shape ctx.shape = (B, C, H, W) # xs = x.new_empty((B, 4, C, H * W)) xs = x.new_empty((B, 8, C, H * W)) # 添加横向和竖向的扫描 xs[:, 0] = x.flatten(2, 3) xs[:, 1] = x.transpose(dim0=2, dim1=3).flatten(2, 3) xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1]) # 提供斜向和反斜向的扫描 xs[:, 4] = diagonal_gather(x) xs[:, 5] = antidiagonal_gather(x) xs[:, 6:8] = torch.flip(xs[:, 4:6], dims=[-1]) return xs @staticmethod def backward(ctx, ys: torch.Tensor): # out: (b, k, d, l) B, C, H, W = ctx.shape L = H * W # 把横向和竖向的反向部分再反向回来,并和原来的横向和竖向相加 # ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L) y_rb = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L) # 把竖向的部分转成横向,然后再相加,再转回最初是的矩阵形式 # y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L) y_rb = y_rb[:, 0] + y_rb[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L) y_rb = y_rb.view(B, -1, H, W) # 把斜向和反斜向的反向部分再反向回来,并和原来的斜向和反斜向相加 y_da = ys[:, 4:6] + ys[:, 6:8].flip(dims=[-1]).view(B, 2, -1, L) # 把斜向和反斜向的部分都转成原来的最初的矩阵形式,再相加 y_da = diagonal_scatter(y_da[:, 0], (B,C,H,W)) + antidiagonal_scatter(y_da[:, 1], (B,C,H,W)) y_res = y_rb + y_da # return y.view(B, -1, H, W) return y_res class CrossMerge(torch.autograd.Function): @staticmethod def forward(ctx, ys: torch.Tensor): B, K, D, H, W = ys.shape ctx.shape = (H, W) ys = ys.view(B, K, D, -1) # ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) # y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1) y_rb = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) # 把竖向的部分转成横向,然后再相加,再转回最初是的矩阵形式 y_rb = y_rb[:, 0] + y_rb[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, D, -1) y_rb = y_rb.view(B, -1, H, W) # 把斜向和反斜向的反向部分再反向回来,并和原来的斜向和反斜向相加 y_da = ys[:, 4:6] + ys[:, 6:8].flip(dims=[-1]).view(B, 2, D, -1) # 把斜向和反斜向的部分都转成原来的最初的矩阵形式,再相加 y_da = diagonal_scatter(y_da[:, 0], (B,D,H,W)) + antidiagonal_scatter(y_da[:, 1], (B,D,H,W)) y_res = y_rb + y_da return y_res.view(B, D, -1) # return y @staticmethod def backward(ctx, x: torch.Tensor): # B, D, L = x.shape # out: (b, k, d, l) H, W = ctx.shape B, C, L = x.shape # xs = x.new_empty((B, 4, C, L)) xs = x.new_empty((B, 8, C, L)) # 横向和竖向扫描 xs[:, 0] = x xs[:, 1] = x.view(B, C, H, W).transpose(dim0=2, dim1=3).flatten(2, 3) xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1]) # xs = xs.view(B, 4, C, H, W) # 提供斜向和反斜向的扫描 xs[:, 4] = diagonal_gather(x.view(B,C,H,W)) xs[:, 5] = antidiagonal_gather(x.view(B,C,H,W)) xs[:, 6:8] = torch.flip(xs[:, 4:6], dims=[-1]) # return xs return xs.view(B, 8, C, H, W) # these are for ablations ============= class CrossScan_Ab_2direction(torch.autograd.Function): @staticmethod def forward(ctx, x: torch.Tensor): B, C, H, W = x.shape ctx.shape = (B, C, H, W) xs = x.new_empty((B, 4, C, H * W)) xs[:, 0] = x.flatten(2, 3) xs[:, 1] = x.flatten(2, 3) xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1]) return xs @staticmethod def backward(ctx, ys: torch.Tensor): # out: (b, k, d, l) B, C, H, W = ctx.shape L = H * W ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, -1, L) y = ys[:, 0] + ys[:, 1].view(B, -1, W, H).transpose(dim0=2, dim1=3).contiguous().view(B, -1, L) return y.view(B, -1, H, W) class CrossMerge_Ab_2direction(torch.autograd.Function): @staticmethod def forward(ctx, ys: torch.Tensor): B, K, D, H, W = ys.shape ctx.shape = (H, W) ys = ys.view(B, K, D, -1) ys = ys[:, 0:2] + ys[:, 2:4].flip(dims=[-1]).view(B, 2, D, -1) y = ys.sum(dim=1) return y @staticmethod def backward(ctx, x: torch.Tensor): # B, D, L = x.shape # out: (b, k, d, l) H, W = ctx.shape B, C, L = x.shape xs = x.new_empty((B, 4, C, L)) xs[:, 0] = x xs[:, 1] = x xs[:, 2:4] = torch.flip(xs[:, 0:2], dims=[-1]) xs = xs.view(B, 4, C, H, W) return xs class CrossScan_Ab_1direction(torch.autograd.Function): @staticmethod def forward(ctx, x: torch.Tensor): B, C, H, W = x.shape ctx.shape = (B, C, H, W) xs = x.view(B, 1, C, H * W).repeat(1, 4, 1, 1).contiguous() return xs @staticmethod def backward(ctx, ys: torch.Tensor): # out: (b, k, d, l) B, C, H, W = ctx.shape y = ys.sum(dim=1).view(B, C, H, W) return y class CrossMerge_Ab_1direction(torch.autograd.Function): @staticmethod def forward(ctx, ys: torch.Tensor): B, K, D, H, W = ys.shape ctx.shape = (H, W) y = ys.sum(dim=1).view(B, D, H * W) return y @staticmethod def backward(ctx, x: torch.Tensor): # B, D, L = x.shape # out: (b, k, d, l) H, W = ctx.shape B, C, L = x.shape xs = x.view(B, 1, C, L).repeat(1, 4, 1, 1).contiguous().view(B, 4, C, H, W) return xs # ============= # ZSJ 这里是mamba的具体内容,要增加扫描方向就在这里改 def cross_selective_scan( x: torch.Tensor=None, x_proj_weight: torch.Tensor=None, x_proj_bias: torch.Tensor=None, dt_projs_weight: torch.Tensor=None, dt_projs_bias: torch.Tensor=None, A_logs: torch.Tensor=None, Ds: torch.Tensor=None, delta_softplus = True, out_norm: torch.nn.Module=None, out_norm_shape="v0", # ============================== to_dtype=True, # True: final out to dtype force_fp32=False, # True: input fp32 # ============================== nrows = -1, # for SelectiveScanNRow; 0: auto; -1: disable; backnrows = -1, # for SelectiveScanNRow; 0: auto; -1: disable; ssoflex=True, # True: out fp32 in SSOflex; else, SSOflex is the same as SSCore # ============================== SelectiveScan=None, CrossScan=CrossScan, CrossMerge=CrossMerge, ): # out_norm: whatever fits (B, L, C); LayerNorm; Sigmoid; Softmax(dim=1);... B, D, H, W = x.shape D, N = A_logs.shape K, D, R = dt_projs_weight.shape L = H * W if nrows == 0: if D % 4 == 0: nrows = 4 elif D % 3 == 0: nrows = 3 elif D % 2 == 0: nrows = 2 else: nrows = 1 if backnrows == 0: if D % 4 == 0: backnrows = 4 elif D % 3 == 0: backnrows = 3 elif D % 2 == 0: backnrows = 2 else: backnrows = 1 def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True): return SelectiveScan.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows, backnrows, ssoflex) xs = CrossScan.apply(x) x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, x_proj_weight) if x_proj_bias is not None: x_dbl = x_dbl + x_proj_bias.view(1, K, -1, 1) dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2) dts = torch.einsum("b k r l, k d r -> b k d l", dts, dt_projs_weight) xs = xs.view(B, -1, L) dts = dts.contiguous().view(B, -1, L) As = -torch.exp(A_logs.to(torch.float)) # (k * c, d_state) Bs = Bs.contiguous() Cs = Cs.contiguous() Ds = Ds.to(torch.float) # (K * c) delta_bias = dt_projs_bias.view(-1).to(torch.float) if force_fp32: xs = xs.to(torch.float) dts = dts.to(torch.float) Bs = Bs.to(torch.float) Cs = Cs.to(torch.float) # ZSJ 这里把矩阵拆分成不同方向的序列,并进行扫描 ys: torch.Tensor = selective_scan( xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus ).view(B, K, -1, H, W) # ZSJ 这里把处理之后的序列融合起来,并还原回原来的矩阵形式 y: torch.Tensor = CrossMerge.apply(ys) if out_norm_shape in ["v1"]: # (B, C, H, W) y = out_norm(y.view(B, -1, H, W)).permute(0, 2, 3, 1) # (B, H, W, C) else: # (B, L, C) y = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C) y = out_norm(y).view(B, H, W, -1) return (y.to(x.dtype) if to_dtype else y) def selective_scan_flop_jit(inputs, outputs): print_jit_input_names(inputs) B, D, L = inputs[0].type().sizes() N = inputs[2].type().sizes()[1] flops = flops_selective_scan_fn(B=B, L=L, D=D, N=N, with_D=True, with_Z=False) return flops # ===================================================== class PatchMerging2D(nn.Module): def __init__(self, dim, out_dim=-1, norm_layer=nn.LayerNorm): super().__init__() self.dim = dim self.reduction = nn.Linear(4 * dim, (2 * dim) if out_dim < 0 else out_dim, bias=False) self.norm = norm_layer(4 * dim) @staticmethod def _patch_merging_pad(x: torch.Tensor): H, W, _ = x.shape[-3:] if (W % 2 != 0) or (H % 2 != 0): x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C return x def forward(self, x): x = self._patch_merging_pad(x) x = self.norm(x) x = self.reduction(x) return x class OSSM(nn.Module): def __init__( self, # basic dims =========== d_model=96, d_state=16, ssm_ratio=2.0, dt_rank="auto", act_layer=nn.SiLU, # dwconv =============== d_conv=3, # < 2 means no conv conv_bias=True, # ====================== dropout=0.0, bias=False, # dt init ============== dt_min=0.001, dt_max=0.1, dt_init="random", dt_scale=1.0, dt_init_floor=1e-4, initialize="v0", # ====================== forward_type="v2", # ====================== **kwargs, ): factory_kwargs = {"device": None, "dtype": None} super().__init__() d_inner = int(ssm_ratio * d_model) dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank self.d_conv = d_conv # tags for forward_type ============================== def checkpostfix(tag, value): ret = value[-len(tag):] == tag if ret: value = value[:-len(tag)] return ret, value self.disable_force32, forward_type = checkpostfix("no32", forward_type) self.disable_z, forward_type = checkpostfix("noz", forward_type) self.disable_z_act, forward_type = checkpostfix("nozact", forward_type) # softmax | sigmoid | dwconv | norm =========================== if forward_type[-len("none"):] == "none": forward_type = forward_type[:-len("none")] self.out_norm = nn.Identity() elif forward_type[-len("dwconv3"):] == "dwconv3": forward_type = forward_type[:-len("dwconv3")] self.out_norm = nn.Conv2d(d_inner, d_inner, kernel_size=3, padding=1, groups=d_inner, bias=False) self.out_norm_shape = "v1" elif forward_type[-len("softmax"):] == "softmax": forward_type = forward_type[:-len("softmax")] self.out_norm = nn.Softmax(dim=1) elif forward_type[-len("sigmoid"):] == "sigmoid": forward_type = forward_type[:-len("sigmoid")] self.out_norm = nn.Sigmoid() else: self.out_norm = nn.LayerNorm(d_inner) # forward_type debug ======================================= FORWARD_TYPES = dict( v0=self.forward_corev0, # v2=partial(self.forward_corev2, force_fp32=(not self.disable_force32), SelectiveScan=SelectiveScanCore), v2=partial(self.forward_corev2, force_fp32=True, SelectiveScan=SelectiveScanCore), v3=partial(self.forward_corev2, force_fp32=False, SelectiveScan=SelectiveScanOflex), v31d=partial(self.forward_corev2, force_fp32=False, SelectiveScan=SelectiveScanOflex, cross_selective_scan=partial( cross_selective_scan, CrossScan=CrossScan_Ab_1direction, CrossMerge=CrossMerge_Ab_1direction, )), v32d=partial(self.forward_corev2, force_fp32=False, SelectiveScan=SelectiveScanOflex, cross_selective_scan=partial( cross_selective_scan, CrossScan=CrossScan_Ab_2direction, CrossMerge=CrossMerge_Ab_2direction, )), # =============================== fake=partial(self.forward_corev2, force_fp32=(not self.disable_force32), SelectiveScan=SelectiveScanFake), v1=partial(self.forward_corev2, force_fp32=True, SelectiveScan=SelectiveScanOflex), v01=partial(self.forward_corev2, force_fp32=(not self.disable_force32), SelectiveScan=SelectiveScanMamba), ) if forward_type.startswith("debug"): from .ss2d_ablations import SS2D_ForwardCoreSpeedAblations, SS2D_ForwardCoreModeAblations, cross_selective_scanv2 FORWARD_TYPES.update(dict( debugforward_core_mambassm_seq=partial(SS2D_ForwardCoreSpeedAblations.forward_core_mambassm_seq, self), debugforward_core_mambassm=partial(SS2D_ForwardCoreSpeedAblations.forward_core_mambassm, self), debugforward_core_mambassm_fp16=partial(SS2D_ForwardCoreSpeedAblations.forward_core_mambassm_fp16, self), debugforward_core_mambassm_fusecs=partial(SS2D_ForwardCoreSpeedAblations.forward_core_mambassm_fusecs, self), debugforward_core_mambassm_fusecscm=partial(SS2D_ForwardCoreSpeedAblations.forward_core_mambassm_fusecscm, self), debugforward_core_sscore_fusecscm=partial(SS2D_ForwardCoreSpeedAblations.forward_core_sscore_fusecscm, self), debugforward_core_sscore_fusecscm_fwdnrow=partial(SS2D_ForwardCoreSpeedAblations.forward_core_ssnrow_fusecscm_fwdnrow, self), debugforward_core_sscore_fusecscm_bwdnrow=partial(SS2D_ForwardCoreSpeedAblations.forward_core_ssnrow_fusecscm_bwdnrow, self), debugforward_core_sscore_fusecscm_fbnrow=partial(SS2D_ForwardCoreSpeedAblations.forward_core_ssnrow_fusecscm_fbnrow, self), debugforward_core_ssoflex_fusecscm=partial(SS2D_ForwardCoreSpeedAblations.forward_core_ssoflex_fusecscm, self), debugforward_core_ssoflex_fusecscm_i16o32=partial(SS2D_ForwardCoreSpeedAblations.forward_core_ssoflex_fusecscm_i16o32, self), debugscan_sharessm=partial(self.forward_corev2, force_fp32=False, SelectiveScan=SelectiveScanOflex, cross_selective_scan=cross_selective_scanv2), )) self.forward_core = FORWARD_TYPES.get(forward_type, None) # ZSJ k_group 指的是扫描的方向 # k_group = 4 if forward_type not in ["debugscan_sharessm"] else 1 k_group = 8 if forward_type not in ["debugscan_sharessm"] else 1 # in proj ======================================= d_proj = d_inner if self.disable_z else (d_inner * 2) self.in_proj = nn.Linear(d_model, d_proj, bias=bias, **factory_kwargs) self.act: nn.Module = act_layer() # conv ======================================= if d_conv > 1: self.conv2d = nn.Conv2d( in_channels=d_inner, out_channels=d_inner, groups=d_inner, bias=conv_bias, kernel_size=d_conv, padding=(d_conv - 1) // 2, **factory_kwargs, ) # x proj ============================ self.x_proj = [ nn.Linear(d_inner, (dt_rank + d_state * 2), bias=False, **factory_kwargs) for _ in range(k_group) ] self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner) del self.x_proj # out proj ======================================= self.out_proj = nn.Linear(d_inner, d_model, bias=bias, **factory_kwargs) self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity() if initialize in ["v0"]: # dt proj ============================ self.dt_projs = [ self.dt_init(dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs) for _ in range(k_group) ] self.dt_projs_weight = nn.Parameter(torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K, inner, rank) self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K, inner) del self.dt_projs # A, D ======================================= self.A_logs = self.A_log_init(d_state, d_inner, copies=k_group, merge=True) # (K * D, N) self.Ds = self.D_init(d_inner, copies=k_group, merge=True) # (K * D) elif initialize in ["v1"]: # simple init dt_projs, A_logs, Ds self.Ds = nn.Parameter(torch.ones((k_group * d_inner))) self.A_logs = nn.Parameter(torch.randn((k_group * d_inner, d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1 self.dt_projs_weight = nn.Parameter(torch.randn((k_group, d_inner, dt_rank))) self.dt_projs_bias = nn.Parameter(torch.randn((k_group, d_inner))) elif initialize in ["v2"]: # simple init dt_projs, A_logs, Ds self.Ds = nn.Parameter(torch.ones((k_group * d_inner))) self.A_logs = nn.Parameter(torch.zeros((k_group * d_inner, d_state))) # A == -A_logs.exp() < 0; # 0 < exp(A * dt) < 1 self.dt_projs_weight = nn.Parameter(torch.randn((k_group, d_inner, dt_rank))) self.dt_projs_bias = nn.Parameter(torch.randn((k_group, d_inner))) @staticmethod def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs): dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs) # Initialize special dt projection to preserve variance at initialization dt_init_std = dt_rank**-0.5 * dt_scale if dt_init == "constant": nn.init.constant_(dt_proj.weight, dt_init_std) elif dt_init == "random": nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) else: raise NotImplementedError # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max dt = torch.exp( torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) ).clamp(min=dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) with torch.no_grad(): dt_proj.bias.copy_(inv_dt) # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit # dt_proj.bias._no_reinit = True return dt_proj @staticmethod def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True): # S4D real initialization A = repeat( torch.arange(1, d_state + 1, dtype=torch.float32, device=device), "n -> d n", d=d_inner, ).contiguous() A_log = torch.log(A) # Keep A_log in fp32 if copies > 0: A_log = repeat(A_log, "d n -> r d n", r=copies) if merge: A_log = A_log.flatten(0, 1) A_log = nn.Parameter(A_log) A_log._no_weight_decay = True return A_log @staticmethod def D_init(d_inner, copies=-1, device=None, merge=True): # D "skip" parameter D = torch.ones(d_inner, device=device) if copies > 0: D = repeat(D, "n1 -> r n1", r=copies) if merge: D = D.flatten(0, 1) D = nn.Parameter(D) # Keep in fp32 D._no_weight_decay = True return D # only used to run previous version def forward_corev0(self, x: torch.Tensor, to_dtype=False, channel_first=False): def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True, nrows=1): return SelectiveScanCore.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, nrows, False) if not channel_first: x = x.permute(0, 3, 1, 2).contiguous() B, D, H, W = x.shape D, N = self.A_logs.shape K, D, R = self.dt_projs_weight.shape L = H * W # ZSJ 这里进行data expand操作,也就是把相同的数据在不同方向展开成一维,并拼接起来,但是这个函数只用在旧版本 # 把横向和竖向拼接在K维度 x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L) # torch.flip把横向和竖向两个方向都进行反向操作 xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1) # (b, k, d, l) x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, self.x_proj_weight) # x_dbl = x_dbl + self.x_proj_bias.view(1, K, -1, 1) dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2) dts = torch.einsum("b k r l, k d r -> b k d l", dts, self.dt_projs_weight) xs = xs.float().view(B, -1, L) # (b, k * d, l) dts = dts.contiguous().float().view(B, -1, L) # (b, k * d, l) Bs = Bs.float() # (b, k, d_state, l) Cs = Cs.float() # (b, k, d_state, l) As = -torch.exp(self.A_logs.float()) # (k * d, d_state) Ds = self.Ds.float() # (k * d) dt_projs_bias = self.dt_projs_bias.float().view(-1) # (k * d) # assert len(xs.shape) == 3 and len(dts.shape) == 3 and len(Bs.shape) == 4 and len(Cs.shape) == 4 # assert len(As.shape) == 2 and len(Ds.shape) == 1 and len(dt_projs_bias.shape) == 1 out_y = selective_scan( xs, dts, As, Bs, Cs, Ds, delta_bias=dt_projs_bias, delta_softplus=True, ).view(B, K, -1, L) # assert out_y.dtype == torch.float inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L) wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L) y = out_y[:, 0] + inv_y[:, 0] + wh_y + invwh_y y = y.transpose(dim0=1, dim1=2).contiguous() # (B, L, C) y = self.out_norm(y).view(B, H, W, -1) return (y.to(x.dtype) if to_dtype else y) def forward_corev2(self, x: torch.Tensor, channel_first=False, SelectiveScan=SelectiveScanOflex, cross_selective_scan=cross_selective_scan, force_fp32=None): if not channel_first: x = x.permute(0, 3, 1, 2).contiguous() # ZSJ V2版本使用的mamba,要改扫描方向在这里改 x = cross_selective_scan( x, self.x_proj_weight, None, self.dt_projs_weight, self.dt_projs_bias, self.A_logs, self.Ds, delta_softplus=True, out_norm=getattr(self, "out_norm", None), out_norm_shape=getattr(self, "out_norm_shape", "v0"), force_fp32=force_fp32, SelectiveScan=SelectiveScan, ) return x def forward(self, x: torch.Tensor, **kwargs): with_dconv = (self.d_conv > 1) x = self.in_proj(x) if not self.disable_z: x, z = x.chunk(2, dim=-1) # (b, h, w, d) if not self.disable_z_act: z = self.act(z) if with_dconv: x = x.permute(0, 3, 1, 2).contiguous() x = self.conv2d(x) # (b, d, h, w) x = self.act(x) y = self.forward_core(x, channel_first=with_dconv) if not self.disable_z: y = y * z out = self.dropout(self.out_proj(y)) return out class Permute(nn.Module): def __init__(self, *args): super().__init__() self.args = args def forward(self, x: torch.Tensor): return x.permute(*self.args) class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.,channels_first=False): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features Linear = partial(nn.Conv2d, kernel_size=1, padding=0) if channels_first else nn.Linear self.fc1 = Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class VSSBlock(nn.Module): def __init__( self, hidden_dim: int = 0, drop_path: float = 0, norm_layer: Callable[..., torch.nn.Module] = partial(nn.LayerNorm, eps=1e-6), # ============================= ssm_d_state: int = 16, ssm_ratio=2.0, ssm_dt_rank: Any = "auto", ssm_act_layer=nn.SiLU, ssm_conv: int = 3, ssm_conv_bias=True, ssm_drop_rate: float = 0, ssm_init="v0", forward_type="v2", # ============================= mlp_ratio=4.0, mlp_act_layer=nn.GELU, mlp_drop_rate: float = 0.0, # ============================= use_checkpoint: bool = False, post_norm: bool = False, **kwargs, ): super().__init__() self.ssm_branch = ssm_ratio > 0 self.mlp_branch = mlp_ratio > 0 self.use_checkpoint = use_checkpoint self.post_norm = post_norm try: from ss2d_ablations import SS2DDev _OSSM = SS2DDev if forward_type.startswith("dev") else OSSM except: _OSSM = OSSM if self.ssm_branch: self.norm = norm_layer(hidden_dim) self.op = _OSSM( d_model=hidden_dim, d_state=ssm_d_state, ssm_ratio=ssm_ratio, dt_rank=ssm_dt_rank, act_layer=ssm_act_layer, # ========================== d_conv=ssm_conv, conv_bias=ssm_conv_bias, # ========================== dropout=ssm_drop_rate, # bias=False, # ========================== # dt_min=0.001, # dt_max=0.1, # dt_init="random", # dt_scale="random", # dt_init_floor=1e-4, initialize=ssm_init, # ========================== forward_type=forward_type, ) self.drop_path = DropPath(drop_path) if self.mlp_branch: self.norm2 = norm_layer(hidden_dim) mlp_hidden_dim = int(hidden_dim * mlp_ratio) self.mlp = Mlp(in_features=hidden_dim, hidden_features=mlp_hidden_dim, act_layer=mlp_act_layer, drop=mlp_drop_rate, channels_first=False) def _forward(self, input: torch.Tensor): if self.ssm_branch: if self.post_norm: x = input + self.drop_path(self.norm(self.op(input))) else: x = input + self.drop_path(self.op(self.norm(input))) if self.mlp_branch: if self.post_norm: x = x + self.drop_path(self.norm2(self.mlp(x))) # FFN else: x = x + self.drop_path(self.mlp(self.norm2(x))) # FFN return x def forward(self, input: torch.Tensor): if self.use_checkpoint: return checkpoint.checkpoint(self._forward, input) else: return self._forward(input) class Decoder_Block(nn.Module): """Basic block in decoder.""" def __init__(self, in_channel, out_channel): super().__init__() assert out_channel == in_channel // 2, 'the out_channel is not in_channel//2 in decoder block' self.up = nn.Upsample(scale_factor=2, mode='nearest') self.fuse = nn.Sequential(nn.Conv2d(in_channels=in_channel + out_channel, out_channels=out_channel, kernel_size=1, padding=0, bias=False), nn.BatchNorm2d(out_channel), nn.ReLU(inplace=True), ) def forward(self, de, en): de = self.up(de) output = torch.cat([de, en], dim=1) output = self.fuse(output) return output class Fuse_Block(nn.Module): """Basic block in decoder.""" def __init__(self, in_channel): super().__init__() self.fuse = nn.Sequential(nn.Conv2d(in_channels=in_channel*2, out_channels=in_channel, kernel_size=1, padding=0, bias=False), nn.BatchNorm2d(in_channel), nn.ReLU(inplace=True), ) def forward(self, x1, x2): # shapes of x1 and x2 are b,h,w,c x1 = rearrange(x1, "b h w c -> b c h w").contiguous() x2 = rearrange(x2, "b h w c -> b c h w").contiguous() output = torch.cat([x1, x2], dim=1) output = self.fuse(output) return output class RSM_CD(nn.Module): def __init__( self, patch_size=4, in_chans=3, num_classes=1000, depths=[2, 2, 9, 2], dims=[96, 192, 384, 768], # ========================= ssm_d_state=16, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer="silu", ssm_conv=3, ssm_conv_bias=True, ssm_drop_rate=0.0, ssm_init="v0", forward_type="v2", # ========================= mlp_ratio=4.0, mlp_act_layer="gelu", mlp_drop_rate=0.0, # ========================= drop_path_rate=0.2, patch_norm=True, norm_layer="LN", use_checkpoint=False, **kwargs, ): super().__init__() self.num_classes = num_classes self.num_layers = len(depths) if isinstance(dims, int): dims = [int(dims * 2 ** i_layer) for i_layer in range(self.num_layers)] self.num_features = dims[-1] self.dims = dims dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule _NORMLAYERS = dict( ln=nn.LayerNorm, bn=nn.BatchNorm2d, ) _ACTLAYERS = dict( silu=nn.SiLU, gelu=nn.GELU, relu=nn.ReLU, sigmoid=nn.Sigmoid, ) if isinstance(norm_layer, str) and norm_layer.lower() in ["ln"]: norm_layer: nn.Module = _NORMLAYERS[norm_layer.lower()] if isinstance(ssm_act_layer, str) and ssm_act_layer.lower() in ["silu", "gelu", "relu"]: ssm_act_layer: nn.Module = _ACTLAYERS[ssm_act_layer.lower()] if isinstance(mlp_act_layer, str) and mlp_act_layer.lower() in ["silu", "gelu", "relu"]: mlp_act_layer: nn.Module = _ACTLAYERS[mlp_act_layer.lower()] _make_patch_embed = self._make_patch_embed_v2 self.patch_embed = _make_patch_embed(in_chans, dims[0], patch_size, patch_norm, norm_layer) _make_downsample = self._make_downsample_v3 # self.encoder_layers = [nn.ModuleList()] * self.num_layers self.encoder_layers = [] self.fuse_layers = [] self.decoder_layers = [] for i_layer in range(self.num_layers): # downsample = _make_downsample( # self.dims[i_layer], # self.dims[i_layer + 1], # norm_layer=norm_layer, # ) if (i_layer < self.num_layers - 1) else nn.Identity() downsample = _make_downsample( self.dims[i_layer - 1], self.dims[i_layer], norm_layer=norm_layer, ) if (i_layer != 0) else nn.Identity() # ZSJ 修改为i_layer != 0,也就是第一层不下采样,和论文的图保持一致,也方便我取出每个尺度处理好的特征 self.encoder_layers.append(self._make_layer( dim = self.dims[i_layer], drop_path = dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], use_checkpoint=use_checkpoint, norm_layer=norm_layer, downsample=downsample, # ================= ssm_d_state=ssm_d_state, ssm_ratio=ssm_ratio, ssm_dt_rank=ssm_dt_rank, ssm_act_layer=ssm_act_layer, ssm_conv=ssm_conv, ssm_conv_bias=ssm_conv_bias, ssm_drop_rate=ssm_drop_rate, ssm_init=ssm_init, forward_type=forward_type, # ================= mlp_ratio=mlp_ratio, mlp_act_layer=mlp_act_layer, mlp_drop_rate=mlp_drop_rate, )) self.fuse_layers.append(Fuse_Block(in_channel=self.dims[i_layer])) if i_layer != 0: self.decoder_layers.append(Decoder_Block(in_channel=self.dims[i_layer], out_channel=self.dims[i_layer-1])) self.encoder_block1, self.encoder_block2, self.encoder_block3, self.encoder_block4 = self.encoder_layers self.fuse_block1, self.fuse_block2, self.fuse_block3, self.fuse_block4 = self.fuse_layers self.deocder_block1, self.deocder_block2, self.deocder_block3 = self.decoder_layers # self.classifier = nn.Sequential(OrderedDict( # norm=norm_layer(self.num_features), # B,H,W,C # permute=Permute(0, 3, 1, 2), # avgpool=nn.AdaptiveAvgPool2d(1), # flatten=nn.Flatten(1), # head=nn.Linear(self.num_features, num_classes), # )) self.upsample_x4 = nn.Sequential( nn.Conv2d(self.dims[0], self.dims[0]//2, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(self.dims[0]//2), nn.ReLU(inplace=True), nn.UpsamplingBilinear2d(scale_factor=2), nn.Conv2d(self.dims[0]//2, 8, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(8), nn.ReLU(inplace=True), nn.UpsamplingBilinear2d(scale_factor=2) ) self.conv_out_change = nn.Conv2d(8, 1, kernel_size=7, stride=1, padding=3) self.apply(self._init_weights) def _init_weights(self, m: nn.Module): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) @staticmethod def _make_patch_embed_v2(in_chans=3, embed_dim=96, patch_size=4, patch_norm=True, norm_layer=nn.LayerNorm): assert patch_size == 4 return nn.Sequential( nn.Conv2d(in_chans, embed_dim // 2, kernel_size=3, stride=2, padding=1), (Permute(0, 2, 3, 1) if patch_norm else nn.Identity()), (norm_layer(embed_dim // 2) if patch_norm else nn.Identity()), (Permute(0, 3, 1, 2) if patch_norm else nn.Identity()), nn.GELU(), nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1), Permute(0, 2, 3, 1), (norm_layer(embed_dim) if patch_norm else nn.Identity()), ) @staticmethod def _make_downsample_v3(dim=96, out_dim=192, norm_layer=nn.LayerNorm): return nn.Sequential( Permute(0, 3, 1, 2), nn.Conv2d(dim, out_dim, kernel_size=3, stride=2, padding=1), Permute(0, 2, 3, 1), norm_layer(out_dim), ) @staticmethod def _make_layer( dim=96, drop_path=[0.1, 0.1], use_checkpoint=False, norm_layer=nn.LayerNorm, downsample=nn.Identity(), # =========================== ssm_d_state=16, ssm_ratio=2.0, ssm_dt_rank="auto", ssm_act_layer=nn.SiLU, ssm_conv=3, ssm_conv_bias=True, ssm_drop_rate=0.0, ssm_init="v0", forward_type="v2", # =========================== mlp_ratio=4.0, mlp_act_layer=nn.GELU, mlp_drop_rate=0.0, **kwargs, ): depth = len(drop_path) blocks = [] for d in range(depth): blocks.append(VSSBlock( hidden_dim=dim, drop_path=drop_path[d], norm_layer=norm_layer, ssm_d_state=ssm_d_state, ssm_ratio=ssm_ratio, ssm_dt_rank=ssm_dt_rank, ssm_act_layer=ssm_act_layer, ssm_conv=ssm_conv, ssm_conv_bias=ssm_conv_bias, ssm_drop_rate=ssm_drop_rate, ssm_init=ssm_init, forward_type=forward_type, mlp_ratio=mlp_ratio, mlp_act_layer=mlp_act_layer, mlp_drop_rate=mlp_drop_rate, use_checkpoint=use_checkpoint, )) return nn.Sequential(OrderedDict( # ZSJ 把downsample放到前面来,方便我取出encoder中每个尺度处理好的图像,而不是刚刚下采样完的图像 downsample=downsample, blocks=nn.Sequential(*blocks,), )) def forward(self, x1: torch.Tensor, x2: torch.Tensor): x1 = self.patch_embed(x1) x2 = self.patch_embed(x2) x1_1 = self.encoder_block1(x1) x1_2 = self.encoder_block2(x1_1) x1_3 = self.encoder_block3(x1_2) x1_4 = self.encoder_block4(x1_3) # b,h,w,c x2_1 = self.encoder_block1(x2) x2_2 = self.encoder_block2(x2_1) x2_3 = self.encoder_block3(x2_2) x2_4 = self.encoder_block4(x2_3) # b,h,w,c fuse_1 = self.fuse_block1(x1_1, x2_1) fuse_2 = self.fuse_block2(x1_2, x2_2) fuse_3 = self.fuse_block3(x1_3, x2_3) fuse_4 = self.fuse_block4(x1_4, x2_4) decode_3 = self.deocder_block3(fuse_4, fuse_3) decode_2 = self.deocder_block2(decode_3, fuse_2) decode_1 = self.deocder_block1(decode_2, fuse_1) output = self.upsample_x4(decode_1) output = self.conv_out_change(output) return output