|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from .intern_vit_6b import InternViT6B |
|
|
|
|
|
|
|
|
def build_model(config): |
|
|
model_type = config.MODEL.TYPE |
|
|
if model_type == 'intern_vit_6b': |
|
|
model = InternViT6B( |
|
|
num_classes=config.MODEL.NUM_CLASSES, |
|
|
patch_size=config.MODEL.INTERN_VIT_6B.PATCH_SIZE, |
|
|
img_size=config.DATA.IMG_SIZE, |
|
|
pretrain_size=config.MODEL.INTERN_VIT_6B.PRETRAIN_SIZE, |
|
|
qkv_bias=config.MODEL.INTERN_VIT_6B.QKV_BIAS, |
|
|
drop_path_rate=config.MODEL.DROP_PATH_RATE, |
|
|
embed_dim=config.MODEL.INTERN_VIT_6B.EMBED_DIM, |
|
|
num_heads=config.MODEL.INTERN_VIT_6B.NUM_HEADS, |
|
|
mlp_ratio=config.MODEL.INTERN_VIT_6B.MLP_RATIO, |
|
|
init_values=config.MODEL.INTERN_VIT_6B.INIT_VALUES, |
|
|
qk_normalization=config.MODEL.INTERN_VIT_6B.QK_NORMALIZATION, |
|
|
depth=config.MODEL.INTERN_VIT_6B.DEPTH, |
|
|
use_flash_attn=config.MODEL.INTERN_VIT_6B.USE_FLASH_ATTN, |
|
|
with_cp=config.TRAIN.USE_CHECKPOINT, |
|
|
freeze_vit=config.MODEL.INTERN_VIT_6B.FREEZE_VIT, |
|
|
pretrained=config.MODEL.INTERN_VIT_6B.PRETRAINED, |
|
|
cls_target=config.MODEL.INTERN_VIT_6B.CLS_TARGET, |
|
|
head_norm_type=config.MODEL.INTERN_VIT_6B.HEAD_NORM_TYPE, |
|
|
) |
|
|
else: |
|
|
raise NotImplementedError(f'Unkown model: {model_type}') |
|
|
|
|
|
return model |
|
|
|