File size: 2,130 Bytes
7ff4dd0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import os
from functools import partial
import torch
from .vmamba import VSSM
from .csms6s import flops_selective_scan_fn,flops_selective_scan_ref
def build_vssm_model(config, **kwargs):
model_type = config.MODEL.TYPE
if model_type in ["vssm"]:
model = VSSM(
patch_size=config.MODEL.VSSM.PATCH_SIZE,
in_chans=config.MODEL.VSSM.IN_CHANS,
num_classes=config.MODEL.NUM_CLASSES,
depths=config.MODEL.VSSM.DEPTHS,
dims=config.MODEL.VSSM.EMBED_DIM,
# ===================
ssm_d_state=config.MODEL.VSSM.SSM_D_STATE,
ssm_ratio=config.MODEL.VSSM.SSM_RATIO,
ssm_rank_ratio=config.MODEL.VSSM.SSM_RANK_RATIO,
ssm_dt_rank=("auto" if config.MODEL.VSSM.SSM_DT_RANK == "auto" else int(config.MODEL.VSSM.SSM_DT_RANK)),
ssm_act_layer=config.MODEL.VSSM.SSM_ACT_LAYER,
ssm_conv=config.MODEL.VSSM.SSM_CONV,
ssm_conv_bias=config.MODEL.VSSM.SSM_CONV_BIAS,
ssm_drop_rate=config.MODEL.VSSM.SSM_DROP_RATE,
ssm_init=config.MODEL.VSSM.SSM_INIT,
forward_type=config.MODEL.VSSM.SSM_FORWARDTYPE,
# ===================
mlp_ratio=config.MODEL.VSSM.MLP_RATIO,
mlp_act_layer=config.MODEL.VSSM.MLP_ACT_LAYER,
mlp_drop_rate=config.MODEL.VSSM.MLP_DROP_RATE,
# ===================
drop_path_rate=config.MODEL.DROP_PATH_RATE,
patch_norm=config.MODEL.VSSM.PATCH_NORM,
norm_layer=config.MODEL.VSSM.NORM_LAYER,
downsample_version=config.MODEL.VSSM.DOWNSAMPLE,
patchembed_version=config.MODEL.VSSM.PATCHEMBED,
gmlp=config.MODEL.VSSM.GMLP,
use_checkpoint=config.TRAIN.USE_CHECKPOINT,
# ===================
posembed=config.MODEL.VSSM.POSEMBED,
imgsize=config.DATA.IMG_SIZE,
)
return model
return None
def build_model(config, is_pretrain=False):
model = None
if model is None:
model = build_vssm_model(config, is_pretrain)
return model
|