File size: 3,657 Bytes
3bbb319 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200
from .mobilefacenet import get_mbf
def get_model(name, **kwargs):
# resnet
if name == "r18":
return iresnet18(False, **kwargs)
elif name == "r34":
return iresnet34(False, **kwargs)
elif name == "r50":
return iresnet50(False, **kwargs)
elif name == "r100":
return iresnet100(False, **kwargs)
elif name == "r200":
return iresnet200(False, **kwargs)
elif name == "r2060":
from .iresnet2060 import iresnet2060
return iresnet2060(False, **kwargs)
elif name == "mbf":
fp16 = kwargs.get("fp16", False)
num_features = kwargs.get("num_features", 512)
return get_mbf(fp16=fp16, num_features=num_features)
elif name == "mbf_large":
from .mobilefacenet import get_mbf_large
fp16 = kwargs.get("fp16", False)
num_features = kwargs.get("num_features", 512)
return get_mbf_large(fp16=fp16, num_features=num_features)
elif name == "vit_t":
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112, patch_size=9, num_classes=num_features, embed_dim=256, depth=12,
num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1)
elif name == "vit_t_dp005_mask0": # For WebFace42M
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112, patch_size=9, num_classes=num_features, embed_dim=256, depth=12,
num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.0)
elif name == "vit_s":
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=12,
num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1)
elif name == "vit_s_dp005_mask_0": # For WebFace42M
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=12,
num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.0)
elif name == "vit_b":
# this is a feature
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24,
num_heads=8, drop_path_rate=0.1, norm_layer="ln", mask_ratio=0.1, using_checkpoint=True)
elif name == "vit_b_dp005_mask_005": # For WebFace42M
# this is a feature
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112, patch_size=9, num_classes=num_features, embed_dim=512, depth=24,
num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True)
elif name == "vit_l_dp005_mask_005": # For WebFace42M
# this is a feature
num_features = kwargs.get("num_features", 512)
from .vit import VisionTransformer
return VisionTransformer(
img_size=112, patch_size=9, num_classes=num_features, embed_dim=768, depth=24,
num_heads=8, drop_path_rate=0.05, norm_layer="ln", mask_ratio=0.05, using_checkpoint=True)
else:
raise ValueError()
|