xuehongyang
ser
83d8d3c
from .iresnet import iresnet100
from .iresnet import iresnet18
from .iresnet import iresnet200
from .iresnet import iresnet34
from .iresnet import iresnet50
from .mobilefacenet import get_mbf
def get_model(name, **kwargs):
# resnet
if name == "r18":
return iresnet18(False, **kwargs)
elif name == "r34":
return iresnet34(False, **kwargs)
elif name == "r50":
return iresnet50(False, **kwargs)
elif name == "r100":
return iresnet100(False, **kwargs)
elif name == "r200":
return iresnet200(False, **kwargs)
elif name == "r2060":
from .iresnet2060 import iresnet2060
return iresnet2060(False, **kwargs)
elif name == "mbf":
fp16 = kwargs.get("fp16", False)
num_features = kwargs.get("num_features", 512)
return get_mbf(fp16=fp16, num_features=num_features)
elif name == "mbf_large":
from .mobilefacenet import get_mbf_large
fp16 = kwargs.get("fp16", False)
num_features = kwargs.get("num_features", 512)
return get_mbf_large(fp16=fp16, num_features=num_features)
elif name == "vit_t":
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112,
patch_size=9,
num_classes=num_features,
embed_dim=256,
depth=12,
num_heads=8,
drop_path_rate=0.1,
norm_layer="ln",
mask_ratio=0.1,
)
elif name == "vit_t_dp005_mask0": # For WebFace42M
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112,
patch_size=9,
num_classes=num_features,
embed_dim=256,
depth=12,
num_heads=8,
drop_path_rate=0.05,
norm_layer="ln",
mask_ratio=0.0,
)
elif name == "vit_s":
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112,
patch_size=9,
num_classes=num_features,
embed_dim=512,
depth=12,
num_heads=8,
drop_path_rate=0.1,
norm_layer="ln",
mask_ratio=0.1,
)
elif name == "vit_s_dp005_mask_0": # For WebFace42M
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112,
patch_size=9,
num_classes=num_features,
embed_dim=512,
depth=12,
num_heads=8,
drop_path_rate=0.05,
norm_layer="ln",
mask_ratio=0.0,
)
elif name == "vit_b":
# this is a feature
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112,
patch_size=9,
num_classes=num_features,
embed_dim=512,
depth=24,
num_heads=8,
drop_path_rate=0.1,
norm_layer="ln",
mask_ratio=0.1,
using_checkpoint=True,
)
elif name == "vit_b_dp005_mask_005": # For WebFace42M
# this is a feature
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112,
patch_size=9,
num_classes=num_features,
embed_dim=512,
depth=24,
num_heads=8,
drop_path_rate=0.05,
norm_layer="ln",
mask_ratio=0.05,
using_checkpoint=True,
)
elif name == "vit_l_dp005_mask_005": # For WebFace42M
# this is a feature
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112,
patch_size=9,
num_classes=num_features,
embed_dim=768,
depth=24,
num_heads=8,
drop_path_rate=0.05,
norm_layer="ln",
mask_ratio=0.05,
using_checkpoint=True,
)
else:
raise ValueError()