Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/decathlon_datalist.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/image_dataset.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/image_reader.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/__init__.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/aspp.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/backbone_fpn_utils.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/crf.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/fcn.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/feature_pyramid_network.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/patchembedding.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/segresnet_block.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/unetr_block.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/upsample.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/__init__.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/convutils.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/factories.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/simplelayers.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/weight_init.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/__init__.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/decorators.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/deprecate_utils.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/jupyter_utils.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/module.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/profiling.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/type_conversion.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/__pycache__/__init__.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/__pycache__/version.cpython-38.pyc +0 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/__init__.py +94 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_efficientnet_builder.py +484 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_factory.py +127 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_features.py +368 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_features_fx.py +141 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_hub.py +402 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_prune.py +113 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/beit.py +621 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/byoanet.py +455 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/byobnet.py +2245 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/coat.py +804 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/convit.py +430 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/crossvit.py +627 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/cspnet.py +1106 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/deit.py +416 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/dla.py +515 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/eva.py +1109 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/factory.py +4 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/features.py +4 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/gcvit.py +592 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/ghostnet.py +432 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/inception_v4.py +325 -0
- my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/levit.py +933 -0
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/decathlon_datalist.cpython-38.pyc
ADDED
|
Binary file (8.88 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/image_dataset.cpython-38.pyc
ADDED
|
Binary file (5.13 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/image_reader.cpython-38.pyc
ADDED
|
Binary file (41.2 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (2.1 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/aspp.cpython-38.pyc
ADDED
|
Binary file (4.2 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/backbone_fpn_utils.cpython-38.pyc
ADDED
|
Binary file (5 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/crf.cpython-38.pyc
ADDED
|
Binary file (3.92 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/fcn.cpython-38.pyc
ADDED
|
Binary file (7.53 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/feature_pyramid_network.cpython-38.pyc
ADDED
|
Binary file (8.21 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/patchembedding.cpython-38.pyc
ADDED
|
Binary file (7.19 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/segresnet_block.cpython-38.pyc
ADDED
|
Binary file (2.83 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/unetr_block.cpython-38.pyc
ADDED
|
Binary file (5.84 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/upsample.cpython-38.pyc
ADDED
|
Binary file (9.48 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (1.24 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/convutils.cpython-38.pyc
ADDED
|
Binary file (7.22 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/factories.cpython-38.pyc
ADDED
|
Binary file (14.1 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/simplelayers.cpython-38.pyc
ADDED
|
Binary file (20.1 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/weight_init.cpython-38.pyc
ADDED
|
Binary file (2.01 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (2.81 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/decorators.cpython-38.pyc
ADDED
|
Binary file (2.92 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/deprecate_utils.cpython-38.pyc
ADDED
|
Binary file (7.3 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/jupyter_utils.cpython-38.pyc
ADDED
|
Binary file (13.1 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/module.cpython-38.pyc
ADDED
|
Binary file (18.1 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/profiling.cpython-38.pyc
ADDED
|
Binary file (3.53 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/type_conversion.cpython-38.pyc
ADDED
|
Binary file (11.7 kB). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/__pycache__/__init__.cpython-38.pyc
ADDED
|
Binary file (571 Bytes). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/__pycache__/version.cpython-38.pyc
ADDED
|
Binary file (177 Bytes). View file
|
|
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/__init__.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .beit import *
|
| 2 |
+
from .byoanet import *
|
| 3 |
+
from .byobnet import *
|
| 4 |
+
from .cait import *
|
| 5 |
+
from .coat import *
|
| 6 |
+
from .convit import *
|
| 7 |
+
from .convmixer import *
|
| 8 |
+
from .convnext import *
|
| 9 |
+
from .crossvit import *
|
| 10 |
+
from .cspnet import *
|
| 11 |
+
from .davit import *
|
| 12 |
+
from .deit import *
|
| 13 |
+
from .densenet import *
|
| 14 |
+
from .dla import *
|
| 15 |
+
from .dpn import *
|
| 16 |
+
from .edgenext import *
|
| 17 |
+
from .efficientformer import *
|
| 18 |
+
from .efficientformer_v2 import *
|
| 19 |
+
from .efficientnet import *
|
| 20 |
+
from .efficientvit_mit import *
|
| 21 |
+
from .efficientvit_msra import *
|
| 22 |
+
from .eva import *
|
| 23 |
+
from .fastvit import *
|
| 24 |
+
from .focalnet import *
|
| 25 |
+
from .gcvit import *
|
| 26 |
+
from .ghostnet import *
|
| 27 |
+
from .hardcorenas import *
|
| 28 |
+
from .hgnet import *
|
| 29 |
+
from .hrnet import *
|
| 30 |
+
from .inception_next import *
|
| 31 |
+
from .inception_resnet_v2 import *
|
| 32 |
+
from .inception_v3 import *
|
| 33 |
+
from .inception_v4 import *
|
| 34 |
+
from .levit import *
|
| 35 |
+
from .maxxvit import *
|
| 36 |
+
from .metaformer import *
|
| 37 |
+
from .mlp_mixer import *
|
| 38 |
+
from .mobilenetv3 import *
|
| 39 |
+
from .mobilevit import *
|
| 40 |
+
from .mvitv2 import *
|
| 41 |
+
from .nasnet import *
|
| 42 |
+
from .nest import *
|
| 43 |
+
from .nextvit import *
|
| 44 |
+
from .nfnet import *
|
| 45 |
+
from .pit import *
|
| 46 |
+
from .pnasnet import *
|
| 47 |
+
from .pvt_v2 import *
|
| 48 |
+
from .regnet import *
|
| 49 |
+
from .repghost import *
|
| 50 |
+
from .repvit import *
|
| 51 |
+
from .res2net import *
|
| 52 |
+
from .resnest import *
|
| 53 |
+
from .resnet import *
|
| 54 |
+
from .resnetv2 import *
|
| 55 |
+
from .rexnet import *
|
| 56 |
+
from .selecsls import *
|
| 57 |
+
from .senet import *
|
| 58 |
+
from .sequencer import *
|
| 59 |
+
from .sknet import *
|
| 60 |
+
from .swin_transformer import *
|
| 61 |
+
from .swin_transformer_v2 import *
|
| 62 |
+
from .swin_transformer_v2_cr import *
|
| 63 |
+
from .tiny_vit import *
|
| 64 |
+
from .tnt import *
|
| 65 |
+
from .tresnet import *
|
| 66 |
+
from .twins import *
|
| 67 |
+
from .vgg import *
|
| 68 |
+
from .visformer import *
|
| 69 |
+
from .vision_transformer import *
|
| 70 |
+
from .vision_transformer_hybrid import *
|
| 71 |
+
from .vision_transformer_relpos import *
|
| 72 |
+
from .vision_transformer_sam import *
|
| 73 |
+
from .volo import *
|
| 74 |
+
from .vovnet import *
|
| 75 |
+
from .xception import *
|
| 76 |
+
from .xception_aligned import *
|
| 77 |
+
from .xcit import *
|
| 78 |
+
|
| 79 |
+
from ._builder import build_model_with_cfg, load_pretrained, load_custom_pretrained, resolve_pretrained_cfg, \
|
| 80 |
+
set_pretrained_download_progress, set_pretrained_check_hash
|
| 81 |
+
from ._factory import create_model, parse_model_name, safe_model_name
|
| 82 |
+
from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet, FeatureDictNet
|
| 83 |
+
from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \
|
| 84 |
+
register_notrace_module, is_notrace_module, get_notrace_modules, \
|
| 85 |
+
register_notrace_function, is_notrace_function, get_notrace_functions
|
| 86 |
+
from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_state_dict, resume_checkpoint
|
| 87 |
+
from ._hub import load_model_config_from_hf, load_state_dict_from_hf, push_to_hf_hub
|
| 88 |
+
from ._manipulate import model_parameters, named_apply, named_modules, named_modules_with_params, \
|
| 89 |
+
group_modules, group_parameters, checkpoint_seq, adapt_input_conv
|
| 90 |
+
from ._pretrained import PretrainedCfg, DefaultCfg, filter_pretrained_cfg
|
| 91 |
+
from ._prune import adapt_model_from_string
|
| 92 |
+
from ._registry import split_model_name_tag, get_arch_name, generate_default_cfgs, register_model, \
|
| 93 |
+
register_model_deprecations, model_entrypoint, list_models, list_pretrained, get_deprecated_models, \
|
| 94 |
+
is_model, list_modules, is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_efficientnet_builder.py
ADDED
|
@@ -0,0 +1,484 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" EfficientNet, MobileNetV3, etc Builder
|
| 2 |
+
|
| 3 |
+
Assembles EfficieNet and related network feature blocks from string definitions.
|
| 4 |
+
Handles stride, dilation calculations, and selects feature extraction points.
|
| 5 |
+
|
| 6 |
+
Hacked together by / Copyright 2019, Ross Wightman
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
import math
|
| 11 |
+
import re
|
| 12 |
+
from copy import deepcopy
|
| 13 |
+
from functools import partial
|
| 14 |
+
from typing import Any, Dict, List
|
| 15 |
+
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
|
| 18 |
+
from ._efficientnet_blocks import *
|
| 19 |
+
from timm.layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible
|
| 20 |
+
|
| 21 |
+
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights",
|
| 22 |
+
'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT']
|
| 23 |
+
|
| 24 |
+
_logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
_DEBUG_BUILDER = False
|
| 28 |
+
|
| 29 |
+
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
|
| 30 |
+
# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
|
| 31 |
+
# NOTE: momentum varies btw .99 and .9997 depending on source
|
| 32 |
+
# .99 in official TF TPU impl
|
| 33 |
+
# .9997 (/w .999 in search space) for paper
|
| 34 |
+
BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
|
| 35 |
+
BN_EPS_TF_DEFAULT = 1e-3
|
| 36 |
+
_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT)
|
| 37 |
+
|
| 38 |
+
BlockArgs = List[List[Dict[str, Any]]]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def get_bn_args_tf():
|
| 42 |
+
return _BN_ARGS_TF.copy()
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def resolve_bn_args(kwargs):
|
| 46 |
+
bn_args = {}
|
| 47 |
+
bn_momentum = kwargs.pop('bn_momentum', None)
|
| 48 |
+
if bn_momentum is not None:
|
| 49 |
+
bn_args['momentum'] = bn_momentum
|
| 50 |
+
bn_eps = kwargs.pop('bn_eps', None)
|
| 51 |
+
if bn_eps is not None:
|
| 52 |
+
bn_args['eps'] = bn_eps
|
| 53 |
+
return bn_args
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def resolve_act_layer(kwargs, default='relu'):
|
| 57 |
+
return get_act_layer(kwargs.pop('act_layer', default))
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None, round_limit=0.9):
|
| 61 |
+
"""Round number of filters based on depth multiplier."""
|
| 62 |
+
if not multiplier:
|
| 63 |
+
return channels
|
| 64 |
+
return make_divisible(channels * multiplier, divisor, channel_min, round_limit=round_limit)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _log_info_if(msg, condition):
|
| 68 |
+
if condition:
|
| 69 |
+
_logger.info(msg)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _parse_ksize(ss):
|
| 73 |
+
if ss.isdigit():
|
| 74 |
+
return int(ss)
|
| 75 |
+
else:
|
| 76 |
+
return [int(k) for k in ss.split('.')]
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _decode_block_str(block_str):
|
| 80 |
+
""" Decode block definition string
|
| 81 |
+
|
| 82 |
+
Gets a list of block arg (dicts) through a string notation of arguments.
|
| 83 |
+
E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip
|
| 84 |
+
|
| 85 |
+
All args can exist in any order with the exception of the leading string which
|
| 86 |
+
is assumed to indicate the block type.
|
| 87 |
+
|
| 88 |
+
leading string - block type (
|
| 89 |
+
ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
|
| 90 |
+
r - number of repeat blocks,
|
| 91 |
+
k - kernel size,
|
| 92 |
+
s - strides (1-9),
|
| 93 |
+
e - expansion ratio,
|
| 94 |
+
c - output channels,
|
| 95 |
+
se - squeeze/excitation ratio
|
| 96 |
+
n - activation fn ('re', 'r6', 'hs', or 'sw')
|
| 97 |
+
Args:
|
| 98 |
+
block_str: a string representation of block arguments.
|
| 99 |
+
Returns:
|
| 100 |
+
A list of block args (dicts)
|
| 101 |
+
Raises:
|
| 102 |
+
ValueError: if the string def not properly specified (TODO)
|
| 103 |
+
"""
|
| 104 |
+
assert isinstance(block_str, str)
|
| 105 |
+
ops = block_str.split('_')
|
| 106 |
+
block_type = ops[0] # take the block type off the front
|
| 107 |
+
ops = ops[1:]
|
| 108 |
+
options = {}
|
| 109 |
+
skip = None
|
| 110 |
+
for op in ops:
|
| 111 |
+
# string options being checked on individual basis, combine if they grow
|
| 112 |
+
if op == 'noskip':
|
| 113 |
+
skip = False # force no skip connection
|
| 114 |
+
elif op == 'skip':
|
| 115 |
+
skip = True # force a skip connection
|
| 116 |
+
elif op.startswith('n'):
|
| 117 |
+
# activation fn
|
| 118 |
+
key = op[0]
|
| 119 |
+
v = op[1:]
|
| 120 |
+
if v == 're':
|
| 121 |
+
value = get_act_layer('relu')
|
| 122 |
+
elif v == 'r6':
|
| 123 |
+
value = get_act_layer('relu6')
|
| 124 |
+
elif v == 'hs':
|
| 125 |
+
value = get_act_layer('hard_swish')
|
| 126 |
+
elif v == 'sw':
|
| 127 |
+
value = get_act_layer('swish') # aka SiLU
|
| 128 |
+
elif v == 'mi':
|
| 129 |
+
value = get_act_layer('mish')
|
| 130 |
+
else:
|
| 131 |
+
continue
|
| 132 |
+
options[key] = value
|
| 133 |
+
else:
|
| 134 |
+
# all numeric options
|
| 135 |
+
splits = re.split(r'(\d.*)', op)
|
| 136 |
+
if len(splits) >= 2:
|
| 137 |
+
key, value = splits[:2]
|
| 138 |
+
options[key] = value
|
| 139 |
+
|
| 140 |
+
# if act_layer is None, the model default (passed to model init) will be used
|
| 141 |
+
act_layer = options['n'] if 'n' in options else None
|
| 142 |
+
exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
|
| 143 |
+
pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
|
| 144 |
+
force_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
|
| 145 |
+
num_repeat = int(options['r'])
|
| 146 |
+
|
| 147 |
+
# each type of block has different valid arguments, fill accordingly
|
| 148 |
+
block_args = dict(
|
| 149 |
+
block_type=block_type,
|
| 150 |
+
out_chs=int(options['c']),
|
| 151 |
+
stride=int(options['s']),
|
| 152 |
+
act_layer=act_layer,
|
| 153 |
+
)
|
| 154 |
+
if block_type == 'ir':
|
| 155 |
+
block_args.update(dict(
|
| 156 |
+
dw_kernel_size=_parse_ksize(options['k']),
|
| 157 |
+
exp_kernel_size=exp_kernel_size,
|
| 158 |
+
pw_kernel_size=pw_kernel_size,
|
| 159 |
+
exp_ratio=float(options['e']),
|
| 160 |
+
se_ratio=float(options['se']) if 'se' in options else 0.,
|
| 161 |
+
noskip=skip is False,
|
| 162 |
+
))
|
| 163 |
+
if 'cc' in options:
|
| 164 |
+
block_args['num_experts'] = int(options['cc'])
|
| 165 |
+
elif block_type == 'ds' or block_type == 'dsa':
|
| 166 |
+
block_args.update(dict(
|
| 167 |
+
dw_kernel_size=_parse_ksize(options['k']),
|
| 168 |
+
pw_kernel_size=pw_kernel_size,
|
| 169 |
+
se_ratio=float(options['se']) if 'se' in options else 0.,
|
| 170 |
+
pw_act=block_type == 'dsa',
|
| 171 |
+
noskip=block_type == 'dsa' or skip is False,
|
| 172 |
+
))
|
| 173 |
+
elif block_type == 'er':
|
| 174 |
+
block_args.update(dict(
|
| 175 |
+
exp_kernel_size=_parse_ksize(options['k']),
|
| 176 |
+
pw_kernel_size=pw_kernel_size,
|
| 177 |
+
exp_ratio=float(options['e']),
|
| 178 |
+
force_in_chs=force_in_chs,
|
| 179 |
+
se_ratio=float(options['se']) if 'se' in options else 0.,
|
| 180 |
+
noskip=skip is False,
|
| 181 |
+
))
|
| 182 |
+
elif block_type == 'cn':
|
| 183 |
+
block_args.update(dict(
|
| 184 |
+
kernel_size=int(options['k']),
|
| 185 |
+
skip=skip is True,
|
| 186 |
+
))
|
| 187 |
+
else:
|
| 188 |
+
assert False, 'Unknown block type (%s)' % block_type
|
| 189 |
+
if 'gs' in options:
|
| 190 |
+
block_args['group_size'] = options['gs']
|
| 191 |
+
|
| 192 |
+
return block_args, num_repeat
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'):
|
| 196 |
+
""" Per-stage depth scaling
|
| 197 |
+
Scales the block repeats in each stage. This depth scaling impl maintains
|
| 198 |
+
compatibility with the EfficientNet scaling method, while allowing sensible
|
| 199 |
+
scaling for other models that may have multiple block arg definitions in each stage.
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
# We scale the total repeat count for each stage, there may be multiple
|
| 203 |
+
# block arg defs per stage so we need to sum.
|
| 204 |
+
num_repeat = sum(repeats)
|
| 205 |
+
if depth_trunc == 'round':
|
| 206 |
+
# Truncating to int by rounding allows stages with few repeats to remain
|
| 207 |
+
# proportionally smaller for longer. This is a good choice when stage definitions
|
| 208 |
+
# include single repeat stages that we'd prefer to keep that way as long as possible
|
| 209 |
+
num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
|
| 210 |
+
else:
|
| 211 |
+
# The default for EfficientNet truncates repeats to int via 'ceil'.
|
| 212 |
+
# Any multiplier > 1.0 will result in an increased depth for every stage.
|
| 213 |
+
num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))
|
| 214 |
+
|
| 215 |
+
# Proportionally distribute repeat count scaling to each block definition in the stage.
|
| 216 |
+
# Allocation is done in reverse as it results in the first block being less likely to be scaled.
|
| 217 |
+
# The first block makes less sense to repeat in most of the arch definitions.
|
| 218 |
+
repeats_scaled = []
|
| 219 |
+
for r in repeats[::-1]:
|
| 220 |
+
rs = max(1, round((r / num_repeat * num_repeat_scaled)))
|
| 221 |
+
repeats_scaled.append(rs)
|
| 222 |
+
num_repeat -= r
|
| 223 |
+
num_repeat_scaled -= rs
|
| 224 |
+
repeats_scaled = repeats_scaled[::-1]
|
| 225 |
+
|
| 226 |
+
# Apply the calculated scaling to each block arg in the stage
|
| 227 |
+
sa_scaled = []
|
| 228 |
+
for ba, rep in zip(stack_args, repeats_scaled):
|
| 229 |
+
sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
|
| 230 |
+
return sa_scaled
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def decode_arch_def(
|
| 234 |
+
arch_def,
|
| 235 |
+
depth_multiplier=1.0,
|
| 236 |
+
depth_trunc='ceil',
|
| 237 |
+
experts_multiplier=1,
|
| 238 |
+
fix_first_last=False,
|
| 239 |
+
group_size=None,
|
| 240 |
+
):
|
| 241 |
+
""" Decode block architecture definition strings -> block kwargs
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
arch_def: architecture definition strings, list of list of strings
|
| 245 |
+
depth_multiplier: network depth multiplier
|
| 246 |
+
depth_trunc: networ depth truncation mode when applying multiplier
|
| 247 |
+
experts_multiplier: CondConv experts multiplier
|
| 248 |
+
fix_first_last: fix first and last block depths when multiplier is applied
|
| 249 |
+
group_size: group size override for all blocks that weren't explicitly set in arch string
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
list of list of block kwargs
|
| 253 |
+
"""
|
| 254 |
+
arch_args = []
|
| 255 |
+
if isinstance(depth_multiplier, tuple):
|
| 256 |
+
assert len(depth_multiplier) == len(arch_def)
|
| 257 |
+
else:
|
| 258 |
+
depth_multiplier = (depth_multiplier,) * len(arch_def)
|
| 259 |
+
for stack_idx, (block_strings, multiplier) in enumerate(zip(arch_def, depth_multiplier)):
|
| 260 |
+
assert isinstance(block_strings, list)
|
| 261 |
+
stack_args = []
|
| 262 |
+
repeats = []
|
| 263 |
+
for block_str in block_strings:
|
| 264 |
+
assert isinstance(block_str, str)
|
| 265 |
+
ba, rep = _decode_block_str(block_str)
|
| 266 |
+
if ba.get('num_experts', 0) > 0 and experts_multiplier > 1:
|
| 267 |
+
ba['num_experts'] *= experts_multiplier
|
| 268 |
+
if group_size is not None:
|
| 269 |
+
ba.setdefault('group_size', group_size)
|
| 270 |
+
stack_args.append(ba)
|
| 271 |
+
repeats.append(rep)
|
| 272 |
+
if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1):
|
| 273 |
+
arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc))
|
| 274 |
+
else:
|
| 275 |
+
arch_args.append(_scale_stage_depth(stack_args, repeats, multiplier, depth_trunc))
|
| 276 |
+
return arch_args
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class EfficientNetBuilder:
|
| 280 |
+
""" Build Trunk Blocks
|
| 281 |
+
|
| 282 |
+
This ended up being somewhat of a cross between
|
| 283 |
+
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py
|
| 284 |
+
and
|
| 285 |
+
https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py
|
| 286 |
+
|
| 287 |
+
"""
|
| 288 |
+
def __init__(self, output_stride=32, pad_type='', round_chs_fn=round_channels, se_from_exp=False,
|
| 289 |
+
act_layer=None, norm_layer=None, se_layer=None, drop_path_rate=0., feature_location=''):
|
| 290 |
+
self.output_stride = output_stride
|
| 291 |
+
self.pad_type = pad_type
|
| 292 |
+
self.round_chs_fn = round_chs_fn
|
| 293 |
+
self.se_from_exp = se_from_exp # calculate se channel reduction from expanded (mid) chs
|
| 294 |
+
self.act_layer = act_layer
|
| 295 |
+
self.norm_layer = norm_layer
|
| 296 |
+
self.se_layer = get_attn(se_layer)
|
| 297 |
+
try:
|
| 298 |
+
self.se_layer(8, rd_ratio=1.0) # test if attn layer accepts rd_ratio arg
|
| 299 |
+
self.se_has_ratio = True
|
| 300 |
+
except TypeError:
|
| 301 |
+
self.se_has_ratio = False
|
| 302 |
+
self.drop_path_rate = drop_path_rate
|
| 303 |
+
if feature_location == 'depthwise':
|
| 304 |
+
# old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense
|
| 305 |
+
_logger.warning("feature_location=='depthwise' is deprecated, using 'expansion'")
|
| 306 |
+
feature_location = 'expansion'
|
| 307 |
+
self.feature_location = feature_location
|
| 308 |
+
assert feature_location in ('bottleneck', 'expansion', '')
|
| 309 |
+
self.verbose = _DEBUG_BUILDER
|
| 310 |
+
|
| 311 |
+
# state updated during build, consumed by model
|
| 312 |
+
self.in_chs = None
|
| 313 |
+
self.features = []
|
| 314 |
+
|
| 315 |
+
def _make_block(self, ba, block_idx, block_count):
|
| 316 |
+
drop_path_rate = self.drop_path_rate * block_idx / block_count
|
| 317 |
+
bt = ba.pop('block_type')
|
| 318 |
+
ba['in_chs'] = self.in_chs
|
| 319 |
+
ba['out_chs'] = self.round_chs_fn(ba['out_chs'])
|
| 320 |
+
if 'force_in_chs' in ba and ba['force_in_chs']:
|
| 321 |
+
# NOTE this is a hack to work around mismatch in TF EdgeEffNet impl
|
| 322 |
+
ba['force_in_chs'] = self.round_chs_fn(ba['force_in_chs'])
|
| 323 |
+
ba['pad_type'] = self.pad_type
|
| 324 |
+
# block act fn overrides the model default
|
| 325 |
+
ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
|
| 326 |
+
assert ba['act_layer'] is not None
|
| 327 |
+
ba['norm_layer'] = self.norm_layer
|
| 328 |
+
ba['drop_path_rate'] = drop_path_rate
|
| 329 |
+
if bt != 'cn':
|
| 330 |
+
se_ratio = ba.pop('se_ratio')
|
| 331 |
+
if se_ratio and self.se_layer is not None:
|
| 332 |
+
if not self.se_from_exp:
|
| 333 |
+
# adjust se_ratio by expansion ratio if calculating se channels from block input
|
| 334 |
+
se_ratio /= ba.get('exp_ratio', 1.0)
|
| 335 |
+
if self.se_has_ratio:
|
| 336 |
+
ba['se_layer'] = partial(self.se_layer, rd_ratio=se_ratio)
|
| 337 |
+
else:
|
| 338 |
+
ba['se_layer'] = self.se_layer
|
| 339 |
+
|
| 340 |
+
if bt == 'ir':
|
| 341 |
+
_log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
|
| 342 |
+
block = CondConvResidual(**ba) if ba.get('num_experts', 0) else InvertedResidual(**ba)
|
| 343 |
+
elif bt == 'ds' or bt == 'dsa':
|
| 344 |
+
_log_info_if(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
|
| 345 |
+
block = DepthwiseSeparableConv(**ba)
|
| 346 |
+
elif bt == 'er':
|
| 347 |
+
_log_info_if(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
|
| 348 |
+
block = EdgeResidual(**ba)
|
| 349 |
+
elif bt == 'cn':
|
| 350 |
+
_log_info_if(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
|
| 351 |
+
block = ConvBnAct(**ba)
|
| 352 |
+
else:
|
| 353 |
+
assert False, 'Uknkown block type (%s) while building model.' % bt
|
| 354 |
+
|
| 355 |
+
self.in_chs = ba['out_chs'] # update in_chs for arg of next block
|
| 356 |
+
return block
|
| 357 |
+
|
| 358 |
+
def __call__(self, in_chs, model_block_args):
|
| 359 |
+
""" Build the blocks
|
| 360 |
+
Args:
|
| 361 |
+
in_chs: Number of input-channels passed to first block
|
| 362 |
+
model_block_args: A list of lists, outer list defines stages, inner
|
| 363 |
+
list contains strings defining block configuration(s)
|
| 364 |
+
Return:
|
| 365 |
+
List of block stacks (each stack wrapped in nn.Sequential)
|
| 366 |
+
"""
|
| 367 |
+
_log_info_if('Building model trunk with %d stages...' % len(model_block_args), self.verbose)
|
| 368 |
+
self.in_chs = in_chs
|
| 369 |
+
total_block_count = sum([len(x) for x in model_block_args])
|
| 370 |
+
total_block_idx = 0
|
| 371 |
+
current_stride = 2
|
| 372 |
+
current_dilation = 1
|
| 373 |
+
stages = []
|
| 374 |
+
if model_block_args[0][0]['stride'] > 1:
|
| 375 |
+
# if the first block starts with a stride, we need to extract first level feat from stem
|
| 376 |
+
feature_info = dict(module='bn1', num_chs=in_chs, stage=0, reduction=current_stride)
|
| 377 |
+
self.features.append(feature_info)
|
| 378 |
+
|
| 379 |
+
# outer list of block_args defines the stacks
|
| 380 |
+
for stack_idx, stack_args in enumerate(model_block_args):
|
| 381 |
+
last_stack = stack_idx + 1 == len(model_block_args)
|
| 382 |
+
_log_info_if('Stack: {}'.format(stack_idx), self.verbose)
|
| 383 |
+
assert isinstance(stack_args, list)
|
| 384 |
+
|
| 385 |
+
blocks = []
|
| 386 |
+
# each stack (stage of blocks) contains a list of block arguments
|
| 387 |
+
for block_idx, block_args in enumerate(stack_args):
|
| 388 |
+
last_block = block_idx + 1 == len(stack_args)
|
| 389 |
+
_log_info_if(' Block: {}'.format(block_idx), self.verbose)
|
| 390 |
+
|
| 391 |
+
assert block_args['stride'] in (1, 2)
|
| 392 |
+
if block_idx >= 1: # only the first block in any stack can have a stride > 1
|
| 393 |
+
block_args['stride'] = 1
|
| 394 |
+
|
| 395 |
+
extract_features = False
|
| 396 |
+
if last_block:
|
| 397 |
+
next_stack_idx = stack_idx + 1
|
| 398 |
+
extract_features = next_stack_idx >= len(model_block_args) or \
|
| 399 |
+
model_block_args[next_stack_idx][0]['stride'] > 1
|
| 400 |
+
|
| 401 |
+
next_dilation = current_dilation
|
| 402 |
+
if block_args['stride'] > 1:
|
| 403 |
+
next_output_stride = current_stride * block_args['stride']
|
| 404 |
+
if next_output_stride > self.output_stride:
|
| 405 |
+
next_dilation = current_dilation * block_args['stride']
|
| 406 |
+
block_args['stride'] = 1
|
| 407 |
+
_log_info_if(' Converting stride to dilation to maintain output_stride=={}'.format(
|
| 408 |
+
self.output_stride), self.verbose)
|
| 409 |
+
else:
|
| 410 |
+
current_stride = next_output_stride
|
| 411 |
+
block_args['dilation'] = current_dilation
|
| 412 |
+
if next_dilation != current_dilation:
|
| 413 |
+
current_dilation = next_dilation
|
| 414 |
+
|
| 415 |
+
# create the block
|
| 416 |
+
block = self._make_block(block_args, total_block_idx, total_block_count)
|
| 417 |
+
blocks.append(block)
|
| 418 |
+
|
| 419 |
+
# stash feature module name and channel info for model feature extraction
|
| 420 |
+
if extract_features:
|
| 421 |
+
feature_info = dict(
|
| 422 |
+
stage=stack_idx + 1,
|
| 423 |
+
reduction=current_stride,
|
| 424 |
+
**block.feature_info(self.feature_location),
|
| 425 |
+
)
|
| 426 |
+
leaf_name = feature_info.get('module', '')
|
| 427 |
+
if leaf_name:
|
| 428 |
+
feature_info['module'] = '.'.join([f'blocks.{stack_idx}.{block_idx}', leaf_name])
|
| 429 |
+
else:
|
| 430 |
+
assert last_block
|
| 431 |
+
feature_info['module'] = f'blocks.{stack_idx}'
|
| 432 |
+
self.features.append(feature_info)
|
| 433 |
+
|
| 434 |
+
total_block_idx += 1 # incr global block idx (across all stacks)
|
| 435 |
+
stages.append(nn.Sequential(*blocks))
|
| 436 |
+
return stages
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def _init_weight_goog(m, n='', fix_group_fanout=True):
|
| 440 |
+
""" Weight initialization as per Tensorflow official implementations.
|
| 441 |
+
|
| 442 |
+
Args:
|
| 443 |
+
m (nn.Module): module to init
|
| 444 |
+
n (str): module name
|
| 445 |
+
fix_group_fanout (bool): enable correct (matching Tensorflow TPU impl) fanout calculation w/ group convs
|
| 446 |
+
|
| 447 |
+
Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc:
|
| 448 |
+
* https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
|
| 449 |
+
* https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
|
| 450 |
+
"""
|
| 451 |
+
if isinstance(m, CondConv2d):
|
| 452 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 453 |
+
if fix_group_fanout:
|
| 454 |
+
fan_out //= m.groups
|
| 455 |
+
init_weight_fn = get_condconv_initializer(
|
| 456 |
+
lambda w: nn.init.normal_(w, 0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
|
| 457 |
+
init_weight_fn(m.weight)
|
| 458 |
+
if m.bias is not None:
|
| 459 |
+
nn.init.zeros_(m.bias)
|
| 460 |
+
elif isinstance(m, nn.Conv2d):
|
| 461 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 462 |
+
if fix_group_fanout:
|
| 463 |
+
fan_out //= m.groups
|
| 464 |
+
nn.init.normal_(m.weight, 0, math.sqrt(2.0 / fan_out))
|
| 465 |
+
if m.bias is not None:
|
| 466 |
+
nn.init.zeros_(m.bias)
|
| 467 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 468 |
+
nn.init.ones_(m.weight)
|
| 469 |
+
nn.init.zeros_(m.bias)
|
| 470 |
+
elif isinstance(m, nn.Linear):
|
| 471 |
+
fan_out = m.weight.size(0) # fan-out
|
| 472 |
+
fan_in = 0
|
| 473 |
+
if 'routing_fn' in n:
|
| 474 |
+
fan_in = m.weight.size(1)
|
| 475 |
+
init_range = 1.0 / math.sqrt(fan_in + fan_out)
|
| 476 |
+
nn.init.uniform_(m.weight, -init_range, init_range)
|
| 477 |
+
nn.init.zeros_(m.bias)
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def efficientnet_init_weights(model: nn.Module, init_fn=None):
|
| 481 |
+
init_fn = init_fn or _init_weight_goog
|
| 482 |
+
for n, m in model.named_modules():
|
| 483 |
+
init_fn(m, n)
|
| 484 |
+
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_factory.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Any, Dict, Optional, Union
|
| 3 |
+
from urllib.parse import urlsplit
|
| 4 |
+
|
| 5 |
+
from timm.layers import set_layer_config
|
| 6 |
+
from ._helpers import load_checkpoint
|
| 7 |
+
from ._hub import load_model_config_from_hf
|
| 8 |
+
from ._pretrained import PretrainedCfg
|
| 9 |
+
from ._registry import is_model, model_entrypoint, split_model_name_tag
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
__all__ = ['parse_model_name', 'safe_model_name', 'create_model']
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def parse_model_name(model_name: str):
|
| 16 |
+
if model_name.startswith('hf_hub'):
|
| 17 |
+
# NOTE for backwards compat, deprecate hf_hub use
|
| 18 |
+
model_name = model_name.replace('hf_hub', 'hf-hub')
|
| 19 |
+
parsed = urlsplit(model_name)
|
| 20 |
+
assert parsed.scheme in ('', 'timm', 'hf-hub')
|
| 21 |
+
if parsed.scheme == 'hf-hub':
|
| 22 |
+
# FIXME may use fragment as revision, currently `@` in URI path
|
| 23 |
+
return parsed.scheme, parsed.path
|
| 24 |
+
else:
|
| 25 |
+
model_name = os.path.split(parsed.path)[-1]
|
| 26 |
+
return 'timm', model_name
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def safe_model_name(model_name: str, remove_source: bool = True):
|
| 30 |
+
# return a filename / path safe model name
|
| 31 |
+
def make_safe(name):
|
| 32 |
+
return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_')
|
| 33 |
+
if remove_source:
|
| 34 |
+
model_name = parse_model_name(model_name)[-1]
|
| 35 |
+
return make_safe(model_name)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def create_model(
|
| 39 |
+
model_name: str,
|
| 40 |
+
pretrained: bool = False,
|
| 41 |
+
pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
|
| 42 |
+
pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
|
| 43 |
+
checkpoint_path: str = '',
|
| 44 |
+
scriptable: Optional[bool] = None,
|
| 45 |
+
exportable: Optional[bool] = None,
|
| 46 |
+
no_jit: Optional[bool] = None,
|
| 47 |
+
**kwargs,
|
| 48 |
+
):
|
| 49 |
+
"""Create a model.
|
| 50 |
+
|
| 51 |
+
Lookup model's entrypoint function and pass relevant args to create a new model.
|
| 52 |
+
|
| 53 |
+
<Tip>
|
| 54 |
+
**kwargs will be passed through entrypoint fn to ``timm.models.build_model_with_cfg()``
|
| 55 |
+
and then the model class __init__(). kwargs values set to None are pruned before passing.
|
| 56 |
+
</Tip>
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
model_name: Name of model to instantiate.
|
| 60 |
+
pretrained: If set to `True`, load pretrained ImageNet-1k weights.
|
| 61 |
+
pretrained_cfg: Pass in an external pretrained_cfg for model.
|
| 62 |
+
pretrained_cfg_overlay: Replace key-values in base pretrained_cfg with these.
|
| 63 |
+
checkpoint_path: Path of checkpoint to load _after_ the model is initialized.
|
| 64 |
+
scriptable: Set layer config so that model is jit scriptable (not working for all models yet).
|
| 65 |
+
exportable: Set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet).
|
| 66 |
+
no_jit: Set layer config so that model doesn't utilize jit scripted layers (so far activations only).
|
| 67 |
+
|
| 68 |
+
Keyword Args:
|
| 69 |
+
drop_rate (float): Classifier dropout rate for training.
|
| 70 |
+
drop_path_rate (float): Stochastic depth drop rate for training.
|
| 71 |
+
global_pool (str): Classifier global pooling type.
|
| 72 |
+
|
| 73 |
+
Example:
|
| 74 |
+
|
| 75 |
+
```py
|
| 76 |
+
>>> from timm import create_model
|
| 77 |
+
|
| 78 |
+
>>> # Create a MobileNetV3-Large model with no pretrained weights.
|
| 79 |
+
>>> model = create_model('mobilenetv3_large_100')
|
| 80 |
+
|
| 81 |
+
>>> # Create a MobileNetV3-Large model with pretrained weights.
|
| 82 |
+
>>> model = create_model('mobilenetv3_large_100', pretrained=True)
|
| 83 |
+
>>> model.num_classes
|
| 84 |
+
1000
|
| 85 |
+
|
| 86 |
+
>>> # Create a MobileNetV3-Large model with pretrained weights and a new head with 10 classes.
|
| 87 |
+
>>> model = create_model('mobilenetv3_large_100', pretrained=True, num_classes=10)
|
| 88 |
+
>>> model.num_classes
|
| 89 |
+
10
|
| 90 |
+
```
|
| 91 |
+
"""
|
| 92 |
+
# Parameters that aren't supported by all models or are intended to only override model defaults if set
|
| 93 |
+
# should default to None in command line args/cfg. Remove them if they are present and not set so that
|
| 94 |
+
# non-supporting models don't break and default args remain in effect.
|
| 95 |
+
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
| 96 |
+
|
| 97 |
+
model_source, model_name = parse_model_name(model_name)
|
| 98 |
+
if model_source == 'hf-hub':
|
| 99 |
+
assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.'
|
| 100 |
+
# For model names specified in the form `hf-hub:path/architecture_name@revision`,
|
| 101 |
+
# load model weights + pretrained_cfg from Hugging Face hub.
|
| 102 |
+
pretrained_cfg, model_name, model_args = load_model_config_from_hf(model_name)
|
| 103 |
+
if model_args:
|
| 104 |
+
for k, v in model_args.items():
|
| 105 |
+
kwargs.setdefault(k, v)
|
| 106 |
+
else:
|
| 107 |
+
model_name, pretrained_tag = split_model_name_tag(model_name)
|
| 108 |
+
if pretrained_tag and not pretrained_cfg:
|
| 109 |
+
# a valid pretrained_cfg argument takes priority over tag in model name
|
| 110 |
+
pretrained_cfg = pretrained_tag
|
| 111 |
+
|
| 112 |
+
if not is_model(model_name):
|
| 113 |
+
raise RuntimeError('Unknown model (%s)' % model_name)
|
| 114 |
+
|
| 115 |
+
create_fn = model_entrypoint(model_name)
|
| 116 |
+
with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
|
| 117 |
+
model = create_fn(
|
| 118 |
+
pretrained=pretrained,
|
| 119 |
+
pretrained_cfg=pretrained_cfg,
|
| 120 |
+
pretrained_cfg_overlay=pretrained_cfg_overlay,
|
| 121 |
+
**kwargs,
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
if checkpoint_path:
|
| 125 |
+
load_checkpoint(model, checkpoint_path)
|
| 126 |
+
|
| 127 |
+
return model
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_features.py
ADDED
|
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" PyTorch Feature Extraction Helpers
|
| 2 |
+
|
| 3 |
+
A collection of classes, functions, modules to help extract features from models
|
| 4 |
+
and provide a common interface for describing them.
|
| 5 |
+
|
| 6 |
+
The return_layers, module re-writing idea inspired by torchvision IntermediateLayerGetter
|
| 7 |
+
https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py
|
| 8 |
+
|
| 9 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
| 10 |
+
"""
|
| 11 |
+
from collections import OrderedDict, defaultdict
|
| 12 |
+
from copy import deepcopy
|
| 13 |
+
from functools import partial
|
| 14 |
+
from typing import Dict, List, Sequence, Tuple, Union
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
from torch.utils.checkpoint import checkpoint
|
| 19 |
+
|
| 20 |
+
from timm.layers import Format
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet']
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class FeatureInfo:
|
| 27 |
+
|
| 28 |
+
def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
|
| 29 |
+
prev_reduction = 1
|
| 30 |
+
for i, fi in enumerate(feature_info):
|
| 31 |
+
# sanity check the mandatory fields, there may be additional fields depending on the model
|
| 32 |
+
assert 'num_chs' in fi and fi['num_chs'] > 0
|
| 33 |
+
assert 'reduction' in fi and fi['reduction'] >= prev_reduction
|
| 34 |
+
prev_reduction = fi['reduction']
|
| 35 |
+
assert 'module' in fi
|
| 36 |
+
fi.setdefault('index', i)
|
| 37 |
+
self.out_indices = out_indices
|
| 38 |
+
self.info = feature_info
|
| 39 |
+
|
| 40 |
+
def from_other(self, out_indices: Tuple[int]):
|
| 41 |
+
return FeatureInfo(deepcopy(self.info), out_indices)
|
| 42 |
+
|
| 43 |
+
def get(self, key, idx=None):
|
| 44 |
+
""" Get value by key at specified index (indices)
|
| 45 |
+
if idx == None, returns value for key at each output index
|
| 46 |
+
if idx is an integer, return value for that feature module index (ignoring output indices)
|
| 47 |
+
if idx is a list/tupple, return value for each module index (ignoring output indices)
|
| 48 |
+
"""
|
| 49 |
+
if idx is None:
|
| 50 |
+
return [self.info[i][key] for i in self.out_indices]
|
| 51 |
+
if isinstance(idx, (tuple, list)):
|
| 52 |
+
return [self.info[i][key] for i in idx]
|
| 53 |
+
else:
|
| 54 |
+
return self.info[idx][key]
|
| 55 |
+
|
| 56 |
+
def get_dicts(self, keys=None, idx=None):
|
| 57 |
+
""" return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)
|
| 58 |
+
"""
|
| 59 |
+
if idx is None:
|
| 60 |
+
if keys is None:
|
| 61 |
+
return [self.info[i] for i in self.out_indices]
|
| 62 |
+
else:
|
| 63 |
+
return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
|
| 64 |
+
if isinstance(idx, (tuple, list)):
|
| 65 |
+
return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx]
|
| 66 |
+
else:
|
| 67 |
+
return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
|
| 68 |
+
|
| 69 |
+
def channels(self, idx=None):
|
| 70 |
+
""" feature channels accessor
|
| 71 |
+
"""
|
| 72 |
+
return self.get('num_chs', idx)
|
| 73 |
+
|
| 74 |
+
def reduction(self, idx=None):
|
| 75 |
+
""" feature reduction (output stride) accessor
|
| 76 |
+
"""
|
| 77 |
+
return self.get('reduction', idx)
|
| 78 |
+
|
| 79 |
+
def module_name(self, idx=None):
|
| 80 |
+
""" feature module name accessor
|
| 81 |
+
"""
|
| 82 |
+
return self.get('module', idx)
|
| 83 |
+
|
| 84 |
+
def __getitem__(self, item):
|
| 85 |
+
return self.info[item]
|
| 86 |
+
|
| 87 |
+
def __len__(self):
|
| 88 |
+
return len(self.info)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class FeatureHooks:
|
| 92 |
+
""" Feature Hook Helper
|
| 93 |
+
|
| 94 |
+
This module helps with the setup and extraction of hooks for extracting features from
|
| 95 |
+
internal nodes in a model by node name.
|
| 96 |
+
|
| 97 |
+
FIXME This works well in eager Python but needs redesign for torchscript.
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
def __init__(
|
| 101 |
+
self,
|
| 102 |
+
hooks: Sequence[str],
|
| 103 |
+
named_modules: dict,
|
| 104 |
+
out_map: Sequence[Union[int, str]] = None,
|
| 105 |
+
default_hook_type: str = 'forward',
|
| 106 |
+
):
|
| 107 |
+
# setup feature hooks
|
| 108 |
+
self._feature_outputs = defaultdict(OrderedDict)
|
| 109 |
+
modules = {k: v for k, v in named_modules}
|
| 110 |
+
for i, h in enumerate(hooks):
|
| 111 |
+
hook_name = h['module']
|
| 112 |
+
m = modules[hook_name]
|
| 113 |
+
hook_id = out_map[i] if out_map else hook_name
|
| 114 |
+
hook_fn = partial(self._collect_output_hook, hook_id)
|
| 115 |
+
hook_type = h.get('hook_type', default_hook_type)
|
| 116 |
+
if hook_type == 'forward_pre':
|
| 117 |
+
m.register_forward_pre_hook(hook_fn)
|
| 118 |
+
elif hook_type == 'forward':
|
| 119 |
+
m.register_forward_hook(hook_fn)
|
| 120 |
+
else:
|
| 121 |
+
assert False, "Unsupported hook type"
|
| 122 |
+
|
| 123 |
+
def _collect_output_hook(self, hook_id, *args):
|
| 124 |
+
x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
|
| 125 |
+
if isinstance(x, tuple):
|
| 126 |
+
x = x[0] # unwrap input tuple
|
| 127 |
+
self._feature_outputs[x.device][hook_id] = x
|
| 128 |
+
|
| 129 |
+
def get_output(self, device) -> Dict[str, torch.tensor]:
|
| 130 |
+
output = self._feature_outputs[device]
|
| 131 |
+
self._feature_outputs[device] = OrderedDict() # clear after reading
|
| 132 |
+
return output
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _module_list(module, flatten_sequential=False):
|
| 136 |
+
# a yield/iter would be better for this but wouldn't be compatible with torchscript
|
| 137 |
+
ml = []
|
| 138 |
+
for name, module in module.named_children():
|
| 139 |
+
if flatten_sequential and isinstance(module, nn.Sequential):
|
| 140 |
+
# first level of Sequential containers is flattened into containing model
|
| 141 |
+
for child_name, child_module in module.named_children():
|
| 142 |
+
combined = [name, child_name]
|
| 143 |
+
ml.append(('_'.join(combined), '.'.join(combined), child_module))
|
| 144 |
+
else:
|
| 145 |
+
ml.append((name, name, module))
|
| 146 |
+
return ml
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def _get_feature_info(net, out_indices):
|
| 150 |
+
feature_info = getattr(net, 'feature_info')
|
| 151 |
+
if isinstance(feature_info, FeatureInfo):
|
| 152 |
+
return feature_info.from_other(out_indices)
|
| 153 |
+
elif isinstance(feature_info, (list, tuple)):
|
| 154 |
+
return FeatureInfo(net.feature_info, out_indices)
|
| 155 |
+
else:
|
| 156 |
+
assert False, "Provided feature_info is not valid"
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def _get_return_layers(feature_info, out_map):
|
| 160 |
+
module_names = feature_info.module_name()
|
| 161 |
+
return_layers = {}
|
| 162 |
+
for i, name in enumerate(module_names):
|
| 163 |
+
return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]
|
| 164 |
+
return return_layers
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class FeatureDictNet(nn.ModuleDict):
|
| 168 |
+
""" Feature extractor with OrderedDict return
|
| 169 |
+
|
| 170 |
+
Wrap a model and extract features as specified by the out indices, the network is
|
| 171 |
+
partially re-built from contained modules.
|
| 172 |
+
|
| 173 |
+
There is a strong assumption that the modules have been registered into the model in the same
|
| 174 |
+
order as they are used. There should be no reuse of the same nn.Module more than once, including
|
| 175 |
+
trivial modules like `self.relu = nn.ReLU`.
|
| 176 |
+
|
| 177 |
+
Only submodules that are directly assigned to the model class (`model.feature1`) or at most
|
| 178 |
+
one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.
|
| 179 |
+
All Sequential containers that are directly assigned to the original model will have their
|
| 180 |
+
modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
|
| 181 |
+
"""
|
| 182 |
+
def __init__(
|
| 183 |
+
self,
|
| 184 |
+
model: nn.Module,
|
| 185 |
+
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
| 186 |
+
out_map: Sequence[Union[int, str]] = None,
|
| 187 |
+
output_fmt: str = 'NCHW',
|
| 188 |
+
feature_concat: bool = False,
|
| 189 |
+
flatten_sequential: bool = False,
|
| 190 |
+
):
|
| 191 |
+
"""
|
| 192 |
+
Args:
|
| 193 |
+
model: Model from which to extract features.
|
| 194 |
+
out_indices: Output indices of the model features to extract.
|
| 195 |
+
out_map: Return id mapping for each output index, otherwise str(index) is used.
|
| 196 |
+
feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
|
| 197 |
+
first element e.g. `x[0]`
|
| 198 |
+
flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
|
| 199 |
+
"""
|
| 200 |
+
super(FeatureDictNet, self).__init__()
|
| 201 |
+
self.feature_info = _get_feature_info(model, out_indices)
|
| 202 |
+
self.output_fmt = Format(output_fmt)
|
| 203 |
+
self.concat = feature_concat
|
| 204 |
+
self.grad_checkpointing = False
|
| 205 |
+
self.return_layers = {}
|
| 206 |
+
|
| 207 |
+
return_layers = _get_return_layers(self.feature_info, out_map)
|
| 208 |
+
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
| 209 |
+
remaining = set(return_layers.keys())
|
| 210 |
+
layers = OrderedDict()
|
| 211 |
+
for new_name, old_name, module in modules:
|
| 212 |
+
layers[new_name] = module
|
| 213 |
+
if old_name in remaining:
|
| 214 |
+
# return id has to be consistently str type for torchscript
|
| 215 |
+
self.return_layers[new_name] = str(return_layers[old_name])
|
| 216 |
+
remaining.remove(old_name)
|
| 217 |
+
if not remaining:
|
| 218 |
+
break
|
| 219 |
+
assert not remaining and len(self.return_layers) == len(return_layers), \
|
| 220 |
+
f'Return layers ({remaining}) are not present in model'
|
| 221 |
+
self.update(layers)
|
| 222 |
+
|
| 223 |
+
def set_grad_checkpointing(self, enable: bool = True):
|
| 224 |
+
self.grad_checkpointing = enable
|
| 225 |
+
|
| 226 |
+
def _collect(self, x) -> (Dict[str, torch.Tensor]):
|
| 227 |
+
out = OrderedDict()
|
| 228 |
+
for i, (name, module) in enumerate(self.items()):
|
| 229 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
| 230 |
+
# Skipping checkpoint of first module because need a gradient at input
|
| 231 |
+
# Skipping last because networks with in-place ops might fail w/ checkpointing enabled
|
| 232 |
+
# NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues
|
| 233 |
+
first_or_last_module = i == 0 or i == max(len(self) - 1, 0)
|
| 234 |
+
x = module(x) if first_or_last_module else checkpoint(module, x)
|
| 235 |
+
else:
|
| 236 |
+
x = module(x)
|
| 237 |
+
|
| 238 |
+
if name in self.return_layers:
|
| 239 |
+
out_id = self.return_layers[name]
|
| 240 |
+
if isinstance(x, (tuple, list)):
|
| 241 |
+
# If model tap is a tuple or list, concat or select first element
|
| 242 |
+
# FIXME this may need to be more generic / flexible for some nets
|
| 243 |
+
out[out_id] = torch.cat(x, 1) if self.concat else x[0]
|
| 244 |
+
else:
|
| 245 |
+
out[out_id] = x
|
| 246 |
+
return out
|
| 247 |
+
|
| 248 |
+
def forward(self, x) -> Dict[str, torch.Tensor]:
|
| 249 |
+
return self._collect(x)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class FeatureListNet(FeatureDictNet):
|
| 253 |
+
""" Feature extractor with list return
|
| 254 |
+
|
| 255 |
+
A specialization of FeatureDictNet that always returns features as a list (values() of dict).
|
| 256 |
+
"""
|
| 257 |
+
def __init__(
|
| 258 |
+
self,
|
| 259 |
+
model: nn.Module,
|
| 260 |
+
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
| 261 |
+
output_fmt: str = 'NCHW',
|
| 262 |
+
feature_concat: bool = False,
|
| 263 |
+
flatten_sequential: bool = False,
|
| 264 |
+
):
|
| 265 |
+
"""
|
| 266 |
+
Args:
|
| 267 |
+
model: Model from which to extract features.
|
| 268 |
+
out_indices: Output indices of the model features to extract.
|
| 269 |
+
feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
|
| 270 |
+
first element e.g. `x[0]`
|
| 271 |
+
flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
|
| 272 |
+
"""
|
| 273 |
+
super().__init__(
|
| 274 |
+
model,
|
| 275 |
+
out_indices=out_indices,
|
| 276 |
+
output_fmt=output_fmt,
|
| 277 |
+
feature_concat=feature_concat,
|
| 278 |
+
flatten_sequential=flatten_sequential,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
def forward(self, x) -> (List[torch.Tensor]):
|
| 282 |
+
return list(self._collect(x).values())
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
class FeatureHookNet(nn.ModuleDict):
|
| 286 |
+
""" FeatureHookNet
|
| 287 |
+
|
| 288 |
+
Wrap a model and extract features specified by the out indices using forward/forward-pre hooks.
|
| 289 |
+
|
| 290 |
+
If `no_rewrite` is True, features are extracted via hooks without modifying the underlying
|
| 291 |
+
network in any way.
|
| 292 |
+
|
| 293 |
+
If `no_rewrite` is False, the model will be re-written as in the
|
| 294 |
+
FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one.
|
| 295 |
+
|
| 296 |
+
FIXME this does not currently work with Torchscript, see FeatureHooks class
|
| 297 |
+
"""
|
| 298 |
+
def __init__(
|
| 299 |
+
self,
|
| 300 |
+
model: nn.Module,
|
| 301 |
+
out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
|
| 302 |
+
out_map: Sequence[Union[int, str]] = None,
|
| 303 |
+
return_dict: bool = False,
|
| 304 |
+
output_fmt: str = 'NCHW',
|
| 305 |
+
no_rewrite: bool = False,
|
| 306 |
+
flatten_sequential: bool = False,
|
| 307 |
+
default_hook_type: str = 'forward',
|
| 308 |
+
):
|
| 309 |
+
"""
|
| 310 |
+
|
| 311 |
+
Args:
|
| 312 |
+
model: Model from which to extract features.
|
| 313 |
+
out_indices: Output indices of the model features to extract.
|
| 314 |
+
out_map: Return id mapping for each output index, otherwise str(index) is used.
|
| 315 |
+
return_dict: Output features as a dict.
|
| 316 |
+
no_rewrite: Enforce that model is not re-written if True, ie no modules are removed / changed.
|
| 317 |
+
flatten_sequential arg must also be False if this is set True.
|
| 318 |
+
flatten_sequential: Re-write modules by flattening first two levels of nn.Sequential containers.
|
| 319 |
+
default_hook_type: The default hook type to use if not specified in model.feature_info.
|
| 320 |
+
"""
|
| 321 |
+
super().__init__()
|
| 322 |
+
assert not torch.jit.is_scripting()
|
| 323 |
+
self.feature_info = _get_feature_info(model, out_indices)
|
| 324 |
+
self.return_dict = return_dict
|
| 325 |
+
self.output_fmt = Format(output_fmt)
|
| 326 |
+
self.grad_checkpointing = False
|
| 327 |
+
|
| 328 |
+
layers = OrderedDict()
|
| 329 |
+
hooks = []
|
| 330 |
+
if no_rewrite:
|
| 331 |
+
assert not flatten_sequential
|
| 332 |
+
if hasattr(model, 'reset_classifier'): # make sure classifier is removed?
|
| 333 |
+
model.reset_classifier(0)
|
| 334 |
+
layers['body'] = model
|
| 335 |
+
hooks.extend(self.feature_info.get_dicts())
|
| 336 |
+
else:
|
| 337 |
+
modules = _module_list(model, flatten_sequential=flatten_sequential)
|
| 338 |
+
remaining = {
|
| 339 |
+
f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
|
| 340 |
+
for f in self.feature_info.get_dicts()
|
| 341 |
+
}
|
| 342 |
+
for new_name, old_name, module in modules:
|
| 343 |
+
layers[new_name] = module
|
| 344 |
+
for fn, fm in module.named_modules(prefix=old_name):
|
| 345 |
+
if fn in remaining:
|
| 346 |
+
hooks.append(dict(module=fn, hook_type=remaining[fn]))
|
| 347 |
+
del remaining[fn]
|
| 348 |
+
if not remaining:
|
| 349 |
+
break
|
| 350 |
+
assert not remaining, f'Return layers ({remaining}) are not present in model'
|
| 351 |
+
self.update(layers)
|
| 352 |
+
self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
|
| 353 |
+
|
| 354 |
+
def set_grad_checkpointing(self, enable: bool = True):
|
| 355 |
+
self.grad_checkpointing = enable
|
| 356 |
+
|
| 357 |
+
def forward(self, x):
|
| 358 |
+
for i, (name, module) in enumerate(self.items()):
|
| 359 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
| 360 |
+
# Skipping checkpoint of first module because need a gradient at input
|
| 361 |
+
# Skipping last because networks with in-place ops might fail w/ checkpointing enabled
|
| 362 |
+
# NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues
|
| 363 |
+
first_or_last_module = i == 0 or i == max(len(self) - 1, 0)
|
| 364 |
+
x = module(x) if first_or_last_module else checkpoint(module, x)
|
| 365 |
+
else:
|
| 366 |
+
x = module(x)
|
| 367 |
+
out = self.hooks.get_output(x.device)
|
| 368 |
+
return out if self.return_dict else list(out.values())
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_features_fx.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" PyTorch FX Based Feature Extraction Helpers
|
| 2 |
+
Using https://pytorch.org/vision/stable/feature_extraction.html
|
| 3 |
+
"""
|
| 4 |
+
from typing import Callable, List, Dict, Union, Type
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
from ._features import _get_feature_info, _get_return_layers
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
|
| 13 |
+
has_fx_feature_extraction = True
|
| 14 |
+
except ImportError:
|
| 15 |
+
has_fx_feature_extraction = False
|
| 16 |
+
|
| 17 |
+
# Layers we went to treat as leaf modules
|
| 18 |
+
from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame
|
| 19 |
+
from timm.layers.non_local_attn import BilinearAttnTransform
|
| 20 |
+
from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
|
| 21 |
+
from timm.layers.norm_act import (
|
| 22 |
+
BatchNormAct2d,
|
| 23 |
+
SyncBatchNormAct,
|
| 24 |
+
FrozenBatchNormAct2d,
|
| 25 |
+
GroupNormAct,
|
| 26 |
+
GroupNorm1Act,
|
| 27 |
+
LayerNormAct,
|
| 28 |
+
LayerNormAct2d
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
__all__ = ['register_notrace_module', 'is_notrace_module', 'get_notrace_modules',
|
| 32 |
+
'register_notrace_function', 'is_notrace_function', 'get_notrace_functions',
|
| 33 |
+
'create_feature_extractor', 'FeatureGraphNet', 'GraphExtractNet']
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
|
| 37 |
+
# BUT modules from timm.models should use the registration mechanism below
|
| 38 |
+
_leaf_modules = {
|
| 39 |
+
BilinearAttnTransform, # reason: flow control t <= 1
|
| 40 |
+
# Reason: get_same_padding has a max which raises a control flow error
|
| 41 |
+
Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame,
|
| 42 |
+
CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0]),
|
| 43 |
+
BatchNormAct2d,
|
| 44 |
+
SyncBatchNormAct,
|
| 45 |
+
FrozenBatchNormAct2d,
|
| 46 |
+
GroupNormAct,
|
| 47 |
+
GroupNorm1Act,
|
| 48 |
+
LayerNormAct,
|
| 49 |
+
LayerNormAct2d,
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
from timm.layers import InplaceAbn
|
| 54 |
+
_leaf_modules.add(InplaceAbn)
|
| 55 |
+
except ImportError:
|
| 56 |
+
pass
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def register_notrace_module(module: Type[nn.Module]):
|
| 60 |
+
"""
|
| 61 |
+
Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
|
| 62 |
+
"""
|
| 63 |
+
_leaf_modules.add(module)
|
| 64 |
+
return module
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def is_notrace_module(module: Type[nn.Module]):
|
| 68 |
+
return module in _leaf_modules
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def get_notrace_modules():
|
| 72 |
+
return list(_leaf_modules)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# Functions we want to autowrap (treat them as leaves)
|
| 76 |
+
_autowrap_functions = set()
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def register_notrace_function(func: Callable):
|
| 80 |
+
"""
|
| 81 |
+
Decorator for functions which ought not to be traced through
|
| 82 |
+
"""
|
| 83 |
+
_autowrap_functions.add(func)
|
| 84 |
+
return func
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def is_notrace_function(func: Callable):
|
| 88 |
+
return func in _autowrap_functions
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def get_notrace_functions():
|
| 92 |
+
return list(_autowrap_functions)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
|
| 96 |
+
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
|
| 97 |
+
return _create_feature_extractor(
|
| 98 |
+
model, return_nodes,
|
| 99 |
+
tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class FeatureGraphNet(nn.Module):
|
| 104 |
+
""" A FX Graph based feature extractor that works with the model feature_info metadata
|
| 105 |
+
"""
|
| 106 |
+
def __init__(self, model, out_indices, out_map=None):
|
| 107 |
+
super().__init__()
|
| 108 |
+
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
|
| 109 |
+
self.feature_info = _get_feature_info(model, out_indices)
|
| 110 |
+
if out_map is not None:
|
| 111 |
+
assert len(out_map) == len(out_indices)
|
| 112 |
+
return_nodes = _get_return_layers(self.feature_info, out_map)
|
| 113 |
+
self.graph_module = create_feature_extractor(model, return_nodes)
|
| 114 |
+
|
| 115 |
+
def forward(self, x):
|
| 116 |
+
return list(self.graph_module(x).values())
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class GraphExtractNet(nn.Module):
|
| 120 |
+
""" A standalone feature extraction wrapper that maps dict -> list or single tensor
|
| 121 |
+
NOTE:
|
| 122 |
+
* one can use feature_extractor directly if dictionary output is desired
|
| 123 |
+
* unlike FeatureGraphNet, this is intended to be used standalone and not with model feature_info
|
| 124 |
+
metadata for builtin feature extraction mode
|
| 125 |
+
* create_feature_extractor can be used directly if dictionary output is acceptable
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
model: model to extract features from
|
| 129 |
+
return_nodes: node names to return features from (dict or list)
|
| 130 |
+
squeeze_out: if only one output, and output in list format, flatten to single tensor
|
| 131 |
+
"""
|
| 132 |
+
def __init__(self, model, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True):
|
| 133 |
+
super().__init__()
|
| 134 |
+
self.squeeze_out = squeeze_out
|
| 135 |
+
self.graph_module = create_feature_extractor(model, return_nodes)
|
| 136 |
+
|
| 137 |
+
def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]:
|
| 138 |
+
out = list(self.graph_module(x).values())
|
| 139 |
+
if self.squeeze_out and len(out) == 1:
|
| 140 |
+
return out[0]
|
| 141 |
+
return out
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_hub.py
ADDED
|
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import hashlib
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
from functools import partial
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from tempfile import TemporaryDirectory
|
| 8 |
+
from typing import Iterable, Optional, Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
from torch.hub import get_dir
|
| 15 |
+
except ImportError:
|
| 16 |
+
from torch.hub import _get_torch_home as get_dir
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
import safetensors.torch
|
| 20 |
+
_has_safetensors = True
|
| 21 |
+
except ImportError:
|
| 22 |
+
_has_safetensors = False
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
from typing import Literal
|
| 26 |
+
except ImportError:
|
| 27 |
+
from typing_extensions import Literal
|
| 28 |
+
|
| 29 |
+
from timm import __version__
|
| 30 |
+
from timm.models._pretrained import filter_pretrained_cfg
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
from huggingface_hub import (
|
| 34 |
+
create_repo, get_hf_file_metadata,
|
| 35 |
+
hf_hub_download, hf_hub_url,
|
| 36 |
+
repo_type_and_id_from_hf_id, upload_folder)
|
| 37 |
+
from huggingface_hub.utils import EntryNotFoundError
|
| 38 |
+
hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
|
| 39 |
+
_has_hf_hub = True
|
| 40 |
+
except ImportError:
|
| 41 |
+
hf_hub_download = None
|
| 42 |
+
_has_hf_hub = False
|
| 43 |
+
|
| 44 |
+
_logger = logging.getLogger(__name__)
|
| 45 |
+
|
| 46 |
+
__all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf',
|
| 47 |
+
'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub']
|
| 48 |
+
|
| 49 |
+
# Default name for a weights file hosted on the Huggingface Hub.
|
| 50 |
+
HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
|
| 51 |
+
HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version
|
| 52 |
+
HF_OPEN_CLIP_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl
|
| 53 |
+
HF_OPEN_CLIP_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def get_cache_dir(child_dir=''):
|
| 57 |
+
"""
|
| 58 |
+
Returns the location of the directory where models are cached (and creates it if necessary).
|
| 59 |
+
"""
|
| 60 |
+
# Issue warning to move data if old env is set
|
| 61 |
+
if os.getenv('TORCH_MODEL_ZOO'):
|
| 62 |
+
_logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
|
| 63 |
+
|
| 64 |
+
hub_dir = get_dir()
|
| 65 |
+
child_dir = () if not child_dir else (child_dir,)
|
| 66 |
+
model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir)
|
| 67 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 68 |
+
return model_dir
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def download_cached_file(url, check_hash=True, progress=False):
|
| 72 |
+
if isinstance(url, (list, tuple)):
|
| 73 |
+
url, filename = url
|
| 74 |
+
else:
|
| 75 |
+
parts = urlparse(url)
|
| 76 |
+
filename = os.path.basename(parts.path)
|
| 77 |
+
cached_file = os.path.join(get_cache_dir(), filename)
|
| 78 |
+
if not os.path.exists(cached_file):
|
| 79 |
+
_logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
|
| 80 |
+
hash_prefix = None
|
| 81 |
+
if check_hash:
|
| 82 |
+
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
|
| 83 |
+
hash_prefix = r.group(1) if r else None
|
| 84 |
+
download_url_to_file(url, cached_file, hash_prefix, progress=progress)
|
| 85 |
+
return cached_file
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def check_cached_file(url, check_hash=True):
|
| 89 |
+
if isinstance(url, (list, tuple)):
|
| 90 |
+
url, filename = url
|
| 91 |
+
else:
|
| 92 |
+
parts = urlparse(url)
|
| 93 |
+
filename = os.path.basename(parts.path)
|
| 94 |
+
cached_file = os.path.join(get_cache_dir(), filename)
|
| 95 |
+
if os.path.exists(cached_file):
|
| 96 |
+
if check_hash:
|
| 97 |
+
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
|
| 98 |
+
hash_prefix = r.group(1) if r else None
|
| 99 |
+
if hash_prefix:
|
| 100 |
+
with open(cached_file, 'rb') as f:
|
| 101 |
+
hd = hashlib.sha256(f.read()).hexdigest()
|
| 102 |
+
if hd[:len(hash_prefix)] != hash_prefix:
|
| 103 |
+
return False
|
| 104 |
+
return True
|
| 105 |
+
return False
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def has_hf_hub(necessary=False):
|
| 109 |
+
if not _has_hf_hub and necessary:
|
| 110 |
+
# if no HF Hub module installed, and it is necessary to continue, raise error
|
| 111 |
+
raise RuntimeError(
|
| 112 |
+
'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
|
| 113 |
+
return _has_hf_hub
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def hf_split(hf_id: str):
|
| 117 |
+
# FIXME I may change @ -> # and be parsed as fragment in a URI model name scheme
|
| 118 |
+
rev_split = hf_id.split('@')
|
| 119 |
+
assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.'
|
| 120 |
+
hf_model_id = rev_split[0]
|
| 121 |
+
hf_revision = rev_split[-1] if len(rev_split) > 1 else None
|
| 122 |
+
return hf_model_id, hf_revision
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def load_cfg_from_json(json_file: Union[str, os.PathLike]):
|
| 126 |
+
with open(json_file, "r", encoding="utf-8") as reader:
|
| 127 |
+
text = reader.read()
|
| 128 |
+
return json.loads(text)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def download_from_hf(model_id: str, filename: str):
|
| 132 |
+
hf_model_id, hf_revision = hf_split(model_id)
|
| 133 |
+
return hf_hub_download(hf_model_id, filename, revision=hf_revision)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def load_model_config_from_hf(model_id: str):
|
| 137 |
+
assert has_hf_hub(True)
|
| 138 |
+
cached_file = download_from_hf(model_id, 'config.json')
|
| 139 |
+
|
| 140 |
+
hf_config = load_cfg_from_json(cached_file)
|
| 141 |
+
if 'pretrained_cfg' not in hf_config:
|
| 142 |
+
# old form, pull pretrain_cfg out of the base dict
|
| 143 |
+
pretrained_cfg = hf_config
|
| 144 |
+
hf_config = {}
|
| 145 |
+
hf_config['architecture'] = pretrained_cfg.pop('architecture')
|
| 146 |
+
hf_config['num_features'] = pretrained_cfg.pop('num_features', None)
|
| 147 |
+
if 'labels' in pretrained_cfg: # deprecated name for 'label_names'
|
| 148 |
+
pretrained_cfg['label_names'] = pretrained_cfg.pop('labels')
|
| 149 |
+
hf_config['pretrained_cfg'] = pretrained_cfg
|
| 150 |
+
|
| 151 |
+
# NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now
|
| 152 |
+
pretrained_cfg = hf_config['pretrained_cfg']
|
| 153 |
+
pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation
|
| 154 |
+
pretrained_cfg['source'] = 'hf-hub'
|
| 155 |
+
|
| 156 |
+
# model should be created with base config num_classes if its exist
|
| 157 |
+
if 'num_classes' in hf_config:
|
| 158 |
+
pretrained_cfg['num_classes'] = hf_config['num_classes']
|
| 159 |
+
|
| 160 |
+
# label meta-data in base config overrides saved pretrained_cfg on load
|
| 161 |
+
if 'label_names' in hf_config:
|
| 162 |
+
pretrained_cfg['label_names'] = hf_config.pop('label_names')
|
| 163 |
+
if 'label_descriptions' in hf_config:
|
| 164 |
+
pretrained_cfg['label_descriptions'] = hf_config.pop('label_descriptions')
|
| 165 |
+
|
| 166 |
+
model_args = hf_config.get('model_args', {})
|
| 167 |
+
model_name = hf_config['architecture']
|
| 168 |
+
return pretrained_cfg, model_name, model_args
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
|
| 172 |
+
assert has_hf_hub(True)
|
| 173 |
+
hf_model_id, hf_revision = hf_split(model_id)
|
| 174 |
+
|
| 175 |
+
# Look for .safetensors alternatives and load from it if it exists
|
| 176 |
+
if _has_safetensors:
|
| 177 |
+
for safe_filename in _get_safe_alternatives(filename):
|
| 178 |
+
try:
|
| 179 |
+
cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision)
|
| 180 |
+
_logger.info(
|
| 181 |
+
f"[{model_id}] Safe alternative available for '{filename}' "
|
| 182 |
+
f"(as '{safe_filename}'). Loading weights using safetensors.")
|
| 183 |
+
return safetensors.torch.load_file(cached_safe_file, device="cpu")
|
| 184 |
+
except EntryNotFoundError:
|
| 185 |
+
pass
|
| 186 |
+
|
| 187 |
+
# Otherwise, load using pytorch.load
|
| 188 |
+
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
|
| 189 |
+
_logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
|
| 190 |
+
return torch.load(cached_file, map_location='cpu')
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def save_config_for_hf(
|
| 194 |
+
model,
|
| 195 |
+
config_path: str,
|
| 196 |
+
model_config: Optional[dict] = None,
|
| 197 |
+
model_args: Optional[dict] = None
|
| 198 |
+
):
|
| 199 |
+
model_config = model_config or {}
|
| 200 |
+
hf_config = {}
|
| 201 |
+
pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
|
| 202 |
+
# set some values at root config level
|
| 203 |
+
hf_config['architecture'] = pretrained_cfg.pop('architecture')
|
| 204 |
+
hf_config['num_classes'] = model_config.pop('num_classes', model.num_classes)
|
| 205 |
+
|
| 206 |
+
# NOTE these attr saved for informational purposes, do not impact model build
|
| 207 |
+
hf_config['num_features'] = model_config.pop('num_features', model.num_features)
|
| 208 |
+
global_pool_type = model_config.pop('global_pool', getattr(model, 'global_pool', None))
|
| 209 |
+
if isinstance(global_pool_type, str) and global_pool_type:
|
| 210 |
+
hf_config['global_pool'] = global_pool_type
|
| 211 |
+
|
| 212 |
+
# Save class label info
|
| 213 |
+
if 'labels' in model_config:
|
| 214 |
+
_logger.warning(
|
| 215 |
+
"'labels' as a config field for is deprecated. Please use 'label_names' and 'label_descriptions'."
|
| 216 |
+
" Renaming provided 'labels' field to 'label_names'.")
|
| 217 |
+
model_config.setdefault('label_names', model_config.pop('labels'))
|
| 218 |
+
|
| 219 |
+
label_names = model_config.pop('label_names', None)
|
| 220 |
+
if label_names:
|
| 221 |
+
assert isinstance(label_names, (dict, list, tuple))
|
| 222 |
+
# map label id (classifier index) -> unique label name (ie synset for ImageNet, MID for OpenImages)
|
| 223 |
+
# can be a dict id: name if there are id gaps, or tuple/list if no gaps.
|
| 224 |
+
hf_config['label_names'] = label_names
|
| 225 |
+
|
| 226 |
+
label_descriptions = model_config.pop('label_descriptions', None)
|
| 227 |
+
if label_descriptions:
|
| 228 |
+
assert isinstance(label_descriptions, dict)
|
| 229 |
+
# maps label names -> descriptions
|
| 230 |
+
hf_config['label_descriptions'] = label_descriptions
|
| 231 |
+
|
| 232 |
+
if model_args:
|
| 233 |
+
hf_config['model_args'] = model_args
|
| 234 |
+
|
| 235 |
+
hf_config['pretrained_cfg'] = pretrained_cfg
|
| 236 |
+
hf_config.update(model_config)
|
| 237 |
+
|
| 238 |
+
with config_path.open('w') as f:
|
| 239 |
+
json.dump(hf_config, f, indent=2)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def save_for_hf(
|
| 243 |
+
model,
|
| 244 |
+
save_directory: str,
|
| 245 |
+
model_config: Optional[dict] = None,
|
| 246 |
+
model_args: Optional[dict] = None,
|
| 247 |
+
safe_serialization: Union[bool, Literal["both"]] = False,
|
| 248 |
+
):
|
| 249 |
+
assert has_hf_hub(True)
|
| 250 |
+
save_directory = Path(save_directory)
|
| 251 |
+
save_directory.mkdir(exist_ok=True, parents=True)
|
| 252 |
+
|
| 253 |
+
# Save model weights, either safely (using safetensors), or using legacy pytorch approach or both.
|
| 254 |
+
tensors = model.state_dict()
|
| 255 |
+
if safe_serialization is True or safe_serialization == "both":
|
| 256 |
+
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
|
| 257 |
+
safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME)
|
| 258 |
+
if safe_serialization is False or safe_serialization == "both":
|
| 259 |
+
torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
|
| 260 |
+
|
| 261 |
+
config_path = save_directory / 'config.json'
|
| 262 |
+
save_config_for_hf(
|
| 263 |
+
model,
|
| 264 |
+
config_path,
|
| 265 |
+
model_config=model_config,
|
| 266 |
+
model_args=model_args,
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def push_to_hf_hub(
|
| 271 |
+
model: torch.nn.Module,
|
| 272 |
+
repo_id: str,
|
| 273 |
+
commit_message: str = 'Add model',
|
| 274 |
+
token: Optional[str] = None,
|
| 275 |
+
revision: Optional[str] = None,
|
| 276 |
+
private: bool = False,
|
| 277 |
+
create_pr: bool = False,
|
| 278 |
+
model_config: Optional[dict] = None,
|
| 279 |
+
model_card: Optional[dict] = None,
|
| 280 |
+
model_args: Optional[dict] = None,
|
| 281 |
+
safe_serialization: Union[bool, Literal["both"]] = False,
|
| 282 |
+
):
|
| 283 |
+
"""
|
| 284 |
+
Arguments:
|
| 285 |
+
(...)
|
| 286 |
+
safe_serialization (`bool` or `"both"`, *optional*, defaults to `False`):
|
| 287 |
+
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
| 288 |
+
Can be set to `"both"` in order to push both safe and unsafe weights.
|
| 289 |
+
"""
|
| 290 |
+
# Create repo if it doesn't exist yet
|
| 291 |
+
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
|
| 292 |
+
|
| 293 |
+
# Infer complete repo_id from repo_url
|
| 294 |
+
# Can be different from the input `repo_id` if repo_owner was implicit
|
| 295 |
+
_, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
|
| 296 |
+
repo_id = f"{repo_owner}/{repo_name}"
|
| 297 |
+
|
| 298 |
+
# Check if README file already exist in repo
|
| 299 |
+
try:
|
| 300 |
+
get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
|
| 301 |
+
has_readme = True
|
| 302 |
+
except EntryNotFoundError:
|
| 303 |
+
has_readme = False
|
| 304 |
+
|
| 305 |
+
# Dump model and push to Hub
|
| 306 |
+
with TemporaryDirectory() as tmpdir:
|
| 307 |
+
# Save model weights and config.
|
| 308 |
+
save_for_hf(
|
| 309 |
+
model,
|
| 310 |
+
tmpdir,
|
| 311 |
+
model_config=model_config,
|
| 312 |
+
model_args=model_args,
|
| 313 |
+
safe_serialization=safe_serialization,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
# Add readme if it does not exist
|
| 317 |
+
if not has_readme:
|
| 318 |
+
model_card = model_card or {}
|
| 319 |
+
model_name = repo_id.split('/')[-1]
|
| 320 |
+
readme_path = Path(tmpdir) / "README.md"
|
| 321 |
+
readme_text = generate_readme(model_card, model_name)
|
| 322 |
+
readme_path.write_text(readme_text)
|
| 323 |
+
|
| 324 |
+
# Upload model and return
|
| 325 |
+
return upload_folder(
|
| 326 |
+
repo_id=repo_id,
|
| 327 |
+
folder_path=tmpdir,
|
| 328 |
+
revision=revision,
|
| 329 |
+
create_pr=create_pr,
|
| 330 |
+
commit_message=commit_message,
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def generate_readme(model_card: dict, model_name: str):
|
| 335 |
+
readme_text = "---\n"
|
| 336 |
+
readme_text += "tags:\n- image-classification\n- timm\n"
|
| 337 |
+
readme_text += "library_name: timm\n"
|
| 338 |
+
readme_text += f"license: {model_card.get('license', 'apache-2.0')}\n"
|
| 339 |
+
if 'details' in model_card and 'Dataset' in model_card['details']:
|
| 340 |
+
readme_text += 'datasets:\n'
|
| 341 |
+
if isinstance(model_card['details']['Dataset'], (tuple, list)):
|
| 342 |
+
for d in model_card['details']['Dataset']:
|
| 343 |
+
readme_text += f"- {d.lower()}\n"
|
| 344 |
+
else:
|
| 345 |
+
readme_text += f"- {model_card['details']['Dataset'].lower()}\n"
|
| 346 |
+
if 'Pretrain Dataset' in model_card['details']:
|
| 347 |
+
if isinstance(model_card['details']['Pretrain Dataset'], (tuple, list)):
|
| 348 |
+
for d in model_card['details']['Pretrain Dataset']:
|
| 349 |
+
readme_text += f"- {d.lower()}\n"
|
| 350 |
+
else:
|
| 351 |
+
readme_text += f"- {model_card['details']['Pretrain Dataset'].lower()}\n"
|
| 352 |
+
readme_text += "---\n"
|
| 353 |
+
readme_text += f"# Model card for {model_name}\n"
|
| 354 |
+
if 'description' in model_card:
|
| 355 |
+
readme_text += f"\n{model_card['description']}\n"
|
| 356 |
+
if 'details' in model_card:
|
| 357 |
+
readme_text += f"\n## Model Details\n"
|
| 358 |
+
for k, v in model_card['details'].items():
|
| 359 |
+
if isinstance(v, (list, tuple)):
|
| 360 |
+
readme_text += f"- **{k}:**\n"
|
| 361 |
+
for vi in v:
|
| 362 |
+
readme_text += f" - {vi}\n"
|
| 363 |
+
elif isinstance(v, dict):
|
| 364 |
+
readme_text += f"- **{k}:**\n"
|
| 365 |
+
for ki, vi in v.items():
|
| 366 |
+
readme_text += f" - {ki}: {vi}\n"
|
| 367 |
+
else:
|
| 368 |
+
readme_text += f"- **{k}:** {v}\n"
|
| 369 |
+
if 'usage' in model_card:
|
| 370 |
+
readme_text += f"\n## Model Usage\n"
|
| 371 |
+
readme_text += model_card['usage']
|
| 372 |
+
readme_text += '\n'
|
| 373 |
+
|
| 374 |
+
if 'comparison' in model_card:
|
| 375 |
+
readme_text += f"\n## Model Comparison\n"
|
| 376 |
+
readme_text += model_card['comparison']
|
| 377 |
+
readme_text += '\n'
|
| 378 |
+
|
| 379 |
+
if 'citation' in model_card:
|
| 380 |
+
readme_text += f"\n## Citation\n"
|
| 381 |
+
if not isinstance(model_card['citation'], (list, tuple)):
|
| 382 |
+
citations = [model_card['citation']]
|
| 383 |
+
else:
|
| 384 |
+
citations = model_card['citation']
|
| 385 |
+
for c in citations:
|
| 386 |
+
readme_text += f"```bibtex\n{c}\n```\n"
|
| 387 |
+
return readme_text
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def _get_safe_alternatives(filename: str) -> Iterable[str]:
|
| 391 |
+
"""Returns potential safetensors alternatives for a given filename.
|
| 392 |
+
|
| 393 |
+
Use case:
|
| 394 |
+
When downloading a model from the Huggingface Hub, we first look if a .safetensors file exists and if yes, we use it.
|
| 395 |
+
Main use case is filename "pytorch_model.bin" => check for "model.safetensors" or "pytorch_model.safetensors".
|
| 396 |
+
"""
|
| 397 |
+
if filename == HF_WEIGHTS_NAME:
|
| 398 |
+
yield HF_SAFE_WEIGHTS_NAME
|
| 399 |
+
if filename == HF_OPEN_CLIP_WEIGHTS_NAME:
|
| 400 |
+
yield HF_OPEN_CLIP_SAFE_WEIGHTS_NAME
|
| 401 |
+
if filename not in (HF_WEIGHTS_NAME, HF_OPEN_CLIP_WEIGHTS_NAME) and filename.endswith(".bin"):
|
| 402 |
+
yield filename[:-4] + ".safetensors"
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_prune.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pkgutil
|
| 3 |
+
from copy import deepcopy
|
| 4 |
+
|
| 5 |
+
from torch import nn as nn
|
| 6 |
+
|
| 7 |
+
from timm.layers import Conv2dSame, BatchNormAct2d, Linear
|
| 8 |
+
|
| 9 |
+
__all__ = ['extract_layer', 'set_layer', 'adapt_model_from_string', 'adapt_model_from_file']
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def extract_layer(model, layer):
|
| 13 |
+
layer = layer.split('.')
|
| 14 |
+
module = model
|
| 15 |
+
if hasattr(model, 'module') and layer[0] != 'module':
|
| 16 |
+
module = model.module
|
| 17 |
+
if not hasattr(model, 'module') and layer[0] == 'module':
|
| 18 |
+
layer = layer[1:]
|
| 19 |
+
for l in layer:
|
| 20 |
+
if hasattr(module, l):
|
| 21 |
+
if not l.isdigit():
|
| 22 |
+
module = getattr(module, l)
|
| 23 |
+
else:
|
| 24 |
+
module = module[int(l)]
|
| 25 |
+
else:
|
| 26 |
+
return module
|
| 27 |
+
return module
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def set_layer(model, layer, val):
|
| 31 |
+
layer = layer.split('.')
|
| 32 |
+
module = model
|
| 33 |
+
if hasattr(model, 'module') and layer[0] != 'module':
|
| 34 |
+
module = model.module
|
| 35 |
+
lst_index = 0
|
| 36 |
+
module2 = module
|
| 37 |
+
for l in layer:
|
| 38 |
+
if hasattr(module2, l):
|
| 39 |
+
if not l.isdigit():
|
| 40 |
+
module2 = getattr(module2, l)
|
| 41 |
+
else:
|
| 42 |
+
module2 = module2[int(l)]
|
| 43 |
+
lst_index += 1
|
| 44 |
+
lst_index -= 1
|
| 45 |
+
for l in layer[:lst_index]:
|
| 46 |
+
if not l.isdigit():
|
| 47 |
+
module = getattr(module, l)
|
| 48 |
+
else:
|
| 49 |
+
module = module[int(l)]
|
| 50 |
+
l = layer[lst_index]
|
| 51 |
+
setattr(module, l, val)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def adapt_model_from_string(parent_module, model_string):
|
| 55 |
+
separator = '***'
|
| 56 |
+
state_dict = {}
|
| 57 |
+
lst_shape = model_string.split(separator)
|
| 58 |
+
for k in lst_shape:
|
| 59 |
+
k = k.split(':')
|
| 60 |
+
key = k[0]
|
| 61 |
+
shape = k[1][1:-1].split(',')
|
| 62 |
+
if shape[0] != '':
|
| 63 |
+
state_dict[key] = [int(i) for i in shape]
|
| 64 |
+
|
| 65 |
+
new_module = deepcopy(parent_module)
|
| 66 |
+
for n, m in parent_module.named_modules():
|
| 67 |
+
old_module = extract_layer(parent_module, n)
|
| 68 |
+
if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
|
| 69 |
+
if isinstance(old_module, Conv2dSame):
|
| 70 |
+
conv = Conv2dSame
|
| 71 |
+
else:
|
| 72 |
+
conv = nn.Conv2d
|
| 73 |
+
s = state_dict[n + '.weight']
|
| 74 |
+
in_channels = s[1]
|
| 75 |
+
out_channels = s[0]
|
| 76 |
+
g = 1
|
| 77 |
+
if old_module.groups > 1:
|
| 78 |
+
in_channels = out_channels
|
| 79 |
+
g = in_channels
|
| 80 |
+
new_conv = conv(
|
| 81 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size,
|
| 82 |
+
bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
|
| 83 |
+
groups=g, stride=old_module.stride)
|
| 84 |
+
set_layer(new_module, n, new_conv)
|
| 85 |
+
elif isinstance(old_module, BatchNormAct2d):
|
| 86 |
+
new_bn = BatchNormAct2d(
|
| 87 |
+
state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
|
| 88 |
+
affine=old_module.affine, track_running_stats=True)
|
| 89 |
+
new_bn.drop = old_module.drop
|
| 90 |
+
new_bn.act = old_module.act
|
| 91 |
+
set_layer(new_module, n, new_bn)
|
| 92 |
+
elif isinstance(old_module, nn.BatchNorm2d):
|
| 93 |
+
new_bn = nn.BatchNorm2d(
|
| 94 |
+
num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
|
| 95 |
+
affine=old_module.affine, track_running_stats=True)
|
| 96 |
+
set_layer(new_module, n, new_bn)
|
| 97 |
+
elif isinstance(old_module, nn.Linear):
|
| 98 |
+
# FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
|
| 99 |
+
num_features = state_dict[n + '.weight'][1]
|
| 100 |
+
new_fc = Linear(
|
| 101 |
+
in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
|
| 102 |
+
set_layer(new_module, n, new_fc)
|
| 103 |
+
if hasattr(new_module, 'num_features'):
|
| 104 |
+
new_module.num_features = num_features
|
| 105 |
+
new_module.eval()
|
| 106 |
+
parent_module.eval()
|
| 107 |
+
|
| 108 |
+
return new_module
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def adapt_model_from_file(parent_module, model_variant):
|
| 112 |
+
adapt_data = pkgutil.get_data(__name__, os.path.join('_pruned', model_variant + '.txt'))
|
| 113 |
+
return adapt_model_from_string(parent_module, adapt_data.decode('utf-8').strip())
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/beit.py
ADDED
|
@@ -0,0 +1,621 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" BEiT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
|
| 2 |
+
|
| 3 |
+
Model from official source: https://github.com/microsoft/unilm/tree/master/beit
|
| 4 |
+
|
| 5 |
+
@inproceedings{beit,
|
| 6 |
+
title={{BEiT}: {BERT} Pre-Training of Image Transformers},
|
| 7 |
+
author={Hangbo Bao and Li Dong and Songhao Piao and Furu Wei},
|
| 8 |
+
booktitle={International Conference on Learning Representations},
|
| 9 |
+
year={2022},
|
| 10 |
+
url={https://openreview.net/forum?id=p-BhZSz59o4}
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
BEiT-v2 from https://github.com/microsoft/unilm/tree/master/beit2
|
| 14 |
+
|
| 15 |
+
@article{beitv2,
|
| 16 |
+
title={{BEiT v2}: Masked Image Modeling with Vector-Quantized Visual Tokenizers},
|
| 17 |
+
author={Zhiliang Peng and Li Dong and Hangbo Bao and Qixiang Ye and Furu Wei},
|
| 18 |
+
year={2022},
|
| 19 |
+
eprint={2208.06366},
|
| 20 |
+
archivePrefix={arXiv},
|
| 21 |
+
primaryClass={cs.CV}
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
At this point only the 1k fine-tuned classification weights and model configs have been added,
|
| 25 |
+
see original source above for pre-training models and procedure.
|
| 26 |
+
|
| 27 |
+
Modifications by / Copyright 2021 Ross Wightman, original copyrights below
|
| 28 |
+
"""
|
| 29 |
+
# --------------------------------------------------------
|
| 30 |
+
# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
|
| 31 |
+
# Github source: https://github.com/microsoft/unilm/tree/master/beit
|
| 32 |
+
# Copyright (c) 2021 Microsoft
|
| 33 |
+
# Licensed under The MIT License [see LICENSE for details]
|
| 34 |
+
# By Hangbo Bao
|
| 35 |
+
# Based on timm and DeiT code bases
|
| 36 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
| 37 |
+
# https://github.com/facebookresearch/deit/
|
| 38 |
+
# https://github.com/facebookresearch/dino
|
| 39 |
+
# --------------------------------------------------------'
|
| 40 |
+
|
| 41 |
+
import math
|
| 42 |
+
from typing import Callable, Optional, Tuple, Union
|
| 43 |
+
|
| 44 |
+
import torch
|
| 45 |
+
import torch.nn as nn
|
| 46 |
+
import torch.nn.functional as F
|
| 47 |
+
from torch.utils.checkpoint import checkpoint
|
| 48 |
+
|
| 49 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 50 |
+
from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_, use_fused_attn
|
| 51 |
+
from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel_pos_bias_table, ndgrid
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
from ._builder import build_model_with_cfg
|
| 55 |
+
from ._registry import generate_default_cfgs, register_model
|
| 56 |
+
from .vision_transformer import checkpoint_filter_fn
|
| 57 |
+
|
| 58 |
+
__all__ = ['Beit']
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor:
|
| 62 |
+
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
| 63 |
+
# cls to token & token 2 cls & cls to cls
|
| 64 |
+
# get pair-wise relative position index for each token inside the window
|
| 65 |
+
window_area = window_size[0] * window_size[1]
|
| 66 |
+
coords = torch.stack(ndgrid(torch.arange(window_size[0]), torch.arange(window_size[1]))) # 2, Wh, Ww
|
| 67 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
| 68 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
| 69 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
| 70 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
| 71 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
| 72 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
| 73 |
+
relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype)
|
| 74 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
| 75 |
+
relative_position_index[0, 0:] = num_relative_distance - 3
|
| 76 |
+
relative_position_index[0:, 0] = num_relative_distance - 2
|
| 77 |
+
relative_position_index[0, 0] = num_relative_distance - 1
|
| 78 |
+
return relative_position_index
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class Attention(nn.Module):
|
| 82 |
+
fused_attn: torch.jit.Final[bool]
|
| 83 |
+
|
| 84 |
+
def __init__(
|
| 85 |
+
self,
|
| 86 |
+
dim: int,
|
| 87 |
+
num_heads: int = 8,
|
| 88 |
+
qkv_bias: bool = False,
|
| 89 |
+
attn_drop: float = 0.,
|
| 90 |
+
proj_drop: float = 0.,
|
| 91 |
+
window_size: Optional[Tuple[int, int]] = None,
|
| 92 |
+
attn_head_dim: Optional[int] = None,
|
| 93 |
+
):
|
| 94 |
+
super().__init__()
|
| 95 |
+
self.num_heads = num_heads
|
| 96 |
+
head_dim = dim // num_heads
|
| 97 |
+
if attn_head_dim is not None:
|
| 98 |
+
head_dim = attn_head_dim
|
| 99 |
+
all_head_dim = head_dim * self.num_heads
|
| 100 |
+
self.scale = head_dim ** -0.5
|
| 101 |
+
self.fused_attn = use_fused_attn()
|
| 102 |
+
|
| 103 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
| 104 |
+
if qkv_bias:
|
| 105 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
| 106 |
+
self.register_buffer('k_bias', torch.zeros(all_head_dim), persistent=False)
|
| 107 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
| 108 |
+
else:
|
| 109 |
+
self.q_bias = None
|
| 110 |
+
self.k_bias = None
|
| 111 |
+
self.v_bias = None
|
| 112 |
+
|
| 113 |
+
if window_size:
|
| 114 |
+
self.window_size = window_size
|
| 115 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
| 116 |
+
self.relative_position_bias_table = nn.Parameter(
|
| 117 |
+
torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
| 118 |
+
self.register_buffer("relative_position_index", gen_relative_position_index(window_size), persistent=False)
|
| 119 |
+
else:
|
| 120 |
+
self.window_size = None
|
| 121 |
+
self.relative_position_bias_table = None
|
| 122 |
+
self.relative_position_index = None
|
| 123 |
+
|
| 124 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 125 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
| 126 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 127 |
+
|
| 128 |
+
def _get_rel_pos_bias(self):
|
| 129 |
+
relative_position_bias = self.relative_position_bias_table[
|
| 130 |
+
self.relative_position_index.view(-1)].view(
|
| 131 |
+
self.window_size[0] * self.window_size[1] + 1,
|
| 132 |
+
self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
| 133 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
| 134 |
+
return relative_position_bias.unsqueeze(0)
|
| 135 |
+
|
| 136 |
+
def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None):
|
| 137 |
+
B, N, C = x.shape
|
| 138 |
+
|
| 139 |
+
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
|
| 140 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
| 141 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
| 142 |
+
q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
|
| 143 |
+
|
| 144 |
+
if self.fused_attn:
|
| 145 |
+
rel_pos_bias = None
|
| 146 |
+
if self.relative_position_bias_table is not None:
|
| 147 |
+
rel_pos_bias = self._get_rel_pos_bias()
|
| 148 |
+
if shared_rel_pos_bias is not None:
|
| 149 |
+
rel_pos_bias = rel_pos_bias + shared_rel_pos_bias
|
| 150 |
+
elif shared_rel_pos_bias is not None:
|
| 151 |
+
rel_pos_bias = shared_rel_pos_bias
|
| 152 |
+
|
| 153 |
+
x = F.scaled_dot_product_attention(
|
| 154 |
+
q, k, v,
|
| 155 |
+
attn_mask=rel_pos_bias,
|
| 156 |
+
dropout_p=self.attn_drop.p if self.training else 0.,
|
| 157 |
+
)
|
| 158 |
+
else:
|
| 159 |
+
q = q * self.scale
|
| 160 |
+
attn = (q @ k.transpose(-2, -1))
|
| 161 |
+
|
| 162 |
+
if self.relative_position_bias_table is not None:
|
| 163 |
+
attn = attn + self._get_rel_pos_bias()
|
| 164 |
+
if shared_rel_pos_bias is not None:
|
| 165 |
+
attn = attn + shared_rel_pos_bias
|
| 166 |
+
|
| 167 |
+
attn = attn.softmax(dim=-1)
|
| 168 |
+
attn = self.attn_drop(attn)
|
| 169 |
+
x = attn @ v
|
| 170 |
+
|
| 171 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
| 172 |
+
x = self.proj(x)
|
| 173 |
+
x = self.proj_drop(x)
|
| 174 |
+
return x
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class Block(nn.Module):
|
| 178 |
+
|
| 179 |
+
def __init__(
|
| 180 |
+
self,
|
| 181 |
+
dim: int,
|
| 182 |
+
num_heads: int,
|
| 183 |
+
qkv_bias: bool = False,
|
| 184 |
+
mlp_ratio: float = 4.,
|
| 185 |
+
scale_mlp: bool = False,
|
| 186 |
+
swiglu_mlp: bool = False,
|
| 187 |
+
proj_drop: float = 0.,
|
| 188 |
+
attn_drop: float = 0.,
|
| 189 |
+
drop_path: float = 0.,
|
| 190 |
+
init_values: Optional[float] = None,
|
| 191 |
+
act_layer: Callable = nn.GELU,
|
| 192 |
+
norm_layer: Callable = LayerNorm,
|
| 193 |
+
window_size: Optional[Tuple[int, int]] = None,
|
| 194 |
+
attn_head_dim: Optional[int] = None,
|
| 195 |
+
):
|
| 196 |
+
super().__init__()
|
| 197 |
+
self.norm1 = norm_layer(dim)
|
| 198 |
+
self.attn = Attention(
|
| 199 |
+
dim,
|
| 200 |
+
num_heads=num_heads,
|
| 201 |
+
qkv_bias=qkv_bias,
|
| 202 |
+
attn_drop=attn_drop,
|
| 203 |
+
proj_drop=proj_drop,
|
| 204 |
+
window_size=window_size,
|
| 205 |
+
attn_head_dim=attn_head_dim,
|
| 206 |
+
)
|
| 207 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 208 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 209 |
+
|
| 210 |
+
self.norm2 = norm_layer(dim)
|
| 211 |
+
if swiglu_mlp:
|
| 212 |
+
self.mlp = SwiGLU(
|
| 213 |
+
in_features=dim,
|
| 214 |
+
hidden_features=int(dim * mlp_ratio),
|
| 215 |
+
norm_layer=norm_layer if scale_mlp else None,
|
| 216 |
+
drop=proj_drop,
|
| 217 |
+
)
|
| 218 |
+
else:
|
| 219 |
+
self.mlp = Mlp(
|
| 220 |
+
in_features=dim,
|
| 221 |
+
hidden_features=int(dim * mlp_ratio),
|
| 222 |
+
act_layer=act_layer,
|
| 223 |
+
norm_layer=norm_layer if scale_mlp else None,
|
| 224 |
+
drop=proj_drop,
|
| 225 |
+
)
|
| 226 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 227 |
+
|
| 228 |
+
if init_values:
|
| 229 |
+
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim))
|
| 230 |
+
self.gamma_2 = nn.Parameter(init_values * torch.ones(dim))
|
| 231 |
+
else:
|
| 232 |
+
self.gamma_1, self.gamma_2 = None, None
|
| 233 |
+
|
| 234 |
+
def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None):
|
| 235 |
+
if self.gamma_1 is None:
|
| 236 |
+
x = x + self.drop_path1(self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias))
|
| 237 |
+
x = x + self.drop_path2(self.mlp(self.norm2(x)))
|
| 238 |
+
else:
|
| 239 |
+
x = x + self.drop_path1(self.gamma_1 * self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias))
|
| 240 |
+
x = x + self.drop_path2(self.gamma_2 * self.mlp(self.norm2(x)))
|
| 241 |
+
return x
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class RelativePositionBias(nn.Module):
|
| 245 |
+
|
| 246 |
+
def __init__(self, window_size, num_heads):
|
| 247 |
+
super().__init__()
|
| 248 |
+
self.window_size = window_size
|
| 249 |
+
self.window_area = window_size[0] * window_size[1]
|
| 250 |
+
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
| 251 |
+
self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads))
|
| 252 |
+
# trunc_normal_(self.relative_position_bias_table, std=.02)
|
| 253 |
+
self.register_buffer("relative_position_index", gen_relative_position_index(window_size))
|
| 254 |
+
|
| 255 |
+
def forward(self):
|
| 256 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
| 257 |
+
self.window_area + 1, self.window_area + 1, -1) # Wh*Ww,Wh*Ww,nH
|
| 258 |
+
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class Beit(nn.Module):
|
| 262 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
| 263 |
+
"""
|
| 264 |
+
|
| 265 |
+
def __init__(
|
| 266 |
+
self,
|
| 267 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
| 268 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
| 269 |
+
in_chans: int = 3,
|
| 270 |
+
num_classes: int = 1000,
|
| 271 |
+
global_pool: str = 'avg',
|
| 272 |
+
embed_dim: int = 768,
|
| 273 |
+
depth: int = 12,
|
| 274 |
+
num_heads: int = 12,
|
| 275 |
+
qkv_bias: bool = True,
|
| 276 |
+
mlp_ratio: float = 4.,
|
| 277 |
+
swiglu_mlp: bool = False,
|
| 278 |
+
scale_mlp: bool = False,
|
| 279 |
+
drop_rate: float = 0.,
|
| 280 |
+
pos_drop_rate: float = 0.,
|
| 281 |
+
proj_drop_rate: float = 0.,
|
| 282 |
+
attn_drop_rate: float = 0.,
|
| 283 |
+
drop_path_rate: float = 0.,
|
| 284 |
+
norm_layer: Callable = LayerNorm,
|
| 285 |
+
init_values: Optional[float] = None,
|
| 286 |
+
use_abs_pos_emb: bool = True,
|
| 287 |
+
use_rel_pos_bias: bool = False,
|
| 288 |
+
use_shared_rel_pos_bias: bool = False,
|
| 289 |
+
head_init_scale: float = 0.001,
|
| 290 |
+
):
|
| 291 |
+
super().__init__()
|
| 292 |
+
self.num_classes = num_classes
|
| 293 |
+
self.global_pool = global_pool
|
| 294 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 295 |
+
self.num_prefix_tokens = 1
|
| 296 |
+
self.grad_checkpointing = False
|
| 297 |
+
|
| 298 |
+
self.patch_embed = PatchEmbed(
|
| 299 |
+
img_size=img_size,
|
| 300 |
+
patch_size=patch_size,
|
| 301 |
+
in_chans=in_chans,
|
| 302 |
+
embed_dim=embed_dim,
|
| 303 |
+
)
|
| 304 |
+
num_patches = self.patch_embed.num_patches
|
| 305 |
+
|
| 306 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 307 |
+
# self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 308 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) if use_abs_pos_emb else None
|
| 309 |
+
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
| 310 |
+
|
| 311 |
+
if use_shared_rel_pos_bias:
|
| 312 |
+
self.rel_pos_bias = RelativePositionBias(
|
| 313 |
+
window_size=self.patch_embed.grid_size,
|
| 314 |
+
num_heads=num_heads,
|
| 315 |
+
)
|
| 316 |
+
else:
|
| 317 |
+
self.rel_pos_bias = None
|
| 318 |
+
|
| 319 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 320 |
+
self.blocks = nn.ModuleList([
|
| 321 |
+
Block(
|
| 322 |
+
dim=embed_dim,
|
| 323 |
+
num_heads=num_heads,
|
| 324 |
+
qkv_bias=qkv_bias,
|
| 325 |
+
mlp_ratio=mlp_ratio,
|
| 326 |
+
scale_mlp=scale_mlp,
|
| 327 |
+
swiglu_mlp=swiglu_mlp,
|
| 328 |
+
proj_drop=proj_drop_rate,
|
| 329 |
+
attn_drop=attn_drop_rate,
|
| 330 |
+
drop_path=dpr[i],
|
| 331 |
+
norm_layer=norm_layer,
|
| 332 |
+
init_values=init_values,
|
| 333 |
+
window_size=self.patch_embed.grid_size if use_rel_pos_bias else None,
|
| 334 |
+
)
|
| 335 |
+
for i in range(depth)])
|
| 336 |
+
|
| 337 |
+
use_fc_norm = self.global_pool == 'avg'
|
| 338 |
+
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
|
| 339 |
+
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
|
| 340 |
+
self.head_drop = nn.Dropout(drop_rate)
|
| 341 |
+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
| 342 |
+
|
| 343 |
+
self.apply(self._init_weights)
|
| 344 |
+
if self.pos_embed is not None:
|
| 345 |
+
trunc_normal_(self.pos_embed, std=.02)
|
| 346 |
+
trunc_normal_(self.cls_token, std=.02)
|
| 347 |
+
|
| 348 |
+
self.fix_init_weight()
|
| 349 |
+
if isinstance(self.head, nn.Linear):
|
| 350 |
+
trunc_normal_(self.head.weight, std=.02)
|
| 351 |
+
self.head.weight.data.mul_(head_init_scale)
|
| 352 |
+
self.head.bias.data.mul_(head_init_scale)
|
| 353 |
+
|
| 354 |
+
def fix_init_weight(self):
|
| 355 |
+
def rescale(param, layer_id):
|
| 356 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
| 357 |
+
|
| 358 |
+
for layer_id, layer in enumerate(self.blocks):
|
| 359 |
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
| 360 |
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
| 361 |
+
|
| 362 |
+
def _init_weights(self, m):
|
| 363 |
+
if isinstance(m, nn.Linear):
|
| 364 |
+
trunc_normal_(m.weight, std=.02)
|
| 365 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 366 |
+
nn.init.constant_(m.bias, 0)
|
| 367 |
+
elif isinstance(m, nn.LayerNorm):
|
| 368 |
+
nn.init.constant_(m.bias, 0)
|
| 369 |
+
nn.init.constant_(m.weight, 1.0)
|
| 370 |
+
|
| 371 |
+
@torch.jit.ignore
|
| 372 |
+
def no_weight_decay(self):
|
| 373 |
+
nwd = {'pos_embed', 'cls_token'}
|
| 374 |
+
for n, _ in self.named_parameters():
|
| 375 |
+
if 'relative_position_bias_table' in n:
|
| 376 |
+
nwd.add(n)
|
| 377 |
+
return nwd
|
| 378 |
+
|
| 379 |
+
@torch.jit.ignore
|
| 380 |
+
def set_grad_checkpointing(self, enable=True):
|
| 381 |
+
self.grad_checkpointing = enable
|
| 382 |
+
|
| 383 |
+
@torch.jit.ignore
|
| 384 |
+
def group_matcher(self, coarse=False):
|
| 385 |
+
matcher = dict(
|
| 386 |
+
stem=r'^cls_token|pos_embed|patch_embed|rel_pos_bias', # stem and embed
|
| 387 |
+
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))],
|
| 388 |
+
)
|
| 389 |
+
return matcher
|
| 390 |
+
|
| 391 |
+
@torch.jit.ignore
|
| 392 |
+
def get_classifier(self):
|
| 393 |
+
return self.head
|
| 394 |
+
|
| 395 |
+
def reset_classifier(self, num_classes, global_pool=None):
|
| 396 |
+
self.num_classes = num_classes
|
| 397 |
+
if global_pool is not None:
|
| 398 |
+
self.global_pool = global_pool
|
| 399 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
| 400 |
+
|
| 401 |
+
def forward_features(self, x):
|
| 402 |
+
x = self.patch_embed(x)
|
| 403 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
| 404 |
+
if self.pos_embed is not None:
|
| 405 |
+
x = x + self.pos_embed
|
| 406 |
+
x = self.pos_drop(x)
|
| 407 |
+
|
| 408 |
+
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
| 409 |
+
for blk in self.blocks:
|
| 410 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
| 411 |
+
x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias)
|
| 412 |
+
else:
|
| 413 |
+
x = blk(x, shared_rel_pos_bias=rel_pos_bias)
|
| 414 |
+
x = self.norm(x)
|
| 415 |
+
return x
|
| 416 |
+
|
| 417 |
+
def forward_head(self, x, pre_logits: bool = False):
|
| 418 |
+
if self.global_pool:
|
| 419 |
+
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
|
| 420 |
+
x = self.fc_norm(x)
|
| 421 |
+
x = self.head_drop(x)
|
| 422 |
+
return x if pre_logits else self.head(x)
|
| 423 |
+
|
| 424 |
+
def forward(self, x):
|
| 425 |
+
x = self.forward_features(x)
|
| 426 |
+
x = self.forward_head(x)
|
| 427 |
+
return x
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def _cfg(url='', **kwargs):
|
| 431 |
+
return {
|
| 432 |
+
'url': url,
|
| 433 |
+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
| 434 |
+
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
| 435 |
+
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
|
| 436 |
+
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
| 437 |
+
**kwargs
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
default_cfgs = generate_default_cfgs({
|
| 442 |
+
'beit_base_patch16_224.in22k_ft_in22k_in1k': _cfg(
|
| 443 |
+
#url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth',
|
| 444 |
+
hf_hub_id='timm/'),
|
| 445 |
+
'beit_base_patch16_384.in22k_ft_in22k_in1k': _cfg(
|
| 446 |
+
#url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth',
|
| 447 |
+
hf_hub_id='timm/',
|
| 448 |
+
input_size=(3, 384, 384), crop_pct=1.0,
|
| 449 |
+
),
|
| 450 |
+
'beit_base_patch16_224.in22k_ft_in22k': _cfg(
|
| 451 |
+
#url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth',
|
| 452 |
+
hf_hub_id='timm/',
|
| 453 |
+
num_classes=21841,
|
| 454 |
+
),
|
| 455 |
+
'beit_large_patch16_224.in22k_ft_in22k_in1k': _cfg(
|
| 456 |
+
#url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth',
|
| 457 |
+
hf_hub_id='timm/'),
|
| 458 |
+
'beit_large_patch16_384.in22k_ft_in22k_in1k': _cfg(
|
| 459 |
+
#url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth',
|
| 460 |
+
hf_hub_id='timm/',
|
| 461 |
+
input_size=(3, 384, 384), crop_pct=1.0,
|
| 462 |
+
),
|
| 463 |
+
'beit_large_patch16_512.in22k_ft_in22k_in1k': _cfg(
|
| 464 |
+
#url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth',
|
| 465 |
+
hf_hub_id='timm/',
|
| 466 |
+
input_size=(3, 512, 512), crop_pct=1.0,
|
| 467 |
+
),
|
| 468 |
+
'beit_large_patch16_224.in22k_ft_in22k': _cfg(
|
| 469 |
+
#url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth',
|
| 470 |
+
hf_hub_id='timm/',
|
| 471 |
+
num_classes=21841,
|
| 472 |
+
),
|
| 473 |
+
|
| 474 |
+
'beitv2_base_patch16_224.in1k_ft_in22k_in1k': _cfg(
|
| 475 |
+
#url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth',
|
| 476 |
+
hf_hub_id='timm/',
|
| 477 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
| 478 |
+
),
|
| 479 |
+
'beitv2_base_patch16_224.in1k_ft_in1k': _cfg(
|
| 480 |
+
#url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft1k.pth',
|
| 481 |
+
hf_hub_id='timm/',
|
| 482 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
| 483 |
+
),
|
| 484 |
+
'beitv2_base_patch16_224.in1k_ft_in22k': _cfg(
|
| 485 |
+
#url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth',
|
| 486 |
+
hf_hub_id='timm/',
|
| 487 |
+
num_classes=21841, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
| 488 |
+
),
|
| 489 |
+
'beitv2_large_patch16_224.in1k_ft_in22k_in1k': _cfg(
|
| 490 |
+
#url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth',
|
| 491 |
+
hf_hub_id='timm/',
|
| 492 |
+
crop_pct=0.95, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
| 493 |
+
),
|
| 494 |
+
'beitv2_large_patch16_224.in1k_ft_in1k': _cfg(
|
| 495 |
+
#url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft1k.pth',
|
| 496 |
+
hf_hub_id='timm/',
|
| 497 |
+
crop_pct=0.95, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
| 498 |
+
),
|
| 499 |
+
'beitv2_large_patch16_224.in1k_ft_in22k': _cfg(
|
| 500 |
+
#url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth',
|
| 501 |
+
hf_hub_id='timm/',
|
| 502 |
+
num_classes=21841, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
| 503 |
+
),
|
| 504 |
+
})
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
def _beit_checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antialias=True):
|
| 508 |
+
state_dict = state_dict.get('model', state_dict)
|
| 509 |
+
state_dict = state_dict.get('module', state_dict)
|
| 510 |
+
# beit v2 didn't strip module
|
| 511 |
+
|
| 512 |
+
out_dict = {}
|
| 513 |
+
for k, v in state_dict.items():
|
| 514 |
+
if 'relative_position_index' in k:
|
| 515 |
+
continue
|
| 516 |
+
if 'patch_embed.proj.weight' in k:
|
| 517 |
+
O, I, H, W = model.patch_embed.proj.weight.shape
|
| 518 |
+
if v.shape[-1] != W or v.shape[-2] != H:
|
| 519 |
+
v = resample_patch_embed(
|
| 520 |
+
v,
|
| 521 |
+
(H, W),
|
| 522 |
+
interpolation=interpolation,
|
| 523 |
+
antialias=antialias,
|
| 524 |
+
verbose=True,
|
| 525 |
+
)
|
| 526 |
+
elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
|
| 527 |
+
# To resize pos embedding when using model at different size from pretrained weights
|
| 528 |
+
num_prefix_tokens = 1
|
| 529 |
+
v = resample_abs_pos_embed(
|
| 530 |
+
v,
|
| 531 |
+
new_size=model.patch_embed.grid_size,
|
| 532 |
+
num_prefix_tokens=num_prefix_tokens,
|
| 533 |
+
interpolation=interpolation,
|
| 534 |
+
antialias=antialias,
|
| 535 |
+
verbose=True,
|
| 536 |
+
)
|
| 537 |
+
elif k.endswith('relative_position_bias_table'):
|
| 538 |
+
m = model.get_submodule(k[:-29])
|
| 539 |
+
if v.shape != m.relative_position_bias_table.shape or m.window_size[0] != m.window_size[1]:
|
| 540 |
+
v = resize_rel_pos_bias_table(
|
| 541 |
+
v,
|
| 542 |
+
new_window_size=m.window_size,
|
| 543 |
+
new_bias_shape=m.relative_position_bias_table.shape,
|
| 544 |
+
)
|
| 545 |
+
out_dict[k] = v
|
| 546 |
+
return out_dict
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
def _create_beit(variant, pretrained=False, **kwargs):
|
| 550 |
+
if kwargs.get('features_only', None):
|
| 551 |
+
raise RuntimeError('features_only not implemented for BEiT models.')
|
| 552 |
+
|
| 553 |
+
model = build_model_with_cfg(
|
| 554 |
+
Beit, variant, pretrained,
|
| 555 |
+
# FIXME an updated filter fn needed to interpolate rel pos emb if fine tuning to diff model sizes
|
| 556 |
+
pretrained_filter_fn=_beit_checkpoint_filter_fn,
|
| 557 |
+
**kwargs)
|
| 558 |
+
return model
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
@register_model
|
| 562 |
+
def beit_base_patch16_224(pretrained=False, **kwargs) -> Beit:
|
| 563 |
+
model_args = dict(
|
| 564 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
|
| 565 |
+
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1)
|
| 566 |
+
model = _create_beit('beit_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 567 |
+
return model
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
@register_model
|
| 571 |
+
def beit_base_patch16_384(pretrained=False, **kwargs) -> Beit:
|
| 572 |
+
model_args = dict(
|
| 573 |
+
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
| 574 |
+
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1)
|
| 575 |
+
model = _create_beit('beit_base_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 576 |
+
return model
|
| 577 |
+
|
| 578 |
+
|
| 579 |
+
@register_model
|
| 580 |
+
def beit_large_patch16_224(pretrained=False, **kwargs) -> Beit:
|
| 581 |
+
model_args = dict(
|
| 582 |
+
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
|
| 583 |
+
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
|
| 584 |
+
model = _create_beit('beit_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 585 |
+
return model
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
@register_model
|
| 589 |
+
def beit_large_patch16_384(pretrained=False, **kwargs) -> Beit:
|
| 590 |
+
model_args = dict(
|
| 591 |
+
img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16,
|
| 592 |
+
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
|
| 593 |
+
model = _create_beit('beit_large_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 594 |
+
return model
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
@register_model
|
| 598 |
+
def beit_large_patch16_512(pretrained=False, **kwargs) -> Beit:
|
| 599 |
+
model_args = dict(
|
| 600 |
+
img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16,
|
| 601 |
+
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
|
| 602 |
+
model = _create_beit('beit_large_patch16_512', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 603 |
+
return model
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
@register_model
|
| 607 |
+
def beitv2_base_patch16_224(pretrained=False, **kwargs) -> Beit:
|
| 608 |
+
model_args = dict(
|
| 609 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
|
| 610 |
+
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
|
| 611 |
+
model = _create_beit('beitv2_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 612 |
+
return model
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
@register_model
|
| 616 |
+
def beitv2_large_patch16_224(pretrained=False, **kwargs) -> Beit:
|
| 617 |
+
model_args = dict(
|
| 618 |
+
patch_size=16, embed_dim=1024, depth=24, num_heads=16,
|
| 619 |
+
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
|
| 620 |
+
model = _create_beit('beitv2_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 621 |
+
return model
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/byoanet.py
ADDED
|
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Bring-Your-Own-Attention Network
|
| 2 |
+
|
| 3 |
+
A flexible network w/ dataclass based config for stacking NN blocks including
|
| 4 |
+
self-attention (or similar) layers.
|
| 5 |
+
|
| 6 |
+
Currently used to implement experimental variants of:
|
| 7 |
+
* Bottleneck Transformers
|
| 8 |
+
* Lambda ResNets
|
| 9 |
+
* HaloNets
|
| 10 |
+
|
| 11 |
+
Consider all of the models definitions here as experimental WIP and likely to change.
|
| 12 |
+
|
| 13 |
+
Hacked together by / copyright Ross Wightman, 2021.
|
| 14 |
+
"""
|
| 15 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 16 |
+
from ._builder import build_model_with_cfg
|
| 17 |
+
from ._registry import register_model, generate_default_cfgs
|
| 18 |
+
from .byobnet import ByoBlockCfg, ByoModelCfg, ByobNet, interleave_blocks
|
| 19 |
+
|
| 20 |
+
__all__ = []
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
model_cfgs = dict(
|
| 24 |
+
|
| 25 |
+
botnet26t=ByoModelCfg(
|
| 26 |
+
blocks=(
|
| 27 |
+
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
| 28 |
+
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
|
| 29 |
+
interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
|
| 30 |
+
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
|
| 31 |
+
),
|
| 32 |
+
stem_chs=64,
|
| 33 |
+
stem_type='tiered',
|
| 34 |
+
stem_pool='maxpool',
|
| 35 |
+
fixed_input_size=True,
|
| 36 |
+
self_attn_layer='bottleneck',
|
| 37 |
+
self_attn_kwargs=dict()
|
| 38 |
+
),
|
| 39 |
+
sebotnet33ts=ByoModelCfg(
|
| 40 |
+
blocks=(
|
| 41 |
+
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
| 42 |
+
interleave_blocks(types=('bottle', 'self_attn'), every=[2], d=3, c=512, s=2, gs=0, br=0.25),
|
| 43 |
+
interleave_blocks(types=('bottle', 'self_attn'), every=[2], d=3, c=1024, s=2, gs=0, br=0.25),
|
| 44 |
+
ByoBlockCfg('self_attn', d=2, c=1536, s=2, gs=0, br=0.333),
|
| 45 |
+
),
|
| 46 |
+
stem_chs=64,
|
| 47 |
+
stem_type='tiered',
|
| 48 |
+
stem_pool='',
|
| 49 |
+
act_layer='silu',
|
| 50 |
+
num_features=1280,
|
| 51 |
+
attn_layer='se',
|
| 52 |
+
self_attn_layer='bottleneck',
|
| 53 |
+
self_attn_kwargs=dict()
|
| 54 |
+
),
|
| 55 |
+
botnet50ts=ByoModelCfg(
|
| 56 |
+
blocks=(
|
| 57 |
+
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
|
| 58 |
+
interleave_blocks(types=('bottle', 'self_attn'), every=4, d=4, c=512, s=2, gs=0, br=0.25),
|
| 59 |
+
interleave_blocks(types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25),
|
| 60 |
+
interleave_blocks(types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25),
|
| 61 |
+
),
|
| 62 |
+
stem_chs=64,
|
| 63 |
+
stem_type='tiered',
|
| 64 |
+
stem_pool='maxpool',
|
| 65 |
+
act_layer='silu',
|
| 66 |
+
fixed_input_size=True,
|
| 67 |
+
self_attn_layer='bottleneck',
|
| 68 |
+
self_attn_kwargs=dict()
|
| 69 |
+
),
|
| 70 |
+
eca_botnext26ts=ByoModelCfg(
|
| 71 |
+
blocks=(
|
| 72 |
+
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
|
| 73 |
+
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25),
|
| 74 |
+
interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=16, br=0.25),
|
| 75 |
+
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
|
| 76 |
+
),
|
| 77 |
+
stem_chs=64,
|
| 78 |
+
stem_type='tiered',
|
| 79 |
+
stem_pool='maxpool',
|
| 80 |
+
fixed_input_size=True,
|
| 81 |
+
act_layer='silu',
|
| 82 |
+
attn_layer='eca',
|
| 83 |
+
self_attn_layer='bottleneck',
|
| 84 |
+
self_attn_kwargs=dict(dim_head=16)
|
| 85 |
+
),
|
| 86 |
+
|
| 87 |
+
halonet_h1=ByoModelCfg(
|
| 88 |
+
blocks=(
|
| 89 |
+
ByoBlockCfg(type='self_attn', d=3, c=64, s=1, gs=0, br=1.0),
|
| 90 |
+
ByoBlockCfg(type='self_attn', d=3, c=128, s=2, gs=0, br=1.0),
|
| 91 |
+
ByoBlockCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0),
|
| 92 |
+
ByoBlockCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0),
|
| 93 |
+
),
|
| 94 |
+
stem_chs=64,
|
| 95 |
+
stem_type='7x7',
|
| 96 |
+
stem_pool='maxpool',
|
| 97 |
+
|
| 98 |
+
self_attn_layer='halo',
|
| 99 |
+
self_attn_kwargs=dict(block_size=8, halo_size=3),
|
| 100 |
+
),
|
| 101 |
+
halonet26t=ByoModelCfg(
|
| 102 |
+
blocks=(
|
| 103 |
+
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
| 104 |
+
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
|
| 105 |
+
interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
|
| 106 |
+
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
|
| 107 |
+
),
|
| 108 |
+
stem_chs=64,
|
| 109 |
+
stem_type='tiered',
|
| 110 |
+
stem_pool='maxpool',
|
| 111 |
+
self_attn_layer='halo',
|
| 112 |
+
self_attn_kwargs=dict(block_size=8, halo_size=2)
|
| 113 |
+
),
|
| 114 |
+
sehalonet33ts=ByoModelCfg(
|
| 115 |
+
blocks=(
|
| 116 |
+
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
| 117 |
+
interleave_blocks(types=('bottle', 'self_attn'), every=[2], d=3, c=512, s=2, gs=0, br=0.25),
|
| 118 |
+
interleave_blocks(types=('bottle', 'self_attn'), every=[2], d=3, c=1024, s=2, gs=0, br=0.25),
|
| 119 |
+
ByoBlockCfg('self_attn', d=2, c=1536, s=2, gs=0, br=0.333),
|
| 120 |
+
),
|
| 121 |
+
stem_chs=64,
|
| 122 |
+
stem_type='tiered',
|
| 123 |
+
stem_pool='',
|
| 124 |
+
act_layer='silu',
|
| 125 |
+
num_features=1280,
|
| 126 |
+
attn_layer='se',
|
| 127 |
+
self_attn_layer='halo',
|
| 128 |
+
self_attn_kwargs=dict(block_size=8, halo_size=3)
|
| 129 |
+
),
|
| 130 |
+
halonet50ts=ByoModelCfg(
|
| 131 |
+
blocks=(
|
| 132 |
+
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
|
| 133 |
+
interleave_blocks(
|
| 134 |
+
types=('bottle', 'self_attn'), every=4, d=4, c=512, s=2, gs=0, br=0.25,
|
| 135 |
+
self_attn_layer='halo', self_attn_kwargs=dict(block_size=8, halo_size=3, num_heads=4)),
|
| 136 |
+
interleave_blocks(types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25),
|
| 137 |
+
interleave_blocks(types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25),
|
| 138 |
+
),
|
| 139 |
+
stem_chs=64,
|
| 140 |
+
stem_type='tiered',
|
| 141 |
+
stem_pool='maxpool',
|
| 142 |
+
act_layer='silu',
|
| 143 |
+
self_attn_layer='halo',
|
| 144 |
+
self_attn_kwargs=dict(block_size=8, halo_size=3)
|
| 145 |
+
),
|
| 146 |
+
eca_halonext26ts=ByoModelCfg(
|
| 147 |
+
blocks=(
|
| 148 |
+
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
|
| 149 |
+
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25),
|
| 150 |
+
interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=16, br=0.25),
|
| 151 |
+
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
|
| 152 |
+
),
|
| 153 |
+
stem_chs=64,
|
| 154 |
+
stem_type='tiered',
|
| 155 |
+
stem_pool='maxpool',
|
| 156 |
+
act_layer='silu',
|
| 157 |
+
attn_layer='eca',
|
| 158 |
+
self_attn_layer='halo',
|
| 159 |
+
self_attn_kwargs=dict(block_size=8, halo_size=2, dim_head=16)
|
| 160 |
+
),
|
| 161 |
+
|
| 162 |
+
lambda_resnet26t=ByoModelCfg(
|
| 163 |
+
blocks=(
|
| 164 |
+
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
| 165 |
+
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
|
| 166 |
+
interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
|
| 167 |
+
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
|
| 168 |
+
),
|
| 169 |
+
stem_chs=64,
|
| 170 |
+
stem_type='tiered',
|
| 171 |
+
stem_pool='maxpool',
|
| 172 |
+
self_attn_layer='lambda',
|
| 173 |
+
self_attn_kwargs=dict(r=9)
|
| 174 |
+
),
|
| 175 |
+
lambda_resnet50ts=ByoModelCfg(
|
| 176 |
+
blocks=(
|
| 177 |
+
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
|
| 178 |
+
interleave_blocks(types=('bottle', 'self_attn'), every=4, d=4, c=512, s=2, gs=0, br=0.25),
|
| 179 |
+
interleave_blocks(types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25),
|
| 180 |
+
interleave_blocks(types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25),
|
| 181 |
+
),
|
| 182 |
+
stem_chs=64,
|
| 183 |
+
stem_type='tiered',
|
| 184 |
+
stem_pool='maxpool',
|
| 185 |
+
act_layer='silu',
|
| 186 |
+
self_attn_layer='lambda',
|
| 187 |
+
self_attn_kwargs=dict(r=9)
|
| 188 |
+
),
|
| 189 |
+
lambda_resnet26rpt_256=ByoModelCfg(
|
| 190 |
+
blocks=(
|
| 191 |
+
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
| 192 |
+
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
|
| 193 |
+
interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
|
| 194 |
+
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
|
| 195 |
+
),
|
| 196 |
+
stem_chs=64,
|
| 197 |
+
stem_type='tiered',
|
| 198 |
+
stem_pool='maxpool',
|
| 199 |
+
self_attn_layer='lambda',
|
| 200 |
+
self_attn_kwargs=dict(r=None)
|
| 201 |
+
),
|
| 202 |
+
|
| 203 |
+
# experimental
|
| 204 |
+
haloregnetz_b=ByoModelCfg(
|
| 205 |
+
blocks=(
|
| 206 |
+
ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=3),
|
| 207 |
+
ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=3),
|
| 208 |
+
interleave_blocks(types=('bottle', 'self_attn'), every=3, d=12, c=192, s=2, gs=16, br=3),
|
| 209 |
+
ByoBlockCfg('self_attn', d=2, c=288, s=2, gs=16, br=3),
|
| 210 |
+
),
|
| 211 |
+
stem_chs=32,
|
| 212 |
+
stem_pool='',
|
| 213 |
+
downsample='',
|
| 214 |
+
num_features=1536,
|
| 215 |
+
act_layer='silu',
|
| 216 |
+
attn_layer='se',
|
| 217 |
+
attn_kwargs=dict(rd_ratio=0.25),
|
| 218 |
+
block_kwargs=dict(bottle_in=True, linear_out=True),
|
| 219 |
+
self_attn_layer='halo',
|
| 220 |
+
self_attn_kwargs=dict(block_size=7, halo_size=2, qk_ratio=0.33)
|
| 221 |
+
),
|
| 222 |
+
|
| 223 |
+
# experimental
|
| 224 |
+
lamhalobotnet50ts=ByoModelCfg(
|
| 225 |
+
blocks=(
|
| 226 |
+
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
|
| 227 |
+
interleave_blocks(
|
| 228 |
+
types=('bottle', 'self_attn'), d=4, c=512, s=2, gs=0, br=0.25,
|
| 229 |
+
self_attn_layer='lambda', self_attn_kwargs=dict(r=13)),
|
| 230 |
+
interleave_blocks(
|
| 231 |
+
types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25,
|
| 232 |
+
self_attn_layer='halo', self_attn_kwargs=dict(halo_size=3)),
|
| 233 |
+
interleave_blocks(
|
| 234 |
+
types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25,
|
| 235 |
+
self_attn_layer='bottleneck', self_attn_kwargs=dict()),
|
| 236 |
+
),
|
| 237 |
+
stem_chs=64,
|
| 238 |
+
stem_type='tiered',
|
| 239 |
+
stem_pool='',
|
| 240 |
+
act_layer='silu',
|
| 241 |
+
),
|
| 242 |
+
halo2botnet50ts=ByoModelCfg(
|
| 243 |
+
blocks=(
|
| 244 |
+
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
|
| 245 |
+
interleave_blocks(
|
| 246 |
+
types=('bottle', 'self_attn'), d=4, c=512, s=2, gs=0, br=0.25,
|
| 247 |
+
self_attn_layer='halo', self_attn_kwargs=dict(halo_size=3)),
|
| 248 |
+
interleave_blocks(
|
| 249 |
+
types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25,
|
| 250 |
+
self_attn_layer='halo', self_attn_kwargs=dict(halo_size=3)),
|
| 251 |
+
interleave_blocks(
|
| 252 |
+
types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25,
|
| 253 |
+
self_attn_layer='bottleneck', self_attn_kwargs=dict()),
|
| 254 |
+
),
|
| 255 |
+
stem_chs=64,
|
| 256 |
+
stem_type='tiered',
|
| 257 |
+
stem_pool='',
|
| 258 |
+
act_layer='silu',
|
| 259 |
+
),
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs):
|
| 264 |
+
return build_model_with_cfg(
|
| 265 |
+
ByobNet, variant, pretrained,
|
| 266 |
+
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
|
| 267 |
+
feature_cfg=dict(flatten_sequential=True),
|
| 268 |
+
**kwargs,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def _cfg(url='', **kwargs):
|
| 273 |
+
return {
|
| 274 |
+
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
| 275 |
+
'crop_pct': 0.95, 'interpolation': 'bicubic',
|
| 276 |
+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
| 277 |
+
'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc',
|
| 278 |
+
'fixed_input_size': False, 'min_input_size': (3, 224, 224),
|
| 279 |
+
**kwargs
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
default_cfgs = generate_default_cfgs({
|
| 284 |
+
# GPU-Efficient (ResNet) weights
|
| 285 |
+
'botnet26t_256.c1_in1k': _cfg(
|
| 286 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/botnet26t_c1_256-167a0e9f.pth',
|
| 287 |
+
hf_hub_id='timm/',
|
| 288 |
+
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
|
| 289 |
+
'sebotnet33ts_256.a1h_in1k': _cfg(
|
| 290 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sebotnet33ts_a1h2_256-957e3c3e.pth',
|
| 291 |
+
hf_hub_id='timm/',
|
| 292 |
+
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
|
| 293 |
+
'botnet50ts_256.untrained': _cfg(
|
| 294 |
+
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
|
| 295 |
+
'eca_botnext26ts_256.c1_in1k': _cfg(
|
| 296 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_botnext26ts_c_256-95a898f6.pth',
|
| 297 |
+
hf_hub_id='timm/',
|
| 298 |
+
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
|
| 299 |
+
|
| 300 |
+
'halonet_h1.untrained': _cfg(input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
|
| 301 |
+
'halonet26t.a1h_in1k': _cfg(
|
| 302 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet26t_a1h_256-3083328c.pth',
|
| 303 |
+
hf_hub_id='timm/',
|
| 304 |
+
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
|
| 305 |
+
'sehalonet33ts.ra2_in1k': _cfg(
|
| 306 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sehalonet33ts_256-87e053f9.pth',
|
| 307 |
+
hf_hub_id='timm/',
|
| 308 |
+
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
|
| 309 |
+
'halonet50ts.a1h_in1k': _cfg(
|
| 310 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet50ts_a1h2_256-f3a3daee.pth',
|
| 311 |
+
hf_hub_id='timm/',
|
| 312 |
+
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
|
| 313 |
+
'eca_halonext26ts.c1_in1k': _cfg(
|
| 314 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_halonext26ts_c_256-06906299.pth',
|
| 315 |
+
hf_hub_id='timm/',
|
| 316 |
+
input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
|
| 317 |
+
|
| 318 |
+
'lambda_resnet26t.c1_in1k': _cfg(
|
| 319 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26t_c_256-e5a5c857.pth',
|
| 320 |
+
hf_hub_id='timm/',
|
| 321 |
+
min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
|
| 322 |
+
'lambda_resnet50ts.a1h_in1k': _cfg(
|
| 323 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet50ts_a1h_256-b87370f7.pth',
|
| 324 |
+
hf_hub_id='timm/',
|
| 325 |
+
min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)),
|
| 326 |
+
'lambda_resnet26rpt_256.c1_in1k': _cfg(
|
| 327 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26rpt_c_256-ab00292d.pth',
|
| 328 |
+
hf_hub_id='timm/',
|
| 329 |
+
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
|
| 330 |
+
|
| 331 |
+
'haloregnetz_b.ra3_in1k': _cfg(
|
| 332 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/haloregnetz_c_raa_256-c8ad7616.pth',
|
| 333 |
+
hf_hub_id='timm/',
|
| 334 |
+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
| 335 |
+
first_conv='stem.conv', input_size=(3, 224, 224), pool_size=(7, 7), min_input_size=(3, 224, 224), crop_pct=0.94),
|
| 336 |
+
|
| 337 |
+
'lamhalobotnet50ts_256.a1h_in1k': _cfg(
|
| 338 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lamhalobotnet50ts_a1h2_256-fe3d9445.pth',
|
| 339 |
+
hf_hub_id='timm/',
|
| 340 |
+
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
|
| 341 |
+
'halo2botnet50ts_256.a1h_in1k': _cfg(
|
| 342 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halo2botnet50ts_a1h2_256-fd9c11a3.pth',
|
| 343 |
+
hf_hub_id='timm/',
|
| 344 |
+
fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
|
| 345 |
+
})
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
@register_model
|
| 349 |
+
def botnet26t_256(pretrained=False, **kwargs) -> ByobNet:
|
| 350 |
+
""" Bottleneck Transformer w/ ResNet26-T backbone.
|
| 351 |
+
"""
|
| 352 |
+
kwargs.setdefault('img_size', 256)
|
| 353 |
+
return _create_byoanet('botnet26t_256', 'botnet26t', pretrained=pretrained, **kwargs)
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
@register_model
|
| 357 |
+
def sebotnet33ts_256(pretrained=False, **kwargs) -> ByobNet:
|
| 358 |
+
""" Bottleneck Transformer w/ a ResNet33-t backbone, SE attn for non Halo blocks, SiLU,
|
| 359 |
+
"""
|
| 360 |
+
return _create_byoanet('sebotnet33ts_256', 'sebotnet33ts', pretrained=pretrained, **kwargs)
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
@register_model
|
| 364 |
+
def botnet50ts_256(pretrained=False, **kwargs) -> ByobNet:
|
| 365 |
+
""" Bottleneck Transformer w/ ResNet50-T backbone, silu act.
|
| 366 |
+
"""
|
| 367 |
+
kwargs.setdefault('img_size', 256)
|
| 368 |
+
return _create_byoanet('botnet50ts_256', 'botnet50ts', pretrained=pretrained, **kwargs)
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
@register_model
|
| 372 |
+
def eca_botnext26ts_256(pretrained=False, **kwargs) -> ByobNet:
|
| 373 |
+
""" Bottleneck Transformer w/ ResNet26-T backbone, silu act.
|
| 374 |
+
"""
|
| 375 |
+
kwargs.setdefault('img_size', 256)
|
| 376 |
+
return _create_byoanet('eca_botnext26ts_256', 'eca_botnext26ts', pretrained=pretrained, **kwargs)
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
@register_model
|
| 380 |
+
def halonet_h1(pretrained=False, **kwargs) -> ByobNet:
|
| 381 |
+
""" HaloNet-H1. Halo attention in all stages as per the paper.
|
| 382 |
+
NOTE: This runs very slowly!
|
| 383 |
+
"""
|
| 384 |
+
return _create_byoanet('halonet_h1', pretrained=pretrained, **kwargs)
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
@register_model
|
| 388 |
+
def halonet26t(pretrained=False, **kwargs) -> ByobNet:
|
| 389 |
+
""" HaloNet w/ a ResNet26-t backbone. Halo attention in final two stages
|
| 390 |
+
"""
|
| 391 |
+
return _create_byoanet('halonet26t', pretrained=pretrained, **kwargs)
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
@register_model
|
| 395 |
+
def sehalonet33ts(pretrained=False, **kwargs) -> ByobNet:
|
| 396 |
+
""" HaloNet w/ a ResNet33-t backbone, SE attn for non Halo blocks, SiLU, 1-2 Halo in stage 2,3,4.
|
| 397 |
+
"""
|
| 398 |
+
return _create_byoanet('sehalonet33ts', pretrained=pretrained, **kwargs)
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
@register_model
|
| 402 |
+
def halonet50ts(pretrained=False, **kwargs) -> ByobNet:
|
| 403 |
+
""" HaloNet w/ a ResNet50-t backbone, silu act. Halo attention in final two stages
|
| 404 |
+
"""
|
| 405 |
+
return _create_byoanet('halonet50ts', pretrained=pretrained, **kwargs)
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
@register_model
|
| 409 |
+
def eca_halonext26ts(pretrained=False, **kwargs) -> ByobNet:
|
| 410 |
+
""" HaloNet w/ a ResNet26-t backbone, silu act. Halo attention in final two stages
|
| 411 |
+
"""
|
| 412 |
+
return _create_byoanet('eca_halonext26ts', pretrained=pretrained, **kwargs)
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
@register_model
|
| 416 |
+
def lambda_resnet26t(pretrained=False, **kwargs) -> ByobNet:
|
| 417 |
+
""" Lambda-ResNet-26-T. Lambda layers w/ conv pos in last two stages.
|
| 418 |
+
"""
|
| 419 |
+
return _create_byoanet('lambda_resnet26t', pretrained=pretrained, **kwargs)
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
@register_model
|
| 423 |
+
def lambda_resnet50ts(pretrained=False, **kwargs) -> ByobNet:
|
| 424 |
+
""" Lambda-ResNet-50-TS. SiLU act. Lambda layers w/ conv pos in last two stages.
|
| 425 |
+
"""
|
| 426 |
+
return _create_byoanet('lambda_resnet50ts', pretrained=pretrained, **kwargs)
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
@register_model
|
| 430 |
+
def lambda_resnet26rpt_256(pretrained=False, **kwargs) -> ByobNet:
|
| 431 |
+
""" Lambda-ResNet-26-R-T. Lambda layers w/ rel pos embed in last two stages.
|
| 432 |
+
"""
|
| 433 |
+
kwargs.setdefault('img_size', 256)
|
| 434 |
+
return _create_byoanet('lambda_resnet26rpt_256', pretrained=pretrained, **kwargs)
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
@register_model
|
| 438 |
+
def haloregnetz_b(pretrained=False, **kwargs) -> ByobNet:
|
| 439 |
+
""" Halo + RegNetZ
|
| 440 |
+
"""
|
| 441 |
+
return _create_byoanet('haloregnetz_b', pretrained=pretrained, **kwargs)
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
@register_model
|
| 445 |
+
def lamhalobotnet50ts_256(pretrained=False, **kwargs) -> ByobNet:
|
| 446 |
+
""" Combo Attention (Lambda + Halo + Bot) Network
|
| 447 |
+
"""
|
| 448 |
+
return _create_byoanet('lamhalobotnet50ts_256', 'lamhalobotnet50ts', pretrained=pretrained, **kwargs)
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
@register_model
|
| 452 |
+
def halo2botnet50ts_256(pretrained=False, **kwargs) -> ByobNet:
|
| 453 |
+
""" Combo Attention (Halo + Halo + Bot) Network
|
| 454 |
+
"""
|
| 455 |
+
return _create_byoanet('halo2botnet50ts_256', 'halo2botnet50ts', pretrained=pretrained, **kwargs)
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/byobnet.py
ADDED
|
@@ -0,0 +1,2245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Bring-Your-Own-Blocks Network
|
| 2 |
+
|
| 3 |
+
A flexible network w/ dataclass based config for stacking those NN blocks.
|
| 4 |
+
|
| 5 |
+
This model is currently used to implement the following networks:
|
| 6 |
+
|
| 7 |
+
GPU Efficient (ResNets) - gernet_l/m/s (original versions called genet, but this was already used (by SENet author)).
|
| 8 |
+
Paper: `Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
|
| 9 |
+
Code and weights: https://github.com/idstcv/GPU-Efficient-Networks, licensed Apache 2.0
|
| 10 |
+
|
| 11 |
+
RepVGG - repvgg_*
|
| 12 |
+
Paper: `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
|
| 13 |
+
Code and weights: https://github.com/DingXiaoH/RepVGG, licensed MIT
|
| 14 |
+
|
| 15 |
+
MobileOne - mobileone_*
|
| 16 |
+
Paper: `MobileOne: An Improved One millisecond Mobile Backbone` - https://arxiv.org/abs/2206.04040
|
| 17 |
+
Code and weights: https://github.com/apple/ml-mobileone, licensed MIT
|
| 18 |
+
|
| 19 |
+
In all cases the models have been modified to fit within the design of ByobNet. I've remapped
|
| 20 |
+
the original weights and verified accuracies.
|
| 21 |
+
|
| 22 |
+
For GPU Efficient nets, I used the original names for the blocks since they were for the most part
|
| 23 |
+
the same as original residual blocks in ResNe(X)t, DarkNet, and other existing models. Note also some
|
| 24 |
+
changes introduced in RegNet were also present in the stem and bottleneck blocks for this model.
|
| 25 |
+
|
| 26 |
+
A significant number of different network archs can be implemented here, including variants of the
|
| 27 |
+
above nets that include attention.
|
| 28 |
+
|
| 29 |
+
Hacked together by / copyright Ross Wightman, 2021.
|
| 30 |
+
"""
|
| 31 |
+
import math
|
| 32 |
+
from dataclasses import dataclass, field, replace
|
| 33 |
+
from functools import partial
|
| 34 |
+
from typing import Tuple, List, Dict, Optional, Union, Any, Callable, Sequence
|
| 35 |
+
|
| 36 |
+
import torch
|
| 37 |
+
import torch.nn as nn
|
| 38 |
+
|
| 39 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 40 |
+
from timm.layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
|
| 41 |
+
create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0a
|
| 42 |
+
from ._builder import build_model_with_cfg
|
| 43 |
+
from ._manipulate import named_apply, checkpoint_seq
|
| 44 |
+
from ._registry import generate_default_cfgs, register_model
|
| 45 |
+
|
| 46 |
+
__all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block']
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@dataclass
|
| 50 |
+
class ByoBlockCfg:
|
| 51 |
+
type: Union[str, nn.Module]
|
| 52 |
+
d: int # block depth (number of block repeats in stage)
|
| 53 |
+
c: int # number of output channels for each block in stage
|
| 54 |
+
s: int = 2 # stride of stage (first block)
|
| 55 |
+
gs: Optional[Union[int, Callable]] = None # group-size of blocks in stage, conv is depthwise if gs == 1
|
| 56 |
+
br: float = 1. # bottleneck-ratio of blocks in stage
|
| 57 |
+
|
| 58 |
+
# NOTE: these config items override the model cfgs that are applied to all blocks by default
|
| 59 |
+
attn_layer: Optional[str] = None
|
| 60 |
+
attn_kwargs: Optional[Dict[str, Any]] = None
|
| 61 |
+
self_attn_layer: Optional[str] = None
|
| 62 |
+
self_attn_kwargs: Optional[Dict[str, Any]] = None
|
| 63 |
+
block_kwargs: Optional[Dict[str, Any]] = None
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@dataclass
|
| 67 |
+
class ByoModelCfg:
|
| 68 |
+
blocks: Tuple[Union[ByoBlockCfg, Tuple[ByoBlockCfg, ...]], ...]
|
| 69 |
+
downsample: str = 'conv1x1'
|
| 70 |
+
stem_type: str = '3x3'
|
| 71 |
+
stem_pool: Optional[str] = 'maxpool'
|
| 72 |
+
stem_chs: int = 32
|
| 73 |
+
width_factor: float = 1.0
|
| 74 |
+
num_features: int = 0 # num out_channels for final conv, no final 1x1 conv if 0
|
| 75 |
+
zero_init_last: bool = True # zero init last weight (usually bn) in residual path
|
| 76 |
+
fixed_input_size: bool = False # model constrained to a fixed-input size / img_size must be provided on creation
|
| 77 |
+
|
| 78 |
+
act_layer: str = 'relu'
|
| 79 |
+
norm_layer: str = 'batchnorm'
|
| 80 |
+
|
| 81 |
+
# NOTE: these config items will be overridden by the block cfg (per-block) if they are set there
|
| 82 |
+
attn_layer: Optional[str] = None
|
| 83 |
+
attn_kwargs: dict = field(default_factory=lambda: dict())
|
| 84 |
+
self_attn_layer: Optional[str] = None
|
| 85 |
+
self_attn_kwargs: dict = field(default_factory=lambda: dict())
|
| 86 |
+
block_kwargs: Dict[str, Any] = field(default_factory=lambda: dict())
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _rep_vgg_bcfg(d=(4, 6, 16, 1), wf=(1., 1., 1., 1.), groups=0):
|
| 90 |
+
c = (64, 128, 256, 512)
|
| 91 |
+
group_size = 0
|
| 92 |
+
if groups > 0:
|
| 93 |
+
group_size = lambda chs, idx: chs // groups if (idx + 1) % 2 == 0 else 0
|
| 94 |
+
bcfg = tuple([ByoBlockCfg(type='rep', d=d, c=c * wf, gs=group_size) for d, c, wf in zip(d, c, wf)])
|
| 95 |
+
return bcfg
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _mobileone_bcfg(d=(2, 8, 10, 1), wf=(1., 1., 1., 1.), se_blocks=(), num_conv_branches=1):
|
| 99 |
+
c = (64, 128, 256, 512)
|
| 100 |
+
prev_c = min(64, c[0] * wf[0])
|
| 101 |
+
se_blocks = se_blocks or (0,) * len(d)
|
| 102 |
+
bcfg = []
|
| 103 |
+
for d, c, w, se in zip(d, c, wf, se_blocks):
|
| 104 |
+
scfg = []
|
| 105 |
+
for i in range(d):
|
| 106 |
+
out_c = c * w
|
| 107 |
+
bk = dict(num_conv_branches=num_conv_branches)
|
| 108 |
+
ak = {}
|
| 109 |
+
if i >= d - se:
|
| 110 |
+
ak['attn_layer'] = 'se'
|
| 111 |
+
scfg += [ByoBlockCfg(type='one', d=1, c=prev_c, gs=1, block_kwargs=bk, **ak)] # depthwise block
|
| 112 |
+
scfg += [ByoBlockCfg(
|
| 113 |
+
type='one', d=1, c=out_c, gs=0, block_kwargs=dict(kernel_size=1, **bk), **ak)] # pointwise block
|
| 114 |
+
prev_c = out_c
|
| 115 |
+
bcfg += [scfg]
|
| 116 |
+
return bcfg
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def interleave_blocks(
|
| 120 |
+
types: Tuple[str, str], d,
|
| 121 |
+
every: Union[int, List[int]] = 1,
|
| 122 |
+
first: bool = False,
|
| 123 |
+
**kwargs,
|
| 124 |
+
) -> Tuple[ByoBlockCfg]:
|
| 125 |
+
""" interleave 2 block types in stack
|
| 126 |
+
"""
|
| 127 |
+
assert len(types) == 2
|
| 128 |
+
if isinstance(every, int):
|
| 129 |
+
every = list(range(0 if first else every, d, every + 1))
|
| 130 |
+
if not every:
|
| 131 |
+
every = [d - 1]
|
| 132 |
+
set(every)
|
| 133 |
+
blocks = []
|
| 134 |
+
for i in range(d):
|
| 135 |
+
block_type = types[1] if i in every else types[0]
|
| 136 |
+
blocks += [ByoBlockCfg(type=block_type, d=1, **kwargs)]
|
| 137 |
+
return tuple(blocks)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def expand_blocks_cfg(stage_blocks_cfg: Union[ByoBlockCfg, Sequence[ByoBlockCfg]]) -> List[ByoBlockCfg]:
|
| 141 |
+
if not isinstance(stage_blocks_cfg, Sequence):
|
| 142 |
+
stage_blocks_cfg = (stage_blocks_cfg,)
|
| 143 |
+
block_cfgs = []
|
| 144 |
+
for i, cfg in enumerate(stage_blocks_cfg):
|
| 145 |
+
block_cfgs += [replace(cfg, d=1) for _ in range(cfg.d)]
|
| 146 |
+
return block_cfgs
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def num_groups(group_size, channels):
|
| 150 |
+
if not group_size: # 0 or None
|
| 151 |
+
return 1 # normal conv with 1 group
|
| 152 |
+
else:
|
| 153 |
+
# NOTE group_size == 1 -> depthwise conv
|
| 154 |
+
assert channels % group_size == 0
|
| 155 |
+
return channels // group_size
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
@dataclass
|
| 159 |
+
class LayerFn:
|
| 160 |
+
conv_norm_act: Callable = ConvNormAct
|
| 161 |
+
norm_act: Callable = BatchNormAct2d
|
| 162 |
+
act: Callable = nn.ReLU
|
| 163 |
+
attn: Optional[Callable] = None
|
| 164 |
+
self_attn: Optional[Callable] = None
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
class DownsampleAvg(nn.Module):
|
| 168 |
+
def __init__(
|
| 169 |
+
self,
|
| 170 |
+
in_chs: int,
|
| 171 |
+
out_chs: int,
|
| 172 |
+
stride: int = 1,
|
| 173 |
+
dilation: int = 1,
|
| 174 |
+
apply_act: bool = False,
|
| 175 |
+
layers: LayerFn = None,
|
| 176 |
+
):
|
| 177 |
+
""" AvgPool Downsampling as in 'D' ResNet variants."""
|
| 178 |
+
super(DownsampleAvg, self).__init__()
|
| 179 |
+
layers = layers or LayerFn()
|
| 180 |
+
avg_stride = stride if dilation == 1 else 1
|
| 181 |
+
if stride > 1 or dilation > 1:
|
| 182 |
+
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
|
| 183 |
+
self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
|
| 184 |
+
else:
|
| 185 |
+
self.pool = nn.Identity()
|
| 186 |
+
self.conv = layers.conv_norm_act(in_chs, out_chs, 1, apply_act=apply_act)
|
| 187 |
+
|
| 188 |
+
def forward(self, x):
|
| 189 |
+
return self.conv(self.pool(x))
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def create_shortcut(
|
| 193 |
+
downsample_type: str,
|
| 194 |
+
in_chs: int,
|
| 195 |
+
out_chs: int,
|
| 196 |
+
stride: int,
|
| 197 |
+
dilation: Tuple[int, int],
|
| 198 |
+
layers: LayerFn,
|
| 199 |
+
**kwargs,
|
| 200 |
+
):
|
| 201 |
+
assert downsample_type in ('avg', 'conv1x1', '')
|
| 202 |
+
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
|
| 203 |
+
if not downsample_type:
|
| 204 |
+
return None # no shortcut
|
| 205 |
+
elif downsample_type == 'avg':
|
| 206 |
+
return DownsampleAvg(in_chs, out_chs, stride=stride, dilation=dilation[0], **kwargs)
|
| 207 |
+
else:
|
| 208 |
+
return layers.conv_norm_act(in_chs, out_chs, kernel_size=1, stride=stride, dilation=dilation[0], **kwargs)
|
| 209 |
+
else:
|
| 210 |
+
return nn.Identity() # identity shortcut
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class BasicBlock(nn.Module):
|
| 214 |
+
""" ResNet Basic Block - kxk + kxk
|
| 215 |
+
"""
|
| 216 |
+
|
| 217 |
+
def __init__(
|
| 218 |
+
self,
|
| 219 |
+
in_chs: int,
|
| 220 |
+
out_chs: int,
|
| 221 |
+
kernel_size: int = 3,
|
| 222 |
+
stride: int = 1,
|
| 223 |
+
dilation: Tuple[int, int] = (1, 1),
|
| 224 |
+
group_size: Optional[int] = None,
|
| 225 |
+
bottle_ratio: float = 1.0,
|
| 226 |
+
downsample: str = 'avg',
|
| 227 |
+
attn_last: bool = True,
|
| 228 |
+
linear_out: bool = False,
|
| 229 |
+
layers: LayerFn = None,
|
| 230 |
+
drop_block: Callable = None,
|
| 231 |
+
drop_path_rate: float = 0.,
|
| 232 |
+
):
|
| 233 |
+
super(BasicBlock, self).__init__()
|
| 234 |
+
layers = layers or LayerFn()
|
| 235 |
+
mid_chs = make_divisible(out_chs * bottle_ratio)
|
| 236 |
+
groups = num_groups(group_size, mid_chs)
|
| 237 |
+
|
| 238 |
+
self.shortcut = create_shortcut(
|
| 239 |
+
downsample, in_chs, out_chs,
|
| 240 |
+
stride=stride, dilation=dilation, apply_act=False, layers=layers,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
self.conv1_kxk = layers.conv_norm_act(in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0])
|
| 244 |
+
self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
|
| 245 |
+
self.conv2_kxk = layers.conv_norm_act(
|
| 246 |
+
mid_chs, out_chs, kernel_size,
|
| 247 |
+
dilation=dilation[1], groups=groups, drop_layer=drop_block, apply_act=False,
|
| 248 |
+
)
|
| 249 |
+
self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
|
| 250 |
+
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
| 251 |
+
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
| 252 |
+
|
| 253 |
+
def init_weights(self, zero_init_last: bool = False):
|
| 254 |
+
if zero_init_last and self.shortcut is not None and getattr(self.conv2_kxk.bn, 'weight', None) is not None:
|
| 255 |
+
nn.init.zeros_(self.conv2_kxk.bn.weight)
|
| 256 |
+
for attn in (self.attn, self.attn_last):
|
| 257 |
+
if hasattr(attn, 'reset_parameters'):
|
| 258 |
+
attn.reset_parameters()
|
| 259 |
+
|
| 260 |
+
def forward(self, x):
|
| 261 |
+
shortcut = x
|
| 262 |
+
x = self.conv1_kxk(x)
|
| 263 |
+
x = self.conv2_kxk(x)
|
| 264 |
+
x = self.attn(x)
|
| 265 |
+
x = self.drop_path(x)
|
| 266 |
+
if self.shortcut is not None:
|
| 267 |
+
x = x + self.shortcut(shortcut)
|
| 268 |
+
return self.act(x)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class BottleneckBlock(nn.Module):
|
| 272 |
+
""" ResNet-like Bottleneck Block - 1x1 - kxk - 1x1
|
| 273 |
+
"""
|
| 274 |
+
|
| 275 |
+
def __init__(
|
| 276 |
+
self,
|
| 277 |
+
in_chs: int,
|
| 278 |
+
out_chs: int,
|
| 279 |
+
kernel_size: int = 3,
|
| 280 |
+
stride: int = 1,
|
| 281 |
+
dilation: Tuple[int, int] = (1, 1),
|
| 282 |
+
bottle_ratio: float = 1.,
|
| 283 |
+
group_size: Optional[int] = None,
|
| 284 |
+
downsample: str = 'avg',
|
| 285 |
+
attn_last: bool = False,
|
| 286 |
+
linear_out: bool = False,
|
| 287 |
+
extra_conv: bool = False,
|
| 288 |
+
bottle_in: bool = False,
|
| 289 |
+
layers: LayerFn = None,
|
| 290 |
+
drop_block: Callable = None,
|
| 291 |
+
drop_path_rate: float = 0.,
|
| 292 |
+
):
|
| 293 |
+
super(BottleneckBlock, self).__init__()
|
| 294 |
+
layers = layers or LayerFn()
|
| 295 |
+
mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
|
| 296 |
+
groups = num_groups(group_size, mid_chs)
|
| 297 |
+
|
| 298 |
+
self.shortcut = create_shortcut(
|
| 299 |
+
downsample, in_chs, out_chs,
|
| 300 |
+
stride=stride, dilation=dilation, apply_act=False, layers=layers,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
|
| 304 |
+
self.conv2_kxk = layers.conv_norm_act(
|
| 305 |
+
mid_chs, mid_chs, kernel_size,
|
| 306 |
+
stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block,
|
| 307 |
+
)
|
| 308 |
+
if extra_conv:
|
| 309 |
+
self.conv2b_kxk = layers.conv_norm_act(
|
| 310 |
+
mid_chs, mid_chs, kernel_size, dilation=dilation[1], groups=groups)
|
| 311 |
+
else:
|
| 312 |
+
self.conv2b_kxk = nn.Identity()
|
| 313 |
+
self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
|
| 314 |
+
self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
|
| 315 |
+
self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
|
| 316 |
+
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
| 317 |
+
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
| 318 |
+
|
| 319 |
+
def init_weights(self, zero_init_last: bool = False):
|
| 320 |
+
if zero_init_last and self.shortcut is not None and getattr(self.conv3_1x1.bn, 'weight', None) is not None:
|
| 321 |
+
nn.init.zeros_(self.conv3_1x1.bn.weight)
|
| 322 |
+
for attn in (self.attn, self.attn_last):
|
| 323 |
+
if hasattr(attn, 'reset_parameters'):
|
| 324 |
+
attn.reset_parameters()
|
| 325 |
+
|
| 326 |
+
def forward(self, x):
|
| 327 |
+
shortcut = x
|
| 328 |
+
x = self.conv1_1x1(x)
|
| 329 |
+
x = self.conv2_kxk(x)
|
| 330 |
+
x = self.conv2b_kxk(x)
|
| 331 |
+
x = self.attn(x)
|
| 332 |
+
x = self.conv3_1x1(x)
|
| 333 |
+
x = self.attn_last(x)
|
| 334 |
+
x = self.drop_path(x)
|
| 335 |
+
if self.shortcut is not None:
|
| 336 |
+
x = x + self.shortcut(shortcut)
|
| 337 |
+
return self.act(x)
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
class DarkBlock(nn.Module):
|
| 341 |
+
""" DarkNet-like (1x1 + 3x3 w/ stride) block
|
| 342 |
+
|
| 343 |
+
The GE-Net impl included a 1x1 + 3x3 block in their search space. It was not used in the feature models.
|
| 344 |
+
This block is pretty much a DarkNet block (also DenseNet) hence the name. Neither DarkNet or DenseNet
|
| 345 |
+
uses strides within the block (external 3x3 or maxpool downsampling is done in front of the block repeats).
|
| 346 |
+
|
| 347 |
+
If one does want to use a lot of these blocks w/ stride, I'd recommend using the EdgeBlock (3x3 /w stride + 1x1)
|
| 348 |
+
for more optimal compute.
|
| 349 |
+
"""
|
| 350 |
+
|
| 351 |
+
def __init__(
|
| 352 |
+
self,
|
| 353 |
+
in_chs: int,
|
| 354 |
+
out_chs: int,
|
| 355 |
+
kernel_size: int = 3,
|
| 356 |
+
stride: int = 1,
|
| 357 |
+
dilation: Tuple[int, int] = (1, 1),
|
| 358 |
+
bottle_ratio: float = 1.0,
|
| 359 |
+
group_size: Optional[int] = None,
|
| 360 |
+
downsample: str = 'avg',
|
| 361 |
+
attn_last: bool = True,
|
| 362 |
+
linear_out: bool = False,
|
| 363 |
+
layers: LayerFn = None,
|
| 364 |
+
drop_block: Callable = None,
|
| 365 |
+
drop_path_rate: float = 0.,
|
| 366 |
+
):
|
| 367 |
+
super(DarkBlock, self).__init__()
|
| 368 |
+
layers = layers or LayerFn()
|
| 369 |
+
mid_chs = make_divisible(out_chs * bottle_ratio)
|
| 370 |
+
groups = num_groups(group_size, mid_chs)
|
| 371 |
+
|
| 372 |
+
self.shortcut = create_shortcut(
|
| 373 |
+
downsample, in_chs, out_chs,
|
| 374 |
+
stride=stride, dilation=dilation, apply_act=False, layers=layers,
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
|
| 378 |
+
self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
|
| 379 |
+
self.conv2_kxk = layers.conv_norm_act(
|
| 380 |
+
mid_chs, out_chs, kernel_size,
|
| 381 |
+
stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block, apply_act=False,
|
| 382 |
+
)
|
| 383 |
+
self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
|
| 384 |
+
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
| 385 |
+
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
| 386 |
+
|
| 387 |
+
def init_weights(self, zero_init_last: bool = False):
|
| 388 |
+
if zero_init_last and self.shortcut is not None and getattr(self.conv2_kxk.bn, 'weight', None) is not None:
|
| 389 |
+
nn.init.zeros_(self.conv2_kxk.bn.weight)
|
| 390 |
+
for attn in (self.attn, self.attn_last):
|
| 391 |
+
if hasattr(attn, 'reset_parameters'):
|
| 392 |
+
attn.reset_parameters()
|
| 393 |
+
|
| 394 |
+
def forward(self, x):
|
| 395 |
+
shortcut = x
|
| 396 |
+
x = self.conv1_1x1(x)
|
| 397 |
+
x = self.attn(x)
|
| 398 |
+
x = self.conv2_kxk(x)
|
| 399 |
+
x = self.attn_last(x)
|
| 400 |
+
x = self.drop_path(x)
|
| 401 |
+
if self.shortcut is not None:
|
| 402 |
+
x = x + self.shortcut(shortcut)
|
| 403 |
+
return self.act(x)
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
class EdgeBlock(nn.Module):
|
| 407 |
+
""" EdgeResidual-like (3x3 + 1x1) block
|
| 408 |
+
|
| 409 |
+
A two layer block like DarkBlock, but with the order of the 3x3 and 1x1 convs reversed.
|
| 410 |
+
Very similar to the EfficientNet Edge-Residual block but this block it ends with activations, is
|
| 411 |
+
intended to be used with either expansion or bottleneck contraction, and can use DW/group/non-grouped convs.
|
| 412 |
+
|
| 413 |
+
FIXME is there a more common 3x3 + 1x1 conv block to name this after?
|
| 414 |
+
"""
|
| 415 |
+
|
| 416 |
+
def __init__(
|
| 417 |
+
self,
|
| 418 |
+
in_chs: int,
|
| 419 |
+
out_chs: int,
|
| 420 |
+
kernel_size: int = 3,
|
| 421 |
+
stride: int = 1,
|
| 422 |
+
dilation: Tuple[int, int] = (1, 1),
|
| 423 |
+
bottle_ratio: float = 1.0,
|
| 424 |
+
group_size: Optional[int] = None,
|
| 425 |
+
downsample: str = 'avg',
|
| 426 |
+
attn_last: bool = False,
|
| 427 |
+
linear_out: bool = False,
|
| 428 |
+
layers: LayerFn = None,
|
| 429 |
+
drop_block: Callable = None,
|
| 430 |
+
drop_path_rate: float = 0.,
|
| 431 |
+
):
|
| 432 |
+
super(EdgeBlock, self).__init__()
|
| 433 |
+
layers = layers or LayerFn()
|
| 434 |
+
mid_chs = make_divisible(out_chs * bottle_ratio)
|
| 435 |
+
groups = num_groups(group_size, mid_chs)
|
| 436 |
+
|
| 437 |
+
self.shortcut = create_shortcut(
|
| 438 |
+
downsample, in_chs, out_chs,
|
| 439 |
+
stride=stride, dilation=dilation, apply_act=False, layers=layers,
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
self.conv1_kxk = layers.conv_norm_act(
|
| 443 |
+
in_chs, mid_chs, kernel_size,
|
| 444 |
+
stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block,
|
| 445 |
+
)
|
| 446 |
+
self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
|
| 447 |
+
self.conv2_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
|
| 448 |
+
self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
|
| 449 |
+
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
| 450 |
+
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
| 451 |
+
|
| 452 |
+
def init_weights(self, zero_init_last: bool = False):
|
| 453 |
+
if zero_init_last and self.shortcut is not None and getattr(self.conv2_1x1.bn, 'weight', None) is not None:
|
| 454 |
+
nn.init.zeros_(self.conv2_1x1.bn.weight)
|
| 455 |
+
for attn in (self.attn, self.attn_last):
|
| 456 |
+
if hasattr(attn, 'reset_parameters'):
|
| 457 |
+
attn.reset_parameters()
|
| 458 |
+
|
| 459 |
+
def forward(self, x):
|
| 460 |
+
shortcut = x
|
| 461 |
+
x = self.conv1_kxk(x)
|
| 462 |
+
x = self.attn(x)
|
| 463 |
+
x = self.conv2_1x1(x)
|
| 464 |
+
x = self.attn_last(x)
|
| 465 |
+
x = self.drop_path(x)
|
| 466 |
+
if self.shortcut is not None:
|
| 467 |
+
x = x + self.shortcut(shortcut)
|
| 468 |
+
return self.act(x)
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
class RepVggBlock(nn.Module):
|
| 472 |
+
""" RepVGG Block.
|
| 473 |
+
|
| 474 |
+
Adapted from impl at https://github.com/DingXiaoH/RepVGG
|
| 475 |
+
"""
|
| 476 |
+
|
| 477 |
+
def __init__(
|
| 478 |
+
self,
|
| 479 |
+
in_chs: int,
|
| 480 |
+
out_chs: int,
|
| 481 |
+
kernel_size: int = 3,
|
| 482 |
+
stride: int = 1,
|
| 483 |
+
dilation: Tuple[int, int] = (1, 1),
|
| 484 |
+
bottle_ratio: float = 1.0,
|
| 485 |
+
group_size: Optional[int] = None,
|
| 486 |
+
downsample: str = '',
|
| 487 |
+
layers: LayerFn = None,
|
| 488 |
+
drop_block: Callable = None,
|
| 489 |
+
drop_path_rate: float = 0.,
|
| 490 |
+
inference_mode: bool = False
|
| 491 |
+
):
|
| 492 |
+
super(RepVggBlock, self).__init__()
|
| 493 |
+
self.groups = groups = num_groups(group_size, in_chs)
|
| 494 |
+
layers = layers or LayerFn()
|
| 495 |
+
|
| 496 |
+
if inference_mode:
|
| 497 |
+
self.reparam_conv = nn.Conv2d(
|
| 498 |
+
in_channels=in_chs,
|
| 499 |
+
out_channels=out_chs,
|
| 500 |
+
kernel_size=kernel_size,
|
| 501 |
+
stride=stride,
|
| 502 |
+
dilation=dilation,
|
| 503 |
+
groups=groups,
|
| 504 |
+
bias=True,
|
| 505 |
+
)
|
| 506 |
+
else:
|
| 507 |
+
self.reparam_conv = None
|
| 508 |
+
use_ident = in_chs == out_chs and stride == 1 and dilation[0] == dilation[1]
|
| 509 |
+
self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None
|
| 510 |
+
self.conv_kxk = layers.conv_norm_act(
|
| 511 |
+
in_chs, out_chs, kernel_size,
|
| 512 |
+
stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block, apply_act=False,
|
| 513 |
+
)
|
| 514 |
+
self.conv_1x1 = layers.conv_norm_act(in_chs, out_chs, 1, stride=stride, groups=groups, apply_act=False)
|
| 515 |
+
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity()
|
| 516 |
+
|
| 517 |
+
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
|
| 518 |
+
self.act = layers.act(inplace=True)
|
| 519 |
+
|
| 520 |
+
def init_weights(self, zero_init_last: bool = False):
|
| 521 |
+
# NOTE this init overrides that base model init with specific changes for the block type
|
| 522 |
+
for m in self.modules():
|
| 523 |
+
if isinstance(m, nn.BatchNorm2d):
|
| 524 |
+
nn.init.normal_(m.weight, .1, .1)
|
| 525 |
+
nn.init.normal_(m.bias, 0, .1)
|
| 526 |
+
if hasattr(self.attn, 'reset_parameters'):
|
| 527 |
+
self.attn.reset_parameters()
|
| 528 |
+
|
| 529 |
+
def forward(self, x):
|
| 530 |
+
if self.reparam_conv is not None:
|
| 531 |
+
return self.act(self.attn(self.reparam_conv(x)))
|
| 532 |
+
|
| 533 |
+
if self.identity is None:
|
| 534 |
+
x = self.conv_1x1(x) + self.conv_kxk(x)
|
| 535 |
+
else:
|
| 536 |
+
identity = self.identity(x)
|
| 537 |
+
x = self.conv_1x1(x) + self.conv_kxk(x)
|
| 538 |
+
x = self.drop_path(x) # not in the paper / official impl, experimental
|
| 539 |
+
x += identity
|
| 540 |
+
x = self.attn(x) # no attn in the paper / official impl, experimental
|
| 541 |
+
return self.act(x)
|
| 542 |
+
|
| 543 |
+
def reparameterize(self):
|
| 544 |
+
""" Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
|
| 545 |
+
https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
|
| 546 |
+
architecture used at training time to obtain a plain CNN-like structure
|
| 547 |
+
for inference.
|
| 548 |
+
"""
|
| 549 |
+
if self.reparam_conv is not None:
|
| 550 |
+
return
|
| 551 |
+
|
| 552 |
+
kernel, bias = self._get_kernel_bias()
|
| 553 |
+
self.reparam_conv = nn.Conv2d(
|
| 554 |
+
in_channels=self.conv_kxk.conv.in_channels,
|
| 555 |
+
out_channels=self.conv_kxk.conv.out_channels,
|
| 556 |
+
kernel_size=self.conv_kxk.conv.kernel_size,
|
| 557 |
+
stride=self.conv_kxk.conv.stride,
|
| 558 |
+
padding=self.conv_kxk.conv.padding,
|
| 559 |
+
dilation=self.conv_kxk.conv.dilation,
|
| 560 |
+
groups=self.conv_kxk.conv.groups,
|
| 561 |
+
bias=True,
|
| 562 |
+
)
|
| 563 |
+
self.reparam_conv.weight.data = kernel
|
| 564 |
+
self.reparam_conv.bias.data = bias
|
| 565 |
+
|
| 566 |
+
# Delete un-used branches
|
| 567 |
+
for name, para in self.named_parameters():
|
| 568 |
+
if 'reparam_conv' in name:
|
| 569 |
+
continue
|
| 570 |
+
para.detach_()
|
| 571 |
+
self.__delattr__('conv_kxk')
|
| 572 |
+
self.__delattr__('conv_1x1')
|
| 573 |
+
self.__delattr__('identity')
|
| 574 |
+
self.__delattr__('drop_path')
|
| 575 |
+
|
| 576 |
+
def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 577 |
+
""" Method to obtain re-parameterized kernel and bias.
|
| 578 |
+
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
|
| 579 |
+
"""
|
| 580 |
+
# get weights and bias of scale branch
|
| 581 |
+
kernel_1x1 = 0
|
| 582 |
+
bias_1x1 = 0
|
| 583 |
+
if self.conv_1x1 is not None:
|
| 584 |
+
kernel_1x1, bias_1x1 = self._fuse_bn_tensor(self.conv_1x1)
|
| 585 |
+
# Pad scale branch kernel to match conv branch kernel size.
|
| 586 |
+
pad = self.conv_kxk.conv.kernel_size[0] // 2
|
| 587 |
+
kernel_1x1 = torch.nn.functional.pad(kernel_1x1, [pad, pad, pad, pad])
|
| 588 |
+
|
| 589 |
+
# get weights and bias of skip branch
|
| 590 |
+
kernel_identity = 0
|
| 591 |
+
bias_identity = 0
|
| 592 |
+
if self.identity is not None:
|
| 593 |
+
kernel_identity, bias_identity = self._fuse_bn_tensor(self.identity)
|
| 594 |
+
|
| 595 |
+
# get weights and bias of conv branches
|
| 596 |
+
kernel_conv, bias_conv = self._fuse_bn_tensor(self.conv_kxk)
|
| 597 |
+
|
| 598 |
+
kernel_final = kernel_conv + kernel_1x1 + kernel_identity
|
| 599 |
+
bias_final = bias_conv + bias_1x1 + bias_identity
|
| 600 |
+
return kernel_final, bias_final
|
| 601 |
+
|
| 602 |
+
def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 603 |
+
""" Method to fuse batchnorm layer with preceeding conv layer.
|
| 604 |
+
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
|
| 605 |
+
"""
|
| 606 |
+
if isinstance(branch, ConvNormAct):
|
| 607 |
+
kernel = branch.conv.weight
|
| 608 |
+
running_mean = branch.bn.running_mean
|
| 609 |
+
running_var = branch.bn.running_var
|
| 610 |
+
gamma = branch.bn.weight
|
| 611 |
+
beta = branch.bn.bias
|
| 612 |
+
eps = branch.bn.eps
|
| 613 |
+
else:
|
| 614 |
+
assert isinstance(branch, nn.BatchNorm2d)
|
| 615 |
+
if not hasattr(self, 'id_tensor'):
|
| 616 |
+
in_chs = self.conv_kxk.conv.in_channels
|
| 617 |
+
input_dim = in_chs // self.groups
|
| 618 |
+
kernel_size = self.conv_kxk.conv.kernel_size
|
| 619 |
+
kernel_value = torch.zeros_like(self.conv_kxk.conv.weight)
|
| 620 |
+
for i in range(in_chs):
|
| 621 |
+
kernel_value[i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2] = 1
|
| 622 |
+
self.id_tensor = kernel_value
|
| 623 |
+
kernel = self.id_tensor
|
| 624 |
+
running_mean = branch.running_mean
|
| 625 |
+
running_var = branch.running_var
|
| 626 |
+
gamma = branch.weight
|
| 627 |
+
beta = branch.bias
|
| 628 |
+
eps = branch.eps
|
| 629 |
+
std = (running_var + eps).sqrt()
|
| 630 |
+
t = (gamma / std).reshape(-1, 1, 1, 1)
|
| 631 |
+
return kernel * t, beta - running_mean * gamma / std
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
class MobileOneBlock(nn.Module):
|
| 635 |
+
""" MobileOne building block.
|
| 636 |
+
|
| 637 |
+
This block has a multi-branched architecture at train-time
|
| 638 |
+
and plain-CNN style architecture at inference time
|
| 639 |
+
For more details, please refer to our paper:
|
| 640 |
+
`An Improved One millisecond Mobile Backbone` -
|
| 641 |
+
https://arxiv.org/pdf/2206.04040.pdf
|
| 642 |
+
"""
|
| 643 |
+
|
| 644 |
+
def __init__(
|
| 645 |
+
self,
|
| 646 |
+
in_chs: int,
|
| 647 |
+
out_chs: int,
|
| 648 |
+
kernel_size: int = 3,
|
| 649 |
+
stride: int = 1,
|
| 650 |
+
dilation: Tuple[int, int] = (1, 1),
|
| 651 |
+
bottle_ratio: float = 1.0, # unused
|
| 652 |
+
group_size: Optional[int] = None,
|
| 653 |
+
downsample: str = '', # unused
|
| 654 |
+
inference_mode: bool = False,
|
| 655 |
+
num_conv_branches: int = 1,
|
| 656 |
+
layers: LayerFn = None,
|
| 657 |
+
drop_block: Callable = None,
|
| 658 |
+
drop_path_rate: float = 0.,
|
| 659 |
+
) -> None:
|
| 660 |
+
""" Construct a MobileOneBlock module.
|
| 661 |
+
"""
|
| 662 |
+
super(MobileOneBlock, self).__init__()
|
| 663 |
+
self.num_conv_branches = num_conv_branches
|
| 664 |
+
self.groups = groups = num_groups(group_size, in_chs)
|
| 665 |
+
layers = layers or LayerFn()
|
| 666 |
+
|
| 667 |
+
if inference_mode:
|
| 668 |
+
self.reparam_conv = nn.Conv2d(
|
| 669 |
+
in_channels=in_chs,
|
| 670 |
+
out_channels=out_chs,
|
| 671 |
+
kernel_size=kernel_size,
|
| 672 |
+
stride=stride,
|
| 673 |
+
dilation=dilation,
|
| 674 |
+
groups=groups,
|
| 675 |
+
bias=True)
|
| 676 |
+
else:
|
| 677 |
+
self.reparam_conv = None
|
| 678 |
+
|
| 679 |
+
# Re-parameterizable skip connection
|
| 680 |
+
use_ident = in_chs == out_chs and stride == 1 and dilation[0] == dilation[1]
|
| 681 |
+
self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None
|
| 682 |
+
|
| 683 |
+
# Re-parameterizable conv branches
|
| 684 |
+
convs = []
|
| 685 |
+
for _ in range(self.num_conv_branches):
|
| 686 |
+
convs.append(layers.conv_norm_act(
|
| 687 |
+
in_chs, out_chs, kernel_size=kernel_size,
|
| 688 |
+
stride=stride, groups=groups, apply_act=False))
|
| 689 |
+
self.conv_kxk = nn.ModuleList(convs)
|
| 690 |
+
|
| 691 |
+
# Re-parameterizable scale branch
|
| 692 |
+
self.conv_scale = None
|
| 693 |
+
if kernel_size > 1:
|
| 694 |
+
self.conv_scale = layers.conv_norm_act(
|
| 695 |
+
in_chs, out_chs, kernel_size=1,
|
| 696 |
+
stride=stride, groups=groups, apply_act=False)
|
| 697 |
+
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity()
|
| 698 |
+
|
| 699 |
+
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
|
| 700 |
+
self.act = layers.act(inplace=True)
|
| 701 |
+
|
| 702 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 703 |
+
""" Apply forward pass. """
|
| 704 |
+
# Inference mode forward pass.
|
| 705 |
+
if self.reparam_conv is not None:
|
| 706 |
+
return self.act(self.attn(self.reparam_conv(x)))
|
| 707 |
+
|
| 708 |
+
# Multi-branched train-time forward pass.
|
| 709 |
+
# Skip branch output
|
| 710 |
+
identity_out = 0
|
| 711 |
+
if self.identity is not None:
|
| 712 |
+
identity_out = self.identity(x)
|
| 713 |
+
|
| 714 |
+
# Scale branch output
|
| 715 |
+
scale_out = 0
|
| 716 |
+
if self.conv_scale is not None:
|
| 717 |
+
scale_out = self.conv_scale(x)
|
| 718 |
+
|
| 719 |
+
# Other branches
|
| 720 |
+
out = scale_out
|
| 721 |
+
for ck in self.conv_kxk:
|
| 722 |
+
out += ck(x)
|
| 723 |
+
out = self.drop_path(out)
|
| 724 |
+
out += identity_out
|
| 725 |
+
|
| 726 |
+
return self.act(self.attn(out))
|
| 727 |
+
|
| 728 |
+
def reparameterize(self):
|
| 729 |
+
""" Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
|
| 730 |
+
https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
|
| 731 |
+
architecture used at training time to obtain a plain CNN-like structure
|
| 732 |
+
for inference.
|
| 733 |
+
"""
|
| 734 |
+
if self.reparam_conv is not None:
|
| 735 |
+
return
|
| 736 |
+
|
| 737 |
+
kernel, bias = self._get_kernel_bias()
|
| 738 |
+
self.reparam_conv = nn.Conv2d(
|
| 739 |
+
in_channels=self.conv_kxk[0].conv.in_channels,
|
| 740 |
+
out_channels=self.conv_kxk[0].conv.out_channels,
|
| 741 |
+
kernel_size=self.conv_kxk[0].conv.kernel_size,
|
| 742 |
+
stride=self.conv_kxk[0].conv.stride,
|
| 743 |
+
padding=self.conv_kxk[0].conv.padding,
|
| 744 |
+
dilation=self.conv_kxk[0].conv.dilation,
|
| 745 |
+
groups=self.conv_kxk[0].conv.groups,
|
| 746 |
+
bias=True)
|
| 747 |
+
self.reparam_conv.weight.data = kernel
|
| 748 |
+
self.reparam_conv.bias.data = bias
|
| 749 |
+
|
| 750 |
+
# Delete un-used branches
|
| 751 |
+
for name, para in self.named_parameters():
|
| 752 |
+
if 'reparam_conv' in name:
|
| 753 |
+
continue
|
| 754 |
+
para.detach_()
|
| 755 |
+
self.__delattr__('conv_kxk')
|
| 756 |
+
self.__delattr__('conv_scale')
|
| 757 |
+
self.__delattr__('identity')
|
| 758 |
+
self.__delattr__('drop_path')
|
| 759 |
+
|
| 760 |
+
def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 761 |
+
""" Method to obtain re-parameterized kernel and bias.
|
| 762 |
+
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
|
| 763 |
+
"""
|
| 764 |
+
# get weights and bias of scale branch
|
| 765 |
+
kernel_scale = 0
|
| 766 |
+
bias_scale = 0
|
| 767 |
+
if self.conv_scale is not None:
|
| 768 |
+
kernel_scale, bias_scale = self._fuse_bn_tensor(self.conv_scale)
|
| 769 |
+
# Pad scale branch kernel to match conv branch kernel size.
|
| 770 |
+
pad = self.conv_kxk[0].conv.kernel_size[0] // 2
|
| 771 |
+
kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
|
| 772 |
+
|
| 773 |
+
# get weights and bias of skip branch
|
| 774 |
+
kernel_identity = 0
|
| 775 |
+
bias_identity = 0
|
| 776 |
+
if self.identity is not None:
|
| 777 |
+
kernel_identity, bias_identity = self._fuse_bn_tensor(self.identity)
|
| 778 |
+
|
| 779 |
+
# get weights and bias of conv branches
|
| 780 |
+
kernel_conv = 0
|
| 781 |
+
bias_conv = 0
|
| 782 |
+
for ix in range(self.num_conv_branches):
|
| 783 |
+
_kernel, _bias = self._fuse_bn_tensor(self.conv_kxk[ix])
|
| 784 |
+
kernel_conv += _kernel
|
| 785 |
+
bias_conv += _bias
|
| 786 |
+
|
| 787 |
+
kernel_final = kernel_conv + kernel_scale + kernel_identity
|
| 788 |
+
bias_final = bias_conv + bias_scale + bias_identity
|
| 789 |
+
return kernel_final, bias_final
|
| 790 |
+
|
| 791 |
+
def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 792 |
+
""" Method to fuse batchnorm layer with preceeding conv layer.
|
| 793 |
+
Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
|
| 794 |
+
"""
|
| 795 |
+
if isinstance(branch, ConvNormAct):
|
| 796 |
+
kernel = branch.conv.weight
|
| 797 |
+
running_mean = branch.bn.running_mean
|
| 798 |
+
running_var = branch.bn.running_var
|
| 799 |
+
gamma = branch.bn.weight
|
| 800 |
+
beta = branch.bn.bias
|
| 801 |
+
eps = branch.bn.eps
|
| 802 |
+
else:
|
| 803 |
+
assert isinstance(branch, nn.BatchNorm2d)
|
| 804 |
+
if not hasattr(self, 'id_tensor'):
|
| 805 |
+
in_chs = self.conv_kxk[0].conv.in_channels
|
| 806 |
+
input_dim = in_chs // self.groups
|
| 807 |
+
kernel_size = self.conv_kxk[0].conv.kernel_size
|
| 808 |
+
kernel_value = torch.zeros_like(self.conv_kxk[0].conv.weight)
|
| 809 |
+
for i in range(in_chs):
|
| 810 |
+
kernel_value[i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2] = 1
|
| 811 |
+
self.id_tensor = kernel_value
|
| 812 |
+
kernel = self.id_tensor
|
| 813 |
+
running_mean = branch.running_mean
|
| 814 |
+
running_var = branch.running_var
|
| 815 |
+
gamma = branch.weight
|
| 816 |
+
beta = branch.bias
|
| 817 |
+
eps = branch.eps
|
| 818 |
+
std = (running_var + eps).sqrt()
|
| 819 |
+
t = (gamma / std).reshape(-1, 1, 1, 1)
|
| 820 |
+
return kernel * t, beta - running_mean * gamma / std
|
| 821 |
+
|
| 822 |
+
|
| 823 |
+
class SelfAttnBlock(nn.Module):
|
| 824 |
+
""" ResNet-like Bottleneck Block - 1x1 - optional kxk - self attn - 1x1
|
| 825 |
+
"""
|
| 826 |
+
|
| 827 |
+
def __init__(
|
| 828 |
+
self,
|
| 829 |
+
in_chs: int,
|
| 830 |
+
out_chs: int,
|
| 831 |
+
kernel_size: int = 3,
|
| 832 |
+
stride: int = 1,
|
| 833 |
+
dilation: Tuple[int, int] = (1, 1),
|
| 834 |
+
bottle_ratio: float = 1.,
|
| 835 |
+
group_size: Optional[int] = None,
|
| 836 |
+
downsample: str = 'avg',
|
| 837 |
+
extra_conv: bool = False,
|
| 838 |
+
linear_out: bool = False,
|
| 839 |
+
bottle_in: bool = False,
|
| 840 |
+
post_attn_na: bool = True,
|
| 841 |
+
feat_size: Optional[Tuple[int, int]] = None,
|
| 842 |
+
layers: LayerFn = None,
|
| 843 |
+
drop_block: Callable = None,
|
| 844 |
+
drop_path_rate: float = 0.,
|
| 845 |
+
):
|
| 846 |
+
super(SelfAttnBlock, self).__init__()
|
| 847 |
+
assert layers is not None
|
| 848 |
+
mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
|
| 849 |
+
groups = num_groups(group_size, mid_chs)
|
| 850 |
+
|
| 851 |
+
self.shortcut = create_shortcut(
|
| 852 |
+
downsample, in_chs, out_chs,
|
| 853 |
+
stride=stride, dilation=dilation, apply_act=False, layers=layers,
|
| 854 |
+
)
|
| 855 |
+
|
| 856 |
+
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
|
| 857 |
+
if extra_conv:
|
| 858 |
+
self.conv2_kxk = layers.conv_norm_act(
|
| 859 |
+
mid_chs, mid_chs, kernel_size,
|
| 860 |
+
stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block,
|
| 861 |
+
)
|
| 862 |
+
stride = 1 # striding done via conv if enabled
|
| 863 |
+
else:
|
| 864 |
+
self.conv2_kxk = nn.Identity()
|
| 865 |
+
opt_kwargs = {} if feat_size is None else dict(feat_size=feat_size)
|
| 866 |
+
# FIXME need to dilate self attn to have dilated network support, moop moop
|
| 867 |
+
self.self_attn = layers.self_attn(mid_chs, stride=stride, **opt_kwargs)
|
| 868 |
+
self.post_attn = layers.norm_act(mid_chs) if post_attn_na else nn.Identity()
|
| 869 |
+
self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
|
| 870 |
+
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
|
| 871 |
+
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
|
| 872 |
+
|
| 873 |
+
def init_weights(self, zero_init_last: bool = False):
|
| 874 |
+
if zero_init_last and self.shortcut is not None and getattr(self.conv3_1x1.bn, 'weight', None) is not None:
|
| 875 |
+
nn.init.zeros_(self.conv3_1x1.bn.weight)
|
| 876 |
+
if hasattr(self.self_attn, 'reset_parameters'):
|
| 877 |
+
self.self_attn.reset_parameters()
|
| 878 |
+
|
| 879 |
+
def forward(self, x):
|
| 880 |
+
shortcut = x
|
| 881 |
+
x = self.conv1_1x1(x)
|
| 882 |
+
x = self.conv2_kxk(x)
|
| 883 |
+
x = self.self_attn(x)
|
| 884 |
+
x = self.post_attn(x)
|
| 885 |
+
x = self.conv3_1x1(x)
|
| 886 |
+
x = self.drop_path(x)
|
| 887 |
+
if self.shortcut is not None:
|
| 888 |
+
x = x + self.shortcut(shortcut)
|
| 889 |
+
return self.act(x)
|
| 890 |
+
|
| 891 |
+
|
| 892 |
+
_block_registry = dict(
|
| 893 |
+
basic=BasicBlock,
|
| 894 |
+
bottle=BottleneckBlock,
|
| 895 |
+
dark=DarkBlock,
|
| 896 |
+
edge=EdgeBlock,
|
| 897 |
+
rep=RepVggBlock,
|
| 898 |
+
one=MobileOneBlock,
|
| 899 |
+
self_attn=SelfAttnBlock,
|
| 900 |
+
)
|
| 901 |
+
|
| 902 |
+
|
| 903 |
+
def register_block(block_type:str, block_fn: nn.Module):
|
| 904 |
+
_block_registry[block_type] = block_fn
|
| 905 |
+
|
| 906 |
+
|
| 907 |
+
def create_block(block: Union[str, nn.Module], **kwargs):
|
| 908 |
+
if isinstance(block, (nn.Module, partial)):
|
| 909 |
+
return block(**kwargs)
|
| 910 |
+
assert block in _block_registry, f'Unknown block type ({block}'
|
| 911 |
+
return _block_registry[block](**kwargs)
|
| 912 |
+
|
| 913 |
+
|
| 914 |
+
class Stem(nn.Sequential):
|
| 915 |
+
|
| 916 |
+
def __init__(
|
| 917 |
+
self,
|
| 918 |
+
in_chs: int,
|
| 919 |
+
out_chs: int,
|
| 920 |
+
kernel_size: int = 3,
|
| 921 |
+
stride: int = 4,
|
| 922 |
+
pool: str = 'maxpool',
|
| 923 |
+
num_rep: int = 3,
|
| 924 |
+
num_act: Optional[int] = None,
|
| 925 |
+
chs_decay: float = 0.5,
|
| 926 |
+
layers: LayerFn = None,
|
| 927 |
+
):
|
| 928 |
+
super().__init__()
|
| 929 |
+
assert stride in (2, 4)
|
| 930 |
+
layers = layers or LayerFn()
|
| 931 |
+
|
| 932 |
+
if isinstance(out_chs, (list, tuple)):
|
| 933 |
+
num_rep = len(out_chs)
|
| 934 |
+
stem_chs = out_chs
|
| 935 |
+
else:
|
| 936 |
+
stem_chs = [round(out_chs * chs_decay ** i) for i in range(num_rep)][::-1]
|
| 937 |
+
|
| 938 |
+
self.stride = stride
|
| 939 |
+
self.feature_info = [] # track intermediate features
|
| 940 |
+
prev_feat = ''
|
| 941 |
+
stem_strides = [2] + [1] * (num_rep - 1)
|
| 942 |
+
if stride == 4 and not pool:
|
| 943 |
+
# set last conv in stack to be strided if stride == 4 and no pooling layer
|
| 944 |
+
stem_strides[-1] = 2
|
| 945 |
+
|
| 946 |
+
num_act = num_rep if num_act is None else num_act
|
| 947 |
+
# if num_act < num_rep, first convs in stack won't have bn + act
|
| 948 |
+
stem_norm_acts = [False] * (num_rep - num_act) + [True] * num_act
|
| 949 |
+
prev_chs = in_chs
|
| 950 |
+
curr_stride = 1
|
| 951 |
+
for i, (ch, s, na) in enumerate(zip(stem_chs, stem_strides, stem_norm_acts)):
|
| 952 |
+
layer_fn = layers.conv_norm_act if na else create_conv2d
|
| 953 |
+
conv_name = f'conv{i + 1}'
|
| 954 |
+
if i > 0 and s > 1:
|
| 955 |
+
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
|
| 956 |
+
self.add_module(conv_name, layer_fn(prev_chs, ch, kernel_size=kernel_size, stride=s))
|
| 957 |
+
prev_chs = ch
|
| 958 |
+
curr_stride *= s
|
| 959 |
+
prev_feat = conv_name
|
| 960 |
+
|
| 961 |
+
if pool and 'max' in pool.lower():
|
| 962 |
+
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
|
| 963 |
+
self.add_module('pool', nn.MaxPool2d(3, 2, 1))
|
| 964 |
+
curr_stride *= 2
|
| 965 |
+
prev_feat = 'pool'
|
| 966 |
+
|
| 967 |
+
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
|
| 968 |
+
assert curr_stride == stride
|
| 969 |
+
|
| 970 |
+
|
| 971 |
+
def create_byob_stem(
|
| 972 |
+
in_chs: int,
|
| 973 |
+
out_chs: int,
|
| 974 |
+
stem_type: str = '',
|
| 975 |
+
pool_type: str = '',
|
| 976 |
+
feat_prefix: str = 'stem',
|
| 977 |
+
layers: LayerFn = None,
|
| 978 |
+
):
|
| 979 |
+
layers = layers or LayerFn()
|
| 980 |
+
assert stem_type in ('', 'quad', 'quad2', 'tiered', 'deep', 'rep', 'one', '7x7', '3x3')
|
| 981 |
+
if 'quad' in stem_type:
|
| 982 |
+
# based on NFNet stem, stack of 4 3x3 convs
|
| 983 |
+
num_act = 2 if 'quad2' in stem_type else None
|
| 984 |
+
stem = Stem(in_chs, out_chs, num_rep=4, num_act=num_act, pool=pool_type, layers=layers)
|
| 985 |
+
elif 'tiered' in stem_type:
|
| 986 |
+
# 3x3 stack of 3 convs as in my ResNet-T
|
| 987 |
+
stem = Stem(in_chs, (3 * out_chs // 8, out_chs // 2, out_chs), pool=pool_type, layers=layers)
|
| 988 |
+
elif 'deep' in stem_type:
|
| 989 |
+
# 3x3 stack of 3 convs as in ResNet-D
|
| 990 |
+
stem = Stem(in_chs, out_chs, num_rep=3, chs_decay=1.0, pool=pool_type, layers=layers)
|
| 991 |
+
elif 'rep' in stem_type:
|
| 992 |
+
stem = RepVggBlock(in_chs, out_chs, stride=2, layers=layers)
|
| 993 |
+
elif 'one' in stem_type:
|
| 994 |
+
stem = MobileOneBlock(in_chs, out_chs, kernel_size=3, stride=2, layers=layers)
|
| 995 |
+
elif '7x7' in stem_type:
|
| 996 |
+
# 7x7 stem conv as in ResNet
|
| 997 |
+
if pool_type:
|
| 998 |
+
stem = Stem(in_chs, out_chs, 7, num_rep=1, pool=pool_type, layers=layers)
|
| 999 |
+
else:
|
| 1000 |
+
stem = layers.conv_norm_act(in_chs, out_chs, 7, stride=2)
|
| 1001 |
+
else:
|
| 1002 |
+
# 3x3 stem conv as in RegNet is the default
|
| 1003 |
+
if pool_type:
|
| 1004 |
+
stem = Stem(in_chs, out_chs, 3, num_rep=1, pool=pool_type, layers=layers)
|
| 1005 |
+
else:
|
| 1006 |
+
stem = layers.conv_norm_act(in_chs, out_chs, 3, stride=2)
|
| 1007 |
+
|
| 1008 |
+
if isinstance(stem, Stem):
|
| 1009 |
+
feature_info = [dict(f, module='.'.join([feat_prefix, f['module']])) for f in stem.feature_info]
|
| 1010 |
+
else:
|
| 1011 |
+
feature_info = [dict(num_chs=out_chs, reduction=2, module=feat_prefix)]
|
| 1012 |
+
return stem, feature_info
|
| 1013 |
+
|
| 1014 |
+
|
| 1015 |
+
def reduce_feat_size(feat_size, stride=2):
|
| 1016 |
+
return None if feat_size is None else tuple([s // stride for s in feat_size])
|
| 1017 |
+
|
| 1018 |
+
|
| 1019 |
+
def override_kwargs(block_kwargs, model_kwargs):
|
| 1020 |
+
""" Override model level attn/self-attn/block kwargs w/ block level
|
| 1021 |
+
|
| 1022 |
+
NOTE: kwargs are NOT merged across levels, block_kwargs will fully replace model_kwargs
|
| 1023 |
+
for the block if set to anything that isn't None.
|
| 1024 |
+
|
| 1025 |
+
i.e. an empty block_kwargs dict will remove kwargs set at model level for that block
|
| 1026 |
+
"""
|
| 1027 |
+
out_kwargs = block_kwargs if block_kwargs is not None else model_kwargs
|
| 1028 |
+
return out_kwargs or {} # make sure None isn't returned
|
| 1029 |
+
|
| 1030 |
+
|
| 1031 |
+
def update_block_kwargs(block_kwargs: Dict[str, Any], block_cfg: ByoBlockCfg, model_cfg: ByoModelCfg, ):
|
| 1032 |
+
layer_fns = block_kwargs['layers']
|
| 1033 |
+
|
| 1034 |
+
# override attn layer / args with block local config
|
| 1035 |
+
attn_set = block_cfg.attn_layer is not None
|
| 1036 |
+
if attn_set or block_cfg.attn_kwargs is not None:
|
| 1037 |
+
# override attn layer config
|
| 1038 |
+
if attn_set and not block_cfg.attn_layer:
|
| 1039 |
+
# empty string for attn_layer type will disable attn for this block
|
| 1040 |
+
attn_layer = None
|
| 1041 |
+
else:
|
| 1042 |
+
attn_kwargs = override_kwargs(block_cfg.attn_kwargs, model_cfg.attn_kwargs)
|
| 1043 |
+
attn_layer = block_cfg.attn_layer or model_cfg.attn_layer
|
| 1044 |
+
attn_layer = partial(get_attn(attn_layer), **attn_kwargs) if attn_layer is not None else None
|
| 1045 |
+
layer_fns = replace(layer_fns, attn=attn_layer)
|
| 1046 |
+
|
| 1047 |
+
# override self-attn layer / args with block local cfg
|
| 1048 |
+
self_attn_set = block_cfg.self_attn_layer is not None
|
| 1049 |
+
if self_attn_set or block_cfg.self_attn_kwargs is not None:
|
| 1050 |
+
# override attn layer config
|
| 1051 |
+
if self_attn_set and not block_cfg.self_attn_layer: # attn_layer == ''
|
| 1052 |
+
# empty string for self_attn_layer type will disable attn for this block
|
| 1053 |
+
self_attn_layer = None
|
| 1054 |
+
else:
|
| 1055 |
+
self_attn_kwargs = override_kwargs(block_cfg.self_attn_kwargs, model_cfg.self_attn_kwargs)
|
| 1056 |
+
self_attn_layer = block_cfg.self_attn_layer or model_cfg.self_attn_layer
|
| 1057 |
+
self_attn_layer = partial(get_attn(self_attn_layer), **self_attn_kwargs) \
|
| 1058 |
+
if self_attn_layer is not None else None
|
| 1059 |
+
layer_fns = replace(layer_fns, self_attn=self_attn_layer)
|
| 1060 |
+
|
| 1061 |
+
block_kwargs['layers'] = layer_fns
|
| 1062 |
+
|
| 1063 |
+
# add additional block_kwargs specified in block_cfg or model_cfg, precedence to block if set
|
| 1064 |
+
block_kwargs.update(override_kwargs(block_cfg.block_kwargs, model_cfg.block_kwargs))
|
| 1065 |
+
|
| 1066 |
+
|
| 1067 |
+
def create_byob_stages(
|
| 1068 |
+
cfg: ByoModelCfg,
|
| 1069 |
+
drop_path_rate: float,
|
| 1070 |
+
output_stride: int,
|
| 1071 |
+
stem_feat: Dict[str, Any],
|
| 1072 |
+
feat_size: Optional[int] = None,
|
| 1073 |
+
layers: Optional[LayerFn] = None,
|
| 1074 |
+
block_kwargs_fn: Optional[Callable] = update_block_kwargs,
|
| 1075 |
+
):
|
| 1076 |
+
|
| 1077 |
+
layers = layers or LayerFn()
|
| 1078 |
+
feature_info = []
|
| 1079 |
+
block_cfgs = [expand_blocks_cfg(s) for s in cfg.blocks]
|
| 1080 |
+
depths = [sum([bc.d for bc in stage_bcs]) for stage_bcs in block_cfgs]
|
| 1081 |
+
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
| 1082 |
+
dilation = 1
|
| 1083 |
+
net_stride = stem_feat['reduction']
|
| 1084 |
+
prev_chs = stem_feat['num_chs']
|
| 1085 |
+
prev_feat = stem_feat
|
| 1086 |
+
stages = []
|
| 1087 |
+
for stage_idx, stage_block_cfgs in enumerate(block_cfgs):
|
| 1088 |
+
stride = stage_block_cfgs[0].s
|
| 1089 |
+
if stride != 1 and prev_feat:
|
| 1090 |
+
feature_info.append(prev_feat)
|
| 1091 |
+
if net_stride >= output_stride and stride > 1:
|
| 1092 |
+
dilation *= stride
|
| 1093 |
+
stride = 1
|
| 1094 |
+
net_stride *= stride
|
| 1095 |
+
first_dilation = 1 if dilation in (1, 2) else 2
|
| 1096 |
+
|
| 1097 |
+
blocks = []
|
| 1098 |
+
for block_idx, block_cfg in enumerate(stage_block_cfgs):
|
| 1099 |
+
out_chs = make_divisible(block_cfg.c * cfg.width_factor)
|
| 1100 |
+
group_size = block_cfg.gs
|
| 1101 |
+
if isinstance(group_size, Callable):
|
| 1102 |
+
group_size = group_size(out_chs, block_idx)
|
| 1103 |
+
block_kwargs = dict( # Blocks used in this model must accept these arguments
|
| 1104 |
+
in_chs=prev_chs,
|
| 1105 |
+
out_chs=out_chs,
|
| 1106 |
+
stride=stride if block_idx == 0 else 1,
|
| 1107 |
+
dilation=(first_dilation, dilation),
|
| 1108 |
+
group_size=group_size,
|
| 1109 |
+
bottle_ratio=block_cfg.br,
|
| 1110 |
+
downsample=cfg.downsample,
|
| 1111 |
+
drop_path_rate=dpr[stage_idx][block_idx],
|
| 1112 |
+
layers=layers,
|
| 1113 |
+
)
|
| 1114 |
+
if block_cfg.type in ('self_attn',):
|
| 1115 |
+
# add feat_size arg for blocks that support/need it
|
| 1116 |
+
block_kwargs['feat_size'] = feat_size
|
| 1117 |
+
block_kwargs_fn(block_kwargs, block_cfg=block_cfg, model_cfg=cfg)
|
| 1118 |
+
blocks += [create_block(block_cfg.type, **block_kwargs)]
|
| 1119 |
+
first_dilation = dilation
|
| 1120 |
+
prev_chs = out_chs
|
| 1121 |
+
if stride > 1 and block_idx == 0:
|
| 1122 |
+
feat_size = reduce_feat_size(feat_size, stride)
|
| 1123 |
+
|
| 1124 |
+
stages += [nn.Sequential(*blocks)]
|
| 1125 |
+
prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}')
|
| 1126 |
+
|
| 1127 |
+
feature_info.append(prev_feat)
|
| 1128 |
+
return nn.Sequential(*stages), feature_info
|
| 1129 |
+
|
| 1130 |
+
|
| 1131 |
+
def get_layer_fns(cfg: ByoModelCfg):
|
| 1132 |
+
act = get_act_layer(cfg.act_layer)
|
| 1133 |
+
norm_act = get_norm_act_layer(norm_layer=cfg.norm_layer, act_layer=act)
|
| 1134 |
+
conv_norm_act = partial(ConvNormAct, norm_layer=cfg.norm_layer, act_layer=act)
|
| 1135 |
+
attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
|
| 1136 |
+
self_attn = partial(get_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None
|
| 1137 |
+
layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn)
|
| 1138 |
+
return layer_fn
|
| 1139 |
+
|
| 1140 |
+
|
| 1141 |
+
class ByobNet(nn.Module):
|
| 1142 |
+
""" 'Bring-your-own-blocks' Net
|
| 1143 |
+
|
| 1144 |
+
A flexible network backbone that allows building model stem + blocks via
|
| 1145 |
+
dataclass cfg definition w/ factory functions for module instantiation.
|
| 1146 |
+
|
| 1147 |
+
Current assumption is that both stem and blocks are in conv-bn-act order (w/ block ending in act).
|
| 1148 |
+
"""
|
| 1149 |
+
def __init__(
|
| 1150 |
+
self,
|
| 1151 |
+
cfg: ByoModelCfg,
|
| 1152 |
+
num_classes: int = 1000,
|
| 1153 |
+
in_chans: int = 3,
|
| 1154 |
+
global_pool: str = 'avg',
|
| 1155 |
+
output_stride: int = 32,
|
| 1156 |
+
img_size: Optional[Union[int, Tuple[int, int]]] = None,
|
| 1157 |
+
drop_rate: float = 0.,
|
| 1158 |
+
drop_path_rate: float =0.,
|
| 1159 |
+
zero_init_last: bool = True,
|
| 1160 |
+
**kwargs,
|
| 1161 |
+
):
|
| 1162 |
+
"""
|
| 1163 |
+
Args:
|
| 1164 |
+
cfg: Model architecture configuration.
|
| 1165 |
+
num_classes: Number of classifier classes.
|
| 1166 |
+
in_chans: Number of input channels.
|
| 1167 |
+
global_pool: Global pooling type.
|
| 1168 |
+
output_stride: Output stride of network, one of (8, 16, 32).
|
| 1169 |
+
img_size: Image size for fixed image size models (i.e. self-attn).
|
| 1170 |
+
drop_rate: Classifier dropout rate.
|
| 1171 |
+
drop_path_rate: Stochastic depth drop-path rate.
|
| 1172 |
+
zero_init_last: Zero-init last weight of residual path.
|
| 1173 |
+
**kwargs: Extra kwargs overlayed onto cfg.
|
| 1174 |
+
"""
|
| 1175 |
+
super().__init__()
|
| 1176 |
+
self.num_classes = num_classes
|
| 1177 |
+
self.drop_rate = drop_rate
|
| 1178 |
+
self.grad_checkpointing = False
|
| 1179 |
+
|
| 1180 |
+
cfg = replace(cfg, **kwargs) # overlay kwargs onto cfg
|
| 1181 |
+
layers = get_layer_fns(cfg)
|
| 1182 |
+
if cfg.fixed_input_size:
|
| 1183 |
+
assert img_size is not None, 'img_size argument is required for fixed input size model'
|
| 1184 |
+
feat_size = to_2tuple(img_size) if img_size is not None else None
|
| 1185 |
+
|
| 1186 |
+
self.feature_info = []
|
| 1187 |
+
stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor))
|
| 1188 |
+
self.stem, stem_feat = create_byob_stem(in_chans, stem_chs, cfg.stem_type, cfg.stem_pool, layers=layers)
|
| 1189 |
+
self.feature_info.extend(stem_feat[:-1])
|
| 1190 |
+
feat_size = reduce_feat_size(feat_size, stride=stem_feat[-1]['reduction'])
|
| 1191 |
+
|
| 1192 |
+
self.stages, stage_feat = create_byob_stages(
|
| 1193 |
+
cfg,
|
| 1194 |
+
drop_path_rate,
|
| 1195 |
+
output_stride,
|
| 1196 |
+
stem_feat[-1],
|
| 1197 |
+
layers=layers,
|
| 1198 |
+
feat_size=feat_size,
|
| 1199 |
+
)
|
| 1200 |
+
self.feature_info.extend(stage_feat[:-1])
|
| 1201 |
+
|
| 1202 |
+
prev_chs = stage_feat[-1]['num_chs']
|
| 1203 |
+
if cfg.num_features:
|
| 1204 |
+
self.num_features = int(round(cfg.width_factor * cfg.num_features))
|
| 1205 |
+
self.final_conv = layers.conv_norm_act(prev_chs, self.num_features, 1)
|
| 1206 |
+
else:
|
| 1207 |
+
self.num_features = prev_chs
|
| 1208 |
+
self.final_conv = nn.Identity()
|
| 1209 |
+
self.feature_info += [
|
| 1210 |
+
dict(num_chs=self.num_features, reduction=stage_feat[-1]['reduction'], module='final_conv')]
|
| 1211 |
+
|
| 1212 |
+
self.head = ClassifierHead(
|
| 1213 |
+
self.num_features,
|
| 1214 |
+
num_classes,
|
| 1215 |
+
pool_type=global_pool,
|
| 1216 |
+
drop_rate=self.drop_rate,
|
| 1217 |
+
)
|
| 1218 |
+
|
| 1219 |
+
# init weights
|
| 1220 |
+
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
|
| 1221 |
+
|
| 1222 |
+
@torch.jit.ignore
|
| 1223 |
+
def group_matcher(self, coarse=False):
|
| 1224 |
+
matcher = dict(
|
| 1225 |
+
stem=r'^stem',
|
| 1226 |
+
blocks=[
|
| 1227 |
+
(r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.(\d+)', None),
|
| 1228 |
+
(r'^final_conv', (99999,))
|
| 1229 |
+
]
|
| 1230 |
+
)
|
| 1231 |
+
return matcher
|
| 1232 |
+
|
| 1233 |
+
@torch.jit.ignore
|
| 1234 |
+
def set_grad_checkpointing(self, enable=True):
|
| 1235 |
+
self.grad_checkpointing = enable
|
| 1236 |
+
|
| 1237 |
+
@torch.jit.ignore
|
| 1238 |
+
def get_classifier(self):
|
| 1239 |
+
return self.head.fc
|
| 1240 |
+
|
| 1241 |
+
def reset_classifier(self, num_classes, global_pool='avg'):
|
| 1242 |
+
self.head.reset(num_classes, global_pool)
|
| 1243 |
+
|
| 1244 |
+
def forward_features(self, x):
|
| 1245 |
+
x = self.stem(x)
|
| 1246 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
| 1247 |
+
x = checkpoint_seq(self.stages, x)
|
| 1248 |
+
else:
|
| 1249 |
+
x = self.stages(x)
|
| 1250 |
+
x = self.final_conv(x)
|
| 1251 |
+
return x
|
| 1252 |
+
|
| 1253 |
+
def forward_head(self, x, pre_logits: bool = False):
|
| 1254 |
+
return self.head(x, pre_logits=pre_logits)
|
| 1255 |
+
|
| 1256 |
+
def forward(self, x):
|
| 1257 |
+
x = self.forward_features(x)
|
| 1258 |
+
x = self.forward_head(x)
|
| 1259 |
+
return x
|
| 1260 |
+
|
| 1261 |
+
|
| 1262 |
+
def _init_weights(module, name='', zero_init_last=False):
|
| 1263 |
+
if isinstance(module, nn.Conv2d):
|
| 1264 |
+
fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
|
| 1265 |
+
fan_out //= module.groups
|
| 1266 |
+
module.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 1267 |
+
if module.bias is not None:
|
| 1268 |
+
module.bias.data.zero_()
|
| 1269 |
+
elif isinstance(module, nn.Linear):
|
| 1270 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.01)
|
| 1271 |
+
if module.bias is not None:
|
| 1272 |
+
nn.init.zeros_(module.bias)
|
| 1273 |
+
elif isinstance(module, nn.BatchNorm2d):
|
| 1274 |
+
nn.init.ones_(module.weight)
|
| 1275 |
+
nn.init.zeros_(module.bias)
|
| 1276 |
+
elif hasattr(module, 'init_weights'):
|
| 1277 |
+
module.init_weights(zero_init_last=zero_init_last)
|
| 1278 |
+
|
| 1279 |
+
|
| 1280 |
+
model_cfgs = dict(
|
| 1281 |
+
gernet_l=ByoModelCfg(
|
| 1282 |
+
blocks=(
|
| 1283 |
+
ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.),
|
| 1284 |
+
ByoBlockCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.),
|
| 1285 |
+
ByoBlockCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4),
|
| 1286 |
+
ByoBlockCfg(type='bottle', d=5, c=640, s=2, gs=1, br=3.),
|
| 1287 |
+
ByoBlockCfg(type='bottle', d=4, c=640, s=1, gs=1, br=3.),
|
| 1288 |
+
),
|
| 1289 |
+
stem_chs=32,
|
| 1290 |
+
stem_pool=None,
|
| 1291 |
+
num_features=2560,
|
| 1292 |
+
),
|
| 1293 |
+
gernet_m=ByoModelCfg(
|
| 1294 |
+
blocks=(
|
| 1295 |
+
ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.),
|
| 1296 |
+
ByoBlockCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.),
|
| 1297 |
+
ByoBlockCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4),
|
| 1298 |
+
ByoBlockCfg(type='bottle', d=4, c=640, s=2, gs=1, br=3.),
|
| 1299 |
+
ByoBlockCfg(type='bottle', d=1, c=640, s=1, gs=1, br=3.),
|
| 1300 |
+
),
|
| 1301 |
+
stem_chs=32,
|
| 1302 |
+
stem_pool=None,
|
| 1303 |
+
num_features=2560,
|
| 1304 |
+
),
|
| 1305 |
+
gernet_s=ByoModelCfg(
|
| 1306 |
+
blocks=(
|
| 1307 |
+
ByoBlockCfg(type='basic', d=1, c=48, s=2, gs=0, br=1.),
|
| 1308 |
+
ByoBlockCfg(type='basic', d=3, c=48, s=2, gs=0, br=1.),
|
| 1309 |
+
ByoBlockCfg(type='bottle', d=7, c=384, s=2, gs=0, br=1 / 4),
|
| 1310 |
+
ByoBlockCfg(type='bottle', d=2, c=560, s=2, gs=1, br=3.),
|
| 1311 |
+
ByoBlockCfg(type='bottle', d=1, c=256, s=1, gs=1, br=3.),
|
| 1312 |
+
),
|
| 1313 |
+
stem_chs=13,
|
| 1314 |
+
stem_pool=None,
|
| 1315 |
+
num_features=1920,
|
| 1316 |
+
),
|
| 1317 |
+
|
| 1318 |
+
repvgg_a0=ByoModelCfg(
|
| 1319 |
+
blocks=_rep_vgg_bcfg(d=(2, 4, 14, 1), wf=(0.75, 0.75, 0.75, 2.5)),
|
| 1320 |
+
stem_type='rep',
|
| 1321 |
+
stem_chs=48,
|
| 1322 |
+
),
|
| 1323 |
+
repvgg_a1=ByoModelCfg(
|
| 1324 |
+
blocks=_rep_vgg_bcfg(d=(2, 4, 14, 1), wf=(1, 1, 1, 2.5)),
|
| 1325 |
+
stem_type='rep',
|
| 1326 |
+
stem_chs=64,
|
| 1327 |
+
),
|
| 1328 |
+
repvgg_a2=ByoModelCfg(
|
| 1329 |
+
blocks=_rep_vgg_bcfg(d=(2, 4, 14, 1), wf=(1.5, 1.5, 1.5, 2.75)),
|
| 1330 |
+
stem_type='rep',
|
| 1331 |
+
stem_chs=64,
|
| 1332 |
+
),
|
| 1333 |
+
repvgg_b0=ByoModelCfg(
|
| 1334 |
+
blocks=_rep_vgg_bcfg(wf=(1., 1., 1., 2.5)),
|
| 1335 |
+
stem_type='rep',
|
| 1336 |
+
stem_chs=64,
|
| 1337 |
+
),
|
| 1338 |
+
repvgg_b1=ByoModelCfg(
|
| 1339 |
+
blocks=_rep_vgg_bcfg(wf=(2., 2., 2., 4.)),
|
| 1340 |
+
stem_type='rep',
|
| 1341 |
+
stem_chs=64,
|
| 1342 |
+
),
|
| 1343 |
+
repvgg_b1g4=ByoModelCfg(
|
| 1344 |
+
blocks=_rep_vgg_bcfg(wf=(2., 2., 2., 4.), groups=4),
|
| 1345 |
+
stem_type='rep',
|
| 1346 |
+
stem_chs=64,
|
| 1347 |
+
),
|
| 1348 |
+
repvgg_b2=ByoModelCfg(
|
| 1349 |
+
blocks=_rep_vgg_bcfg(wf=(2.5, 2.5, 2.5, 5.)),
|
| 1350 |
+
stem_type='rep',
|
| 1351 |
+
stem_chs=64,
|
| 1352 |
+
),
|
| 1353 |
+
repvgg_b2g4=ByoModelCfg(
|
| 1354 |
+
blocks=_rep_vgg_bcfg(wf=(2.5, 2.5, 2.5, 5.), groups=4),
|
| 1355 |
+
stem_type='rep',
|
| 1356 |
+
stem_chs=64,
|
| 1357 |
+
),
|
| 1358 |
+
repvgg_b3=ByoModelCfg(
|
| 1359 |
+
blocks=_rep_vgg_bcfg(wf=(3., 3., 3., 5.)),
|
| 1360 |
+
stem_type='rep',
|
| 1361 |
+
stem_chs=64,
|
| 1362 |
+
),
|
| 1363 |
+
repvgg_b3g4=ByoModelCfg(
|
| 1364 |
+
blocks=_rep_vgg_bcfg(wf=(3., 3., 3., 5.), groups=4),
|
| 1365 |
+
stem_type='rep',
|
| 1366 |
+
stem_chs=64,
|
| 1367 |
+
),
|
| 1368 |
+
repvgg_d2se=ByoModelCfg(
|
| 1369 |
+
blocks=_rep_vgg_bcfg(d=(8, 14, 24, 1), wf=(2.5, 2.5, 2.5, 5.)),
|
| 1370 |
+
stem_type='rep',
|
| 1371 |
+
stem_chs=64,
|
| 1372 |
+
attn_layer='se',
|
| 1373 |
+
attn_kwargs=dict(rd_ratio=0.0625, rd_divisor=1),
|
| 1374 |
+
),
|
| 1375 |
+
|
| 1376 |
+
# 4 x conv stem w/ 2 act, no maxpool, 2,4,6,4 repeats, group size 32 in first 3 blocks
|
| 1377 |
+
# DW convs in last block, 2048 pre-FC, silu act
|
| 1378 |
+
resnet51q=ByoModelCfg(
|
| 1379 |
+
blocks=(
|
| 1380 |
+
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
|
| 1381 |
+
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25),
|
| 1382 |
+
ByoBlockCfg(type='bottle', d=6, c=1536, s=2, gs=32, br=0.25),
|
| 1383 |
+
ByoBlockCfg(type='bottle', d=4, c=1536, s=2, gs=1, br=1.0),
|
| 1384 |
+
),
|
| 1385 |
+
stem_chs=128,
|
| 1386 |
+
stem_type='quad2',
|
| 1387 |
+
stem_pool=None,
|
| 1388 |
+
num_features=2048,
|
| 1389 |
+
act_layer='silu',
|
| 1390 |
+
),
|
| 1391 |
+
|
| 1392 |
+
# 4 x conv stem w/ 4 act, no maxpool, 1,4,6,4 repeats, edge block first, group size 32 in next 2 blocks
|
| 1393 |
+
# DW convs in last block, 4 conv for each bottle block, 2048 pre-FC, silu act
|
| 1394 |
+
resnet61q=ByoModelCfg(
|
| 1395 |
+
blocks=(
|
| 1396 |
+
ByoBlockCfg(type='edge', d=1, c=256, s=1, gs=0, br=1.0, block_kwargs=dict()),
|
| 1397 |
+
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25),
|
| 1398 |
+
ByoBlockCfg(type='bottle', d=6, c=1536, s=2, gs=32, br=0.25),
|
| 1399 |
+
ByoBlockCfg(type='bottle', d=4, c=1536, s=2, gs=1, br=1.0),
|
| 1400 |
+
),
|
| 1401 |
+
stem_chs=128,
|
| 1402 |
+
stem_type='quad',
|
| 1403 |
+
stem_pool=None,
|
| 1404 |
+
num_features=2048,
|
| 1405 |
+
act_layer='silu',
|
| 1406 |
+
block_kwargs=dict(extra_conv=True),
|
| 1407 |
+
),
|
| 1408 |
+
|
| 1409 |
+
# A series of ResNeXt-26 models w/ one of none, GC, SE, ECA, BAT attn, group size 32, SiLU act,
|
| 1410 |
+
# and a tiered stem w/ maxpool
|
| 1411 |
+
resnext26ts=ByoModelCfg(
|
| 1412 |
+
blocks=(
|
| 1413 |
+
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
|
| 1414 |
+
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=32, br=0.25),
|
| 1415 |
+
ByoBlockCfg(type='bottle', d=2, c=1024, s=2, gs=32, br=0.25),
|
| 1416 |
+
ByoBlockCfg(type='bottle', d=2, c=2048, s=2, gs=32, br=0.25),
|
| 1417 |
+
),
|
| 1418 |
+
stem_chs=64,
|
| 1419 |
+
stem_type='tiered',
|
| 1420 |
+
stem_pool='maxpool',
|
| 1421 |
+
act_layer='silu',
|
| 1422 |
+
),
|
| 1423 |
+
gcresnext26ts=ByoModelCfg(
|
| 1424 |
+
blocks=(
|
| 1425 |
+
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
|
| 1426 |
+
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=32, br=0.25),
|
| 1427 |
+
ByoBlockCfg(type='bottle', d=2, c=1024, s=2, gs=32, br=0.25),
|
| 1428 |
+
ByoBlockCfg(type='bottle', d=2, c=2048, s=2, gs=32, br=0.25),
|
| 1429 |
+
),
|
| 1430 |
+
stem_chs=64,
|
| 1431 |
+
stem_type='tiered',
|
| 1432 |
+
stem_pool='maxpool',
|
| 1433 |
+
act_layer='silu',
|
| 1434 |
+
attn_layer='gca',
|
| 1435 |
+
),
|
| 1436 |
+
seresnext26ts=ByoModelCfg(
|
| 1437 |
+
blocks=(
|
| 1438 |
+
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
|
| 1439 |
+
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=32, br=0.25),
|
| 1440 |
+
ByoBlockCfg(type='bottle', d=2, c=1024, s=2, gs=32, br=0.25),
|
| 1441 |
+
ByoBlockCfg(type='bottle', d=2, c=2048, s=2, gs=32, br=0.25),
|
| 1442 |
+
),
|
| 1443 |
+
stem_chs=64,
|
| 1444 |
+
stem_type='tiered',
|
| 1445 |
+
stem_pool='maxpool',
|
| 1446 |
+
act_layer='silu',
|
| 1447 |
+
attn_layer='se',
|
| 1448 |
+
),
|
| 1449 |
+
eca_resnext26ts=ByoModelCfg(
|
| 1450 |
+
blocks=(
|
| 1451 |
+
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
|
| 1452 |
+
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=32, br=0.25),
|
| 1453 |
+
ByoBlockCfg(type='bottle', d=2, c=1024, s=2, gs=32, br=0.25),
|
| 1454 |
+
ByoBlockCfg(type='bottle', d=2, c=2048, s=2, gs=32, br=0.25),
|
| 1455 |
+
),
|
| 1456 |
+
stem_chs=64,
|
| 1457 |
+
stem_type='tiered',
|
| 1458 |
+
stem_pool='maxpool',
|
| 1459 |
+
act_layer='silu',
|
| 1460 |
+
attn_layer='eca',
|
| 1461 |
+
),
|
| 1462 |
+
bat_resnext26ts=ByoModelCfg(
|
| 1463 |
+
blocks=(
|
| 1464 |
+
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
|
| 1465 |
+
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=32, br=0.25),
|
| 1466 |
+
ByoBlockCfg(type='bottle', d=2, c=1024, s=2, gs=32, br=0.25),
|
| 1467 |
+
ByoBlockCfg(type='bottle', d=2, c=2048, s=2, gs=32, br=0.25),
|
| 1468 |
+
),
|
| 1469 |
+
stem_chs=64,
|
| 1470 |
+
stem_type='tiered',
|
| 1471 |
+
stem_pool='maxpool',
|
| 1472 |
+
act_layer='silu',
|
| 1473 |
+
attn_layer='bat',
|
| 1474 |
+
attn_kwargs=dict(block_size=8)
|
| 1475 |
+
),
|
| 1476 |
+
|
| 1477 |
+
# ResNet-32 (2, 3, 3, 2) models w/ no attn, no groups, SiLU act, no pre-fc feat layer, tiered stem w/o maxpool
|
| 1478 |
+
resnet32ts=ByoModelCfg(
|
| 1479 |
+
blocks=(
|
| 1480 |
+
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
| 1481 |
+
ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
|
| 1482 |
+
ByoBlockCfg(type='bottle', d=3, c=1536, s=2, gs=0, br=0.25),
|
| 1483 |
+
ByoBlockCfg(type='bottle', d=2, c=1536, s=2, gs=0, br=0.25),
|
| 1484 |
+
),
|
| 1485 |
+
stem_chs=64,
|
| 1486 |
+
stem_type='tiered',
|
| 1487 |
+
stem_pool='',
|
| 1488 |
+
num_features=0,
|
| 1489 |
+
act_layer='silu',
|
| 1490 |
+
),
|
| 1491 |
+
|
| 1492 |
+
# ResNet-33 (2, 3, 3, 2) models w/ no attn, no groups, SiLU act, 1280 pre-FC feat, tiered stem w/o maxpool
|
| 1493 |
+
resnet33ts=ByoModelCfg(
|
| 1494 |
+
blocks=(
|
| 1495 |
+
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
| 1496 |
+
ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
|
| 1497 |
+
ByoBlockCfg(type='bottle', d=3, c=1536, s=2, gs=0, br=0.25),
|
| 1498 |
+
ByoBlockCfg(type='bottle', d=2, c=1536, s=2, gs=0, br=0.25),
|
| 1499 |
+
),
|
| 1500 |
+
stem_chs=64,
|
| 1501 |
+
stem_type='tiered',
|
| 1502 |
+
stem_pool='',
|
| 1503 |
+
num_features=1280,
|
| 1504 |
+
act_layer='silu',
|
| 1505 |
+
),
|
| 1506 |
+
|
| 1507 |
+
# A series of ResNet-33 (2, 3, 3, 2) models w/ one of GC, SE, ECA attn, no groups, SiLU act, 1280 pre-FC feat
|
| 1508 |
+
# and a tiered stem w/ no maxpool
|
| 1509 |
+
gcresnet33ts=ByoModelCfg(
|
| 1510 |
+
blocks=(
|
| 1511 |
+
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
| 1512 |
+
ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
|
| 1513 |
+
ByoBlockCfg(type='bottle', d=3, c=1536, s=2, gs=0, br=0.25),
|
| 1514 |
+
ByoBlockCfg(type='bottle', d=2, c=1536, s=2, gs=0, br=0.25),
|
| 1515 |
+
),
|
| 1516 |
+
stem_chs=64,
|
| 1517 |
+
stem_type='tiered',
|
| 1518 |
+
stem_pool='',
|
| 1519 |
+
num_features=1280,
|
| 1520 |
+
act_layer='silu',
|
| 1521 |
+
attn_layer='gca',
|
| 1522 |
+
),
|
| 1523 |
+
seresnet33ts=ByoModelCfg(
|
| 1524 |
+
blocks=(
|
| 1525 |
+
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
| 1526 |
+
ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
|
| 1527 |
+
ByoBlockCfg(type='bottle', d=3, c=1536, s=2, gs=0, br=0.25),
|
| 1528 |
+
ByoBlockCfg(type='bottle', d=2, c=1536, s=2, gs=0, br=0.25),
|
| 1529 |
+
),
|
| 1530 |
+
stem_chs=64,
|
| 1531 |
+
stem_type='tiered',
|
| 1532 |
+
stem_pool='',
|
| 1533 |
+
num_features=1280,
|
| 1534 |
+
act_layer='silu',
|
| 1535 |
+
attn_layer='se',
|
| 1536 |
+
),
|
| 1537 |
+
eca_resnet33ts=ByoModelCfg(
|
| 1538 |
+
blocks=(
|
| 1539 |
+
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
|
| 1540 |
+
ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
|
| 1541 |
+
ByoBlockCfg(type='bottle', d=3, c=1536, s=2, gs=0, br=0.25),
|
| 1542 |
+
ByoBlockCfg(type='bottle', d=2, c=1536, s=2, gs=0, br=0.25),
|
| 1543 |
+
),
|
| 1544 |
+
stem_chs=64,
|
| 1545 |
+
stem_type='tiered',
|
| 1546 |
+
stem_pool='',
|
| 1547 |
+
num_features=1280,
|
| 1548 |
+
act_layer='silu',
|
| 1549 |
+
attn_layer='eca',
|
| 1550 |
+
),
|
| 1551 |
+
|
| 1552 |
+
gcresnet50t=ByoModelCfg(
|
| 1553 |
+
blocks=(
|
| 1554 |
+
ByoBlockCfg(type='bottle', d=3, c=256, s=1, br=0.25),
|
| 1555 |
+
ByoBlockCfg(type='bottle', d=4, c=512, s=2, br=0.25),
|
| 1556 |
+
ByoBlockCfg(type='bottle', d=6, c=1024, s=2, br=0.25),
|
| 1557 |
+
ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25),
|
| 1558 |
+
),
|
| 1559 |
+
stem_chs=64,
|
| 1560 |
+
stem_type='tiered',
|
| 1561 |
+
stem_pool='',
|
| 1562 |
+
attn_layer='gca',
|
| 1563 |
+
),
|
| 1564 |
+
|
| 1565 |
+
gcresnext50ts=ByoModelCfg(
|
| 1566 |
+
blocks=(
|
| 1567 |
+
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=32, br=0.25),
|
| 1568 |
+
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25),
|
| 1569 |
+
ByoBlockCfg(type='bottle', d=6, c=1024, s=2, gs=32, br=0.25),
|
| 1570 |
+
ByoBlockCfg(type='bottle', d=3, c=2048, s=2, gs=32, br=0.25),
|
| 1571 |
+
),
|
| 1572 |
+
stem_chs=64,
|
| 1573 |
+
stem_type='tiered',
|
| 1574 |
+
stem_pool='maxpool',
|
| 1575 |
+
act_layer='silu',
|
| 1576 |
+
attn_layer='gca',
|
| 1577 |
+
),
|
| 1578 |
+
|
| 1579 |
+
# experimental models, closer to a RegNetZ than a ResNet. Similar to EfficientNets but w/ groups instead of DW
|
| 1580 |
+
regnetz_b16=ByoModelCfg(
|
| 1581 |
+
blocks=(
|
| 1582 |
+
ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=3),
|
| 1583 |
+
ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=3),
|
| 1584 |
+
ByoBlockCfg(type='bottle', d=12, c=192, s=2, gs=16, br=3),
|
| 1585 |
+
ByoBlockCfg(type='bottle', d=2, c=288, s=2, gs=16, br=3),
|
| 1586 |
+
),
|
| 1587 |
+
stem_chs=32,
|
| 1588 |
+
stem_pool='',
|
| 1589 |
+
downsample='',
|
| 1590 |
+
num_features=1536,
|
| 1591 |
+
act_layer='silu',
|
| 1592 |
+
attn_layer='se',
|
| 1593 |
+
attn_kwargs=dict(rd_ratio=0.25),
|
| 1594 |
+
block_kwargs=dict(bottle_in=True, linear_out=True),
|
| 1595 |
+
),
|
| 1596 |
+
regnetz_c16=ByoModelCfg(
|
| 1597 |
+
blocks=(
|
| 1598 |
+
ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=4),
|
| 1599 |
+
ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=4),
|
| 1600 |
+
ByoBlockCfg(type='bottle', d=12, c=192, s=2, gs=16, br=4),
|
| 1601 |
+
ByoBlockCfg(type='bottle', d=2, c=288, s=2, gs=16, br=4),
|
| 1602 |
+
),
|
| 1603 |
+
stem_chs=32,
|
| 1604 |
+
stem_pool='',
|
| 1605 |
+
downsample='',
|
| 1606 |
+
num_features=1536,
|
| 1607 |
+
act_layer='silu',
|
| 1608 |
+
attn_layer='se',
|
| 1609 |
+
attn_kwargs=dict(rd_ratio=0.25),
|
| 1610 |
+
block_kwargs=dict(bottle_in=True, linear_out=True),
|
| 1611 |
+
),
|
| 1612 |
+
regnetz_d32=ByoModelCfg(
|
| 1613 |
+
blocks=(
|
| 1614 |
+
ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=32, br=4),
|
| 1615 |
+
ByoBlockCfg(type='bottle', d=6, c=128, s=2, gs=32, br=4),
|
| 1616 |
+
ByoBlockCfg(type='bottle', d=12, c=256, s=2, gs=32, br=4),
|
| 1617 |
+
ByoBlockCfg(type='bottle', d=3, c=384, s=2, gs=32, br=4),
|
| 1618 |
+
),
|
| 1619 |
+
stem_chs=64,
|
| 1620 |
+
stem_type='tiered',
|
| 1621 |
+
stem_pool='',
|
| 1622 |
+
downsample='',
|
| 1623 |
+
num_features=1792,
|
| 1624 |
+
act_layer='silu',
|
| 1625 |
+
attn_layer='se',
|
| 1626 |
+
attn_kwargs=dict(rd_ratio=0.25),
|
| 1627 |
+
block_kwargs=dict(bottle_in=True, linear_out=True),
|
| 1628 |
+
),
|
| 1629 |
+
regnetz_d8=ByoModelCfg(
|
| 1630 |
+
blocks=(
|
| 1631 |
+
ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=8, br=4),
|
| 1632 |
+
ByoBlockCfg(type='bottle', d=6, c=128, s=2, gs=8, br=4),
|
| 1633 |
+
ByoBlockCfg(type='bottle', d=12, c=256, s=2, gs=8, br=4),
|
| 1634 |
+
ByoBlockCfg(type='bottle', d=3, c=384, s=2, gs=8, br=4),
|
| 1635 |
+
),
|
| 1636 |
+
stem_chs=64,
|
| 1637 |
+
stem_type='tiered',
|
| 1638 |
+
stem_pool='',
|
| 1639 |
+
downsample='',
|
| 1640 |
+
num_features=1792,
|
| 1641 |
+
act_layer='silu',
|
| 1642 |
+
attn_layer='se',
|
| 1643 |
+
attn_kwargs=dict(rd_ratio=0.25),
|
| 1644 |
+
block_kwargs=dict(bottle_in=True, linear_out=True),
|
| 1645 |
+
),
|
| 1646 |
+
regnetz_e8=ByoModelCfg(
|
| 1647 |
+
blocks=(
|
| 1648 |
+
ByoBlockCfg(type='bottle', d=3, c=96, s=1, gs=8, br=4),
|
| 1649 |
+
ByoBlockCfg(type='bottle', d=8, c=192, s=2, gs=8, br=4),
|
| 1650 |
+
ByoBlockCfg(type='bottle', d=16, c=384, s=2, gs=8, br=4),
|
| 1651 |
+
ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=8, br=4),
|
| 1652 |
+
),
|
| 1653 |
+
stem_chs=64,
|
| 1654 |
+
stem_type='tiered',
|
| 1655 |
+
stem_pool='',
|
| 1656 |
+
downsample='',
|
| 1657 |
+
num_features=2048,
|
| 1658 |
+
act_layer='silu',
|
| 1659 |
+
attn_layer='se',
|
| 1660 |
+
attn_kwargs=dict(rd_ratio=0.25),
|
| 1661 |
+
block_kwargs=dict(bottle_in=True, linear_out=True),
|
| 1662 |
+
),
|
| 1663 |
+
|
| 1664 |
+
# experimental EvoNorm configs
|
| 1665 |
+
regnetz_b16_evos=ByoModelCfg(
|
| 1666 |
+
blocks=(
|
| 1667 |
+
ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=3),
|
| 1668 |
+
ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=3),
|
| 1669 |
+
ByoBlockCfg(type='bottle', d=12, c=192, s=2, gs=16, br=3),
|
| 1670 |
+
ByoBlockCfg(type='bottle', d=2, c=288, s=2, gs=16, br=3),
|
| 1671 |
+
),
|
| 1672 |
+
stem_chs=32,
|
| 1673 |
+
stem_pool='',
|
| 1674 |
+
downsample='',
|
| 1675 |
+
num_features=1536,
|
| 1676 |
+
act_layer='silu',
|
| 1677 |
+
norm_layer=partial(EvoNorm2dS0a, group_size=16),
|
| 1678 |
+
attn_layer='se',
|
| 1679 |
+
attn_kwargs=dict(rd_ratio=0.25),
|
| 1680 |
+
block_kwargs=dict(bottle_in=True, linear_out=True),
|
| 1681 |
+
),
|
| 1682 |
+
regnetz_c16_evos=ByoModelCfg(
|
| 1683 |
+
blocks=(
|
| 1684 |
+
ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=4),
|
| 1685 |
+
ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=4),
|
| 1686 |
+
ByoBlockCfg(type='bottle', d=12, c=192, s=2, gs=16, br=4),
|
| 1687 |
+
ByoBlockCfg(type='bottle', d=2, c=288, s=2, gs=16, br=4),
|
| 1688 |
+
),
|
| 1689 |
+
stem_chs=32,
|
| 1690 |
+
stem_pool='',
|
| 1691 |
+
downsample='',
|
| 1692 |
+
num_features=1536,
|
| 1693 |
+
act_layer='silu',
|
| 1694 |
+
norm_layer=partial(EvoNorm2dS0a, group_size=16),
|
| 1695 |
+
attn_layer='se',
|
| 1696 |
+
attn_kwargs=dict(rd_ratio=0.25),
|
| 1697 |
+
block_kwargs=dict(bottle_in=True, linear_out=True),
|
| 1698 |
+
),
|
| 1699 |
+
regnetz_d8_evos=ByoModelCfg(
|
| 1700 |
+
blocks=(
|
| 1701 |
+
ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=8, br=4),
|
| 1702 |
+
ByoBlockCfg(type='bottle', d=6, c=128, s=2, gs=8, br=4),
|
| 1703 |
+
ByoBlockCfg(type='bottle', d=12, c=256, s=2, gs=8, br=4),
|
| 1704 |
+
ByoBlockCfg(type='bottle', d=3, c=384, s=2, gs=8, br=4),
|
| 1705 |
+
),
|
| 1706 |
+
stem_chs=64,
|
| 1707 |
+
stem_type='deep',
|
| 1708 |
+
stem_pool='',
|
| 1709 |
+
downsample='',
|
| 1710 |
+
num_features=1792,
|
| 1711 |
+
act_layer='silu',
|
| 1712 |
+
norm_layer=partial(EvoNorm2dS0a, group_size=16),
|
| 1713 |
+
attn_layer='se',
|
| 1714 |
+
attn_kwargs=dict(rd_ratio=0.25),
|
| 1715 |
+
block_kwargs=dict(bottle_in=True, linear_out=True),
|
| 1716 |
+
),
|
| 1717 |
+
|
| 1718 |
+
mobileone_s0=ByoModelCfg(
|
| 1719 |
+
blocks=_mobileone_bcfg(wf=(0.75, 1.0, 1.0, 2.), num_conv_branches=4),
|
| 1720 |
+
stem_type='one',
|
| 1721 |
+
stem_chs=48,
|
| 1722 |
+
),
|
| 1723 |
+
mobileone_s1=ByoModelCfg(
|
| 1724 |
+
blocks=_mobileone_bcfg(wf=(1.5, 1.5, 2.0, 2.5)),
|
| 1725 |
+
stem_type='one',
|
| 1726 |
+
stem_chs=64,
|
| 1727 |
+
),
|
| 1728 |
+
mobileone_s2=ByoModelCfg(
|
| 1729 |
+
blocks=_mobileone_bcfg(wf=(1.5, 2.0, 2.5, 4.0)),
|
| 1730 |
+
stem_type='one',
|
| 1731 |
+
stem_chs=64,
|
| 1732 |
+
),
|
| 1733 |
+
mobileone_s3=ByoModelCfg(
|
| 1734 |
+
blocks=_mobileone_bcfg(wf=(2.0, 2.5, 3.0, 4.0)),
|
| 1735 |
+
stem_type='one',
|
| 1736 |
+
stem_chs=64,
|
| 1737 |
+
),
|
| 1738 |
+
mobileone_s4=ByoModelCfg(
|
| 1739 |
+
blocks=_mobileone_bcfg(wf=(3.0, 3.5, 3.5, 4.0), se_blocks=(0, 0, 5, 1)),
|
| 1740 |
+
stem_type='one',
|
| 1741 |
+
stem_chs=64,
|
| 1742 |
+
),
|
| 1743 |
+
)
|
| 1744 |
+
|
| 1745 |
+
|
| 1746 |
+
def _create_byobnet(variant, pretrained=False, **kwargs):
|
| 1747 |
+
return build_model_with_cfg(
|
| 1748 |
+
ByobNet, variant, pretrained,
|
| 1749 |
+
model_cfg=model_cfgs[variant],
|
| 1750 |
+
feature_cfg=dict(flatten_sequential=True),
|
| 1751 |
+
**kwargs)
|
| 1752 |
+
|
| 1753 |
+
|
| 1754 |
+
def _cfg(url='', **kwargs):
|
| 1755 |
+
return {
|
| 1756 |
+
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
| 1757 |
+
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
| 1758 |
+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
| 1759 |
+
'first_conv': 'stem.conv', 'classifier': 'head.fc',
|
| 1760 |
+
**kwargs
|
| 1761 |
+
}
|
| 1762 |
+
|
| 1763 |
+
|
| 1764 |
+
def _cfgr(url='', **kwargs):
|
| 1765 |
+
return {
|
| 1766 |
+
'url': url, 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
|
| 1767 |
+
'crop_pct': 0.9, 'interpolation': 'bicubic',
|
| 1768 |
+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
| 1769 |
+
'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc',
|
| 1770 |
+
**kwargs
|
| 1771 |
+
}
|
| 1772 |
+
|
| 1773 |
+
|
| 1774 |
+
default_cfgs = generate_default_cfgs({
|
| 1775 |
+
# GPU-Efficient (ResNet) weights
|
| 1776 |
+
'gernet_s.idstcv_in1k': _cfg(hf_hub_id='timm/'),
|
| 1777 |
+
'gernet_m.idstcv_in1k': _cfg(hf_hub_id='timm/'),
|
| 1778 |
+
'gernet_l.idstcv_in1k': _cfg(hf_hub_id='timm/', input_size=(3, 256, 256), pool_size=(8, 8)),
|
| 1779 |
+
|
| 1780 |
+
# RepVGG weights
|
| 1781 |
+
'repvgg_a0.rvgg_in1k': _cfg(
|
| 1782 |
+
hf_hub_id='timm/',
|
| 1783 |
+
first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
|
| 1784 |
+
'repvgg_a1.rvgg_in1k': _cfg(
|
| 1785 |
+
hf_hub_id='timm/',
|
| 1786 |
+
first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
|
| 1787 |
+
'repvgg_a2.rvgg_in1k': _cfg(
|
| 1788 |
+
hf_hub_id='timm/',
|
| 1789 |
+
first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
|
| 1790 |
+
'repvgg_b0.rvgg_in1k': _cfg(
|
| 1791 |
+
hf_hub_id='timm/',
|
| 1792 |
+
first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
|
| 1793 |
+
'repvgg_b1.rvgg_in1k': _cfg(
|
| 1794 |
+
hf_hub_id='timm/',
|
| 1795 |
+
first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
|
| 1796 |
+
'repvgg_b1g4.rvgg_in1k': _cfg(
|
| 1797 |
+
hf_hub_id='timm/',
|
| 1798 |
+
first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
|
| 1799 |
+
'repvgg_b2.rvgg_in1k': _cfg(
|
| 1800 |
+
hf_hub_id='timm/',
|
| 1801 |
+
first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
|
| 1802 |
+
'repvgg_b2g4.rvgg_in1k': _cfg(
|
| 1803 |
+
hf_hub_id='timm/',
|
| 1804 |
+
first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
|
| 1805 |
+
'repvgg_b3.rvgg_in1k': _cfg(
|
| 1806 |
+
hf_hub_id='timm/',
|
| 1807 |
+
first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
|
| 1808 |
+
'repvgg_b3g4.rvgg_in1k': _cfg(
|
| 1809 |
+
hf_hub_id='timm/',
|
| 1810 |
+
first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
|
| 1811 |
+
'repvgg_d2se.rvgg_in1k': _cfg(
|
| 1812 |
+
hf_hub_id='timm/',
|
| 1813 |
+
first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit',
|
| 1814 |
+
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0,
|
| 1815 |
+
),
|
| 1816 |
+
|
| 1817 |
+
# experimental ResNet configs
|
| 1818 |
+
'resnet51q.ra2_in1k': _cfg(
|
| 1819 |
+
hf_hub_id='timm/',
|
| 1820 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet51q_ra2-d47dcc76.pth',
|
| 1821 |
+
first_conv='stem.conv1', input_size=(3, 256, 256), pool_size=(8, 8),
|
| 1822 |
+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 1823 |
+
'resnet61q.ra2_in1k': _cfgr(
|
| 1824 |
+
hf_hub_id='timm/',
|
| 1825 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet61q_ra2-6afc536c.pth',
|
| 1826 |
+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 1827 |
+
|
| 1828 |
+
# ResNeXt-26 models with different attention in Bottleneck blocks
|
| 1829 |
+
'resnext26ts.ra2_in1k': _cfgr(
|
| 1830 |
+
hf_hub_id='timm/',
|
| 1831 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnext26ts_256_ra2-8bbd9106.pth',
|
| 1832 |
+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 1833 |
+
'seresnext26ts.ch_in1k': _cfgr(
|
| 1834 |
+
hf_hub_id='timm/',
|
| 1835 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/seresnext26ts_256-6f0d74a3.pth',
|
| 1836 |
+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 1837 |
+
'gcresnext26ts.ch_in1k': _cfgr(
|
| 1838 |
+
hf_hub_id='timm/',
|
| 1839 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext26ts_256-e414378b.pth',
|
| 1840 |
+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 1841 |
+
'eca_resnext26ts.ch_in1k': _cfgr(
|
| 1842 |
+
hf_hub_id='timm/',
|
| 1843 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_resnext26ts_256-5a1d030f.pth',
|
| 1844 |
+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 1845 |
+
'bat_resnext26ts.ch_in1k': _cfgr(
|
| 1846 |
+
hf_hub_id='timm/',
|
| 1847 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/bat_resnext26ts_256-fa6fd595.pth',
|
| 1848 |
+
min_input_size=(3, 256, 256)),
|
| 1849 |
+
|
| 1850 |
+
# ResNet-32 / 33 models with different attention in Bottleneck blocks
|
| 1851 |
+
'resnet32ts.ra2_in1k': _cfgr(
|
| 1852 |
+
hf_hub_id='timm/',
|
| 1853 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet32ts_256-aacf5250.pth',
|
| 1854 |
+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 1855 |
+
'resnet33ts.ra2_in1k': _cfgr(
|
| 1856 |
+
hf_hub_id='timm/',
|
| 1857 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet33ts_256-e91b09a4.pth',
|
| 1858 |
+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 1859 |
+
'gcresnet33ts.ra2_in1k': _cfgr(
|
| 1860 |
+
hf_hub_id='timm/',
|
| 1861 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnet33ts_256-0e0cd345.pth',
|
| 1862 |
+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 1863 |
+
'seresnet33ts.ra2_in1k': _cfgr(
|
| 1864 |
+
hf_hub_id='timm/',
|
| 1865 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/seresnet33ts_256-f8ad44d9.pth',
|
| 1866 |
+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 1867 |
+
'eca_resnet33ts.ra2_in1k': _cfgr(
|
| 1868 |
+
hf_hub_id='timm/',
|
| 1869 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_resnet33ts_256-8f98face.pth',
|
| 1870 |
+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 1871 |
+
|
| 1872 |
+
'gcresnet50t.ra2_in1k': _cfgr(
|
| 1873 |
+
hf_hub_id='timm/',
|
| 1874 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnet50t_256-96374d1c.pth',
|
| 1875 |
+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 1876 |
+
|
| 1877 |
+
'gcresnext50ts.ch_in1k': _cfgr(
|
| 1878 |
+
hf_hub_id='timm/',
|
| 1879 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext50ts_256-3e0f515e.pth',
|
| 1880 |
+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 1881 |
+
|
| 1882 |
+
# custom `timm` specific RegNetZ inspired models w/ different sizing from paper
|
| 1883 |
+
'regnetz_b16.ra3_in1k': _cfgr(
|
| 1884 |
+
hf_hub_id='timm/',
|
| 1885 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_b_raa-677d9606.pth',
|
| 1886 |
+
first_conv='stem.conv', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
| 1887 |
+
input_size=(3, 224, 224), pool_size=(7, 7), crop_pct=0.94, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 1888 |
+
'regnetz_c16.ra3_in1k': _cfgr(
|
| 1889 |
+
hf_hub_id='timm/',
|
| 1890 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_c_rab2_256-a54bf36a.pth',
|
| 1891 |
+
first_conv='stem.conv', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
| 1892 |
+
crop_pct=0.94, test_input_size=(3, 320, 320), test_crop_pct=1.0),
|
| 1893 |
+
'regnetz_d32.ra3_in1k': _cfgr(
|
| 1894 |
+
hf_hub_id='timm/',
|
| 1895 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_d_rab_256-b8073a89.pth',
|
| 1896 |
+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.95, test_input_size=(3, 320, 320)),
|
| 1897 |
+
'regnetz_d8.ra3_in1k': _cfgr(
|
| 1898 |
+
hf_hub_id='timm/',
|
| 1899 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_d8_bh-afc03c55.pth',
|
| 1900 |
+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.94, test_input_size=(3, 320, 320), test_crop_pct=1.0),
|
| 1901 |
+
'regnetz_e8.ra3_in1k': _cfgr(
|
| 1902 |
+
hf_hub_id='timm/',
|
| 1903 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_e8_bh-aace8e6e.pth',
|
| 1904 |
+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.94, test_input_size=(3, 320, 320), test_crop_pct=1.0),
|
| 1905 |
+
|
| 1906 |
+
'regnetz_b16_evos.untrained': _cfgr(
|
| 1907 |
+
first_conv='stem.conv', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
| 1908 |
+
input_size=(3, 224, 224), pool_size=(7, 7), crop_pct=0.95, test_input_size=(3, 288, 288)),
|
| 1909 |
+
'regnetz_c16_evos.ch_in1k': _cfgr(
|
| 1910 |
+
hf_hub_id='timm/',
|
| 1911 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_c16_evos_ch-d8311942.pth',
|
| 1912 |
+
first_conv='stem.conv', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
|
| 1913 |
+
crop_pct=0.95, test_input_size=(3, 320, 320)),
|
| 1914 |
+
'regnetz_d8_evos.ch_in1k': _cfgr(
|
| 1915 |
+
hf_hub_id='timm/',
|
| 1916 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_d8_evos_ch-2bc12646.pth',
|
| 1917 |
+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0),
|
| 1918 |
+
|
| 1919 |
+
'mobileone_s0.apple_in1k': _cfg(
|
| 1920 |
+
hf_hub_id='timm/',
|
| 1921 |
+
crop_pct=0.875,
|
| 1922 |
+
first_conv=('stem.conv_kxk.0.conv', 'stem.conv_scale.conv'),
|
| 1923 |
+
),
|
| 1924 |
+
'mobileone_s1.apple_in1k': _cfg(
|
| 1925 |
+
hf_hub_id='timm/',
|
| 1926 |
+
crop_pct=0.9,
|
| 1927 |
+
first_conv=('stem.conv_kxk.0.conv', 'stem.conv_scale.conv'),
|
| 1928 |
+
),
|
| 1929 |
+
'mobileone_s2.apple_in1k': _cfg(
|
| 1930 |
+
hf_hub_id='timm/',
|
| 1931 |
+
crop_pct=0.9,
|
| 1932 |
+
first_conv=('stem.conv_kxk.0.conv', 'stem.conv_scale.conv'),
|
| 1933 |
+
),
|
| 1934 |
+
'mobileone_s3.apple_in1k': _cfg(
|
| 1935 |
+
hf_hub_id='timm/',
|
| 1936 |
+
crop_pct=0.9,
|
| 1937 |
+
first_conv=('stem.conv_kxk.0.conv', 'stem.conv_scale.conv'),
|
| 1938 |
+
),
|
| 1939 |
+
'mobileone_s4.apple_in1k': _cfg(
|
| 1940 |
+
hf_hub_id='timm/',
|
| 1941 |
+
crop_pct=0.9,
|
| 1942 |
+
first_conv=('stem.conv_kxk.0.conv', 'stem.conv_scale.conv'),
|
| 1943 |
+
),
|
| 1944 |
+
})
|
| 1945 |
+
|
| 1946 |
+
|
| 1947 |
+
@register_model
|
| 1948 |
+
def gernet_l(pretrained=False, **kwargs) -> ByobNet:
|
| 1949 |
+
""" GEResNet-Large (GENet-Large from official impl)
|
| 1950 |
+
`Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
|
| 1951 |
+
"""
|
| 1952 |
+
return _create_byobnet('gernet_l', pretrained=pretrained, **kwargs)
|
| 1953 |
+
|
| 1954 |
+
|
| 1955 |
+
@register_model
|
| 1956 |
+
def gernet_m(pretrained=False, **kwargs) -> ByobNet:
|
| 1957 |
+
""" GEResNet-Medium (GENet-Normal from official impl)
|
| 1958 |
+
`Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
|
| 1959 |
+
"""
|
| 1960 |
+
return _create_byobnet('gernet_m', pretrained=pretrained, **kwargs)
|
| 1961 |
+
|
| 1962 |
+
|
| 1963 |
+
@register_model
|
| 1964 |
+
def gernet_s(pretrained=False, **kwargs) -> ByobNet:
|
| 1965 |
+
""" EResNet-Small (GENet-Small from official impl)
|
| 1966 |
+
`Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
|
| 1967 |
+
"""
|
| 1968 |
+
return _create_byobnet('gernet_s', pretrained=pretrained, **kwargs)
|
| 1969 |
+
|
| 1970 |
+
|
| 1971 |
+
@register_model
|
| 1972 |
+
def repvgg_a0(pretrained=False, **kwargs) -> ByobNet:
|
| 1973 |
+
""" RepVGG-A0
|
| 1974 |
+
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
|
| 1975 |
+
"""
|
| 1976 |
+
return _create_byobnet('repvgg_a0', pretrained=pretrained, **kwargs)
|
| 1977 |
+
|
| 1978 |
+
|
| 1979 |
+
@register_model
|
| 1980 |
+
def repvgg_a1(pretrained=False, **kwargs) -> ByobNet:
|
| 1981 |
+
""" RepVGG-A1
|
| 1982 |
+
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
|
| 1983 |
+
"""
|
| 1984 |
+
return _create_byobnet('repvgg_a1', pretrained=pretrained, **kwargs)
|
| 1985 |
+
|
| 1986 |
+
|
| 1987 |
+
@register_model
|
| 1988 |
+
def repvgg_a2(pretrained=False, **kwargs) -> ByobNet:
|
| 1989 |
+
""" RepVGG-A2
|
| 1990 |
+
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
|
| 1991 |
+
"""
|
| 1992 |
+
return _create_byobnet('repvgg_a2', pretrained=pretrained, **kwargs)
|
| 1993 |
+
|
| 1994 |
+
|
| 1995 |
+
@register_model
|
| 1996 |
+
def repvgg_b0(pretrained=False, **kwargs) -> ByobNet:
|
| 1997 |
+
""" RepVGG-B0
|
| 1998 |
+
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
|
| 1999 |
+
"""
|
| 2000 |
+
return _create_byobnet('repvgg_b0', pretrained=pretrained, **kwargs)
|
| 2001 |
+
|
| 2002 |
+
|
| 2003 |
+
@register_model
|
| 2004 |
+
def repvgg_b1(pretrained=False, **kwargs) -> ByobNet:
|
| 2005 |
+
""" RepVGG-B1
|
| 2006 |
+
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
|
| 2007 |
+
"""
|
| 2008 |
+
return _create_byobnet('repvgg_b1', pretrained=pretrained, **kwargs)
|
| 2009 |
+
|
| 2010 |
+
|
| 2011 |
+
@register_model
|
| 2012 |
+
def repvgg_b1g4(pretrained=False, **kwargs) -> ByobNet:
|
| 2013 |
+
""" RepVGG-B1g4
|
| 2014 |
+
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
|
| 2015 |
+
"""
|
| 2016 |
+
return _create_byobnet('repvgg_b1g4', pretrained=pretrained, **kwargs)
|
| 2017 |
+
|
| 2018 |
+
|
| 2019 |
+
@register_model
|
| 2020 |
+
def repvgg_b2(pretrained=False, **kwargs) -> ByobNet:
|
| 2021 |
+
""" RepVGG-B2
|
| 2022 |
+
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
|
| 2023 |
+
"""
|
| 2024 |
+
return _create_byobnet('repvgg_b2', pretrained=pretrained, **kwargs)
|
| 2025 |
+
|
| 2026 |
+
|
| 2027 |
+
@register_model
|
| 2028 |
+
def repvgg_b2g4(pretrained=False, **kwargs) -> ByobNet:
|
| 2029 |
+
""" RepVGG-B2g4
|
| 2030 |
+
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
|
| 2031 |
+
"""
|
| 2032 |
+
return _create_byobnet('repvgg_b2g4', pretrained=pretrained, **kwargs)
|
| 2033 |
+
|
| 2034 |
+
|
| 2035 |
+
@register_model
|
| 2036 |
+
def repvgg_b3(pretrained=False, **kwargs) -> ByobNet:
|
| 2037 |
+
""" RepVGG-B3
|
| 2038 |
+
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
|
| 2039 |
+
"""
|
| 2040 |
+
return _create_byobnet('repvgg_b3', pretrained=pretrained, **kwargs)
|
| 2041 |
+
|
| 2042 |
+
|
| 2043 |
+
@register_model
|
| 2044 |
+
def repvgg_b3g4(pretrained=False, **kwargs) -> ByobNet:
|
| 2045 |
+
""" RepVGG-B3g4
|
| 2046 |
+
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
|
| 2047 |
+
"""
|
| 2048 |
+
return _create_byobnet('repvgg_b3g4', pretrained=pretrained, **kwargs)
|
| 2049 |
+
|
| 2050 |
+
|
| 2051 |
+
@register_model
|
| 2052 |
+
def repvgg_d2se(pretrained=False, **kwargs) -> ByobNet:
|
| 2053 |
+
""" RepVGG-D2se
|
| 2054 |
+
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
|
| 2055 |
+
"""
|
| 2056 |
+
return _create_byobnet('repvgg_d2se', pretrained=pretrained, **kwargs)
|
| 2057 |
+
|
| 2058 |
+
|
| 2059 |
+
@register_model
|
| 2060 |
+
def resnet51q(pretrained=False, **kwargs) -> ByobNet:
|
| 2061 |
+
"""
|
| 2062 |
+
"""
|
| 2063 |
+
return _create_byobnet('resnet51q', pretrained=pretrained, **kwargs)
|
| 2064 |
+
|
| 2065 |
+
|
| 2066 |
+
@register_model
|
| 2067 |
+
def resnet61q(pretrained=False, **kwargs) -> ByobNet:
|
| 2068 |
+
"""
|
| 2069 |
+
"""
|
| 2070 |
+
return _create_byobnet('resnet61q', pretrained=pretrained, **kwargs)
|
| 2071 |
+
|
| 2072 |
+
|
| 2073 |
+
@register_model
|
| 2074 |
+
def resnext26ts(pretrained=False, **kwargs) -> ByobNet:
|
| 2075 |
+
"""
|
| 2076 |
+
"""
|
| 2077 |
+
return _create_byobnet('resnext26ts', pretrained=pretrained, **kwargs)
|
| 2078 |
+
|
| 2079 |
+
|
| 2080 |
+
@register_model
|
| 2081 |
+
def gcresnext26ts(pretrained=False, **kwargs) -> ByobNet:
|
| 2082 |
+
"""
|
| 2083 |
+
"""
|
| 2084 |
+
return _create_byobnet('gcresnext26ts', pretrained=pretrained, **kwargs)
|
| 2085 |
+
|
| 2086 |
+
|
| 2087 |
+
@register_model
|
| 2088 |
+
def seresnext26ts(pretrained=False, **kwargs) -> ByobNet:
|
| 2089 |
+
"""
|
| 2090 |
+
"""
|
| 2091 |
+
return _create_byobnet('seresnext26ts', pretrained=pretrained, **kwargs)
|
| 2092 |
+
|
| 2093 |
+
|
| 2094 |
+
@register_model
|
| 2095 |
+
def eca_resnext26ts(pretrained=False, **kwargs) -> ByobNet:
|
| 2096 |
+
"""
|
| 2097 |
+
"""
|
| 2098 |
+
return _create_byobnet('eca_resnext26ts', pretrained=pretrained, **kwargs)
|
| 2099 |
+
|
| 2100 |
+
|
| 2101 |
+
@register_model
|
| 2102 |
+
def bat_resnext26ts(pretrained=False, **kwargs) -> ByobNet:
|
| 2103 |
+
"""
|
| 2104 |
+
"""
|
| 2105 |
+
return _create_byobnet('bat_resnext26ts', pretrained=pretrained, **kwargs)
|
| 2106 |
+
|
| 2107 |
+
|
| 2108 |
+
@register_model
|
| 2109 |
+
def resnet32ts(pretrained=False, **kwargs) -> ByobNet:
|
| 2110 |
+
"""
|
| 2111 |
+
"""
|
| 2112 |
+
return _create_byobnet('resnet32ts', pretrained=pretrained, **kwargs)
|
| 2113 |
+
|
| 2114 |
+
|
| 2115 |
+
@register_model
|
| 2116 |
+
def resnet33ts(pretrained=False, **kwargs) -> ByobNet:
|
| 2117 |
+
"""
|
| 2118 |
+
"""
|
| 2119 |
+
return _create_byobnet('resnet33ts', pretrained=pretrained, **kwargs)
|
| 2120 |
+
|
| 2121 |
+
|
| 2122 |
+
@register_model
|
| 2123 |
+
def gcresnet33ts(pretrained=False, **kwargs) -> ByobNet:
|
| 2124 |
+
"""
|
| 2125 |
+
"""
|
| 2126 |
+
return _create_byobnet('gcresnet33ts', pretrained=pretrained, **kwargs)
|
| 2127 |
+
|
| 2128 |
+
|
| 2129 |
+
@register_model
|
| 2130 |
+
def seresnet33ts(pretrained=False, **kwargs) -> ByobNet:
|
| 2131 |
+
"""
|
| 2132 |
+
"""
|
| 2133 |
+
return _create_byobnet('seresnet33ts', pretrained=pretrained, **kwargs)
|
| 2134 |
+
|
| 2135 |
+
|
| 2136 |
+
@register_model
|
| 2137 |
+
def eca_resnet33ts(pretrained=False, **kwargs) -> ByobNet:
|
| 2138 |
+
"""
|
| 2139 |
+
"""
|
| 2140 |
+
return _create_byobnet('eca_resnet33ts', pretrained=pretrained, **kwargs)
|
| 2141 |
+
|
| 2142 |
+
|
| 2143 |
+
@register_model
|
| 2144 |
+
def gcresnet50t(pretrained=False, **kwargs) -> ByobNet:
|
| 2145 |
+
"""
|
| 2146 |
+
"""
|
| 2147 |
+
return _create_byobnet('gcresnet50t', pretrained=pretrained, **kwargs)
|
| 2148 |
+
|
| 2149 |
+
|
| 2150 |
+
@register_model
|
| 2151 |
+
def gcresnext50ts(pretrained=False, **kwargs) -> ByobNet:
|
| 2152 |
+
"""
|
| 2153 |
+
"""
|
| 2154 |
+
return _create_byobnet('gcresnext50ts', pretrained=pretrained, **kwargs)
|
| 2155 |
+
|
| 2156 |
+
|
| 2157 |
+
@register_model
|
| 2158 |
+
def regnetz_b16(pretrained=False, **kwargs) -> ByobNet:
|
| 2159 |
+
"""
|
| 2160 |
+
"""
|
| 2161 |
+
return _create_byobnet('regnetz_b16', pretrained=pretrained, **kwargs)
|
| 2162 |
+
|
| 2163 |
+
|
| 2164 |
+
@register_model
|
| 2165 |
+
def regnetz_c16(pretrained=False, **kwargs) -> ByobNet:
|
| 2166 |
+
"""
|
| 2167 |
+
"""
|
| 2168 |
+
return _create_byobnet('regnetz_c16', pretrained=pretrained, **kwargs)
|
| 2169 |
+
|
| 2170 |
+
|
| 2171 |
+
@register_model
|
| 2172 |
+
def regnetz_d32(pretrained=False, **kwargs) -> ByobNet:
|
| 2173 |
+
"""
|
| 2174 |
+
"""
|
| 2175 |
+
return _create_byobnet('regnetz_d32', pretrained=pretrained, **kwargs)
|
| 2176 |
+
|
| 2177 |
+
|
| 2178 |
+
@register_model
|
| 2179 |
+
def regnetz_d8(pretrained=False, **kwargs) -> ByobNet:
|
| 2180 |
+
"""
|
| 2181 |
+
"""
|
| 2182 |
+
return _create_byobnet('regnetz_d8', pretrained=pretrained, **kwargs)
|
| 2183 |
+
|
| 2184 |
+
|
| 2185 |
+
@register_model
|
| 2186 |
+
def regnetz_e8(pretrained=False, **kwargs) -> ByobNet:
|
| 2187 |
+
"""
|
| 2188 |
+
"""
|
| 2189 |
+
return _create_byobnet('regnetz_e8', pretrained=pretrained, **kwargs)
|
| 2190 |
+
|
| 2191 |
+
|
| 2192 |
+
@register_model
|
| 2193 |
+
def regnetz_b16_evos(pretrained=False, **kwargs) -> ByobNet:
|
| 2194 |
+
"""
|
| 2195 |
+
"""
|
| 2196 |
+
return _create_byobnet('regnetz_b16_evos', pretrained=pretrained, **kwargs)
|
| 2197 |
+
|
| 2198 |
+
|
| 2199 |
+
@register_model
|
| 2200 |
+
def regnetz_c16_evos(pretrained=False, **kwargs) -> ByobNet:
|
| 2201 |
+
"""
|
| 2202 |
+
"""
|
| 2203 |
+
return _create_byobnet('regnetz_c16_evos', pretrained=pretrained, **kwargs)
|
| 2204 |
+
|
| 2205 |
+
|
| 2206 |
+
@register_model
|
| 2207 |
+
def regnetz_d8_evos(pretrained=False, **kwargs) -> ByobNet:
|
| 2208 |
+
"""
|
| 2209 |
+
"""
|
| 2210 |
+
return _create_byobnet('regnetz_d8_evos', pretrained=pretrained, **kwargs)
|
| 2211 |
+
|
| 2212 |
+
|
| 2213 |
+
@register_model
|
| 2214 |
+
def mobileone_s0(pretrained=False, **kwargs) -> ByobNet:
|
| 2215 |
+
"""
|
| 2216 |
+
"""
|
| 2217 |
+
return _create_byobnet('mobileone_s0', pretrained=pretrained, **kwargs)
|
| 2218 |
+
|
| 2219 |
+
|
| 2220 |
+
@register_model
|
| 2221 |
+
def mobileone_s1(pretrained=False, **kwargs) -> ByobNet:
|
| 2222 |
+
"""
|
| 2223 |
+
"""
|
| 2224 |
+
return _create_byobnet('mobileone_s1', pretrained=pretrained, **kwargs)
|
| 2225 |
+
|
| 2226 |
+
|
| 2227 |
+
@register_model
|
| 2228 |
+
def mobileone_s2(pretrained=False, **kwargs) -> ByobNet:
|
| 2229 |
+
"""
|
| 2230 |
+
"""
|
| 2231 |
+
return _create_byobnet('mobileone_s2', pretrained=pretrained, **kwargs)
|
| 2232 |
+
|
| 2233 |
+
|
| 2234 |
+
@register_model
|
| 2235 |
+
def mobileone_s3(pretrained=False, **kwargs) -> ByobNet:
|
| 2236 |
+
"""
|
| 2237 |
+
"""
|
| 2238 |
+
return _create_byobnet('mobileone_s3', pretrained=pretrained, **kwargs)
|
| 2239 |
+
|
| 2240 |
+
|
| 2241 |
+
@register_model
|
| 2242 |
+
def mobileone_s4(pretrained=False, **kwargs) -> ByobNet:
|
| 2243 |
+
"""
|
| 2244 |
+
"""
|
| 2245 |
+
return _create_byobnet('mobileone_s4', pretrained=pretrained, **kwargs)
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/coat.py
ADDED
|
@@ -0,0 +1,804 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CoaT architecture.
|
| 3 |
+
|
| 4 |
+
Paper: Co-Scale Conv-Attentional Image Transformers - https://arxiv.org/abs/2104.06399
|
| 5 |
+
|
| 6 |
+
Official CoaT code at: https://github.com/mlpc-ucsd/CoaT
|
| 7 |
+
|
| 8 |
+
Modified from timm/models/vision_transformer.py
|
| 9 |
+
"""
|
| 10 |
+
from functools import partial
|
| 11 |
+
from typing import Tuple, List, Union
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
|
| 17 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 18 |
+
from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, LayerNorm
|
| 19 |
+
from ._builder import build_model_with_cfg
|
| 20 |
+
from ._registry import register_model, generate_default_cfgs
|
| 21 |
+
|
| 22 |
+
__all__ = ['CoaT']
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ConvRelPosEnc(nn.Module):
|
| 26 |
+
""" Convolutional relative position encoding. """
|
| 27 |
+
def __init__(self, head_chs, num_heads, window):
|
| 28 |
+
"""
|
| 29 |
+
Initialization.
|
| 30 |
+
Ch: Channels per head.
|
| 31 |
+
h: Number of heads.
|
| 32 |
+
window: Window size(s) in convolutional relative positional encoding. It can have two forms:
|
| 33 |
+
1. An integer of window size, which assigns all attention heads with the same window s
|
| 34 |
+
size in ConvRelPosEnc.
|
| 35 |
+
2. A dict mapping window size to #attention head splits (
|
| 36 |
+
e.g. {window size 1: #attention head split 1, window size 2: #attention head split 2})
|
| 37 |
+
It will apply different window size to the attention head splits.
|
| 38 |
+
"""
|
| 39 |
+
super().__init__()
|
| 40 |
+
|
| 41 |
+
if isinstance(window, int):
|
| 42 |
+
# Set the same window size for all attention heads.
|
| 43 |
+
window = {window: num_heads}
|
| 44 |
+
self.window = window
|
| 45 |
+
elif isinstance(window, dict):
|
| 46 |
+
self.window = window
|
| 47 |
+
else:
|
| 48 |
+
raise ValueError()
|
| 49 |
+
|
| 50 |
+
self.conv_list = nn.ModuleList()
|
| 51 |
+
self.head_splits = []
|
| 52 |
+
for cur_window, cur_head_split in window.items():
|
| 53 |
+
dilation = 1
|
| 54 |
+
# Determine padding size.
|
| 55 |
+
# Ref: https://discuss.pytorch.org/t/how-to-keep-the-shape-of-input-and-output-same-when-dilation-conv/14338
|
| 56 |
+
padding_size = (cur_window + (cur_window - 1) * (dilation - 1)) // 2
|
| 57 |
+
cur_conv = nn.Conv2d(
|
| 58 |
+
cur_head_split * head_chs,
|
| 59 |
+
cur_head_split * head_chs,
|
| 60 |
+
kernel_size=(cur_window, cur_window),
|
| 61 |
+
padding=(padding_size, padding_size),
|
| 62 |
+
dilation=(dilation, dilation),
|
| 63 |
+
groups=cur_head_split * head_chs,
|
| 64 |
+
)
|
| 65 |
+
self.conv_list.append(cur_conv)
|
| 66 |
+
self.head_splits.append(cur_head_split)
|
| 67 |
+
self.channel_splits = [x * head_chs for x in self.head_splits]
|
| 68 |
+
|
| 69 |
+
def forward(self, q, v, size: Tuple[int, int]):
|
| 70 |
+
B, num_heads, N, C = q.shape
|
| 71 |
+
H, W = size
|
| 72 |
+
_assert(N == 1 + H * W, '')
|
| 73 |
+
|
| 74 |
+
# Convolutional relative position encoding.
|
| 75 |
+
q_img = q[:, :, 1:, :] # [B, h, H*W, Ch]
|
| 76 |
+
v_img = v[:, :, 1:, :] # [B, h, H*W, Ch]
|
| 77 |
+
|
| 78 |
+
v_img = v_img.transpose(-1, -2).reshape(B, num_heads * C, H, W)
|
| 79 |
+
v_img_list = torch.split(v_img, self.channel_splits, dim=1) # Split according to channels
|
| 80 |
+
conv_v_img_list = []
|
| 81 |
+
for i, conv in enumerate(self.conv_list):
|
| 82 |
+
conv_v_img_list.append(conv(v_img_list[i]))
|
| 83 |
+
conv_v_img = torch.cat(conv_v_img_list, dim=1)
|
| 84 |
+
conv_v_img = conv_v_img.reshape(B, num_heads, C, H * W).transpose(-1, -2)
|
| 85 |
+
|
| 86 |
+
EV_hat = q_img * conv_v_img
|
| 87 |
+
EV_hat = F.pad(EV_hat, (0, 0, 1, 0, 0, 0)) # [B, h, N, Ch].
|
| 88 |
+
return EV_hat
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class FactorAttnConvRelPosEnc(nn.Module):
|
| 92 |
+
""" Factorized attention with convolutional relative position encoding class. """
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
dim,
|
| 96 |
+
num_heads=8,
|
| 97 |
+
qkv_bias=False,
|
| 98 |
+
attn_drop=0.,
|
| 99 |
+
proj_drop=0.,
|
| 100 |
+
shared_crpe=None,
|
| 101 |
+
):
|
| 102 |
+
super().__init__()
|
| 103 |
+
self.num_heads = num_heads
|
| 104 |
+
head_dim = dim // num_heads
|
| 105 |
+
self.scale = head_dim ** -0.5
|
| 106 |
+
|
| 107 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 108 |
+
self.attn_drop = nn.Dropout(attn_drop) # Note: attn_drop is actually not used.
|
| 109 |
+
self.proj = nn.Linear(dim, dim)
|
| 110 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 111 |
+
|
| 112 |
+
# Shared convolutional relative position encoding.
|
| 113 |
+
self.crpe = shared_crpe
|
| 114 |
+
|
| 115 |
+
def forward(self, x, size: Tuple[int, int]):
|
| 116 |
+
B, N, C = x.shape
|
| 117 |
+
|
| 118 |
+
# Generate Q, K, V.
|
| 119 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 120 |
+
q, k, v = qkv.unbind(0) # [B, h, N, Ch]
|
| 121 |
+
|
| 122 |
+
# Factorized attention.
|
| 123 |
+
k_softmax = k.softmax(dim=2)
|
| 124 |
+
factor_att = k_softmax.transpose(-1, -2) @ v
|
| 125 |
+
factor_att = q @ factor_att
|
| 126 |
+
|
| 127 |
+
# Convolutional relative position encoding.
|
| 128 |
+
crpe = self.crpe(q, v, size=size) # [B, h, N, Ch]
|
| 129 |
+
|
| 130 |
+
# Merge and reshape.
|
| 131 |
+
x = self.scale * factor_att + crpe
|
| 132 |
+
x = x.transpose(1, 2).reshape(B, N, C) # [B, h, N, Ch] -> [B, N, h, Ch] -> [B, N, C]
|
| 133 |
+
|
| 134 |
+
# Output projection.
|
| 135 |
+
x = self.proj(x)
|
| 136 |
+
x = self.proj_drop(x)
|
| 137 |
+
|
| 138 |
+
return x
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class ConvPosEnc(nn.Module):
|
| 142 |
+
""" Convolutional Position Encoding.
|
| 143 |
+
Note: This module is similar to the conditional position encoding in CPVT.
|
| 144 |
+
"""
|
| 145 |
+
def __init__(self, dim, k=3):
|
| 146 |
+
super(ConvPosEnc, self).__init__()
|
| 147 |
+
self.proj = nn.Conv2d(dim, dim, k, 1, k//2, groups=dim)
|
| 148 |
+
|
| 149 |
+
def forward(self, x, size: Tuple[int, int]):
|
| 150 |
+
B, N, C = x.shape
|
| 151 |
+
H, W = size
|
| 152 |
+
_assert(N == 1 + H * W, '')
|
| 153 |
+
|
| 154 |
+
# Extract CLS token and image tokens.
|
| 155 |
+
cls_token, img_tokens = x[:, :1], x[:, 1:] # [B, 1, C], [B, H*W, C]
|
| 156 |
+
|
| 157 |
+
# Depthwise convolution.
|
| 158 |
+
feat = img_tokens.transpose(1, 2).view(B, C, H, W)
|
| 159 |
+
x = self.proj(feat) + feat
|
| 160 |
+
x = x.flatten(2).transpose(1, 2)
|
| 161 |
+
|
| 162 |
+
# Combine with CLS token.
|
| 163 |
+
x = torch.cat((cls_token, x), dim=1)
|
| 164 |
+
|
| 165 |
+
return x
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class SerialBlock(nn.Module):
|
| 169 |
+
""" Serial block class.
|
| 170 |
+
Note: In this implementation, each serial block only contains a conv-attention and a FFN (MLP) module. """
|
| 171 |
+
def __init__(
|
| 172 |
+
self,
|
| 173 |
+
dim,
|
| 174 |
+
num_heads,
|
| 175 |
+
mlp_ratio=4.,
|
| 176 |
+
qkv_bias=False,
|
| 177 |
+
proj_drop=0.,
|
| 178 |
+
attn_drop=0.,
|
| 179 |
+
drop_path=0.,
|
| 180 |
+
act_layer=nn.GELU,
|
| 181 |
+
norm_layer=nn.LayerNorm,
|
| 182 |
+
shared_cpe=None,
|
| 183 |
+
shared_crpe=None,
|
| 184 |
+
):
|
| 185 |
+
super().__init__()
|
| 186 |
+
|
| 187 |
+
# Conv-Attention.
|
| 188 |
+
self.cpe = shared_cpe
|
| 189 |
+
|
| 190 |
+
self.norm1 = norm_layer(dim)
|
| 191 |
+
self.factoratt_crpe = FactorAttnConvRelPosEnc(
|
| 192 |
+
dim,
|
| 193 |
+
num_heads=num_heads,
|
| 194 |
+
qkv_bias=qkv_bias,
|
| 195 |
+
attn_drop=attn_drop,
|
| 196 |
+
proj_drop=proj_drop,
|
| 197 |
+
shared_crpe=shared_crpe,
|
| 198 |
+
)
|
| 199 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 200 |
+
|
| 201 |
+
# MLP.
|
| 202 |
+
self.norm2 = norm_layer(dim)
|
| 203 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 204 |
+
self.mlp = Mlp(
|
| 205 |
+
in_features=dim,
|
| 206 |
+
hidden_features=mlp_hidden_dim,
|
| 207 |
+
act_layer=act_layer,
|
| 208 |
+
drop=proj_drop,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
def forward(self, x, size: Tuple[int, int]):
|
| 212 |
+
# Conv-Attention.
|
| 213 |
+
x = self.cpe(x, size)
|
| 214 |
+
cur = self.norm1(x)
|
| 215 |
+
cur = self.factoratt_crpe(cur, size)
|
| 216 |
+
x = x + self.drop_path(cur)
|
| 217 |
+
|
| 218 |
+
# MLP.
|
| 219 |
+
cur = self.norm2(x)
|
| 220 |
+
cur = self.mlp(cur)
|
| 221 |
+
x = x + self.drop_path(cur)
|
| 222 |
+
|
| 223 |
+
return x
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class ParallelBlock(nn.Module):
|
| 227 |
+
""" Parallel block class. """
|
| 228 |
+
def __init__(
|
| 229 |
+
self,
|
| 230 |
+
dims,
|
| 231 |
+
num_heads,
|
| 232 |
+
mlp_ratios=[],
|
| 233 |
+
qkv_bias=False,
|
| 234 |
+
proj_drop=0.,
|
| 235 |
+
attn_drop=0.,
|
| 236 |
+
drop_path=0.,
|
| 237 |
+
act_layer=nn.GELU,
|
| 238 |
+
norm_layer=nn.LayerNorm,
|
| 239 |
+
shared_crpes=None,
|
| 240 |
+
):
|
| 241 |
+
super().__init__()
|
| 242 |
+
|
| 243 |
+
# Conv-Attention.
|
| 244 |
+
self.norm12 = norm_layer(dims[1])
|
| 245 |
+
self.norm13 = norm_layer(dims[2])
|
| 246 |
+
self.norm14 = norm_layer(dims[3])
|
| 247 |
+
self.factoratt_crpe2 = FactorAttnConvRelPosEnc(
|
| 248 |
+
dims[1],
|
| 249 |
+
num_heads=num_heads,
|
| 250 |
+
qkv_bias=qkv_bias,
|
| 251 |
+
attn_drop=attn_drop,
|
| 252 |
+
proj_drop=proj_drop,
|
| 253 |
+
shared_crpe=shared_crpes[1],
|
| 254 |
+
)
|
| 255 |
+
self.factoratt_crpe3 = FactorAttnConvRelPosEnc(
|
| 256 |
+
dims[2],
|
| 257 |
+
num_heads=num_heads,
|
| 258 |
+
qkv_bias=qkv_bias,
|
| 259 |
+
attn_drop=attn_drop,
|
| 260 |
+
proj_drop=proj_drop,
|
| 261 |
+
shared_crpe=shared_crpes[2],
|
| 262 |
+
)
|
| 263 |
+
self.factoratt_crpe4 = FactorAttnConvRelPosEnc(
|
| 264 |
+
dims[3],
|
| 265 |
+
num_heads=num_heads,
|
| 266 |
+
qkv_bias=qkv_bias,
|
| 267 |
+
attn_drop=attn_drop,
|
| 268 |
+
proj_drop=proj_drop,
|
| 269 |
+
shared_crpe=shared_crpes[3],
|
| 270 |
+
)
|
| 271 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 272 |
+
|
| 273 |
+
# MLP.
|
| 274 |
+
self.norm22 = norm_layer(dims[1])
|
| 275 |
+
self.norm23 = norm_layer(dims[2])
|
| 276 |
+
self.norm24 = norm_layer(dims[3])
|
| 277 |
+
# In parallel block, we assume dimensions are the same and share the linear transformation.
|
| 278 |
+
assert dims[1] == dims[2] == dims[3]
|
| 279 |
+
assert mlp_ratios[1] == mlp_ratios[2] == mlp_ratios[3]
|
| 280 |
+
mlp_hidden_dim = int(dims[1] * mlp_ratios[1])
|
| 281 |
+
self.mlp2 = self.mlp3 = self.mlp4 = Mlp(
|
| 282 |
+
in_features=dims[1],
|
| 283 |
+
hidden_features=mlp_hidden_dim,
|
| 284 |
+
act_layer=act_layer,
|
| 285 |
+
drop=proj_drop,
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
def upsample(self, x, factor: float, size: Tuple[int, int]):
|
| 289 |
+
""" Feature map up-sampling. """
|
| 290 |
+
return self.interpolate(x, scale_factor=factor, size=size)
|
| 291 |
+
|
| 292 |
+
def downsample(self, x, factor: float, size: Tuple[int, int]):
|
| 293 |
+
""" Feature map down-sampling. """
|
| 294 |
+
return self.interpolate(x, scale_factor=1.0/factor, size=size)
|
| 295 |
+
|
| 296 |
+
def interpolate(self, x, scale_factor: float, size: Tuple[int, int]):
|
| 297 |
+
""" Feature map interpolation. """
|
| 298 |
+
B, N, C = x.shape
|
| 299 |
+
H, W = size
|
| 300 |
+
_assert(N == 1 + H * W, '')
|
| 301 |
+
|
| 302 |
+
cls_token = x[:, :1, :]
|
| 303 |
+
img_tokens = x[:, 1:, :]
|
| 304 |
+
|
| 305 |
+
img_tokens = img_tokens.transpose(1, 2).reshape(B, C, H, W)
|
| 306 |
+
img_tokens = F.interpolate(
|
| 307 |
+
img_tokens,
|
| 308 |
+
scale_factor=scale_factor,
|
| 309 |
+
recompute_scale_factor=False,
|
| 310 |
+
mode='bilinear',
|
| 311 |
+
align_corners=False,
|
| 312 |
+
)
|
| 313 |
+
img_tokens = img_tokens.reshape(B, C, -1).transpose(1, 2)
|
| 314 |
+
|
| 315 |
+
out = torch.cat((cls_token, img_tokens), dim=1)
|
| 316 |
+
|
| 317 |
+
return out
|
| 318 |
+
|
| 319 |
+
def forward(self, x1, x2, x3, x4, sizes: List[Tuple[int, int]]):
|
| 320 |
+
_, S2, S3, S4 = sizes
|
| 321 |
+
cur2 = self.norm12(x2)
|
| 322 |
+
cur3 = self.norm13(x3)
|
| 323 |
+
cur4 = self.norm14(x4)
|
| 324 |
+
cur2 = self.factoratt_crpe2(cur2, size=S2)
|
| 325 |
+
cur3 = self.factoratt_crpe3(cur3, size=S3)
|
| 326 |
+
cur4 = self.factoratt_crpe4(cur4, size=S4)
|
| 327 |
+
upsample3_2 = self.upsample(cur3, factor=2., size=S3)
|
| 328 |
+
upsample4_3 = self.upsample(cur4, factor=2., size=S4)
|
| 329 |
+
upsample4_2 = self.upsample(cur4, factor=4., size=S4)
|
| 330 |
+
downsample2_3 = self.downsample(cur2, factor=2., size=S2)
|
| 331 |
+
downsample3_4 = self.downsample(cur3, factor=2., size=S3)
|
| 332 |
+
downsample2_4 = self.downsample(cur2, factor=4., size=S2)
|
| 333 |
+
cur2 = cur2 + upsample3_2 + upsample4_2
|
| 334 |
+
cur3 = cur3 + upsample4_3 + downsample2_3
|
| 335 |
+
cur4 = cur4 + downsample3_4 + downsample2_4
|
| 336 |
+
x2 = x2 + self.drop_path(cur2)
|
| 337 |
+
x3 = x3 + self.drop_path(cur3)
|
| 338 |
+
x4 = x4 + self.drop_path(cur4)
|
| 339 |
+
|
| 340 |
+
# MLP.
|
| 341 |
+
cur2 = self.norm22(x2)
|
| 342 |
+
cur3 = self.norm23(x3)
|
| 343 |
+
cur4 = self.norm24(x4)
|
| 344 |
+
cur2 = self.mlp2(cur2)
|
| 345 |
+
cur3 = self.mlp3(cur3)
|
| 346 |
+
cur4 = self.mlp4(cur4)
|
| 347 |
+
x2 = x2 + self.drop_path(cur2)
|
| 348 |
+
x3 = x3 + self.drop_path(cur3)
|
| 349 |
+
x4 = x4 + self.drop_path(cur4)
|
| 350 |
+
|
| 351 |
+
return x1, x2, x3, x4
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
class CoaT(nn.Module):
|
| 355 |
+
""" CoaT class. """
|
| 356 |
+
def __init__(
|
| 357 |
+
self,
|
| 358 |
+
img_size=224,
|
| 359 |
+
patch_size=16,
|
| 360 |
+
in_chans=3,
|
| 361 |
+
num_classes=1000,
|
| 362 |
+
embed_dims=(64, 128, 320, 512),
|
| 363 |
+
serial_depths=(3, 4, 6, 3),
|
| 364 |
+
parallel_depth=0,
|
| 365 |
+
num_heads=8,
|
| 366 |
+
mlp_ratios=(4, 4, 4, 4),
|
| 367 |
+
qkv_bias=True,
|
| 368 |
+
drop_rate=0.,
|
| 369 |
+
proj_drop_rate=0.,
|
| 370 |
+
attn_drop_rate=0.,
|
| 371 |
+
drop_path_rate=0.,
|
| 372 |
+
norm_layer=LayerNorm,
|
| 373 |
+
return_interm_layers=False,
|
| 374 |
+
out_features=None,
|
| 375 |
+
crpe_window=None,
|
| 376 |
+
global_pool='token',
|
| 377 |
+
):
|
| 378 |
+
super().__init__()
|
| 379 |
+
assert global_pool in ('token', 'avg')
|
| 380 |
+
crpe_window = crpe_window or {3: 2, 5: 3, 7: 3}
|
| 381 |
+
self.return_interm_layers = return_interm_layers
|
| 382 |
+
self.out_features = out_features
|
| 383 |
+
self.embed_dims = embed_dims
|
| 384 |
+
self.num_features = embed_dims[-1]
|
| 385 |
+
self.num_classes = num_classes
|
| 386 |
+
self.global_pool = global_pool
|
| 387 |
+
|
| 388 |
+
# Patch embeddings.
|
| 389 |
+
img_size = to_2tuple(img_size)
|
| 390 |
+
self.patch_embed1 = PatchEmbed(
|
| 391 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans,
|
| 392 |
+
embed_dim=embed_dims[0], norm_layer=nn.LayerNorm)
|
| 393 |
+
self.patch_embed2 = PatchEmbed(
|
| 394 |
+
img_size=[x // 4 for x in img_size], patch_size=2, in_chans=embed_dims[0],
|
| 395 |
+
embed_dim=embed_dims[1], norm_layer=nn.LayerNorm)
|
| 396 |
+
self.patch_embed3 = PatchEmbed(
|
| 397 |
+
img_size=[x // 8 for x in img_size], patch_size=2, in_chans=embed_dims[1],
|
| 398 |
+
embed_dim=embed_dims[2], norm_layer=nn.LayerNorm)
|
| 399 |
+
self.patch_embed4 = PatchEmbed(
|
| 400 |
+
img_size=[x // 16 for x in img_size], patch_size=2, in_chans=embed_dims[2],
|
| 401 |
+
embed_dim=embed_dims[3], norm_layer=nn.LayerNorm)
|
| 402 |
+
|
| 403 |
+
# Class tokens.
|
| 404 |
+
self.cls_token1 = nn.Parameter(torch.zeros(1, 1, embed_dims[0]))
|
| 405 |
+
self.cls_token2 = nn.Parameter(torch.zeros(1, 1, embed_dims[1]))
|
| 406 |
+
self.cls_token3 = nn.Parameter(torch.zeros(1, 1, embed_dims[2]))
|
| 407 |
+
self.cls_token4 = nn.Parameter(torch.zeros(1, 1, embed_dims[3]))
|
| 408 |
+
|
| 409 |
+
# Convolutional position encodings.
|
| 410 |
+
self.cpe1 = ConvPosEnc(dim=embed_dims[0], k=3)
|
| 411 |
+
self.cpe2 = ConvPosEnc(dim=embed_dims[1], k=3)
|
| 412 |
+
self.cpe3 = ConvPosEnc(dim=embed_dims[2], k=3)
|
| 413 |
+
self.cpe4 = ConvPosEnc(dim=embed_dims[3], k=3)
|
| 414 |
+
|
| 415 |
+
# Convolutional relative position encodings.
|
| 416 |
+
self.crpe1 = ConvRelPosEnc(head_chs=embed_dims[0] // num_heads, num_heads=num_heads, window=crpe_window)
|
| 417 |
+
self.crpe2 = ConvRelPosEnc(head_chs=embed_dims[1] // num_heads, num_heads=num_heads, window=crpe_window)
|
| 418 |
+
self.crpe3 = ConvRelPosEnc(head_chs=embed_dims[2] // num_heads, num_heads=num_heads, window=crpe_window)
|
| 419 |
+
self.crpe4 = ConvRelPosEnc(head_chs=embed_dims[3] // num_heads, num_heads=num_heads, window=crpe_window)
|
| 420 |
+
|
| 421 |
+
# Disable stochastic depth.
|
| 422 |
+
dpr = drop_path_rate
|
| 423 |
+
assert dpr == 0.0
|
| 424 |
+
skwargs = dict(
|
| 425 |
+
num_heads=num_heads,
|
| 426 |
+
qkv_bias=qkv_bias,
|
| 427 |
+
proj_drop=proj_drop_rate,
|
| 428 |
+
attn_drop=attn_drop_rate,
|
| 429 |
+
drop_path=dpr,
|
| 430 |
+
norm_layer=norm_layer,
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
# Serial blocks 1.
|
| 434 |
+
self.serial_blocks1 = nn.ModuleList([
|
| 435 |
+
SerialBlock(
|
| 436 |
+
dim=embed_dims[0],
|
| 437 |
+
mlp_ratio=mlp_ratios[0],
|
| 438 |
+
shared_cpe=self.cpe1,
|
| 439 |
+
shared_crpe=self.crpe1,
|
| 440 |
+
**skwargs,
|
| 441 |
+
)
|
| 442 |
+
for _ in range(serial_depths[0])]
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
# Serial blocks 2.
|
| 446 |
+
self.serial_blocks2 = nn.ModuleList([
|
| 447 |
+
SerialBlock(
|
| 448 |
+
dim=embed_dims[1],
|
| 449 |
+
mlp_ratio=mlp_ratios[1],
|
| 450 |
+
shared_cpe=self.cpe2,
|
| 451 |
+
shared_crpe=self.crpe2,
|
| 452 |
+
**skwargs,
|
| 453 |
+
)
|
| 454 |
+
for _ in range(serial_depths[1])]
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
# Serial blocks 3.
|
| 458 |
+
self.serial_blocks3 = nn.ModuleList([
|
| 459 |
+
SerialBlock(
|
| 460 |
+
dim=embed_dims[2],
|
| 461 |
+
mlp_ratio=mlp_ratios[2],
|
| 462 |
+
shared_cpe=self.cpe3,
|
| 463 |
+
shared_crpe=self.crpe3,
|
| 464 |
+
**skwargs,
|
| 465 |
+
)
|
| 466 |
+
for _ in range(serial_depths[2])]
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
# Serial blocks 4.
|
| 470 |
+
self.serial_blocks4 = nn.ModuleList([
|
| 471 |
+
SerialBlock(
|
| 472 |
+
dim=embed_dims[3],
|
| 473 |
+
mlp_ratio=mlp_ratios[3],
|
| 474 |
+
shared_cpe=self.cpe4,
|
| 475 |
+
shared_crpe=self.crpe4,
|
| 476 |
+
**skwargs,
|
| 477 |
+
)
|
| 478 |
+
for _ in range(serial_depths[3])]
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
# Parallel blocks.
|
| 482 |
+
self.parallel_depth = parallel_depth
|
| 483 |
+
if self.parallel_depth > 0:
|
| 484 |
+
self.parallel_blocks = nn.ModuleList([
|
| 485 |
+
ParallelBlock(
|
| 486 |
+
dims=embed_dims,
|
| 487 |
+
mlp_ratios=mlp_ratios,
|
| 488 |
+
shared_crpes=(self.crpe1, self.crpe2, self.crpe3, self.crpe4),
|
| 489 |
+
**skwargs,
|
| 490 |
+
)
|
| 491 |
+
for _ in range(parallel_depth)]
|
| 492 |
+
)
|
| 493 |
+
else:
|
| 494 |
+
self.parallel_blocks = None
|
| 495 |
+
|
| 496 |
+
# Classification head(s).
|
| 497 |
+
if not self.return_interm_layers:
|
| 498 |
+
if self.parallel_blocks is not None:
|
| 499 |
+
self.norm2 = norm_layer(embed_dims[1])
|
| 500 |
+
self.norm3 = norm_layer(embed_dims[2])
|
| 501 |
+
else:
|
| 502 |
+
self.norm2 = self.norm3 = None
|
| 503 |
+
self.norm4 = norm_layer(embed_dims[3])
|
| 504 |
+
|
| 505 |
+
if self.parallel_depth > 0:
|
| 506 |
+
# CoaT series: Aggregate features of last three scales for classification.
|
| 507 |
+
assert embed_dims[1] == embed_dims[2] == embed_dims[3]
|
| 508 |
+
self.aggregate = torch.nn.Conv1d(in_channels=3, out_channels=1, kernel_size=1)
|
| 509 |
+
self.head_drop = nn.Dropout(drop_rate)
|
| 510 |
+
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
| 511 |
+
else:
|
| 512 |
+
# CoaT-Lite series: Use feature of last scale for classification.
|
| 513 |
+
self.aggregate = None
|
| 514 |
+
self.head_drop = nn.Dropout(drop_rate)
|
| 515 |
+
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
| 516 |
+
|
| 517 |
+
# Initialize weights.
|
| 518 |
+
trunc_normal_(self.cls_token1, std=.02)
|
| 519 |
+
trunc_normal_(self.cls_token2, std=.02)
|
| 520 |
+
trunc_normal_(self.cls_token3, std=.02)
|
| 521 |
+
trunc_normal_(self.cls_token4, std=.02)
|
| 522 |
+
self.apply(self._init_weights)
|
| 523 |
+
|
| 524 |
+
def _init_weights(self, m):
|
| 525 |
+
if isinstance(m, nn.Linear):
|
| 526 |
+
trunc_normal_(m.weight, std=.02)
|
| 527 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 528 |
+
nn.init.constant_(m.bias, 0)
|
| 529 |
+
elif isinstance(m, nn.LayerNorm):
|
| 530 |
+
nn.init.constant_(m.bias, 0)
|
| 531 |
+
nn.init.constant_(m.weight, 1.0)
|
| 532 |
+
|
| 533 |
+
@torch.jit.ignore
|
| 534 |
+
def no_weight_decay(self):
|
| 535 |
+
return {'cls_token1', 'cls_token2', 'cls_token3', 'cls_token4'}
|
| 536 |
+
|
| 537 |
+
@torch.jit.ignore
|
| 538 |
+
def set_grad_checkpointing(self, enable=True):
|
| 539 |
+
assert not enable, 'gradient checkpointing not supported'
|
| 540 |
+
|
| 541 |
+
@torch.jit.ignore
|
| 542 |
+
def group_matcher(self, coarse=False):
|
| 543 |
+
matcher = dict(
|
| 544 |
+
stem1=r'^cls_token1|patch_embed1|crpe1|cpe1',
|
| 545 |
+
serial_blocks1=r'^serial_blocks1\.(\d+)',
|
| 546 |
+
stem2=r'^cls_token2|patch_embed2|crpe2|cpe2',
|
| 547 |
+
serial_blocks2=r'^serial_blocks2\.(\d+)',
|
| 548 |
+
stem3=r'^cls_token3|patch_embed3|crpe3|cpe3',
|
| 549 |
+
serial_blocks3=r'^serial_blocks3\.(\d+)',
|
| 550 |
+
stem4=r'^cls_token4|patch_embed4|crpe4|cpe4',
|
| 551 |
+
serial_blocks4=r'^serial_blocks4\.(\d+)',
|
| 552 |
+
parallel_blocks=[ # FIXME (partially?) overlap parallel w/ serial blocks??
|
| 553 |
+
(r'^parallel_blocks\.(\d+)', None),
|
| 554 |
+
(r'^norm|aggregate', (99999,)),
|
| 555 |
+
]
|
| 556 |
+
)
|
| 557 |
+
return matcher
|
| 558 |
+
|
| 559 |
+
@torch.jit.ignore
|
| 560 |
+
def get_classifier(self):
|
| 561 |
+
return self.head
|
| 562 |
+
|
| 563 |
+
def reset_classifier(self, num_classes, global_pool=None):
|
| 564 |
+
self.num_classes = num_classes
|
| 565 |
+
if global_pool is not None:
|
| 566 |
+
assert global_pool in ('token', 'avg')
|
| 567 |
+
self.global_pool = global_pool
|
| 568 |
+
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
| 569 |
+
|
| 570 |
+
def forward_features(self, x0):
|
| 571 |
+
B = x0.shape[0]
|
| 572 |
+
|
| 573 |
+
# Serial blocks 1.
|
| 574 |
+
x1 = self.patch_embed1(x0)
|
| 575 |
+
H1, W1 = self.patch_embed1.grid_size
|
| 576 |
+
x1 = insert_cls(x1, self.cls_token1)
|
| 577 |
+
for blk in self.serial_blocks1:
|
| 578 |
+
x1 = blk(x1, size=(H1, W1))
|
| 579 |
+
x1_nocls = remove_cls(x1).reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
|
| 580 |
+
|
| 581 |
+
# Serial blocks 2.
|
| 582 |
+
x2 = self.patch_embed2(x1_nocls)
|
| 583 |
+
H2, W2 = self.patch_embed2.grid_size
|
| 584 |
+
x2 = insert_cls(x2, self.cls_token2)
|
| 585 |
+
for blk in self.serial_blocks2:
|
| 586 |
+
x2 = blk(x2, size=(H2, W2))
|
| 587 |
+
x2_nocls = remove_cls(x2).reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous()
|
| 588 |
+
|
| 589 |
+
# Serial blocks 3.
|
| 590 |
+
x3 = self.patch_embed3(x2_nocls)
|
| 591 |
+
H3, W3 = self.patch_embed3.grid_size
|
| 592 |
+
x3 = insert_cls(x3, self.cls_token3)
|
| 593 |
+
for blk in self.serial_blocks3:
|
| 594 |
+
x3 = blk(x3, size=(H3, W3))
|
| 595 |
+
x3_nocls = remove_cls(x3).reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous()
|
| 596 |
+
|
| 597 |
+
# Serial blocks 4.
|
| 598 |
+
x4 = self.patch_embed4(x3_nocls)
|
| 599 |
+
H4, W4 = self.patch_embed4.grid_size
|
| 600 |
+
x4 = insert_cls(x4, self.cls_token4)
|
| 601 |
+
for blk in self.serial_blocks4:
|
| 602 |
+
x4 = blk(x4, size=(H4, W4))
|
| 603 |
+
x4_nocls = remove_cls(x4).reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous()
|
| 604 |
+
|
| 605 |
+
# Only serial blocks: Early return.
|
| 606 |
+
if self.parallel_blocks is None:
|
| 607 |
+
if not torch.jit.is_scripting() and self.return_interm_layers:
|
| 608 |
+
# Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2).
|
| 609 |
+
feat_out = {}
|
| 610 |
+
if 'x1_nocls' in self.out_features:
|
| 611 |
+
feat_out['x1_nocls'] = x1_nocls
|
| 612 |
+
if 'x2_nocls' in self.out_features:
|
| 613 |
+
feat_out['x2_nocls'] = x2_nocls
|
| 614 |
+
if 'x3_nocls' in self.out_features:
|
| 615 |
+
feat_out['x3_nocls'] = x3_nocls
|
| 616 |
+
if 'x4_nocls' in self.out_features:
|
| 617 |
+
feat_out['x4_nocls'] = x4_nocls
|
| 618 |
+
return feat_out
|
| 619 |
+
else:
|
| 620 |
+
# Return features for classification.
|
| 621 |
+
x4 = self.norm4(x4)
|
| 622 |
+
return x4
|
| 623 |
+
|
| 624 |
+
# Parallel blocks.
|
| 625 |
+
for blk in self.parallel_blocks:
|
| 626 |
+
x2, x3, x4 = self.cpe2(x2, (H2, W2)), self.cpe3(x3, (H3, W3)), self.cpe4(x4, (H4, W4))
|
| 627 |
+
x1, x2, x3, x4 = blk(x1, x2, x3, x4, sizes=[(H1, W1), (H2, W2), (H3, W3), (H4, W4)])
|
| 628 |
+
|
| 629 |
+
if not torch.jit.is_scripting() and self.return_interm_layers:
|
| 630 |
+
# Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2).
|
| 631 |
+
feat_out = {}
|
| 632 |
+
if 'x1_nocls' in self.out_features:
|
| 633 |
+
x1_nocls = remove_cls(x1).reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
|
| 634 |
+
feat_out['x1_nocls'] = x1_nocls
|
| 635 |
+
if 'x2_nocls' in self.out_features:
|
| 636 |
+
x2_nocls = remove_cls(x2).reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous()
|
| 637 |
+
feat_out['x2_nocls'] = x2_nocls
|
| 638 |
+
if 'x3_nocls' in self.out_features:
|
| 639 |
+
x3_nocls = remove_cls(x3).reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous()
|
| 640 |
+
feat_out['x3_nocls'] = x3_nocls
|
| 641 |
+
if 'x4_nocls' in self.out_features:
|
| 642 |
+
x4_nocls = remove_cls(x4).reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous()
|
| 643 |
+
feat_out['x4_nocls'] = x4_nocls
|
| 644 |
+
return feat_out
|
| 645 |
+
else:
|
| 646 |
+
x2 = self.norm2(x2)
|
| 647 |
+
x3 = self.norm3(x3)
|
| 648 |
+
x4 = self.norm4(x4)
|
| 649 |
+
return [x2, x3, x4]
|
| 650 |
+
|
| 651 |
+
def forward_head(self, x_feat: Union[torch.Tensor, List[torch.Tensor]], pre_logits: bool = False):
|
| 652 |
+
if isinstance(x_feat, list):
|
| 653 |
+
assert self.aggregate is not None
|
| 654 |
+
if self.global_pool == 'avg':
|
| 655 |
+
x = torch.cat([xl[:, 1:].mean(dim=1, keepdim=True) for xl in x_feat], dim=1) # [B, 3, C]
|
| 656 |
+
else:
|
| 657 |
+
x = torch.stack([xl[:, 0] for xl in x_feat], dim=1) # [B, 3, C]
|
| 658 |
+
x = self.aggregate(x).squeeze(dim=1) # Shape: [B, C]
|
| 659 |
+
else:
|
| 660 |
+
x = x_feat[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x_feat[:, 0]
|
| 661 |
+
x = self.head_drop(x)
|
| 662 |
+
return x if pre_logits else self.head(x)
|
| 663 |
+
|
| 664 |
+
def forward(self, x) -> torch.Tensor:
|
| 665 |
+
if not torch.jit.is_scripting() and self.return_interm_layers:
|
| 666 |
+
# Return intermediate features (for down-stream tasks).
|
| 667 |
+
return self.forward_features(x)
|
| 668 |
+
else:
|
| 669 |
+
# Return features for classification.
|
| 670 |
+
x_feat = self.forward_features(x)
|
| 671 |
+
x = self.forward_head(x_feat)
|
| 672 |
+
return x
|
| 673 |
+
|
| 674 |
+
|
| 675 |
+
def insert_cls(x, cls_token):
|
| 676 |
+
""" Insert CLS token. """
|
| 677 |
+
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
| 678 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 679 |
+
return x
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
def remove_cls(x):
|
| 683 |
+
""" Remove CLS token. """
|
| 684 |
+
return x[:, 1:, :]
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
def checkpoint_filter_fn(state_dict, model):
|
| 688 |
+
out_dict = {}
|
| 689 |
+
state_dict = state_dict.get('model', state_dict)
|
| 690 |
+
for k, v in state_dict.items():
|
| 691 |
+
# original model had unused norm layers, removing them requires filtering pretrained checkpoints
|
| 692 |
+
if k.startswith('norm1') or \
|
| 693 |
+
(k.startswith('norm2') and getattr(model, 'norm2', None) is None) or \
|
| 694 |
+
(k.startswith('norm3') and getattr(model, 'norm3', None) is None) or \
|
| 695 |
+
(k.startswith('norm4') and getattr(model, 'norm4', None) is None) or \
|
| 696 |
+
(k.startswith('aggregate') and getattr(model, 'aggregate', None) is None) or \
|
| 697 |
+
(k.startswith('head') and getattr(model, 'head', None) is None):
|
| 698 |
+
continue
|
| 699 |
+
out_dict[k] = v
|
| 700 |
+
return out_dict
|
| 701 |
+
|
| 702 |
+
|
| 703 |
+
def _create_coat(variant, pretrained=False, default_cfg=None, **kwargs):
|
| 704 |
+
if kwargs.get('features_only', None):
|
| 705 |
+
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
| 706 |
+
|
| 707 |
+
model = build_model_with_cfg(
|
| 708 |
+
CoaT,
|
| 709 |
+
variant,
|
| 710 |
+
pretrained,
|
| 711 |
+
pretrained_filter_fn=checkpoint_filter_fn,
|
| 712 |
+
**kwargs,
|
| 713 |
+
)
|
| 714 |
+
return model
|
| 715 |
+
|
| 716 |
+
|
| 717 |
+
def _cfg_coat(url='', **kwargs):
|
| 718 |
+
return {
|
| 719 |
+
'url': url,
|
| 720 |
+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
| 721 |
+
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
| 722 |
+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
| 723 |
+
'first_conv': 'patch_embed1.proj', 'classifier': 'head',
|
| 724 |
+
**kwargs
|
| 725 |
+
}
|
| 726 |
+
|
| 727 |
+
|
| 728 |
+
default_cfgs = generate_default_cfgs({
|
| 729 |
+
'coat_tiny.in1k': _cfg_coat(hf_hub_id='timm/'),
|
| 730 |
+
'coat_mini.in1k': _cfg_coat(hf_hub_id='timm/'),
|
| 731 |
+
'coat_small.in1k': _cfg_coat(hf_hub_id='timm/'),
|
| 732 |
+
'coat_lite_tiny.in1k': _cfg_coat(hf_hub_id='timm/'),
|
| 733 |
+
'coat_lite_mini.in1k': _cfg_coat(hf_hub_id='timm/'),
|
| 734 |
+
'coat_lite_small.in1k': _cfg_coat(hf_hub_id='timm/'),
|
| 735 |
+
'coat_lite_medium.in1k': _cfg_coat(hf_hub_id='timm/'),
|
| 736 |
+
'coat_lite_medium_384.in1k': _cfg_coat(
|
| 737 |
+
hf_hub_id='timm/',
|
| 738 |
+
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash',
|
| 739 |
+
),
|
| 740 |
+
})
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
@register_model
|
| 744 |
+
def coat_tiny(pretrained=False, **kwargs) -> CoaT:
|
| 745 |
+
model_cfg = dict(
|
| 746 |
+
patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6)
|
| 747 |
+
model = _create_coat('coat_tiny', pretrained=pretrained, **dict(model_cfg, **kwargs))
|
| 748 |
+
return model
|
| 749 |
+
|
| 750 |
+
|
| 751 |
+
@register_model
|
| 752 |
+
def coat_mini(pretrained=False, **kwargs) -> CoaT:
|
| 753 |
+
model_cfg = dict(
|
| 754 |
+
patch_size=4, embed_dims=[152, 216, 216, 216], serial_depths=[2, 2, 2, 2], parallel_depth=6)
|
| 755 |
+
model = _create_coat('coat_mini', pretrained=pretrained, **dict(model_cfg, **kwargs))
|
| 756 |
+
return model
|
| 757 |
+
|
| 758 |
+
|
| 759 |
+
@register_model
|
| 760 |
+
def coat_small(pretrained=False, **kwargs) -> CoaT:
|
| 761 |
+
model_cfg = dict(
|
| 762 |
+
patch_size=4, embed_dims=[152, 320, 320, 320], serial_depths=[2, 2, 2, 2], parallel_depth=6, **kwargs)
|
| 763 |
+
model = _create_coat('coat_small', pretrained=pretrained, **dict(model_cfg, **kwargs))
|
| 764 |
+
return model
|
| 765 |
+
|
| 766 |
+
|
| 767 |
+
@register_model
|
| 768 |
+
def coat_lite_tiny(pretrained=False, **kwargs) -> CoaT:
|
| 769 |
+
model_cfg = dict(
|
| 770 |
+
patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], mlp_ratios=[8, 8, 4, 4])
|
| 771 |
+
model = _create_coat('coat_lite_tiny', pretrained=pretrained, **dict(model_cfg, **kwargs))
|
| 772 |
+
return model
|
| 773 |
+
|
| 774 |
+
|
| 775 |
+
@register_model
|
| 776 |
+
def coat_lite_mini(pretrained=False, **kwargs) -> CoaT:
|
| 777 |
+
model_cfg = dict(
|
| 778 |
+
patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[2, 2, 2, 2], mlp_ratios=[8, 8, 4, 4])
|
| 779 |
+
model = _create_coat('coat_lite_mini', pretrained=pretrained, **dict(model_cfg, **kwargs))
|
| 780 |
+
return model
|
| 781 |
+
|
| 782 |
+
|
| 783 |
+
@register_model
|
| 784 |
+
def coat_lite_small(pretrained=False, **kwargs) -> CoaT:
|
| 785 |
+
model_cfg = dict(
|
| 786 |
+
patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[3, 4, 6, 3], mlp_ratios=[8, 8, 4, 4])
|
| 787 |
+
model = _create_coat('coat_lite_small', pretrained=pretrained, **dict(model_cfg, **kwargs))
|
| 788 |
+
return model
|
| 789 |
+
|
| 790 |
+
|
| 791 |
+
@register_model
|
| 792 |
+
def coat_lite_medium(pretrained=False, **kwargs) -> CoaT:
|
| 793 |
+
model_cfg = dict(
|
| 794 |
+
patch_size=4, embed_dims=[128, 256, 320, 512], serial_depths=[3, 6, 10, 8])
|
| 795 |
+
model = _create_coat('coat_lite_medium', pretrained=pretrained, **dict(model_cfg, **kwargs))
|
| 796 |
+
return model
|
| 797 |
+
|
| 798 |
+
|
| 799 |
+
@register_model
|
| 800 |
+
def coat_lite_medium_384(pretrained=False, **kwargs) -> CoaT:
|
| 801 |
+
model_cfg = dict(
|
| 802 |
+
img_size=384, patch_size=4, embed_dims=[128, 256, 320, 512], serial_depths=[3, 6, 10, 8])
|
| 803 |
+
model = _create_coat('coat_lite_medium_384', pretrained=pretrained, **dict(model_cfg, **kwargs))
|
| 804 |
+
return model
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/convit.py
ADDED
|
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" ConViT Model
|
| 2 |
+
|
| 3 |
+
@article{d2021convit,
|
| 4 |
+
title={ConViT: Improving Vision Transformers with Soft Convolutional Inductive Biases},
|
| 5 |
+
author={d'Ascoli, St{\'e}phane and Touvron, Hugo and Leavitt, Matthew and Morcos, Ari and Biroli, Giulio and Sagun, Levent},
|
| 6 |
+
journal={arXiv preprint arXiv:2103.10697},
|
| 7 |
+
year={2021}
|
| 8 |
+
}
|
| 9 |
+
|
| 10 |
+
Paper link: https://arxiv.org/abs/2103.10697
|
| 11 |
+
Original code: https://github.com/facebookresearch/convit, original copyright below
|
| 12 |
+
|
| 13 |
+
Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
|
| 14 |
+
"""
|
| 15 |
+
# Copyright (c) 2015-present, Facebook, Inc.
|
| 16 |
+
# All rights reserved.
|
| 17 |
+
#
|
| 18 |
+
# This source code is licensed under the CC-by-NC license found in the
|
| 19 |
+
# LICENSE file in the root directory of this source tree.
|
| 20 |
+
#
|
| 21 |
+
'''These modules are adapted from those of timm, see
|
| 22 |
+
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
| 23 |
+
'''
|
| 24 |
+
|
| 25 |
+
from functools import partial
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn as nn
|
| 29 |
+
|
| 30 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 31 |
+
from timm.layers import DropPath, trunc_normal_, PatchEmbed, Mlp, LayerNorm
|
| 32 |
+
from ._builder import build_model_with_cfg
|
| 33 |
+
from ._features_fx import register_notrace_module
|
| 34 |
+
from ._registry import register_model, generate_default_cfgs
|
| 35 |
+
from .vision_transformer_hybrid import HybridEmbed
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
__all__ = ['ConVit']
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
|
| 42 |
+
class GPSA(nn.Module):
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
dim,
|
| 46 |
+
num_heads=8,
|
| 47 |
+
qkv_bias=False,
|
| 48 |
+
attn_drop=0.,
|
| 49 |
+
proj_drop=0.,
|
| 50 |
+
locality_strength=1.,
|
| 51 |
+
):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.num_heads = num_heads
|
| 54 |
+
self.dim = dim
|
| 55 |
+
head_dim = dim // num_heads
|
| 56 |
+
self.scale = head_dim ** -0.5
|
| 57 |
+
self.locality_strength = locality_strength
|
| 58 |
+
|
| 59 |
+
self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
| 60 |
+
self.v = nn.Linear(dim, dim, bias=qkv_bias)
|
| 61 |
+
|
| 62 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 63 |
+
self.proj = nn.Linear(dim, dim)
|
| 64 |
+
self.pos_proj = nn.Linear(3, num_heads)
|
| 65 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 66 |
+
self.gating_param = nn.Parameter(torch.ones(self.num_heads))
|
| 67 |
+
self.rel_indices: torch.Tensor = torch.zeros(1, 1, 1, 3) # silly torchscript hack, won't work with None
|
| 68 |
+
|
| 69 |
+
def forward(self, x):
|
| 70 |
+
B, N, C = x.shape
|
| 71 |
+
if self.rel_indices is None or self.rel_indices.shape[1] != N:
|
| 72 |
+
self.rel_indices = self.get_rel_indices(N)
|
| 73 |
+
attn = self.get_attention(x)
|
| 74 |
+
v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
| 75 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 76 |
+
x = self.proj(x)
|
| 77 |
+
x = self.proj_drop(x)
|
| 78 |
+
return x
|
| 79 |
+
|
| 80 |
+
def get_attention(self, x):
|
| 81 |
+
B, N, C = x.shape
|
| 82 |
+
qk = self.qk(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 83 |
+
q, k = qk[0], qk[1]
|
| 84 |
+
pos_score = self.rel_indices.expand(B, -1, -1, -1)
|
| 85 |
+
pos_score = self.pos_proj(pos_score).permute(0, 3, 1, 2)
|
| 86 |
+
patch_score = (q @ k.transpose(-2, -1)) * self.scale
|
| 87 |
+
patch_score = patch_score.softmax(dim=-1)
|
| 88 |
+
pos_score = pos_score.softmax(dim=-1)
|
| 89 |
+
|
| 90 |
+
gating = self.gating_param.view(1, -1, 1, 1)
|
| 91 |
+
attn = (1. - torch.sigmoid(gating)) * patch_score + torch.sigmoid(gating) * pos_score
|
| 92 |
+
attn /= attn.sum(dim=-1).unsqueeze(-1)
|
| 93 |
+
attn = self.attn_drop(attn)
|
| 94 |
+
return attn
|
| 95 |
+
|
| 96 |
+
def get_attention_map(self, x, return_map=False):
|
| 97 |
+
attn_map = self.get_attention(x).mean(0) # average over batch
|
| 98 |
+
distances = self.rel_indices.squeeze()[:, :, -1] ** .5
|
| 99 |
+
dist = torch.einsum('nm,hnm->h', (distances, attn_map)) / distances.size(0)
|
| 100 |
+
if return_map:
|
| 101 |
+
return dist, attn_map
|
| 102 |
+
else:
|
| 103 |
+
return dist
|
| 104 |
+
|
| 105 |
+
def local_init(self):
|
| 106 |
+
self.v.weight.data.copy_(torch.eye(self.dim))
|
| 107 |
+
locality_distance = 1 # max(1,1/locality_strength**.5)
|
| 108 |
+
|
| 109 |
+
kernel_size = int(self.num_heads ** .5)
|
| 110 |
+
center = (kernel_size - 1) / 2 if kernel_size % 2 == 0 else kernel_size // 2
|
| 111 |
+
for h1 in range(kernel_size):
|
| 112 |
+
for h2 in range(kernel_size):
|
| 113 |
+
position = h1 + kernel_size * h2
|
| 114 |
+
self.pos_proj.weight.data[position, 2] = -1
|
| 115 |
+
self.pos_proj.weight.data[position, 1] = 2 * (h1 - center) * locality_distance
|
| 116 |
+
self.pos_proj.weight.data[position, 0] = 2 * (h2 - center) * locality_distance
|
| 117 |
+
self.pos_proj.weight.data *= self.locality_strength
|
| 118 |
+
|
| 119 |
+
def get_rel_indices(self, num_patches: int) -> torch.Tensor:
|
| 120 |
+
img_size = int(num_patches ** .5)
|
| 121 |
+
rel_indices = torch.zeros(1, num_patches, num_patches, 3)
|
| 122 |
+
ind = torch.arange(img_size).view(1, -1) - torch.arange(img_size).view(-1, 1)
|
| 123 |
+
indx = ind.repeat(img_size, img_size)
|
| 124 |
+
indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1)
|
| 125 |
+
indd = indx ** 2 + indy ** 2
|
| 126 |
+
rel_indices[:, :, :, 2] = indd.unsqueeze(0)
|
| 127 |
+
rel_indices[:, :, :, 1] = indy.unsqueeze(0)
|
| 128 |
+
rel_indices[:, :, :, 0] = indx.unsqueeze(0)
|
| 129 |
+
device = self.qk.weight.device
|
| 130 |
+
return rel_indices.to(device)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class MHSA(nn.Module):
|
| 134 |
+
def __init__(
|
| 135 |
+
self,
|
| 136 |
+
dim,
|
| 137 |
+
num_heads=8,
|
| 138 |
+
qkv_bias=False,
|
| 139 |
+
attn_drop=0.,
|
| 140 |
+
proj_drop=0.,
|
| 141 |
+
):
|
| 142 |
+
super().__init__()
|
| 143 |
+
self.num_heads = num_heads
|
| 144 |
+
head_dim = dim // num_heads
|
| 145 |
+
self.scale = head_dim ** -0.5
|
| 146 |
+
|
| 147 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 148 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 149 |
+
self.proj = nn.Linear(dim, dim)
|
| 150 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 151 |
+
|
| 152 |
+
def get_attention_map(self, x, return_map=False):
|
| 153 |
+
B, N, C = x.shape
|
| 154 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 155 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 156 |
+
attn_map = (q @ k.transpose(-2, -1)) * self.scale
|
| 157 |
+
attn_map = attn_map.softmax(dim=-1).mean(0)
|
| 158 |
+
|
| 159 |
+
img_size = int(N ** .5)
|
| 160 |
+
ind = torch.arange(img_size).view(1, -1) - torch.arange(img_size).view(-1, 1)
|
| 161 |
+
indx = ind.repeat(img_size, img_size)
|
| 162 |
+
indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1)
|
| 163 |
+
indd = indx ** 2 + indy ** 2
|
| 164 |
+
distances = indd ** .5
|
| 165 |
+
distances = distances.to(x.device)
|
| 166 |
+
|
| 167 |
+
dist = torch.einsum('nm,hnm->h', (distances, attn_map)) / N
|
| 168 |
+
if return_map:
|
| 169 |
+
return dist, attn_map
|
| 170 |
+
else:
|
| 171 |
+
return dist
|
| 172 |
+
|
| 173 |
+
def forward(self, x):
|
| 174 |
+
B, N, C = x.shape
|
| 175 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 176 |
+
q, k, v = qkv.unbind(0)
|
| 177 |
+
|
| 178 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 179 |
+
attn = attn.softmax(dim=-1)
|
| 180 |
+
attn = self.attn_drop(attn)
|
| 181 |
+
|
| 182 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 183 |
+
x = self.proj(x)
|
| 184 |
+
x = self.proj_drop(x)
|
| 185 |
+
return x
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class Block(nn.Module):
|
| 189 |
+
|
| 190 |
+
def __init__(
|
| 191 |
+
self,
|
| 192 |
+
dim,
|
| 193 |
+
num_heads,
|
| 194 |
+
mlp_ratio=4.,
|
| 195 |
+
qkv_bias=False,
|
| 196 |
+
proj_drop=0.,
|
| 197 |
+
attn_drop=0.,
|
| 198 |
+
drop_path=0.,
|
| 199 |
+
act_layer=nn.GELU,
|
| 200 |
+
norm_layer=LayerNorm,
|
| 201 |
+
use_gpsa=True,
|
| 202 |
+
locality_strength=1.,
|
| 203 |
+
):
|
| 204 |
+
super().__init__()
|
| 205 |
+
self.norm1 = norm_layer(dim)
|
| 206 |
+
self.use_gpsa = use_gpsa
|
| 207 |
+
if self.use_gpsa:
|
| 208 |
+
self.attn = GPSA(
|
| 209 |
+
dim,
|
| 210 |
+
num_heads=num_heads,
|
| 211 |
+
qkv_bias=qkv_bias,
|
| 212 |
+
attn_drop=attn_drop,
|
| 213 |
+
proj_drop=proj_drop,
|
| 214 |
+
locality_strength=locality_strength,
|
| 215 |
+
)
|
| 216 |
+
else:
|
| 217 |
+
self.attn = MHSA(
|
| 218 |
+
dim,
|
| 219 |
+
num_heads=num_heads,
|
| 220 |
+
qkv_bias=qkv_bias,
|
| 221 |
+
attn_drop=attn_drop,
|
| 222 |
+
proj_drop=proj_drop,
|
| 223 |
+
)
|
| 224 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 225 |
+
self.norm2 = norm_layer(dim)
|
| 226 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 227 |
+
self.mlp = Mlp(
|
| 228 |
+
in_features=dim,
|
| 229 |
+
hidden_features=mlp_hidden_dim,
|
| 230 |
+
act_layer=act_layer,
|
| 231 |
+
drop=proj_drop,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
def forward(self, x):
|
| 235 |
+
x = x + self.drop_path(self.attn(self.norm1(x)))
|
| 236 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
| 237 |
+
return x
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
class ConVit(nn.Module):
|
| 241 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
| 242 |
+
"""
|
| 243 |
+
|
| 244 |
+
def __init__(
|
| 245 |
+
self,
|
| 246 |
+
img_size=224,
|
| 247 |
+
patch_size=16,
|
| 248 |
+
in_chans=3,
|
| 249 |
+
num_classes=1000,
|
| 250 |
+
global_pool='token',
|
| 251 |
+
embed_dim=768,
|
| 252 |
+
depth=12,
|
| 253 |
+
num_heads=12,
|
| 254 |
+
mlp_ratio=4.,
|
| 255 |
+
qkv_bias=False,
|
| 256 |
+
drop_rate=0.,
|
| 257 |
+
pos_drop_rate=0.,
|
| 258 |
+
proj_drop_rate=0.,
|
| 259 |
+
attn_drop_rate=0.,
|
| 260 |
+
drop_path_rate=0.,
|
| 261 |
+
hybrid_backbone=None,
|
| 262 |
+
norm_layer=LayerNorm,
|
| 263 |
+
local_up_to_layer=3,
|
| 264 |
+
locality_strength=1.,
|
| 265 |
+
use_pos_embed=True,
|
| 266 |
+
):
|
| 267 |
+
super().__init__()
|
| 268 |
+
assert global_pool in ('', 'avg', 'token')
|
| 269 |
+
embed_dim *= num_heads
|
| 270 |
+
self.num_classes = num_classes
|
| 271 |
+
self.global_pool = global_pool
|
| 272 |
+
self.local_up_to_layer = local_up_to_layer
|
| 273 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 274 |
+
self.locality_strength = locality_strength
|
| 275 |
+
self.use_pos_embed = use_pos_embed
|
| 276 |
+
|
| 277 |
+
if hybrid_backbone is not None:
|
| 278 |
+
self.patch_embed = HybridEmbed(
|
| 279 |
+
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
|
| 280 |
+
else:
|
| 281 |
+
self.patch_embed = PatchEmbed(
|
| 282 |
+
img_size=img_size,
|
| 283 |
+
patch_size=patch_size,
|
| 284 |
+
in_chans=in_chans,
|
| 285 |
+
embed_dim=embed_dim,
|
| 286 |
+
)
|
| 287 |
+
num_patches = self.patch_embed.num_patches
|
| 288 |
+
self.num_patches = num_patches
|
| 289 |
+
|
| 290 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 291 |
+
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
| 292 |
+
|
| 293 |
+
if self.use_pos_embed:
|
| 294 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
|
| 295 |
+
trunc_normal_(self.pos_embed, std=.02)
|
| 296 |
+
|
| 297 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 298 |
+
self.blocks = nn.ModuleList([
|
| 299 |
+
Block(
|
| 300 |
+
dim=embed_dim,
|
| 301 |
+
num_heads=num_heads,
|
| 302 |
+
mlp_ratio=mlp_ratio,
|
| 303 |
+
qkv_bias=qkv_bias,
|
| 304 |
+
proj_drop=proj_drop_rate,
|
| 305 |
+
attn_drop=attn_drop_rate,
|
| 306 |
+
drop_path=dpr[i],
|
| 307 |
+
norm_layer=norm_layer,
|
| 308 |
+
use_gpsa=i < local_up_to_layer,
|
| 309 |
+
locality_strength=locality_strength,
|
| 310 |
+
) for i in range(depth)])
|
| 311 |
+
self.norm = norm_layer(embed_dim)
|
| 312 |
+
|
| 313 |
+
# Classifier head
|
| 314 |
+
self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')]
|
| 315 |
+
self.head_drop = nn.Dropout(drop_rate)
|
| 316 |
+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
| 317 |
+
|
| 318 |
+
trunc_normal_(self.cls_token, std=.02)
|
| 319 |
+
self.apply(self._init_weights)
|
| 320 |
+
for n, m in self.named_modules():
|
| 321 |
+
if hasattr(m, 'local_init'):
|
| 322 |
+
m.local_init()
|
| 323 |
+
|
| 324 |
+
def _init_weights(self, m):
|
| 325 |
+
if isinstance(m, nn.Linear):
|
| 326 |
+
trunc_normal_(m.weight, std=.02)
|
| 327 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 328 |
+
nn.init.constant_(m.bias, 0)
|
| 329 |
+
elif isinstance(m, nn.LayerNorm):
|
| 330 |
+
nn.init.constant_(m.bias, 0)
|
| 331 |
+
nn.init.constant_(m.weight, 1.0)
|
| 332 |
+
|
| 333 |
+
@torch.jit.ignore
|
| 334 |
+
def no_weight_decay(self):
|
| 335 |
+
return {'pos_embed', 'cls_token'}
|
| 336 |
+
|
| 337 |
+
@torch.jit.ignore
|
| 338 |
+
def group_matcher(self, coarse=False):
|
| 339 |
+
return dict(
|
| 340 |
+
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
|
| 341 |
+
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
@torch.jit.ignore
|
| 345 |
+
def set_grad_checkpointing(self, enable=True):
|
| 346 |
+
assert not enable, 'gradient checkpointing not supported'
|
| 347 |
+
|
| 348 |
+
@torch.jit.ignore
|
| 349 |
+
def get_classifier(self):
|
| 350 |
+
return self.head
|
| 351 |
+
|
| 352 |
+
def reset_classifier(self, num_classes, global_pool=None):
|
| 353 |
+
self.num_classes = num_classes
|
| 354 |
+
if global_pool is not None:
|
| 355 |
+
assert global_pool in ('', 'token', 'avg')
|
| 356 |
+
self.global_pool = global_pool
|
| 357 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
| 358 |
+
|
| 359 |
+
def forward_features(self, x):
|
| 360 |
+
x = self.patch_embed(x)
|
| 361 |
+
if self.use_pos_embed:
|
| 362 |
+
x = x + self.pos_embed
|
| 363 |
+
x = self.pos_drop(x)
|
| 364 |
+
cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
|
| 365 |
+
for u, blk in enumerate(self.blocks):
|
| 366 |
+
if u == self.local_up_to_layer:
|
| 367 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 368 |
+
x = blk(x)
|
| 369 |
+
x = self.norm(x)
|
| 370 |
+
return x
|
| 371 |
+
|
| 372 |
+
def forward_head(self, x, pre_logits: bool = False):
|
| 373 |
+
if self.global_pool:
|
| 374 |
+
x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
|
| 375 |
+
x = self.head_drop(x)
|
| 376 |
+
return x if pre_logits else self.head(x)
|
| 377 |
+
|
| 378 |
+
def forward(self, x):
|
| 379 |
+
x = self.forward_features(x)
|
| 380 |
+
x = self.forward_head(x)
|
| 381 |
+
return x
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def _create_convit(variant, pretrained=False, **kwargs):
|
| 385 |
+
if kwargs.get('features_only', None):
|
| 386 |
+
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
| 387 |
+
|
| 388 |
+
return build_model_with_cfg(ConVit, variant, pretrained, **kwargs)
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def _cfg(url='', **kwargs):
|
| 392 |
+
return {
|
| 393 |
+
'url': url,
|
| 394 |
+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
| 395 |
+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True,
|
| 396 |
+
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
| 397 |
+
**kwargs
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
default_cfgs = generate_default_cfgs({
|
| 402 |
+
# ConViT
|
| 403 |
+
'convit_tiny.fb_in1k': _cfg(hf_hub_id='timm/'),
|
| 404 |
+
'convit_small.fb_in1k': _cfg(hf_hub_id='timm/'),
|
| 405 |
+
'convit_base.fb_in1k': _cfg(hf_hub_id='timm/')
|
| 406 |
+
})
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
@register_model
|
| 410 |
+
def convit_tiny(pretrained=False, **kwargs) -> ConVit:
|
| 411 |
+
model_args = dict(
|
| 412 |
+
local_up_to_layer=10, locality_strength=1.0, embed_dim=48, num_heads=4)
|
| 413 |
+
model = _create_convit(variant='convit_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 414 |
+
return model
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
@register_model
|
| 418 |
+
def convit_small(pretrained=False, **kwargs) -> ConVit:
|
| 419 |
+
model_args = dict(
|
| 420 |
+
local_up_to_layer=10, locality_strength=1.0, embed_dim=48, num_heads=9)
|
| 421 |
+
model = _create_convit(variant='convit_small', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 422 |
+
return model
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
@register_model
|
| 426 |
+
def convit_base(pretrained=False, **kwargs) -> ConVit:
|
| 427 |
+
model_args = dict(
|
| 428 |
+
local_up_to_layer=10, locality_strength=1.0, embed_dim=48, num_heads=16)
|
| 429 |
+
model = _create_convit(variant='convit_base', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 430 |
+
return model
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/crossvit.py
ADDED
|
@@ -0,0 +1,627 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" CrossViT Model
|
| 2 |
+
|
| 3 |
+
@inproceedings{
|
| 4 |
+
chen2021crossvit,
|
| 5 |
+
title={{CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification}},
|
| 6 |
+
author={Chun-Fu (Richard) Chen and Quanfu Fan and Rameswar Panda},
|
| 7 |
+
booktitle={International Conference on Computer Vision (ICCV)},
|
| 8 |
+
year={2021}
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
Paper link: https://arxiv.org/abs/2103.14899
|
| 12 |
+
Original code: https://github.com/IBM/CrossViT/blob/main/models/crossvit.py
|
| 13 |
+
|
| 14 |
+
NOTE: model names have been renamed from originals to represent actual input res all *_224 -> *_240 and *_384 -> *_408
|
| 15 |
+
|
| 16 |
+
Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
# Copyright IBM All Rights Reserved.
|
| 20 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
"""
|
| 24 |
+
Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
| 25 |
+
|
| 26 |
+
"""
|
| 27 |
+
from functools import partial
|
| 28 |
+
from typing import List
|
| 29 |
+
from typing import Tuple
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
import torch.hub
|
| 33 |
+
import torch.nn as nn
|
| 34 |
+
|
| 35 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 36 |
+
from timm.layers import DropPath, to_2tuple, trunc_normal_, _assert
|
| 37 |
+
from ._builder import build_model_with_cfg
|
| 38 |
+
from ._features_fx import register_notrace_function
|
| 39 |
+
from ._registry import register_model, generate_default_cfgs
|
| 40 |
+
from .vision_transformer import Block
|
| 41 |
+
|
| 42 |
+
__all__ = ['CrossVit'] # model_registry will add each entrypoint fn to this
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class PatchEmbed(nn.Module):
|
| 46 |
+
""" Image to Patch Embedding
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, multi_conv=False):
|
| 50 |
+
super().__init__()
|
| 51 |
+
img_size = to_2tuple(img_size)
|
| 52 |
+
patch_size = to_2tuple(patch_size)
|
| 53 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
| 54 |
+
self.img_size = img_size
|
| 55 |
+
self.patch_size = patch_size
|
| 56 |
+
self.num_patches = num_patches
|
| 57 |
+
if multi_conv:
|
| 58 |
+
if patch_size[0] == 12:
|
| 59 |
+
self.proj = nn.Sequential(
|
| 60 |
+
nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3),
|
| 61 |
+
nn.ReLU(inplace=True),
|
| 62 |
+
nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=3, padding=0),
|
| 63 |
+
nn.ReLU(inplace=True),
|
| 64 |
+
nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=1, padding=1),
|
| 65 |
+
)
|
| 66 |
+
elif patch_size[0] == 16:
|
| 67 |
+
self.proj = nn.Sequential(
|
| 68 |
+
nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3),
|
| 69 |
+
nn.ReLU(inplace=True),
|
| 70 |
+
nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=2, padding=1),
|
| 71 |
+
nn.ReLU(inplace=True),
|
| 72 |
+
nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1),
|
| 73 |
+
)
|
| 74 |
+
else:
|
| 75 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 76 |
+
|
| 77 |
+
def forward(self, x):
|
| 78 |
+
B, C, H, W = x.shape
|
| 79 |
+
# FIXME look at relaxing size constraints
|
| 80 |
+
_assert(H == self.img_size[0],
|
| 81 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
|
| 82 |
+
_assert(W == self.img_size[1],
|
| 83 |
+
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
|
| 84 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
| 85 |
+
return x
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class CrossAttention(nn.Module):
|
| 89 |
+
def __init__(
|
| 90 |
+
self,
|
| 91 |
+
dim,
|
| 92 |
+
num_heads=8,
|
| 93 |
+
qkv_bias=False,
|
| 94 |
+
attn_drop=0.,
|
| 95 |
+
proj_drop=0.,
|
| 96 |
+
):
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.num_heads = num_heads
|
| 99 |
+
head_dim = dim // num_heads
|
| 100 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
| 101 |
+
self.scale = head_dim ** -0.5
|
| 102 |
+
|
| 103 |
+
self.wq = nn.Linear(dim, dim, bias=qkv_bias)
|
| 104 |
+
self.wk = nn.Linear(dim, dim, bias=qkv_bias)
|
| 105 |
+
self.wv = nn.Linear(dim, dim, bias=qkv_bias)
|
| 106 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 107 |
+
self.proj = nn.Linear(dim, dim)
|
| 108 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 109 |
+
|
| 110 |
+
def forward(self, x):
|
| 111 |
+
B, N, C = x.shape
|
| 112 |
+
# B1C -> B1H(C/H) -> BH1(C/H)
|
| 113 |
+
q = self.wq(x[:, 0:1, ...]).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
| 114 |
+
# BNC -> BNH(C/H) -> BHN(C/H)
|
| 115 |
+
k = self.wk(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
| 116 |
+
# BNC -> BNH(C/H) -> BHN(C/H)
|
| 117 |
+
v = self.wv(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
| 118 |
+
|
| 119 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale # BH1(C/H) @ BH(C/H)N -> BH1N
|
| 120 |
+
attn = attn.softmax(dim=-1)
|
| 121 |
+
attn = self.attn_drop(attn)
|
| 122 |
+
|
| 123 |
+
x = (attn @ v).transpose(1, 2).reshape(B, 1, C) # (BH1N @ BHN(C/H)) -> BH1(C/H) -> B1H(C/H) -> B1C
|
| 124 |
+
x = self.proj(x)
|
| 125 |
+
x = self.proj_drop(x)
|
| 126 |
+
return x
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class CrossAttentionBlock(nn.Module):
|
| 130 |
+
|
| 131 |
+
def __init__(
|
| 132 |
+
self,
|
| 133 |
+
dim,
|
| 134 |
+
num_heads,
|
| 135 |
+
mlp_ratio=4.,
|
| 136 |
+
qkv_bias=False,
|
| 137 |
+
proj_drop=0.,
|
| 138 |
+
attn_drop=0.,
|
| 139 |
+
drop_path=0.,
|
| 140 |
+
act_layer=nn.GELU,
|
| 141 |
+
norm_layer=nn.LayerNorm,
|
| 142 |
+
):
|
| 143 |
+
super().__init__()
|
| 144 |
+
self.norm1 = norm_layer(dim)
|
| 145 |
+
self.attn = CrossAttention(
|
| 146 |
+
dim,
|
| 147 |
+
num_heads=num_heads,
|
| 148 |
+
qkv_bias=qkv_bias,
|
| 149 |
+
attn_drop=attn_drop,
|
| 150 |
+
proj_drop=proj_drop,
|
| 151 |
+
)
|
| 152 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 153 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 154 |
+
|
| 155 |
+
def forward(self, x):
|
| 156 |
+
x = x[:, 0:1, ...] + self.drop_path(self.attn(self.norm1(x)))
|
| 157 |
+
return x
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class MultiScaleBlock(nn.Module):
|
| 161 |
+
|
| 162 |
+
def __init__(
|
| 163 |
+
self,
|
| 164 |
+
dim,
|
| 165 |
+
patches,
|
| 166 |
+
depth,
|
| 167 |
+
num_heads,
|
| 168 |
+
mlp_ratio,
|
| 169 |
+
qkv_bias=False,
|
| 170 |
+
proj_drop=0.,
|
| 171 |
+
attn_drop=0.,
|
| 172 |
+
drop_path=0.,
|
| 173 |
+
act_layer=nn.GELU,
|
| 174 |
+
norm_layer=nn.LayerNorm,
|
| 175 |
+
):
|
| 176 |
+
super().__init__()
|
| 177 |
+
|
| 178 |
+
num_branches = len(dim)
|
| 179 |
+
self.num_branches = num_branches
|
| 180 |
+
# different branch could have different embedding size, the first one is the base
|
| 181 |
+
self.blocks = nn.ModuleList()
|
| 182 |
+
for d in range(num_branches):
|
| 183 |
+
tmp = []
|
| 184 |
+
for i in range(depth[d]):
|
| 185 |
+
tmp.append(Block(
|
| 186 |
+
dim=dim[d],
|
| 187 |
+
num_heads=num_heads[d],
|
| 188 |
+
mlp_ratio=mlp_ratio[d],
|
| 189 |
+
qkv_bias=qkv_bias,
|
| 190 |
+
proj_drop=proj_drop,
|
| 191 |
+
attn_drop=attn_drop,
|
| 192 |
+
drop_path=drop_path[i],
|
| 193 |
+
norm_layer=norm_layer,
|
| 194 |
+
))
|
| 195 |
+
if len(tmp) != 0:
|
| 196 |
+
self.blocks.append(nn.Sequential(*tmp))
|
| 197 |
+
|
| 198 |
+
if len(self.blocks) == 0:
|
| 199 |
+
self.blocks = None
|
| 200 |
+
|
| 201 |
+
self.projs = nn.ModuleList()
|
| 202 |
+
for d in range(num_branches):
|
| 203 |
+
if dim[d] == dim[(d + 1) % num_branches] and False:
|
| 204 |
+
tmp = [nn.Identity()]
|
| 205 |
+
else:
|
| 206 |
+
tmp = [norm_layer(dim[d]), act_layer(), nn.Linear(dim[d], dim[(d + 1) % num_branches])]
|
| 207 |
+
self.projs.append(nn.Sequential(*tmp))
|
| 208 |
+
|
| 209 |
+
self.fusion = nn.ModuleList()
|
| 210 |
+
for d in range(num_branches):
|
| 211 |
+
d_ = (d + 1) % num_branches
|
| 212 |
+
nh = num_heads[d_]
|
| 213 |
+
if depth[-1] == 0: # backward capability:
|
| 214 |
+
self.fusion.append(
|
| 215 |
+
CrossAttentionBlock(
|
| 216 |
+
dim=dim[d_],
|
| 217 |
+
num_heads=nh,
|
| 218 |
+
mlp_ratio=mlp_ratio[d],
|
| 219 |
+
qkv_bias=qkv_bias,
|
| 220 |
+
proj_drop=proj_drop,
|
| 221 |
+
attn_drop=attn_drop,
|
| 222 |
+
drop_path=drop_path[-1],
|
| 223 |
+
norm_layer=norm_layer,
|
| 224 |
+
))
|
| 225 |
+
else:
|
| 226 |
+
tmp = []
|
| 227 |
+
for _ in range(depth[-1]):
|
| 228 |
+
tmp.append(CrossAttentionBlock(
|
| 229 |
+
dim=dim[d_],
|
| 230 |
+
num_heads=nh,
|
| 231 |
+
mlp_ratio=mlp_ratio[d],
|
| 232 |
+
qkv_bias=qkv_bias,
|
| 233 |
+
proj_drop=proj_drop,
|
| 234 |
+
attn_drop=attn_drop,
|
| 235 |
+
drop_path=drop_path[-1],
|
| 236 |
+
norm_layer=norm_layer,
|
| 237 |
+
))
|
| 238 |
+
self.fusion.append(nn.Sequential(*tmp))
|
| 239 |
+
|
| 240 |
+
self.revert_projs = nn.ModuleList()
|
| 241 |
+
for d in range(num_branches):
|
| 242 |
+
if dim[(d + 1) % num_branches] == dim[d] and False:
|
| 243 |
+
tmp = [nn.Identity()]
|
| 244 |
+
else:
|
| 245 |
+
tmp = [norm_layer(dim[(d + 1) % num_branches]), act_layer(),
|
| 246 |
+
nn.Linear(dim[(d + 1) % num_branches], dim[d])]
|
| 247 |
+
self.revert_projs.append(nn.Sequential(*tmp))
|
| 248 |
+
|
| 249 |
+
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
|
| 250 |
+
|
| 251 |
+
outs_b = []
|
| 252 |
+
for i, block in enumerate(self.blocks):
|
| 253 |
+
outs_b.append(block(x[i]))
|
| 254 |
+
|
| 255 |
+
# only take the cls token out
|
| 256 |
+
proj_cls_token = torch.jit.annotate(List[torch.Tensor], [])
|
| 257 |
+
for i, proj in enumerate(self.projs):
|
| 258 |
+
proj_cls_token.append(proj(outs_b[i][:, 0:1, ...]))
|
| 259 |
+
|
| 260 |
+
# cross attention
|
| 261 |
+
outs = []
|
| 262 |
+
for i, (fusion, revert_proj) in enumerate(zip(self.fusion, self.revert_projs)):
|
| 263 |
+
tmp = torch.cat((proj_cls_token[i], outs_b[(i + 1) % self.num_branches][:, 1:, ...]), dim=1)
|
| 264 |
+
tmp = fusion(tmp)
|
| 265 |
+
reverted_proj_cls_token = revert_proj(tmp[:, 0:1, ...])
|
| 266 |
+
tmp = torch.cat((reverted_proj_cls_token, outs_b[i][:, 1:, ...]), dim=1)
|
| 267 |
+
outs.append(tmp)
|
| 268 |
+
return outs
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def _compute_num_patches(img_size, patches):
|
| 272 |
+
return [i[0] // p * i[1] // p for i, p in zip(img_size, patches)]
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
@register_notrace_function
|
| 276 |
+
def scale_image(x, ss: Tuple[int, int], crop_scale: bool = False): # annotations for torchscript
|
| 277 |
+
"""
|
| 278 |
+
Pulled out of CrossViT.forward_features to bury conditional logic in a leaf node for FX tracing.
|
| 279 |
+
Args:
|
| 280 |
+
x (Tensor): input image
|
| 281 |
+
ss (tuple[int, int]): height and width to scale to
|
| 282 |
+
crop_scale (bool): whether to crop instead of interpolate to achieve the desired scale. Defaults to False
|
| 283 |
+
Returns:
|
| 284 |
+
Tensor: the "scaled" image batch tensor
|
| 285 |
+
"""
|
| 286 |
+
H, W = x.shape[-2:]
|
| 287 |
+
if H != ss[0] or W != ss[1]:
|
| 288 |
+
if crop_scale and ss[0] <= H and ss[1] <= W:
|
| 289 |
+
cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.))
|
| 290 |
+
x = x[:, :, cu:cu + ss[0], cl:cl + ss[1]]
|
| 291 |
+
else:
|
| 292 |
+
x = torch.nn.functional.interpolate(x, size=ss, mode='bicubic', align_corners=False)
|
| 293 |
+
return x
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class CrossVit(nn.Module):
|
| 297 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
| 298 |
+
"""
|
| 299 |
+
|
| 300 |
+
def __init__(
|
| 301 |
+
self,
|
| 302 |
+
img_size=224,
|
| 303 |
+
img_scale=(1.0, 1.0),
|
| 304 |
+
patch_size=(8, 16),
|
| 305 |
+
in_chans=3,
|
| 306 |
+
num_classes=1000,
|
| 307 |
+
embed_dim=(192, 384),
|
| 308 |
+
depth=((1, 3, 1), (1, 3, 1), (1, 3, 1)),
|
| 309 |
+
num_heads=(6, 12),
|
| 310 |
+
mlp_ratio=(2., 2., 4.),
|
| 311 |
+
multi_conv=False,
|
| 312 |
+
crop_scale=False,
|
| 313 |
+
qkv_bias=True,
|
| 314 |
+
drop_rate=0.,
|
| 315 |
+
pos_drop_rate=0.,
|
| 316 |
+
proj_drop_rate=0.,
|
| 317 |
+
attn_drop_rate=0.,
|
| 318 |
+
drop_path_rate=0.,
|
| 319 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
| 320 |
+
global_pool='token',
|
| 321 |
+
):
|
| 322 |
+
super().__init__()
|
| 323 |
+
assert global_pool in ('token', 'avg')
|
| 324 |
+
|
| 325 |
+
self.num_classes = num_classes
|
| 326 |
+
self.global_pool = global_pool
|
| 327 |
+
self.img_size = to_2tuple(img_size)
|
| 328 |
+
img_scale = to_2tuple(img_scale)
|
| 329 |
+
self.img_size_scaled = [tuple([int(sj * si) for sj in self.img_size]) for si in img_scale]
|
| 330 |
+
self.crop_scale = crop_scale # crop instead of interpolate for scale
|
| 331 |
+
num_patches = _compute_num_patches(self.img_size_scaled, patch_size)
|
| 332 |
+
self.num_branches = len(patch_size)
|
| 333 |
+
self.embed_dim = embed_dim
|
| 334 |
+
self.num_features = sum(embed_dim)
|
| 335 |
+
self.patch_embed = nn.ModuleList()
|
| 336 |
+
|
| 337 |
+
# hard-coded for torch jit script
|
| 338 |
+
for i in range(self.num_branches):
|
| 339 |
+
setattr(self, f'pos_embed_{i}', nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i])))
|
| 340 |
+
setattr(self, f'cls_token_{i}', nn.Parameter(torch.zeros(1, 1, embed_dim[i])))
|
| 341 |
+
|
| 342 |
+
for im_s, p, d in zip(self.img_size_scaled, patch_size, embed_dim):
|
| 343 |
+
self.patch_embed.append(
|
| 344 |
+
PatchEmbed(
|
| 345 |
+
img_size=im_s,
|
| 346 |
+
patch_size=p,
|
| 347 |
+
in_chans=in_chans,
|
| 348 |
+
embed_dim=d,
|
| 349 |
+
multi_conv=multi_conv,
|
| 350 |
+
))
|
| 351 |
+
|
| 352 |
+
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
| 353 |
+
|
| 354 |
+
total_depth = sum([sum(x[-2:]) for x in depth])
|
| 355 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_depth)] # stochastic depth decay rule
|
| 356 |
+
dpr_ptr = 0
|
| 357 |
+
self.blocks = nn.ModuleList()
|
| 358 |
+
for idx, block_cfg in enumerate(depth):
|
| 359 |
+
curr_depth = max(block_cfg[:-1]) + block_cfg[-1]
|
| 360 |
+
dpr_ = dpr[dpr_ptr:dpr_ptr + curr_depth]
|
| 361 |
+
blk = MultiScaleBlock(
|
| 362 |
+
embed_dim,
|
| 363 |
+
num_patches,
|
| 364 |
+
block_cfg,
|
| 365 |
+
num_heads=num_heads,
|
| 366 |
+
mlp_ratio=mlp_ratio,
|
| 367 |
+
qkv_bias=qkv_bias,
|
| 368 |
+
proj_drop=proj_drop_rate,
|
| 369 |
+
attn_drop=attn_drop_rate,
|
| 370 |
+
drop_path=dpr_,
|
| 371 |
+
norm_layer=norm_layer,
|
| 372 |
+
)
|
| 373 |
+
dpr_ptr += curr_depth
|
| 374 |
+
self.blocks.append(blk)
|
| 375 |
+
|
| 376 |
+
self.norm = nn.ModuleList([norm_layer(embed_dim[i]) for i in range(self.num_branches)])
|
| 377 |
+
self.head_drop = nn.Dropout(drop_rate)
|
| 378 |
+
self.head = nn.ModuleList([
|
| 379 |
+
nn.Linear(embed_dim[i], num_classes) if num_classes > 0 else nn.Identity()
|
| 380 |
+
for i in range(self.num_branches)])
|
| 381 |
+
|
| 382 |
+
for i in range(self.num_branches):
|
| 383 |
+
trunc_normal_(getattr(self, f'pos_embed_{i}'), std=.02)
|
| 384 |
+
trunc_normal_(getattr(self, f'cls_token_{i}'), std=.02)
|
| 385 |
+
|
| 386 |
+
self.apply(self._init_weights)
|
| 387 |
+
|
| 388 |
+
def _init_weights(self, m):
|
| 389 |
+
if isinstance(m, nn.Linear):
|
| 390 |
+
trunc_normal_(m.weight, std=.02)
|
| 391 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 392 |
+
nn.init.constant_(m.bias, 0)
|
| 393 |
+
elif isinstance(m, nn.LayerNorm):
|
| 394 |
+
nn.init.constant_(m.bias, 0)
|
| 395 |
+
nn.init.constant_(m.weight, 1.0)
|
| 396 |
+
|
| 397 |
+
@torch.jit.ignore
|
| 398 |
+
def no_weight_decay(self):
|
| 399 |
+
out = set()
|
| 400 |
+
for i in range(self.num_branches):
|
| 401 |
+
out.add(f'cls_token_{i}')
|
| 402 |
+
pe = getattr(self, f'pos_embed_{i}', None)
|
| 403 |
+
if pe is not None and pe.requires_grad:
|
| 404 |
+
out.add(f'pos_embed_{i}')
|
| 405 |
+
return out
|
| 406 |
+
|
| 407 |
+
@torch.jit.ignore
|
| 408 |
+
def group_matcher(self, coarse=False):
|
| 409 |
+
return dict(
|
| 410 |
+
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
|
| 411 |
+
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
@torch.jit.ignore
|
| 415 |
+
def set_grad_checkpointing(self, enable=True):
|
| 416 |
+
assert not enable, 'gradient checkpointing not supported'
|
| 417 |
+
|
| 418 |
+
@torch.jit.ignore
|
| 419 |
+
def get_classifier(self):
|
| 420 |
+
return self.head
|
| 421 |
+
|
| 422 |
+
def reset_classifier(self, num_classes, global_pool=None):
|
| 423 |
+
self.num_classes = num_classes
|
| 424 |
+
if global_pool is not None:
|
| 425 |
+
assert global_pool in ('token', 'avg')
|
| 426 |
+
self.global_pool = global_pool
|
| 427 |
+
self.head = nn.ModuleList(
|
| 428 |
+
[nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in
|
| 429 |
+
range(self.num_branches)])
|
| 430 |
+
|
| 431 |
+
def forward_features(self, x) -> List[torch.Tensor]:
|
| 432 |
+
B = x.shape[0]
|
| 433 |
+
xs = []
|
| 434 |
+
for i, patch_embed in enumerate(self.patch_embed):
|
| 435 |
+
x_ = x
|
| 436 |
+
ss = self.img_size_scaled[i]
|
| 437 |
+
x_ = scale_image(x_, ss, self.crop_scale)
|
| 438 |
+
x_ = patch_embed(x_)
|
| 439 |
+
cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script
|
| 440 |
+
cls_tokens = cls_tokens.expand(B, -1, -1)
|
| 441 |
+
x_ = torch.cat((cls_tokens, x_), dim=1)
|
| 442 |
+
pos_embed = self.pos_embed_0 if i == 0 else self.pos_embed_1 # hard-coded for torch jit script
|
| 443 |
+
x_ = x_ + pos_embed
|
| 444 |
+
x_ = self.pos_drop(x_)
|
| 445 |
+
xs.append(x_)
|
| 446 |
+
|
| 447 |
+
for i, blk in enumerate(self.blocks):
|
| 448 |
+
xs = blk(xs)
|
| 449 |
+
|
| 450 |
+
# NOTE: was before branch token section, move to here to assure all branch token are before layer norm
|
| 451 |
+
xs = [norm(xs[i]) for i, norm in enumerate(self.norm)]
|
| 452 |
+
return xs
|
| 453 |
+
|
| 454 |
+
def forward_head(self, xs: List[torch.Tensor], pre_logits: bool = False) -> torch.Tensor:
|
| 455 |
+
xs = [x[:, 1:].mean(dim=1) for x in xs] if self.global_pool == 'avg' else [x[:, 0] for x in xs]
|
| 456 |
+
xs = [self.head_drop(x) for x in xs]
|
| 457 |
+
if pre_logits or isinstance(self.head[0], nn.Identity):
|
| 458 |
+
return torch.cat([x for x in xs], dim=1)
|
| 459 |
+
return torch.mean(torch.stack([head(xs[i]) for i, head in enumerate(self.head)], dim=0), dim=0)
|
| 460 |
+
|
| 461 |
+
def forward(self, x):
|
| 462 |
+
xs = self.forward_features(x)
|
| 463 |
+
x = self.forward_head(xs)
|
| 464 |
+
return x
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def _create_crossvit(variant, pretrained=False, **kwargs):
|
| 468 |
+
if kwargs.get('features_only', None):
|
| 469 |
+
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
| 470 |
+
|
| 471 |
+
def pretrained_filter_fn(state_dict):
|
| 472 |
+
new_state_dict = {}
|
| 473 |
+
for key in state_dict.keys():
|
| 474 |
+
if 'pos_embed' in key or 'cls_token' in key:
|
| 475 |
+
new_key = key.replace(".", "_")
|
| 476 |
+
else:
|
| 477 |
+
new_key = key
|
| 478 |
+
new_state_dict[new_key] = state_dict[key]
|
| 479 |
+
return new_state_dict
|
| 480 |
+
|
| 481 |
+
return build_model_with_cfg(
|
| 482 |
+
CrossVit,
|
| 483 |
+
variant,
|
| 484 |
+
pretrained,
|
| 485 |
+
pretrained_filter_fn=pretrained_filter_fn,
|
| 486 |
+
**kwargs,
|
| 487 |
+
)
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
def _cfg(url='', **kwargs):
|
| 491 |
+
return {
|
| 492 |
+
'url': url,
|
| 493 |
+
'num_classes': 1000, 'input_size': (3, 240, 240), 'pool_size': None, 'crop_pct': 0.875,
|
| 494 |
+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True,
|
| 495 |
+
'first_conv': ('patch_embed.0.proj', 'patch_embed.1.proj'),
|
| 496 |
+
'classifier': ('head.0', 'head.1'),
|
| 497 |
+
**kwargs
|
| 498 |
+
}
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
default_cfgs = generate_default_cfgs({
|
| 502 |
+
'crossvit_15_240.in1k': _cfg(hf_hub_id='timm/'),
|
| 503 |
+
'crossvit_15_dagger_240.in1k': _cfg(
|
| 504 |
+
hf_hub_id='timm/',
|
| 505 |
+
first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
|
| 506 |
+
),
|
| 507 |
+
'crossvit_15_dagger_408.in1k': _cfg(
|
| 508 |
+
hf_hub_id='timm/',
|
| 509 |
+
input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), crop_pct=1.0,
|
| 510 |
+
),
|
| 511 |
+
'crossvit_18_240.in1k': _cfg(hf_hub_id='timm/'),
|
| 512 |
+
'crossvit_18_dagger_240.in1k': _cfg(
|
| 513 |
+
hf_hub_id='timm/',
|
| 514 |
+
first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
|
| 515 |
+
),
|
| 516 |
+
'crossvit_18_dagger_408.in1k': _cfg(
|
| 517 |
+
hf_hub_id='timm/',
|
| 518 |
+
input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), crop_pct=1.0,
|
| 519 |
+
),
|
| 520 |
+
'crossvit_9_240.in1k': _cfg(hf_hub_id='timm/'),
|
| 521 |
+
'crossvit_9_dagger_240.in1k': _cfg(
|
| 522 |
+
hf_hub_id='timm/',
|
| 523 |
+
first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
|
| 524 |
+
),
|
| 525 |
+
'crossvit_base_240.in1k': _cfg(hf_hub_id='timm/'),
|
| 526 |
+
'crossvit_small_240.in1k': _cfg(hf_hub_id='timm/'),
|
| 527 |
+
'crossvit_tiny_240.in1k': _cfg(hf_hub_id='timm/'),
|
| 528 |
+
})
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
@register_model
|
| 532 |
+
def crossvit_tiny_240(pretrained=False, **kwargs) -> CrossVit:
|
| 533 |
+
model_args = dict(
|
| 534 |
+
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[96, 192], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
|
| 535 |
+
num_heads=[3, 3], mlp_ratio=[4, 4, 1])
|
| 536 |
+
model = _create_crossvit(variant='crossvit_tiny_240', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 537 |
+
return model
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
@register_model
|
| 541 |
+
def crossvit_small_240(pretrained=False, **kwargs) -> CrossVit:
|
| 542 |
+
model_args = dict(
|
| 543 |
+
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
|
| 544 |
+
num_heads=[6, 6], mlp_ratio=[4, 4, 1])
|
| 545 |
+
model = _create_crossvit(variant='crossvit_small_240', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 546 |
+
return model
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
@register_model
|
| 550 |
+
def crossvit_base_240(pretrained=False, **kwargs) -> CrossVit:
|
| 551 |
+
model_args = dict(
|
| 552 |
+
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[384, 768], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
|
| 553 |
+
num_heads=[12, 12], mlp_ratio=[4, 4, 1])
|
| 554 |
+
model = _create_crossvit(variant='crossvit_base_240', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 555 |
+
return model
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
@register_model
|
| 559 |
+
def crossvit_9_240(pretrained=False, **kwargs) -> CrossVit:
|
| 560 |
+
model_args = dict(
|
| 561 |
+
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]],
|
| 562 |
+
num_heads=[4, 4], mlp_ratio=[3, 3, 1])
|
| 563 |
+
model = _create_crossvit(variant='crossvit_9_240', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 564 |
+
return model
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
@register_model
|
| 568 |
+
def crossvit_15_240(pretrained=False, **kwargs) -> CrossVit:
|
| 569 |
+
model_args = dict(
|
| 570 |
+
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
|
| 571 |
+
num_heads=[6, 6], mlp_ratio=[3, 3, 1])
|
| 572 |
+
model = _create_crossvit(variant='crossvit_15_240', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 573 |
+
return model
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
@register_model
|
| 577 |
+
def crossvit_18_240(pretrained=False, **kwargs) -> CrossVit:
|
| 578 |
+
model_args = dict(
|
| 579 |
+
img_scale=(1.0, 224 / 240), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
|
| 580 |
+
num_heads=[7, 7], mlp_ratio=[3, 3, 1], **kwargs)
|
| 581 |
+
model = _create_crossvit(variant='crossvit_18_240', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 582 |
+
return model
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
@register_model
|
| 586 |
+
def crossvit_9_dagger_240(pretrained=False, **kwargs) -> CrossVit:
|
| 587 |
+
model_args = dict(
|
| 588 |
+
img_scale=(1.0, 224 / 240), patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]],
|
| 589 |
+
num_heads=[4, 4], mlp_ratio=[3, 3, 1], multi_conv=True)
|
| 590 |
+
model = _create_crossvit(variant='crossvit_9_dagger_240', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 591 |
+
return model
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
@register_model
|
| 595 |
+
def crossvit_15_dagger_240(pretrained=False, **kwargs) -> CrossVit:
|
| 596 |
+
model_args = dict(
|
| 597 |
+
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
|
| 598 |
+
num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True)
|
| 599 |
+
model = _create_crossvit(variant='crossvit_15_dagger_240', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 600 |
+
return model
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
@register_model
|
| 604 |
+
def crossvit_15_dagger_408(pretrained=False, **kwargs) -> CrossVit:
|
| 605 |
+
model_args = dict(
|
| 606 |
+
img_scale=(1.0, 384/408), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
|
| 607 |
+
num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True)
|
| 608 |
+
model = _create_crossvit(variant='crossvit_15_dagger_408', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 609 |
+
return model
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
@register_model
|
| 613 |
+
def crossvit_18_dagger_240(pretrained=False, **kwargs) -> CrossVit:
|
| 614 |
+
model_args = dict(
|
| 615 |
+
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
|
| 616 |
+
num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True)
|
| 617 |
+
model = _create_crossvit(variant='crossvit_18_dagger_240', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 618 |
+
return model
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
@register_model
|
| 622 |
+
def crossvit_18_dagger_408(pretrained=False, **kwargs) -> CrossVit:
|
| 623 |
+
model_args = dict(
|
| 624 |
+
img_scale=(1.0, 384/408), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
|
| 625 |
+
num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True)
|
| 626 |
+
model = _create_crossvit(variant='crossvit_18_dagger_408', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 627 |
+
return model
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/cspnet.py
ADDED
|
@@ -0,0 +1,1106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PyTorch CspNet
|
| 2 |
+
|
| 3 |
+
A PyTorch implementation of Cross Stage Partial Networks including:
|
| 4 |
+
* CSPResNet50
|
| 5 |
+
* CSPResNeXt50
|
| 6 |
+
* CSPDarkNet53
|
| 7 |
+
* and DarkNet53 for good measure
|
| 8 |
+
|
| 9 |
+
Based on paper `CSPNet: A New Backbone that can Enhance Learning Capability of CNN` - https://arxiv.org/abs/1911.11929
|
| 10 |
+
|
| 11 |
+
Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStagePartialNetworks
|
| 12 |
+
|
| 13 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
| 14 |
+
"""
|
| 15 |
+
from dataclasses import dataclass, asdict, replace
|
| 16 |
+
from functools import partial
|
| 17 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
|
| 22 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 23 |
+
from timm.layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, get_attn, create_act_layer, make_divisible
|
| 24 |
+
from ._builder import build_model_with_cfg
|
| 25 |
+
from ._manipulate import named_apply, MATCH_PREV_GROUP
|
| 26 |
+
from ._registry import register_model, generate_default_cfgs
|
| 27 |
+
|
| 28 |
+
__all__ = ['CspNet'] # model_registry will add each entrypoint fn to this
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class CspStemCfg:
|
| 33 |
+
out_chs: Union[int, Tuple[int, ...]] = 32
|
| 34 |
+
stride: Union[int, Tuple[int, ...]] = 2
|
| 35 |
+
kernel_size: int = 3
|
| 36 |
+
padding: Union[int, str] = ''
|
| 37 |
+
pool: Optional[str] = ''
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _pad_arg(x, n):
|
| 41 |
+
# pads an argument tuple to specified n by padding with last value
|
| 42 |
+
if not isinstance(x, (tuple, list)):
|
| 43 |
+
x = (x,)
|
| 44 |
+
curr_n = len(x)
|
| 45 |
+
pad_n = n - curr_n
|
| 46 |
+
if pad_n <= 0:
|
| 47 |
+
return x[:n]
|
| 48 |
+
return tuple(x + (x[-1],) * pad_n)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@dataclass
|
| 52 |
+
class CspStagesCfg:
|
| 53 |
+
depth: Tuple[int, ...] = (3, 3, 5, 2) # block depth (number of block repeats in stages)
|
| 54 |
+
out_chs: Tuple[int, ...] = (128, 256, 512, 1024) # number of output channels for blocks in stage
|
| 55 |
+
stride: Union[int, Tuple[int, ...]] = 2 # stride of stage
|
| 56 |
+
groups: Union[int, Tuple[int, ...]] = 1 # num kxk conv groups
|
| 57 |
+
block_ratio: Union[float, Tuple[float, ...]] = 1.0
|
| 58 |
+
bottle_ratio: Union[float, Tuple[float, ...]] = 1. # bottleneck-ratio of blocks in stage
|
| 59 |
+
avg_down: Union[bool, Tuple[bool, ...]] = False
|
| 60 |
+
attn_layer: Optional[Union[str, Tuple[str, ...]]] = None
|
| 61 |
+
attn_kwargs: Optional[Union[Dict, Tuple[Dict]]] = None
|
| 62 |
+
stage_type: Union[str, Tuple[str]] = 'csp' # stage type ('csp', 'cs2', 'dark')
|
| 63 |
+
block_type: Union[str, Tuple[str]] = 'bottle' # blocks type for stages ('bottle', 'dark')
|
| 64 |
+
|
| 65 |
+
# cross-stage only
|
| 66 |
+
expand_ratio: Union[float, Tuple[float, ...]] = 1.0
|
| 67 |
+
cross_linear: Union[bool, Tuple[bool, ...]] = False
|
| 68 |
+
down_growth: Union[bool, Tuple[bool, ...]] = False
|
| 69 |
+
|
| 70 |
+
def __post_init__(self):
|
| 71 |
+
n = len(self.depth)
|
| 72 |
+
assert len(self.out_chs) == n
|
| 73 |
+
self.stride = _pad_arg(self.stride, n)
|
| 74 |
+
self.groups = _pad_arg(self.groups, n)
|
| 75 |
+
self.block_ratio = _pad_arg(self.block_ratio, n)
|
| 76 |
+
self.bottle_ratio = _pad_arg(self.bottle_ratio, n)
|
| 77 |
+
self.avg_down = _pad_arg(self.avg_down, n)
|
| 78 |
+
self.attn_layer = _pad_arg(self.attn_layer, n)
|
| 79 |
+
self.attn_kwargs = _pad_arg(self.attn_kwargs, n)
|
| 80 |
+
self.stage_type = _pad_arg(self.stage_type, n)
|
| 81 |
+
self.block_type = _pad_arg(self.block_type, n)
|
| 82 |
+
|
| 83 |
+
self.expand_ratio = _pad_arg(self.expand_ratio, n)
|
| 84 |
+
self.cross_linear = _pad_arg(self.cross_linear, n)
|
| 85 |
+
self.down_growth = _pad_arg(self.down_growth, n)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@dataclass
|
| 89 |
+
class CspModelCfg:
|
| 90 |
+
stem: CspStemCfg
|
| 91 |
+
stages: CspStagesCfg
|
| 92 |
+
zero_init_last: bool = True # zero init last weight (usually bn) in residual path
|
| 93 |
+
act_layer: str = 'leaky_relu'
|
| 94 |
+
norm_layer: str = 'batchnorm'
|
| 95 |
+
aa_layer: Optional[str] = None # FIXME support string factory for this
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def _cs3_cfg(
|
| 99 |
+
width_multiplier=1.0,
|
| 100 |
+
depth_multiplier=1.0,
|
| 101 |
+
avg_down=False,
|
| 102 |
+
act_layer='silu',
|
| 103 |
+
focus=False,
|
| 104 |
+
attn_layer=None,
|
| 105 |
+
attn_kwargs=None,
|
| 106 |
+
bottle_ratio=1.0,
|
| 107 |
+
block_type='dark',
|
| 108 |
+
):
|
| 109 |
+
if focus:
|
| 110 |
+
stem_cfg = CspStemCfg(
|
| 111 |
+
out_chs=make_divisible(64 * width_multiplier),
|
| 112 |
+
kernel_size=6, stride=2, padding=2, pool='')
|
| 113 |
+
else:
|
| 114 |
+
stem_cfg = CspStemCfg(
|
| 115 |
+
out_chs=tuple([make_divisible(c * width_multiplier) for c in (32, 64)]),
|
| 116 |
+
kernel_size=3, stride=2, pool='')
|
| 117 |
+
return CspModelCfg(
|
| 118 |
+
stem=stem_cfg,
|
| 119 |
+
stages=CspStagesCfg(
|
| 120 |
+
out_chs=tuple([make_divisible(c * width_multiplier) for c in (128, 256, 512, 1024)]),
|
| 121 |
+
depth=tuple([int(d * depth_multiplier) for d in (3, 6, 9, 3)]),
|
| 122 |
+
stride=2,
|
| 123 |
+
bottle_ratio=bottle_ratio,
|
| 124 |
+
block_ratio=0.5,
|
| 125 |
+
avg_down=avg_down,
|
| 126 |
+
attn_layer=attn_layer,
|
| 127 |
+
attn_kwargs=attn_kwargs,
|
| 128 |
+
stage_type='cs3',
|
| 129 |
+
block_type=block_type,
|
| 130 |
+
),
|
| 131 |
+
act_layer=act_layer,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class BottleneckBlock(nn.Module):
|
| 136 |
+
""" ResNe(X)t Bottleneck Block
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
def __init__(
|
| 140 |
+
self,
|
| 141 |
+
in_chs,
|
| 142 |
+
out_chs,
|
| 143 |
+
dilation=1,
|
| 144 |
+
bottle_ratio=0.25,
|
| 145 |
+
groups=1,
|
| 146 |
+
act_layer=nn.ReLU,
|
| 147 |
+
norm_layer=nn.BatchNorm2d,
|
| 148 |
+
attn_last=False,
|
| 149 |
+
attn_layer=None,
|
| 150 |
+
drop_block=None,
|
| 151 |
+
drop_path=0.
|
| 152 |
+
):
|
| 153 |
+
super(BottleneckBlock, self).__init__()
|
| 154 |
+
mid_chs = int(round(out_chs * bottle_ratio))
|
| 155 |
+
ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
|
| 156 |
+
attn_last = attn_layer is not None and attn_last
|
| 157 |
+
attn_first = attn_layer is not None and not attn_last
|
| 158 |
+
|
| 159 |
+
self.conv1 = ConvNormAct(in_chs, mid_chs, kernel_size=1, **ckwargs)
|
| 160 |
+
self.conv2 = ConvNormAct(
|
| 161 |
+
mid_chs, mid_chs, kernel_size=3, dilation=dilation, groups=groups,
|
| 162 |
+
drop_layer=drop_block, **ckwargs)
|
| 163 |
+
self.attn2 = attn_layer(mid_chs, act_layer=act_layer) if attn_first else nn.Identity()
|
| 164 |
+
self.conv3 = ConvNormAct(mid_chs, out_chs, kernel_size=1, apply_act=False, **ckwargs)
|
| 165 |
+
self.attn3 = attn_layer(out_chs, act_layer=act_layer) if attn_last else nn.Identity()
|
| 166 |
+
self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
|
| 167 |
+
self.act3 = create_act_layer(act_layer)
|
| 168 |
+
|
| 169 |
+
def zero_init_last(self):
|
| 170 |
+
nn.init.zeros_(self.conv3.bn.weight)
|
| 171 |
+
|
| 172 |
+
def forward(self, x):
|
| 173 |
+
shortcut = x
|
| 174 |
+
x = self.conv1(x)
|
| 175 |
+
x = self.conv2(x)
|
| 176 |
+
x = self.attn2(x)
|
| 177 |
+
x = self.conv3(x)
|
| 178 |
+
x = self.attn3(x)
|
| 179 |
+
x = self.drop_path(x) + shortcut
|
| 180 |
+
# FIXME partial shortcut needed if first block handled as per original, not used for my current impl
|
| 181 |
+
#x[:, :shortcut.size(1)] += shortcut
|
| 182 |
+
x = self.act3(x)
|
| 183 |
+
return x
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class DarkBlock(nn.Module):
|
| 187 |
+
""" DarkNet Block
|
| 188 |
+
"""
|
| 189 |
+
|
| 190 |
+
def __init__(
|
| 191 |
+
self,
|
| 192 |
+
in_chs,
|
| 193 |
+
out_chs,
|
| 194 |
+
dilation=1,
|
| 195 |
+
bottle_ratio=0.5,
|
| 196 |
+
groups=1,
|
| 197 |
+
act_layer=nn.ReLU,
|
| 198 |
+
norm_layer=nn.BatchNorm2d,
|
| 199 |
+
attn_layer=None,
|
| 200 |
+
drop_block=None,
|
| 201 |
+
drop_path=0.
|
| 202 |
+
):
|
| 203 |
+
super(DarkBlock, self).__init__()
|
| 204 |
+
mid_chs = int(round(out_chs * bottle_ratio))
|
| 205 |
+
ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
|
| 206 |
+
|
| 207 |
+
self.conv1 = ConvNormAct(in_chs, mid_chs, kernel_size=1, **ckwargs)
|
| 208 |
+
self.attn = attn_layer(mid_chs, act_layer=act_layer) if attn_layer is not None else nn.Identity()
|
| 209 |
+
self.conv2 = ConvNormAct(
|
| 210 |
+
mid_chs, out_chs, kernel_size=3, dilation=dilation, groups=groups,
|
| 211 |
+
drop_layer=drop_block, **ckwargs)
|
| 212 |
+
self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
|
| 213 |
+
|
| 214 |
+
def zero_init_last(self):
|
| 215 |
+
nn.init.zeros_(self.conv2.bn.weight)
|
| 216 |
+
|
| 217 |
+
def forward(self, x):
|
| 218 |
+
shortcut = x
|
| 219 |
+
x = self.conv1(x)
|
| 220 |
+
x = self.attn(x)
|
| 221 |
+
x = self.conv2(x)
|
| 222 |
+
x = self.drop_path(x) + shortcut
|
| 223 |
+
return x
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class EdgeBlock(nn.Module):
|
| 227 |
+
""" EdgeResidual / Fused-MBConv / MobileNetV1-like 3x3 + 1x1 block (w/ activated output)
|
| 228 |
+
"""
|
| 229 |
+
|
| 230 |
+
def __init__(
|
| 231 |
+
self,
|
| 232 |
+
in_chs,
|
| 233 |
+
out_chs,
|
| 234 |
+
dilation=1,
|
| 235 |
+
bottle_ratio=0.5,
|
| 236 |
+
groups=1,
|
| 237 |
+
act_layer=nn.ReLU,
|
| 238 |
+
norm_layer=nn.BatchNorm2d,
|
| 239 |
+
attn_layer=None,
|
| 240 |
+
drop_block=None,
|
| 241 |
+
drop_path=0.
|
| 242 |
+
):
|
| 243 |
+
super(EdgeBlock, self).__init__()
|
| 244 |
+
mid_chs = int(round(out_chs * bottle_ratio))
|
| 245 |
+
ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
|
| 246 |
+
|
| 247 |
+
self.conv1 = ConvNormAct(
|
| 248 |
+
in_chs, mid_chs, kernel_size=3, dilation=dilation, groups=groups,
|
| 249 |
+
drop_layer=drop_block, **ckwargs)
|
| 250 |
+
self.attn = attn_layer(mid_chs, act_layer=act_layer) if attn_layer is not None else nn.Identity()
|
| 251 |
+
self.conv2 = ConvNormAct(mid_chs, out_chs, kernel_size=1, **ckwargs)
|
| 252 |
+
self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
|
| 253 |
+
|
| 254 |
+
def zero_init_last(self):
|
| 255 |
+
nn.init.zeros_(self.conv2.bn.weight)
|
| 256 |
+
|
| 257 |
+
def forward(self, x):
|
| 258 |
+
shortcut = x
|
| 259 |
+
x = self.conv1(x)
|
| 260 |
+
x = self.attn(x)
|
| 261 |
+
x = self.conv2(x)
|
| 262 |
+
x = self.drop_path(x) + shortcut
|
| 263 |
+
return x
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
class CrossStage(nn.Module):
|
| 267 |
+
"""Cross Stage."""
|
| 268 |
+
def __init__(
|
| 269 |
+
self,
|
| 270 |
+
in_chs,
|
| 271 |
+
out_chs,
|
| 272 |
+
stride,
|
| 273 |
+
dilation,
|
| 274 |
+
depth,
|
| 275 |
+
block_ratio=1.,
|
| 276 |
+
bottle_ratio=1.,
|
| 277 |
+
expand_ratio=1.,
|
| 278 |
+
groups=1,
|
| 279 |
+
first_dilation=None,
|
| 280 |
+
avg_down=False,
|
| 281 |
+
down_growth=False,
|
| 282 |
+
cross_linear=False,
|
| 283 |
+
block_dpr=None,
|
| 284 |
+
block_fn=BottleneckBlock,
|
| 285 |
+
**block_kwargs,
|
| 286 |
+
):
|
| 287 |
+
super(CrossStage, self).__init__()
|
| 288 |
+
first_dilation = first_dilation or dilation
|
| 289 |
+
down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels
|
| 290 |
+
self.expand_chs = exp_chs = int(round(out_chs * expand_ratio))
|
| 291 |
+
block_out_chs = int(round(out_chs * block_ratio))
|
| 292 |
+
conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
|
| 293 |
+
aa_layer = block_kwargs.pop('aa_layer', None)
|
| 294 |
+
|
| 295 |
+
if stride != 1 or first_dilation != dilation:
|
| 296 |
+
if avg_down:
|
| 297 |
+
self.conv_down = nn.Sequential(
|
| 298 |
+
nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
|
| 299 |
+
ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
|
| 300 |
+
)
|
| 301 |
+
else:
|
| 302 |
+
self.conv_down = ConvNormActAa(
|
| 303 |
+
in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
|
| 304 |
+
aa_layer=aa_layer, **conv_kwargs)
|
| 305 |
+
prev_chs = down_chs
|
| 306 |
+
else:
|
| 307 |
+
self.conv_down = nn.Identity()
|
| 308 |
+
prev_chs = in_chs
|
| 309 |
+
|
| 310 |
+
# FIXME this 1x1 expansion is pushed down into the cross and block paths in the darknet cfgs. Also,
|
| 311 |
+
# there is also special case for the first stage for some of the model that results in uneven split
|
| 312 |
+
# across the two paths. I did it this way for simplicity for now.
|
| 313 |
+
self.conv_exp = ConvNormAct(prev_chs, exp_chs, kernel_size=1, apply_act=not cross_linear, **conv_kwargs)
|
| 314 |
+
prev_chs = exp_chs // 2 # output of conv_exp is always split in two
|
| 315 |
+
|
| 316 |
+
self.blocks = nn.Sequential()
|
| 317 |
+
for i in range(depth):
|
| 318 |
+
self.blocks.add_module(str(i), block_fn(
|
| 319 |
+
in_chs=prev_chs,
|
| 320 |
+
out_chs=block_out_chs,
|
| 321 |
+
dilation=dilation,
|
| 322 |
+
bottle_ratio=bottle_ratio,
|
| 323 |
+
groups=groups,
|
| 324 |
+
drop_path=block_dpr[i] if block_dpr is not None else 0.,
|
| 325 |
+
**block_kwargs,
|
| 326 |
+
))
|
| 327 |
+
prev_chs = block_out_chs
|
| 328 |
+
|
| 329 |
+
# transition convs
|
| 330 |
+
self.conv_transition_b = ConvNormAct(prev_chs, exp_chs // 2, kernel_size=1, **conv_kwargs)
|
| 331 |
+
self.conv_transition = ConvNormAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs)
|
| 332 |
+
|
| 333 |
+
def forward(self, x):
|
| 334 |
+
x = self.conv_down(x)
|
| 335 |
+
x = self.conv_exp(x)
|
| 336 |
+
xs, xb = x.split(self.expand_chs // 2, dim=1)
|
| 337 |
+
xb = self.blocks(xb)
|
| 338 |
+
xb = self.conv_transition_b(xb).contiguous()
|
| 339 |
+
out = self.conv_transition(torch.cat([xs, xb], dim=1))
|
| 340 |
+
return out
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
class CrossStage3(nn.Module):
|
| 344 |
+
"""Cross Stage 3.
|
| 345 |
+
Similar to CrossStage, but with only one transition conv for the output.
|
| 346 |
+
"""
|
| 347 |
+
def __init__(
|
| 348 |
+
self,
|
| 349 |
+
in_chs,
|
| 350 |
+
out_chs,
|
| 351 |
+
stride,
|
| 352 |
+
dilation,
|
| 353 |
+
depth,
|
| 354 |
+
block_ratio=1.,
|
| 355 |
+
bottle_ratio=1.,
|
| 356 |
+
expand_ratio=1.,
|
| 357 |
+
groups=1,
|
| 358 |
+
first_dilation=None,
|
| 359 |
+
avg_down=False,
|
| 360 |
+
down_growth=False,
|
| 361 |
+
cross_linear=False,
|
| 362 |
+
block_dpr=None,
|
| 363 |
+
block_fn=BottleneckBlock,
|
| 364 |
+
**block_kwargs,
|
| 365 |
+
):
|
| 366 |
+
super(CrossStage3, self).__init__()
|
| 367 |
+
first_dilation = first_dilation or dilation
|
| 368 |
+
down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels
|
| 369 |
+
self.expand_chs = exp_chs = int(round(out_chs * expand_ratio))
|
| 370 |
+
block_out_chs = int(round(out_chs * block_ratio))
|
| 371 |
+
conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
|
| 372 |
+
aa_layer = block_kwargs.pop('aa_layer', None)
|
| 373 |
+
|
| 374 |
+
if stride != 1 or first_dilation != dilation:
|
| 375 |
+
if avg_down:
|
| 376 |
+
self.conv_down = nn.Sequential(
|
| 377 |
+
nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
|
| 378 |
+
ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
|
| 379 |
+
)
|
| 380 |
+
else:
|
| 381 |
+
self.conv_down = ConvNormActAa(
|
| 382 |
+
in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
|
| 383 |
+
aa_layer=aa_layer, **conv_kwargs)
|
| 384 |
+
prev_chs = down_chs
|
| 385 |
+
else:
|
| 386 |
+
self.conv_down = None
|
| 387 |
+
prev_chs = in_chs
|
| 388 |
+
|
| 389 |
+
# expansion conv
|
| 390 |
+
self.conv_exp = ConvNormAct(prev_chs, exp_chs, kernel_size=1, apply_act=not cross_linear, **conv_kwargs)
|
| 391 |
+
prev_chs = exp_chs // 2 # expanded output is split in 2 for blocks and cross stage
|
| 392 |
+
|
| 393 |
+
self.blocks = nn.Sequential()
|
| 394 |
+
for i in range(depth):
|
| 395 |
+
self.blocks.add_module(str(i), block_fn(
|
| 396 |
+
in_chs=prev_chs,
|
| 397 |
+
out_chs=block_out_chs,
|
| 398 |
+
dilation=dilation,
|
| 399 |
+
bottle_ratio=bottle_ratio,
|
| 400 |
+
groups=groups,
|
| 401 |
+
drop_path=block_dpr[i] if block_dpr is not None else 0.,
|
| 402 |
+
**block_kwargs,
|
| 403 |
+
))
|
| 404 |
+
prev_chs = block_out_chs
|
| 405 |
+
|
| 406 |
+
# transition convs
|
| 407 |
+
self.conv_transition = ConvNormAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs)
|
| 408 |
+
|
| 409 |
+
def forward(self, x):
|
| 410 |
+
x = self.conv_down(x)
|
| 411 |
+
x = self.conv_exp(x)
|
| 412 |
+
x1, x2 = x.split(self.expand_chs // 2, dim=1)
|
| 413 |
+
x1 = self.blocks(x1)
|
| 414 |
+
out = self.conv_transition(torch.cat([x1, x2], dim=1))
|
| 415 |
+
return out
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
class DarkStage(nn.Module):
|
| 419 |
+
"""DarkNet stage."""
|
| 420 |
+
|
| 421 |
+
def __init__(
|
| 422 |
+
self,
|
| 423 |
+
in_chs,
|
| 424 |
+
out_chs,
|
| 425 |
+
stride,
|
| 426 |
+
dilation,
|
| 427 |
+
depth,
|
| 428 |
+
block_ratio=1.,
|
| 429 |
+
bottle_ratio=1.,
|
| 430 |
+
groups=1,
|
| 431 |
+
first_dilation=None,
|
| 432 |
+
avg_down=False,
|
| 433 |
+
block_fn=BottleneckBlock,
|
| 434 |
+
block_dpr=None,
|
| 435 |
+
**block_kwargs,
|
| 436 |
+
):
|
| 437 |
+
super(DarkStage, self).__init__()
|
| 438 |
+
first_dilation = first_dilation or dilation
|
| 439 |
+
conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
|
| 440 |
+
aa_layer = block_kwargs.pop('aa_layer', None)
|
| 441 |
+
|
| 442 |
+
if avg_down:
|
| 443 |
+
self.conv_down = nn.Sequential(
|
| 444 |
+
nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
|
| 445 |
+
ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
|
| 446 |
+
)
|
| 447 |
+
else:
|
| 448 |
+
self.conv_down = ConvNormActAa(
|
| 449 |
+
in_chs, out_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
|
| 450 |
+
aa_layer=aa_layer, **conv_kwargs)
|
| 451 |
+
|
| 452 |
+
prev_chs = out_chs
|
| 453 |
+
block_out_chs = int(round(out_chs * block_ratio))
|
| 454 |
+
self.blocks = nn.Sequential()
|
| 455 |
+
for i in range(depth):
|
| 456 |
+
self.blocks.add_module(str(i), block_fn(
|
| 457 |
+
in_chs=prev_chs,
|
| 458 |
+
out_chs=block_out_chs,
|
| 459 |
+
dilation=dilation,
|
| 460 |
+
bottle_ratio=bottle_ratio,
|
| 461 |
+
groups=groups,
|
| 462 |
+
drop_path=block_dpr[i] if block_dpr is not None else 0.,
|
| 463 |
+
**block_kwargs
|
| 464 |
+
))
|
| 465 |
+
prev_chs = block_out_chs
|
| 466 |
+
|
| 467 |
+
def forward(self, x):
|
| 468 |
+
x = self.conv_down(x)
|
| 469 |
+
x = self.blocks(x)
|
| 470 |
+
return x
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
def create_csp_stem(
|
| 474 |
+
in_chans=3,
|
| 475 |
+
out_chs=32,
|
| 476 |
+
kernel_size=3,
|
| 477 |
+
stride=2,
|
| 478 |
+
pool='',
|
| 479 |
+
padding='',
|
| 480 |
+
act_layer=nn.ReLU,
|
| 481 |
+
norm_layer=nn.BatchNorm2d,
|
| 482 |
+
aa_layer=None,
|
| 483 |
+
):
|
| 484 |
+
stem = nn.Sequential()
|
| 485 |
+
feature_info = []
|
| 486 |
+
if not isinstance(out_chs, (tuple, list)):
|
| 487 |
+
out_chs = [out_chs]
|
| 488 |
+
stem_depth = len(out_chs)
|
| 489 |
+
assert stem_depth
|
| 490 |
+
assert stride in (1, 2, 4)
|
| 491 |
+
prev_feat = None
|
| 492 |
+
prev_chs = in_chans
|
| 493 |
+
last_idx = stem_depth - 1
|
| 494 |
+
stem_stride = 1
|
| 495 |
+
for i, chs in enumerate(out_chs):
|
| 496 |
+
conv_name = f'conv{i + 1}'
|
| 497 |
+
conv_stride = 2 if (i == 0 and stride > 1) or (i == last_idx and stride > 2 and not pool) else 1
|
| 498 |
+
if conv_stride > 1 and prev_feat is not None:
|
| 499 |
+
feature_info.append(prev_feat)
|
| 500 |
+
stem.add_module(conv_name, ConvNormAct(
|
| 501 |
+
prev_chs, chs, kernel_size,
|
| 502 |
+
stride=conv_stride,
|
| 503 |
+
padding=padding if i == 0 else '',
|
| 504 |
+
act_layer=act_layer,
|
| 505 |
+
norm_layer=norm_layer,
|
| 506 |
+
))
|
| 507 |
+
stem_stride *= conv_stride
|
| 508 |
+
prev_chs = chs
|
| 509 |
+
prev_feat = dict(num_chs=prev_chs, reduction=stem_stride, module='.'.join(['stem', conv_name]))
|
| 510 |
+
if pool:
|
| 511 |
+
assert stride > 2
|
| 512 |
+
if prev_feat is not None:
|
| 513 |
+
feature_info.append(prev_feat)
|
| 514 |
+
if aa_layer is not None:
|
| 515 |
+
stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=1, padding=1))
|
| 516 |
+
stem.add_module('aa', aa_layer(channels=prev_chs, stride=2))
|
| 517 |
+
pool_name = 'aa'
|
| 518 |
+
else:
|
| 519 |
+
stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
|
| 520 |
+
pool_name = 'pool'
|
| 521 |
+
stem_stride *= 2
|
| 522 |
+
prev_feat = dict(num_chs=prev_chs, reduction=stem_stride, module='.'.join(['stem', pool_name]))
|
| 523 |
+
feature_info.append(prev_feat)
|
| 524 |
+
return stem, feature_info
|
| 525 |
+
|
| 526 |
+
|
| 527 |
+
def _get_stage_fn(stage_args):
|
| 528 |
+
stage_type = stage_args.pop('stage_type')
|
| 529 |
+
assert stage_type in ('dark', 'csp', 'cs3')
|
| 530 |
+
if stage_type == 'dark':
|
| 531 |
+
stage_args.pop('expand_ratio', None)
|
| 532 |
+
stage_args.pop('cross_linear', None)
|
| 533 |
+
stage_args.pop('down_growth', None)
|
| 534 |
+
stage_fn = DarkStage
|
| 535 |
+
elif stage_type == 'csp':
|
| 536 |
+
stage_fn = CrossStage
|
| 537 |
+
else:
|
| 538 |
+
stage_fn = CrossStage3
|
| 539 |
+
return stage_fn, stage_args
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
def _get_block_fn(stage_args):
|
| 543 |
+
block_type = stage_args.pop('block_type')
|
| 544 |
+
assert block_type in ('dark', 'edge', 'bottle')
|
| 545 |
+
if block_type == 'dark':
|
| 546 |
+
return DarkBlock, stage_args
|
| 547 |
+
elif block_type == 'edge':
|
| 548 |
+
return EdgeBlock, stage_args
|
| 549 |
+
else:
|
| 550 |
+
return BottleneckBlock, stage_args
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
def _get_attn_fn(stage_args):
|
| 554 |
+
attn_layer = stage_args.pop('attn_layer')
|
| 555 |
+
attn_kwargs = stage_args.pop('attn_kwargs', None) or {}
|
| 556 |
+
if attn_layer is not None:
|
| 557 |
+
attn_layer = get_attn(attn_layer)
|
| 558 |
+
if attn_kwargs:
|
| 559 |
+
attn_layer = partial(attn_layer, **attn_kwargs)
|
| 560 |
+
return attn_layer, stage_args
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
def create_csp_stages(
|
| 564 |
+
cfg: CspModelCfg,
|
| 565 |
+
drop_path_rate: float,
|
| 566 |
+
output_stride: int,
|
| 567 |
+
stem_feat: Dict[str, Any],
|
| 568 |
+
):
|
| 569 |
+
cfg_dict = asdict(cfg.stages)
|
| 570 |
+
num_stages = len(cfg.stages.depth)
|
| 571 |
+
cfg_dict['block_dpr'] = [None] * num_stages if not drop_path_rate else \
|
| 572 |
+
[x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.stages.depth)).split(cfg.stages.depth)]
|
| 573 |
+
stage_args = [dict(zip(cfg_dict.keys(), values)) for values in zip(*cfg_dict.values())]
|
| 574 |
+
block_kwargs = dict(
|
| 575 |
+
act_layer=cfg.act_layer,
|
| 576 |
+
norm_layer=cfg.norm_layer,
|
| 577 |
+
)
|
| 578 |
+
|
| 579 |
+
dilation = 1
|
| 580 |
+
net_stride = stem_feat['reduction']
|
| 581 |
+
prev_chs = stem_feat['num_chs']
|
| 582 |
+
prev_feat = stem_feat
|
| 583 |
+
feature_info = []
|
| 584 |
+
stages = []
|
| 585 |
+
for stage_idx, stage_args in enumerate(stage_args):
|
| 586 |
+
stage_fn, stage_args = _get_stage_fn(stage_args)
|
| 587 |
+
block_fn, stage_args = _get_block_fn(stage_args)
|
| 588 |
+
attn_fn, stage_args = _get_attn_fn(stage_args)
|
| 589 |
+
stride = stage_args.pop('stride')
|
| 590 |
+
if stride != 1 and prev_feat:
|
| 591 |
+
feature_info.append(prev_feat)
|
| 592 |
+
if net_stride >= output_stride and stride > 1:
|
| 593 |
+
dilation *= stride
|
| 594 |
+
stride = 1
|
| 595 |
+
net_stride *= stride
|
| 596 |
+
first_dilation = 1 if dilation in (1, 2) else 2
|
| 597 |
+
|
| 598 |
+
stages += [stage_fn(
|
| 599 |
+
prev_chs,
|
| 600 |
+
**stage_args,
|
| 601 |
+
stride=stride,
|
| 602 |
+
first_dilation=first_dilation,
|
| 603 |
+
dilation=dilation,
|
| 604 |
+
block_fn=block_fn,
|
| 605 |
+
aa_layer=cfg.aa_layer,
|
| 606 |
+
attn_layer=attn_fn, # will be passed through stage as block_kwargs
|
| 607 |
+
**block_kwargs,
|
| 608 |
+
)]
|
| 609 |
+
prev_chs = stage_args['out_chs']
|
| 610 |
+
prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}')
|
| 611 |
+
|
| 612 |
+
feature_info.append(prev_feat)
|
| 613 |
+
return nn.Sequential(*stages), feature_info
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
class CspNet(nn.Module):
|
| 617 |
+
"""Cross Stage Partial base model.
|
| 618 |
+
|
| 619 |
+
Paper: `CSPNet: A New Backbone that can Enhance Learning Capability of CNN` - https://arxiv.org/abs/1911.11929
|
| 620 |
+
Ref Impl: https://github.com/WongKinYiu/CrossStagePartialNetworks
|
| 621 |
+
|
| 622 |
+
NOTE: There are differences in the way I handle the 1x1 'expansion' conv in this impl vs the
|
| 623 |
+
darknet impl. I did it this way for simplicity and less special cases.
|
| 624 |
+
"""
|
| 625 |
+
|
| 626 |
+
def __init__(
|
| 627 |
+
self,
|
| 628 |
+
cfg: CspModelCfg,
|
| 629 |
+
in_chans=3,
|
| 630 |
+
num_classes=1000,
|
| 631 |
+
output_stride=32,
|
| 632 |
+
global_pool='avg',
|
| 633 |
+
drop_rate=0.,
|
| 634 |
+
drop_path_rate=0.,
|
| 635 |
+
zero_init_last=True,
|
| 636 |
+
**kwargs,
|
| 637 |
+
):
|
| 638 |
+
"""
|
| 639 |
+
Args:
|
| 640 |
+
cfg (CspModelCfg): Model architecture configuration
|
| 641 |
+
in_chans (int): Number of input channels (default: 3)
|
| 642 |
+
num_classes (int): Number of classifier classes (default: 1000)
|
| 643 |
+
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
|
| 644 |
+
global_pool (str): Global pooling type (default: 'avg')
|
| 645 |
+
drop_rate (float): Dropout rate (default: 0.)
|
| 646 |
+
drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
|
| 647 |
+
zero_init_last (bool): Zero-init last weight of residual path
|
| 648 |
+
kwargs (dict): Extra kwargs overlayed onto cfg
|
| 649 |
+
"""
|
| 650 |
+
super().__init__()
|
| 651 |
+
self.num_classes = num_classes
|
| 652 |
+
self.drop_rate = drop_rate
|
| 653 |
+
assert output_stride in (8, 16, 32)
|
| 654 |
+
|
| 655 |
+
cfg = replace(cfg, **kwargs) # overlay kwargs onto cfg
|
| 656 |
+
layer_args = dict(
|
| 657 |
+
act_layer=cfg.act_layer,
|
| 658 |
+
norm_layer=cfg.norm_layer,
|
| 659 |
+
aa_layer=cfg.aa_layer
|
| 660 |
+
)
|
| 661 |
+
self.feature_info = []
|
| 662 |
+
|
| 663 |
+
# Construct the stem
|
| 664 |
+
self.stem, stem_feat_info = create_csp_stem(in_chans, **asdict(cfg.stem), **layer_args)
|
| 665 |
+
self.feature_info.extend(stem_feat_info[:-1])
|
| 666 |
+
|
| 667 |
+
# Construct the stages
|
| 668 |
+
self.stages, stage_feat_info = create_csp_stages(
|
| 669 |
+
cfg,
|
| 670 |
+
drop_path_rate=drop_path_rate,
|
| 671 |
+
output_stride=output_stride,
|
| 672 |
+
stem_feat=stem_feat_info[-1],
|
| 673 |
+
)
|
| 674 |
+
prev_chs = stage_feat_info[-1]['num_chs']
|
| 675 |
+
self.feature_info.extend(stage_feat_info)
|
| 676 |
+
|
| 677 |
+
# Construct the head
|
| 678 |
+
self.num_features = prev_chs
|
| 679 |
+
self.head = ClassifierHead(
|
| 680 |
+
in_features=prev_chs, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
| 681 |
+
|
| 682 |
+
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
|
| 683 |
+
|
| 684 |
+
@torch.jit.ignore
|
| 685 |
+
def group_matcher(self, coarse=False):
|
| 686 |
+
matcher = dict(
|
| 687 |
+
stem=r'^stem',
|
| 688 |
+
blocks=r'^stages\.(\d+)' if coarse else [
|
| 689 |
+
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
|
| 690 |
+
(r'^stages\.(\d+)\..*transition', MATCH_PREV_GROUP), # map to last block in stage
|
| 691 |
+
(r'^stages\.(\d+)', (0,)),
|
| 692 |
+
]
|
| 693 |
+
)
|
| 694 |
+
return matcher
|
| 695 |
+
|
| 696 |
+
@torch.jit.ignore
|
| 697 |
+
def set_grad_checkpointing(self, enable=True):
|
| 698 |
+
assert not enable, 'gradient checkpointing not supported'
|
| 699 |
+
|
| 700 |
+
@torch.jit.ignore
|
| 701 |
+
def get_classifier(self):
|
| 702 |
+
return self.head.fc
|
| 703 |
+
|
| 704 |
+
def reset_classifier(self, num_classes, global_pool='avg'):
|
| 705 |
+
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
| 706 |
+
|
| 707 |
+
def forward_features(self, x):
|
| 708 |
+
x = self.stem(x)
|
| 709 |
+
x = self.stages(x)
|
| 710 |
+
return x
|
| 711 |
+
|
| 712 |
+
def forward_head(self, x, pre_logits: bool = False):
|
| 713 |
+
return self.head(x, pre_logits=pre_logits)
|
| 714 |
+
|
| 715 |
+
def forward(self, x):
|
| 716 |
+
x = self.forward_features(x)
|
| 717 |
+
x = self.forward_head(x)
|
| 718 |
+
return x
|
| 719 |
+
|
| 720 |
+
|
| 721 |
+
def _init_weights(module, name, zero_init_last=False):
|
| 722 |
+
if isinstance(module, nn.Conv2d):
|
| 723 |
+
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
|
| 724 |
+
if module.bias is not None:
|
| 725 |
+
nn.init.zeros_(module.bias)
|
| 726 |
+
elif isinstance(module, nn.Linear):
|
| 727 |
+
nn.init.normal_(module.weight, mean=0.0, std=0.01)
|
| 728 |
+
if module.bias is not None:
|
| 729 |
+
nn.init.zeros_(module.bias)
|
| 730 |
+
elif zero_init_last and hasattr(module, 'zero_init_last'):
|
| 731 |
+
module.zero_init_last()
|
| 732 |
+
|
| 733 |
+
|
| 734 |
+
model_cfgs = dict(
|
| 735 |
+
cspresnet50=CspModelCfg(
|
| 736 |
+
stem=CspStemCfg(out_chs=64, kernel_size=7, stride=4, pool='max'),
|
| 737 |
+
stages=CspStagesCfg(
|
| 738 |
+
depth=(3, 3, 5, 2),
|
| 739 |
+
out_chs=(128, 256, 512, 1024),
|
| 740 |
+
stride=(1, 2),
|
| 741 |
+
expand_ratio=2.,
|
| 742 |
+
bottle_ratio=0.5,
|
| 743 |
+
cross_linear=True,
|
| 744 |
+
),
|
| 745 |
+
),
|
| 746 |
+
cspresnet50d=CspModelCfg(
|
| 747 |
+
stem=CspStemCfg(out_chs=(32, 32, 64), kernel_size=3, stride=4, pool='max'),
|
| 748 |
+
stages=CspStagesCfg(
|
| 749 |
+
depth=(3, 3, 5, 2),
|
| 750 |
+
out_chs=(128, 256, 512, 1024),
|
| 751 |
+
stride=(1,) + (2,),
|
| 752 |
+
expand_ratio=2.,
|
| 753 |
+
bottle_ratio=0.5,
|
| 754 |
+
block_ratio=1.,
|
| 755 |
+
cross_linear=True,
|
| 756 |
+
),
|
| 757 |
+
),
|
| 758 |
+
cspresnet50w=CspModelCfg(
|
| 759 |
+
stem=CspStemCfg(out_chs=(32, 32, 64), kernel_size=3, stride=4, pool='max'),
|
| 760 |
+
stages=CspStagesCfg(
|
| 761 |
+
depth=(3, 3, 5, 2),
|
| 762 |
+
out_chs=(256, 512, 1024, 2048),
|
| 763 |
+
stride=(1,) + (2,),
|
| 764 |
+
expand_ratio=1.,
|
| 765 |
+
bottle_ratio=0.25,
|
| 766 |
+
block_ratio=0.5,
|
| 767 |
+
cross_linear=True,
|
| 768 |
+
),
|
| 769 |
+
),
|
| 770 |
+
cspresnext50=CspModelCfg(
|
| 771 |
+
stem=CspStemCfg(out_chs=64, kernel_size=7, stride=4, pool='max'),
|
| 772 |
+
stages=CspStagesCfg(
|
| 773 |
+
depth=(3, 3, 5, 2),
|
| 774 |
+
out_chs=(256, 512, 1024, 2048),
|
| 775 |
+
stride=(1,) + (2,),
|
| 776 |
+
groups=32,
|
| 777 |
+
expand_ratio=1.,
|
| 778 |
+
bottle_ratio=1.,
|
| 779 |
+
block_ratio=0.5,
|
| 780 |
+
cross_linear=True,
|
| 781 |
+
),
|
| 782 |
+
),
|
| 783 |
+
cspdarknet53=CspModelCfg(
|
| 784 |
+
stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
|
| 785 |
+
stages=CspStagesCfg(
|
| 786 |
+
depth=(1, 2, 8, 8, 4),
|
| 787 |
+
out_chs=(64, 128, 256, 512, 1024),
|
| 788 |
+
stride=2,
|
| 789 |
+
expand_ratio=(2.,) + (1.,),
|
| 790 |
+
bottle_ratio=(0.5,) + (1.,),
|
| 791 |
+
block_ratio=(1.,) + (0.5,),
|
| 792 |
+
down_growth=True,
|
| 793 |
+
block_type='dark',
|
| 794 |
+
),
|
| 795 |
+
),
|
| 796 |
+
darknet17=CspModelCfg(
|
| 797 |
+
stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
|
| 798 |
+
stages=CspStagesCfg(
|
| 799 |
+
depth=(1,) * 5,
|
| 800 |
+
out_chs=(64, 128, 256, 512, 1024),
|
| 801 |
+
stride=(2,),
|
| 802 |
+
bottle_ratio=(0.5,),
|
| 803 |
+
block_ratio=(1.,),
|
| 804 |
+
stage_type='dark',
|
| 805 |
+
block_type='dark',
|
| 806 |
+
),
|
| 807 |
+
),
|
| 808 |
+
darknet21=CspModelCfg(
|
| 809 |
+
stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
|
| 810 |
+
stages=CspStagesCfg(
|
| 811 |
+
depth=(1, 1, 1, 2, 2),
|
| 812 |
+
out_chs=(64, 128, 256, 512, 1024),
|
| 813 |
+
stride=(2,),
|
| 814 |
+
bottle_ratio=(0.5,),
|
| 815 |
+
block_ratio=(1.,),
|
| 816 |
+
stage_type='dark',
|
| 817 |
+
block_type='dark',
|
| 818 |
+
|
| 819 |
+
),
|
| 820 |
+
),
|
| 821 |
+
sedarknet21=CspModelCfg(
|
| 822 |
+
stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
|
| 823 |
+
stages=CspStagesCfg(
|
| 824 |
+
depth=(1, 1, 1, 2, 2),
|
| 825 |
+
out_chs=(64, 128, 256, 512, 1024),
|
| 826 |
+
stride=2,
|
| 827 |
+
bottle_ratio=0.5,
|
| 828 |
+
block_ratio=1.,
|
| 829 |
+
attn_layer='se',
|
| 830 |
+
stage_type='dark',
|
| 831 |
+
block_type='dark',
|
| 832 |
+
|
| 833 |
+
),
|
| 834 |
+
),
|
| 835 |
+
darknet53=CspModelCfg(
|
| 836 |
+
stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
|
| 837 |
+
stages=CspStagesCfg(
|
| 838 |
+
depth=(1, 2, 8, 8, 4),
|
| 839 |
+
out_chs=(64, 128, 256, 512, 1024),
|
| 840 |
+
stride=2,
|
| 841 |
+
bottle_ratio=0.5,
|
| 842 |
+
block_ratio=1.,
|
| 843 |
+
stage_type='dark',
|
| 844 |
+
block_type='dark',
|
| 845 |
+
),
|
| 846 |
+
),
|
| 847 |
+
darknetaa53=CspModelCfg(
|
| 848 |
+
stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
|
| 849 |
+
stages=CspStagesCfg(
|
| 850 |
+
depth=(1, 2, 8, 8, 4),
|
| 851 |
+
out_chs=(64, 128, 256, 512, 1024),
|
| 852 |
+
stride=2,
|
| 853 |
+
bottle_ratio=0.5,
|
| 854 |
+
block_ratio=1.,
|
| 855 |
+
avg_down=True,
|
| 856 |
+
stage_type='dark',
|
| 857 |
+
block_type='dark',
|
| 858 |
+
),
|
| 859 |
+
),
|
| 860 |
+
|
| 861 |
+
cs3darknet_s=_cs3_cfg(width_multiplier=0.5, depth_multiplier=0.5),
|
| 862 |
+
cs3darknet_m=_cs3_cfg(width_multiplier=0.75, depth_multiplier=0.67),
|
| 863 |
+
cs3darknet_l=_cs3_cfg(),
|
| 864 |
+
cs3darknet_x=_cs3_cfg(width_multiplier=1.25, depth_multiplier=1.33),
|
| 865 |
+
|
| 866 |
+
cs3darknet_focus_s=_cs3_cfg(width_multiplier=0.5, depth_multiplier=0.5, focus=True),
|
| 867 |
+
cs3darknet_focus_m=_cs3_cfg(width_multiplier=0.75, depth_multiplier=0.67, focus=True),
|
| 868 |
+
cs3darknet_focus_l=_cs3_cfg(focus=True),
|
| 869 |
+
cs3darknet_focus_x=_cs3_cfg(width_multiplier=1.25, depth_multiplier=1.33, focus=True),
|
| 870 |
+
|
| 871 |
+
cs3sedarknet_l=_cs3_cfg(attn_layer='se', attn_kwargs=dict(rd_ratio=.25)),
|
| 872 |
+
cs3sedarknet_x=_cs3_cfg(attn_layer='se', width_multiplier=1.25, depth_multiplier=1.33),
|
| 873 |
+
|
| 874 |
+
cs3sedarknet_xdw=CspModelCfg(
|
| 875 |
+
stem=CspStemCfg(out_chs=(32, 64), kernel_size=3, stride=2, pool=''),
|
| 876 |
+
stages=CspStagesCfg(
|
| 877 |
+
depth=(3, 6, 12, 4),
|
| 878 |
+
out_chs=(256, 512, 1024, 2048),
|
| 879 |
+
stride=2,
|
| 880 |
+
groups=(1, 1, 256, 512),
|
| 881 |
+
bottle_ratio=0.5,
|
| 882 |
+
block_ratio=0.5,
|
| 883 |
+
attn_layer='se',
|
| 884 |
+
),
|
| 885 |
+
act_layer='silu',
|
| 886 |
+
),
|
| 887 |
+
|
| 888 |
+
cs3edgenet_x=_cs3_cfg(width_multiplier=1.25, depth_multiplier=1.33, bottle_ratio=1.5, block_type='edge'),
|
| 889 |
+
cs3se_edgenet_x=_cs3_cfg(
|
| 890 |
+
width_multiplier=1.25, depth_multiplier=1.33, bottle_ratio=1.5, block_type='edge',
|
| 891 |
+
attn_layer='se', attn_kwargs=dict(rd_ratio=.25)),
|
| 892 |
+
)
|
| 893 |
+
|
| 894 |
+
|
| 895 |
+
def _create_cspnet(variant, pretrained=False, **kwargs):
|
| 896 |
+
if variant.startswith('darknet') or variant.startswith('cspdarknet'):
|
| 897 |
+
# NOTE: DarkNet is one of few models with stride==1 features w/ 6 out_indices [0..5]
|
| 898 |
+
default_out_indices = (0, 1, 2, 3, 4, 5)
|
| 899 |
+
else:
|
| 900 |
+
default_out_indices = (0, 1, 2, 3, 4)
|
| 901 |
+
out_indices = kwargs.pop('out_indices', default_out_indices)
|
| 902 |
+
return build_model_with_cfg(
|
| 903 |
+
CspNet, variant, pretrained,
|
| 904 |
+
model_cfg=model_cfgs[variant],
|
| 905 |
+
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
| 906 |
+
**kwargs)
|
| 907 |
+
|
| 908 |
+
|
| 909 |
+
def _cfg(url='', **kwargs):
|
| 910 |
+
return {
|
| 911 |
+
'url': url,
|
| 912 |
+
'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
|
| 913 |
+
'crop_pct': 0.887, 'interpolation': 'bilinear',
|
| 914 |
+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
| 915 |
+
'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc',
|
| 916 |
+
**kwargs
|
| 917 |
+
}
|
| 918 |
+
|
| 919 |
+
|
| 920 |
+
default_cfgs = generate_default_cfgs({
|
| 921 |
+
'cspresnet50.ra_in1k': _cfg(
|
| 922 |
+
hf_hub_id='timm/',
|
| 923 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspresnet50_ra-d3e8d487.pth'),
|
| 924 |
+
'cspresnet50d.untrained': _cfg(),
|
| 925 |
+
'cspresnet50w.untrained': _cfg(),
|
| 926 |
+
'cspresnext50.ra_in1k': _cfg(
|
| 927 |
+
hf_hub_id='timm/',
|
| 928 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspresnext50_ra_224-648b4713.pth',
|
| 929 |
+
),
|
| 930 |
+
'cspdarknet53.ra_in1k': _cfg(
|
| 931 |
+
hf_hub_id='timm/',
|
| 932 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspdarknet53_ra_256-d05c7c21.pth'),
|
| 933 |
+
|
| 934 |
+
'darknet17.untrained': _cfg(),
|
| 935 |
+
'darknet21.untrained': _cfg(),
|
| 936 |
+
'sedarknet21.untrained': _cfg(),
|
| 937 |
+
'darknet53.c2ns_in1k': _cfg(
|
| 938 |
+
hf_hub_id='timm/',
|
| 939 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/darknet53_256_c2ns-3aeff817.pth',
|
| 940 |
+
interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 941 |
+
'darknetaa53.c2ns_in1k': _cfg(
|
| 942 |
+
hf_hub_id='timm/',
|
| 943 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/darknetaa53_c2ns-5c28ec8a.pth',
|
| 944 |
+
test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 945 |
+
|
| 946 |
+
'cs3darknet_s.untrained': _cfg(interpolation='bicubic'),
|
| 947 |
+
'cs3darknet_m.c2ns_in1k': _cfg(
|
| 948 |
+
hf_hub_id='timm/',
|
| 949 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_m_c2ns-43f06604.pth',
|
| 950 |
+
interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95,
|
| 951 |
+
),
|
| 952 |
+
'cs3darknet_l.c2ns_in1k': _cfg(
|
| 953 |
+
hf_hub_id='timm/',
|
| 954 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_l_c2ns-16220c5d.pth',
|
| 955 |
+
interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
| 956 |
+
'cs3darknet_x.c2ns_in1k': _cfg(
|
| 957 |
+
hf_hub_id='timm/',
|
| 958 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_x_c2ns-4e4490aa.pth',
|
| 959 |
+
interpolation='bicubic', crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 960 |
+
|
| 961 |
+
'cs3darknet_focus_s.untrained': _cfg(interpolation='bicubic'),
|
| 962 |
+
'cs3darknet_focus_m.c2ns_in1k': _cfg(
|
| 963 |
+
hf_hub_id='timm/',
|
| 964 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_focus_m_c2ns-e23bed41.pth',
|
| 965 |
+
interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
| 966 |
+
'cs3darknet_focus_l.c2ns_in1k': _cfg(
|
| 967 |
+
hf_hub_id='timm/',
|
| 968 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_focus_l_c2ns-65ef8888.pth',
|
| 969 |
+
interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
| 970 |
+
'cs3darknet_focus_x.untrained': _cfg(interpolation='bicubic'),
|
| 971 |
+
|
| 972 |
+
'cs3sedarknet_l.c2ns_in1k': _cfg(
|
| 973 |
+
hf_hub_id='timm/',
|
| 974 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3sedarknet_l_c2ns-e8d1dc13.pth',
|
| 975 |
+
interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
|
| 976 |
+
'cs3sedarknet_x.c2ns_in1k': _cfg(
|
| 977 |
+
hf_hub_id='timm/',
|
| 978 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3sedarknet_x_c2ns-b4d0abc0.pth',
|
| 979 |
+
interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 980 |
+
|
| 981 |
+
'cs3sedarknet_xdw.untrained': _cfg(interpolation='bicubic'),
|
| 982 |
+
|
| 983 |
+
'cs3edgenet_x.c2_in1k': _cfg(
|
| 984 |
+
hf_hub_id='timm/',
|
| 985 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3edgenet_x_c2-2e1610a9.pth',
|
| 986 |
+
interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=1.0),
|
| 987 |
+
'cs3se_edgenet_x.c2ns_in1k': _cfg(
|
| 988 |
+
hf_hub_id='timm/',
|
| 989 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3se_edgenet_x_c2ns-76f8e3ac.pth',
|
| 990 |
+
interpolation='bicubic', crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0),
|
| 991 |
+
})
|
| 992 |
+
|
| 993 |
+
|
| 994 |
+
@register_model
|
| 995 |
+
def cspresnet50(pretrained=False, **kwargs) -> CspNet:
|
| 996 |
+
return _create_cspnet('cspresnet50', pretrained=pretrained, **kwargs)
|
| 997 |
+
|
| 998 |
+
|
| 999 |
+
@register_model
|
| 1000 |
+
def cspresnet50d(pretrained=False, **kwargs) -> CspNet:
|
| 1001 |
+
return _create_cspnet('cspresnet50d', pretrained=pretrained, **kwargs)
|
| 1002 |
+
|
| 1003 |
+
|
| 1004 |
+
@register_model
|
| 1005 |
+
def cspresnet50w(pretrained=False, **kwargs) -> CspNet:
|
| 1006 |
+
return _create_cspnet('cspresnet50w', pretrained=pretrained, **kwargs)
|
| 1007 |
+
|
| 1008 |
+
|
| 1009 |
+
@register_model
|
| 1010 |
+
def cspresnext50(pretrained=False, **kwargs) -> CspNet:
|
| 1011 |
+
return _create_cspnet('cspresnext50', pretrained=pretrained, **kwargs)
|
| 1012 |
+
|
| 1013 |
+
|
| 1014 |
+
@register_model
|
| 1015 |
+
def cspdarknet53(pretrained=False, **kwargs) -> CspNet:
|
| 1016 |
+
return _create_cspnet('cspdarknet53', pretrained=pretrained, **kwargs)
|
| 1017 |
+
|
| 1018 |
+
|
| 1019 |
+
@register_model
|
| 1020 |
+
def darknet17(pretrained=False, **kwargs) -> CspNet:
|
| 1021 |
+
return _create_cspnet('darknet17', pretrained=pretrained, **kwargs)
|
| 1022 |
+
|
| 1023 |
+
|
| 1024 |
+
@register_model
|
| 1025 |
+
def darknet21(pretrained=False, **kwargs) -> CspNet:
|
| 1026 |
+
return _create_cspnet('darknet21', pretrained=pretrained, **kwargs)
|
| 1027 |
+
|
| 1028 |
+
|
| 1029 |
+
@register_model
|
| 1030 |
+
def sedarknet21(pretrained=False, **kwargs) -> CspNet:
|
| 1031 |
+
return _create_cspnet('sedarknet21', pretrained=pretrained, **kwargs)
|
| 1032 |
+
|
| 1033 |
+
|
| 1034 |
+
@register_model
|
| 1035 |
+
def darknet53(pretrained=False, **kwargs) -> CspNet:
|
| 1036 |
+
return _create_cspnet('darknet53', pretrained=pretrained, **kwargs)
|
| 1037 |
+
|
| 1038 |
+
|
| 1039 |
+
@register_model
|
| 1040 |
+
def darknetaa53(pretrained=False, **kwargs) -> CspNet:
|
| 1041 |
+
return _create_cspnet('darknetaa53', pretrained=pretrained, **kwargs)
|
| 1042 |
+
|
| 1043 |
+
|
| 1044 |
+
@register_model
|
| 1045 |
+
def cs3darknet_s(pretrained=False, **kwargs) -> CspNet:
|
| 1046 |
+
return _create_cspnet('cs3darknet_s', pretrained=pretrained, **kwargs)
|
| 1047 |
+
|
| 1048 |
+
|
| 1049 |
+
@register_model
|
| 1050 |
+
def cs3darknet_m(pretrained=False, **kwargs) -> CspNet:
|
| 1051 |
+
return _create_cspnet('cs3darknet_m', pretrained=pretrained, **kwargs)
|
| 1052 |
+
|
| 1053 |
+
|
| 1054 |
+
@register_model
|
| 1055 |
+
def cs3darknet_l(pretrained=False, **kwargs) -> CspNet:
|
| 1056 |
+
return _create_cspnet('cs3darknet_l', pretrained=pretrained, **kwargs)
|
| 1057 |
+
|
| 1058 |
+
|
| 1059 |
+
@register_model
|
| 1060 |
+
def cs3darknet_x(pretrained=False, **kwargs) -> CspNet:
|
| 1061 |
+
return _create_cspnet('cs3darknet_x', pretrained=pretrained, **kwargs)
|
| 1062 |
+
|
| 1063 |
+
|
| 1064 |
+
@register_model
|
| 1065 |
+
def cs3darknet_focus_s(pretrained=False, **kwargs) -> CspNet:
|
| 1066 |
+
return _create_cspnet('cs3darknet_focus_s', pretrained=pretrained, **kwargs)
|
| 1067 |
+
|
| 1068 |
+
|
| 1069 |
+
@register_model
|
| 1070 |
+
def cs3darknet_focus_m(pretrained=False, **kwargs) -> CspNet:
|
| 1071 |
+
return _create_cspnet('cs3darknet_focus_m', pretrained=pretrained, **kwargs)
|
| 1072 |
+
|
| 1073 |
+
|
| 1074 |
+
@register_model
|
| 1075 |
+
def cs3darknet_focus_l(pretrained=False, **kwargs) -> CspNet:
|
| 1076 |
+
return _create_cspnet('cs3darknet_focus_l', pretrained=pretrained, **kwargs)
|
| 1077 |
+
|
| 1078 |
+
|
| 1079 |
+
@register_model
|
| 1080 |
+
def cs3darknet_focus_x(pretrained=False, **kwargs) -> CspNet:
|
| 1081 |
+
return _create_cspnet('cs3darknet_focus_x', pretrained=pretrained, **kwargs)
|
| 1082 |
+
|
| 1083 |
+
|
| 1084 |
+
@register_model
|
| 1085 |
+
def cs3sedarknet_l(pretrained=False, **kwargs) -> CspNet:
|
| 1086 |
+
return _create_cspnet('cs3sedarknet_l', pretrained=pretrained, **kwargs)
|
| 1087 |
+
|
| 1088 |
+
|
| 1089 |
+
@register_model
|
| 1090 |
+
def cs3sedarknet_x(pretrained=False, **kwargs) -> CspNet:
|
| 1091 |
+
return _create_cspnet('cs3sedarknet_x', pretrained=pretrained, **kwargs)
|
| 1092 |
+
|
| 1093 |
+
|
| 1094 |
+
@register_model
|
| 1095 |
+
def cs3sedarknet_xdw(pretrained=False, **kwargs) -> CspNet:
|
| 1096 |
+
return _create_cspnet('cs3sedarknet_xdw', pretrained=pretrained, **kwargs)
|
| 1097 |
+
|
| 1098 |
+
|
| 1099 |
+
@register_model
|
| 1100 |
+
def cs3edgenet_x(pretrained=False, **kwargs) -> CspNet:
|
| 1101 |
+
return _create_cspnet('cs3edgenet_x', pretrained=pretrained, **kwargs)
|
| 1102 |
+
|
| 1103 |
+
|
| 1104 |
+
@register_model
|
| 1105 |
+
def cs3se_edgenet_x(pretrained=False, **kwargs) -> CspNet:
|
| 1106 |
+
return _create_cspnet('cs3se_edgenet_x', pretrained=pretrained, **kwargs)
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/deit.py
ADDED
|
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" DeiT - Data-efficient Image Transformers
|
| 2 |
+
|
| 3 |
+
DeiT model defs and weights from https://github.com/facebookresearch/deit, original copyright below
|
| 4 |
+
|
| 5 |
+
paper: `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
|
| 6 |
+
|
| 7 |
+
paper: `DeiT III: Revenge of the ViT` - https://arxiv.org/abs/2204.07118
|
| 8 |
+
|
| 9 |
+
Modifications copyright 2021, Ross Wightman
|
| 10 |
+
"""
|
| 11 |
+
# Copyright (c) 2015-present, Facebook, Inc.
|
| 12 |
+
# All rights reserved.
|
| 13 |
+
from functools import partial
|
| 14 |
+
from typing import Sequence, Union
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from torch import nn as nn
|
| 18 |
+
|
| 19 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 20 |
+
from timm.layers import resample_abs_pos_embed
|
| 21 |
+
from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn
|
| 22 |
+
from ._builder import build_model_with_cfg
|
| 23 |
+
from ._manipulate import checkpoint_seq
|
| 24 |
+
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
|
| 25 |
+
|
| 26 |
+
__all__ = ['VisionTransformerDistilled'] # model_registry will add each entrypoint fn to this
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class VisionTransformerDistilled(VisionTransformer):
|
| 30 |
+
""" Vision Transformer w/ Distillation Token and Head
|
| 31 |
+
|
| 32 |
+
Distillation token & head support for `DeiT: Data-efficient Image Transformers`
|
| 33 |
+
- https://arxiv.org/abs/2012.12877
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self, *args, **kwargs):
|
| 37 |
+
weight_init = kwargs.pop('weight_init', '')
|
| 38 |
+
super().__init__(*args, **kwargs, weight_init='skip')
|
| 39 |
+
assert self.global_pool in ('token',)
|
| 40 |
+
|
| 41 |
+
self.num_prefix_tokens = 2
|
| 42 |
+
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
| 43 |
+
self.pos_embed = nn.Parameter(
|
| 44 |
+
torch.zeros(1, self.patch_embed.num_patches + self.num_prefix_tokens, self.embed_dim))
|
| 45 |
+
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
|
| 46 |
+
self.distilled_training = False # must set this True to train w/ distillation token
|
| 47 |
+
|
| 48 |
+
self.init_weights(weight_init)
|
| 49 |
+
|
| 50 |
+
def init_weights(self, mode=''):
|
| 51 |
+
trunc_normal_(self.dist_token, std=.02)
|
| 52 |
+
super().init_weights(mode=mode)
|
| 53 |
+
|
| 54 |
+
@torch.jit.ignore
|
| 55 |
+
def group_matcher(self, coarse=False):
|
| 56 |
+
return dict(
|
| 57 |
+
stem=r'^cls_token|pos_embed|patch_embed|dist_token',
|
| 58 |
+
blocks=[
|
| 59 |
+
(r'^blocks\.(\d+)', None),
|
| 60 |
+
(r'^norm', (99999,))] # final norm w/ last block
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
@torch.jit.ignore
|
| 64 |
+
def get_classifier(self):
|
| 65 |
+
return self.head, self.head_dist
|
| 66 |
+
|
| 67 |
+
def reset_classifier(self, num_classes, global_pool=None):
|
| 68 |
+
self.num_classes = num_classes
|
| 69 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
| 70 |
+
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
| 71 |
+
|
| 72 |
+
@torch.jit.ignore
|
| 73 |
+
def set_distilled_training(self, enable=True):
|
| 74 |
+
self.distilled_training = enable
|
| 75 |
+
|
| 76 |
+
def _pos_embed(self, x):
|
| 77 |
+
if self.dynamic_img_size:
|
| 78 |
+
B, H, W, C = x.shape
|
| 79 |
+
pos_embed = resample_abs_pos_embed(
|
| 80 |
+
self.pos_embed,
|
| 81 |
+
(H, W),
|
| 82 |
+
num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
|
| 83 |
+
)
|
| 84 |
+
x = x.view(B, -1, C)
|
| 85 |
+
else:
|
| 86 |
+
pos_embed = self.pos_embed
|
| 87 |
+
if self.no_embed_class:
|
| 88 |
+
# deit-3, updated JAX (big vision)
|
| 89 |
+
# position embedding does not overlap with class token, add then concat
|
| 90 |
+
x = x + pos_embed
|
| 91 |
+
x = torch.cat((
|
| 92 |
+
self.cls_token.expand(x.shape[0], -1, -1),
|
| 93 |
+
self.dist_token.expand(x.shape[0], -1, -1),
|
| 94 |
+
x),
|
| 95 |
+
dim=1)
|
| 96 |
+
else:
|
| 97 |
+
# original timm, JAX, and deit vit impl
|
| 98 |
+
# pos_embed has entry for class token, concat then add
|
| 99 |
+
x = torch.cat((
|
| 100 |
+
self.cls_token.expand(x.shape[0], -1, -1),
|
| 101 |
+
self.dist_token.expand(x.shape[0], -1, -1),
|
| 102 |
+
x),
|
| 103 |
+
dim=1)
|
| 104 |
+
x = x + pos_embed
|
| 105 |
+
return self.pos_drop(x)
|
| 106 |
+
|
| 107 |
+
def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
|
| 108 |
+
x, x_dist = x[:, 0], x[:, 1]
|
| 109 |
+
if pre_logits:
|
| 110 |
+
return (x + x_dist) / 2
|
| 111 |
+
x = self.head(x)
|
| 112 |
+
x_dist = self.head_dist(x_dist)
|
| 113 |
+
if self.distilled_training and self.training and not torch.jit.is_scripting():
|
| 114 |
+
# only return separate classification predictions when training in distilled mode
|
| 115 |
+
return x, x_dist
|
| 116 |
+
else:
|
| 117 |
+
# during standard train / finetune, inference average the classifier predictions
|
| 118 |
+
return (x + x_dist) / 2
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def _create_deit(variant, pretrained=False, distilled=False, **kwargs):
|
| 122 |
+
if kwargs.get('features_only', None):
|
| 123 |
+
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
| 124 |
+
model_cls = VisionTransformerDistilled if distilled else VisionTransformer
|
| 125 |
+
model = build_model_with_cfg(
|
| 126 |
+
model_cls,
|
| 127 |
+
variant,
|
| 128 |
+
pretrained,
|
| 129 |
+
pretrained_filter_fn=partial(checkpoint_filter_fn, adapt_layer_scale=True),
|
| 130 |
+
**kwargs,
|
| 131 |
+
)
|
| 132 |
+
return model
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def _cfg(url='', **kwargs):
|
| 136 |
+
return {
|
| 137 |
+
'url': url,
|
| 138 |
+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
| 139 |
+
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
| 140 |
+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
| 141 |
+
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
| 142 |
+
**kwargs
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
default_cfgs = generate_default_cfgs({
|
| 147 |
+
# deit models (FB weights)
|
| 148 |
+
'deit_tiny_patch16_224.fb_in1k': _cfg(
|
| 149 |
+
hf_hub_id='timm/',
|
| 150 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
|
| 151 |
+
'deit_small_patch16_224.fb_in1k': _cfg(
|
| 152 |
+
hf_hub_id='timm/',
|
| 153 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
|
| 154 |
+
'deit_base_patch16_224.fb_in1k': _cfg(
|
| 155 |
+
hf_hub_id='timm/',
|
| 156 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth'),
|
| 157 |
+
'deit_base_patch16_384.fb_in1k': _cfg(
|
| 158 |
+
hf_hub_id='timm/',
|
| 159 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
|
| 160 |
+
input_size=(3, 384, 384), crop_pct=1.0),
|
| 161 |
+
|
| 162 |
+
'deit_tiny_distilled_patch16_224.fb_in1k': _cfg(
|
| 163 |
+
hf_hub_id='timm/',
|
| 164 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
|
| 165 |
+
classifier=('head', 'head_dist')),
|
| 166 |
+
'deit_small_distilled_patch16_224.fb_in1k': _cfg(
|
| 167 |
+
hf_hub_id='timm/',
|
| 168 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
|
| 169 |
+
classifier=('head', 'head_dist')),
|
| 170 |
+
'deit_base_distilled_patch16_224.fb_in1k': _cfg(
|
| 171 |
+
hf_hub_id='timm/',
|
| 172 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
|
| 173 |
+
classifier=('head', 'head_dist')),
|
| 174 |
+
'deit_base_distilled_patch16_384.fb_in1k': _cfg(
|
| 175 |
+
hf_hub_id='timm/',
|
| 176 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
|
| 177 |
+
input_size=(3, 384, 384), crop_pct=1.0,
|
| 178 |
+
classifier=('head', 'head_dist')),
|
| 179 |
+
|
| 180 |
+
'deit3_small_patch16_224.fb_in1k': _cfg(
|
| 181 |
+
hf_hub_id='timm/',
|
| 182 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_1k.pth'),
|
| 183 |
+
'deit3_small_patch16_384.fb_in1k': _cfg(
|
| 184 |
+
hf_hub_id='timm/',
|
| 185 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_1k.pth',
|
| 186 |
+
input_size=(3, 384, 384), crop_pct=1.0),
|
| 187 |
+
'deit3_medium_patch16_224.fb_in1k': _cfg(
|
| 188 |
+
hf_hub_id='timm/',
|
| 189 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_3_medium_224_1k.pth'),
|
| 190 |
+
'deit3_base_patch16_224.fb_in1k': _cfg(
|
| 191 |
+
hf_hub_id='timm/',
|
| 192 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_1k.pth'),
|
| 193 |
+
'deit3_base_patch16_384.fb_in1k': _cfg(
|
| 194 |
+
hf_hub_id='timm/',
|
| 195 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_1k.pth',
|
| 196 |
+
input_size=(3, 384, 384), crop_pct=1.0),
|
| 197 |
+
'deit3_large_patch16_224.fb_in1k': _cfg(
|
| 198 |
+
hf_hub_id='timm/',
|
| 199 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_1k.pth'),
|
| 200 |
+
'deit3_large_patch16_384.fb_in1k': _cfg(
|
| 201 |
+
hf_hub_id='timm/',
|
| 202 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_1k.pth',
|
| 203 |
+
input_size=(3, 384, 384), crop_pct=1.0),
|
| 204 |
+
'deit3_huge_patch14_224.fb_in1k': _cfg(
|
| 205 |
+
hf_hub_id='timm/',
|
| 206 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_1k.pth'),
|
| 207 |
+
|
| 208 |
+
'deit3_small_patch16_224.fb_in22k_ft_in1k': _cfg(
|
| 209 |
+
hf_hub_id='timm/',
|
| 210 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_21k.pth',
|
| 211 |
+
crop_pct=1.0),
|
| 212 |
+
'deit3_small_patch16_384.fb_in22k_ft_in1k': _cfg(
|
| 213 |
+
hf_hub_id='timm/',
|
| 214 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_21k.pth',
|
| 215 |
+
input_size=(3, 384, 384), crop_pct=1.0),
|
| 216 |
+
'deit3_medium_patch16_224.fb_in22k_ft_in1k': _cfg(
|
| 217 |
+
hf_hub_id='timm/',
|
| 218 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_3_medium_224_21k.pth',
|
| 219 |
+
crop_pct=1.0),
|
| 220 |
+
'deit3_base_patch16_224.fb_in22k_ft_in1k': _cfg(
|
| 221 |
+
hf_hub_id='timm/',
|
| 222 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_21k.pth',
|
| 223 |
+
crop_pct=1.0),
|
| 224 |
+
'deit3_base_patch16_384.fb_in22k_ft_in1k': _cfg(
|
| 225 |
+
hf_hub_id='timm/',
|
| 226 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_21k.pth',
|
| 227 |
+
input_size=(3, 384, 384), crop_pct=1.0),
|
| 228 |
+
'deit3_large_patch16_224.fb_in22k_ft_in1k': _cfg(
|
| 229 |
+
hf_hub_id='timm/',
|
| 230 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_21k.pth',
|
| 231 |
+
crop_pct=1.0),
|
| 232 |
+
'deit3_large_patch16_384.fb_in22k_ft_in1k': _cfg(
|
| 233 |
+
hf_hub_id='timm/',
|
| 234 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_21k.pth',
|
| 235 |
+
input_size=(3, 384, 384), crop_pct=1.0),
|
| 236 |
+
'deit3_huge_patch14_224.fb_in22k_ft_in1k': _cfg(
|
| 237 |
+
hf_hub_id='timm/',
|
| 238 |
+
url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_21k_v1.pth',
|
| 239 |
+
crop_pct=1.0),
|
| 240 |
+
})
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
@register_model
|
| 244 |
+
def deit_tiny_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
|
| 245 |
+
""" DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
| 246 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
| 247 |
+
"""
|
| 248 |
+
model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3)
|
| 249 |
+
model = _create_deit('deit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 250 |
+
return model
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
@register_model
|
| 254 |
+
def deit_small_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
|
| 255 |
+
""" DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
| 256 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
| 257 |
+
"""
|
| 258 |
+
model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6)
|
| 259 |
+
model = _create_deit('deit_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 260 |
+
return model
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
@register_model
|
| 264 |
+
def deit_base_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
|
| 265 |
+
""" DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
| 266 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
| 267 |
+
"""
|
| 268 |
+
model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
|
| 269 |
+
model = _create_deit('deit_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 270 |
+
return model
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
@register_model
|
| 274 |
+
def deit_base_patch16_384(pretrained=False, **kwargs) -> VisionTransformer:
|
| 275 |
+
""" DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
|
| 276 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
| 277 |
+
"""
|
| 278 |
+
model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
|
| 279 |
+
model = _create_deit('deit_base_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 280 |
+
return model
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
@register_model
|
| 284 |
+
def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs) -> VisionTransformerDistilled:
|
| 285 |
+
""" DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
| 286 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
| 287 |
+
"""
|
| 288 |
+
model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3)
|
| 289 |
+
model = _create_deit(
|
| 290 |
+
'deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **dict(model_args, **kwargs))
|
| 291 |
+
return model
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
@register_model
|
| 295 |
+
def deit_small_distilled_patch16_224(pretrained=False, **kwargs) -> VisionTransformerDistilled:
|
| 296 |
+
""" DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
| 297 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
| 298 |
+
"""
|
| 299 |
+
model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6)
|
| 300 |
+
model = _create_deit(
|
| 301 |
+
'deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **dict(model_args, **kwargs))
|
| 302 |
+
return model
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
@register_model
|
| 306 |
+
def deit_base_distilled_patch16_224(pretrained=False, **kwargs) -> VisionTransformerDistilled:
|
| 307 |
+
""" DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
| 308 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
| 309 |
+
"""
|
| 310 |
+
model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
|
| 311 |
+
model = _create_deit(
|
| 312 |
+
'deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **dict(model_args, **kwargs))
|
| 313 |
+
return model
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
@register_model
|
| 317 |
+
def deit_base_distilled_patch16_384(pretrained=False, **kwargs) -> VisionTransformerDistilled:
|
| 318 |
+
""" DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
|
| 319 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
| 320 |
+
"""
|
| 321 |
+
model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
|
| 322 |
+
model = _create_deit(
|
| 323 |
+
'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **dict(model_args, **kwargs))
|
| 324 |
+
return model
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
@register_model
|
| 328 |
+
def deit3_small_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
|
| 329 |
+
""" DeiT-3 small model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
|
| 330 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
| 331 |
+
"""
|
| 332 |
+
model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6)
|
| 333 |
+
model = _create_deit('deit3_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 334 |
+
return model
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
@register_model
|
| 338 |
+
def deit3_small_patch16_384(pretrained=False, **kwargs) -> VisionTransformer:
|
| 339 |
+
""" DeiT-3 small model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
| 340 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
| 341 |
+
"""
|
| 342 |
+
model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6)
|
| 343 |
+
model = _create_deit('deit3_small_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 344 |
+
return model
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
@register_model
|
| 348 |
+
def deit3_medium_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
|
| 349 |
+
""" DeiT-3 medium model @ 224x224 (https://arxiv.org/abs/2012.12877).
|
| 350 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
| 351 |
+
"""
|
| 352 |
+
model_args = dict(patch_size=16, embed_dim=512, depth=12, num_heads=8, no_embed_class=True, init_values=1e-6)
|
| 353 |
+
model = _create_deit('deit3_medium_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 354 |
+
return model
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
@register_model
|
| 358 |
+
def deit3_base_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
|
| 359 |
+
""" DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
|
| 360 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
| 361 |
+
"""
|
| 362 |
+
model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6)
|
| 363 |
+
model = _create_deit('deit3_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 364 |
+
return model
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
@register_model
|
| 368 |
+
def deit3_base_patch16_384(pretrained=False, **kwargs) -> VisionTransformer:
|
| 369 |
+
""" DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
| 370 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
| 371 |
+
"""
|
| 372 |
+
model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6)
|
| 373 |
+
model = _create_deit('deit3_base_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 374 |
+
return model
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
@register_model
|
| 378 |
+
def deit3_large_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
|
| 379 |
+
""" DeiT-3 large model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
|
| 380 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
| 381 |
+
"""
|
| 382 |
+
model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6)
|
| 383 |
+
model = _create_deit('deit3_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 384 |
+
return model
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
@register_model
|
| 388 |
+
def deit3_large_patch16_384(pretrained=False, **kwargs) -> VisionTransformer:
|
| 389 |
+
""" DeiT-3 large model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
| 390 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
| 391 |
+
"""
|
| 392 |
+
model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6)
|
| 393 |
+
model = _create_deit('deit3_large_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 394 |
+
return model
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
@register_model
|
| 398 |
+
def deit3_huge_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
|
| 399 |
+
""" DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
| 400 |
+
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
| 401 |
+
"""
|
| 402 |
+
model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, no_embed_class=True, init_values=1e-6)
|
| 403 |
+
model = _create_deit('deit3_huge_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 404 |
+
return model
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
register_model_deprecations(__name__, {
|
| 408 |
+
'deit3_small_patch16_224_in21ft1k': 'deit3_small_patch16_224.fb_in22k_ft_in1k',
|
| 409 |
+
'deit3_small_patch16_384_in21ft1k': 'deit3_small_patch16_384.fb_in22k_ft_in1k',
|
| 410 |
+
'deit3_medium_patch16_224_in21ft1k': 'deit3_medium_patch16_224.fb_in22k_ft_in1k',
|
| 411 |
+
'deit3_base_patch16_224_in21ft1k': 'deit3_base_patch16_224.fb_in22k_ft_in1k',
|
| 412 |
+
'deit3_base_patch16_384_in21ft1k': 'deit3_base_patch16_384.fb_in22k_ft_in1k',
|
| 413 |
+
'deit3_large_patch16_224_in21ft1k': 'deit3_large_patch16_224.fb_in22k_ft_in1k',
|
| 414 |
+
'deit3_large_patch16_384_in21ft1k': 'deit3_large_patch16_384.fb_in22k_ft_in1k',
|
| 415 |
+
'deit3_huge_patch14_224_in21ft1k': 'deit3_huge_patch14_224.fb_in22k_ft_in1k'
|
| 416 |
+
})
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/dla.py
ADDED
|
@@ -0,0 +1,515 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Deep Layer Aggregation and DLA w/ Res2Net
|
| 2 |
+
DLA original adapted from Official Pytorch impl at: https://github.com/ucbdrive/dla
|
| 3 |
+
DLA Paper: `Deep Layer Aggregation` - https://arxiv.org/abs/1707.06484
|
| 4 |
+
|
| 5 |
+
Res2Net additions from: https://github.com/gasvn/Res2Net/
|
| 6 |
+
Res2Net Paper: `Res2Net: A New Multi-scale Backbone Architecture` - https://arxiv.org/abs/1904.01169
|
| 7 |
+
"""
|
| 8 |
+
import math
|
| 9 |
+
from typing import List, Optional
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 16 |
+
from timm.layers import create_classifier
|
| 17 |
+
from ._builder import build_model_with_cfg
|
| 18 |
+
from ._registry import register_model, generate_default_cfgs
|
| 19 |
+
|
| 20 |
+
__all__ = ['DLA']
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class DlaBasic(nn.Module):
|
| 24 |
+
"""DLA Basic"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, inplanes, planes, stride=1, dilation=1, **_):
|
| 27 |
+
super(DlaBasic, self).__init__()
|
| 28 |
+
self.conv1 = nn.Conv2d(
|
| 29 |
+
inplanes, planes, kernel_size=3,
|
| 30 |
+
stride=stride, padding=dilation, bias=False, dilation=dilation)
|
| 31 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
| 32 |
+
self.relu = nn.ReLU(inplace=True)
|
| 33 |
+
self.conv2 = nn.Conv2d(
|
| 34 |
+
planes, planes, kernel_size=3,
|
| 35 |
+
stride=1, padding=dilation, bias=False, dilation=dilation)
|
| 36 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
| 37 |
+
self.stride = stride
|
| 38 |
+
|
| 39 |
+
def forward(self, x, shortcut: Optional[torch.Tensor] = None, children: Optional[List[torch.Tensor]] = None):
|
| 40 |
+
if shortcut is None:
|
| 41 |
+
shortcut = x
|
| 42 |
+
|
| 43 |
+
out = self.conv1(x)
|
| 44 |
+
out = self.bn1(out)
|
| 45 |
+
out = self.relu(out)
|
| 46 |
+
|
| 47 |
+
out = self.conv2(out)
|
| 48 |
+
out = self.bn2(out)
|
| 49 |
+
|
| 50 |
+
out += shortcut
|
| 51 |
+
out = self.relu(out)
|
| 52 |
+
|
| 53 |
+
return out
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class DlaBottleneck(nn.Module):
|
| 57 |
+
"""DLA/DLA-X Bottleneck"""
|
| 58 |
+
expansion = 2
|
| 59 |
+
|
| 60 |
+
def __init__(self, inplanes, outplanes, stride=1, dilation=1, cardinality=1, base_width=64):
|
| 61 |
+
super(DlaBottleneck, self).__init__()
|
| 62 |
+
self.stride = stride
|
| 63 |
+
mid_planes = int(math.floor(outplanes * (base_width / 64)) * cardinality)
|
| 64 |
+
mid_planes = mid_planes // self.expansion
|
| 65 |
+
|
| 66 |
+
self.conv1 = nn.Conv2d(inplanes, mid_planes, kernel_size=1, bias=False)
|
| 67 |
+
self.bn1 = nn.BatchNorm2d(mid_planes)
|
| 68 |
+
self.conv2 = nn.Conv2d(
|
| 69 |
+
mid_planes, mid_planes, kernel_size=3,
|
| 70 |
+
stride=stride, padding=dilation, bias=False, dilation=dilation, groups=cardinality)
|
| 71 |
+
self.bn2 = nn.BatchNorm2d(mid_planes)
|
| 72 |
+
self.conv3 = nn.Conv2d(mid_planes, outplanes, kernel_size=1, bias=False)
|
| 73 |
+
self.bn3 = nn.BatchNorm2d(outplanes)
|
| 74 |
+
self.relu = nn.ReLU(inplace=True)
|
| 75 |
+
|
| 76 |
+
def forward(self, x, shortcut: Optional[torch.Tensor] = None, children: Optional[List[torch.Tensor]] = None):
|
| 77 |
+
if shortcut is None:
|
| 78 |
+
shortcut = x
|
| 79 |
+
|
| 80 |
+
out = self.conv1(x)
|
| 81 |
+
out = self.bn1(out)
|
| 82 |
+
out = self.relu(out)
|
| 83 |
+
|
| 84 |
+
out = self.conv2(out)
|
| 85 |
+
out = self.bn2(out)
|
| 86 |
+
out = self.relu(out)
|
| 87 |
+
|
| 88 |
+
out = self.conv3(out)
|
| 89 |
+
out = self.bn3(out)
|
| 90 |
+
|
| 91 |
+
out += shortcut
|
| 92 |
+
out = self.relu(out)
|
| 93 |
+
|
| 94 |
+
return out
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class DlaBottle2neck(nn.Module):
|
| 98 |
+
""" Res2Net/Res2NeXT DLA Bottleneck
|
| 99 |
+
Adapted from https://github.com/gasvn/Res2Net/blob/master/dla.py
|
| 100 |
+
"""
|
| 101 |
+
expansion = 2
|
| 102 |
+
|
| 103 |
+
def __init__(self, inplanes, outplanes, stride=1, dilation=1, scale=4, cardinality=8, base_width=4):
|
| 104 |
+
super(DlaBottle2neck, self).__init__()
|
| 105 |
+
self.is_first = stride > 1
|
| 106 |
+
self.scale = scale
|
| 107 |
+
mid_planes = int(math.floor(outplanes * (base_width / 64)) * cardinality)
|
| 108 |
+
mid_planes = mid_planes // self.expansion
|
| 109 |
+
self.width = mid_planes
|
| 110 |
+
|
| 111 |
+
self.conv1 = nn.Conv2d(inplanes, mid_planes * scale, kernel_size=1, bias=False)
|
| 112 |
+
self.bn1 = nn.BatchNorm2d(mid_planes * scale)
|
| 113 |
+
|
| 114 |
+
num_scale_convs = max(1, scale - 1)
|
| 115 |
+
convs = []
|
| 116 |
+
bns = []
|
| 117 |
+
for _ in range(num_scale_convs):
|
| 118 |
+
convs.append(nn.Conv2d(
|
| 119 |
+
mid_planes, mid_planes, kernel_size=3,
|
| 120 |
+
stride=stride, padding=dilation, dilation=dilation, groups=cardinality, bias=False))
|
| 121 |
+
bns.append(nn.BatchNorm2d(mid_planes))
|
| 122 |
+
self.convs = nn.ModuleList(convs)
|
| 123 |
+
self.bns = nn.ModuleList(bns)
|
| 124 |
+
self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1) if self.is_first else None
|
| 125 |
+
|
| 126 |
+
self.conv3 = nn.Conv2d(mid_planes * scale, outplanes, kernel_size=1, bias=False)
|
| 127 |
+
self.bn3 = nn.BatchNorm2d(outplanes)
|
| 128 |
+
self.relu = nn.ReLU(inplace=True)
|
| 129 |
+
|
| 130 |
+
def forward(self, x, shortcut: Optional[torch.Tensor] = None, children: Optional[List[torch.Tensor]] = None):
|
| 131 |
+
if shortcut is None:
|
| 132 |
+
shortcut = x
|
| 133 |
+
|
| 134 |
+
out = self.conv1(x)
|
| 135 |
+
out = self.bn1(out)
|
| 136 |
+
out = self.relu(out)
|
| 137 |
+
|
| 138 |
+
spx = torch.split(out, self.width, 1)
|
| 139 |
+
spo = []
|
| 140 |
+
sp = spx[0] # redundant, for torchscript
|
| 141 |
+
for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
|
| 142 |
+
if i == 0 or self.is_first:
|
| 143 |
+
sp = spx[i]
|
| 144 |
+
else:
|
| 145 |
+
sp = sp + spx[i]
|
| 146 |
+
sp = conv(sp)
|
| 147 |
+
sp = bn(sp)
|
| 148 |
+
sp = self.relu(sp)
|
| 149 |
+
spo.append(sp)
|
| 150 |
+
if self.scale > 1:
|
| 151 |
+
if self.pool is not None: # self.is_first == True, None check for torchscript
|
| 152 |
+
spo.append(self.pool(spx[-1]))
|
| 153 |
+
else:
|
| 154 |
+
spo.append(spx[-1])
|
| 155 |
+
out = torch.cat(spo, 1)
|
| 156 |
+
|
| 157 |
+
out = self.conv3(out)
|
| 158 |
+
out = self.bn3(out)
|
| 159 |
+
|
| 160 |
+
out += shortcut
|
| 161 |
+
out = self.relu(out)
|
| 162 |
+
|
| 163 |
+
return out
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
class DlaRoot(nn.Module):
|
| 167 |
+
def __init__(self, in_channels, out_channels, kernel_size, shortcut):
|
| 168 |
+
super(DlaRoot, self).__init__()
|
| 169 |
+
self.conv = nn.Conv2d(
|
| 170 |
+
in_channels, out_channels, 1, stride=1, bias=False, padding=(kernel_size - 1) // 2)
|
| 171 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
| 172 |
+
self.relu = nn.ReLU(inplace=True)
|
| 173 |
+
self.shortcut = shortcut
|
| 174 |
+
|
| 175 |
+
def forward(self, x_children: List[torch.Tensor]):
|
| 176 |
+
x = self.conv(torch.cat(x_children, 1))
|
| 177 |
+
x = self.bn(x)
|
| 178 |
+
if self.shortcut:
|
| 179 |
+
x += x_children[0]
|
| 180 |
+
x = self.relu(x)
|
| 181 |
+
|
| 182 |
+
return x
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class DlaTree(nn.Module):
|
| 186 |
+
def __init__(
|
| 187 |
+
self,
|
| 188 |
+
levels,
|
| 189 |
+
block,
|
| 190 |
+
in_channels,
|
| 191 |
+
out_channels,
|
| 192 |
+
stride=1,
|
| 193 |
+
dilation=1,
|
| 194 |
+
cardinality=1,
|
| 195 |
+
base_width=64,
|
| 196 |
+
level_root=False,
|
| 197 |
+
root_dim=0,
|
| 198 |
+
root_kernel_size=1,
|
| 199 |
+
root_shortcut=False,
|
| 200 |
+
):
|
| 201 |
+
super(DlaTree, self).__init__()
|
| 202 |
+
if root_dim == 0:
|
| 203 |
+
root_dim = 2 * out_channels
|
| 204 |
+
if level_root:
|
| 205 |
+
root_dim += in_channels
|
| 206 |
+
self.downsample = nn.MaxPool2d(stride, stride=stride) if stride > 1 else nn.Identity()
|
| 207 |
+
self.project = nn.Identity()
|
| 208 |
+
cargs = dict(dilation=dilation, cardinality=cardinality, base_width=base_width)
|
| 209 |
+
if levels == 1:
|
| 210 |
+
self.tree1 = block(in_channels, out_channels, stride, **cargs)
|
| 211 |
+
self.tree2 = block(out_channels, out_channels, 1, **cargs)
|
| 212 |
+
if in_channels != out_channels:
|
| 213 |
+
# NOTE the official impl/weights have project layers in levels > 1 case that are never
|
| 214 |
+
# used, I've moved the project layer here to avoid wasted params but old checkpoints will
|
| 215 |
+
# need strict=False while loading.
|
| 216 |
+
self.project = nn.Sequential(
|
| 217 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
|
| 218 |
+
nn.BatchNorm2d(out_channels))
|
| 219 |
+
self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_shortcut)
|
| 220 |
+
else:
|
| 221 |
+
cargs.update(dict(root_kernel_size=root_kernel_size, root_shortcut=root_shortcut))
|
| 222 |
+
self.tree1 = DlaTree(
|
| 223 |
+
levels - 1,
|
| 224 |
+
block,
|
| 225 |
+
in_channels,
|
| 226 |
+
out_channels,
|
| 227 |
+
stride,
|
| 228 |
+
root_dim=0,
|
| 229 |
+
**cargs,
|
| 230 |
+
)
|
| 231 |
+
self.tree2 = DlaTree(
|
| 232 |
+
levels - 1,
|
| 233 |
+
block,
|
| 234 |
+
out_channels,
|
| 235 |
+
out_channels,
|
| 236 |
+
root_dim=root_dim + out_channels,
|
| 237 |
+
**cargs,
|
| 238 |
+
)
|
| 239 |
+
self.root = None
|
| 240 |
+
self.level_root = level_root
|
| 241 |
+
self.root_dim = root_dim
|
| 242 |
+
self.levels = levels
|
| 243 |
+
|
| 244 |
+
def forward(self, x, shortcut: Optional[torch.Tensor] = None, children: Optional[List[torch.Tensor]] = None):
|
| 245 |
+
if children is None:
|
| 246 |
+
children = []
|
| 247 |
+
bottom = self.downsample(x)
|
| 248 |
+
shortcut = self.project(bottom)
|
| 249 |
+
if self.level_root:
|
| 250 |
+
children.append(bottom)
|
| 251 |
+
x1 = self.tree1(x, shortcut)
|
| 252 |
+
if self.root is not None: # levels == 1
|
| 253 |
+
x2 = self.tree2(x1)
|
| 254 |
+
x = self.root([x2, x1] + children)
|
| 255 |
+
else:
|
| 256 |
+
children.append(x1)
|
| 257 |
+
x = self.tree2(x1, None, children)
|
| 258 |
+
return x
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class DLA(nn.Module):
|
| 262 |
+
def __init__(
|
| 263 |
+
self,
|
| 264 |
+
levels,
|
| 265 |
+
channels,
|
| 266 |
+
output_stride=32,
|
| 267 |
+
num_classes=1000,
|
| 268 |
+
in_chans=3,
|
| 269 |
+
global_pool='avg',
|
| 270 |
+
cardinality=1,
|
| 271 |
+
base_width=64,
|
| 272 |
+
block=DlaBottle2neck,
|
| 273 |
+
shortcut_root=False,
|
| 274 |
+
drop_rate=0.0,
|
| 275 |
+
):
|
| 276 |
+
super(DLA, self).__init__()
|
| 277 |
+
self.channels = channels
|
| 278 |
+
self.num_classes = num_classes
|
| 279 |
+
self.cardinality = cardinality
|
| 280 |
+
self.base_width = base_width
|
| 281 |
+
assert output_stride == 32 # FIXME support dilation
|
| 282 |
+
|
| 283 |
+
self.base_layer = nn.Sequential(
|
| 284 |
+
nn.Conv2d(in_chans, channels[0], kernel_size=7, stride=1, padding=3, bias=False),
|
| 285 |
+
nn.BatchNorm2d(channels[0]),
|
| 286 |
+
nn.ReLU(inplace=True),
|
| 287 |
+
)
|
| 288 |
+
self.level0 = self._make_conv_level(channels[0], channels[0], levels[0])
|
| 289 |
+
self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2)
|
| 290 |
+
cargs = dict(cardinality=cardinality, base_width=base_width, root_shortcut=shortcut_root)
|
| 291 |
+
self.level2 = DlaTree(levels[2], block, channels[1], channels[2], 2, level_root=False, **cargs)
|
| 292 |
+
self.level3 = DlaTree(levels[3], block, channels[2], channels[3], 2, level_root=True, **cargs)
|
| 293 |
+
self.level4 = DlaTree(levels[4], block, channels[3], channels[4], 2, level_root=True, **cargs)
|
| 294 |
+
self.level5 = DlaTree(levels[5], block, channels[4], channels[5], 2, level_root=True, **cargs)
|
| 295 |
+
self.feature_info = [
|
| 296 |
+
dict(num_chs=channels[0], reduction=1, module='level0'), # rare to have a meaningful stride 1 level
|
| 297 |
+
dict(num_chs=channels[1], reduction=2, module='level1'),
|
| 298 |
+
dict(num_chs=channels[2], reduction=4, module='level2'),
|
| 299 |
+
dict(num_chs=channels[3], reduction=8, module='level3'),
|
| 300 |
+
dict(num_chs=channels[4], reduction=16, module='level4'),
|
| 301 |
+
dict(num_chs=channels[5], reduction=32, module='level5'),
|
| 302 |
+
]
|
| 303 |
+
|
| 304 |
+
self.num_features = channels[-1]
|
| 305 |
+
self.global_pool, self.head_drop, self.fc = create_classifier(
|
| 306 |
+
self.num_features,
|
| 307 |
+
self.num_classes,
|
| 308 |
+
pool_type=global_pool,
|
| 309 |
+
use_conv=True,
|
| 310 |
+
drop_rate=drop_rate,
|
| 311 |
+
)
|
| 312 |
+
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
|
| 313 |
+
|
| 314 |
+
for m in self.modules():
|
| 315 |
+
if isinstance(m, nn.Conv2d):
|
| 316 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 317 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
| 318 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 319 |
+
m.weight.data.fill_(1)
|
| 320 |
+
m.bias.data.zero_()
|
| 321 |
+
|
| 322 |
+
def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1):
|
| 323 |
+
modules = []
|
| 324 |
+
for i in range(convs):
|
| 325 |
+
modules.extend([
|
| 326 |
+
nn.Conv2d(
|
| 327 |
+
inplanes, planes, kernel_size=3,
|
| 328 |
+
stride=stride if i == 0 else 1,
|
| 329 |
+
padding=dilation, bias=False, dilation=dilation),
|
| 330 |
+
nn.BatchNorm2d(planes),
|
| 331 |
+
nn.ReLU(inplace=True)])
|
| 332 |
+
inplanes = planes
|
| 333 |
+
return nn.Sequential(*modules)
|
| 334 |
+
|
| 335 |
+
@torch.jit.ignore
|
| 336 |
+
def group_matcher(self, coarse=False):
|
| 337 |
+
matcher = dict(
|
| 338 |
+
stem=r'^base_layer',
|
| 339 |
+
blocks=r'^level(\d+)' if coarse else [
|
| 340 |
+
# an unusual arch, this achieves somewhat more granularity without getting super messy
|
| 341 |
+
(r'^level(\d+)\.tree(\d+)', None),
|
| 342 |
+
(r'^level(\d+)\.root', (2,)),
|
| 343 |
+
(r'^level(\d+)', (1,))
|
| 344 |
+
]
|
| 345 |
+
)
|
| 346 |
+
return matcher
|
| 347 |
+
|
| 348 |
+
@torch.jit.ignore
|
| 349 |
+
def set_grad_checkpointing(self, enable=True):
|
| 350 |
+
assert not enable, 'gradient checkpointing not supported'
|
| 351 |
+
|
| 352 |
+
@torch.jit.ignore
|
| 353 |
+
def get_classifier(self):
|
| 354 |
+
return self.fc
|
| 355 |
+
|
| 356 |
+
def reset_classifier(self, num_classes, global_pool='avg'):
|
| 357 |
+
self.num_classes = num_classes
|
| 358 |
+
self.global_pool, self.fc = create_classifier(
|
| 359 |
+
self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
|
| 360 |
+
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
|
| 361 |
+
|
| 362 |
+
def forward_features(self, x):
|
| 363 |
+
x = self.base_layer(x)
|
| 364 |
+
x = self.level0(x)
|
| 365 |
+
x = self.level1(x)
|
| 366 |
+
x = self.level2(x)
|
| 367 |
+
x = self.level3(x)
|
| 368 |
+
x = self.level4(x)
|
| 369 |
+
x = self.level5(x)
|
| 370 |
+
return x
|
| 371 |
+
|
| 372 |
+
def forward_head(self, x, pre_logits: bool = False):
|
| 373 |
+
x = self.global_pool(x)
|
| 374 |
+
x = self.head_drop(x)
|
| 375 |
+
if pre_logits:
|
| 376 |
+
return self.flatten(x)
|
| 377 |
+
x = self.fc(x)
|
| 378 |
+
return self.flatten(x)
|
| 379 |
+
|
| 380 |
+
def forward(self, x):
|
| 381 |
+
x = self.forward_features(x)
|
| 382 |
+
x = self.forward_head(x)
|
| 383 |
+
return x
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def _create_dla(variant, pretrained=False, **kwargs):
|
| 387 |
+
return build_model_with_cfg(
|
| 388 |
+
DLA,
|
| 389 |
+
variant,
|
| 390 |
+
pretrained,
|
| 391 |
+
pretrained_strict=False,
|
| 392 |
+
feature_cfg=dict(out_indices=(1, 2, 3, 4, 5)),
|
| 393 |
+
**kwargs,
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def _cfg(url='', **kwargs):
|
| 398 |
+
return {
|
| 399 |
+
'url': url,
|
| 400 |
+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
| 401 |
+
'crop_pct': 0.875, 'interpolation': 'bilinear',
|
| 402 |
+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
| 403 |
+
'first_conv': 'base_layer.0', 'classifier': 'fc',
|
| 404 |
+
**kwargs
|
| 405 |
+
}
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
default_cfgs = generate_default_cfgs({
|
| 409 |
+
'dla34.in1k': _cfg(hf_hub_id='timm/'),
|
| 410 |
+
'dla46_c.in1k': _cfg(hf_hub_id='timm/'),
|
| 411 |
+
'dla46x_c.in1k': _cfg(hf_hub_id='timm/'),
|
| 412 |
+
'dla60x_c.in1k': _cfg(hf_hub_id='timm/'),
|
| 413 |
+
'dla60.in1k': _cfg(hf_hub_id='timm/'),
|
| 414 |
+
'dla60x.in1k': _cfg(hf_hub_id='timm/'),
|
| 415 |
+
'dla102.in1k': _cfg(hf_hub_id='timm/'),
|
| 416 |
+
'dla102x.in1k': _cfg(hf_hub_id='timm/'),
|
| 417 |
+
'dla102x2.in1k': _cfg(hf_hub_id='timm/'),
|
| 418 |
+
'dla169.in1k': _cfg(hf_hub_id='timm/'),
|
| 419 |
+
'dla60_res2net.in1k': _cfg(hf_hub_id='timm/'),
|
| 420 |
+
'dla60_res2next.in1k': _cfg(hf_hub_id='timm/'),
|
| 421 |
+
})
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
@register_model
|
| 425 |
+
def dla60_res2net(pretrained=False, **kwargs) -> DLA:
|
| 426 |
+
model_args = dict(
|
| 427 |
+
levels=(1, 1, 1, 2, 3, 1), channels=(16, 32, 128, 256, 512, 1024),
|
| 428 |
+
block=DlaBottle2neck, cardinality=1, base_width=28)
|
| 429 |
+
return _create_dla('dla60_res2net', pretrained, **dict(model_args, **kwargs))
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
@register_model
|
| 433 |
+
def dla60_res2next(pretrained=False,**kwargs):
|
| 434 |
+
model_args = dict(
|
| 435 |
+
levels=(1, 1, 1, 2, 3, 1), channels=(16, 32, 128, 256, 512, 1024),
|
| 436 |
+
block=DlaBottle2neck, cardinality=8, base_width=4)
|
| 437 |
+
return _create_dla('dla60_res2next', pretrained, **dict(model_args, **kwargs))
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
@register_model
|
| 441 |
+
def dla34(pretrained=False, **kwargs) -> DLA: # DLA-34
|
| 442 |
+
model_args = dict(
|
| 443 |
+
levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 128, 256, 512], block=DlaBasic)
|
| 444 |
+
return _create_dla('dla34', pretrained, **dict(model_args, **kwargs))
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
@register_model
|
| 448 |
+
def dla46_c(pretrained=False, **kwargs) -> DLA: # DLA-46-C
|
| 449 |
+
model_args = dict(
|
| 450 |
+
levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 64, 128, 256], block=DlaBottleneck)
|
| 451 |
+
return _create_dla('dla46_c', pretrained, **dict(model_args, **kwargs))
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
@register_model
|
| 455 |
+
def dla46x_c(pretrained=False, **kwargs) -> DLA: # DLA-X-46-C
|
| 456 |
+
model_args = dict(
|
| 457 |
+
levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 64, 128, 256],
|
| 458 |
+
block=DlaBottleneck, cardinality=32, base_width=4)
|
| 459 |
+
return _create_dla('dla46x_c', pretrained, **dict(model_args, **kwargs))
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
@register_model
|
| 463 |
+
def dla60x_c(pretrained=False, **kwargs) -> DLA: # DLA-X-60-C
|
| 464 |
+
model_args = dict(
|
| 465 |
+
levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 64, 64, 128, 256],
|
| 466 |
+
block=DlaBottleneck, cardinality=32, base_width=4)
|
| 467 |
+
return _create_dla('dla60x_c', pretrained, **dict(model_args, **kwargs))
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
@register_model
|
| 471 |
+
def dla60(pretrained=False, **kwargs) -> DLA: # DLA-60
|
| 472 |
+
model_args = dict(
|
| 473 |
+
levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 128, 256, 512, 1024],
|
| 474 |
+
block=DlaBottleneck)
|
| 475 |
+
return _create_dla('dla60', pretrained, **dict(model_args, **kwargs))
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
@register_model
|
| 479 |
+
def dla60x(pretrained=False, **kwargs) -> DLA: # DLA-X-60
|
| 480 |
+
model_args = dict(
|
| 481 |
+
levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 128, 256, 512, 1024],
|
| 482 |
+
block=DlaBottleneck, cardinality=32, base_width=4)
|
| 483 |
+
return _create_dla('dla60x', pretrained, **dict(model_args, **kwargs))
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
@register_model
|
| 487 |
+
def dla102(pretrained=False, **kwargs) -> DLA: # DLA-102
|
| 488 |
+
model_args = dict(
|
| 489 |
+
levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
|
| 490 |
+
block=DlaBottleneck, shortcut_root=True)
|
| 491 |
+
return _create_dla('dla102', pretrained, **dict(model_args, **kwargs))
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
@register_model
|
| 495 |
+
def dla102x(pretrained=False, **kwargs) -> DLA: # DLA-X-102
|
| 496 |
+
model_args = dict(
|
| 497 |
+
levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
|
| 498 |
+
block=DlaBottleneck, cardinality=32, base_width=4, shortcut_root=True)
|
| 499 |
+
return _create_dla('dla102x', pretrained, **dict(model_args, **kwargs))
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
@register_model
|
| 503 |
+
def dla102x2(pretrained=False, **kwargs) -> DLA: # DLA-X-102 64
|
| 504 |
+
model_args = dict(
|
| 505 |
+
levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
|
| 506 |
+
block=DlaBottleneck, cardinality=64, base_width=4, shortcut_root=True)
|
| 507 |
+
return _create_dla('dla102x2', pretrained, **dict(model_args, **kwargs))
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
@register_model
|
| 511 |
+
def dla169(pretrained=False, **kwargs) -> DLA: # DLA-169
|
| 512 |
+
model_args = dict(
|
| 513 |
+
levels=[1, 1, 2, 3, 5, 1], channels=[16, 32, 128, 256, 512, 1024],
|
| 514 |
+
block=DlaBottleneck, shortcut_root=True)
|
| 515 |
+
return _create_dla('dla169', pretrained, **dict(model_args, **kwargs))
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/eva.py
ADDED
|
@@ -0,0 +1,1109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" EVA
|
| 2 |
+
|
| 3 |
+
EVA from https://github.com/baaivision/EVA , paper: https://arxiv.org/abs/2211.07636
|
| 4 |
+
|
| 5 |
+
@article{EVA,
|
| 6 |
+
title={EVA: Exploring the Limits of Masked Visual Representation Learning at Scale},
|
| 7 |
+
author={Fang, Yuxin and Wang, Wen and Xie, Binhui and Sun, Quan and Wu, Ledell and Wang, Xinggang and Huang,
|
| 8 |
+
Tiejun and Wang, Xinlong and Cao, Yue},
|
| 9 |
+
journal={arXiv preprint arXiv:2211.07636},
|
| 10 |
+
year={2022}
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
EVA-02: A Visual Representation for Neon Genesis - https://arxiv.org/abs/2303.11331
|
| 14 |
+
@article{EVA02,
|
| 15 |
+
title={EVA-02: A Visual Representation for Neon Genesis},
|
| 16 |
+
author={Fang, Yuxin and Sun, Quan and Wang, Xinggang and Huang, Tiejun and Wang, Xinlong and Cao, Yue},
|
| 17 |
+
journal={arXiv preprint arXiv:2303.11331},
|
| 18 |
+
year={2023}
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
This file contains EVA & EVA02 model implementations evolved from BEiT, additional models in vision_transformer.py.
|
| 22 |
+
|
| 23 |
+
Modifications by / Copyright 2023 Ross Wightman, original copyrights below
|
| 24 |
+
"""
|
| 25 |
+
# EVA models Copyright (c) 2022 BAAI-Vision
|
| 26 |
+
# EVA02 models Copyright (c) 2023 BAAI-Vision
|
| 27 |
+
|
| 28 |
+
import math
|
| 29 |
+
from typing import Callable, Optional, Tuple, Union
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
import torch.nn as nn
|
| 33 |
+
import torch.nn.functional as F
|
| 34 |
+
from torch.utils.checkpoint import checkpoint
|
| 35 |
+
|
| 36 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
| 37 |
+
from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, PatchDropout, RotaryEmbeddingCat, \
|
| 38 |
+
apply_rot_embed_cat, apply_keep_indices_nlc, trunc_normal_, resample_patch_embed, resample_abs_pos_embed, \
|
| 39 |
+
to_2tuple, use_fused_attn
|
| 40 |
+
|
| 41 |
+
from ._builder import build_model_with_cfg
|
| 42 |
+
from ._registry import generate_default_cfgs, register_model
|
| 43 |
+
|
| 44 |
+
__all__ = ['Eva']
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class EvaAttention(nn.Module):
|
| 48 |
+
fused_attn: torch.jit.Final[bool]
|
| 49 |
+
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
dim: int,
|
| 53 |
+
num_heads: int = 8,
|
| 54 |
+
qkv_bias: bool = True,
|
| 55 |
+
qkv_fused: bool = True,
|
| 56 |
+
attn_drop: float = 0.,
|
| 57 |
+
proj_drop: float = 0.,
|
| 58 |
+
attn_head_dim: Optional[int] = None,
|
| 59 |
+
norm_layer: Optional[Callable] = None,
|
| 60 |
+
):
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
dim:
|
| 65 |
+
num_heads:
|
| 66 |
+
qkv_bias:
|
| 67 |
+
qkv_fused:
|
| 68 |
+
attn_drop:
|
| 69 |
+
proj_drop:
|
| 70 |
+
attn_head_dim:
|
| 71 |
+
norm_layer:
|
| 72 |
+
"""
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.num_heads = num_heads
|
| 75 |
+
head_dim = dim // num_heads
|
| 76 |
+
if attn_head_dim is not None:
|
| 77 |
+
head_dim = attn_head_dim
|
| 78 |
+
all_head_dim = head_dim * self.num_heads
|
| 79 |
+
self.scale = head_dim ** -0.5
|
| 80 |
+
self.fused_attn = use_fused_attn()
|
| 81 |
+
|
| 82 |
+
if qkv_fused:
|
| 83 |
+
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
| 84 |
+
self.q_proj = self.k_proj = self.v_proj = None
|
| 85 |
+
if qkv_bias:
|
| 86 |
+
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
| 87 |
+
self.register_buffer('k_bias', torch.zeros(all_head_dim), persistent=False)
|
| 88 |
+
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
| 89 |
+
else:
|
| 90 |
+
self.q_bias = self.k_bias = self.v_bias = None
|
| 91 |
+
else:
|
| 92 |
+
self.q_proj = nn.Linear(dim, all_head_dim, bias=qkv_bias)
|
| 93 |
+
self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
|
| 94 |
+
self.v_proj = nn.Linear(dim, all_head_dim, bias=qkv_bias)
|
| 95 |
+
self.qkv = None
|
| 96 |
+
self.q_bias = self.k_bias = self.v_bias = None
|
| 97 |
+
|
| 98 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 99 |
+
self.norm = norm_layer(all_head_dim) if norm_layer is not None else nn.Identity()
|
| 100 |
+
self.proj = nn.Linear(all_head_dim, dim)
|
| 101 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 102 |
+
|
| 103 |
+
def forward(
|
| 104 |
+
self,
|
| 105 |
+
x,
|
| 106 |
+
rope: Optional[torch.Tensor] = None,
|
| 107 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 108 |
+
):
|
| 109 |
+
B, N, C = x.shape
|
| 110 |
+
|
| 111 |
+
if self.qkv is not None:
|
| 112 |
+
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
|
| 113 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
| 114 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
| 115 |
+
q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
|
| 116 |
+
else:
|
| 117 |
+
q = self.q_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) # B, num_heads, N, C
|
| 118 |
+
k = self.k_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
|
| 119 |
+
v = self.v_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
|
| 120 |
+
|
| 121 |
+
if rope is not None:
|
| 122 |
+
q = torch.cat([q[:, :, :1, :], apply_rot_embed_cat(q[:, :, 1:, :], rope)], 2).type_as(v)
|
| 123 |
+
k = torch.cat([k[:, :, :1, :], apply_rot_embed_cat(k[:, :, 1:, :], rope)], 2).type_as(v)
|
| 124 |
+
|
| 125 |
+
if self.fused_attn:
|
| 126 |
+
x = F.scaled_dot_product_attention(
|
| 127 |
+
q, k, v,
|
| 128 |
+
attn_mask=attn_mask,
|
| 129 |
+
dropout_p=self.attn_drop.p if self.training else 0.,
|
| 130 |
+
)
|
| 131 |
+
else:
|
| 132 |
+
q = q * self.scale
|
| 133 |
+
attn = (q @ k.transpose(-2, -1))
|
| 134 |
+
attn = attn.softmax(dim=-1)
|
| 135 |
+
if attn_mask is not None:
|
| 136 |
+
attn_mask = attn_mask.to(torch.bool)
|
| 137 |
+
attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
|
| 138 |
+
attn = self.attn_drop(attn)
|
| 139 |
+
x = attn @ v
|
| 140 |
+
|
| 141 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
| 142 |
+
x = self.norm(x)
|
| 143 |
+
x = self.proj(x)
|
| 144 |
+
x = self.proj_drop(x)
|
| 145 |
+
return x
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class EvaBlock(nn.Module):
|
| 149 |
+
|
| 150 |
+
def __init__(
|
| 151 |
+
self,
|
| 152 |
+
dim: int,
|
| 153 |
+
num_heads: int,
|
| 154 |
+
qkv_bias: bool = True,
|
| 155 |
+
qkv_fused: bool = True,
|
| 156 |
+
mlp_ratio: float = 4.,
|
| 157 |
+
swiglu_mlp: bool = False,
|
| 158 |
+
scale_mlp: bool = False,
|
| 159 |
+
scale_attn_inner: bool = False,
|
| 160 |
+
proj_drop: float = 0.,
|
| 161 |
+
attn_drop: float = 0.,
|
| 162 |
+
drop_path: float = 0.,
|
| 163 |
+
init_values: Optional[float] = None,
|
| 164 |
+
act_layer: Callable = nn.GELU,
|
| 165 |
+
norm_layer: Callable = LayerNorm,
|
| 166 |
+
attn_head_dim: Optional[int] = None,
|
| 167 |
+
):
|
| 168 |
+
"""
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
dim:
|
| 172 |
+
num_heads:
|
| 173 |
+
qkv_bias:
|
| 174 |
+
qkv_fused:
|
| 175 |
+
mlp_ratio:
|
| 176 |
+
swiglu_mlp:
|
| 177 |
+
scale_mlp:
|
| 178 |
+
scale_attn_inner:
|
| 179 |
+
proj_drop:
|
| 180 |
+
attn_drop:
|
| 181 |
+
drop_path:
|
| 182 |
+
init_values:
|
| 183 |
+
act_layer:
|
| 184 |
+
norm_layer:
|
| 185 |
+
attn_head_dim:
|
| 186 |
+
"""
|
| 187 |
+
super().__init__()
|
| 188 |
+
self.norm1 = norm_layer(dim)
|
| 189 |
+
self.attn = EvaAttention(
|
| 190 |
+
dim,
|
| 191 |
+
num_heads=num_heads,
|
| 192 |
+
qkv_bias=qkv_bias,
|
| 193 |
+
qkv_fused=qkv_fused,
|
| 194 |
+
attn_drop=attn_drop,
|
| 195 |
+
proj_drop=proj_drop,
|
| 196 |
+
attn_head_dim=attn_head_dim,
|
| 197 |
+
norm_layer=norm_layer if scale_attn_inner else None,
|
| 198 |
+
)
|
| 199 |
+
self.gamma_1 = nn.Parameter(init_values * torch.ones(dim)) if init_values is not None else None
|
| 200 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 201 |
+
|
| 202 |
+
self.norm2 = norm_layer(dim)
|
| 203 |
+
hidden_features = int(dim * mlp_ratio)
|
| 204 |
+
if swiglu_mlp:
|
| 205 |
+
if scale_mlp:
|
| 206 |
+
# when norm in SwiGLU used, an impl with separate fc for gate & x is used
|
| 207 |
+
self.mlp = SwiGLU(
|
| 208 |
+
in_features=dim,
|
| 209 |
+
hidden_features=hidden_features,
|
| 210 |
+
norm_layer=norm_layer if scale_mlp else None,
|
| 211 |
+
drop=proj_drop,
|
| 212 |
+
)
|
| 213 |
+
else:
|
| 214 |
+
# w/o any extra norm, an impl with packed weights is used, matches existing GluMLP
|
| 215 |
+
self.mlp = GluMlp(
|
| 216 |
+
in_features=dim,
|
| 217 |
+
hidden_features=hidden_features * 2,
|
| 218 |
+
norm_layer=norm_layer if scale_mlp else None,
|
| 219 |
+
act_layer=nn.SiLU,
|
| 220 |
+
gate_last=False,
|
| 221 |
+
drop=proj_drop,
|
| 222 |
+
)
|
| 223 |
+
else:
|
| 224 |
+
self.mlp = Mlp(
|
| 225 |
+
in_features=dim,
|
| 226 |
+
hidden_features=hidden_features,
|
| 227 |
+
act_layer=act_layer,
|
| 228 |
+
norm_layer=norm_layer if scale_mlp else None,
|
| 229 |
+
drop=proj_drop,
|
| 230 |
+
)
|
| 231 |
+
self.gamma_2 = nn.Parameter(init_values * torch.ones(dim)) if init_values is not None else None
|
| 232 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 233 |
+
|
| 234 |
+
def forward(self, x, rope: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None):
|
| 235 |
+
if self.gamma_1 is None:
|
| 236 |
+
x = x + self.drop_path1(self.attn(self.norm1(x), rope=rope, attn_mask=attn_mask))
|
| 237 |
+
x = x + self.drop_path2(self.mlp(self.norm2(x)))
|
| 238 |
+
else:
|
| 239 |
+
x = x + self.drop_path1(self.gamma_1 * self.attn(self.norm1(x), rope=rope, attn_mask=attn_mask))
|
| 240 |
+
x = x + self.drop_path2(self.gamma_2 * self.mlp(self.norm2(x)))
|
| 241 |
+
return x
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class EvaBlockPostNorm(nn.Module):
|
| 245 |
+
""" EVA block w/ post-norm and support for swiglu, MLP norm scale, ROPE. """
|
| 246 |
+
def __init__(
|
| 247 |
+
self,
|
| 248 |
+
dim: int,
|
| 249 |
+
num_heads: int,
|
| 250 |
+
qkv_bias: bool = True,
|
| 251 |
+
qkv_fused: bool = True,
|
| 252 |
+
mlp_ratio: float = 4.,
|
| 253 |
+
swiglu_mlp: bool = False,
|
| 254 |
+
scale_mlp: bool = False,
|
| 255 |
+
scale_attn_inner: bool = False,
|
| 256 |
+
proj_drop: float = 0.,
|
| 257 |
+
attn_drop: float = 0.,
|
| 258 |
+
drop_path: float = 0.,
|
| 259 |
+
init_values: Optional[float] = None, # ignore for post-norm
|
| 260 |
+
act_layer: Callable = nn.GELU,
|
| 261 |
+
norm_layer: Callable = nn.LayerNorm,
|
| 262 |
+
attn_head_dim: Optional[int] = None,
|
| 263 |
+
):
|
| 264 |
+
"""
|
| 265 |
+
|
| 266 |
+
Args:
|
| 267 |
+
dim:
|
| 268 |
+
num_heads:
|
| 269 |
+
qkv_bias:
|
| 270 |
+
qkv_fused:
|
| 271 |
+
mlp_ratio:
|
| 272 |
+
swiglu_mlp:
|
| 273 |
+
scale_mlp:
|
| 274 |
+
scale_attn_inner:
|
| 275 |
+
proj_drop:
|
| 276 |
+
attn_drop:
|
| 277 |
+
drop_path:
|
| 278 |
+
init_values:
|
| 279 |
+
act_layer:
|
| 280 |
+
norm_layer:
|
| 281 |
+
attn_head_dim:
|
| 282 |
+
"""
|
| 283 |
+
super().__init__()
|
| 284 |
+
self.attn = EvaAttention(
|
| 285 |
+
dim,
|
| 286 |
+
num_heads=num_heads,
|
| 287 |
+
qkv_bias=qkv_bias,
|
| 288 |
+
qkv_fused=qkv_fused,
|
| 289 |
+
attn_drop=attn_drop,
|
| 290 |
+
proj_drop=proj_drop,
|
| 291 |
+
attn_head_dim=attn_head_dim,
|
| 292 |
+
norm_layer=norm_layer if scale_attn_inner else None,
|
| 293 |
+
)
|
| 294 |
+
self.norm1 = norm_layer(dim)
|
| 295 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 296 |
+
|
| 297 |
+
hidden_features = int(dim * mlp_ratio)
|
| 298 |
+
if swiglu_mlp:
|
| 299 |
+
if scale_mlp:
|
| 300 |
+
# when norm in SwiGLU used, an impl with separate fc for gate & x is used
|
| 301 |
+
self.mlp = SwiGLU(
|
| 302 |
+
in_features=dim,
|
| 303 |
+
hidden_features=hidden_features,
|
| 304 |
+
norm_layer=norm_layer if scale_mlp else None,
|
| 305 |
+
drop=proj_drop,
|
| 306 |
+
)
|
| 307 |
+
else:
|
| 308 |
+
# w/o any extra norm, an impl with packed fc1 weights is used, matches existing GluMLP
|
| 309 |
+
self.mlp = GluMlp(
|
| 310 |
+
in_features=dim,
|
| 311 |
+
hidden_features=hidden_features * 2,
|
| 312 |
+
norm_layer=norm_layer if scale_mlp else None,
|
| 313 |
+
act_layer=nn.SiLU,
|
| 314 |
+
gate_last=False,
|
| 315 |
+
drop=proj_drop,
|
| 316 |
+
)
|
| 317 |
+
else:
|
| 318 |
+
self.mlp = Mlp(
|
| 319 |
+
in_features=dim,
|
| 320 |
+
hidden_features=hidden_features,
|
| 321 |
+
act_layer=act_layer,
|
| 322 |
+
norm_layer=norm_layer if scale_mlp else None,
|
| 323 |
+
drop=proj_drop,
|
| 324 |
+
)
|
| 325 |
+
self.norm2 = norm_layer(dim)
|
| 326 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 327 |
+
|
| 328 |
+
def forward(self, x, rope: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None):
|
| 329 |
+
x = x + self.drop_path1(self.norm1(self.attn(x, rope=rope, attn_mask=attn_mask)))
|
| 330 |
+
x = x + self.drop_path2(self.norm2(self.mlp(x)))
|
| 331 |
+
return x
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
class Eva(nn.Module):
|
| 335 |
+
""" Eva Vision Transformer w/ Abs & Rotary Pos Embed
|
| 336 |
+
|
| 337 |
+
This class implements the EVA and EVA02 models that were based on the BEiT ViT variant
|
| 338 |
+
* EVA - abs pos embed, global avg pool
|
| 339 |
+
* EVA02 - abs + rope pos embed, global avg pool, SwiGLU, scale Norm in MLP (ala normformer)
|
| 340 |
+
"""
|
| 341 |
+
|
| 342 |
+
def __init__(
|
| 343 |
+
self,
|
| 344 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
| 345 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
| 346 |
+
in_chans: int = 3,
|
| 347 |
+
num_classes: int = 1000,
|
| 348 |
+
global_pool: str = 'avg',
|
| 349 |
+
embed_dim: int = 768,
|
| 350 |
+
depth: int = 12,
|
| 351 |
+
num_heads: int = 12,
|
| 352 |
+
qkv_bias: bool = True,
|
| 353 |
+
qkv_fused: bool = True,
|
| 354 |
+
mlp_ratio: float = 4.,
|
| 355 |
+
swiglu_mlp: bool = False,
|
| 356 |
+
scale_mlp: bool = False,
|
| 357 |
+
scale_attn_inner: bool = False,
|
| 358 |
+
drop_rate: float = 0.,
|
| 359 |
+
pos_drop_rate: float = 0.,
|
| 360 |
+
patch_drop_rate: float = 0.,
|
| 361 |
+
proj_drop_rate: float = 0.,
|
| 362 |
+
attn_drop_rate: float = 0.,
|
| 363 |
+
drop_path_rate: float = 0.,
|
| 364 |
+
norm_layer: Callable = LayerNorm,
|
| 365 |
+
init_values: Optional[float] = None,
|
| 366 |
+
class_token: bool = True,
|
| 367 |
+
use_abs_pos_emb: bool = True,
|
| 368 |
+
use_rot_pos_emb: bool = False,
|
| 369 |
+
use_post_norm: bool = False,
|
| 370 |
+
dynamic_img_size: bool = False,
|
| 371 |
+
dynamic_img_pad: bool = False,
|
| 372 |
+
ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None,
|
| 373 |
+
head_init_scale: float = 0.001,
|
| 374 |
+
):
|
| 375 |
+
"""
|
| 376 |
+
|
| 377 |
+
Args:
|
| 378 |
+
img_size:
|
| 379 |
+
patch_size:
|
| 380 |
+
in_chans:
|
| 381 |
+
num_classes:
|
| 382 |
+
global_pool:
|
| 383 |
+
embed_dim:
|
| 384 |
+
depth:
|
| 385 |
+
num_heads:
|
| 386 |
+
qkv_bias:
|
| 387 |
+
qkv_fused:
|
| 388 |
+
mlp_ratio:
|
| 389 |
+
swiglu_mlp:
|
| 390 |
+
scale_mlp:
|
| 391 |
+
scale_attn_inner:
|
| 392 |
+
drop_rate:
|
| 393 |
+
pos_drop_rate:
|
| 394 |
+
proj_drop_rate:
|
| 395 |
+
attn_drop_rate:
|
| 396 |
+
drop_path_rate:
|
| 397 |
+
norm_layer:
|
| 398 |
+
init_values:
|
| 399 |
+
class_token:
|
| 400 |
+
use_abs_pos_emb:
|
| 401 |
+
use_rot_pos_emb:
|
| 402 |
+
use_post_norm:
|
| 403 |
+
ref_feat_shape:
|
| 404 |
+
head_init_scale:
|
| 405 |
+
"""
|
| 406 |
+
super().__init__()
|
| 407 |
+
self.num_classes = num_classes
|
| 408 |
+
self.global_pool = global_pool
|
| 409 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 410 |
+
self.num_prefix_tokens = 1 if class_token else 0
|
| 411 |
+
self.dynamic_img_size = dynamic_img_size
|
| 412 |
+
self.grad_checkpointing = False
|
| 413 |
+
|
| 414 |
+
embed_args = {}
|
| 415 |
+
if dynamic_img_size:
|
| 416 |
+
# flatten deferred until after pos embed
|
| 417 |
+
embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
|
| 418 |
+
self.patch_embed = PatchEmbed(
|
| 419 |
+
img_size=img_size,
|
| 420 |
+
patch_size=patch_size,
|
| 421 |
+
in_chans=in_chans,
|
| 422 |
+
embed_dim=embed_dim,
|
| 423 |
+
dynamic_img_pad=dynamic_img_pad,
|
| 424 |
+
**embed_args,
|
| 425 |
+
)
|
| 426 |
+
num_patches = self.patch_embed.num_patches
|
| 427 |
+
|
| 428 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
| 429 |
+
|
| 430 |
+
self.pos_embed = nn.Parameter(
|
| 431 |
+
torch.zeros(1, num_patches + self.num_prefix_tokens, embed_dim)) if use_abs_pos_emb else None
|
| 432 |
+
self.pos_drop = nn.Dropout(p=pos_drop_rate)
|
| 433 |
+
if patch_drop_rate > 0:
|
| 434 |
+
self.patch_drop = PatchDropout(
|
| 435 |
+
patch_drop_rate,
|
| 436 |
+
num_prefix_tokens=self.num_prefix_tokens,
|
| 437 |
+
return_indices=True,
|
| 438 |
+
)
|
| 439 |
+
else:
|
| 440 |
+
self.patch_drop = None
|
| 441 |
+
|
| 442 |
+
if use_rot_pos_emb:
|
| 443 |
+
ref_feat_shape = to_2tuple(ref_feat_shape) if ref_feat_shape is not None else None
|
| 444 |
+
self.rope = RotaryEmbeddingCat(
|
| 445 |
+
embed_dim // num_heads,
|
| 446 |
+
in_pixels=False,
|
| 447 |
+
feat_shape=None if dynamic_img_size else self.patch_embed.grid_size,
|
| 448 |
+
ref_feat_shape=ref_feat_shape,
|
| 449 |
+
)
|
| 450 |
+
else:
|
| 451 |
+
self.rope = None
|
| 452 |
+
|
| 453 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 454 |
+
block_fn = EvaBlockPostNorm if use_post_norm else EvaBlock
|
| 455 |
+
self.blocks = nn.ModuleList([
|
| 456 |
+
block_fn(
|
| 457 |
+
dim=embed_dim,
|
| 458 |
+
num_heads=num_heads,
|
| 459 |
+
qkv_bias=qkv_bias,
|
| 460 |
+
qkv_fused=qkv_fused,
|
| 461 |
+
mlp_ratio=mlp_ratio,
|
| 462 |
+
swiglu_mlp=swiglu_mlp,
|
| 463 |
+
scale_mlp=scale_mlp,
|
| 464 |
+
scale_attn_inner=scale_attn_inner,
|
| 465 |
+
proj_drop=proj_drop_rate,
|
| 466 |
+
attn_drop=attn_drop_rate,
|
| 467 |
+
drop_path=dpr[i],
|
| 468 |
+
norm_layer=norm_layer,
|
| 469 |
+
init_values=init_values,
|
| 470 |
+
)
|
| 471 |
+
for i in range(depth)])
|
| 472 |
+
|
| 473 |
+
use_fc_norm = self.global_pool == 'avg'
|
| 474 |
+
self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
|
| 475 |
+
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
|
| 476 |
+
self.head_drop = nn.Dropout(drop_rate)
|
| 477 |
+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
| 478 |
+
|
| 479 |
+
self.apply(self._init_weights)
|
| 480 |
+
if self.pos_embed is not None:
|
| 481 |
+
trunc_normal_(self.pos_embed, std=.02)
|
| 482 |
+
if self.cls_token is not None:
|
| 483 |
+
trunc_normal_(self.cls_token, std=.02)
|
| 484 |
+
|
| 485 |
+
self.fix_init_weight()
|
| 486 |
+
if isinstance(self.head, nn.Linear):
|
| 487 |
+
trunc_normal_(self.head.weight, std=.02)
|
| 488 |
+
self.head.weight.data.mul_(head_init_scale)
|
| 489 |
+
self.head.bias.data.mul_(head_init_scale)
|
| 490 |
+
|
| 491 |
+
def fix_init_weight(self):
|
| 492 |
+
def rescale(param, layer_id):
|
| 493 |
+
param.div_(math.sqrt(2.0 * layer_id))
|
| 494 |
+
|
| 495 |
+
for layer_id, layer in enumerate(self.blocks):
|
| 496 |
+
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
| 497 |
+
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
| 498 |
+
|
| 499 |
+
def _init_weights(self, m):
|
| 500 |
+
if isinstance(m, nn.Linear):
|
| 501 |
+
trunc_normal_(m.weight, std=.02)
|
| 502 |
+
if m.bias is not None:
|
| 503 |
+
nn.init.zeros_(m.bias)
|
| 504 |
+
|
| 505 |
+
@torch.jit.ignore
|
| 506 |
+
def no_weight_decay(self):
|
| 507 |
+
nwd = {'pos_embed', 'cls_token'}
|
| 508 |
+
return nwd
|
| 509 |
+
|
| 510 |
+
@torch.jit.ignore
|
| 511 |
+
def set_grad_checkpointing(self, enable=True):
|
| 512 |
+
self.grad_checkpointing = enable
|
| 513 |
+
|
| 514 |
+
@torch.jit.ignore
|
| 515 |
+
def group_matcher(self, coarse=False):
|
| 516 |
+
matcher = dict(
|
| 517 |
+
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
|
| 518 |
+
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))],
|
| 519 |
+
)
|
| 520 |
+
return matcher
|
| 521 |
+
|
| 522 |
+
@torch.jit.ignore
|
| 523 |
+
def get_classifier(self):
|
| 524 |
+
return self.head
|
| 525 |
+
|
| 526 |
+
def reset_classifier(self, num_classes, global_pool=None):
|
| 527 |
+
self.num_classes = num_classes
|
| 528 |
+
if global_pool is not None:
|
| 529 |
+
self.global_pool = global_pool
|
| 530 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
| 531 |
+
|
| 532 |
+
def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 533 |
+
if self.dynamic_img_size:
|
| 534 |
+
B, H, W, C = x.shape
|
| 535 |
+
if self.pos_embed is not None:
|
| 536 |
+
pos_embed = resample_abs_pos_embed(
|
| 537 |
+
self.pos_embed,
|
| 538 |
+
(H, W),
|
| 539 |
+
num_prefix_tokens=self.num_prefix_tokens,
|
| 540 |
+
)
|
| 541 |
+
else:
|
| 542 |
+
pos_embed = None
|
| 543 |
+
x = x.view(B, -1, C)
|
| 544 |
+
rot_pos_embed = self.rope.get_embed(shape=(H, W)) if self.rope is not None else None
|
| 545 |
+
else:
|
| 546 |
+
pos_embed = self.pos_embed
|
| 547 |
+
rot_pos_embed = self.rope.get_embed() if self.rope is not None else None
|
| 548 |
+
|
| 549 |
+
if self.cls_token is not None:
|
| 550 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
| 551 |
+
if pos_embed is not None:
|
| 552 |
+
x = x + pos_embed
|
| 553 |
+
x = self.pos_drop(x)
|
| 554 |
+
|
| 555 |
+
# obtain shared rotary position embedding and apply patch dropout
|
| 556 |
+
if self.patch_drop is not None:
|
| 557 |
+
x, keep_indices = self.patch_drop(x)
|
| 558 |
+
if rot_pos_embed is not None and keep_indices is not None:
|
| 559 |
+
rot_pos_embed = apply_keep_indices_nlc(x, rot_pos_embed, keep_indices)
|
| 560 |
+
return x, rot_pos_embed
|
| 561 |
+
|
| 562 |
+
def forward_features(self, x):
|
| 563 |
+
x = self.patch_embed(x)
|
| 564 |
+
x, rot_pos_embed = self._pos_embed(x)
|
| 565 |
+
for blk in self.blocks:
|
| 566 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
| 567 |
+
x = checkpoint(blk, x, rope=rot_pos_embed)
|
| 568 |
+
else:
|
| 569 |
+
x = blk(x, rope=rot_pos_embed)
|
| 570 |
+
x = self.norm(x)
|
| 571 |
+
return x
|
| 572 |
+
|
| 573 |
+
def forward_head(self, x, pre_logits: bool = False):
|
| 574 |
+
if self.global_pool:
|
| 575 |
+
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
|
| 576 |
+
x = self.fc_norm(x)
|
| 577 |
+
x = self.head_drop(x)
|
| 578 |
+
return x if pre_logits else self.head(x)
|
| 579 |
+
|
| 580 |
+
def forward(self, x):
|
| 581 |
+
x = self.forward_features(x)
|
| 582 |
+
x = self.forward_head(x)
|
| 583 |
+
return x
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
def checkpoint_filter_fn(
|
| 587 |
+
state_dict,
|
| 588 |
+
model,
|
| 589 |
+
interpolation='bicubic',
|
| 590 |
+
antialias=True,
|
| 591 |
+
):
|
| 592 |
+
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
| 593 |
+
out_dict = {}
|
| 594 |
+
state_dict = state_dict.get('model_ema', state_dict)
|
| 595 |
+
state_dict = state_dict.get('model', state_dict)
|
| 596 |
+
state_dict = state_dict.get('module', state_dict)
|
| 597 |
+
state_dict = state_dict.get('state_dict', state_dict)
|
| 598 |
+
# prefix for loading OpenCLIP compatible weights
|
| 599 |
+
if 'visual.trunk.pos_embed' in state_dict:
|
| 600 |
+
prefix = 'visual.trunk.'
|
| 601 |
+
elif 'visual.pos_embed' in state_dict:
|
| 602 |
+
prefix = 'visual.'
|
| 603 |
+
else:
|
| 604 |
+
prefix = ''
|
| 605 |
+
mim_weights = prefix + 'mask_token' in state_dict
|
| 606 |
+
no_qkv = prefix + 'blocks.0.attn.q_proj.weight' in state_dict
|
| 607 |
+
|
| 608 |
+
len_prefix = len(prefix)
|
| 609 |
+
for k, v in state_dict.items():
|
| 610 |
+
if prefix:
|
| 611 |
+
if k.startswith(prefix):
|
| 612 |
+
k = k[len_prefix:]
|
| 613 |
+
else:
|
| 614 |
+
continue
|
| 615 |
+
|
| 616 |
+
if 'rope' in k:
|
| 617 |
+
# fixed embedding no need to load buffer from checkpoint
|
| 618 |
+
continue
|
| 619 |
+
|
| 620 |
+
if 'patch_embed.proj.weight' in k:
|
| 621 |
+
_, _, H, W = model.patch_embed.proj.weight.shape
|
| 622 |
+
if v.shape[-1] != W or v.shape[-2] != H:
|
| 623 |
+
v = resample_patch_embed(
|
| 624 |
+
v,
|
| 625 |
+
(H, W),
|
| 626 |
+
interpolation=interpolation,
|
| 627 |
+
antialias=antialias,
|
| 628 |
+
verbose=True,
|
| 629 |
+
)
|
| 630 |
+
elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
|
| 631 |
+
# To resize pos embedding when using model at different size from pretrained weights
|
| 632 |
+
num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1)
|
| 633 |
+
v = resample_abs_pos_embed(
|
| 634 |
+
v,
|
| 635 |
+
new_size=model.patch_embed.grid_size,
|
| 636 |
+
num_prefix_tokens=num_prefix_tokens,
|
| 637 |
+
interpolation=interpolation,
|
| 638 |
+
antialias=antialias,
|
| 639 |
+
verbose=True,
|
| 640 |
+
)
|
| 641 |
+
|
| 642 |
+
k = k.replace('mlp.ffn_ln', 'mlp.norm')
|
| 643 |
+
k = k.replace('attn.inner_attn_ln', 'attn.norm')
|
| 644 |
+
k = k.replace('mlp.w12', 'mlp.fc1')
|
| 645 |
+
k = k.replace('mlp.w1', 'mlp.fc1_g')
|
| 646 |
+
k = k.replace('mlp.w2', 'mlp.fc1_x')
|
| 647 |
+
k = k.replace('mlp.w3', 'mlp.fc2')
|
| 648 |
+
if no_qkv:
|
| 649 |
+
k = k.replace('q_bias', 'q_proj.bias')
|
| 650 |
+
k = k.replace('v_bias', 'v_proj.bias')
|
| 651 |
+
|
| 652 |
+
if mim_weights and k in ('mask_token', 'lm_head.weight', 'lm_head.bias', 'norm.weight', 'norm.bias'):
|
| 653 |
+
if k == 'norm.weight' or k == 'norm.bias':
|
| 654 |
+
# try moving norm -> fc norm on fine-tune, probably a better starting point than new init
|
| 655 |
+
k = k.replace('norm', 'fc_norm')
|
| 656 |
+
else:
|
| 657 |
+
# skip pretrain mask token & head weights
|
| 658 |
+
continue
|
| 659 |
+
|
| 660 |
+
out_dict[k] = v
|
| 661 |
+
|
| 662 |
+
return out_dict
|
| 663 |
+
|
| 664 |
+
|
| 665 |
+
def _create_eva(variant, pretrained=False, **kwargs):
|
| 666 |
+
if kwargs.get('features_only', None):
|
| 667 |
+
raise RuntimeError('features_only not implemented for Eva models.')
|
| 668 |
+
|
| 669 |
+
model = build_model_with_cfg(
|
| 670 |
+
Eva, variant, pretrained,
|
| 671 |
+
pretrained_filter_fn=checkpoint_filter_fn,
|
| 672 |
+
**kwargs)
|
| 673 |
+
return model
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
def _cfg(url='', **kwargs):
|
| 677 |
+
return {
|
| 678 |
+
'url': url,
|
| 679 |
+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
| 680 |
+
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
| 681 |
+
'mean': OPENAI_CLIP_MEAN, 'std': OPENAI_CLIP_STD,
|
| 682 |
+
'first_conv': 'patch_embed.proj', 'classifier': 'head',
|
| 683 |
+
'license': 'mit', **kwargs
|
| 684 |
+
}
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
default_cfgs = generate_default_cfgs({
|
| 688 |
+
|
| 689 |
+
# EVA 01 CLIP fine-tuned on imagenet-1k
|
| 690 |
+
'eva_giant_patch14_224.clip_ft_in1k': _cfg(
|
| 691 |
+
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt',
|
| 692 |
+
hf_hub_id='timm/',
|
| 693 |
+
),
|
| 694 |
+
'eva_giant_patch14_336.clip_ft_in1k': _cfg(
|
| 695 |
+
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt',
|
| 696 |
+
hf_hub_id='timm/',
|
| 697 |
+
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
|
| 698 |
+
|
| 699 |
+
# MIM EVA 01 pretrain, ft on in22k -> in1k
|
| 700 |
+
'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg(
|
| 701 |
+
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt',
|
| 702 |
+
hf_hub_id='timm/',
|
| 703 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
|
| 704 |
+
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
|
| 705 |
+
'eva_giant_patch14_560.m30m_ft_in22k_in1k': _cfg(
|
| 706 |
+
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt',
|
| 707 |
+
hf_hub_id='timm/',
|
| 708 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
|
| 709 |
+
input_size=(3, 560, 560), crop_pct=1.0, crop_mode='squash'),
|
| 710 |
+
|
| 711 |
+
# in22k or m38m MIM pretrain w/ intermediate in22k fine-tune and final in1k fine-tune
|
| 712 |
+
'eva02_base_patch14_448.mim_in22k_ft_in22k_in1k': _cfg(
|
| 713 |
+
# hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_B_pt_in21k_medft_in21k_ft_in1k_p14.pt',
|
| 714 |
+
hf_hub_id='timm/',
|
| 715 |
+
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash',
|
| 716 |
+
),
|
| 717 |
+
'eva02_large_patch14_448.mim_in22k_ft_in22k_in1k': _cfg(
|
| 718 |
+
# hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_L_pt_in21k_medft_in21k_ft_in1k_p14.pt',
|
| 719 |
+
hf_hub_id='timm/',
|
| 720 |
+
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash',
|
| 721 |
+
),
|
| 722 |
+
'eva02_large_patch14_448.mim_m38m_ft_in22k_in1k': _cfg(
|
| 723 |
+
hf_hub_id='timm/',
|
| 724 |
+
#hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_L_pt_m38m_medft_in21k_ft_in1k_p14.pt',
|
| 725 |
+
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash',
|
| 726 |
+
),
|
| 727 |
+
|
| 728 |
+
# in22k or m3m MIM pretrain w/ in1k fine-tune
|
| 729 |
+
'eva02_tiny_patch14_336.mim_in22k_ft_in1k': _cfg(
|
| 730 |
+
#hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in1k/eva02_Ti_pt_in21k_ft_in1k_p14.pt',
|
| 731 |
+
hf_hub_id='timm/',
|
| 732 |
+
input_size=(3, 336, 336), crop_pct=1.0,
|
| 733 |
+
),
|
| 734 |
+
'eva02_small_patch14_336.mim_in22k_ft_in1k': _cfg(
|
| 735 |
+
#hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in1k/eva02_S_pt_in21k_ft_in1k_p14.pt',
|
| 736 |
+
hf_hub_id='timm/',
|
| 737 |
+
input_size=(3, 336, 336), crop_pct=1.0,
|
| 738 |
+
),
|
| 739 |
+
'eva02_base_patch14_448.mim_in22k_ft_in1k': _cfg(
|
| 740 |
+
#hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in1k/eva02_B_pt_in21k_ft_in1k_p14.pt',
|
| 741 |
+
hf_hub_id='timm/',
|
| 742 |
+
input_size=(3, 448, 448), crop_pct=1.0,
|
| 743 |
+
),
|
| 744 |
+
'eva02_large_patch14_448.mim_in22k_ft_in1k': _cfg(
|
| 745 |
+
#hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in1k/eva02_L_pt_in21k_ft_in1k_p14.pt',
|
| 746 |
+
hf_hub_id='timm/',
|
| 747 |
+
input_size=(3, 448, 448), crop_pct=1.0,
|
| 748 |
+
),
|
| 749 |
+
'eva02_large_patch14_448.mim_m38m_ft_in1k': _cfg(
|
| 750 |
+
#hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in1k/eva02_L_pt_m38m_ft_in1k_p14.pt',
|
| 751 |
+
hf_hub_id='timm/',
|
| 752 |
+
input_size=(3, 448, 448), crop_pct=1.0,
|
| 753 |
+
),
|
| 754 |
+
|
| 755 |
+
# in22k or m3m MIM pretrain w/ in22k fine-tune
|
| 756 |
+
'eva02_base_patch14_448.mim_in22k_ft_in22k': _cfg(
|
| 757 |
+
#hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_B_pt_in21k_medft_in21k_p14.pt',
|
| 758 |
+
hf_hub_id='timm/',
|
| 759 |
+
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', num_classes=21841,
|
| 760 |
+
),
|
| 761 |
+
'eva02_large_patch14_448.mim_in22k_ft_in22k': _cfg(
|
| 762 |
+
#hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_L_pt_in21k_medft_in21k_p14.pt',
|
| 763 |
+
hf_hub_id='timm/',
|
| 764 |
+
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', num_classes=21841,
|
| 765 |
+
),
|
| 766 |
+
'eva02_large_patch14_448.mim_m38m_ft_in22k': _cfg(
|
| 767 |
+
#hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_L_pt_m38m_medft_in21k_p14.pt',
|
| 768 |
+
hf_hub_id='timm/',
|
| 769 |
+
input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', num_classes=21841,
|
| 770 |
+
),
|
| 771 |
+
|
| 772 |
+
# in22k or m38m MIM pretrain
|
| 773 |
+
'eva02_tiny_patch14_224.mim_in22k': _cfg(
|
| 774 |
+
# hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_Ti_pt_in21k_p14.pt',
|
| 775 |
+
hf_hub_id='timm/',
|
| 776 |
+
num_classes=0,
|
| 777 |
+
),
|
| 778 |
+
'eva02_small_patch14_224.mim_in22k': _cfg(
|
| 779 |
+
#hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_S_pt_in21k_p14.pt',
|
| 780 |
+
hf_hub_id='timm/',
|
| 781 |
+
num_classes=0,
|
| 782 |
+
),
|
| 783 |
+
'eva02_base_patch14_224.mim_in22k': _cfg(
|
| 784 |
+
#hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_B_pt_in21k_p14.pt',
|
| 785 |
+
hf_hub_id='timm/',
|
| 786 |
+
num_classes=0,
|
| 787 |
+
),
|
| 788 |
+
'eva02_large_patch14_224.mim_in22k': _cfg(
|
| 789 |
+
#hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_L_pt_in21k_p14.pt',
|
| 790 |
+
hf_hub_id='timm/',
|
| 791 |
+
num_classes=0,
|
| 792 |
+
),
|
| 793 |
+
'eva02_large_patch14_224.mim_m38m': _cfg(
|
| 794 |
+
#hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_L_pt_m38m_p14.pt',
|
| 795 |
+
hf_hub_id='timm/',
|
| 796 |
+
num_classes=0,
|
| 797 |
+
),
|
| 798 |
+
|
| 799 |
+
# EVA01 and EVA02 CLIP image towers
|
| 800 |
+
'eva_giant_patch14_clip_224.laion400m': _cfg(
|
| 801 |
+
# hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA01_CLIP_g_14_plus_psz14_s11B.pt',
|
| 802 |
+
hf_hub_id='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k', # float16 weights
|
| 803 |
+
hf_hub_filename='open_clip_pytorch_model.bin',
|
| 804 |
+
num_classes=1024,
|
| 805 |
+
),
|
| 806 |
+
'eva_giant_patch14_clip_224.merged2b': _cfg(
|
| 807 |
+
# hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA01_CLIP_g_14_plus_psz14_s11B.pt',
|
| 808 |
+
hf_hub_id='timm/eva_giant_patch14_plus_clip_224.merged2b_s11b_b114k', # float16 weights
|
| 809 |
+
hf_hub_filename='open_clip_pytorch_model.bin',
|
| 810 |
+
num_classes=1024,
|
| 811 |
+
),
|
| 812 |
+
'eva02_base_patch16_clip_224.merged2b': _cfg(
|
| 813 |
+
# hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_L_psz14_s4B.pt',
|
| 814 |
+
hf_hub_id='timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k', # float16 weights
|
| 815 |
+
hf_hub_filename='open_clip_pytorch_model.bin',
|
| 816 |
+
num_classes=512,
|
| 817 |
+
),
|
| 818 |
+
'eva02_large_patch14_clip_224.merged2b': _cfg(
|
| 819 |
+
# hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_L_psz14_s4B.pt',
|
| 820 |
+
hf_hub_id='timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k', # float16 weights
|
| 821 |
+
hf_hub_filename='open_clip_pytorch_model.bin',
|
| 822 |
+
num_classes=768,
|
| 823 |
+
),
|
| 824 |
+
'eva02_large_patch14_clip_336.merged2b': _cfg(
|
| 825 |
+
# hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_L_psz14_s4B.pt',
|
| 826 |
+
hf_hub_id='timm/eva02_large_patch14_clip_336.merged2b_s6b_b61k', # float16 weights
|
| 827 |
+
hf_hub_filename='open_clip_pytorch_model.bin',
|
| 828 |
+
input_size=(3, 336, 336), crop_pct=1.0,
|
| 829 |
+
num_classes=768,
|
| 830 |
+
),
|
| 831 |
+
'eva02_enormous_patch14_clip_224.laion2b': _cfg(
|
| 832 |
+
# hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_E_psz14_plus_s9B.pt',
|
| 833 |
+
hf_hub_id='timm/eva02_enormous_patch14_clip_224.laion2b_s4b_b115k', # float16 weights
|
| 834 |
+
hf_hub_filename='open_clip_pytorch_model.bin',
|
| 835 |
+
num_classes=1024,
|
| 836 |
+
),
|
| 837 |
+
'eva02_enormous_patch14_clip_224.laion2b_plus': _cfg(
|
| 838 |
+
# hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_E_psz14_plus_s9B.pt',
|
| 839 |
+
hf_hub_id='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k', # bfloat16 weights
|
| 840 |
+
hf_hub_filename='open_clip_pytorch_model.bin',
|
| 841 |
+
num_classes=1024,
|
| 842 |
+
),
|
| 843 |
+
'eva02_enormous_patch14_clip_224.pretrain': _cfg(
|
| 844 |
+
# hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_E_psz14.pt',
|
| 845 |
+
num_classes=0,
|
| 846 |
+
),
|
| 847 |
+
|
| 848 |
+
})
|
| 849 |
+
|
| 850 |
+
|
| 851 |
+
@register_model
|
| 852 |
+
def eva_giant_patch14_224(pretrained=False, **kwargs) -> Eva:
|
| 853 |
+
""" EVA-g model https://arxiv.org/abs/2211.07636 """
|
| 854 |
+
model_args = dict(patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408)
|
| 855 |
+
model = _create_eva('eva_giant_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 856 |
+
return model
|
| 857 |
+
|
| 858 |
+
|
| 859 |
+
@register_model
|
| 860 |
+
def eva_giant_patch14_336(pretrained=False, **kwargs) -> Eva:
|
| 861 |
+
""" EVA-g model https://arxiv.org/abs/2211.07636 """
|
| 862 |
+
model_args = dict(patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408)
|
| 863 |
+
model = _create_eva('eva_giant_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 864 |
+
return model
|
| 865 |
+
|
| 866 |
+
|
| 867 |
+
@register_model
|
| 868 |
+
def eva_giant_patch14_560(pretrained=False, **kwargs) -> Eva:
|
| 869 |
+
""" EVA-g model https://arxiv.org/abs/2211.07636 """
|
| 870 |
+
model_args = dict(patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408)
|
| 871 |
+
model = _create_eva('eva_giant_patch14_560', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 872 |
+
return model
|
| 873 |
+
|
| 874 |
+
|
| 875 |
+
@register_model
|
| 876 |
+
def eva02_tiny_patch14_224(pretrained=False, **kwargs) -> Eva:
|
| 877 |
+
model_args = dict(
|
| 878 |
+
img_size=224,
|
| 879 |
+
patch_size=14,
|
| 880 |
+
embed_dim=192,
|
| 881 |
+
depth=12,
|
| 882 |
+
num_heads=3,
|
| 883 |
+
mlp_ratio=4 * 2 / 3,
|
| 884 |
+
swiglu_mlp=True,
|
| 885 |
+
use_rot_pos_emb=True,
|
| 886 |
+
ref_feat_shape=(16, 16), # 224/14
|
| 887 |
+
)
|
| 888 |
+
model = _create_eva('eva02_tiny_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 889 |
+
return model
|
| 890 |
+
|
| 891 |
+
|
| 892 |
+
@register_model
|
| 893 |
+
def eva02_small_patch14_224(pretrained=False, **kwargs) -> Eva:
|
| 894 |
+
model_args = dict(
|
| 895 |
+
img_size=224,
|
| 896 |
+
patch_size=14,
|
| 897 |
+
embed_dim=384,
|
| 898 |
+
depth=12,
|
| 899 |
+
num_heads=6,
|
| 900 |
+
mlp_ratio=4 * 2 / 3,
|
| 901 |
+
swiglu_mlp=True,
|
| 902 |
+
use_rot_pos_emb=True,
|
| 903 |
+
ref_feat_shape=(16, 16), # 224/14
|
| 904 |
+
)
|
| 905 |
+
model = _create_eva('eva02_small_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 906 |
+
return model
|
| 907 |
+
|
| 908 |
+
|
| 909 |
+
@register_model
|
| 910 |
+
def eva02_base_patch14_224(pretrained=False, **kwargs) -> Eva:
|
| 911 |
+
model_args = dict(
|
| 912 |
+
img_size=224,
|
| 913 |
+
patch_size=14,
|
| 914 |
+
embed_dim=768,
|
| 915 |
+
depth=12,
|
| 916 |
+
num_heads=12,
|
| 917 |
+
qkv_fused=False,
|
| 918 |
+
mlp_ratio=4 * 2 / 3,
|
| 919 |
+
swiglu_mlp=True,
|
| 920 |
+
scale_mlp=True,
|
| 921 |
+
use_rot_pos_emb=True,
|
| 922 |
+
ref_feat_shape=(16, 16), # 224/14
|
| 923 |
+
)
|
| 924 |
+
model = _create_eva('eva02_base_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 925 |
+
return model
|
| 926 |
+
|
| 927 |
+
|
| 928 |
+
@register_model
|
| 929 |
+
def eva02_large_patch14_224(pretrained=False, **kwargs) -> Eva:
|
| 930 |
+
model_args = dict(
|
| 931 |
+
img_size=224,
|
| 932 |
+
patch_size=14,
|
| 933 |
+
embed_dim=1024,
|
| 934 |
+
depth=24,
|
| 935 |
+
num_heads=16,
|
| 936 |
+
mlp_ratio=4 * 2 / 3,
|
| 937 |
+
qkv_fused=False,
|
| 938 |
+
swiglu_mlp=True,
|
| 939 |
+
scale_mlp=True,
|
| 940 |
+
use_rot_pos_emb=True,
|
| 941 |
+
ref_feat_shape=(16, 16), # 224/14
|
| 942 |
+
)
|
| 943 |
+
model = _create_eva('eva02_large_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 944 |
+
return model
|
| 945 |
+
|
| 946 |
+
|
| 947 |
+
@register_model
|
| 948 |
+
def eva02_tiny_patch14_336(pretrained=False, **kwargs) -> Eva:
|
| 949 |
+
model_args = dict(
|
| 950 |
+
img_size=336,
|
| 951 |
+
patch_size=14,
|
| 952 |
+
embed_dim=192,
|
| 953 |
+
depth=12,
|
| 954 |
+
num_heads=3,
|
| 955 |
+
mlp_ratio=4 * 2 / 3,
|
| 956 |
+
swiglu_mlp=True,
|
| 957 |
+
use_rot_pos_emb=True,
|
| 958 |
+
ref_feat_shape=(16, 16), # 224/14
|
| 959 |
+
)
|
| 960 |
+
model = _create_eva('eva02_tiny_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 961 |
+
return model
|
| 962 |
+
|
| 963 |
+
|
| 964 |
+
@register_model
|
| 965 |
+
def eva02_small_patch14_336(pretrained=False, **kwargs) -> Eva:
|
| 966 |
+
model_args = dict(
|
| 967 |
+
img_size=336,
|
| 968 |
+
patch_size=14,
|
| 969 |
+
embed_dim=384,
|
| 970 |
+
depth=12,
|
| 971 |
+
num_heads=6,
|
| 972 |
+
mlp_ratio=4 * 2 / 3,
|
| 973 |
+
swiglu_mlp=True,
|
| 974 |
+
use_rot_pos_emb=True,
|
| 975 |
+
ref_feat_shape=(16, 16), # 224/14
|
| 976 |
+
)
|
| 977 |
+
model = _create_eva('eva02_small_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 978 |
+
return model
|
| 979 |
+
|
| 980 |
+
|
| 981 |
+
@register_model
|
| 982 |
+
def eva02_base_patch14_448(pretrained=False, **kwargs) -> Eva:
|
| 983 |
+
model_args = dict(
|
| 984 |
+
img_size=448,
|
| 985 |
+
patch_size=14,
|
| 986 |
+
embed_dim=768,
|
| 987 |
+
depth=12,
|
| 988 |
+
num_heads=12,
|
| 989 |
+
qkv_fused=False,
|
| 990 |
+
mlp_ratio=4 * 2 / 3,
|
| 991 |
+
swiglu_mlp=True,
|
| 992 |
+
scale_mlp=True,
|
| 993 |
+
use_rot_pos_emb=True,
|
| 994 |
+
ref_feat_shape=(16, 16), # 224/14
|
| 995 |
+
)
|
| 996 |
+
model = _create_eva('eva02_base_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 997 |
+
return model
|
| 998 |
+
|
| 999 |
+
|
| 1000 |
+
@register_model
|
| 1001 |
+
def eva02_large_patch14_448(pretrained=False, **kwargs) -> Eva:
|
| 1002 |
+
model_args = dict(
|
| 1003 |
+
img_size=448,
|
| 1004 |
+
patch_size=14,
|
| 1005 |
+
embed_dim=1024,
|
| 1006 |
+
depth=24,
|
| 1007 |
+
num_heads=16,
|
| 1008 |
+
mlp_ratio=4 * 2 / 3,
|
| 1009 |
+
qkv_fused=False,
|
| 1010 |
+
swiglu_mlp=True,
|
| 1011 |
+
scale_mlp=True,
|
| 1012 |
+
use_rot_pos_emb=True,
|
| 1013 |
+
ref_feat_shape=(16, 16), # 224/14
|
| 1014 |
+
)
|
| 1015 |
+
model = _create_eva('eva02_large_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 1016 |
+
return model
|
| 1017 |
+
|
| 1018 |
+
|
| 1019 |
+
@register_model
|
| 1020 |
+
def eva_giant_patch14_clip_224(pretrained=False, **kwargs) -> Eva:
|
| 1021 |
+
""" EVA-g CLIP model (only difference from non-CLIP is the pooling) """
|
| 1022 |
+
model_args = dict(
|
| 1023 |
+
patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408,
|
| 1024 |
+
global_pool=kwargs.pop('global_pool', 'token'))
|
| 1025 |
+
model = _create_eva('eva_giant_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 1026 |
+
return model
|
| 1027 |
+
|
| 1028 |
+
|
| 1029 |
+
@register_model
|
| 1030 |
+
def eva02_base_patch16_clip_224(pretrained=False, **kwargs) -> Eva:
|
| 1031 |
+
""" A EVA-CLIP specific variant that adds additional attn scale layernorm to eva02_base """
|
| 1032 |
+
model_args = dict(
|
| 1033 |
+
img_size=224,
|
| 1034 |
+
patch_size=16,
|
| 1035 |
+
embed_dim=768,
|
| 1036 |
+
depth=12,
|
| 1037 |
+
num_heads=12,
|
| 1038 |
+
qkv_fused=False,
|
| 1039 |
+
mlp_ratio=4 * 2 / 3,
|
| 1040 |
+
swiglu_mlp=True,
|
| 1041 |
+
scale_mlp=True,
|
| 1042 |
+
scale_attn_inner=True,
|
| 1043 |
+
use_rot_pos_emb=True,
|
| 1044 |
+
ref_feat_shape=(16, 16), # 224/14
|
| 1045 |
+
global_pool=kwargs.pop('global_pool', 'token'),
|
| 1046 |
+
)
|
| 1047 |
+
model = _create_eva('eva02_base_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 1048 |
+
return model
|
| 1049 |
+
|
| 1050 |
+
|
| 1051 |
+
@register_model
|
| 1052 |
+
def eva02_large_patch14_clip_224(pretrained=False, **kwargs) -> Eva:
|
| 1053 |
+
""" A EVA-CLIP specific variant that adds additional attn scale layernorm to eva02_large """
|
| 1054 |
+
model_args = dict(
|
| 1055 |
+
img_size=224,
|
| 1056 |
+
patch_size=14,
|
| 1057 |
+
embed_dim=1024,
|
| 1058 |
+
depth=24,
|
| 1059 |
+
num_heads=16,
|
| 1060 |
+
mlp_ratio=4 * 2 / 3,
|
| 1061 |
+
qkv_fused=False,
|
| 1062 |
+
swiglu_mlp=True,
|
| 1063 |
+
scale_mlp=True,
|
| 1064 |
+
scale_attn_inner=True,
|
| 1065 |
+
use_rot_pos_emb=True,
|
| 1066 |
+
ref_feat_shape=(16, 16), # 224/14
|
| 1067 |
+
global_pool=kwargs.pop('global_pool', 'token'),
|
| 1068 |
+
)
|
| 1069 |
+
model = _create_eva('eva02_large_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 1070 |
+
return model
|
| 1071 |
+
|
| 1072 |
+
|
| 1073 |
+
@register_model
|
| 1074 |
+
def eva02_large_patch14_clip_336(pretrained=False, **kwargs) -> Eva:
|
| 1075 |
+
""" A EVA-CLIP specific variant that adds additional attn scale layernorm to eva02_large """
|
| 1076 |
+
model_args = dict(
|
| 1077 |
+
img_size=336,
|
| 1078 |
+
patch_size=14,
|
| 1079 |
+
embed_dim=1024,
|
| 1080 |
+
depth=24,
|
| 1081 |
+
num_heads=16,
|
| 1082 |
+
mlp_ratio=4 * 2 / 3,
|
| 1083 |
+
qkv_fused=False,
|
| 1084 |
+
swiglu_mlp=True,
|
| 1085 |
+
scale_mlp=True,
|
| 1086 |
+
scale_attn_inner=True,
|
| 1087 |
+
use_rot_pos_emb=True,
|
| 1088 |
+
ref_feat_shape=(16, 16), # 224/14
|
| 1089 |
+
global_pool=kwargs.pop('global_pool', 'token'),
|
| 1090 |
+
)
|
| 1091 |
+
model = _create_eva('eva02_large_patch14_clip_336', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 1092 |
+
return model
|
| 1093 |
+
|
| 1094 |
+
|
| 1095 |
+
@register_model
|
| 1096 |
+
def eva02_enormous_patch14_clip_224(pretrained=False, **kwargs) -> Eva:
|
| 1097 |
+
""" A EVA-CLIP specific variant that uses residual post-norm in blocks """
|
| 1098 |
+
model_args = dict(
|
| 1099 |
+
img_size=224,
|
| 1100 |
+
patch_size=14,
|
| 1101 |
+
embed_dim=1792,
|
| 1102 |
+
depth=64,
|
| 1103 |
+
num_heads=16,
|
| 1104 |
+
mlp_ratio=15360 / 1792,
|
| 1105 |
+
use_post_norm=True,
|
| 1106 |
+
global_pool=kwargs.pop('global_pool', 'token'),
|
| 1107 |
+
)
|
| 1108 |
+
model = _create_eva('eva02_enormous_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
| 1109 |
+
return model
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/factory.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ._factory import *
|
| 2 |
+
|
| 3 |
+
import warnings
|
| 4 |
+
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/features.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ._features import *
|
| 2 |
+
|
| 3 |
+
import warnings
|
| 4 |
+
warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/gcvit.py
ADDED
|
@@ -0,0 +1,592 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Global Context ViT
|
| 2 |
+
|
| 3 |
+
From scratch implementation of GCViT in the style of timm swin_transformer_v2_cr.py
|
| 4 |
+
|
| 5 |
+
Global Context Vision Transformers -https://arxiv.org/abs/2206.09959
|
| 6 |
+
|
| 7 |
+
@article{hatamizadeh2022global,
|
| 8 |
+
title={Global Context Vision Transformers},
|
| 9 |
+
author={Hatamizadeh, Ali and Yin, Hongxu and Kautz, Jan and Molchanov, Pavlo},
|
| 10 |
+
journal={arXiv preprint arXiv:2206.09959},
|
| 11 |
+
year={2022}
|
| 12 |
+
}
|
| 13 |
+
|
| 14 |
+
Free of any code related to NVIDIA GCVit impl at https://github.com/NVlabs/GCVit.
|
| 15 |
+
The license for this code release is Apache 2.0 with no commercial restrictions.
|
| 16 |
+
|
| 17 |
+
However, weight files adapted from NVIDIA GCVit impl ARE under a non-commercial share-alike license
|
| 18 |
+
(https://creativecommons.org/licenses/by-nc-sa/4.0/) until I have a chance to train new ones...
|
| 19 |
+
|
| 20 |
+
Hacked together by / Copyright 2022, Ross Wightman
|
| 21 |
+
"""
|
| 22 |
+
import math
|
| 23 |
+
from functools import partial
|
| 24 |
+
from typing import Callable, List, Optional, Tuple, Union
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
import torch.utils.checkpoint as checkpoint
|
| 29 |
+
|
| 30 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 31 |
+
from timm.layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d, \
|
| 32 |
+
get_attn, get_act_layer, get_norm_layer, RelPosBias, _assert
|
| 33 |
+
from ._builder import build_model_with_cfg
|
| 34 |
+
from ._features_fx import register_notrace_function
|
| 35 |
+
from ._manipulate import named_apply
|
| 36 |
+
from ._registry import register_model, generate_default_cfgs
|
| 37 |
+
|
| 38 |
+
__all__ = ['GlobalContextVit']
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class MbConvBlock(nn.Module):
|
| 42 |
+
""" A depthwise separable / fused mbconv style residual block with SE, `no norm.
|
| 43 |
+
"""
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
in_chs,
|
| 47 |
+
out_chs=None,
|
| 48 |
+
expand_ratio=1.0,
|
| 49 |
+
attn_layer='se',
|
| 50 |
+
bias=False,
|
| 51 |
+
act_layer=nn.GELU,
|
| 52 |
+
):
|
| 53 |
+
super().__init__()
|
| 54 |
+
attn_kwargs = dict(act_layer=act_layer)
|
| 55 |
+
if isinstance(attn_layer, str) and attn_layer == 'se' or attn_layer == 'eca':
|
| 56 |
+
attn_kwargs['rd_ratio'] = 0.25
|
| 57 |
+
attn_kwargs['bias'] = False
|
| 58 |
+
attn_layer = get_attn(attn_layer)
|
| 59 |
+
out_chs = out_chs or in_chs
|
| 60 |
+
mid_chs = int(expand_ratio * in_chs)
|
| 61 |
+
|
| 62 |
+
self.conv_dw = nn.Conv2d(in_chs, mid_chs, 3, 1, 1, groups=in_chs, bias=bias)
|
| 63 |
+
self.act = act_layer()
|
| 64 |
+
self.se = attn_layer(mid_chs, **attn_kwargs)
|
| 65 |
+
self.conv_pw = nn.Conv2d(mid_chs, out_chs, 1, 1, 0, bias=bias)
|
| 66 |
+
|
| 67 |
+
def forward(self, x):
|
| 68 |
+
shortcut = x
|
| 69 |
+
x = self.conv_dw(x)
|
| 70 |
+
x = self.act(x)
|
| 71 |
+
x = self.se(x)
|
| 72 |
+
x = self.conv_pw(x)
|
| 73 |
+
x = x + shortcut
|
| 74 |
+
return x
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class Downsample2d(nn.Module):
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
dim,
|
| 81 |
+
dim_out=None,
|
| 82 |
+
reduction='conv',
|
| 83 |
+
act_layer=nn.GELU,
|
| 84 |
+
norm_layer=LayerNorm2d, # NOTE in NCHW
|
| 85 |
+
):
|
| 86 |
+
super().__init__()
|
| 87 |
+
dim_out = dim_out or dim
|
| 88 |
+
|
| 89 |
+
self.norm1 = norm_layer(dim) if norm_layer is not None else nn.Identity()
|
| 90 |
+
self.conv_block = MbConvBlock(dim, act_layer=act_layer)
|
| 91 |
+
assert reduction in ('conv', 'max', 'avg')
|
| 92 |
+
if reduction == 'conv':
|
| 93 |
+
self.reduction = nn.Conv2d(dim, dim_out, 3, 2, 1, bias=False)
|
| 94 |
+
elif reduction == 'max':
|
| 95 |
+
assert dim == dim_out
|
| 96 |
+
self.reduction = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 97 |
+
else:
|
| 98 |
+
assert dim == dim_out
|
| 99 |
+
self.reduction = nn.AvgPool2d(kernel_size=2)
|
| 100 |
+
self.norm2 = norm_layer(dim_out) if norm_layer is not None else nn.Identity()
|
| 101 |
+
|
| 102 |
+
def forward(self, x):
|
| 103 |
+
x = self.norm1(x)
|
| 104 |
+
x = self.conv_block(x)
|
| 105 |
+
x = self.reduction(x)
|
| 106 |
+
x = self.norm2(x)
|
| 107 |
+
return x
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class FeatureBlock(nn.Module):
|
| 111 |
+
def __init__(
|
| 112 |
+
self,
|
| 113 |
+
dim,
|
| 114 |
+
levels=0,
|
| 115 |
+
reduction='max',
|
| 116 |
+
act_layer=nn.GELU,
|
| 117 |
+
):
|
| 118 |
+
super().__init__()
|
| 119 |
+
reductions = levels
|
| 120 |
+
levels = max(1, levels)
|
| 121 |
+
if reduction == 'avg':
|
| 122 |
+
pool_fn = partial(nn.AvgPool2d, kernel_size=2)
|
| 123 |
+
else:
|
| 124 |
+
pool_fn = partial(nn.MaxPool2d, kernel_size=3, stride=2, padding=1)
|
| 125 |
+
self.blocks = nn.Sequential()
|
| 126 |
+
for i in range(levels):
|
| 127 |
+
self.blocks.add_module(f'conv{i+1}', MbConvBlock(dim, act_layer=act_layer))
|
| 128 |
+
if reductions:
|
| 129 |
+
self.blocks.add_module(f'pool{i+1}', pool_fn())
|
| 130 |
+
reductions -= 1
|
| 131 |
+
|
| 132 |
+
def forward(self, x):
|
| 133 |
+
return self.blocks(x)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class Stem(nn.Module):
|
| 137 |
+
def __init__(
|
| 138 |
+
self,
|
| 139 |
+
in_chs: int = 3,
|
| 140 |
+
out_chs: int = 96,
|
| 141 |
+
act_layer: Callable = nn.GELU,
|
| 142 |
+
norm_layer: Callable = LayerNorm2d, # NOTE stem in NCHW
|
| 143 |
+
):
|
| 144 |
+
super().__init__()
|
| 145 |
+
self.conv1 = nn.Conv2d(in_chs, out_chs, kernel_size=3, stride=2, padding=1)
|
| 146 |
+
self.down = Downsample2d(out_chs, act_layer=act_layer, norm_layer=norm_layer)
|
| 147 |
+
|
| 148 |
+
def forward(self, x):
|
| 149 |
+
x = self.conv1(x)
|
| 150 |
+
x = self.down(x)
|
| 151 |
+
return x
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class WindowAttentionGlobal(nn.Module):
|
| 155 |
+
|
| 156 |
+
def __init__(
|
| 157 |
+
self,
|
| 158 |
+
dim: int,
|
| 159 |
+
num_heads: int,
|
| 160 |
+
window_size: Tuple[int, int],
|
| 161 |
+
use_global: bool = True,
|
| 162 |
+
qkv_bias: bool = True,
|
| 163 |
+
attn_drop: float = 0.,
|
| 164 |
+
proj_drop: float = 0.,
|
| 165 |
+
):
|
| 166 |
+
super().__init__()
|
| 167 |
+
window_size = to_2tuple(window_size)
|
| 168 |
+
self.window_size = window_size
|
| 169 |
+
self.num_heads = num_heads
|
| 170 |
+
self.head_dim = dim // num_heads
|
| 171 |
+
self.scale = self.head_dim ** -0.5
|
| 172 |
+
self.use_global = use_global
|
| 173 |
+
|
| 174 |
+
self.rel_pos = RelPosBias(window_size=window_size, num_heads=num_heads)
|
| 175 |
+
if self.use_global:
|
| 176 |
+
self.qkv = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
| 177 |
+
else:
|
| 178 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 179 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 180 |
+
self.proj = nn.Linear(dim, dim)
|
| 181 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 182 |
+
|
| 183 |
+
def forward(self, x, q_global: Optional[torch.Tensor] = None):
|
| 184 |
+
B, N, C = x.shape
|
| 185 |
+
if self.use_global and q_global is not None:
|
| 186 |
+
_assert(x.shape[-1] == q_global.shape[-1], 'x and q_global seq lengths should be equal')
|
| 187 |
+
|
| 188 |
+
kv = self.qkv(x)
|
| 189 |
+
kv = kv.reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 190 |
+
k, v = kv.unbind(0)
|
| 191 |
+
|
| 192 |
+
q = q_global.repeat(B // q_global.shape[0], 1, 1, 1)
|
| 193 |
+
q = q.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
|
| 194 |
+
else:
|
| 195 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 196 |
+
q, k, v = qkv.unbind(0)
|
| 197 |
+
q = q * self.scale
|
| 198 |
+
|
| 199 |
+
attn = q @ k.transpose(-2, -1).contiguous() # NOTE contiguous() fixes an odd jit bug in PyTorch 2.0
|
| 200 |
+
attn = self.rel_pos(attn)
|
| 201 |
+
attn = attn.softmax(dim=-1)
|
| 202 |
+
attn = self.attn_drop(attn)
|
| 203 |
+
|
| 204 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 205 |
+
x = self.proj(x)
|
| 206 |
+
x = self.proj_drop(x)
|
| 207 |
+
return x
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def window_partition(x, window_size: Tuple[int, int]):
|
| 211 |
+
B, H, W, C = x.shape
|
| 212 |
+
x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
|
| 213 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
|
| 214 |
+
return windows
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
@register_notrace_function # reason: int argument is a Proxy
|
| 218 |
+
def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]):
|
| 219 |
+
H, W = img_size
|
| 220 |
+
C = windows.shape[-1]
|
| 221 |
+
x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
|
| 222 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
|
| 223 |
+
return x
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class LayerScale(nn.Module):
|
| 227 |
+
def __init__(self, dim, init_values=1e-5, inplace=False):
|
| 228 |
+
super().__init__()
|
| 229 |
+
self.inplace = inplace
|
| 230 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 231 |
+
|
| 232 |
+
def forward(self, x):
|
| 233 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class GlobalContextVitBlock(nn.Module):
|
| 237 |
+
def __init__(
|
| 238 |
+
self,
|
| 239 |
+
dim: int,
|
| 240 |
+
feat_size: Tuple[int, int],
|
| 241 |
+
num_heads: int,
|
| 242 |
+
window_size: int = 7,
|
| 243 |
+
mlp_ratio: float = 4.,
|
| 244 |
+
use_global: bool = True,
|
| 245 |
+
qkv_bias: bool = True,
|
| 246 |
+
layer_scale: Optional[float] = None,
|
| 247 |
+
proj_drop: float = 0.,
|
| 248 |
+
attn_drop: float = 0.,
|
| 249 |
+
drop_path: float = 0.,
|
| 250 |
+
attn_layer: Callable = WindowAttentionGlobal,
|
| 251 |
+
act_layer: Callable = nn.GELU,
|
| 252 |
+
norm_layer: Callable = nn.LayerNorm,
|
| 253 |
+
):
|
| 254 |
+
super().__init__()
|
| 255 |
+
feat_size = to_2tuple(feat_size)
|
| 256 |
+
window_size = to_2tuple(window_size)
|
| 257 |
+
self.window_size = window_size
|
| 258 |
+
self.num_windows = int((feat_size[0] // window_size[0]) * (feat_size[1] // window_size[1]))
|
| 259 |
+
|
| 260 |
+
self.norm1 = norm_layer(dim)
|
| 261 |
+
self.attn = attn_layer(
|
| 262 |
+
dim,
|
| 263 |
+
num_heads=num_heads,
|
| 264 |
+
window_size=window_size,
|
| 265 |
+
use_global=use_global,
|
| 266 |
+
qkv_bias=qkv_bias,
|
| 267 |
+
attn_drop=attn_drop,
|
| 268 |
+
proj_drop=proj_drop,
|
| 269 |
+
)
|
| 270 |
+
self.ls1 = LayerScale(dim, layer_scale) if layer_scale is not None else nn.Identity()
|
| 271 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 272 |
+
|
| 273 |
+
self.norm2 = norm_layer(dim)
|
| 274 |
+
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop)
|
| 275 |
+
self.ls2 = LayerScale(dim, layer_scale) if layer_scale is not None else nn.Identity()
|
| 276 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 277 |
+
|
| 278 |
+
def _window_attn(self, x, q_global: Optional[torch.Tensor] = None):
|
| 279 |
+
B, H, W, C = x.shape
|
| 280 |
+
x_win = window_partition(x, self.window_size)
|
| 281 |
+
x_win = x_win.view(-1, self.window_size[0] * self.window_size[1], C)
|
| 282 |
+
attn_win = self.attn(x_win, q_global)
|
| 283 |
+
x = window_reverse(attn_win, self.window_size, (H, W))
|
| 284 |
+
return x
|
| 285 |
+
|
| 286 |
+
def forward(self, x, q_global: Optional[torch.Tensor] = None):
|
| 287 |
+
x = x + self.drop_path1(self.ls1(self._window_attn(self.norm1(x), q_global)))
|
| 288 |
+
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
|
| 289 |
+
return x
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
class GlobalContextVitStage(nn.Module):
|
| 293 |
+
def __init__(
|
| 294 |
+
self,
|
| 295 |
+
dim,
|
| 296 |
+
depth: int,
|
| 297 |
+
num_heads: int,
|
| 298 |
+
feat_size: Tuple[int, int],
|
| 299 |
+
window_size: Tuple[int, int],
|
| 300 |
+
downsample: bool = True,
|
| 301 |
+
global_norm: bool = False,
|
| 302 |
+
stage_norm: bool = False,
|
| 303 |
+
mlp_ratio: float = 4.,
|
| 304 |
+
qkv_bias: bool = True,
|
| 305 |
+
layer_scale: Optional[float] = None,
|
| 306 |
+
proj_drop: float = 0.,
|
| 307 |
+
attn_drop: float = 0.,
|
| 308 |
+
drop_path: Union[List[float], float] = 0.0,
|
| 309 |
+
act_layer: Callable = nn.GELU,
|
| 310 |
+
norm_layer: Callable = nn.LayerNorm,
|
| 311 |
+
norm_layer_cl: Callable = LayerNorm2d,
|
| 312 |
+
):
|
| 313 |
+
super().__init__()
|
| 314 |
+
if downsample:
|
| 315 |
+
self.downsample = Downsample2d(
|
| 316 |
+
dim=dim,
|
| 317 |
+
dim_out=dim * 2,
|
| 318 |
+
norm_layer=norm_layer,
|
| 319 |
+
)
|
| 320 |
+
dim = dim * 2
|
| 321 |
+
feat_size = (feat_size[0] // 2, feat_size[1] // 2)
|
| 322 |
+
else:
|
| 323 |
+
self.downsample = nn.Identity()
|
| 324 |
+
self.feat_size = feat_size
|
| 325 |
+
window_size = to_2tuple(window_size)
|
| 326 |
+
|
| 327 |
+
feat_levels = int(math.log2(min(feat_size) / min(window_size)))
|
| 328 |
+
self.global_block = FeatureBlock(dim, feat_levels)
|
| 329 |
+
self.global_norm = norm_layer_cl(dim) if global_norm else nn.Identity()
|
| 330 |
+
|
| 331 |
+
self.blocks = nn.ModuleList([
|
| 332 |
+
GlobalContextVitBlock(
|
| 333 |
+
dim=dim,
|
| 334 |
+
num_heads=num_heads,
|
| 335 |
+
feat_size=feat_size,
|
| 336 |
+
window_size=window_size,
|
| 337 |
+
mlp_ratio=mlp_ratio,
|
| 338 |
+
qkv_bias=qkv_bias,
|
| 339 |
+
use_global=(i % 2 != 0),
|
| 340 |
+
layer_scale=layer_scale,
|
| 341 |
+
proj_drop=proj_drop,
|
| 342 |
+
attn_drop=attn_drop,
|
| 343 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
| 344 |
+
act_layer=act_layer,
|
| 345 |
+
norm_layer=norm_layer_cl,
|
| 346 |
+
)
|
| 347 |
+
for i in range(depth)
|
| 348 |
+
])
|
| 349 |
+
self.norm = norm_layer_cl(dim) if stage_norm else nn.Identity()
|
| 350 |
+
self.dim = dim
|
| 351 |
+
self.feat_size = feat_size
|
| 352 |
+
self.grad_checkpointing = False
|
| 353 |
+
|
| 354 |
+
def forward(self, x):
|
| 355 |
+
# input NCHW, downsample & global block are 2d conv + pooling
|
| 356 |
+
x = self.downsample(x)
|
| 357 |
+
global_query = self.global_block(x)
|
| 358 |
+
|
| 359 |
+
# reshape NCHW --> NHWC for transformer blocks
|
| 360 |
+
x = x.permute(0, 2, 3, 1)
|
| 361 |
+
global_query = self.global_norm(global_query.permute(0, 2, 3, 1))
|
| 362 |
+
for blk in self.blocks:
|
| 363 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
| 364 |
+
x = checkpoint.checkpoint(blk, x)
|
| 365 |
+
else:
|
| 366 |
+
x = blk(x, global_query)
|
| 367 |
+
x = self.norm(x)
|
| 368 |
+
x = x.permute(0, 3, 1, 2).contiguous() # back to NCHW
|
| 369 |
+
return x
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
class GlobalContextVit(nn.Module):
|
| 373 |
+
def __init__(
|
| 374 |
+
self,
|
| 375 |
+
in_chans: int = 3,
|
| 376 |
+
num_classes: int = 1000,
|
| 377 |
+
global_pool: str = 'avg',
|
| 378 |
+
img_size: Tuple[int, int] = 224,
|
| 379 |
+
window_ratio: Tuple[int, ...] = (32, 32, 16, 32),
|
| 380 |
+
window_size: Tuple[int, ...] = None,
|
| 381 |
+
embed_dim: int = 64,
|
| 382 |
+
depths: Tuple[int, ...] = (3, 4, 19, 5),
|
| 383 |
+
num_heads: Tuple[int, ...] = (2, 4, 8, 16),
|
| 384 |
+
mlp_ratio: float = 3.0,
|
| 385 |
+
qkv_bias: bool = True,
|
| 386 |
+
layer_scale: Optional[float] = None,
|
| 387 |
+
drop_rate: float = 0.,
|
| 388 |
+
proj_drop_rate: float = 0.,
|
| 389 |
+
attn_drop_rate: float = 0.,
|
| 390 |
+
drop_path_rate: float = 0.,
|
| 391 |
+
weight_init='',
|
| 392 |
+
act_layer: str = 'gelu',
|
| 393 |
+
norm_layer: str = 'layernorm2d',
|
| 394 |
+
norm_layer_cl: str = 'layernorm',
|
| 395 |
+
norm_eps: float = 1e-5,
|
| 396 |
+
):
|
| 397 |
+
super().__init__()
|
| 398 |
+
act_layer = get_act_layer(act_layer)
|
| 399 |
+
norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
|
| 400 |
+
norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps)
|
| 401 |
+
|
| 402 |
+
img_size = to_2tuple(img_size)
|
| 403 |
+
feat_size = tuple(d // 4 for d in img_size) # stem reduction by 4
|
| 404 |
+
self.global_pool = global_pool
|
| 405 |
+
self.num_classes = num_classes
|
| 406 |
+
self.drop_rate = drop_rate
|
| 407 |
+
num_stages = len(depths)
|
| 408 |
+
self.num_features = int(embed_dim * 2 ** (num_stages - 1))
|
| 409 |
+
if window_size is not None:
|
| 410 |
+
window_size = to_ntuple(num_stages)(window_size)
|
| 411 |
+
else:
|
| 412 |
+
assert window_ratio is not None
|
| 413 |
+
window_size = tuple([(img_size[0] // r, img_size[1] // r) for r in to_ntuple(num_stages)(window_ratio)])
|
| 414 |
+
|
| 415 |
+
self.stem = Stem(
|
| 416 |
+
in_chs=in_chans,
|
| 417 |
+
out_chs=embed_dim,
|
| 418 |
+
act_layer=act_layer,
|
| 419 |
+
norm_layer=norm_layer
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
|
| 423 |
+
stages = []
|
| 424 |
+
for i in range(num_stages):
|
| 425 |
+
last_stage = i == num_stages - 1
|
| 426 |
+
stage_scale = 2 ** max(i - 1, 0)
|
| 427 |
+
stages.append(GlobalContextVitStage(
|
| 428 |
+
dim=embed_dim * stage_scale,
|
| 429 |
+
depth=depths[i],
|
| 430 |
+
num_heads=num_heads[i],
|
| 431 |
+
feat_size=(feat_size[0] // stage_scale, feat_size[1] // stage_scale),
|
| 432 |
+
window_size=window_size[i],
|
| 433 |
+
downsample=i != 0,
|
| 434 |
+
stage_norm=last_stage,
|
| 435 |
+
mlp_ratio=mlp_ratio,
|
| 436 |
+
qkv_bias=qkv_bias,
|
| 437 |
+
layer_scale=layer_scale,
|
| 438 |
+
proj_drop=proj_drop_rate,
|
| 439 |
+
attn_drop=attn_drop_rate,
|
| 440 |
+
drop_path=dpr[i],
|
| 441 |
+
act_layer=act_layer,
|
| 442 |
+
norm_layer=norm_layer,
|
| 443 |
+
norm_layer_cl=norm_layer_cl,
|
| 444 |
+
))
|
| 445 |
+
self.stages = nn.Sequential(*stages)
|
| 446 |
+
|
| 447 |
+
# Classifier head
|
| 448 |
+
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
| 449 |
+
|
| 450 |
+
if weight_init:
|
| 451 |
+
named_apply(partial(self._init_weights, scheme=weight_init), self)
|
| 452 |
+
|
| 453 |
+
def _init_weights(self, module, name, scheme='vit'):
|
| 454 |
+
# note Conv2d left as default init
|
| 455 |
+
if scheme == 'vit':
|
| 456 |
+
if isinstance(module, nn.Linear):
|
| 457 |
+
nn.init.xavier_uniform_(module.weight)
|
| 458 |
+
if module.bias is not None:
|
| 459 |
+
if 'mlp' in name:
|
| 460 |
+
nn.init.normal_(module.bias, std=1e-6)
|
| 461 |
+
else:
|
| 462 |
+
nn.init.zeros_(module.bias)
|
| 463 |
+
else:
|
| 464 |
+
if isinstance(module, nn.Linear):
|
| 465 |
+
nn.init.normal_(module.weight, std=.02)
|
| 466 |
+
if module.bias is not None:
|
| 467 |
+
nn.init.zeros_(module.bias)
|
| 468 |
+
|
| 469 |
+
@torch.jit.ignore
|
| 470 |
+
def no_weight_decay(self):
|
| 471 |
+
return {
|
| 472 |
+
k for k, _ in self.named_parameters()
|
| 473 |
+
if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])}
|
| 474 |
+
|
| 475 |
+
@torch.jit.ignore
|
| 476 |
+
def group_matcher(self, coarse=False):
|
| 477 |
+
matcher = dict(
|
| 478 |
+
stem=r'^stem', # stem and embed
|
| 479 |
+
blocks=r'^stages\.(\d+)'
|
| 480 |
+
)
|
| 481 |
+
return matcher
|
| 482 |
+
|
| 483 |
+
@torch.jit.ignore
|
| 484 |
+
def set_grad_checkpointing(self, enable=True):
|
| 485 |
+
for s in self.stages:
|
| 486 |
+
s.grad_checkpointing = enable
|
| 487 |
+
|
| 488 |
+
@torch.jit.ignore
|
| 489 |
+
def get_classifier(self):
|
| 490 |
+
return self.head.fc
|
| 491 |
+
|
| 492 |
+
def reset_classifier(self, num_classes, global_pool=None):
|
| 493 |
+
self.num_classes = num_classes
|
| 494 |
+
if global_pool is None:
|
| 495 |
+
global_pool = self.head.global_pool.pool_type
|
| 496 |
+
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
|
| 497 |
+
|
| 498 |
+
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
|
| 499 |
+
x = self.stem(x)
|
| 500 |
+
x = self.stages(x)
|
| 501 |
+
return x
|
| 502 |
+
|
| 503 |
+
def forward_head(self, x, pre_logits: bool = False):
|
| 504 |
+
return self.head(x, pre_logits=pre_logits)
|
| 505 |
+
|
| 506 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 507 |
+
x = self.forward_features(x)
|
| 508 |
+
x = self.forward_head(x)
|
| 509 |
+
return x
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
def _create_gcvit(variant, pretrained=False, **kwargs):
|
| 513 |
+
if kwargs.get('features_only', None):
|
| 514 |
+
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
| 515 |
+
model = build_model_with_cfg(GlobalContextVit, variant, pretrained, **kwargs)
|
| 516 |
+
return model
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def _cfg(url='', **kwargs):
|
| 520 |
+
return {
|
| 521 |
+
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
| 522 |
+
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
| 523 |
+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
| 524 |
+
'first_conv': 'stem.conv1', 'classifier': 'head.fc',
|
| 525 |
+
'fixed_input_size': True,
|
| 526 |
+
**kwargs
|
| 527 |
+
}
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
default_cfgs = generate_default_cfgs({
|
| 531 |
+
'gcvit_xxtiny.in1k': _cfg(
|
| 532 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_xxtiny_224_nvidia-d1d86009.pth'),
|
| 533 |
+
'gcvit_xtiny.in1k': _cfg(
|
| 534 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_xtiny_224_nvidia-274b92b7.pth'),
|
| 535 |
+
'gcvit_tiny.in1k': _cfg(
|
| 536 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_tiny_224_nvidia-ac783954.pth'),
|
| 537 |
+
'gcvit_small.in1k': _cfg(
|
| 538 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_small_224_nvidia-4e98afa2.pth'),
|
| 539 |
+
'gcvit_base.in1k': _cfg(
|
| 540 |
+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_base_224_nvidia-f009139b.pth'),
|
| 541 |
+
})
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
@register_model
|
| 545 |
+
def gcvit_xxtiny(pretrained=False, **kwargs) -> GlobalContextVit:
|
| 546 |
+
model_kwargs = dict(
|
| 547 |
+
depths=(2, 2, 6, 2),
|
| 548 |
+
num_heads=(2, 4, 8, 16),
|
| 549 |
+
**kwargs)
|
| 550 |
+
return _create_gcvit('gcvit_xxtiny', pretrained=pretrained, **model_kwargs)
|
| 551 |
+
|
| 552 |
+
|
| 553 |
+
@register_model
|
| 554 |
+
def gcvit_xtiny(pretrained=False, **kwargs) -> GlobalContextVit:
|
| 555 |
+
model_kwargs = dict(
|
| 556 |
+
depths=(3, 4, 6, 5),
|
| 557 |
+
num_heads=(2, 4, 8, 16),
|
| 558 |
+
**kwargs)
|
| 559 |
+
return _create_gcvit('gcvit_xtiny', pretrained=pretrained, **model_kwargs)
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
@register_model
|
| 563 |
+
def gcvit_tiny(pretrained=False, **kwargs) -> GlobalContextVit:
|
| 564 |
+
model_kwargs = dict(
|
| 565 |
+
depths=(3, 4, 19, 5),
|
| 566 |
+
num_heads=(2, 4, 8, 16),
|
| 567 |
+
**kwargs)
|
| 568 |
+
return _create_gcvit('gcvit_tiny', pretrained=pretrained, **model_kwargs)
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
@register_model
|
| 572 |
+
def gcvit_small(pretrained=False, **kwargs) -> GlobalContextVit:
|
| 573 |
+
model_kwargs = dict(
|
| 574 |
+
depths=(3, 4, 19, 5),
|
| 575 |
+
num_heads=(3, 6, 12, 24),
|
| 576 |
+
embed_dim=96,
|
| 577 |
+
mlp_ratio=2,
|
| 578 |
+
layer_scale=1e-5,
|
| 579 |
+
**kwargs)
|
| 580 |
+
return _create_gcvit('gcvit_small', pretrained=pretrained, **model_kwargs)
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
@register_model
|
| 584 |
+
def gcvit_base(pretrained=False, **kwargs) -> GlobalContextVit:
|
| 585 |
+
model_kwargs = dict(
|
| 586 |
+
depths=(3, 4, 19, 5),
|
| 587 |
+
num_heads=(4, 8, 16, 32),
|
| 588 |
+
embed_dim=128,
|
| 589 |
+
mlp_ratio=2,
|
| 590 |
+
layer_scale=1e-5,
|
| 591 |
+
**kwargs)
|
| 592 |
+
return _create_gcvit('gcvit_base', pretrained=pretrained, **model_kwargs)
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/ghostnet.py
ADDED
|
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
An implementation of GhostNet & GhostNetV2 Models as defined in:
|
| 3 |
+
GhostNet: More Features from Cheap Operations. https://arxiv.org/abs/1911.11907
|
| 4 |
+
GhostNetV2: Enhance Cheap Operation with Long-Range Attention. https://proceedings.neurips.cc/paper_files/paper/2022/file/40b60852a4abdaa696b5a1a78da34635-Paper-Conference.pdf
|
| 5 |
+
|
| 6 |
+
The train script & code of models at:
|
| 7 |
+
Original model: https://github.com/huawei-noah/CV-backbones/tree/master/ghostnet_pytorch
|
| 8 |
+
Original model: https://github.com/huawei-noah/Efficient-AI-Backbones/blob/master/ghostnetv2_pytorch/model/ghostnetv2_torch.py
|
| 9 |
+
"""
|
| 10 |
+
import math
|
| 11 |
+
from functools import partial
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
|
| 17 |
+
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 18 |
+
from timm.layers import SelectAdaptivePool2d, Linear, make_divisible
|
| 19 |
+
from ._builder import build_model_with_cfg
|
| 20 |
+
from ._efficientnet_blocks import SqueezeExcite, ConvBnAct
|
| 21 |
+
from ._manipulate import checkpoint_seq
|
| 22 |
+
from ._registry import register_model, generate_default_cfgs
|
| 23 |
+
|
| 24 |
+
__all__ = ['GhostNet']
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
_SE_LAYER = partial(SqueezeExcite, gate_layer='hard_sigmoid', rd_round_fn=partial(make_divisible, divisor=4))
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class GhostModule(nn.Module):
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
in_chs,
|
| 34 |
+
out_chs,
|
| 35 |
+
kernel_size=1,
|
| 36 |
+
ratio=2,
|
| 37 |
+
dw_size=3,
|
| 38 |
+
stride=1,
|
| 39 |
+
use_act=True,
|
| 40 |
+
act_layer=nn.ReLU,
|
| 41 |
+
):
|
| 42 |
+
super(GhostModule, self).__init__()
|
| 43 |
+
self.out_chs = out_chs
|
| 44 |
+
init_chs = math.ceil(out_chs / ratio)
|
| 45 |
+
new_chs = init_chs * (ratio - 1)
|
| 46 |
+
|
| 47 |
+
self.primary_conv = nn.Sequential(
|
| 48 |
+
nn.Conv2d(in_chs, init_chs, kernel_size, stride, kernel_size // 2, bias=False),
|
| 49 |
+
nn.BatchNorm2d(init_chs),
|
| 50 |
+
act_layer(inplace=True) if use_act else nn.Identity(),
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
self.cheap_operation = nn.Sequential(
|
| 54 |
+
nn.Conv2d(init_chs, new_chs, dw_size, 1, dw_size//2, groups=init_chs, bias=False),
|
| 55 |
+
nn.BatchNorm2d(new_chs),
|
| 56 |
+
act_layer(inplace=True) if use_act else nn.Identity(),
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def forward(self, x):
|
| 60 |
+
x1 = self.primary_conv(x)
|
| 61 |
+
x2 = self.cheap_operation(x1)
|
| 62 |
+
out = torch.cat([x1, x2], dim=1)
|
| 63 |
+
return out[:, :self.out_chs, :, :]
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class GhostModuleV2(nn.Module):
|
| 67 |
+
def __init__(
|
| 68 |
+
self,
|
| 69 |
+
in_chs,
|
| 70 |
+
out_chs,
|
| 71 |
+
kernel_size=1,
|
| 72 |
+
ratio=2,
|
| 73 |
+
dw_size=3,
|
| 74 |
+
stride=1,
|
| 75 |
+
use_act=True,
|
| 76 |
+
act_layer=nn.ReLU,
|
| 77 |
+
):
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.gate_fn = nn.Sigmoid()
|
| 80 |
+
self.out_chs = out_chs
|
| 81 |
+
init_chs = math.ceil(out_chs / ratio)
|
| 82 |
+
new_chs = init_chs * (ratio - 1)
|
| 83 |
+
self.primary_conv = nn.Sequential(
|
| 84 |
+
nn.Conv2d(in_chs, init_chs, kernel_size, stride, kernel_size // 2, bias=False),
|
| 85 |
+
nn.BatchNorm2d(init_chs),
|
| 86 |
+
act_layer(inplace=True) if use_act else nn.Identity(),
|
| 87 |
+
)
|
| 88 |
+
self.cheap_operation = nn.Sequential(
|
| 89 |
+
nn.Conv2d(init_chs, new_chs, dw_size, 1, dw_size // 2, groups=init_chs, bias=False),
|
| 90 |
+
nn.BatchNorm2d(new_chs),
|
| 91 |
+
act_layer(inplace=True) if use_act else nn.Identity(),
|
| 92 |
+
)
|
| 93 |
+
self.short_conv = nn.Sequential(
|
| 94 |
+
nn.Conv2d(in_chs, out_chs, kernel_size, stride, kernel_size // 2, bias=False),
|
| 95 |
+
nn.BatchNorm2d(out_chs),
|
| 96 |
+
nn.Conv2d(out_chs, out_chs, kernel_size=(1, 5), stride=1, padding=(0, 2), groups=out_chs, bias=False),
|
| 97 |
+
nn.BatchNorm2d(out_chs),
|
| 98 |
+
nn.Conv2d(out_chs, out_chs, kernel_size=(5, 1), stride=1, padding=(2, 0), groups=out_chs, bias=False),
|
| 99 |
+
nn.BatchNorm2d(out_chs),
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
def forward(self, x):
|
| 103 |
+
res = self.short_conv(F.avg_pool2d(x, kernel_size=2, stride=2))
|
| 104 |
+
x1 = self.primary_conv(x)
|
| 105 |
+
x2 = self.cheap_operation(x1)
|
| 106 |
+
out = torch.cat([x1, x2], dim=1)
|
| 107 |
+
return out[:, :self.out_chs, :, :] * F.interpolate(
|
| 108 |
+
self.gate_fn(res), size=(out.shape[-2], out.shape[-1]), mode='nearest')
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class GhostBottleneck(nn.Module):
|
| 112 |
+
""" Ghost bottleneck w/ optional SE"""
|
| 113 |
+
|
| 114 |
+
def __init__(
|
| 115 |
+
self,
|
| 116 |
+
in_chs,
|
| 117 |
+
mid_chs,
|
| 118 |
+
out_chs,
|
| 119 |
+
dw_kernel_size=3,
|
| 120 |
+
stride=1,
|
| 121 |
+
act_layer=nn.ReLU,
|
| 122 |
+
se_ratio=0.,
|
| 123 |
+
mode='original',
|
| 124 |
+
):
|
| 125 |
+
super(GhostBottleneck, self).__init__()
|
| 126 |
+
has_se = se_ratio is not None and se_ratio > 0.
|
| 127 |
+
self.stride = stride
|
| 128 |
+
|
| 129 |
+
# Point-wise expansion
|
| 130 |
+
if mode == 'original':
|
| 131 |
+
self.ghost1 = GhostModule(in_chs, mid_chs, use_act=True, act_layer=act_layer)
|
| 132 |
+
else:
|
| 133 |
+
self.ghost1 = GhostModuleV2(in_chs, mid_chs, use_act=True, act_layer=act_layer)
|
| 134 |
+
|
| 135 |
+
# Depth-wise convolution
|
| 136 |
+
if self.stride > 1:
|
| 137 |
+
self.conv_dw = nn.Conv2d(
|
| 138 |
+
mid_chs, mid_chs, dw_kernel_size, stride=stride,
|
| 139 |
+
padding=(dw_kernel_size-1)//2, groups=mid_chs, bias=False)
|
| 140 |
+
self.bn_dw = nn.BatchNorm2d(mid_chs)
|
| 141 |
+
else:
|
| 142 |
+
self.conv_dw = None
|
| 143 |
+
self.bn_dw = None
|
| 144 |
+
|
| 145 |
+
# Squeeze-and-excitation
|
| 146 |
+
self.se = _SE_LAYER(mid_chs, rd_ratio=se_ratio) if has_se else None
|
| 147 |
+
|
| 148 |
+
# Point-wise linear projection
|
| 149 |
+
self.ghost2 = GhostModule(mid_chs, out_chs, use_act=False)
|
| 150 |
+
|
| 151 |
+
# shortcut
|
| 152 |
+
if in_chs == out_chs and self.stride == 1:
|
| 153 |
+
self.shortcut = nn.Sequential()
|
| 154 |
+
else:
|
| 155 |
+
self.shortcut = nn.Sequential(
|
| 156 |
+
nn.Conv2d(
|
| 157 |
+
in_chs, in_chs, dw_kernel_size, stride=stride,
|
| 158 |
+
padding=(dw_kernel_size-1)//2, groups=in_chs, bias=False),
|
| 159 |
+
nn.BatchNorm2d(in_chs),
|
| 160 |
+
nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False),
|
| 161 |
+
nn.BatchNorm2d(out_chs),
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
def forward(self, x):
|
| 165 |
+
shortcut = x
|
| 166 |
+
|
| 167 |
+
# 1st ghost bottleneck
|
| 168 |
+
x = self.ghost1(x)
|
| 169 |
+
|
| 170 |
+
# Depth-wise convolution
|
| 171 |
+
if self.conv_dw is not None:
|
| 172 |
+
x = self.conv_dw(x)
|
| 173 |
+
x = self.bn_dw(x)
|
| 174 |
+
|
| 175 |
+
# Squeeze-and-excitation
|
| 176 |
+
if self.se is not None:
|
| 177 |
+
x = self.se(x)
|
| 178 |
+
|
| 179 |
+
# 2nd ghost bottleneck
|
| 180 |
+
x = self.ghost2(x)
|
| 181 |
+
|
| 182 |
+
x += self.shortcut(shortcut)
|
| 183 |
+
return x
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
class GhostNet(nn.Module):
|
| 187 |
+
def __init__(
|
| 188 |
+
self,
|
| 189 |
+
cfgs,
|
| 190 |
+
num_classes=1000,
|
| 191 |
+
width=1.0,
|
| 192 |
+
in_chans=3,
|
| 193 |
+
output_stride=32,
|
| 194 |
+
global_pool='avg',
|
| 195 |
+
drop_rate=0.2,
|
| 196 |
+
version='v1',
|
| 197 |
+
):
|
| 198 |
+
super(GhostNet, self).__init__()
|
| 199 |
+
# setting of inverted residual blocks
|
| 200 |
+
assert output_stride == 32, 'only output_stride==32 is valid, dilation not supported'
|
| 201 |
+
self.cfgs = cfgs
|
| 202 |
+
self.num_classes = num_classes
|
| 203 |
+
self.drop_rate = drop_rate
|
| 204 |
+
self.grad_checkpointing = False
|
| 205 |
+
self.feature_info = []
|
| 206 |
+
|
| 207 |
+
# building first layer
|
| 208 |
+
stem_chs = make_divisible(16 * width, 4)
|
| 209 |
+
self.conv_stem = nn.Conv2d(in_chans, stem_chs, 3, 2, 1, bias=False)
|
| 210 |
+
self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=f'conv_stem'))
|
| 211 |
+
self.bn1 = nn.BatchNorm2d(stem_chs)
|
| 212 |
+
self.act1 = nn.ReLU(inplace=True)
|
| 213 |
+
prev_chs = stem_chs
|
| 214 |
+
|
| 215 |
+
# building inverted residual blocks
|
| 216 |
+
stages = nn.ModuleList([])
|
| 217 |
+
stage_idx = 0
|
| 218 |
+
layer_idx = 0
|
| 219 |
+
net_stride = 2
|
| 220 |
+
for cfg in self.cfgs:
|
| 221 |
+
layers = []
|
| 222 |
+
s = 1
|
| 223 |
+
for k, exp_size, c, se_ratio, s in cfg:
|
| 224 |
+
out_chs = make_divisible(c * width, 4)
|
| 225 |
+
mid_chs = make_divisible(exp_size * width, 4)
|
| 226 |
+
layer_kwargs = {}
|
| 227 |
+
if version == 'v2' and layer_idx > 1:
|
| 228 |
+
layer_kwargs['mode'] = 'attn'
|
| 229 |
+
layers.append(GhostBottleneck(prev_chs, mid_chs, out_chs, k, s, se_ratio=se_ratio, **layer_kwargs))
|
| 230 |
+
prev_chs = out_chs
|
| 231 |
+
layer_idx += 1
|
| 232 |
+
if s > 1:
|
| 233 |
+
net_stride *= 2
|
| 234 |
+
self.feature_info.append(dict(
|
| 235 |
+
num_chs=prev_chs, reduction=net_stride, module=f'blocks.{stage_idx}'))
|
| 236 |
+
stages.append(nn.Sequential(*layers))
|
| 237 |
+
stage_idx += 1
|
| 238 |
+
|
| 239 |
+
out_chs = make_divisible(exp_size * width, 4)
|
| 240 |
+
stages.append(nn.Sequential(ConvBnAct(prev_chs, out_chs, 1)))
|
| 241 |
+
self.pool_dim = prev_chs = out_chs
|
| 242 |
+
|
| 243 |
+
self.blocks = nn.Sequential(*stages)
|
| 244 |
+
|
| 245 |
+
# building last several layers
|
| 246 |
+
self.num_features = out_chs = 1280
|
| 247 |
+
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
| 248 |
+
self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True)
|
| 249 |
+
self.act2 = nn.ReLU(inplace=True)
|
| 250 |
+
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
|
| 251 |
+
self.classifier = Linear(out_chs, num_classes) if num_classes > 0 else nn.Identity()
|
| 252 |
+
|
| 253 |
+
# FIXME init
|
| 254 |
+
|
| 255 |
+
@torch.jit.ignore
|
| 256 |
+
def group_matcher(self, coarse=False):
|
| 257 |
+
matcher = dict(
|
| 258 |
+
stem=r'^conv_stem|bn1',
|
| 259 |
+
blocks=[
|
| 260 |
+
(r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)', None),
|
| 261 |
+
(r'conv_head', (99999,))
|
| 262 |
+
]
|
| 263 |
+
)
|
| 264 |
+
return matcher
|
| 265 |
+
|
| 266 |
+
@torch.jit.ignore
|
| 267 |
+
def set_grad_checkpointing(self, enable=True):
|
| 268 |
+
self.grad_checkpointing = enable
|
| 269 |
+
|
| 270 |
+
@torch.jit.ignore
|
| 271 |
+
def get_classifier(self):
|
| 272 |
+
return self.classifier
|
| 273 |
+
|
| 274 |
+
def reset_classifier(self, num_classes, global_pool='avg'):
|
| 275 |
+
self.num_classes = num_classes
|
| 276 |
+
# cannot meaningfully change pooling of efficient head after creation
|
| 277 |
+
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
| 278 |
+
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
|
| 279 |
+
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
| 280 |
+
|
| 281 |
+
def forward_features(self, x):
|
| 282 |
+
x = self.conv_stem(x)
|
| 283 |
+
x = self.bn1(x)
|
| 284 |
+
x = self.act1(x)
|
| 285 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
| 286 |
+
x = checkpoint_seq(self.blocks, x, flatten=True)
|
| 287 |
+
else:
|
| 288 |
+
x = self.blocks(x)
|
| 289 |
+
return x
|
| 290 |
+
|
| 291 |
+
def forward_head(self, x):
|
| 292 |
+
x = self.global_pool(x)
|
| 293 |
+
x = self.conv_head(x)
|
| 294 |
+
x = self.act2(x)
|
| 295 |
+
x = self.flatten(x)
|
| 296 |
+
if self.drop_rate > 0.:
|
| 297 |
+
x = F.dropout(x, p=self.drop_rate, training=self.training)
|
| 298 |
+
x = self.classifier(x)
|
| 299 |
+
return x
|
| 300 |
+
|
| 301 |
+
def forward(self, x):
|
| 302 |
+
x = self.forward_features(x)
|
| 303 |
+
x = self.forward_head(x)
|
| 304 |
+
return x
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def checkpoint_filter_fn(state_dict, model: nn.Module):
|
| 308 |
+
out_dict = {}
|
| 309 |
+
for k, v in state_dict.items():
|
| 310 |
+
if 'total' in k:
|
| 311 |
+
continue
|
| 312 |
+
out_dict[k] = v
|
| 313 |
+
return out_dict
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def _create_ghostnet(variant, width=1.0, pretrained=False, **kwargs):
|
| 317 |
+
"""
|
| 318 |
+
Constructs a GhostNet model
|
| 319 |
+
"""
|
| 320 |
+
cfgs = [
|
| 321 |
+
# k, t, c, SE, s
|
| 322 |
+
# stage1
|
| 323 |
+
[[3, 16, 16, 0, 1]],
|
| 324 |
+
# stage2
|
| 325 |
+
[[3, 48, 24, 0, 2]],
|
| 326 |
+
[[3, 72, 24, 0, 1]],
|
| 327 |
+
# stage3
|
| 328 |
+
[[5, 72, 40, 0.25, 2]],
|
| 329 |
+
[[5, 120, 40, 0.25, 1]],
|
| 330 |
+
# stage4
|
| 331 |
+
[[3, 240, 80, 0, 2]],
|
| 332 |
+
[[3, 200, 80, 0, 1],
|
| 333 |
+
[3, 184, 80, 0, 1],
|
| 334 |
+
[3, 184, 80, 0, 1],
|
| 335 |
+
[3, 480, 112, 0.25, 1],
|
| 336 |
+
[3, 672, 112, 0.25, 1]
|
| 337 |
+
],
|
| 338 |
+
# stage5
|
| 339 |
+
[[5, 672, 160, 0.25, 2]],
|
| 340 |
+
[[5, 960, 160, 0, 1],
|
| 341 |
+
[5, 960, 160, 0.25, 1],
|
| 342 |
+
[5, 960, 160, 0, 1],
|
| 343 |
+
[5, 960, 160, 0.25, 1]
|
| 344 |
+
]
|
| 345 |
+
]
|
| 346 |
+
model_kwargs = dict(
|
| 347 |
+
cfgs=cfgs,
|
| 348 |
+
width=width,
|
| 349 |
+
**kwargs,
|
| 350 |
+
)
|
| 351 |
+
return build_model_with_cfg(
|
| 352 |
+
GhostNet,
|
| 353 |
+
variant,
|
| 354 |
+
pretrained,
|
| 355 |
+
pretrained_filter_fn=checkpoint_filter_fn,
|
| 356 |
+
feature_cfg=dict(flatten_sequential=True),
|
| 357 |
+
**model_kwargs,
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def _cfg(url='', **kwargs):
|
| 362 |
+
return {
|
| 363 |
+
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
|
| 364 |
+
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
| 365 |
+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
| 366 |
+
'first_conv': 'conv_stem', 'classifier': 'classifier',
|
| 367 |
+
**kwargs
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
default_cfgs = generate_default_cfgs({
|
| 372 |
+
'ghostnet_050.untrained': _cfg(),
|
| 373 |
+
'ghostnet_100.in1k': _cfg(
|
| 374 |
+
hf_hub_id='timm/',
|
| 375 |
+
# url='https://github.com/huawei-noah/CV-backbones/releases/download/ghostnet_pth/ghostnet_1x.pth'
|
| 376 |
+
),
|
| 377 |
+
'ghostnet_130.untrained': _cfg(),
|
| 378 |
+
'ghostnetv2_100.in1k': _cfg(
|
| 379 |
+
hf_hub_id='timm/',
|
| 380 |
+
# url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/GhostNetV2/ck_ghostnetv2_10.pth.tar'
|
| 381 |
+
),
|
| 382 |
+
'ghostnetv2_130.in1k': _cfg(
|
| 383 |
+
hf_hub_id='timm/',
|
| 384 |
+
# url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/GhostNetV2/ck_ghostnetv2_13.pth.tar'
|
| 385 |
+
),
|
| 386 |
+
'ghostnetv2_160.in1k': _cfg(
|
| 387 |
+
hf_hub_id='timm/',
|
| 388 |
+
# url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/GhostNetV2/ck_ghostnetv2_16.pth.tar'
|
| 389 |
+
),
|
| 390 |
+
})
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
@register_model
|
| 394 |
+
def ghostnet_050(pretrained=False, **kwargs) -> GhostNet:
|
| 395 |
+
""" GhostNet-0.5x """
|
| 396 |
+
model = _create_ghostnet('ghostnet_050', width=0.5, pretrained=pretrained, **kwargs)
|
| 397 |
+
return model
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
@register_model
|
| 401 |
+
def ghostnet_100(pretrained=False, **kwargs) -> GhostNet:
|
| 402 |
+
""" GhostNet-1.0x """
|
| 403 |
+
model = _create_ghostnet('ghostnet_100', width=1.0, pretrained=pretrained, **kwargs)
|
| 404 |
+
return model
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
@register_model
|
| 408 |
+
def ghostnet_130(pretrained=False, **kwargs) -> GhostNet:
|
| 409 |
+
""" GhostNet-1.3x """
|
| 410 |
+
model = _create_ghostnet('ghostnet_130', width=1.3, pretrained=pretrained, **kwargs)
|
| 411 |
+
return model
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
@register_model
|
| 415 |
+
def ghostnetv2_100(pretrained=False, **kwargs) -> GhostNet:
|
| 416 |
+
""" GhostNetV2-1.0x """
|
| 417 |
+
model = _create_ghostnet('ghostnetv2_100', width=1.0, pretrained=pretrained, version='v2', **kwargs)
|
| 418 |
+
return model
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
@register_model
|
| 422 |
+
def ghostnetv2_130(pretrained=False, **kwargs) -> GhostNet:
|
| 423 |
+
""" GhostNetV2-1.3x """
|
| 424 |
+
model = _create_ghostnet('ghostnetv2_130', width=1.3, pretrained=pretrained, version='v2', **kwargs)
|
| 425 |
+
return model
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
@register_model
|
| 429 |
+
def ghostnetv2_160(pretrained=False, **kwargs) -> GhostNet:
|
| 430 |
+
""" GhostNetV2-1.6x """
|
| 431 |
+
model = _create_ghostnet('ghostnetv2_160', width=1.6, pretrained=pretrained, version='v2', **kwargs)
|
| 432 |
+
return model
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/inception_v4.py
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Pytorch Inception-V4 implementation
|
| 2 |
+
Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is
|
| 3 |
+
based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License)
|
| 4 |
+
"""
|
| 5 |
+
from functools import partial
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
|
| 11 |
+
from timm.layers import create_classifier, ConvNormAct
|
| 12 |
+
from ._builder import build_model_with_cfg
|
| 13 |
+
from ._registry import register_model, generate_default_cfgs
|
| 14 |
+
|
| 15 |
+
__all__ = ['InceptionV4']
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Mixed3a(nn.Module):
|
| 19 |
+
def __init__(self, conv_block=ConvNormAct):
|
| 20 |
+
super(Mixed3a, self).__init__()
|
| 21 |
+
self.maxpool = nn.MaxPool2d(3, stride=2)
|
| 22 |
+
self.conv = conv_block(64, 96, kernel_size=3, stride=2)
|
| 23 |
+
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
x0 = self.maxpool(x)
|
| 26 |
+
x1 = self.conv(x)
|
| 27 |
+
out = torch.cat((x0, x1), 1)
|
| 28 |
+
return out
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Mixed4a(nn.Module):
|
| 32 |
+
def __init__(self, conv_block=ConvNormAct):
|
| 33 |
+
super(Mixed4a, self).__init__()
|
| 34 |
+
|
| 35 |
+
self.branch0 = nn.Sequential(
|
| 36 |
+
conv_block(160, 64, kernel_size=1, stride=1),
|
| 37 |
+
conv_block(64, 96, kernel_size=3, stride=1)
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
self.branch1 = nn.Sequential(
|
| 41 |
+
conv_block(160, 64, kernel_size=1, stride=1),
|
| 42 |
+
conv_block(64, 64, kernel_size=(1, 7), stride=1, padding=(0, 3)),
|
| 43 |
+
conv_block(64, 64, kernel_size=(7, 1), stride=1, padding=(3, 0)),
|
| 44 |
+
conv_block(64, 96, kernel_size=(3, 3), stride=1)
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
def forward(self, x):
|
| 48 |
+
x0 = self.branch0(x)
|
| 49 |
+
x1 = self.branch1(x)
|
| 50 |
+
out = torch.cat((x0, x1), 1)
|
| 51 |
+
return out
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class Mixed5a(nn.Module):
|
| 55 |
+
def __init__(self, conv_block=ConvNormAct):
|
| 56 |
+
super(Mixed5a, self).__init__()
|
| 57 |
+
self.conv = conv_block(192, 192, kernel_size=3, stride=2)
|
| 58 |
+
self.maxpool = nn.MaxPool2d(3, stride=2)
|
| 59 |
+
|
| 60 |
+
def forward(self, x):
|
| 61 |
+
x0 = self.conv(x)
|
| 62 |
+
x1 = self.maxpool(x)
|
| 63 |
+
out = torch.cat((x0, x1), 1)
|
| 64 |
+
return out
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class InceptionA(nn.Module):
|
| 68 |
+
def __init__(self, conv_block=ConvNormAct):
|
| 69 |
+
super(InceptionA, self).__init__()
|
| 70 |
+
self.branch0 = conv_block(384, 96, kernel_size=1, stride=1)
|
| 71 |
+
|
| 72 |
+
self.branch1 = nn.Sequential(
|
| 73 |
+
conv_block(384, 64, kernel_size=1, stride=1),
|
| 74 |
+
conv_block(64, 96, kernel_size=3, stride=1, padding=1)
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
self.branch2 = nn.Sequential(
|
| 78 |
+
conv_block(384, 64, kernel_size=1, stride=1),
|
| 79 |
+
conv_block(64, 96, kernel_size=3, stride=1, padding=1),
|
| 80 |
+
conv_block(96, 96, kernel_size=3, stride=1, padding=1)
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
self.branch3 = nn.Sequential(
|
| 84 |
+
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
|
| 85 |
+
conv_block(384, 96, kernel_size=1, stride=1)
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
def forward(self, x):
|
| 89 |
+
x0 = self.branch0(x)
|
| 90 |
+
x1 = self.branch1(x)
|
| 91 |
+
x2 = self.branch2(x)
|
| 92 |
+
x3 = self.branch3(x)
|
| 93 |
+
out = torch.cat((x0, x1, x2, x3), 1)
|
| 94 |
+
return out
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class ReductionA(nn.Module):
|
| 98 |
+
def __init__(self, conv_block=ConvNormAct):
|
| 99 |
+
super(ReductionA, self).__init__()
|
| 100 |
+
self.branch0 = conv_block(384, 384, kernel_size=3, stride=2)
|
| 101 |
+
|
| 102 |
+
self.branch1 = nn.Sequential(
|
| 103 |
+
conv_block(384, 192, kernel_size=1, stride=1),
|
| 104 |
+
conv_block(192, 224, kernel_size=3, stride=1, padding=1),
|
| 105 |
+
conv_block(224, 256, kernel_size=3, stride=2)
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
self.branch2 = nn.MaxPool2d(3, stride=2)
|
| 109 |
+
|
| 110 |
+
def forward(self, x):
|
| 111 |
+
x0 = self.branch0(x)
|
| 112 |
+
x1 = self.branch1(x)
|
| 113 |
+
x2 = self.branch2(x)
|
| 114 |
+
out = torch.cat((x0, x1, x2), 1)
|
| 115 |
+
return out
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class InceptionB(nn.Module):
|
| 119 |
+
def __init__(self, conv_block=ConvNormAct):
|
| 120 |
+
super(InceptionB, self).__init__()
|
| 121 |
+
self.branch0 = conv_block(1024, 384, kernel_size=1, stride=1)
|
| 122 |
+
|
| 123 |
+
self.branch1 = nn.Sequential(
|
| 124 |
+
conv_block(1024, 192, kernel_size=1, stride=1),
|
| 125 |
+
conv_block(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)),
|
| 126 |
+
conv_block(224, 256, kernel_size=(7, 1), stride=1, padding=(3, 0))
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
self.branch2 = nn.Sequential(
|
| 130 |
+
conv_block(1024, 192, kernel_size=1, stride=1),
|
| 131 |
+
conv_block(192, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)),
|
| 132 |
+
conv_block(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)),
|
| 133 |
+
conv_block(224, 224, kernel_size=(7, 1), stride=1, padding=(3, 0)),
|
| 134 |
+
conv_block(224, 256, kernel_size=(1, 7), stride=1, padding=(0, 3))
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
self.branch3 = nn.Sequential(
|
| 138 |
+
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
|
| 139 |
+
conv_block(1024, 128, kernel_size=1, stride=1)
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
def forward(self, x):
|
| 143 |
+
x0 = self.branch0(x)
|
| 144 |
+
x1 = self.branch1(x)
|
| 145 |
+
x2 = self.branch2(x)
|
| 146 |
+
x3 = self.branch3(x)
|
| 147 |
+
out = torch.cat((x0, x1, x2, x3), 1)
|
| 148 |
+
return out
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class ReductionB(nn.Module):
|
| 152 |
+
def __init__(self, conv_block=ConvNormAct):
|
| 153 |
+
super(ReductionB, self).__init__()
|
| 154 |
+
|
| 155 |
+
self.branch0 = nn.Sequential(
|
| 156 |
+
conv_block(1024, 192, kernel_size=1, stride=1),
|
| 157 |
+
conv_block(192, 192, kernel_size=3, stride=2)
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
self.branch1 = nn.Sequential(
|
| 161 |
+
conv_block(1024, 256, kernel_size=1, stride=1),
|
| 162 |
+
conv_block(256, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)),
|
| 163 |
+
conv_block(256, 320, kernel_size=(7, 1), stride=1, padding=(3, 0)),
|
| 164 |
+
conv_block(320, 320, kernel_size=3, stride=2)
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
self.branch2 = nn.MaxPool2d(3, stride=2)
|
| 168 |
+
|
| 169 |
+
def forward(self, x):
|
| 170 |
+
x0 = self.branch0(x)
|
| 171 |
+
x1 = self.branch1(x)
|
| 172 |
+
x2 = self.branch2(x)
|
| 173 |
+
out = torch.cat((x0, x1, x2), 1)
|
| 174 |
+
return out
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class InceptionC(nn.Module):
|
| 178 |
+
def __init__(self, conv_block=ConvNormAct):
|
| 179 |
+
super(InceptionC, self).__init__()
|
| 180 |
+
|
| 181 |
+
self.branch0 = conv_block(1536, 256, kernel_size=1, stride=1)
|
| 182 |
+
|
| 183 |
+
self.branch1_0 = conv_block(1536, 384, kernel_size=1, stride=1)
|
| 184 |
+
self.branch1_1a = conv_block(384, 256, kernel_size=(1, 3), stride=1, padding=(0, 1))
|
| 185 |
+
self.branch1_1b = conv_block(384, 256, kernel_size=(3, 1), stride=1, padding=(1, 0))
|
| 186 |
+
|
| 187 |
+
self.branch2_0 = conv_block(1536, 384, kernel_size=1, stride=1)
|
| 188 |
+
self.branch2_1 = conv_block(384, 448, kernel_size=(3, 1), stride=1, padding=(1, 0))
|
| 189 |
+
self.branch2_2 = conv_block(448, 512, kernel_size=(1, 3), stride=1, padding=(0, 1))
|
| 190 |
+
self.branch2_3a = conv_block(512, 256, kernel_size=(1, 3), stride=1, padding=(0, 1))
|
| 191 |
+
self.branch2_3b = conv_block(512, 256, kernel_size=(3, 1), stride=1, padding=(1, 0))
|
| 192 |
+
|
| 193 |
+
self.branch3 = nn.Sequential(
|
| 194 |
+
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
|
| 195 |
+
conv_block(1536, 256, kernel_size=1, stride=1)
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
def forward(self, x):
|
| 199 |
+
x0 = self.branch0(x)
|
| 200 |
+
|
| 201 |
+
x1_0 = self.branch1_0(x)
|
| 202 |
+
x1_1a = self.branch1_1a(x1_0)
|
| 203 |
+
x1_1b = self.branch1_1b(x1_0)
|
| 204 |
+
x1 = torch.cat((x1_1a, x1_1b), 1)
|
| 205 |
+
|
| 206 |
+
x2_0 = self.branch2_0(x)
|
| 207 |
+
x2_1 = self.branch2_1(x2_0)
|
| 208 |
+
x2_2 = self.branch2_2(x2_1)
|
| 209 |
+
x2_3a = self.branch2_3a(x2_2)
|
| 210 |
+
x2_3b = self.branch2_3b(x2_2)
|
| 211 |
+
x2 = torch.cat((x2_3a, x2_3b), 1)
|
| 212 |
+
|
| 213 |
+
x3 = self.branch3(x)
|
| 214 |
+
|
| 215 |
+
out = torch.cat((x0, x1, x2, x3), 1)
|
| 216 |
+
return out
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class InceptionV4(nn.Module):
|
| 220 |
+
def __init__(
|
| 221 |
+
self,
|
| 222 |
+
num_classes=1000,
|
| 223 |
+
in_chans=3,
|
| 224 |
+
output_stride=32,
|
| 225 |
+
drop_rate=0.,
|
| 226 |
+
global_pool='avg',
|
| 227 |
+
norm_layer='batchnorm2d',
|
| 228 |
+
norm_eps=1e-3,
|
| 229 |
+
act_layer='relu',
|
| 230 |
+
):
|
| 231 |
+
super(InceptionV4, self).__init__()
|
| 232 |
+
assert output_stride == 32
|
| 233 |
+
self.num_classes = num_classes
|
| 234 |
+
self.num_features = 1536
|
| 235 |
+
conv_block = partial(
|
| 236 |
+
ConvNormAct,
|
| 237 |
+
padding=0,
|
| 238 |
+
norm_layer=norm_layer,
|
| 239 |
+
act_layer=act_layer,
|
| 240 |
+
norm_kwargs=dict(eps=norm_eps),
|
| 241 |
+
act_kwargs=dict(inplace=True),
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
features = [
|
| 245 |
+
conv_block(in_chans, 32, kernel_size=3, stride=2),
|
| 246 |
+
conv_block(32, 32, kernel_size=3, stride=1),
|
| 247 |
+
conv_block(32, 64, kernel_size=3, stride=1, padding=1),
|
| 248 |
+
Mixed3a(conv_block),
|
| 249 |
+
Mixed4a(conv_block),
|
| 250 |
+
Mixed5a(conv_block),
|
| 251 |
+
]
|
| 252 |
+
features += [InceptionA(conv_block) for _ in range(4)]
|
| 253 |
+
features += [ReductionA(conv_block)] # Mixed6a
|
| 254 |
+
features += [InceptionB(conv_block) for _ in range(7)]
|
| 255 |
+
features += [ReductionB(conv_block)] # Mixed7a
|
| 256 |
+
features += [InceptionC(conv_block) for _ in range(3)]
|
| 257 |
+
self.features = nn.Sequential(*features)
|
| 258 |
+
self.feature_info = [
|
| 259 |
+
dict(num_chs=64, reduction=2, module='features.2'),
|
| 260 |
+
dict(num_chs=160, reduction=4, module='features.3'),
|
| 261 |
+
dict(num_chs=384, reduction=8, module='features.9'),
|
| 262 |
+
dict(num_chs=1024, reduction=16, module='features.17'),
|
| 263 |
+
dict(num_chs=1536, reduction=32, module='features.21'),
|
| 264 |
+
]
|
| 265 |
+
self.global_pool, self.head_drop, self.last_linear = create_classifier(
|
| 266 |
+
self.num_features, self.num_classes, pool_type=global_pool, drop_rate=drop_rate)
|
| 267 |
+
|
| 268 |
+
@torch.jit.ignore
|
| 269 |
+
def group_matcher(self, coarse=False):
|
| 270 |
+
return dict(
|
| 271 |
+
stem=r'^features\.[012]\.',
|
| 272 |
+
blocks=r'^features\.(\d+)'
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
@torch.jit.ignore
|
| 276 |
+
def set_grad_checkpointing(self, enable=True):
|
| 277 |
+
assert not enable, 'gradient checkpointing not supported'
|
| 278 |
+
|
| 279 |
+
@torch.jit.ignore
|
| 280 |
+
def get_classifier(self):
|
| 281 |
+
return self.last_linear
|
| 282 |
+
|
| 283 |
+
def reset_classifier(self, num_classes, global_pool='avg'):
|
| 284 |
+
self.num_classes = num_classes
|
| 285 |
+
self.global_pool, self.last_linear = create_classifier(
|
| 286 |
+
self.num_features, self.num_classes, pool_type=global_pool)
|
| 287 |
+
|
| 288 |
+
def forward_features(self, x):
|
| 289 |
+
return self.features(x)
|
| 290 |
+
|
| 291 |
+
def forward_head(self, x, pre_logits: bool = False):
|
| 292 |
+
x = self.global_pool(x)
|
| 293 |
+
x = self.head_drop(x)
|
| 294 |
+
return x if pre_logits else self.last_linear(x)
|
| 295 |
+
|
| 296 |
+
def forward(self, x):
|
| 297 |
+
x = self.forward_features(x)
|
| 298 |
+
x = self.forward_head(x)
|
| 299 |
+
return x
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def _create_inception_v4(variant, pretrained=False, **kwargs) -> InceptionV4:
|
| 303 |
+
return build_model_with_cfg(
|
| 304 |
+
InceptionV4,
|
| 305 |
+
variant,
|
| 306 |
+
pretrained,
|
| 307 |
+
feature_cfg=dict(flatten_sequential=True),
|
| 308 |
+
**kwargs,
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
default_cfgs = generate_default_cfgs({
|
| 313 |
+
'inception_v4.tf_in1k': {
|
| 314 |
+
'hf_hub_id': 'timm/',
|
| 315 |
+
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
|
| 316 |
+
'crop_pct': 0.875, 'interpolation': 'bicubic',
|
| 317 |
+
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
|
| 318 |
+
'first_conv': 'features.0.conv', 'classifier': 'last_linear',
|
| 319 |
+
}
|
| 320 |
+
})
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
@register_model
|
| 324 |
+
def inception_v4(pretrained=False, **kwargs):
|
| 325 |
+
return _create_inception_v4('inception_v4', pretrained, **kwargs)
|
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/levit.py
ADDED
|
@@ -0,0 +1,933 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" LeViT
|
| 2 |
+
|
| 3 |
+
Paper: `LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference`
|
| 4 |
+
- https://arxiv.org/abs/2104.01136
|
| 5 |
+
|
| 6 |
+
@article{graham2021levit,
|
| 7 |
+
title={LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference},
|
| 8 |
+
author={Benjamin Graham and Alaaeldin El-Nouby and Hugo Touvron and Pierre Stock and Armand Joulin and Herv\'e J\'egou and Matthijs Douze},
|
| 9 |
+
journal={arXiv preprint arXiv:22104.01136},
|
| 10 |
+
year={2021}
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
Adapted from official impl at https://github.com/facebookresearch/LeViT, original copyright bellow.
|
| 14 |
+
|
| 15 |
+
This version combines both conv/linear models and fixes torchscript compatibility.
|
| 16 |
+
|
| 17 |
+
Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
# Copyright (c) 2015-present, Facebook, Inc.
|
| 21 |
+
# All rights reserved.
|
| 22 |
+
|
| 23 |
+
# Modified from
|
| 24 |
+
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
| 25 |
+
# Copyright 2020 Ross Wightman, Apache-2.0 License
|
| 26 |
+
from collections import OrderedDict
|
| 27 |
+
from functools import partial
|
| 28 |
+
from typing import Dict
|
| 29 |
+
|
| 30 |
+
import torch
|
| 31 |
+
import torch.nn as nn
|
| 32 |
+
|
| 33 |
+
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
|
| 34 |
+
from timm.layers import to_ntuple, to_2tuple, get_act_layer, DropPath, trunc_normal_, ndgrid
|
| 35 |
+
from ._builder import build_model_with_cfg
|
| 36 |
+
from ._manipulate import checkpoint_seq
|
| 37 |
+
from ._registry import generate_default_cfgs, register_model
|
| 38 |
+
|
| 39 |
+
__all__ = ['Levit']
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class ConvNorm(nn.Module):
|
| 43 |
+
def __init__(
|
| 44 |
+
self, in_chs, out_chs, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bn_weight_init=1):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.linear = nn.Conv2d(in_chs, out_chs, kernel_size, stride, padding, dilation, groups, bias=False)
|
| 47 |
+
self.bn = nn.BatchNorm2d(out_chs)
|
| 48 |
+
|
| 49 |
+
nn.init.constant_(self.bn.weight, bn_weight_init)
|
| 50 |
+
|
| 51 |
+
@torch.no_grad()
|
| 52 |
+
def fuse(self):
|
| 53 |
+
c, bn = self.linear, self.bn
|
| 54 |
+
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
|
| 55 |
+
w = c.weight * w[:, None, None, None]
|
| 56 |
+
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
|
| 57 |
+
m = nn.Conv2d(
|
| 58 |
+
w.size(1), w.size(0), w.shape[2:], stride=self.linear.stride,
|
| 59 |
+
padding=self.linear.padding, dilation=self.linear.dilation, groups=self.linear.groups)
|
| 60 |
+
m.weight.data.copy_(w)
|
| 61 |
+
m.bias.data.copy_(b)
|
| 62 |
+
return m
|
| 63 |
+
|
| 64 |
+
def forward(self, x):
|
| 65 |
+
return self.bn(self.linear(x))
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class LinearNorm(nn.Module):
|
| 69 |
+
def __init__(self, in_features, out_features, bn_weight_init=1):
|
| 70 |
+
super().__init__()
|
| 71 |
+
self.linear = nn.Linear(in_features, out_features, bias=False)
|
| 72 |
+
self.bn = nn.BatchNorm1d(out_features)
|
| 73 |
+
|
| 74 |
+
nn.init.constant_(self.bn.weight, bn_weight_init)
|
| 75 |
+
|
| 76 |
+
@torch.no_grad()
|
| 77 |
+
def fuse(self):
|
| 78 |
+
l, bn = self.linear, self.bn
|
| 79 |
+
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
|
| 80 |
+
w = l.weight * w[:, None]
|
| 81 |
+
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
|
| 82 |
+
m = nn.Linear(w.size(1), w.size(0))
|
| 83 |
+
m.weight.data.copy_(w)
|
| 84 |
+
m.bias.data.copy_(b)
|
| 85 |
+
return m
|
| 86 |
+
|
| 87 |
+
def forward(self, x):
|
| 88 |
+
x = self.linear(x)
|
| 89 |
+
return self.bn(x.flatten(0, 1)).reshape_as(x)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class NormLinear(nn.Module):
|
| 93 |
+
def __init__(self, in_features, out_features, bias=True, std=0.02, drop=0.):
|
| 94 |
+
super().__init__()
|
| 95 |
+
self.bn = nn.BatchNorm1d(in_features)
|
| 96 |
+
self.drop = nn.Dropout(drop)
|
| 97 |
+
self.linear = nn.Linear(in_features, out_features, bias=bias)
|
| 98 |
+
|
| 99 |
+
trunc_normal_(self.linear.weight, std=std)
|
| 100 |
+
if self.linear.bias is not None:
|
| 101 |
+
nn.init.constant_(self.linear.bias, 0)
|
| 102 |
+
|
| 103 |
+
@torch.no_grad()
|
| 104 |
+
def fuse(self):
|
| 105 |
+
bn, l = self.bn, self.linear
|
| 106 |
+
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
|
| 107 |
+
b = bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5
|
| 108 |
+
w = l.weight * w[None, :]
|
| 109 |
+
if l.bias is None:
|
| 110 |
+
b = b @ self.linear.weight.T
|
| 111 |
+
else:
|
| 112 |
+
b = (l.weight @ b[:, None]).view(-1) + self.linear.bias
|
| 113 |
+
m = nn.Linear(w.size(1), w.size(0))
|
| 114 |
+
m.weight.data.copy_(w)
|
| 115 |
+
m.bias.data.copy_(b)
|
| 116 |
+
return m
|
| 117 |
+
|
| 118 |
+
def forward(self, x):
|
| 119 |
+
return self.linear(self.drop(self.bn(x)))
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class Stem8(nn.Sequential):
|
| 123 |
+
def __init__(self, in_chs, out_chs, act_layer):
|
| 124 |
+
super().__init__()
|
| 125 |
+
self.stride = 8
|
| 126 |
+
|
| 127 |
+
self.add_module('conv1', ConvNorm(in_chs, out_chs // 4, 3, stride=2, padding=1))
|
| 128 |
+
self.add_module('act1', act_layer())
|
| 129 |
+
self.add_module('conv2', ConvNorm(out_chs // 4, out_chs // 2, 3, stride=2, padding=1))
|
| 130 |
+
self.add_module('act2', act_layer())
|
| 131 |
+
self.add_module('conv3', ConvNorm(out_chs // 2, out_chs, 3, stride=2, padding=1))
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class Stem16(nn.Sequential):
|
| 135 |
+
def __init__(self, in_chs, out_chs, act_layer):
|
| 136 |
+
super().__init__()
|
| 137 |
+
self.stride = 16
|
| 138 |
+
|
| 139 |
+
self.add_module('conv1', ConvNorm(in_chs, out_chs // 8, 3, stride=2, padding=1))
|
| 140 |
+
self.add_module('act1', act_layer())
|
| 141 |
+
self.add_module('conv2', ConvNorm(out_chs // 8, out_chs // 4, 3, stride=2, padding=1))
|
| 142 |
+
self.add_module('act2', act_layer())
|
| 143 |
+
self.add_module('conv3', ConvNorm(out_chs // 4, out_chs // 2, 3, stride=2, padding=1))
|
| 144 |
+
self.add_module('act3', act_layer())
|
| 145 |
+
self.add_module('conv4', ConvNorm(out_chs // 2, out_chs, 3, stride=2, padding=1))
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class Downsample(nn.Module):
|
| 149 |
+
def __init__(self, stride, resolution, use_pool=False):
|
| 150 |
+
super().__init__()
|
| 151 |
+
self.stride = stride
|
| 152 |
+
self.resolution = to_2tuple(resolution)
|
| 153 |
+
self.pool = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) if use_pool else None
|
| 154 |
+
|
| 155 |
+
def forward(self, x):
|
| 156 |
+
B, N, C = x.shape
|
| 157 |
+
x = x.view(B, self.resolution[0], self.resolution[1], C)
|
| 158 |
+
if self.pool is not None:
|
| 159 |
+
x = self.pool(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
|
| 160 |
+
else:
|
| 161 |
+
x = x[:, ::self.stride, ::self.stride]
|
| 162 |
+
return x.reshape(B, -1, C)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
class Attention(nn.Module):
|
| 166 |
+
attention_bias_cache: Dict[str, torch.Tensor]
|
| 167 |
+
|
| 168 |
+
def __init__(
|
| 169 |
+
self,
|
| 170 |
+
dim,
|
| 171 |
+
key_dim,
|
| 172 |
+
num_heads=8,
|
| 173 |
+
attn_ratio=4.,
|
| 174 |
+
resolution=14,
|
| 175 |
+
use_conv=False,
|
| 176 |
+
act_layer=nn.SiLU,
|
| 177 |
+
):
|
| 178 |
+
super().__init__()
|
| 179 |
+
ln_layer = ConvNorm if use_conv else LinearNorm
|
| 180 |
+
resolution = to_2tuple(resolution)
|
| 181 |
+
|
| 182 |
+
self.use_conv = use_conv
|
| 183 |
+
self.num_heads = num_heads
|
| 184 |
+
self.scale = key_dim ** -0.5
|
| 185 |
+
self.key_dim = key_dim
|
| 186 |
+
self.key_attn_dim = key_dim * num_heads
|
| 187 |
+
self.val_dim = int(attn_ratio * key_dim)
|
| 188 |
+
self.val_attn_dim = int(attn_ratio * key_dim) * num_heads
|
| 189 |
+
|
| 190 |
+
self.qkv = ln_layer(dim, self.val_attn_dim + self.key_attn_dim * 2)
|
| 191 |
+
self.proj = nn.Sequential(OrderedDict([
|
| 192 |
+
('act', act_layer()),
|
| 193 |
+
('ln', ln_layer(self.val_attn_dim, dim, bn_weight_init=0))
|
| 194 |
+
]))
|
| 195 |
+
|
| 196 |
+
self.attention_biases = nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1]))
|
| 197 |
+
pos = torch.stack(ndgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1)
|
| 198 |
+
rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
|
| 199 |
+
rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
|
| 200 |
+
self.register_buffer('attention_bias_idxs', rel_pos, persistent=False)
|
| 201 |
+
self.attention_bias_cache = {}
|
| 202 |
+
|
| 203 |
+
@torch.no_grad()
|
| 204 |
+
def train(self, mode=True):
|
| 205 |
+
super().train(mode)
|
| 206 |
+
if mode and self.attention_bias_cache:
|
| 207 |
+
self.attention_bias_cache = {} # clear ab cache
|
| 208 |
+
|
| 209 |
+
def get_attention_biases(self, device: torch.device) -> torch.Tensor:
|
| 210 |
+
if torch.jit.is_tracing() or self.training:
|
| 211 |
+
return self.attention_biases[:, self.attention_bias_idxs]
|
| 212 |
+
else:
|
| 213 |
+
device_key = str(device)
|
| 214 |
+
if device_key not in self.attention_bias_cache:
|
| 215 |
+
self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
|
| 216 |
+
return self.attention_bias_cache[device_key]
|
| 217 |
+
|
| 218 |
+
def forward(self, x): # x (B,C,H,W)
|
| 219 |
+
if self.use_conv:
|
| 220 |
+
B, C, H, W = x.shape
|
| 221 |
+
q, k, v = self.qkv(x).view(
|
| 222 |
+
B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.val_dim], dim=2)
|
| 223 |
+
|
| 224 |
+
attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
|
| 225 |
+
attn = attn.softmax(dim=-1)
|
| 226 |
+
|
| 227 |
+
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
|
| 228 |
+
else:
|
| 229 |
+
B, N, C = x.shape
|
| 230 |
+
q, k, v = self.qkv(x).view(
|
| 231 |
+
B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.val_dim], dim=3)
|
| 232 |
+
q = q.permute(0, 2, 1, 3)
|
| 233 |
+
k = k.permute(0, 2, 3, 1)
|
| 234 |
+
v = v.permute(0, 2, 1, 3)
|
| 235 |
+
|
| 236 |
+
attn = q @ k * self.scale + self.get_attention_biases(x.device)
|
| 237 |
+
attn = attn.softmax(dim=-1)
|
| 238 |
+
|
| 239 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, self.val_attn_dim)
|
| 240 |
+
x = self.proj(x)
|
| 241 |
+
return x
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class AttentionDownsample(nn.Module):
|
| 245 |
+
attention_bias_cache: Dict[str, torch.Tensor]
|
| 246 |
+
|
| 247 |
+
def __init__(
|
| 248 |
+
self,
|
| 249 |
+
in_dim,
|
| 250 |
+
out_dim,
|
| 251 |
+
key_dim,
|
| 252 |
+
num_heads=8,
|
| 253 |
+
attn_ratio=2.0,
|
| 254 |
+
stride=2,
|
| 255 |
+
resolution=14,
|
| 256 |
+
use_conv=False,
|
| 257 |
+
use_pool=False,
|
| 258 |
+
act_layer=nn.SiLU,
|
| 259 |
+
):
|
| 260 |
+
super().__init__()
|
| 261 |
+
resolution = to_2tuple(resolution)
|
| 262 |
+
|
| 263 |
+
self.stride = stride
|
| 264 |
+
self.resolution = resolution
|
| 265 |
+
self.num_heads = num_heads
|
| 266 |
+
self.key_dim = key_dim
|
| 267 |
+
self.key_attn_dim = key_dim * num_heads
|
| 268 |
+
self.val_dim = int(attn_ratio * key_dim)
|
| 269 |
+
self.val_attn_dim = self.val_dim * self.num_heads
|
| 270 |
+
self.scale = key_dim ** -0.5
|
| 271 |
+
self.use_conv = use_conv
|
| 272 |
+
|
| 273 |
+
if self.use_conv:
|
| 274 |
+
ln_layer = ConvNorm
|
| 275 |
+
sub_layer = partial(
|
| 276 |
+
nn.AvgPool2d,
|
| 277 |
+
kernel_size=3 if use_pool else 1, padding=1 if use_pool else 0, count_include_pad=False)
|
| 278 |
+
else:
|
| 279 |
+
ln_layer = LinearNorm
|
| 280 |
+
sub_layer = partial(Downsample, resolution=resolution, use_pool=use_pool)
|
| 281 |
+
|
| 282 |
+
self.kv = ln_layer(in_dim, self.val_attn_dim + self.key_attn_dim)
|
| 283 |
+
self.q = nn.Sequential(OrderedDict([
|
| 284 |
+
('down', sub_layer(stride=stride)),
|
| 285 |
+
('ln', ln_layer(in_dim, self.key_attn_dim))
|
| 286 |
+
]))
|
| 287 |
+
self.proj = nn.Sequential(OrderedDict([
|
| 288 |
+
('act', act_layer()),
|
| 289 |
+
('ln', ln_layer(self.val_attn_dim, out_dim))
|
| 290 |
+
]))
|
| 291 |
+
|
| 292 |
+
self.attention_biases = nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1]))
|
| 293 |
+
k_pos = torch.stack(ndgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1)
|
| 294 |
+
q_pos = torch.stack(ndgrid(
|
| 295 |
+
torch.arange(0, resolution[0], step=stride),
|
| 296 |
+
torch.arange(0, resolution[1], step=stride)
|
| 297 |
+
)).flatten(1)
|
| 298 |
+
rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs()
|
| 299 |
+
rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
|
| 300 |
+
self.register_buffer('attention_bias_idxs', rel_pos, persistent=False)
|
| 301 |
+
|
| 302 |
+
self.attention_bias_cache = {} # per-device attention_biases cache
|
| 303 |
+
|
| 304 |
+
@torch.no_grad()
|
| 305 |
+
def train(self, mode=True):
|
| 306 |
+
super().train(mode)
|
| 307 |
+
if mode and self.attention_bias_cache:
|
| 308 |
+
self.attention_bias_cache = {} # clear ab cache
|
| 309 |
+
|
| 310 |
+
def get_attention_biases(self, device: torch.device) -> torch.Tensor:
|
| 311 |
+
if torch.jit.is_tracing() or self.training:
|
| 312 |
+
return self.attention_biases[:, self.attention_bias_idxs]
|
| 313 |
+
else:
|
| 314 |
+
device_key = str(device)
|
| 315 |
+
if device_key not in self.attention_bias_cache:
|
| 316 |
+
self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
|
| 317 |
+
return self.attention_bias_cache[device_key]
|
| 318 |
+
|
| 319 |
+
def forward(self, x):
|
| 320 |
+
if self.use_conv:
|
| 321 |
+
B, C, H, W = x.shape
|
| 322 |
+
HH, WW = (H - 1) // self.stride + 1, (W - 1) // self.stride + 1
|
| 323 |
+
k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.val_dim], dim=2)
|
| 324 |
+
q = self.q(x).view(B, self.num_heads, self.key_dim, -1)
|
| 325 |
+
|
| 326 |
+
attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
|
| 327 |
+
attn = attn.softmax(dim=-1)
|
| 328 |
+
|
| 329 |
+
x = (v @ attn.transpose(-2, -1)).reshape(B, self.val_attn_dim, HH, WW)
|
| 330 |
+
else:
|
| 331 |
+
B, N, C = x.shape
|
| 332 |
+
k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.val_dim], dim=3)
|
| 333 |
+
k = k.permute(0, 2, 3, 1) # BHCN
|
| 334 |
+
v = v.permute(0, 2, 1, 3) # BHNC
|
| 335 |
+
q = self.q(x).view(B, -1, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
|
| 336 |
+
|
| 337 |
+
attn = q @ k * self.scale + self.get_attention_biases(x.device)
|
| 338 |
+
attn = attn.softmax(dim=-1)
|
| 339 |
+
|
| 340 |
+
x = (attn @ v).transpose(1, 2).reshape(B, -1, self.val_attn_dim)
|
| 341 |
+
x = self.proj(x)
|
| 342 |
+
return x
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
class LevitMlp(nn.Module):
|
| 346 |
+
""" MLP for Levit w/ normalization + ability to switch btw conv and linear
|
| 347 |
+
"""
|
| 348 |
+
def __init__(
|
| 349 |
+
self,
|
| 350 |
+
in_features,
|
| 351 |
+
hidden_features=None,
|
| 352 |
+
out_features=None,
|
| 353 |
+
use_conv=False,
|
| 354 |
+
act_layer=nn.SiLU,
|
| 355 |
+
drop=0.
|
| 356 |
+
):
|
| 357 |
+
super().__init__()
|
| 358 |
+
out_features = out_features or in_features
|
| 359 |
+
hidden_features = hidden_features or in_features
|
| 360 |
+
ln_layer = ConvNorm if use_conv else LinearNorm
|
| 361 |
+
|
| 362 |
+
self.ln1 = ln_layer(in_features, hidden_features)
|
| 363 |
+
self.act = act_layer()
|
| 364 |
+
self.drop = nn.Dropout(drop)
|
| 365 |
+
self.ln2 = ln_layer(hidden_features, out_features, bn_weight_init=0)
|
| 366 |
+
|
| 367 |
+
def forward(self, x):
|
| 368 |
+
x = self.ln1(x)
|
| 369 |
+
x = self.act(x)
|
| 370 |
+
x = self.drop(x)
|
| 371 |
+
x = self.ln2(x)
|
| 372 |
+
return x
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
class LevitDownsample(nn.Module):
|
| 376 |
+
def __init__(
|
| 377 |
+
self,
|
| 378 |
+
in_dim,
|
| 379 |
+
out_dim,
|
| 380 |
+
key_dim,
|
| 381 |
+
num_heads=8,
|
| 382 |
+
attn_ratio=4.,
|
| 383 |
+
mlp_ratio=2.,
|
| 384 |
+
act_layer=nn.SiLU,
|
| 385 |
+
attn_act_layer=None,
|
| 386 |
+
resolution=14,
|
| 387 |
+
use_conv=False,
|
| 388 |
+
use_pool=False,
|
| 389 |
+
drop_path=0.,
|
| 390 |
+
):
|
| 391 |
+
super().__init__()
|
| 392 |
+
attn_act_layer = attn_act_layer or act_layer
|
| 393 |
+
|
| 394 |
+
self.attn_downsample = AttentionDownsample(
|
| 395 |
+
in_dim=in_dim,
|
| 396 |
+
out_dim=out_dim,
|
| 397 |
+
key_dim=key_dim,
|
| 398 |
+
num_heads=num_heads,
|
| 399 |
+
attn_ratio=attn_ratio,
|
| 400 |
+
act_layer=attn_act_layer,
|
| 401 |
+
resolution=resolution,
|
| 402 |
+
use_conv=use_conv,
|
| 403 |
+
use_pool=use_pool,
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
self.mlp = LevitMlp(
|
| 407 |
+
out_dim,
|
| 408 |
+
int(out_dim * mlp_ratio),
|
| 409 |
+
use_conv=use_conv,
|
| 410 |
+
act_layer=act_layer
|
| 411 |
+
)
|
| 412 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 413 |
+
|
| 414 |
+
def forward(self, x):
|
| 415 |
+
x = self.attn_downsample(x)
|
| 416 |
+
x = x + self.drop_path(self.mlp(x))
|
| 417 |
+
return x
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
class LevitBlock(nn.Module):
|
| 421 |
+
def __init__(
|
| 422 |
+
self,
|
| 423 |
+
dim,
|
| 424 |
+
key_dim,
|
| 425 |
+
num_heads=8,
|
| 426 |
+
attn_ratio=4.,
|
| 427 |
+
mlp_ratio=2.,
|
| 428 |
+
resolution=14,
|
| 429 |
+
use_conv=False,
|
| 430 |
+
act_layer=nn.SiLU,
|
| 431 |
+
attn_act_layer=None,
|
| 432 |
+
drop_path=0.,
|
| 433 |
+
):
|
| 434 |
+
super().__init__()
|
| 435 |
+
attn_act_layer = attn_act_layer or act_layer
|
| 436 |
+
|
| 437 |
+
self.attn = Attention(
|
| 438 |
+
dim=dim,
|
| 439 |
+
key_dim=key_dim,
|
| 440 |
+
num_heads=num_heads,
|
| 441 |
+
attn_ratio=attn_ratio,
|
| 442 |
+
resolution=resolution,
|
| 443 |
+
use_conv=use_conv,
|
| 444 |
+
act_layer=attn_act_layer,
|
| 445 |
+
)
|
| 446 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 447 |
+
|
| 448 |
+
self.mlp = LevitMlp(
|
| 449 |
+
dim,
|
| 450 |
+
int(dim * mlp_ratio),
|
| 451 |
+
use_conv=use_conv,
|
| 452 |
+
act_layer=act_layer
|
| 453 |
+
)
|
| 454 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 455 |
+
|
| 456 |
+
def forward(self, x):
|
| 457 |
+
x = x + self.drop_path1(self.attn(x))
|
| 458 |
+
x = x + self.drop_path2(self.mlp(x))
|
| 459 |
+
return x
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
class LevitStage(nn.Module):
|
| 463 |
+
def __init__(
|
| 464 |
+
self,
|
| 465 |
+
in_dim,
|
| 466 |
+
out_dim,
|
| 467 |
+
key_dim,
|
| 468 |
+
depth=4,
|
| 469 |
+
num_heads=8,
|
| 470 |
+
attn_ratio=4.0,
|
| 471 |
+
mlp_ratio=4.0,
|
| 472 |
+
act_layer=nn.SiLU,
|
| 473 |
+
attn_act_layer=None,
|
| 474 |
+
resolution=14,
|
| 475 |
+
downsample='',
|
| 476 |
+
use_conv=False,
|
| 477 |
+
drop_path=0.,
|
| 478 |
+
):
|
| 479 |
+
super().__init__()
|
| 480 |
+
resolution = to_2tuple(resolution)
|
| 481 |
+
|
| 482 |
+
if downsample:
|
| 483 |
+
self.downsample = LevitDownsample(
|
| 484 |
+
in_dim,
|
| 485 |
+
out_dim,
|
| 486 |
+
key_dim=key_dim,
|
| 487 |
+
num_heads=in_dim // key_dim,
|
| 488 |
+
attn_ratio=4.,
|
| 489 |
+
mlp_ratio=2.,
|
| 490 |
+
act_layer=act_layer,
|
| 491 |
+
attn_act_layer=attn_act_layer,
|
| 492 |
+
resolution=resolution,
|
| 493 |
+
use_conv=use_conv,
|
| 494 |
+
drop_path=drop_path,
|
| 495 |
+
)
|
| 496 |
+
resolution = [(r - 1) // 2 + 1 for r in resolution]
|
| 497 |
+
else:
|
| 498 |
+
assert in_dim == out_dim
|
| 499 |
+
self.downsample = nn.Identity()
|
| 500 |
+
|
| 501 |
+
blocks = []
|
| 502 |
+
for _ in range(depth):
|
| 503 |
+
blocks += [LevitBlock(
|
| 504 |
+
out_dim,
|
| 505 |
+
key_dim,
|
| 506 |
+
num_heads=num_heads,
|
| 507 |
+
attn_ratio=attn_ratio,
|
| 508 |
+
mlp_ratio=mlp_ratio,
|
| 509 |
+
act_layer=act_layer,
|
| 510 |
+
attn_act_layer=attn_act_layer,
|
| 511 |
+
resolution=resolution,
|
| 512 |
+
use_conv=use_conv,
|
| 513 |
+
drop_path=drop_path,
|
| 514 |
+
)]
|
| 515 |
+
self.blocks = nn.Sequential(*blocks)
|
| 516 |
+
|
| 517 |
+
def forward(self, x):
|
| 518 |
+
x = self.downsample(x)
|
| 519 |
+
x = self.blocks(x)
|
| 520 |
+
return x
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
class Levit(nn.Module):
|
| 524 |
+
""" Vision Transformer with support for patch or hybrid CNN input stage
|
| 525 |
+
|
| 526 |
+
NOTE: distillation is defaulted to True since pretrained weights use it, will cause problems
|
| 527 |
+
w/ train scripts that don't take tuple outputs,
|
| 528 |
+
"""
|
| 529 |
+
|
| 530 |
+
def __init__(
|
| 531 |
+
self,
|
| 532 |
+
img_size=224,
|
| 533 |
+
in_chans=3,
|
| 534 |
+
num_classes=1000,
|
| 535 |
+
embed_dim=(192,),
|
| 536 |
+
key_dim=64,
|
| 537 |
+
depth=(12,),
|
| 538 |
+
num_heads=(3,),
|
| 539 |
+
attn_ratio=2.,
|
| 540 |
+
mlp_ratio=2.,
|
| 541 |
+
stem_backbone=None,
|
| 542 |
+
stem_stride=None,
|
| 543 |
+
stem_type='s16',
|
| 544 |
+
down_op='subsample',
|
| 545 |
+
act_layer='hard_swish',
|
| 546 |
+
attn_act_layer=None,
|
| 547 |
+
use_conv=False,
|
| 548 |
+
global_pool='avg',
|
| 549 |
+
drop_rate=0.,
|
| 550 |
+
drop_path_rate=0.):
|
| 551 |
+
super().__init__()
|
| 552 |
+
act_layer = get_act_layer(act_layer)
|
| 553 |
+
attn_act_layer = get_act_layer(attn_act_layer or act_layer)
|
| 554 |
+
self.use_conv = use_conv
|
| 555 |
+
self.num_classes = num_classes
|
| 556 |
+
self.global_pool = global_pool
|
| 557 |
+
self.num_features = embed_dim[-1]
|
| 558 |
+
self.embed_dim = embed_dim
|
| 559 |
+
self.drop_rate = drop_rate
|
| 560 |
+
self.grad_checkpointing = False
|
| 561 |
+
self.feature_info = []
|
| 562 |
+
|
| 563 |
+
num_stages = len(embed_dim)
|
| 564 |
+
assert len(depth) == num_stages
|
| 565 |
+
num_heads = to_ntuple(num_stages)(num_heads)
|
| 566 |
+
attn_ratio = to_ntuple(num_stages)(attn_ratio)
|
| 567 |
+
mlp_ratio = to_ntuple(num_stages)(mlp_ratio)
|
| 568 |
+
|
| 569 |
+
if stem_backbone is not None:
|
| 570 |
+
assert stem_stride >= 2
|
| 571 |
+
self.stem = stem_backbone
|
| 572 |
+
stride = stem_stride
|
| 573 |
+
else:
|
| 574 |
+
assert stem_type in ('s16', 's8')
|
| 575 |
+
if stem_type == 's16':
|
| 576 |
+
self.stem = Stem16(in_chans, embed_dim[0], act_layer=act_layer)
|
| 577 |
+
else:
|
| 578 |
+
self.stem = Stem8(in_chans, embed_dim[0], act_layer=act_layer)
|
| 579 |
+
stride = self.stem.stride
|
| 580 |
+
resolution = tuple([i // p for i, p in zip(to_2tuple(img_size), to_2tuple(stride))])
|
| 581 |
+
|
| 582 |
+
in_dim = embed_dim[0]
|
| 583 |
+
stages = []
|
| 584 |
+
for i in range(num_stages):
|
| 585 |
+
stage_stride = 2 if i > 0 else 1
|
| 586 |
+
stages += [LevitStage(
|
| 587 |
+
in_dim,
|
| 588 |
+
embed_dim[i],
|
| 589 |
+
key_dim,
|
| 590 |
+
depth=depth[i],
|
| 591 |
+
num_heads=num_heads[i],
|
| 592 |
+
attn_ratio=attn_ratio[i],
|
| 593 |
+
mlp_ratio=mlp_ratio[i],
|
| 594 |
+
act_layer=act_layer,
|
| 595 |
+
attn_act_layer=attn_act_layer,
|
| 596 |
+
resolution=resolution,
|
| 597 |
+
use_conv=use_conv,
|
| 598 |
+
downsample=down_op if stage_stride == 2 else '',
|
| 599 |
+
drop_path=drop_path_rate
|
| 600 |
+
)]
|
| 601 |
+
stride *= stage_stride
|
| 602 |
+
resolution = tuple([(r - 1) // stage_stride + 1 for r in resolution])
|
| 603 |
+
self.feature_info += [dict(num_chs=embed_dim[i], reduction=stride, module=f'stages.{i}')]
|
| 604 |
+
in_dim = embed_dim[i]
|
| 605 |
+
self.stages = nn.Sequential(*stages)
|
| 606 |
+
|
| 607 |
+
# Classifier head
|
| 608 |
+
self.head = NormLinear(embed_dim[-1], num_classes, drop=drop_rate) if num_classes > 0 else nn.Identity()
|
| 609 |
+
|
| 610 |
+
@torch.jit.ignore
|
| 611 |
+
def no_weight_decay(self):
|
| 612 |
+
return {x for x in self.state_dict().keys() if 'attention_biases' in x}
|
| 613 |
+
|
| 614 |
+
@torch.jit.ignore
|
| 615 |
+
def group_matcher(self, coarse=False):
|
| 616 |
+
matcher = dict(
|
| 617 |
+
stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
|
| 618 |
+
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
|
| 619 |
+
)
|
| 620 |
+
return matcher
|
| 621 |
+
|
| 622 |
+
@torch.jit.ignore
|
| 623 |
+
def set_grad_checkpointing(self, enable=True):
|
| 624 |
+
self.grad_checkpointing = enable
|
| 625 |
+
|
| 626 |
+
@torch.jit.ignore
|
| 627 |
+
def get_classifier(self):
|
| 628 |
+
return self.head
|
| 629 |
+
|
| 630 |
+
def reset_classifier(self, num_classes, global_pool=None, distillation=None):
|
| 631 |
+
self.num_classes = num_classes
|
| 632 |
+
if global_pool is not None:
|
| 633 |
+
self.global_pool = global_pool
|
| 634 |
+
self.head = NormLinear(
|
| 635 |
+
self.embed_dim[-1], num_classes, drop=self.drop_rate) if num_classes > 0 else nn.Identity()
|
| 636 |
+
|
| 637 |
+
def forward_features(self, x):
|
| 638 |
+
x = self.stem(x)
|
| 639 |
+
if not self.use_conv:
|
| 640 |
+
x = x.flatten(2).transpose(1, 2)
|
| 641 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
| 642 |
+
x = checkpoint_seq(self.stages, x)
|
| 643 |
+
else:
|
| 644 |
+
x = self.stages(x)
|
| 645 |
+
return x
|
| 646 |
+
|
| 647 |
+
def forward_head(self, x, pre_logits: bool = False):
|
| 648 |
+
if self.global_pool == 'avg':
|
| 649 |
+
x = x.mean(dim=(-2, -1)) if self.use_conv else x.mean(dim=1)
|
| 650 |
+
return x if pre_logits else self.head(x)
|
| 651 |
+
|
| 652 |
+
def forward(self, x):
|
| 653 |
+
x = self.forward_features(x)
|
| 654 |
+
x = self.forward_head(x)
|
| 655 |
+
return x
|
| 656 |
+
|
| 657 |
+
|
| 658 |
+
class LevitDistilled(Levit):
|
| 659 |
+
def __init__(self, *args, **kwargs):
|
| 660 |
+
super().__init__(*args, **kwargs)
|
| 661 |
+
self.head_dist = NormLinear(self.num_features, self.num_classes) if self.num_classes > 0 else nn.Identity()
|
| 662 |
+
self.distilled_training = False # must set this True to train w/ distillation token
|
| 663 |
+
|
| 664 |
+
@torch.jit.ignore
|
| 665 |
+
def get_classifier(self):
|
| 666 |
+
return self.head, self.head_dist
|
| 667 |
+
|
| 668 |
+
def reset_classifier(self, num_classes, global_pool=None, distillation=None):
|
| 669 |
+
self.num_classes = num_classes
|
| 670 |
+
if global_pool is not None:
|
| 671 |
+
self.global_pool = global_pool
|
| 672 |
+
self.head = NormLinear(
|
| 673 |
+
self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else nn.Identity()
|
| 674 |
+
self.head_dist = NormLinear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
| 675 |
+
|
| 676 |
+
@torch.jit.ignore
|
| 677 |
+
def set_distilled_training(self, enable=True):
|
| 678 |
+
self.distilled_training = enable
|
| 679 |
+
|
| 680 |
+
def forward_head(self, x, pre_logits: bool = False):
|
| 681 |
+
if self.global_pool == 'avg':
|
| 682 |
+
x = x.mean(dim=(-2, -1)) if self.use_conv else x.mean(dim=1)
|
| 683 |
+
if pre_logits:
|
| 684 |
+
return x
|
| 685 |
+
x, x_dist = self.head(x), self.head_dist(x)
|
| 686 |
+
if self.distilled_training and self.training and not torch.jit.is_scripting():
|
| 687 |
+
# only return separate classification predictions when training in distilled mode
|
| 688 |
+
return x, x_dist
|
| 689 |
+
else:
|
| 690 |
+
# during standard train/finetune, inference average the classifier predictions
|
| 691 |
+
return (x + x_dist) / 2
|
| 692 |
+
|
| 693 |
+
|
| 694 |
+
def checkpoint_filter_fn(state_dict, model):
|
| 695 |
+
if 'model' in state_dict:
|
| 696 |
+
state_dict = state_dict['model']
|
| 697 |
+
|
| 698 |
+
# filter out attn biases, should not have been persistent
|
| 699 |
+
state_dict = {k: v for k, v in state_dict.items() if 'attention_bias_idxs' not in k}
|
| 700 |
+
|
| 701 |
+
D = model.state_dict()
|
| 702 |
+
out_dict = {}
|
| 703 |
+
for ka, kb, va, vb in zip(D.keys(), state_dict.keys(), D.values(), state_dict.values()):
|
| 704 |
+
if va.ndim == 4 and vb.ndim == 2:
|
| 705 |
+
vb = vb[:, :, None, None]
|
| 706 |
+
if va.shape != vb.shape:
|
| 707 |
+
# head or first-conv shapes may change for fine-tune
|
| 708 |
+
assert 'head' in ka or 'stem.conv1.linear' in ka
|
| 709 |
+
out_dict[ka] = vb
|
| 710 |
+
|
| 711 |
+
return out_dict
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
model_cfgs = dict(
|
| 715 |
+
levit_128s=dict(
|
| 716 |
+
embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 6, 8), depth=(2, 3, 4)),
|
| 717 |
+
levit_128=dict(
|
| 718 |
+
embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 8, 12), depth=(4, 4, 4)),
|
| 719 |
+
levit_192=dict(
|
| 720 |
+
embed_dim=(192, 288, 384), key_dim=32, num_heads=(3, 5, 6), depth=(4, 4, 4)),
|
| 721 |
+
levit_256=dict(
|
| 722 |
+
embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 4, 4)),
|
| 723 |
+
levit_384=dict(
|
| 724 |
+
embed_dim=(384, 512, 768), key_dim=32, num_heads=(6, 9, 12), depth=(4, 4, 4)),
|
| 725 |
+
|
| 726 |
+
# stride-8 stem experiments
|
| 727 |
+
levit_384_s8=dict(
|
| 728 |
+
embed_dim=(384, 512, 768), key_dim=32, num_heads=(6, 9, 12), depth=(4, 4, 4),
|
| 729 |
+
act_layer='silu', stem_type='s8'),
|
| 730 |
+
levit_512_s8=dict(
|
| 731 |
+
embed_dim=(512, 640, 896), key_dim=64, num_heads=(8, 10, 14), depth=(4, 4, 4),
|
| 732 |
+
act_layer='silu', stem_type='s8'),
|
| 733 |
+
|
| 734 |
+
# wider experiments
|
| 735 |
+
levit_512=dict(
|
| 736 |
+
embed_dim=(512, 768, 1024), key_dim=64, num_heads=(8, 12, 16), depth=(4, 4, 4), act_layer='silu'),
|
| 737 |
+
|
| 738 |
+
# deeper experiments
|
| 739 |
+
levit_256d=dict(
|
| 740 |
+
embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 8, 6), act_layer='silu'),
|
| 741 |
+
levit_512d=dict(
|
| 742 |
+
embed_dim=(512, 640, 768), key_dim=64, num_heads=(8, 10, 12), depth=(4, 8, 6), act_layer='silu'),
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
def create_levit(variant, cfg_variant=None, pretrained=False, distilled=True, **kwargs):
|
| 747 |
+
is_conv = '_conv' in variant
|
| 748 |
+
out_indices = kwargs.pop('out_indices', (0, 1, 2))
|
| 749 |
+
if kwargs.get('features_only', None):
|
| 750 |
+
if not is_conv:
|
| 751 |
+
raise RuntimeError('features_only not implemented for LeVit in non-convolutional mode.')
|
| 752 |
+
if cfg_variant is None:
|
| 753 |
+
if variant in model_cfgs:
|
| 754 |
+
cfg_variant = variant
|
| 755 |
+
elif is_conv:
|
| 756 |
+
cfg_variant = variant.replace('_conv', '')
|
| 757 |
+
|
| 758 |
+
model_cfg = dict(model_cfgs[cfg_variant], **kwargs)
|
| 759 |
+
model = build_model_with_cfg(
|
| 760 |
+
LevitDistilled if distilled else Levit,
|
| 761 |
+
variant,
|
| 762 |
+
pretrained,
|
| 763 |
+
pretrained_filter_fn=checkpoint_filter_fn,
|
| 764 |
+
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
|
| 765 |
+
**model_cfg,
|
| 766 |
+
)
|
| 767 |
+
return model
|
| 768 |
+
|
| 769 |
+
|
| 770 |
+
def _cfg(url='', **kwargs):
|
| 771 |
+
return {
|
| 772 |
+
'url': url,
|
| 773 |
+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
|
| 774 |
+
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
|
| 775 |
+
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
|
| 776 |
+
'first_conv': 'stem.conv1.linear', 'classifier': ('head.linear', 'head_dist.linear'),
|
| 777 |
+
**kwargs
|
| 778 |
+
}
|
| 779 |
+
|
| 780 |
+
|
| 781 |
+
default_cfgs = generate_default_cfgs({
|
| 782 |
+
# weights in nn.Linear mode
|
| 783 |
+
'levit_128s.fb_dist_in1k': _cfg(
|
| 784 |
+
hf_hub_id='timm/',
|
| 785 |
+
),
|
| 786 |
+
'levit_128.fb_dist_in1k': _cfg(
|
| 787 |
+
hf_hub_id='timm/',
|
| 788 |
+
),
|
| 789 |
+
'levit_192.fb_dist_in1k': _cfg(
|
| 790 |
+
hf_hub_id='timm/',
|
| 791 |
+
),
|
| 792 |
+
'levit_256.fb_dist_in1k': _cfg(
|
| 793 |
+
hf_hub_id='timm/',
|
| 794 |
+
),
|
| 795 |
+
'levit_384.fb_dist_in1k': _cfg(
|
| 796 |
+
hf_hub_id='timm/',
|
| 797 |
+
),
|
| 798 |
+
|
| 799 |
+
# weights in nn.Conv2d mode
|
| 800 |
+
'levit_conv_128s.fb_dist_in1k': _cfg(
|
| 801 |
+
hf_hub_id='timm/',
|
| 802 |
+
pool_size=(4, 4),
|
| 803 |
+
),
|
| 804 |
+
'levit_conv_128.fb_dist_in1k': _cfg(
|
| 805 |
+
hf_hub_id='timm/',
|
| 806 |
+
pool_size=(4, 4),
|
| 807 |
+
),
|
| 808 |
+
'levit_conv_192.fb_dist_in1k': _cfg(
|
| 809 |
+
hf_hub_id='timm/',
|
| 810 |
+
pool_size=(4, 4),
|
| 811 |
+
),
|
| 812 |
+
'levit_conv_256.fb_dist_in1k': _cfg(
|
| 813 |
+
hf_hub_id='timm/',
|
| 814 |
+
pool_size=(4, 4),
|
| 815 |
+
),
|
| 816 |
+
'levit_conv_384.fb_dist_in1k': _cfg(
|
| 817 |
+
hf_hub_id='timm/',
|
| 818 |
+
pool_size=(4, 4),
|
| 819 |
+
),
|
| 820 |
+
|
| 821 |
+
'levit_384_s8.untrained': _cfg(classifier='head.linear'),
|
| 822 |
+
'levit_512_s8.untrained': _cfg(classifier='head.linear'),
|
| 823 |
+
'levit_512.untrained': _cfg(classifier='head.linear'),
|
| 824 |
+
'levit_256d.untrained': _cfg(classifier='head.linear'),
|
| 825 |
+
'levit_512d.untrained': _cfg(classifier='head.linear'),
|
| 826 |
+
|
| 827 |
+
'levit_conv_384_s8.untrained': _cfg(classifier='head.linear'),
|
| 828 |
+
'levit_conv_512_s8.untrained': _cfg(classifier='head.linear'),
|
| 829 |
+
'levit_conv_512.untrained': _cfg(classifier='head.linear'),
|
| 830 |
+
'levit_conv_256d.untrained': _cfg(classifier='head.linear'),
|
| 831 |
+
'levit_conv_512d.untrained': _cfg(classifier='head.linear'),
|
| 832 |
+
})
|
| 833 |
+
|
| 834 |
+
|
| 835 |
+
@register_model
|
| 836 |
+
def levit_128s(pretrained=False, **kwargs) -> Levit:
|
| 837 |
+
return create_levit('levit_128s', pretrained=pretrained, **kwargs)
|
| 838 |
+
|
| 839 |
+
|
| 840 |
+
@register_model
|
| 841 |
+
def levit_128(pretrained=False, **kwargs) -> Levit:
|
| 842 |
+
return create_levit('levit_128', pretrained=pretrained, **kwargs)
|
| 843 |
+
|
| 844 |
+
|
| 845 |
+
@register_model
|
| 846 |
+
def levit_192(pretrained=False, **kwargs) -> Levit:
|
| 847 |
+
return create_levit('levit_192', pretrained=pretrained, **kwargs)
|
| 848 |
+
|
| 849 |
+
|
| 850 |
+
@register_model
|
| 851 |
+
def levit_256(pretrained=False, **kwargs) -> Levit:
|
| 852 |
+
return create_levit('levit_256', pretrained=pretrained, **kwargs)
|
| 853 |
+
|
| 854 |
+
|
| 855 |
+
@register_model
|
| 856 |
+
def levit_384(pretrained=False, **kwargs) -> Levit:
|
| 857 |
+
return create_levit('levit_384', pretrained=pretrained, **kwargs)
|
| 858 |
+
|
| 859 |
+
|
| 860 |
+
@register_model
|
| 861 |
+
def levit_384_s8(pretrained=False, **kwargs) -> Levit:
|
| 862 |
+
return create_levit('levit_384_s8', pretrained=pretrained, **kwargs)
|
| 863 |
+
|
| 864 |
+
|
| 865 |
+
@register_model
|
| 866 |
+
def levit_512_s8(pretrained=False, **kwargs) -> Levit:
|
| 867 |
+
return create_levit('levit_512_s8', pretrained=pretrained, distilled=False, **kwargs)
|
| 868 |
+
|
| 869 |
+
|
| 870 |
+
@register_model
|
| 871 |
+
def levit_512(pretrained=False, **kwargs) -> Levit:
|
| 872 |
+
return create_levit('levit_512', pretrained=pretrained, distilled=False, **kwargs)
|
| 873 |
+
|
| 874 |
+
|
| 875 |
+
@register_model
|
| 876 |
+
def levit_256d(pretrained=False, **kwargs) -> Levit:
|
| 877 |
+
return create_levit('levit_256d', pretrained=pretrained, distilled=False, **kwargs)
|
| 878 |
+
|
| 879 |
+
|
| 880 |
+
@register_model
|
| 881 |
+
def levit_512d(pretrained=False, **kwargs) -> Levit:
|
| 882 |
+
return create_levit('levit_512d', pretrained=pretrained, distilled=False, **kwargs)
|
| 883 |
+
|
| 884 |
+
|
| 885 |
+
@register_model
|
| 886 |
+
def levit_conv_128s(pretrained=False, **kwargs) -> Levit:
|
| 887 |
+
return create_levit('levit_conv_128s', pretrained=pretrained, use_conv=True, **kwargs)
|
| 888 |
+
|
| 889 |
+
|
| 890 |
+
@register_model
|
| 891 |
+
def levit_conv_128(pretrained=False, **kwargs) -> Levit:
|
| 892 |
+
return create_levit('levit_conv_128', pretrained=pretrained, use_conv=True, **kwargs)
|
| 893 |
+
|
| 894 |
+
|
| 895 |
+
@register_model
|
| 896 |
+
def levit_conv_192(pretrained=False, **kwargs) -> Levit:
|
| 897 |
+
return create_levit('levit_conv_192', pretrained=pretrained, use_conv=True, **kwargs)
|
| 898 |
+
|
| 899 |
+
|
| 900 |
+
@register_model
|
| 901 |
+
def levit_conv_256(pretrained=False, **kwargs) -> Levit:
|
| 902 |
+
return create_levit('levit_conv_256', pretrained=pretrained, use_conv=True, **kwargs)
|
| 903 |
+
|
| 904 |
+
|
| 905 |
+
@register_model
|
| 906 |
+
def levit_conv_384(pretrained=False, **kwargs) -> Levit:
|
| 907 |
+
return create_levit('levit_conv_384', pretrained=pretrained, use_conv=True, **kwargs)
|
| 908 |
+
|
| 909 |
+
|
| 910 |
+
@register_model
|
| 911 |
+
def levit_conv_384_s8(pretrained=False, **kwargs) -> Levit:
|
| 912 |
+
return create_levit('levit_conv_384_s8', pretrained=pretrained, use_conv=True, **kwargs)
|
| 913 |
+
|
| 914 |
+
|
| 915 |
+
@register_model
|
| 916 |
+
def levit_conv_512_s8(pretrained=False, **kwargs) -> Levit:
|
| 917 |
+
return create_levit('levit_conv_512_s8', pretrained=pretrained, use_conv=True, distilled=False, **kwargs)
|
| 918 |
+
|
| 919 |
+
|
| 920 |
+
@register_model
|
| 921 |
+
def levit_conv_512(pretrained=False, **kwargs) -> Levit:
|
| 922 |
+
return create_levit('levit_conv_512', pretrained=pretrained, use_conv=True, distilled=False, **kwargs)
|
| 923 |
+
|
| 924 |
+
|
| 925 |
+
@register_model
|
| 926 |
+
def levit_conv_256d(pretrained=False, **kwargs) -> Levit:
|
| 927 |
+
return create_levit('levit_conv_256d', pretrained=pretrained, use_conv=True, distilled=False, **kwargs)
|
| 928 |
+
|
| 929 |
+
|
| 930 |
+
@register_model
|
| 931 |
+
def levit_conv_512d(pretrained=False, **kwargs) -> Levit:
|
| 932 |
+
return create_levit('levit_conv_512d', pretrained=pretrained, use_conv=True, distilled=False, **kwargs)
|
| 933 |
+
|