Spaces:
Running
Running
| # Copyright (c) 2015-present, Facebook, Inc. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the CC-by-NC license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # | |
| import torch | |
| import torch.nn as nn | |
| from functools import partial | |
| from timm.models.vision_transformer import VisionTransformer, _cfg | |
| from timm.models.registry import register_model | |
| def deit_tiny_patch8_224(pretrained=False, **kwargs): | |
| model = VisionTransformer( | |
| patch_size=8, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) | |
| model.default_cfg = _cfg() | |
| if pretrained: | |
| checkpoint = torch.hub.load_state_dict_from_url( | |
| url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth", | |
| map_location="cpu", check_hash=True | |
| ) | |
| model.load_state_dict(checkpoint["model"]) | |
| return model | |
| def deit_tiny_patch16_224(pretrained=False, **kwargs): | |
| model = VisionTransformer( | |
| patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) | |
| model.default_cfg = _cfg() | |
| if pretrained: | |
| checkpoint = torch.hub.load_state_dict_from_url( | |
| url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth", | |
| map_location="cpu", check_hash=True | |
| ) | |
| model.load_state_dict(checkpoint["model"]) | |
| return model | |
| def deit_tiny_patch16_d_6_224(pretrained=False, **kwargs): | |
| model = VisionTransformer( | |
| patch_size=16, embed_dim=192, depth=6, num_heads=3, mlp_ratio=4, qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) | |
| model.default_cfg = _cfg() | |
| if pretrained: | |
| checkpoint = torch.hub.load_state_dict_from_url( | |
| url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth", | |
| map_location="cpu", check_hash=True | |
| ) | |
| model.load_state_dict(checkpoint["model"]) | |
| return model | |
| def deit_tiny_patch32_224(pretrained=False, **kwargs): | |
| model = VisionTransformer( | |
| patch_size=32, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) | |
| model.default_cfg = _cfg() | |
| if pretrained: | |
| checkpoint = torch.hub.load_state_dict_from_url( | |
| url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth", | |
| map_location="cpu", check_hash=True | |
| ) | |
| model.load_state_dict(checkpoint["model"]) | |
| return model | |
| def deit_small_patch8_224(pretrained=False, **kwargs): | |
| model = VisionTransformer( | |
| patch_size=8, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) | |
| model.default_cfg = _cfg() | |
| if pretrained: | |
| checkpoint = torch.hub.load_state_dict_from_url( | |
| url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth", | |
| map_location="cpu", check_hash=True | |
| ) | |
| model.load_state_dict(checkpoint["model"]) | |
| return model | |
| def deit_small_patch16_224(pretrained=False, **kwargs): | |
| model = VisionTransformer( | |
| patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) | |
| model.default_cfg = _cfg() | |
| if pretrained: | |
| checkpoint = torch.hub.load_state_dict_from_url( | |
| url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth", | |
| map_location="cpu", check_hash=True | |
| ) | |
| model.load_state_dict(checkpoint["model"]) | |
| return model | |
| def deit_small_patch16_d_6_224(pretrained=False, **kwargs): | |
| model = VisionTransformer( | |
| patch_size=16, embed_dim=384, depth=6, num_heads=6, mlp_ratio=4, qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) | |
| model.default_cfg = _cfg() | |
| if pretrained: | |
| checkpoint = torch.hub.load_state_dict_from_url( | |
| url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth", | |
| map_location="cpu", check_hash=True | |
| ) | |
| model.load_state_dict(checkpoint["model"]) | |
| return model | |
| def deit_small_patch32_224(pretrained=False, **kwargs): | |
| model = VisionTransformer( | |
| patch_size=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) | |
| model.default_cfg = _cfg() | |
| if pretrained: | |
| checkpoint = torch.hub.load_state_dict_from_url( | |
| url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth", | |
| map_location="cpu", check_hash=True | |
| ) | |
| model.load_state_dict(checkpoint["model"]) | |
| return model | |
| def deit_base_patch8_224(pretrained=False, **kwargs): | |
| model = VisionTransformer( | |
| patch_size=8, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) | |
| model.default_cfg = _cfg() | |
| if pretrained: | |
| checkpoint = torch.hub.load_state_dict_from_url( | |
| url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", | |
| map_location="cpu", check_hash=True | |
| ) | |
| model.load_state_dict(checkpoint["model"]) | |
| return model | |
| def deit_base_patch16_224(pretrained=False, **kwargs): | |
| model = VisionTransformer( | |
| patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) | |
| model.default_cfg = _cfg() | |
| if pretrained: | |
| checkpoint = torch.hub.load_state_dict_from_url( | |
| url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", | |
| map_location="cpu", check_hash=True | |
| ) | |
| model.load_state_dict(checkpoint["model"]) | |
| return model | |
| def deit_base_patch16_ft_224(pretrained=False, **kwargs): | |
| model = VisionTransformer( | |
| patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) | |
| model.default_cfg = _cfg() | |
| if pretrained: | |
| checkpoint = torch.hub.load_state_dict_from_url( | |
| url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", | |
| map_location="cpu", check_hash=True | |
| ) | |
| model.load_state_dict(checkpoint["model"]) | |
| for m in model.parameters(): | |
| m.requires_grad = False | |
| for m in model.head.parameters(): | |
| m.requires_grad = True | |
| return model | |
| def deit_base24_patch16_224(pretrained=False, **kwargs): | |
| model = VisionTransformer( | |
| patch_size=16, embed_dim=768, depth=24, num_heads=12, mlp_ratio=4, qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) | |
| model.default_cfg = _cfg() | |
| if pretrained: | |
| checkpoint = torch.hub.load_state_dict_from_url( | |
| url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", | |
| map_location="cpu", check_hash=True | |
| ) | |
| model.load_state_dict(checkpoint["model"]) | |
| return model | |
| def deit_base16_patch16_224(pretrained=False, **kwargs): | |
| model = VisionTransformer( | |
| patch_size=16, embed_dim=768, depth=16, num_heads=12, mlp_ratio=4, qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) | |
| model.default_cfg = _cfg() | |
| if pretrained: | |
| checkpoint = torch.hub.load_state_dict_from_url( | |
| url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", | |
| map_location="cpu", check_hash=True | |
| ) | |
| model.load_state_dict(checkpoint["model"]) | |
| return model | |
| def deit_base_patch16_384(pretrained=False, **kwargs): | |
| model = VisionTransformer(img_size=384, | |
| patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) | |
| model.default_cfg = _cfg() | |
| if pretrained: | |
| checkpoint = torch.hub.load_state_dict_from_url( | |
| url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", | |
| map_location="cpu", check_hash=True | |
| ) | |
| model.load_state_dict(checkpoint["model"]) | |
| return model | |
| def deit_base_patch32_224(pretrained=False, **kwargs): | |
| model = VisionTransformer( | |
| patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, | |
| norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) | |
| model.default_cfg = _cfg() | |
| if pretrained: | |
| checkpoint = torch.hub.load_state_dict_from_url( | |
| url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", | |
| map_location="cpu", check_hash=True | |
| ) | |
| model.load_state_dict(checkpoint["model"]) | |
| return model | |