import os from functools import partial import torch from .vmamba import VSSM from .csms6s import flops_selective_scan_fn,flops_selective_scan_ref def build_vssm_model(config, **kwargs): model_type = config.MODEL.TYPE if model_type in ["vssm"]: model = VSSM( patch_size=config.MODEL.VSSM.PATCH_SIZE, in_chans=config.MODEL.VSSM.IN_CHANS, num_classes=config.MODEL.NUM_CLASSES, depths=config.MODEL.VSSM.DEPTHS, dims=config.MODEL.VSSM.EMBED_DIM, # =================== ssm_d_state=config.MODEL.VSSM.SSM_D_STATE, ssm_ratio=config.MODEL.VSSM.SSM_RATIO, ssm_rank_ratio=config.MODEL.VSSM.SSM_RANK_RATIO, ssm_dt_rank=("auto" if config.MODEL.VSSM.SSM_DT_RANK == "auto" else int(config.MODEL.VSSM.SSM_DT_RANK)), ssm_act_layer=config.MODEL.VSSM.SSM_ACT_LAYER, ssm_conv=config.MODEL.VSSM.SSM_CONV, ssm_conv_bias=config.MODEL.VSSM.SSM_CONV_BIAS, ssm_drop_rate=config.MODEL.VSSM.SSM_DROP_RATE, ssm_init=config.MODEL.VSSM.SSM_INIT, forward_type=config.MODEL.VSSM.SSM_FORWARDTYPE, # =================== mlp_ratio=config.MODEL.VSSM.MLP_RATIO, mlp_act_layer=config.MODEL.VSSM.MLP_ACT_LAYER, mlp_drop_rate=config.MODEL.VSSM.MLP_DROP_RATE, # =================== drop_path_rate=config.MODEL.DROP_PATH_RATE, patch_norm=config.MODEL.VSSM.PATCH_NORM, norm_layer=config.MODEL.VSSM.NORM_LAYER, downsample_version=config.MODEL.VSSM.DOWNSAMPLE, patchembed_version=config.MODEL.VSSM.PATCHEMBED, gmlp=config.MODEL.VSSM.GMLP, use_checkpoint=config.TRAIN.USE_CHECKPOINT, # =================== posembed=config.MODEL.VSSM.POSEMBED, imgsize=config.DATA.IMG_SIZE, ) return model return None def build_model(config, is_pretrain=False): model = None if model is None: model = build_vssm_model(config, is_pretrain) return model