|
|
from .iresnet import iresnet100 |
|
|
from .iresnet import iresnet18 |
|
|
from .iresnet import iresnet200 |
|
|
from .iresnet import iresnet34 |
|
|
from .iresnet import iresnet50 |
|
|
from .mobilefacenet import get_mbf |
|
|
|
|
|
|
|
|
def get_model(name, **kwargs): |
|
|
|
|
|
if name == "r18": |
|
|
return iresnet18(False, **kwargs) |
|
|
elif name == "r34": |
|
|
return iresnet34(False, **kwargs) |
|
|
elif name == "r50": |
|
|
return iresnet50(False, **kwargs) |
|
|
elif name == "r100": |
|
|
return iresnet100(False, **kwargs) |
|
|
elif name == "r200": |
|
|
return iresnet200(False, **kwargs) |
|
|
elif name == "r2060": |
|
|
from .iresnet2060 import iresnet2060 |
|
|
|
|
|
return iresnet2060(False, **kwargs) |
|
|
|
|
|
elif name == "mbf": |
|
|
fp16 = kwargs.get("fp16", False) |
|
|
num_features = kwargs.get("num_features", 512) |
|
|
return get_mbf(fp16=fp16, num_features=num_features) |
|
|
|
|
|
elif name == "mbf_large": |
|
|
from .mobilefacenet import get_mbf_large |
|
|
|
|
|
fp16 = kwargs.get("fp16", False) |
|
|
num_features = kwargs.get("num_features", 512) |
|
|
return get_mbf_large(fp16=fp16, num_features=num_features) |
|
|
|
|
|
elif name == "vit_t": |
|
|
num_features = kwargs.get("num_features", 512) |
|
|
from .vit import VisionTransformer |
|
|
|
|
|
return VisionTransformer( |
|
|
img_size=112, |
|
|
patch_size=9, |
|
|
num_classes=num_features, |
|
|
embed_dim=256, |
|
|
depth=12, |
|
|
num_heads=8, |
|
|
drop_path_rate=0.1, |
|
|
norm_layer="ln", |
|
|
mask_ratio=0.1, |
|
|
) |
|
|
|
|
|
elif name == "vit_t_dp005_mask0": |
|
|
num_features = kwargs.get("num_features", 512) |
|
|
from .vit import VisionTransformer |
|
|
|
|
|
return VisionTransformer( |
|
|
img_size=112, |
|
|
patch_size=9, |
|
|
num_classes=num_features, |
|
|
embed_dim=256, |
|
|
depth=12, |
|
|
num_heads=8, |
|
|
drop_path_rate=0.05, |
|
|
norm_layer="ln", |
|
|
mask_ratio=0.0, |
|
|
) |
|
|
|
|
|
elif name == "vit_s": |
|
|
num_features = kwargs.get("num_features", 512) |
|
|
from .vit import VisionTransformer |
|
|
|
|
|
return VisionTransformer( |
|
|
img_size=112, |
|
|
patch_size=9, |
|
|
num_classes=num_features, |
|
|
embed_dim=512, |
|
|
depth=12, |
|
|
num_heads=8, |
|
|
drop_path_rate=0.1, |
|
|
norm_layer="ln", |
|
|
mask_ratio=0.1, |
|
|
) |
|
|
|
|
|
elif name == "vit_s_dp005_mask_0": |
|
|
num_features = kwargs.get("num_features", 512) |
|
|
from .vit import VisionTransformer |
|
|
|
|
|
return VisionTransformer( |
|
|
img_size=112, |
|
|
patch_size=9, |
|
|
num_classes=num_features, |
|
|
embed_dim=512, |
|
|
depth=12, |
|
|
num_heads=8, |
|
|
drop_path_rate=0.05, |
|
|
norm_layer="ln", |
|
|
mask_ratio=0.0, |
|
|
) |
|
|
|
|
|
elif name == "vit_b": |
|
|
|
|
|
num_features = kwargs.get("num_features", 512) |
|
|
from .vit import VisionTransformer |
|
|
|
|
|
return VisionTransformer( |
|
|
img_size=112, |
|
|
patch_size=9, |
|
|
num_classes=num_features, |
|
|
embed_dim=512, |
|
|
depth=24, |
|
|
num_heads=8, |
|
|
drop_path_rate=0.1, |
|
|
norm_layer="ln", |
|
|
mask_ratio=0.1, |
|
|
using_checkpoint=True, |
|
|
) |
|
|
|
|
|
elif name == "vit_b_dp005_mask_005": |
|
|
|
|
|
num_features = kwargs.get("num_features", 512) |
|
|
from .vit import VisionTransformer |
|
|
|
|
|
return VisionTransformer( |
|
|
img_size=112, |
|
|
patch_size=9, |
|
|
num_classes=num_features, |
|
|
embed_dim=512, |
|
|
depth=24, |
|
|
num_heads=8, |
|
|
drop_path_rate=0.05, |
|
|
norm_layer="ln", |
|
|
mask_ratio=0.05, |
|
|
using_checkpoint=True, |
|
|
) |
|
|
|
|
|
elif name == "vit_l_dp005_mask_005": |
|
|
|
|
|
num_features = kwargs.get("num_features", 512) |
|
|
from .vit import VisionTransformer |
|
|
|
|
|
return VisionTransformer( |
|
|
img_size=112, |
|
|
patch_size=9, |
|
|
num_classes=num_features, |
|
|
embed_dim=768, |
|
|
depth=24, |
|
|
num_heads=8, |
|
|
drop_path_rate=0.05, |
|
|
norm_layer="ln", |
|
|
mask_ratio=0.05, |
|
|
using_checkpoint=True, |
|
|
) |
|
|
|
|
|
else: |
|
|
raise ValueError() |
|
|
|