| | import os |
| | from functools import partial |
| | import torch |
| |
|
| | from .vmamba import VSSM |
| | from .csms6s import flops_selective_scan_fn,flops_selective_scan_ref |
| |
|
| |
|
| | def build_vssm_model(config, **kwargs): |
| | model_type = config.MODEL.TYPE |
| | if model_type in ["vssm"]: |
| | model = VSSM( |
| | patch_size=config.MODEL.VSSM.PATCH_SIZE, |
| | in_chans=config.MODEL.VSSM.IN_CHANS, |
| | num_classes=config.MODEL.NUM_CLASSES, |
| | depths=config.MODEL.VSSM.DEPTHS, |
| | dims=config.MODEL.VSSM.EMBED_DIM, |
| | |
| | ssm_d_state=config.MODEL.VSSM.SSM_D_STATE, |
| | ssm_ratio=config.MODEL.VSSM.SSM_RATIO, |
| | ssm_rank_ratio=config.MODEL.VSSM.SSM_RANK_RATIO, |
| | ssm_dt_rank=("auto" if config.MODEL.VSSM.SSM_DT_RANK == "auto" else int(config.MODEL.VSSM.SSM_DT_RANK)), |
| | ssm_act_layer=config.MODEL.VSSM.SSM_ACT_LAYER, |
| | ssm_conv=config.MODEL.VSSM.SSM_CONV, |
| | ssm_conv_bias=config.MODEL.VSSM.SSM_CONV_BIAS, |
| | ssm_drop_rate=config.MODEL.VSSM.SSM_DROP_RATE, |
| | ssm_init=config.MODEL.VSSM.SSM_INIT, |
| | forward_type=config.MODEL.VSSM.SSM_FORWARDTYPE, |
| | |
| | mlp_ratio=config.MODEL.VSSM.MLP_RATIO, |
| | mlp_act_layer=config.MODEL.VSSM.MLP_ACT_LAYER, |
| | mlp_drop_rate=config.MODEL.VSSM.MLP_DROP_RATE, |
| | |
| | drop_path_rate=config.MODEL.DROP_PATH_RATE, |
| | patch_norm=config.MODEL.VSSM.PATCH_NORM, |
| | norm_layer=config.MODEL.VSSM.NORM_LAYER, |
| | downsample_version=config.MODEL.VSSM.DOWNSAMPLE, |
| | patchembed_version=config.MODEL.VSSM.PATCHEMBED, |
| | gmlp=config.MODEL.VSSM.GMLP, |
| | use_checkpoint=config.TRAIN.USE_CHECKPOINT, |
| | |
| | posembed=config.MODEL.VSSM.POSEMBED, |
| | imgsize=config.DATA.IMG_SIZE, |
| | ) |
| | return model |
| |
|
| | return None |
| |
|
| |
|
| | def build_model(config, is_pretrain=False): |
| | model = None |
| | if model is None: |
| | model = build_vssm_model(config, is_pretrain) |
| | return model |
| |
|
| |
|
| |
|
| |
|
| |
|