import os import time import math import copy from functools import partial from typing import Optional, Callable, Any from collections import OrderedDict import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from einops import rearrange, repeat from timm.models.layers import DropPath, trunc_normal_ from fvcore.nn import FlopCountAnalysis, flop_count_str, flop_count, parameter_count from torchvision.models import VisionTransformer DropPath.__repr__ = lambda self: f"timm.DropPath({self.drop_prob})" torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = True from rscd.models.backbones.lamba_util.csms6s import SelectiveScanCuda from rscd.models.backbones.lamba_util.utils import Scan_FB_S, Merge_FB_S, CrossMergeS, CrossScanS, \ local_scan_zero_ones, reverse_local_scan_zero_ones from rscd.models.backbones.lamba_util.csms6s import flops_selective_scan_fn, flops_selective_scan_ref, selective_scan_flop_jit def my_gumbel_softmax(logits, k): # 添加 Gumbel noise gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits))) gumbel_logits = logits + gumbel_noise # 获取 top-k 的索引 topk_indices = torch.topk(gumbel_logits, k=k, dim=-1).indices # 构造 top-k one-hot 分布 topk_onehot = torch.zeros_like(logits) topk_onehot.scatter_(dim=-1, index=topk_indices, value=1.0) return topk_onehot def window_expansion(x, H, W): # x [b, 1, 4, 1, 1] b, _, num_win = x.shape H1, W1 = int(H/4), int(W/4) num_win1 = int(num_win/4) x = x.reshape(b, 1, num_win1, num_win1, 1).squeeze(-1) x = F.interpolate(x, scale_factor=H1) return x def window_partition(x, quad_size=2): """ Args: x: (B, H, W, C) window_size (int): window size Returns: windows: (B, num_windows, window_size, window_size, C) """ B, C, H, W = x.shape H_quad = H // quad_size W_quad = W // quad_size x = x.view(B, C, quad_size, H_quad, quad_size, W_quad) windows = x.permute(0, 2, 4, 3, 5, 1).contiguous().view(B, -1, H_quad, W_quad, C) #.permute(0, 2, 1, 3, 4) return windows def window_reverse(windows): """ Args: windows: (B, C, num_windows, window_size, window_size) window_size (int): Window size H (int): Height of image W (int): Width of image Returns: x: (B, N, H, W, C) """ B, N, H_l, W_l, C = windows.shape scale = int((N)**0.5) H = H_l * scale W = W_l * scale x = windows.permute(0, 4, 1, 2, 3) x = x.view(B, C, N // scale, N // scale, H_l, W_l) x = x.permute(0, 1, 2, 4, 3, 5).contiguous().view(B, C, H, W) return x class Predictor(nn.Module): """ Image to Patch Embedding """ def __init__(self, embed_dim=384): super().__init__() self.in_conv = nn.Sequential( nn.LayerNorm(embed_dim), nn.Linear(embed_dim, embed_dim), nn.GELU() ) self.out_conv = nn.Sequential( nn.Linear(embed_dim, embed_dim // 2), nn.GELU(), nn.Linear(embed_dim // 2, embed_dim // 4), nn.GELU(), nn.Linear(embed_dim // 4, 2), nn.LogSoftmax(dim=-1) ) def forward(self, x): if len(x.shape) == 4: B, C, H, W = x.size() x_rs = x.reshape(B, C, -1).permute(0, 2, 1) else: B, N, C = x.size() H = int(N**0.5) x_rs = x x_rs = self.in_conv(x_rs) B, N, C = x_rs.size() window_scale = int(H//2) local_x = x_rs[:, :, :C // 2] global_x = x_rs[:, :, C // 2:].view(B, H, -1, C // 2).permute(0, 3, 1, 2) global_x_avg = F.adaptive_avg_pool2d(global_x, (2, 2)) # [b, c, 2, 2] global_x_avg_concat = F.interpolate(global_x_avg, scale_factor=window_scale) global_x_avg_concat = global_x_avg_concat.view(B, C // 2, -1).permute(0, 2, 1).contiguous() x_rs = torch.cat([local_x, global_x_avg_concat], dim=-1) x_score = self.out_conv(x_rs) x_score_rs = x_score.permute(0, 2, 1).reshape(B, 2, H, -1) return x_score_rs # ===================================================== # we have this class as linear and conv init differ from each other # this function enable loading from both conv2d or linear class Linear2d(nn.Linear): def forward(self, x: torch.Tensor): # B, C, H, W = x.shape return F.conv2d(x, self.weight[:, :, None, None], self.bias) def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): state_dict[prefix + "weight"] = state_dict[prefix + "weight"].view(self.weight.shape) return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) class LayerNorm2d(nn.LayerNorm): def forward(self, x: torch.Tensor): x = x.permute(0, 2, 3, 1) x = nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) x = x.permute(0, 3, 1, 2) return x class PatchMerging2D(nn.Module): def __init__(self, dim, out_dim=-1, norm_layer=nn.LayerNorm, channel_first=False): super().__init__() self.dim = dim Linear = Linear2d if channel_first else nn.Linear self._patch_merging_pad = self._patch_merging_pad_channel_first if channel_first else self._patch_merging_pad_channel_last self.reduction = Linear(4 * dim, (2 * dim) if out_dim < 0 else out_dim, bias=False) self.norm = norm_layer(4 * dim) @staticmethod def _patch_merging_pad_channel_last(x: torch.Tensor): H, W, _ = x.shape[-3:] if (W % 2 != 0) or (H % 2 != 0): x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C return x @staticmethod def _patch_merging_pad_channel_first(x: torch.Tensor): H, W = x.shape[-2:] if (W % 2 != 0) or (H % 2 != 0): x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) x0 = x[..., 0::2, 0::2] # ... H/2 W/2 x1 = x[..., 1::2, 0::2] # ... H/2 W/2 x2 = x[..., 0::2, 1::2] # ... H/2 W/2 x3 = x[..., 1::2, 1::2] # ... H/2 W/2 x = torch.cat([x0, x1, x2, x3], 1) # ... H/2 W/2 4*C return x def forward(self, x): x = self._patch_merging_pad(x) x = self.norm(x) x = self.reduction(x) return x class Permute(nn.Module): def __init__(self, *args): super().__init__() self.args = args def forward(self, x: torch.Tensor): return x.permute(*self.args) class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., channels_first=False): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features Linear = Linear2d if channels_first else nn.Linear self.fc1 = Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class gMlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., channels_first=False): super().__init__() self.channel_first = channels_first out_features = out_features or in_features hidden_features = hidden_features or in_features Linear = Linear2d if channels_first else nn.Linear self.fc1 = Linear(in_features, 2 * hidden_features) self.act = act_layer() self.fc2 = Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x: torch.Tensor): x = self.fc1(x) x, z = x.chunk(2, dim=(1 if self.channel_first else -1)) x = self.fc2(x * self.act(z)) x = self.drop(x) return x class SoftmaxSpatial(nn.Softmax): def forward(self, x: torch.Tensor): if self.dim == -1: B, C, H, W = x.shape return super().forward(x.view(B, C, -1)).view(B, C, H, W) elif self.dim == 1: B, H, W, C = x.shape return super().forward(x.view(B, -1, C)).view(B, H, W, C) else: raise NotImplementedError # ===================================================== class mamba_init: @staticmethod def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4, **factory_kwargs): dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs) # Initialize special dt projection to preserve variance at initialization dt_init_std = dt_rank ** -0.5 * dt_scale if dt_init == "constant": nn.init.constant_(dt_proj.weight, dt_init_std) elif dt_init == "random": nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std) else: raise NotImplementedError # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max dt = torch.exp( torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) ).clamp(min=dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) with torch.no_grad(): dt_proj.bias.copy_(inv_dt) # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit # dt_proj.bias._no_reinit = True return dt_proj @staticmethod def A_log_init(d_state, d_inner, copies=-1, device=None, merge=True): # S4D real initialization A = repeat( torch.arange(1, d_state + 1, dtype=torch.float32, device=device), "n -> d n", d=d_inner, ).contiguous() A_log = torch.log(A) # Keep A_log in fp32 if copies > 0: A_log = repeat(A_log, "d n -> r d n", r=copies) if merge: A_log = A_log.flatten(0, 1) A_log = nn.Parameter(A_log) A_log._no_weight_decay = True return A_log @staticmethod def D_init(d_inner, copies=-1, device=None, merge=True): # D "skip" parameter D = torch.ones(d_inner, device=device) if copies > 0: D = repeat(D, "n1 -> r n1", r=copies) if merge: D = D.flatten(0, 1) D = nn.Parameter(D) # Keep in fp32 D._no_weight_decay = True return D def shift_size_generate(index=0, H=0): sz = int(H // 8) if (index%5)==1: shift_size = (sz, sz) reverse_size = (-sz, -sz) elif (index%5)==2: shift_size = (-sz, -sz) reverse_size = (sz, sz) elif (index % 5) == 3: shift_size = (sz, -sz) reverse_size = (-sz, sz) elif (index%5)== 4: shift_size = (-sz, sz) reverse_size = (sz, -sz) return shift_size, reverse_size # support: v01-v05; v051d,v052d,v052dc; # postfix: _onsigmoid,_onsoftmax,_ondwconv3,_onnone;_nozact,_noz;_oact;_no32; # history support: v2,v3;v31d,v32d,v32dc; class SS2Dv2: def __initv2__( self, # basic dims =========== d_model=96, d_state=16, ssm_ratio=2.0, dt_rank="auto", act_layer=nn.SiLU, # dwconv =============== d_conv=3, # < 2 means no conv conv_bias=True, # ====================== dropout=0.0, bias=False, # dt init ============== dt_min=0.001, dt_max=0.1, dt_init="random", dt_scale=1.0, dt_init_floor=1e-4, initialize="v0", # ====================== forward_type="v2", channel_first=False, channel_divide = 1, stage_num = 0, depth_num =0, block_depth = 0, # ====================== **kwargs, ): factory_kwargs = {"device": None, "dtype": None} super().__init__() d_proj = int(ssm_ratio * d_model) self.channel_divide = int(channel_divide) d_inner = int((ssm_ratio * d_model)//channel_divide) dt_rank = math.ceil(d_model / 16) if dt_rank == "auto" else dt_rank self.d_conv = d_conv self.channel_first = channel_first self.with_dconv = d_conv > 1 Linear = Linear2d if channel_first else nn.Linear self.forward = self.forwardv2 # tags for forward_type ============================== def checkpostfix(tag, value): ret = value[-len(tag):] == tag if ret: value = value[:-len(tag)] return ret, value self.disable_force32 = False, #checkpostfix("_no32", forward_type) self.oact = False # checkpostfix("_oact", forward_type) self.disable_z = True # checkpostfix("_noz", forward_type) self.disable_z_act = False # checkpostfix("_nozact", forward_type) self.out_norm_none = False self.out_norm_dwconv3 = False self.out_norm_softmax = False self.out_norm_sigmoid = False if self.out_norm_none: self.out_norm = nn.Identity() elif self.out_norm_dwconv3: self.out_norm = nn.Sequential( (nn.Identity() if channel_first else Permute(0, 3, 1, 2)), nn.Conv2d(d_proj, d_proj, kernel_size=3, padding=1, groups=d_proj, bias=False), (nn.Identity() if channel_first else Permute(0, 2, 3, 1)), ) elif self.out_norm_softmax: self.out_norm = SoftmaxSpatial(dim=(-1 if channel_first else 1)) elif self.out_norm_sigmoid: self.out_norm = nn.Sigmoid() else: LayerNorm = LayerNorm2d if channel_first else nn.LayerNorm self.out_norm = LayerNorm(d_proj * 2) # forward_type debug ======================================= self.forward_core = partial(self.forward_core, force_fp32=True, no_einsum=True) #FORWARD_TYPES.get(forward_type, None) self.stage_num = stage_num self.depth_num = depth_num # self.block_index = (sum(block_depth[0:stage_num]) + depth_num)if stage_num>=1 else depth_num self.quad_flag = False self.shift_flag = False if self.stage_num == 0 or self.stage_num==1: k_group = 4 # 4 self.score_predictor = Predictor(d_proj) self.quad_flag = True if self.depth_num % 5: self.shift_flag = True else: k_group = 4 # 4 # in proj ======================================= #d_proj = d_inner if self.disable_z else (d_inner * 2) self.in_proj = Linear(d_model * 2, d_proj * 2, bias=bias, **factory_kwargs) self.act: nn.Module = act_layer() # conv ======================================= if self.with_dconv: self.conv2d = nn.Conv2d( in_channels=d_proj * 2, out_channels=d_proj * 2, groups=d_proj * 2, bias=conv_bias, kernel_size=d_conv, padding=(d_conv - 1) // 2, **factory_kwargs, ) # x proj ============================ self.x_proj = [ nn.Linear(d_inner, (dt_rank + d_state * 2), bias=False, **factory_kwargs) for _ in range(k_group) ] self.x_proj_weight = nn.Parameter(torch.stack([t.weight for t in self.x_proj], dim=0)) # (K, N, inner) del self.x_proj # out proj ======================================= self.out_act = nn.GELU() if self.oact else nn.Identity() self.out_proj = Linear(d_proj * 2, d_model * 2, bias=bias, **factory_kwargs) self.dropout = nn.Dropout(dropout) if dropout > 0. else nn.Identity() if initialize: # dt proj ============================ self.dt_projs = [ self.dt_init(dt_rank, d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor, **factory_kwargs) for _ in range(k_group) ] self.dt_projs_weight = nn.Parameter( torch.stack([t.weight for t in self.dt_projs], dim=0)) # (K, inner, rank) self.dt_projs_bias = nn.Parameter(torch.stack([t.bias for t in self.dt_projs], dim=0)) # (K, inner) del self.dt_projs # A, D ======================================= self.A_logs = self.A_log_init(d_state, d_inner, copies=k_group, merge=True) # (K * D, N) self.Ds = self.D_init(d_inner, copies=k_group, merge=True) # (K * D) def forward_core( self, x: torch.Tensor = None, # ============================== to_dtype=True, # True: final out to dtype force_fp32=False, # True: input fp32 # ============================== ssoflex=True, # True: out fp32 in SSOflex; else, SSOflex is the same as SSCore # ============================== SelectiveScan=SelectiveScanCuda, CrossScan=CrossScanS, CrossMerge=CrossMergeS, no_einsum=False, # replace einsum with linear or conv1d to raise throughput # ============================== **kwargs, ): x_proj_weight = self.x_proj_weight x_proj_bias = getattr(self, "x_proj_bias", None) dt_projs_weight = self.dt_projs_weight dt_projs_bias = self.dt_projs_bias A_logs = self.A_logs Ds = self.Ds delta_softplus = True out_norm = getattr(self, "out_norm", None) channel_first = self.channel_first to_fp32 = lambda *args: (_a.to(torch.float32) for _a in args) B, D, H, W = x.shape _, N = A_logs.shape K, _, R = dt_projs_weight.shape L = H * W def selective_scan(u, delta, A, B, C, D=None, delta_bias=None, delta_softplus=True): return SelectiveScan.apply(u, delta, A, B, C, D, delta_bias, delta_softplus, ssoflex, "mamba") if self.quad_flag: # score prediction+ quad_size = int(2) quad_number = quad_size * quad_size xA, xB = x.split(x.shape[1] // 2, 1) score = self.score_predictor(xA - xB) if self.shift_flag: shift_size, reverse_size = shift_size_generate(self.depth_num, H) x = torch.roll(x, shifts=shift_size, dims=(2, 3)) if H % quad_number != 0 or W % quad_number != 0: newH, newW = math.ceil(H / quad_number) * quad_number, math.ceil(W / quad_number) * quad_number diff_H, diff_W = newH - H, newW - W x = F.pad(x, (0, diff_H, 0, diff_W, 0, 0)) score = F.pad(score, (0, diff_H, 0, diff_W, 0, 0)) B, D, H, W = x.shape L = H * W diff_flag = True else: diff_flag = False ### quad_one_stage score_window = F.adaptive_avg_pool2d(score[:, 1, :, :], (4, 4)) # b, 1, 2, 2 locality_decision = my_gumbel_softmax(score_window.view(B, 1, -1), k = 6) # [b, 1, 4, 1, 1] locality = window_expansion(locality_decision, H=int(H), W=int(W)) # [b, 1, l] xs_zeros_ones = None len_zeros = [] indices_zeros = [] # num_zeros = [] indices_ones = [] num_ones = [] for i in range(B): x_zeros, x_ones, sub_len_zeros, sub_indices_zeros, sub_indices_ones, sub_num_ones = local_scan_zero_ones(locality[i], x[i]) len_zeros.append(sub_len_zeros) indices_zeros.append(sub_indices_zeros) # num_zeros.append(sub_num_zeros) indices_ones.append(sub_indices_ones) num_ones.append(sub_num_ones) x_zeros_ones = torch.cat([x_zeros, x_ones], dim=-1) if xs_zeros_ones is None: xs_zeros_ones = x_zeros_ones.unsqueeze(0) else: xs_zeros_ones = torch.cat([xs_zeros_ones, x_zeros_ones.unsqueeze(0)], dim=0) xs_1 = Scan_FB_S.apply(xs_zeros_ones) # b, k, c, l xs_zeros_ones_h = None len_zeros_h = [] indices_zeros_h = [] # num_zeros_h = [] indices_ones_h = [] num_ones_h = [] for i in range(B): x_zeros_h, x_ones_h, sub_len_zeros_h, sub_indices_zeros_h, sub_indices_ones_h, sub_num_ones_h = local_scan_zero_ones(locality[i], x[i], h_scan=True) len_zeros_h.append(sub_len_zeros_h) indices_zeros_h.append(sub_indices_zeros_h) # num_zeros_h.append(sub_num_zeros_h) indices_ones_h.append(sub_indices_ones_h) num_ones_h.append(sub_num_ones_h) x_zeros_ones_h = torch.cat([x_zeros_h, x_ones_h], dim=-1) if xs_zeros_ones_h is None: xs_zeros_ones_h = x_zeros_ones_h.unsqueeze(0) else: xs_zeros_ones_h = torch.cat([xs_zeros_ones_h, x_zeros_ones_h.unsqueeze(0)], dim=0) xs_2 = Scan_FB_S.apply(xs_zeros_ones_h) # b, k, c, l xs = torch.cat([xs_1, xs_2], dim=1) else: xs = CrossScan.apply(x) L = L * 2 D = D // 2 if no_einsum: x_dbl = F.conv1d(xs.view(B, -1, L), x_proj_weight.view(-1, D, 1), bias=(x_proj_bias.view(-1) if x_proj_bias is not None else None), groups=K) dts, Bs, Cs = torch.split(x_dbl.view(B, K, -1, L), [R, N, N], dim=2) dts = F.conv1d(dts.contiguous().view(B, -1, L), dt_projs_weight.view(K * D, -1, 1), groups=K) else: x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs, x_proj_weight) if x_proj_bias is not None: x_dbl = x_dbl + x_proj_bias.view(1, K, -1, 1) dts, Bs, Cs = torch.split(x_dbl, [R, N, N], dim=2) dts = torch.einsum("b k r l, k d r -> b k d l", dts, dt_projs_weight) xs = xs.view(B, -1, L) dts = dts.contiguous().view(B, -1, L) As = -torch.exp(A_logs.to(torch.float)) # (k * c, d_state) Bs = Bs.contiguous().view(B, K, N, L) Cs = Cs.contiguous().view(B, K, N, L) Ds = Ds.to(torch.float) # (K * c) delta_bias = dt_projs_bias.view(-1).to(torch.float) if force_fp32: xs, dts, Bs, Cs = to_fp32(xs, dts, Bs, Cs) ys: torch.Tensor = selective_scan( xs, dts, As, Bs, Cs, Ds, delta_bias, delta_softplus ).view(B, K, -1, L) if self.quad_flag: y1 = Merge_FB_S.apply(ys[:, 0:2]) # BCL y2 = Merge_FB_S.apply(ys[:, 2:]) # BCL L = L // 2 D = D * 2 # for quad y = None for i in range(B): y_1 = reverse_local_scan_zero_ones(indices_zeros[i], indices_ones[i], num_ones[i], y1[i, ..., :len_zeros[i]], y1[i, ..., len_zeros[i]:]) y_2 = reverse_local_scan_zero_ones(indices_zeros_h[i], indices_ones_h[i], num_ones_h[i], y2[i, ..., :len_zeros_h[i]], y2[i, ..., len_zeros_h[i]:], h_scan=True) sub_y = y_1 + y_2 if y is None: y = sub_y.unsqueeze(0) else: y = torch.cat([y, sub_y.unsqueeze(0)], dim=0) if diff_flag: y = y.reshape(B, D, H, -1) y = y[:, :, 0:-diff_H, 0:-diff_W].contiguous() H, W = H - diff_H, W - diff_W else: y = y.view(B, D, H, -1) if self.shift_flag: y = torch.roll(y, shifts=reverse_size, dims=(2, 3)) else: ys = ys.view(B, K, D, H, W * 2) y: torch.Tensor = CrossMerge.apply(ys) L = L // 2 D = D * 2 y = y.view(B, -1, H, W) if not channel_first: y = y.view(B, -1, H * W).transpose(dim0=1, dim1=2).contiguous().view(B, H, W, -1) # (B, L, C) y = out_norm(y) return (y.to(x.dtype) if to_dtype else y) def forwardv2(self, x: torch.Tensor, **kwargs): x = self.in_proj(x) # 384 if not self.channel_first: x = x.permute(0, 3, 1, 2).contiguous() if self.with_dconv: x = self.conv2d(x) # (b, d, h, w) x = self.act(x) y = self.forward_core(x) y = self.out_act(y) out = self.dropout(self.out_proj(y)) return out class SS2D(nn.Module, mamba_init, SS2Dv2): def __init__( self, # basic dims =========== d_model=96, d_state=16, ssm_ratio=2.0, dt_rank="auto", act_layer=nn.SiLU, # dwconv =============== d_conv=3, # < 2 means no conv conv_bias=True, # ====================== dropout=0.0, bias=False, # dt init ============== dt_min=0.001, dt_max=0.1, dt_init="random", dt_scale=1.0, dt_init_floor=1e-4, initialize="v0", # ====================== forward_type="v2", channel_first=False, channel_divide = 1, stage_num = 0, depth_num = 0, block_depth = 0, # ====================== **kwargs, ): super().__init__() kwargs.update( d_model=d_model, d_state=d_state, ssm_ratio=ssm_ratio, dt_rank=dt_rank, act_layer=act_layer, d_conv=d_conv, conv_bias=conv_bias, dropout=dropout, bias=bias, dt_min=dt_min, dt_max=dt_max, dt_init=dt_init, dt_scale=dt_scale, dt_init_floor=dt_init_floor, initialize=initialize, forward_type=forward_type, channel_first=channel_first, channel_divide =channel_divide,stage_num = stage_num,depth_num=depth_num, block_depth=block_depth, ) self.__initv2__(**kwargs) return