| | |
| | |
| | |
| | |
| | |
| |
|
| | from .intern_vit_6b import InternViT6B |
| |
|
| |
|
| | def build_model(config): |
| | model_type = config.MODEL.TYPE |
| | if model_type == 'intern_vit_6b': |
| | model = InternViT6B( |
| | num_classes=config.MODEL.NUM_CLASSES, |
| | patch_size=config.MODEL.INTERN_VIT_6B.PATCH_SIZE, |
| | img_size=config.DATA.IMG_SIZE, |
| | pretrain_size=config.MODEL.INTERN_VIT_6B.PRETRAIN_SIZE, |
| | qkv_bias=config.MODEL.INTERN_VIT_6B.QKV_BIAS, |
| | drop_path_rate=config.MODEL.DROP_PATH_RATE, |
| | embed_dim=config.MODEL.INTERN_VIT_6B.EMBED_DIM, |
| | num_heads=config.MODEL.INTERN_VIT_6B.NUM_HEADS, |
| | mlp_ratio=config.MODEL.INTERN_VIT_6B.MLP_RATIO, |
| | init_values=config.MODEL.INTERN_VIT_6B.INIT_VALUES, |
| | qk_normalization=config.MODEL.INTERN_VIT_6B.QK_NORMALIZATION, |
| | depth=config.MODEL.INTERN_VIT_6B.DEPTH, |
| | use_flash_attn=config.MODEL.INTERN_VIT_6B.USE_FLASH_ATTN, |
| | with_cp=config.TRAIN.USE_CHECKPOINT, |
| | freeze_vit=config.MODEL.INTERN_VIT_6B.FREEZE_VIT, |
| | pretrained=config.MODEL.INTERN_VIT_6B.PRETRAINED, |
| | cls_target=config.MODEL.INTERN_VIT_6B.CLS_TARGET, |
| | head_norm_type=config.MODEL.INTERN_VIT_6B.HEAD_NORM_TYPE, |
| | ) |
| | else: |
| | raise NotImplementedError(f'Unkown model: {model_type}') |
| |
|
| | return model |
| |
|