prasb commited on
Commit
a0bd5c5
·
verified ·
1 Parent(s): 3bfbab8

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/decathlon_datalist.cpython-38.pyc +0 -0
  2. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/image_dataset.cpython-38.pyc +0 -0
  3. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/image_reader.cpython-38.pyc +0 -0
  4. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/__init__.cpython-38.pyc +0 -0
  5. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/aspp.cpython-38.pyc +0 -0
  6. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/backbone_fpn_utils.cpython-38.pyc +0 -0
  7. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/crf.cpython-38.pyc +0 -0
  8. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/fcn.cpython-38.pyc +0 -0
  9. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/feature_pyramid_network.cpython-38.pyc +0 -0
  10. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/patchembedding.cpython-38.pyc +0 -0
  11. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/segresnet_block.cpython-38.pyc +0 -0
  12. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/unetr_block.cpython-38.pyc +0 -0
  13. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/upsample.cpython-38.pyc +0 -0
  14. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/__init__.cpython-38.pyc +0 -0
  15. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/convutils.cpython-38.pyc +0 -0
  16. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/factories.cpython-38.pyc +0 -0
  17. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/simplelayers.cpython-38.pyc +0 -0
  18. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/weight_init.cpython-38.pyc +0 -0
  19. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/__init__.cpython-38.pyc +0 -0
  20. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/decorators.cpython-38.pyc +0 -0
  21. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/deprecate_utils.cpython-38.pyc +0 -0
  22. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/jupyter_utils.cpython-38.pyc +0 -0
  23. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/module.cpython-38.pyc +0 -0
  24. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/profiling.cpython-38.pyc +0 -0
  25. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/type_conversion.cpython-38.pyc +0 -0
  26. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/__pycache__/__init__.cpython-38.pyc +0 -0
  27. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/__pycache__/version.cpython-38.pyc +0 -0
  28. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/__init__.py +94 -0
  29. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_efficientnet_builder.py +484 -0
  30. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_factory.py +127 -0
  31. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_features.py +368 -0
  32. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_features_fx.py +141 -0
  33. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_hub.py +402 -0
  34. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_prune.py +113 -0
  35. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/beit.py +621 -0
  36. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/byoanet.py +455 -0
  37. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/byobnet.py +2245 -0
  38. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/coat.py +804 -0
  39. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/convit.py +430 -0
  40. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/crossvit.py +627 -0
  41. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/cspnet.py +1106 -0
  42. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/deit.py +416 -0
  43. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/dla.py +515 -0
  44. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/eva.py +1109 -0
  45. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/factory.py +4 -0
  46. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/features.py +4 -0
  47. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/gcvit.py +592 -0
  48. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/ghostnet.py +432 -0
  49. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/inception_v4.py +325 -0
  50. my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/levit.py +933 -0
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/decathlon_datalist.cpython-38.pyc ADDED
Binary file (8.88 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/image_dataset.cpython-38.pyc ADDED
Binary file (5.13 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/image_reader.cpython-38.pyc ADDED
Binary file (41.2 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (2.1 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/aspp.cpython-38.pyc ADDED
Binary file (4.2 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/backbone_fpn_utils.cpython-38.pyc ADDED
Binary file (5 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/crf.cpython-38.pyc ADDED
Binary file (3.92 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/fcn.cpython-38.pyc ADDED
Binary file (7.53 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/feature_pyramid_network.cpython-38.pyc ADDED
Binary file (8.21 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/patchembedding.cpython-38.pyc ADDED
Binary file (7.19 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/segresnet_block.cpython-38.pyc ADDED
Binary file (2.83 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/unetr_block.cpython-38.pyc ADDED
Binary file (5.84 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/upsample.cpython-38.pyc ADDED
Binary file (9.48 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (1.24 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/convutils.cpython-38.pyc ADDED
Binary file (7.22 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/factories.cpython-38.pyc ADDED
Binary file (14.1 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/simplelayers.cpython-38.pyc ADDED
Binary file (20.1 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/weight_init.cpython-38.pyc ADDED
Binary file (2.01 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (2.81 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/decorators.cpython-38.pyc ADDED
Binary file (2.92 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/deprecate_utils.cpython-38.pyc ADDED
Binary file (7.3 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/jupyter_utils.cpython-38.pyc ADDED
Binary file (13.1 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/module.cpython-38.pyc ADDED
Binary file (18.1 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/profiling.cpython-38.pyc ADDED
Binary file (3.53 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/type_conversion.cpython-38.pyc ADDED
Binary file (11.7 kB). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (571 Bytes). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/__pycache__/version.cpython-38.pyc ADDED
Binary file (177 Bytes). View file
 
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/__init__.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .beit import *
2
+ from .byoanet import *
3
+ from .byobnet import *
4
+ from .cait import *
5
+ from .coat import *
6
+ from .convit import *
7
+ from .convmixer import *
8
+ from .convnext import *
9
+ from .crossvit import *
10
+ from .cspnet import *
11
+ from .davit import *
12
+ from .deit import *
13
+ from .densenet import *
14
+ from .dla import *
15
+ from .dpn import *
16
+ from .edgenext import *
17
+ from .efficientformer import *
18
+ from .efficientformer_v2 import *
19
+ from .efficientnet import *
20
+ from .efficientvit_mit import *
21
+ from .efficientvit_msra import *
22
+ from .eva import *
23
+ from .fastvit import *
24
+ from .focalnet import *
25
+ from .gcvit import *
26
+ from .ghostnet import *
27
+ from .hardcorenas import *
28
+ from .hgnet import *
29
+ from .hrnet import *
30
+ from .inception_next import *
31
+ from .inception_resnet_v2 import *
32
+ from .inception_v3 import *
33
+ from .inception_v4 import *
34
+ from .levit import *
35
+ from .maxxvit import *
36
+ from .metaformer import *
37
+ from .mlp_mixer import *
38
+ from .mobilenetv3 import *
39
+ from .mobilevit import *
40
+ from .mvitv2 import *
41
+ from .nasnet import *
42
+ from .nest import *
43
+ from .nextvit import *
44
+ from .nfnet import *
45
+ from .pit import *
46
+ from .pnasnet import *
47
+ from .pvt_v2 import *
48
+ from .regnet import *
49
+ from .repghost import *
50
+ from .repvit import *
51
+ from .res2net import *
52
+ from .resnest import *
53
+ from .resnet import *
54
+ from .resnetv2 import *
55
+ from .rexnet import *
56
+ from .selecsls import *
57
+ from .senet import *
58
+ from .sequencer import *
59
+ from .sknet import *
60
+ from .swin_transformer import *
61
+ from .swin_transformer_v2 import *
62
+ from .swin_transformer_v2_cr import *
63
+ from .tiny_vit import *
64
+ from .tnt import *
65
+ from .tresnet import *
66
+ from .twins import *
67
+ from .vgg import *
68
+ from .visformer import *
69
+ from .vision_transformer import *
70
+ from .vision_transformer_hybrid import *
71
+ from .vision_transformer_relpos import *
72
+ from .vision_transformer_sam import *
73
+ from .volo import *
74
+ from .vovnet import *
75
+ from .xception import *
76
+ from .xception_aligned import *
77
+ from .xcit import *
78
+
79
+ from ._builder import build_model_with_cfg, load_pretrained, load_custom_pretrained, resolve_pretrained_cfg, \
80
+ set_pretrained_download_progress, set_pretrained_check_hash
81
+ from ._factory import create_model, parse_model_name, safe_model_name
82
+ from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet, FeatureDictNet
83
+ from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \
84
+ register_notrace_module, is_notrace_module, get_notrace_modules, \
85
+ register_notrace_function, is_notrace_function, get_notrace_functions
86
+ from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_state_dict, resume_checkpoint
87
+ from ._hub import load_model_config_from_hf, load_state_dict_from_hf, push_to_hf_hub
88
+ from ._manipulate import model_parameters, named_apply, named_modules, named_modules_with_params, \
89
+ group_modules, group_parameters, checkpoint_seq, adapt_input_conv
90
+ from ._pretrained import PretrainedCfg, DefaultCfg, filter_pretrained_cfg
91
+ from ._prune import adapt_model_from_string
92
+ from ._registry import split_model_name_tag, get_arch_name, generate_default_cfgs, register_model, \
93
+ register_model_deprecations, model_entrypoint, list_models, list_pretrained, get_deprecated_models, \
94
+ is_model, list_modules, is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_efficientnet_builder.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ EfficientNet, MobileNetV3, etc Builder
2
+
3
+ Assembles EfficieNet and related network feature blocks from string definitions.
4
+ Handles stride, dilation calculations, and selects feature extraction points.
5
+
6
+ Hacked together by / Copyright 2019, Ross Wightman
7
+ """
8
+
9
+ import logging
10
+ import math
11
+ import re
12
+ from copy import deepcopy
13
+ from functools import partial
14
+ from typing import Any, Dict, List
15
+
16
+ import torch.nn as nn
17
+
18
+ from ._efficientnet_blocks import *
19
+ from timm.layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible
20
+
21
+ __all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights",
22
+ 'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT']
23
+
24
+ _logger = logging.getLogger(__name__)
25
+
26
+
27
+ _DEBUG_BUILDER = False
28
+
29
+ # Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
30
+ # papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
31
+ # NOTE: momentum varies btw .99 and .9997 depending on source
32
+ # .99 in official TF TPU impl
33
+ # .9997 (/w .999 in search space) for paper
34
+ BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
35
+ BN_EPS_TF_DEFAULT = 1e-3
36
+ _BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT)
37
+
38
+ BlockArgs = List[List[Dict[str, Any]]]
39
+
40
+
41
+ def get_bn_args_tf():
42
+ return _BN_ARGS_TF.copy()
43
+
44
+
45
+ def resolve_bn_args(kwargs):
46
+ bn_args = {}
47
+ bn_momentum = kwargs.pop('bn_momentum', None)
48
+ if bn_momentum is not None:
49
+ bn_args['momentum'] = bn_momentum
50
+ bn_eps = kwargs.pop('bn_eps', None)
51
+ if bn_eps is not None:
52
+ bn_args['eps'] = bn_eps
53
+ return bn_args
54
+
55
+
56
+ def resolve_act_layer(kwargs, default='relu'):
57
+ return get_act_layer(kwargs.pop('act_layer', default))
58
+
59
+
60
+ def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None, round_limit=0.9):
61
+ """Round number of filters based on depth multiplier."""
62
+ if not multiplier:
63
+ return channels
64
+ return make_divisible(channels * multiplier, divisor, channel_min, round_limit=round_limit)
65
+
66
+
67
+ def _log_info_if(msg, condition):
68
+ if condition:
69
+ _logger.info(msg)
70
+
71
+
72
+ def _parse_ksize(ss):
73
+ if ss.isdigit():
74
+ return int(ss)
75
+ else:
76
+ return [int(k) for k in ss.split('.')]
77
+
78
+
79
+ def _decode_block_str(block_str):
80
+ """ Decode block definition string
81
+
82
+ Gets a list of block arg (dicts) through a string notation of arguments.
83
+ E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip
84
+
85
+ All args can exist in any order with the exception of the leading string which
86
+ is assumed to indicate the block type.
87
+
88
+ leading string - block type (
89
+ ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
90
+ r - number of repeat blocks,
91
+ k - kernel size,
92
+ s - strides (1-9),
93
+ e - expansion ratio,
94
+ c - output channels,
95
+ se - squeeze/excitation ratio
96
+ n - activation fn ('re', 'r6', 'hs', or 'sw')
97
+ Args:
98
+ block_str: a string representation of block arguments.
99
+ Returns:
100
+ A list of block args (dicts)
101
+ Raises:
102
+ ValueError: if the string def not properly specified (TODO)
103
+ """
104
+ assert isinstance(block_str, str)
105
+ ops = block_str.split('_')
106
+ block_type = ops[0] # take the block type off the front
107
+ ops = ops[1:]
108
+ options = {}
109
+ skip = None
110
+ for op in ops:
111
+ # string options being checked on individual basis, combine if they grow
112
+ if op == 'noskip':
113
+ skip = False # force no skip connection
114
+ elif op == 'skip':
115
+ skip = True # force a skip connection
116
+ elif op.startswith('n'):
117
+ # activation fn
118
+ key = op[0]
119
+ v = op[1:]
120
+ if v == 're':
121
+ value = get_act_layer('relu')
122
+ elif v == 'r6':
123
+ value = get_act_layer('relu6')
124
+ elif v == 'hs':
125
+ value = get_act_layer('hard_swish')
126
+ elif v == 'sw':
127
+ value = get_act_layer('swish') # aka SiLU
128
+ elif v == 'mi':
129
+ value = get_act_layer('mish')
130
+ else:
131
+ continue
132
+ options[key] = value
133
+ else:
134
+ # all numeric options
135
+ splits = re.split(r'(\d.*)', op)
136
+ if len(splits) >= 2:
137
+ key, value = splits[:2]
138
+ options[key] = value
139
+
140
+ # if act_layer is None, the model default (passed to model init) will be used
141
+ act_layer = options['n'] if 'n' in options else None
142
+ exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
143
+ pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
144
+ force_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
145
+ num_repeat = int(options['r'])
146
+
147
+ # each type of block has different valid arguments, fill accordingly
148
+ block_args = dict(
149
+ block_type=block_type,
150
+ out_chs=int(options['c']),
151
+ stride=int(options['s']),
152
+ act_layer=act_layer,
153
+ )
154
+ if block_type == 'ir':
155
+ block_args.update(dict(
156
+ dw_kernel_size=_parse_ksize(options['k']),
157
+ exp_kernel_size=exp_kernel_size,
158
+ pw_kernel_size=pw_kernel_size,
159
+ exp_ratio=float(options['e']),
160
+ se_ratio=float(options['se']) if 'se' in options else 0.,
161
+ noskip=skip is False,
162
+ ))
163
+ if 'cc' in options:
164
+ block_args['num_experts'] = int(options['cc'])
165
+ elif block_type == 'ds' or block_type == 'dsa':
166
+ block_args.update(dict(
167
+ dw_kernel_size=_parse_ksize(options['k']),
168
+ pw_kernel_size=pw_kernel_size,
169
+ se_ratio=float(options['se']) if 'se' in options else 0.,
170
+ pw_act=block_type == 'dsa',
171
+ noskip=block_type == 'dsa' or skip is False,
172
+ ))
173
+ elif block_type == 'er':
174
+ block_args.update(dict(
175
+ exp_kernel_size=_parse_ksize(options['k']),
176
+ pw_kernel_size=pw_kernel_size,
177
+ exp_ratio=float(options['e']),
178
+ force_in_chs=force_in_chs,
179
+ se_ratio=float(options['se']) if 'se' in options else 0.,
180
+ noskip=skip is False,
181
+ ))
182
+ elif block_type == 'cn':
183
+ block_args.update(dict(
184
+ kernel_size=int(options['k']),
185
+ skip=skip is True,
186
+ ))
187
+ else:
188
+ assert False, 'Unknown block type (%s)' % block_type
189
+ if 'gs' in options:
190
+ block_args['group_size'] = options['gs']
191
+
192
+ return block_args, num_repeat
193
+
194
+
195
+ def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'):
196
+ """ Per-stage depth scaling
197
+ Scales the block repeats in each stage. This depth scaling impl maintains
198
+ compatibility with the EfficientNet scaling method, while allowing sensible
199
+ scaling for other models that may have multiple block arg definitions in each stage.
200
+ """
201
+
202
+ # We scale the total repeat count for each stage, there may be multiple
203
+ # block arg defs per stage so we need to sum.
204
+ num_repeat = sum(repeats)
205
+ if depth_trunc == 'round':
206
+ # Truncating to int by rounding allows stages with few repeats to remain
207
+ # proportionally smaller for longer. This is a good choice when stage definitions
208
+ # include single repeat stages that we'd prefer to keep that way as long as possible
209
+ num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
210
+ else:
211
+ # The default for EfficientNet truncates repeats to int via 'ceil'.
212
+ # Any multiplier > 1.0 will result in an increased depth for every stage.
213
+ num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))
214
+
215
+ # Proportionally distribute repeat count scaling to each block definition in the stage.
216
+ # Allocation is done in reverse as it results in the first block being less likely to be scaled.
217
+ # The first block makes less sense to repeat in most of the arch definitions.
218
+ repeats_scaled = []
219
+ for r in repeats[::-1]:
220
+ rs = max(1, round((r / num_repeat * num_repeat_scaled)))
221
+ repeats_scaled.append(rs)
222
+ num_repeat -= r
223
+ num_repeat_scaled -= rs
224
+ repeats_scaled = repeats_scaled[::-1]
225
+
226
+ # Apply the calculated scaling to each block arg in the stage
227
+ sa_scaled = []
228
+ for ba, rep in zip(stack_args, repeats_scaled):
229
+ sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
230
+ return sa_scaled
231
+
232
+
233
+ def decode_arch_def(
234
+ arch_def,
235
+ depth_multiplier=1.0,
236
+ depth_trunc='ceil',
237
+ experts_multiplier=1,
238
+ fix_first_last=False,
239
+ group_size=None,
240
+ ):
241
+ """ Decode block architecture definition strings -> block kwargs
242
+
243
+ Args:
244
+ arch_def: architecture definition strings, list of list of strings
245
+ depth_multiplier: network depth multiplier
246
+ depth_trunc: networ depth truncation mode when applying multiplier
247
+ experts_multiplier: CondConv experts multiplier
248
+ fix_first_last: fix first and last block depths when multiplier is applied
249
+ group_size: group size override for all blocks that weren't explicitly set in arch string
250
+
251
+ Returns:
252
+ list of list of block kwargs
253
+ """
254
+ arch_args = []
255
+ if isinstance(depth_multiplier, tuple):
256
+ assert len(depth_multiplier) == len(arch_def)
257
+ else:
258
+ depth_multiplier = (depth_multiplier,) * len(arch_def)
259
+ for stack_idx, (block_strings, multiplier) in enumerate(zip(arch_def, depth_multiplier)):
260
+ assert isinstance(block_strings, list)
261
+ stack_args = []
262
+ repeats = []
263
+ for block_str in block_strings:
264
+ assert isinstance(block_str, str)
265
+ ba, rep = _decode_block_str(block_str)
266
+ if ba.get('num_experts', 0) > 0 and experts_multiplier > 1:
267
+ ba['num_experts'] *= experts_multiplier
268
+ if group_size is not None:
269
+ ba.setdefault('group_size', group_size)
270
+ stack_args.append(ba)
271
+ repeats.append(rep)
272
+ if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1):
273
+ arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc))
274
+ else:
275
+ arch_args.append(_scale_stage_depth(stack_args, repeats, multiplier, depth_trunc))
276
+ return arch_args
277
+
278
+
279
+ class EfficientNetBuilder:
280
+ """ Build Trunk Blocks
281
+
282
+ This ended up being somewhat of a cross between
283
+ https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py
284
+ and
285
+ https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py
286
+
287
+ """
288
+ def __init__(self, output_stride=32, pad_type='', round_chs_fn=round_channels, se_from_exp=False,
289
+ act_layer=None, norm_layer=None, se_layer=None, drop_path_rate=0., feature_location=''):
290
+ self.output_stride = output_stride
291
+ self.pad_type = pad_type
292
+ self.round_chs_fn = round_chs_fn
293
+ self.se_from_exp = se_from_exp # calculate se channel reduction from expanded (mid) chs
294
+ self.act_layer = act_layer
295
+ self.norm_layer = norm_layer
296
+ self.se_layer = get_attn(se_layer)
297
+ try:
298
+ self.se_layer(8, rd_ratio=1.0) # test if attn layer accepts rd_ratio arg
299
+ self.se_has_ratio = True
300
+ except TypeError:
301
+ self.se_has_ratio = False
302
+ self.drop_path_rate = drop_path_rate
303
+ if feature_location == 'depthwise':
304
+ # old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense
305
+ _logger.warning("feature_location=='depthwise' is deprecated, using 'expansion'")
306
+ feature_location = 'expansion'
307
+ self.feature_location = feature_location
308
+ assert feature_location in ('bottleneck', 'expansion', '')
309
+ self.verbose = _DEBUG_BUILDER
310
+
311
+ # state updated during build, consumed by model
312
+ self.in_chs = None
313
+ self.features = []
314
+
315
+ def _make_block(self, ba, block_idx, block_count):
316
+ drop_path_rate = self.drop_path_rate * block_idx / block_count
317
+ bt = ba.pop('block_type')
318
+ ba['in_chs'] = self.in_chs
319
+ ba['out_chs'] = self.round_chs_fn(ba['out_chs'])
320
+ if 'force_in_chs' in ba and ba['force_in_chs']:
321
+ # NOTE this is a hack to work around mismatch in TF EdgeEffNet impl
322
+ ba['force_in_chs'] = self.round_chs_fn(ba['force_in_chs'])
323
+ ba['pad_type'] = self.pad_type
324
+ # block act fn overrides the model default
325
+ ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
326
+ assert ba['act_layer'] is not None
327
+ ba['norm_layer'] = self.norm_layer
328
+ ba['drop_path_rate'] = drop_path_rate
329
+ if bt != 'cn':
330
+ se_ratio = ba.pop('se_ratio')
331
+ if se_ratio and self.se_layer is not None:
332
+ if not self.se_from_exp:
333
+ # adjust se_ratio by expansion ratio if calculating se channels from block input
334
+ se_ratio /= ba.get('exp_ratio', 1.0)
335
+ if self.se_has_ratio:
336
+ ba['se_layer'] = partial(self.se_layer, rd_ratio=se_ratio)
337
+ else:
338
+ ba['se_layer'] = self.se_layer
339
+
340
+ if bt == 'ir':
341
+ _log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
342
+ block = CondConvResidual(**ba) if ba.get('num_experts', 0) else InvertedResidual(**ba)
343
+ elif bt == 'ds' or bt == 'dsa':
344
+ _log_info_if(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
345
+ block = DepthwiseSeparableConv(**ba)
346
+ elif bt == 'er':
347
+ _log_info_if(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
348
+ block = EdgeResidual(**ba)
349
+ elif bt == 'cn':
350
+ _log_info_if(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
351
+ block = ConvBnAct(**ba)
352
+ else:
353
+ assert False, 'Uknkown block type (%s) while building model.' % bt
354
+
355
+ self.in_chs = ba['out_chs'] # update in_chs for arg of next block
356
+ return block
357
+
358
+ def __call__(self, in_chs, model_block_args):
359
+ """ Build the blocks
360
+ Args:
361
+ in_chs: Number of input-channels passed to first block
362
+ model_block_args: A list of lists, outer list defines stages, inner
363
+ list contains strings defining block configuration(s)
364
+ Return:
365
+ List of block stacks (each stack wrapped in nn.Sequential)
366
+ """
367
+ _log_info_if('Building model trunk with %d stages...' % len(model_block_args), self.verbose)
368
+ self.in_chs = in_chs
369
+ total_block_count = sum([len(x) for x in model_block_args])
370
+ total_block_idx = 0
371
+ current_stride = 2
372
+ current_dilation = 1
373
+ stages = []
374
+ if model_block_args[0][0]['stride'] > 1:
375
+ # if the first block starts with a stride, we need to extract first level feat from stem
376
+ feature_info = dict(module='bn1', num_chs=in_chs, stage=0, reduction=current_stride)
377
+ self.features.append(feature_info)
378
+
379
+ # outer list of block_args defines the stacks
380
+ for stack_idx, stack_args in enumerate(model_block_args):
381
+ last_stack = stack_idx + 1 == len(model_block_args)
382
+ _log_info_if('Stack: {}'.format(stack_idx), self.verbose)
383
+ assert isinstance(stack_args, list)
384
+
385
+ blocks = []
386
+ # each stack (stage of blocks) contains a list of block arguments
387
+ for block_idx, block_args in enumerate(stack_args):
388
+ last_block = block_idx + 1 == len(stack_args)
389
+ _log_info_if(' Block: {}'.format(block_idx), self.verbose)
390
+
391
+ assert block_args['stride'] in (1, 2)
392
+ if block_idx >= 1: # only the first block in any stack can have a stride > 1
393
+ block_args['stride'] = 1
394
+
395
+ extract_features = False
396
+ if last_block:
397
+ next_stack_idx = stack_idx + 1
398
+ extract_features = next_stack_idx >= len(model_block_args) or \
399
+ model_block_args[next_stack_idx][0]['stride'] > 1
400
+
401
+ next_dilation = current_dilation
402
+ if block_args['stride'] > 1:
403
+ next_output_stride = current_stride * block_args['stride']
404
+ if next_output_stride > self.output_stride:
405
+ next_dilation = current_dilation * block_args['stride']
406
+ block_args['stride'] = 1
407
+ _log_info_if(' Converting stride to dilation to maintain output_stride=={}'.format(
408
+ self.output_stride), self.verbose)
409
+ else:
410
+ current_stride = next_output_stride
411
+ block_args['dilation'] = current_dilation
412
+ if next_dilation != current_dilation:
413
+ current_dilation = next_dilation
414
+
415
+ # create the block
416
+ block = self._make_block(block_args, total_block_idx, total_block_count)
417
+ blocks.append(block)
418
+
419
+ # stash feature module name and channel info for model feature extraction
420
+ if extract_features:
421
+ feature_info = dict(
422
+ stage=stack_idx + 1,
423
+ reduction=current_stride,
424
+ **block.feature_info(self.feature_location),
425
+ )
426
+ leaf_name = feature_info.get('module', '')
427
+ if leaf_name:
428
+ feature_info['module'] = '.'.join([f'blocks.{stack_idx}.{block_idx}', leaf_name])
429
+ else:
430
+ assert last_block
431
+ feature_info['module'] = f'blocks.{stack_idx}'
432
+ self.features.append(feature_info)
433
+
434
+ total_block_idx += 1 # incr global block idx (across all stacks)
435
+ stages.append(nn.Sequential(*blocks))
436
+ return stages
437
+
438
+
439
+ def _init_weight_goog(m, n='', fix_group_fanout=True):
440
+ """ Weight initialization as per Tensorflow official implementations.
441
+
442
+ Args:
443
+ m (nn.Module): module to init
444
+ n (str): module name
445
+ fix_group_fanout (bool): enable correct (matching Tensorflow TPU impl) fanout calculation w/ group convs
446
+
447
+ Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc:
448
+ * https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
449
+ * https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
450
+ """
451
+ if isinstance(m, CondConv2d):
452
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
453
+ if fix_group_fanout:
454
+ fan_out //= m.groups
455
+ init_weight_fn = get_condconv_initializer(
456
+ lambda w: nn.init.normal_(w, 0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
457
+ init_weight_fn(m.weight)
458
+ if m.bias is not None:
459
+ nn.init.zeros_(m.bias)
460
+ elif isinstance(m, nn.Conv2d):
461
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
462
+ if fix_group_fanout:
463
+ fan_out //= m.groups
464
+ nn.init.normal_(m.weight, 0, math.sqrt(2.0 / fan_out))
465
+ if m.bias is not None:
466
+ nn.init.zeros_(m.bias)
467
+ elif isinstance(m, nn.BatchNorm2d):
468
+ nn.init.ones_(m.weight)
469
+ nn.init.zeros_(m.bias)
470
+ elif isinstance(m, nn.Linear):
471
+ fan_out = m.weight.size(0) # fan-out
472
+ fan_in = 0
473
+ if 'routing_fn' in n:
474
+ fan_in = m.weight.size(1)
475
+ init_range = 1.0 / math.sqrt(fan_in + fan_out)
476
+ nn.init.uniform_(m.weight, -init_range, init_range)
477
+ nn.init.zeros_(m.bias)
478
+
479
+
480
+ def efficientnet_init_weights(model: nn.Module, init_fn=None):
481
+ init_fn = init_fn or _init_weight_goog
482
+ for n, m in model.named_modules():
483
+ init_fn(m, n)
484
+
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_factory.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Dict, Optional, Union
3
+ from urllib.parse import urlsplit
4
+
5
+ from timm.layers import set_layer_config
6
+ from ._helpers import load_checkpoint
7
+ from ._hub import load_model_config_from_hf
8
+ from ._pretrained import PretrainedCfg
9
+ from ._registry import is_model, model_entrypoint, split_model_name_tag
10
+
11
+
12
+ __all__ = ['parse_model_name', 'safe_model_name', 'create_model']
13
+
14
+
15
+ def parse_model_name(model_name: str):
16
+ if model_name.startswith('hf_hub'):
17
+ # NOTE for backwards compat, deprecate hf_hub use
18
+ model_name = model_name.replace('hf_hub', 'hf-hub')
19
+ parsed = urlsplit(model_name)
20
+ assert parsed.scheme in ('', 'timm', 'hf-hub')
21
+ if parsed.scheme == 'hf-hub':
22
+ # FIXME may use fragment as revision, currently `@` in URI path
23
+ return parsed.scheme, parsed.path
24
+ else:
25
+ model_name = os.path.split(parsed.path)[-1]
26
+ return 'timm', model_name
27
+
28
+
29
+ def safe_model_name(model_name: str, remove_source: bool = True):
30
+ # return a filename / path safe model name
31
+ def make_safe(name):
32
+ return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_')
33
+ if remove_source:
34
+ model_name = parse_model_name(model_name)[-1]
35
+ return make_safe(model_name)
36
+
37
+
38
+ def create_model(
39
+ model_name: str,
40
+ pretrained: bool = False,
41
+ pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
42
+ pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
43
+ checkpoint_path: str = '',
44
+ scriptable: Optional[bool] = None,
45
+ exportable: Optional[bool] = None,
46
+ no_jit: Optional[bool] = None,
47
+ **kwargs,
48
+ ):
49
+ """Create a model.
50
+
51
+ Lookup model's entrypoint function and pass relevant args to create a new model.
52
+
53
+ <Tip>
54
+ **kwargs will be passed through entrypoint fn to ``timm.models.build_model_with_cfg()``
55
+ and then the model class __init__(). kwargs values set to None are pruned before passing.
56
+ </Tip>
57
+
58
+ Args:
59
+ model_name: Name of model to instantiate.
60
+ pretrained: If set to `True`, load pretrained ImageNet-1k weights.
61
+ pretrained_cfg: Pass in an external pretrained_cfg for model.
62
+ pretrained_cfg_overlay: Replace key-values in base pretrained_cfg with these.
63
+ checkpoint_path: Path of checkpoint to load _after_ the model is initialized.
64
+ scriptable: Set layer config so that model is jit scriptable (not working for all models yet).
65
+ exportable: Set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet).
66
+ no_jit: Set layer config so that model doesn't utilize jit scripted layers (so far activations only).
67
+
68
+ Keyword Args:
69
+ drop_rate (float): Classifier dropout rate for training.
70
+ drop_path_rate (float): Stochastic depth drop rate for training.
71
+ global_pool (str): Classifier global pooling type.
72
+
73
+ Example:
74
+
75
+ ```py
76
+ >>> from timm import create_model
77
+
78
+ >>> # Create a MobileNetV3-Large model with no pretrained weights.
79
+ >>> model = create_model('mobilenetv3_large_100')
80
+
81
+ >>> # Create a MobileNetV3-Large model with pretrained weights.
82
+ >>> model = create_model('mobilenetv3_large_100', pretrained=True)
83
+ >>> model.num_classes
84
+ 1000
85
+
86
+ >>> # Create a MobileNetV3-Large model with pretrained weights and a new head with 10 classes.
87
+ >>> model = create_model('mobilenetv3_large_100', pretrained=True, num_classes=10)
88
+ >>> model.num_classes
89
+ 10
90
+ ```
91
+ """
92
+ # Parameters that aren't supported by all models or are intended to only override model defaults if set
93
+ # should default to None in command line args/cfg. Remove them if they are present and not set so that
94
+ # non-supporting models don't break and default args remain in effect.
95
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
96
+
97
+ model_source, model_name = parse_model_name(model_name)
98
+ if model_source == 'hf-hub':
99
+ assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.'
100
+ # For model names specified in the form `hf-hub:path/architecture_name@revision`,
101
+ # load model weights + pretrained_cfg from Hugging Face hub.
102
+ pretrained_cfg, model_name, model_args = load_model_config_from_hf(model_name)
103
+ if model_args:
104
+ for k, v in model_args.items():
105
+ kwargs.setdefault(k, v)
106
+ else:
107
+ model_name, pretrained_tag = split_model_name_tag(model_name)
108
+ if pretrained_tag and not pretrained_cfg:
109
+ # a valid pretrained_cfg argument takes priority over tag in model name
110
+ pretrained_cfg = pretrained_tag
111
+
112
+ if not is_model(model_name):
113
+ raise RuntimeError('Unknown model (%s)' % model_name)
114
+
115
+ create_fn = model_entrypoint(model_name)
116
+ with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
117
+ model = create_fn(
118
+ pretrained=pretrained,
119
+ pretrained_cfg=pretrained_cfg,
120
+ pretrained_cfg_overlay=pretrained_cfg_overlay,
121
+ **kwargs,
122
+ )
123
+
124
+ if checkpoint_path:
125
+ load_checkpoint(model, checkpoint_path)
126
+
127
+ return model
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_features.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ PyTorch Feature Extraction Helpers
2
+
3
+ A collection of classes, functions, modules to help extract features from models
4
+ and provide a common interface for describing them.
5
+
6
+ The return_layers, module re-writing idea inspired by torchvision IntermediateLayerGetter
7
+ https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py
8
+
9
+ Hacked together by / Copyright 2020 Ross Wightman
10
+ """
11
+ from collections import OrderedDict, defaultdict
12
+ from copy import deepcopy
13
+ from functools import partial
14
+ from typing import Dict, List, Sequence, Tuple, Union
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from torch.utils.checkpoint import checkpoint
19
+
20
+ from timm.layers import Format
21
+
22
+
23
+ __all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet']
24
+
25
+
26
+ class FeatureInfo:
27
+
28
+ def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
29
+ prev_reduction = 1
30
+ for i, fi in enumerate(feature_info):
31
+ # sanity check the mandatory fields, there may be additional fields depending on the model
32
+ assert 'num_chs' in fi and fi['num_chs'] > 0
33
+ assert 'reduction' in fi and fi['reduction'] >= prev_reduction
34
+ prev_reduction = fi['reduction']
35
+ assert 'module' in fi
36
+ fi.setdefault('index', i)
37
+ self.out_indices = out_indices
38
+ self.info = feature_info
39
+
40
+ def from_other(self, out_indices: Tuple[int]):
41
+ return FeatureInfo(deepcopy(self.info), out_indices)
42
+
43
+ def get(self, key, idx=None):
44
+ """ Get value by key at specified index (indices)
45
+ if idx == None, returns value for key at each output index
46
+ if idx is an integer, return value for that feature module index (ignoring output indices)
47
+ if idx is a list/tupple, return value for each module index (ignoring output indices)
48
+ """
49
+ if idx is None:
50
+ return [self.info[i][key] for i in self.out_indices]
51
+ if isinstance(idx, (tuple, list)):
52
+ return [self.info[i][key] for i in idx]
53
+ else:
54
+ return self.info[idx][key]
55
+
56
+ def get_dicts(self, keys=None, idx=None):
57
+ """ return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)
58
+ """
59
+ if idx is None:
60
+ if keys is None:
61
+ return [self.info[i] for i in self.out_indices]
62
+ else:
63
+ return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
64
+ if isinstance(idx, (tuple, list)):
65
+ return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx]
66
+ else:
67
+ return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
68
+
69
+ def channels(self, idx=None):
70
+ """ feature channels accessor
71
+ """
72
+ return self.get('num_chs', idx)
73
+
74
+ def reduction(self, idx=None):
75
+ """ feature reduction (output stride) accessor
76
+ """
77
+ return self.get('reduction', idx)
78
+
79
+ def module_name(self, idx=None):
80
+ """ feature module name accessor
81
+ """
82
+ return self.get('module', idx)
83
+
84
+ def __getitem__(self, item):
85
+ return self.info[item]
86
+
87
+ def __len__(self):
88
+ return len(self.info)
89
+
90
+
91
+ class FeatureHooks:
92
+ """ Feature Hook Helper
93
+
94
+ This module helps with the setup and extraction of hooks for extracting features from
95
+ internal nodes in a model by node name.
96
+
97
+ FIXME This works well in eager Python but needs redesign for torchscript.
98
+ """
99
+
100
+ def __init__(
101
+ self,
102
+ hooks: Sequence[str],
103
+ named_modules: dict,
104
+ out_map: Sequence[Union[int, str]] = None,
105
+ default_hook_type: str = 'forward',
106
+ ):
107
+ # setup feature hooks
108
+ self._feature_outputs = defaultdict(OrderedDict)
109
+ modules = {k: v for k, v in named_modules}
110
+ for i, h in enumerate(hooks):
111
+ hook_name = h['module']
112
+ m = modules[hook_name]
113
+ hook_id = out_map[i] if out_map else hook_name
114
+ hook_fn = partial(self._collect_output_hook, hook_id)
115
+ hook_type = h.get('hook_type', default_hook_type)
116
+ if hook_type == 'forward_pre':
117
+ m.register_forward_pre_hook(hook_fn)
118
+ elif hook_type == 'forward':
119
+ m.register_forward_hook(hook_fn)
120
+ else:
121
+ assert False, "Unsupported hook type"
122
+
123
+ def _collect_output_hook(self, hook_id, *args):
124
+ x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
125
+ if isinstance(x, tuple):
126
+ x = x[0] # unwrap input tuple
127
+ self._feature_outputs[x.device][hook_id] = x
128
+
129
+ def get_output(self, device) -> Dict[str, torch.tensor]:
130
+ output = self._feature_outputs[device]
131
+ self._feature_outputs[device] = OrderedDict() # clear after reading
132
+ return output
133
+
134
+
135
+ def _module_list(module, flatten_sequential=False):
136
+ # a yield/iter would be better for this but wouldn't be compatible with torchscript
137
+ ml = []
138
+ for name, module in module.named_children():
139
+ if flatten_sequential and isinstance(module, nn.Sequential):
140
+ # first level of Sequential containers is flattened into containing model
141
+ for child_name, child_module in module.named_children():
142
+ combined = [name, child_name]
143
+ ml.append(('_'.join(combined), '.'.join(combined), child_module))
144
+ else:
145
+ ml.append((name, name, module))
146
+ return ml
147
+
148
+
149
+ def _get_feature_info(net, out_indices):
150
+ feature_info = getattr(net, 'feature_info')
151
+ if isinstance(feature_info, FeatureInfo):
152
+ return feature_info.from_other(out_indices)
153
+ elif isinstance(feature_info, (list, tuple)):
154
+ return FeatureInfo(net.feature_info, out_indices)
155
+ else:
156
+ assert False, "Provided feature_info is not valid"
157
+
158
+
159
+ def _get_return_layers(feature_info, out_map):
160
+ module_names = feature_info.module_name()
161
+ return_layers = {}
162
+ for i, name in enumerate(module_names):
163
+ return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]
164
+ return return_layers
165
+
166
+
167
+ class FeatureDictNet(nn.ModuleDict):
168
+ """ Feature extractor with OrderedDict return
169
+
170
+ Wrap a model and extract features as specified by the out indices, the network is
171
+ partially re-built from contained modules.
172
+
173
+ There is a strong assumption that the modules have been registered into the model in the same
174
+ order as they are used. There should be no reuse of the same nn.Module more than once, including
175
+ trivial modules like `self.relu = nn.ReLU`.
176
+
177
+ Only submodules that are directly assigned to the model class (`model.feature1`) or at most
178
+ one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.
179
+ All Sequential containers that are directly assigned to the original model will have their
180
+ modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
181
+ """
182
+ def __init__(
183
+ self,
184
+ model: nn.Module,
185
+ out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
186
+ out_map: Sequence[Union[int, str]] = None,
187
+ output_fmt: str = 'NCHW',
188
+ feature_concat: bool = False,
189
+ flatten_sequential: bool = False,
190
+ ):
191
+ """
192
+ Args:
193
+ model: Model from which to extract features.
194
+ out_indices: Output indices of the model features to extract.
195
+ out_map: Return id mapping for each output index, otherwise str(index) is used.
196
+ feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
197
+ first element e.g. `x[0]`
198
+ flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
199
+ """
200
+ super(FeatureDictNet, self).__init__()
201
+ self.feature_info = _get_feature_info(model, out_indices)
202
+ self.output_fmt = Format(output_fmt)
203
+ self.concat = feature_concat
204
+ self.grad_checkpointing = False
205
+ self.return_layers = {}
206
+
207
+ return_layers = _get_return_layers(self.feature_info, out_map)
208
+ modules = _module_list(model, flatten_sequential=flatten_sequential)
209
+ remaining = set(return_layers.keys())
210
+ layers = OrderedDict()
211
+ for new_name, old_name, module in modules:
212
+ layers[new_name] = module
213
+ if old_name in remaining:
214
+ # return id has to be consistently str type for torchscript
215
+ self.return_layers[new_name] = str(return_layers[old_name])
216
+ remaining.remove(old_name)
217
+ if not remaining:
218
+ break
219
+ assert not remaining and len(self.return_layers) == len(return_layers), \
220
+ f'Return layers ({remaining}) are not present in model'
221
+ self.update(layers)
222
+
223
+ def set_grad_checkpointing(self, enable: bool = True):
224
+ self.grad_checkpointing = enable
225
+
226
+ def _collect(self, x) -> (Dict[str, torch.Tensor]):
227
+ out = OrderedDict()
228
+ for i, (name, module) in enumerate(self.items()):
229
+ if self.grad_checkpointing and not torch.jit.is_scripting():
230
+ # Skipping checkpoint of first module because need a gradient at input
231
+ # Skipping last because networks with in-place ops might fail w/ checkpointing enabled
232
+ # NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues
233
+ first_or_last_module = i == 0 or i == max(len(self) - 1, 0)
234
+ x = module(x) if first_or_last_module else checkpoint(module, x)
235
+ else:
236
+ x = module(x)
237
+
238
+ if name in self.return_layers:
239
+ out_id = self.return_layers[name]
240
+ if isinstance(x, (tuple, list)):
241
+ # If model tap is a tuple or list, concat or select first element
242
+ # FIXME this may need to be more generic / flexible for some nets
243
+ out[out_id] = torch.cat(x, 1) if self.concat else x[0]
244
+ else:
245
+ out[out_id] = x
246
+ return out
247
+
248
+ def forward(self, x) -> Dict[str, torch.Tensor]:
249
+ return self._collect(x)
250
+
251
+
252
+ class FeatureListNet(FeatureDictNet):
253
+ """ Feature extractor with list return
254
+
255
+ A specialization of FeatureDictNet that always returns features as a list (values() of dict).
256
+ """
257
+ def __init__(
258
+ self,
259
+ model: nn.Module,
260
+ out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
261
+ output_fmt: str = 'NCHW',
262
+ feature_concat: bool = False,
263
+ flatten_sequential: bool = False,
264
+ ):
265
+ """
266
+ Args:
267
+ model: Model from which to extract features.
268
+ out_indices: Output indices of the model features to extract.
269
+ feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
270
+ first element e.g. `x[0]`
271
+ flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
272
+ """
273
+ super().__init__(
274
+ model,
275
+ out_indices=out_indices,
276
+ output_fmt=output_fmt,
277
+ feature_concat=feature_concat,
278
+ flatten_sequential=flatten_sequential,
279
+ )
280
+
281
+ def forward(self, x) -> (List[torch.Tensor]):
282
+ return list(self._collect(x).values())
283
+
284
+
285
+ class FeatureHookNet(nn.ModuleDict):
286
+ """ FeatureHookNet
287
+
288
+ Wrap a model and extract features specified by the out indices using forward/forward-pre hooks.
289
+
290
+ If `no_rewrite` is True, features are extracted via hooks without modifying the underlying
291
+ network in any way.
292
+
293
+ If `no_rewrite` is False, the model will be re-written as in the
294
+ FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one.
295
+
296
+ FIXME this does not currently work with Torchscript, see FeatureHooks class
297
+ """
298
+ def __init__(
299
+ self,
300
+ model: nn.Module,
301
+ out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
302
+ out_map: Sequence[Union[int, str]] = None,
303
+ return_dict: bool = False,
304
+ output_fmt: str = 'NCHW',
305
+ no_rewrite: bool = False,
306
+ flatten_sequential: bool = False,
307
+ default_hook_type: str = 'forward',
308
+ ):
309
+ """
310
+
311
+ Args:
312
+ model: Model from which to extract features.
313
+ out_indices: Output indices of the model features to extract.
314
+ out_map: Return id mapping for each output index, otherwise str(index) is used.
315
+ return_dict: Output features as a dict.
316
+ no_rewrite: Enforce that model is not re-written if True, ie no modules are removed / changed.
317
+ flatten_sequential arg must also be False if this is set True.
318
+ flatten_sequential: Re-write modules by flattening first two levels of nn.Sequential containers.
319
+ default_hook_type: The default hook type to use if not specified in model.feature_info.
320
+ """
321
+ super().__init__()
322
+ assert not torch.jit.is_scripting()
323
+ self.feature_info = _get_feature_info(model, out_indices)
324
+ self.return_dict = return_dict
325
+ self.output_fmt = Format(output_fmt)
326
+ self.grad_checkpointing = False
327
+
328
+ layers = OrderedDict()
329
+ hooks = []
330
+ if no_rewrite:
331
+ assert not flatten_sequential
332
+ if hasattr(model, 'reset_classifier'): # make sure classifier is removed?
333
+ model.reset_classifier(0)
334
+ layers['body'] = model
335
+ hooks.extend(self.feature_info.get_dicts())
336
+ else:
337
+ modules = _module_list(model, flatten_sequential=flatten_sequential)
338
+ remaining = {
339
+ f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
340
+ for f in self.feature_info.get_dicts()
341
+ }
342
+ for new_name, old_name, module in modules:
343
+ layers[new_name] = module
344
+ for fn, fm in module.named_modules(prefix=old_name):
345
+ if fn in remaining:
346
+ hooks.append(dict(module=fn, hook_type=remaining[fn]))
347
+ del remaining[fn]
348
+ if not remaining:
349
+ break
350
+ assert not remaining, f'Return layers ({remaining}) are not present in model'
351
+ self.update(layers)
352
+ self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
353
+
354
+ def set_grad_checkpointing(self, enable: bool = True):
355
+ self.grad_checkpointing = enable
356
+
357
+ def forward(self, x):
358
+ for i, (name, module) in enumerate(self.items()):
359
+ if self.grad_checkpointing and not torch.jit.is_scripting():
360
+ # Skipping checkpoint of first module because need a gradient at input
361
+ # Skipping last because networks with in-place ops might fail w/ checkpointing enabled
362
+ # NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues
363
+ first_or_last_module = i == 0 or i == max(len(self) - 1, 0)
364
+ x = module(x) if first_or_last_module else checkpoint(module, x)
365
+ else:
366
+ x = module(x)
367
+ out = self.hooks.get_output(x.device)
368
+ return out if self.return_dict else list(out.values())
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_features_fx.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ PyTorch FX Based Feature Extraction Helpers
2
+ Using https://pytorch.org/vision/stable/feature_extraction.html
3
+ """
4
+ from typing import Callable, List, Dict, Union, Type
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+ from ._features import _get_feature_info, _get_return_layers
10
+
11
+ try:
12
+ from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
13
+ has_fx_feature_extraction = True
14
+ except ImportError:
15
+ has_fx_feature_extraction = False
16
+
17
+ # Layers we went to treat as leaf modules
18
+ from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame
19
+ from timm.layers.non_local_attn import BilinearAttnTransform
20
+ from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
21
+ from timm.layers.norm_act import (
22
+ BatchNormAct2d,
23
+ SyncBatchNormAct,
24
+ FrozenBatchNormAct2d,
25
+ GroupNormAct,
26
+ GroupNorm1Act,
27
+ LayerNormAct,
28
+ LayerNormAct2d
29
+ )
30
+
31
+ __all__ = ['register_notrace_module', 'is_notrace_module', 'get_notrace_modules',
32
+ 'register_notrace_function', 'is_notrace_function', 'get_notrace_functions',
33
+ 'create_feature_extractor', 'FeatureGraphNet', 'GraphExtractNet']
34
+
35
+
36
+ # NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
37
+ # BUT modules from timm.models should use the registration mechanism below
38
+ _leaf_modules = {
39
+ BilinearAttnTransform, # reason: flow control t <= 1
40
+ # Reason: get_same_padding has a max which raises a control flow error
41
+ Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame,
42
+ CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0]),
43
+ BatchNormAct2d,
44
+ SyncBatchNormAct,
45
+ FrozenBatchNormAct2d,
46
+ GroupNormAct,
47
+ GroupNorm1Act,
48
+ LayerNormAct,
49
+ LayerNormAct2d,
50
+ }
51
+
52
+ try:
53
+ from timm.layers import InplaceAbn
54
+ _leaf_modules.add(InplaceAbn)
55
+ except ImportError:
56
+ pass
57
+
58
+
59
+ def register_notrace_module(module: Type[nn.Module]):
60
+ """
61
+ Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
62
+ """
63
+ _leaf_modules.add(module)
64
+ return module
65
+
66
+
67
+ def is_notrace_module(module: Type[nn.Module]):
68
+ return module in _leaf_modules
69
+
70
+
71
+ def get_notrace_modules():
72
+ return list(_leaf_modules)
73
+
74
+
75
+ # Functions we want to autowrap (treat them as leaves)
76
+ _autowrap_functions = set()
77
+
78
+
79
+ def register_notrace_function(func: Callable):
80
+ """
81
+ Decorator for functions which ought not to be traced through
82
+ """
83
+ _autowrap_functions.add(func)
84
+ return func
85
+
86
+
87
+ def is_notrace_function(func: Callable):
88
+ return func in _autowrap_functions
89
+
90
+
91
+ def get_notrace_functions():
92
+ return list(_autowrap_functions)
93
+
94
+
95
+ def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
96
+ assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
97
+ return _create_feature_extractor(
98
+ model, return_nodes,
99
+ tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}
100
+ )
101
+
102
+
103
+ class FeatureGraphNet(nn.Module):
104
+ """ A FX Graph based feature extractor that works with the model feature_info metadata
105
+ """
106
+ def __init__(self, model, out_indices, out_map=None):
107
+ super().__init__()
108
+ assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
109
+ self.feature_info = _get_feature_info(model, out_indices)
110
+ if out_map is not None:
111
+ assert len(out_map) == len(out_indices)
112
+ return_nodes = _get_return_layers(self.feature_info, out_map)
113
+ self.graph_module = create_feature_extractor(model, return_nodes)
114
+
115
+ def forward(self, x):
116
+ return list(self.graph_module(x).values())
117
+
118
+
119
+ class GraphExtractNet(nn.Module):
120
+ """ A standalone feature extraction wrapper that maps dict -> list or single tensor
121
+ NOTE:
122
+ * one can use feature_extractor directly if dictionary output is desired
123
+ * unlike FeatureGraphNet, this is intended to be used standalone and not with model feature_info
124
+ metadata for builtin feature extraction mode
125
+ * create_feature_extractor can be used directly if dictionary output is acceptable
126
+
127
+ Args:
128
+ model: model to extract features from
129
+ return_nodes: node names to return features from (dict or list)
130
+ squeeze_out: if only one output, and output in list format, flatten to single tensor
131
+ """
132
+ def __init__(self, model, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True):
133
+ super().__init__()
134
+ self.squeeze_out = squeeze_out
135
+ self.graph_module = create_feature_extractor(model, return_nodes)
136
+
137
+ def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]:
138
+ out = list(self.graph_module(x).values())
139
+ if self.squeeze_out and len(out) == 1:
140
+ return out[0]
141
+ return out
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_hub.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hashlib
2
+ import json
3
+ import logging
4
+ import os
5
+ from functools import partial
6
+ from pathlib import Path
7
+ from tempfile import TemporaryDirectory
8
+ from typing import Iterable, Optional, Union
9
+
10
+ import torch
11
+ from torch.hub import HASH_REGEX, download_url_to_file, urlparse
12
+
13
+ try:
14
+ from torch.hub import get_dir
15
+ except ImportError:
16
+ from torch.hub import _get_torch_home as get_dir
17
+
18
+ try:
19
+ import safetensors.torch
20
+ _has_safetensors = True
21
+ except ImportError:
22
+ _has_safetensors = False
23
+
24
+ try:
25
+ from typing import Literal
26
+ except ImportError:
27
+ from typing_extensions import Literal
28
+
29
+ from timm import __version__
30
+ from timm.models._pretrained import filter_pretrained_cfg
31
+
32
+ try:
33
+ from huggingface_hub import (
34
+ create_repo, get_hf_file_metadata,
35
+ hf_hub_download, hf_hub_url,
36
+ repo_type_and_id_from_hf_id, upload_folder)
37
+ from huggingface_hub.utils import EntryNotFoundError
38
+ hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
39
+ _has_hf_hub = True
40
+ except ImportError:
41
+ hf_hub_download = None
42
+ _has_hf_hub = False
43
+
44
+ _logger = logging.getLogger(__name__)
45
+
46
+ __all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf',
47
+ 'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub']
48
+
49
+ # Default name for a weights file hosted on the Huggingface Hub.
50
+ HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
51
+ HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version
52
+ HF_OPEN_CLIP_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl
53
+ HF_OPEN_CLIP_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version
54
+
55
+
56
+ def get_cache_dir(child_dir=''):
57
+ """
58
+ Returns the location of the directory where models are cached (and creates it if necessary).
59
+ """
60
+ # Issue warning to move data if old env is set
61
+ if os.getenv('TORCH_MODEL_ZOO'):
62
+ _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
63
+
64
+ hub_dir = get_dir()
65
+ child_dir = () if not child_dir else (child_dir,)
66
+ model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir)
67
+ os.makedirs(model_dir, exist_ok=True)
68
+ return model_dir
69
+
70
+
71
+ def download_cached_file(url, check_hash=True, progress=False):
72
+ if isinstance(url, (list, tuple)):
73
+ url, filename = url
74
+ else:
75
+ parts = urlparse(url)
76
+ filename = os.path.basename(parts.path)
77
+ cached_file = os.path.join(get_cache_dir(), filename)
78
+ if not os.path.exists(cached_file):
79
+ _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
80
+ hash_prefix = None
81
+ if check_hash:
82
+ r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
83
+ hash_prefix = r.group(1) if r else None
84
+ download_url_to_file(url, cached_file, hash_prefix, progress=progress)
85
+ return cached_file
86
+
87
+
88
+ def check_cached_file(url, check_hash=True):
89
+ if isinstance(url, (list, tuple)):
90
+ url, filename = url
91
+ else:
92
+ parts = urlparse(url)
93
+ filename = os.path.basename(parts.path)
94
+ cached_file = os.path.join(get_cache_dir(), filename)
95
+ if os.path.exists(cached_file):
96
+ if check_hash:
97
+ r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
98
+ hash_prefix = r.group(1) if r else None
99
+ if hash_prefix:
100
+ with open(cached_file, 'rb') as f:
101
+ hd = hashlib.sha256(f.read()).hexdigest()
102
+ if hd[:len(hash_prefix)] != hash_prefix:
103
+ return False
104
+ return True
105
+ return False
106
+
107
+
108
+ def has_hf_hub(necessary=False):
109
+ if not _has_hf_hub and necessary:
110
+ # if no HF Hub module installed, and it is necessary to continue, raise error
111
+ raise RuntimeError(
112
+ 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
113
+ return _has_hf_hub
114
+
115
+
116
+ def hf_split(hf_id: str):
117
+ # FIXME I may change @ -> # and be parsed as fragment in a URI model name scheme
118
+ rev_split = hf_id.split('@')
119
+ assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.'
120
+ hf_model_id = rev_split[0]
121
+ hf_revision = rev_split[-1] if len(rev_split) > 1 else None
122
+ return hf_model_id, hf_revision
123
+
124
+
125
+ def load_cfg_from_json(json_file: Union[str, os.PathLike]):
126
+ with open(json_file, "r", encoding="utf-8") as reader:
127
+ text = reader.read()
128
+ return json.loads(text)
129
+
130
+
131
+ def download_from_hf(model_id: str, filename: str):
132
+ hf_model_id, hf_revision = hf_split(model_id)
133
+ return hf_hub_download(hf_model_id, filename, revision=hf_revision)
134
+
135
+
136
+ def load_model_config_from_hf(model_id: str):
137
+ assert has_hf_hub(True)
138
+ cached_file = download_from_hf(model_id, 'config.json')
139
+
140
+ hf_config = load_cfg_from_json(cached_file)
141
+ if 'pretrained_cfg' not in hf_config:
142
+ # old form, pull pretrain_cfg out of the base dict
143
+ pretrained_cfg = hf_config
144
+ hf_config = {}
145
+ hf_config['architecture'] = pretrained_cfg.pop('architecture')
146
+ hf_config['num_features'] = pretrained_cfg.pop('num_features', None)
147
+ if 'labels' in pretrained_cfg: # deprecated name for 'label_names'
148
+ pretrained_cfg['label_names'] = pretrained_cfg.pop('labels')
149
+ hf_config['pretrained_cfg'] = pretrained_cfg
150
+
151
+ # NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now
152
+ pretrained_cfg = hf_config['pretrained_cfg']
153
+ pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation
154
+ pretrained_cfg['source'] = 'hf-hub'
155
+
156
+ # model should be created with base config num_classes if its exist
157
+ if 'num_classes' in hf_config:
158
+ pretrained_cfg['num_classes'] = hf_config['num_classes']
159
+
160
+ # label meta-data in base config overrides saved pretrained_cfg on load
161
+ if 'label_names' in hf_config:
162
+ pretrained_cfg['label_names'] = hf_config.pop('label_names')
163
+ if 'label_descriptions' in hf_config:
164
+ pretrained_cfg['label_descriptions'] = hf_config.pop('label_descriptions')
165
+
166
+ model_args = hf_config.get('model_args', {})
167
+ model_name = hf_config['architecture']
168
+ return pretrained_cfg, model_name, model_args
169
+
170
+
171
+ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
172
+ assert has_hf_hub(True)
173
+ hf_model_id, hf_revision = hf_split(model_id)
174
+
175
+ # Look for .safetensors alternatives and load from it if it exists
176
+ if _has_safetensors:
177
+ for safe_filename in _get_safe_alternatives(filename):
178
+ try:
179
+ cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision)
180
+ _logger.info(
181
+ f"[{model_id}] Safe alternative available for '{filename}' "
182
+ f"(as '{safe_filename}'). Loading weights using safetensors.")
183
+ return safetensors.torch.load_file(cached_safe_file, device="cpu")
184
+ except EntryNotFoundError:
185
+ pass
186
+
187
+ # Otherwise, load using pytorch.load
188
+ cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
189
+ _logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
190
+ return torch.load(cached_file, map_location='cpu')
191
+
192
+
193
+ def save_config_for_hf(
194
+ model,
195
+ config_path: str,
196
+ model_config: Optional[dict] = None,
197
+ model_args: Optional[dict] = None
198
+ ):
199
+ model_config = model_config or {}
200
+ hf_config = {}
201
+ pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
202
+ # set some values at root config level
203
+ hf_config['architecture'] = pretrained_cfg.pop('architecture')
204
+ hf_config['num_classes'] = model_config.pop('num_classes', model.num_classes)
205
+
206
+ # NOTE these attr saved for informational purposes, do not impact model build
207
+ hf_config['num_features'] = model_config.pop('num_features', model.num_features)
208
+ global_pool_type = model_config.pop('global_pool', getattr(model, 'global_pool', None))
209
+ if isinstance(global_pool_type, str) and global_pool_type:
210
+ hf_config['global_pool'] = global_pool_type
211
+
212
+ # Save class label info
213
+ if 'labels' in model_config:
214
+ _logger.warning(
215
+ "'labels' as a config field for is deprecated. Please use 'label_names' and 'label_descriptions'."
216
+ " Renaming provided 'labels' field to 'label_names'.")
217
+ model_config.setdefault('label_names', model_config.pop('labels'))
218
+
219
+ label_names = model_config.pop('label_names', None)
220
+ if label_names:
221
+ assert isinstance(label_names, (dict, list, tuple))
222
+ # map label id (classifier index) -> unique label name (ie synset for ImageNet, MID for OpenImages)
223
+ # can be a dict id: name if there are id gaps, or tuple/list if no gaps.
224
+ hf_config['label_names'] = label_names
225
+
226
+ label_descriptions = model_config.pop('label_descriptions', None)
227
+ if label_descriptions:
228
+ assert isinstance(label_descriptions, dict)
229
+ # maps label names -> descriptions
230
+ hf_config['label_descriptions'] = label_descriptions
231
+
232
+ if model_args:
233
+ hf_config['model_args'] = model_args
234
+
235
+ hf_config['pretrained_cfg'] = pretrained_cfg
236
+ hf_config.update(model_config)
237
+
238
+ with config_path.open('w') as f:
239
+ json.dump(hf_config, f, indent=2)
240
+
241
+
242
+ def save_for_hf(
243
+ model,
244
+ save_directory: str,
245
+ model_config: Optional[dict] = None,
246
+ model_args: Optional[dict] = None,
247
+ safe_serialization: Union[bool, Literal["both"]] = False,
248
+ ):
249
+ assert has_hf_hub(True)
250
+ save_directory = Path(save_directory)
251
+ save_directory.mkdir(exist_ok=True, parents=True)
252
+
253
+ # Save model weights, either safely (using safetensors), or using legacy pytorch approach or both.
254
+ tensors = model.state_dict()
255
+ if safe_serialization is True or safe_serialization == "both":
256
+ assert _has_safetensors, "`pip install safetensors` to use .safetensors"
257
+ safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME)
258
+ if safe_serialization is False or safe_serialization == "both":
259
+ torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
260
+
261
+ config_path = save_directory / 'config.json'
262
+ save_config_for_hf(
263
+ model,
264
+ config_path,
265
+ model_config=model_config,
266
+ model_args=model_args,
267
+ )
268
+
269
+
270
+ def push_to_hf_hub(
271
+ model: torch.nn.Module,
272
+ repo_id: str,
273
+ commit_message: str = 'Add model',
274
+ token: Optional[str] = None,
275
+ revision: Optional[str] = None,
276
+ private: bool = False,
277
+ create_pr: bool = False,
278
+ model_config: Optional[dict] = None,
279
+ model_card: Optional[dict] = None,
280
+ model_args: Optional[dict] = None,
281
+ safe_serialization: Union[bool, Literal["both"]] = False,
282
+ ):
283
+ """
284
+ Arguments:
285
+ (...)
286
+ safe_serialization (`bool` or `"both"`, *optional*, defaults to `False`):
287
+ Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
288
+ Can be set to `"both"` in order to push both safe and unsafe weights.
289
+ """
290
+ # Create repo if it doesn't exist yet
291
+ repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
292
+
293
+ # Infer complete repo_id from repo_url
294
+ # Can be different from the input `repo_id` if repo_owner was implicit
295
+ _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
296
+ repo_id = f"{repo_owner}/{repo_name}"
297
+
298
+ # Check if README file already exist in repo
299
+ try:
300
+ get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
301
+ has_readme = True
302
+ except EntryNotFoundError:
303
+ has_readme = False
304
+
305
+ # Dump model and push to Hub
306
+ with TemporaryDirectory() as tmpdir:
307
+ # Save model weights and config.
308
+ save_for_hf(
309
+ model,
310
+ tmpdir,
311
+ model_config=model_config,
312
+ model_args=model_args,
313
+ safe_serialization=safe_serialization,
314
+ )
315
+
316
+ # Add readme if it does not exist
317
+ if not has_readme:
318
+ model_card = model_card or {}
319
+ model_name = repo_id.split('/')[-1]
320
+ readme_path = Path(tmpdir) / "README.md"
321
+ readme_text = generate_readme(model_card, model_name)
322
+ readme_path.write_text(readme_text)
323
+
324
+ # Upload model and return
325
+ return upload_folder(
326
+ repo_id=repo_id,
327
+ folder_path=tmpdir,
328
+ revision=revision,
329
+ create_pr=create_pr,
330
+ commit_message=commit_message,
331
+ )
332
+
333
+
334
+ def generate_readme(model_card: dict, model_name: str):
335
+ readme_text = "---\n"
336
+ readme_text += "tags:\n- image-classification\n- timm\n"
337
+ readme_text += "library_name: timm\n"
338
+ readme_text += f"license: {model_card.get('license', 'apache-2.0')}\n"
339
+ if 'details' in model_card and 'Dataset' in model_card['details']:
340
+ readme_text += 'datasets:\n'
341
+ if isinstance(model_card['details']['Dataset'], (tuple, list)):
342
+ for d in model_card['details']['Dataset']:
343
+ readme_text += f"- {d.lower()}\n"
344
+ else:
345
+ readme_text += f"- {model_card['details']['Dataset'].lower()}\n"
346
+ if 'Pretrain Dataset' in model_card['details']:
347
+ if isinstance(model_card['details']['Pretrain Dataset'], (tuple, list)):
348
+ for d in model_card['details']['Pretrain Dataset']:
349
+ readme_text += f"- {d.lower()}\n"
350
+ else:
351
+ readme_text += f"- {model_card['details']['Pretrain Dataset'].lower()}\n"
352
+ readme_text += "---\n"
353
+ readme_text += f"# Model card for {model_name}\n"
354
+ if 'description' in model_card:
355
+ readme_text += f"\n{model_card['description']}\n"
356
+ if 'details' in model_card:
357
+ readme_text += f"\n## Model Details\n"
358
+ for k, v in model_card['details'].items():
359
+ if isinstance(v, (list, tuple)):
360
+ readme_text += f"- **{k}:**\n"
361
+ for vi in v:
362
+ readme_text += f" - {vi}\n"
363
+ elif isinstance(v, dict):
364
+ readme_text += f"- **{k}:**\n"
365
+ for ki, vi in v.items():
366
+ readme_text += f" - {ki}: {vi}\n"
367
+ else:
368
+ readme_text += f"- **{k}:** {v}\n"
369
+ if 'usage' in model_card:
370
+ readme_text += f"\n## Model Usage\n"
371
+ readme_text += model_card['usage']
372
+ readme_text += '\n'
373
+
374
+ if 'comparison' in model_card:
375
+ readme_text += f"\n## Model Comparison\n"
376
+ readme_text += model_card['comparison']
377
+ readme_text += '\n'
378
+
379
+ if 'citation' in model_card:
380
+ readme_text += f"\n## Citation\n"
381
+ if not isinstance(model_card['citation'], (list, tuple)):
382
+ citations = [model_card['citation']]
383
+ else:
384
+ citations = model_card['citation']
385
+ for c in citations:
386
+ readme_text += f"```bibtex\n{c}\n```\n"
387
+ return readme_text
388
+
389
+
390
+ def _get_safe_alternatives(filename: str) -> Iterable[str]:
391
+ """Returns potential safetensors alternatives for a given filename.
392
+
393
+ Use case:
394
+ When downloading a model from the Huggingface Hub, we first look if a .safetensors file exists and if yes, we use it.
395
+ Main use case is filename "pytorch_model.bin" => check for "model.safetensors" or "pytorch_model.safetensors".
396
+ """
397
+ if filename == HF_WEIGHTS_NAME:
398
+ yield HF_SAFE_WEIGHTS_NAME
399
+ if filename == HF_OPEN_CLIP_WEIGHTS_NAME:
400
+ yield HF_OPEN_CLIP_SAFE_WEIGHTS_NAME
401
+ if filename not in (HF_WEIGHTS_NAME, HF_OPEN_CLIP_WEIGHTS_NAME) and filename.endswith(".bin"):
402
+ yield filename[:-4] + ".safetensors"
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_prune.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pkgutil
3
+ from copy import deepcopy
4
+
5
+ from torch import nn as nn
6
+
7
+ from timm.layers import Conv2dSame, BatchNormAct2d, Linear
8
+
9
+ __all__ = ['extract_layer', 'set_layer', 'adapt_model_from_string', 'adapt_model_from_file']
10
+
11
+
12
+ def extract_layer(model, layer):
13
+ layer = layer.split('.')
14
+ module = model
15
+ if hasattr(model, 'module') and layer[0] != 'module':
16
+ module = model.module
17
+ if not hasattr(model, 'module') and layer[0] == 'module':
18
+ layer = layer[1:]
19
+ for l in layer:
20
+ if hasattr(module, l):
21
+ if not l.isdigit():
22
+ module = getattr(module, l)
23
+ else:
24
+ module = module[int(l)]
25
+ else:
26
+ return module
27
+ return module
28
+
29
+
30
+ def set_layer(model, layer, val):
31
+ layer = layer.split('.')
32
+ module = model
33
+ if hasattr(model, 'module') and layer[0] != 'module':
34
+ module = model.module
35
+ lst_index = 0
36
+ module2 = module
37
+ for l in layer:
38
+ if hasattr(module2, l):
39
+ if not l.isdigit():
40
+ module2 = getattr(module2, l)
41
+ else:
42
+ module2 = module2[int(l)]
43
+ lst_index += 1
44
+ lst_index -= 1
45
+ for l in layer[:lst_index]:
46
+ if not l.isdigit():
47
+ module = getattr(module, l)
48
+ else:
49
+ module = module[int(l)]
50
+ l = layer[lst_index]
51
+ setattr(module, l, val)
52
+
53
+
54
+ def adapt_model_from_string(parent_module, model_string):
55
+ separator = '***'
56
+ state_dict = {}
57
+ lst_shape = model_string.split(separator)
58
+ for k in lst_shape:
59
+ k = k.split(':')
60
+ key = k[0]
61
+ shape = k[1][1:-1].split(',')
62
+ if shape[0] != '':
63
+ state_dict[key] = [int(i) for i in shape]
64
+
65
+ new_module = deepcopy(parent_module)
66
+ for n, m in parent_module.named_modules():
67
+ old_module = extract_layer(parent_module, n)
68
+ if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
69
+ if isinstance(old_module, Conv2dSame):
70
+ conv = Conv2dSame
71
+ else:
72
+ conv = nn.Conv2d
73
+ s = state_dict[n + '.weight']
74
+ in_channels = s[1]
75
+ out_channels = s[0]
76
+ g = 1
77
+ if old_module.groups > 1:
78
+ in_channels = out_channels
79
+ g = in_channels
80
+ new_conv = conv(
81
+ in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size,
82
+ bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
83
+ groups=g, stride=old_module.stride)
84
+ set_layer(new_module, n, new_conv)
85
+ elif isinstance(old_module, BatchNormAct2d):
86
+ new_bn = BatchNormAct2d(
87
+ state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
88
+ affine=old_module.affine, track_running_stats=True)
89
+ new_bn.drop = old_module.drop
90
+ new_bn.act = old_module.act
91
+ set_layer(new_module, n, new_bn)
92
+ elif isinstance(old_module, nn.BatchNorm2d):
93
+ new_bn = nn.BatchNorm2d(
94
+ num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
95
+ affine=old_module.affine, track_running_stats=True)
96
+ set_layer(new_module, n, new_bn)
97
+ elif isinstance(old_module, nn.Linear):
98
+ # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
99
+ num_features = state_dict[n + '.weight'][1]
100
+ new_fc = Linear(
101
+ in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
102
+ set_layer(new_module, n, new_fc)
103
+ if hasattr(new_module, 'num_features'):
104
+ new_module.num_features = num_features
105
+ new_module.eval()
106
+ parent_module.eval()
107
+
108
+ return new_module
109
+
110
+
111
+ def adapt_model_from_file(parent_module, model_variant):
112
+ adapt_data = pkgutil.get_data(__name__, os.path.join('_pruned', model_variant + '.txt'))
113
+ return adapt_model_from_string(parent_module, adapt_data.decode('utf-8').strip())
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/beit.py ADDED
@@ -0,0 +1,621 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ BEiT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
2
+
3
+ Model from official source: https://github.com/microsoft/unilm/tree/master/beit
4
+
5
+ @inproceedings{beit,
6
+ title={{BEiT}: {BERT} Pre-Training of Image Transformers},
7
+ author={Hangbo Bao and Li Dong and Songhao Piao and Furu Wei},
8
+ booktitle={International Conference on Learning Representations},
9
+ year={2022},
10
+ url={https://openreview.net/forum?id=p-BhZSz59o4}
11
+ }
12
+
13
+ BEiT-v2 from https://github.com/microsoft/unilm/tree/master/beit2
14
+
15
+ @article{beitv2,
16
+ title={{BEiT v2}: Masked Image Modeling with Vector-Quantized Visual Tokenizers},
17
+ author={Zhiliang Peng and Li Dong and Hangbo Bao and Qixiang Ye and Furu Wei},
18
+ year={2022},
19
+ eprint={2208.06366},
20
+ archivePrefix={arXiv},
21
+ primaryClass={cs.CV}
22
+ }
23
+
24
+ At this point only the 1k fine-tuned classification weights and model configs have been added,
25
+ see original source above for pre-training models and procedure.
26
+
27
+ Modifications by / Copyright 2021 Ross Wightman, original copyrights below
28
+ """
29
+ # --------------------------------------------------------
30
+ # BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
31
+ # Github source: https://github.com/microsoft/unilm/tree/master/beit
32
+ # Copyright (c) 2021 Microsoft
33
+ # Licensed under The MIT License [see LICENSE for details]
34
+ # By Hangbo Bao
35
+ # Based on timm and DeiT code bases
36
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm
37
+ # https://github.com/facebookresearch/deit/
38
+ # https://github.com/facebookresearch/dino
39
+ # --------------------------------------------------------'
40
+
41
+ import math
42
+ from typing import Callable, Optional, Tuple, Union
43
+
44
+ import torch
45
+ import torch.nn as nn
46
+ import torch.nn.functional as F
47
+ from torch.utils.checkpoint import checkpoint
48
+
49
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
50
+ from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_, use_fused_attn
51
+ from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel_pos_bias_table, ndgrid
52
+
53
+
54
+ from ._builder import build_model_with_cfg
55
+ from ._registry import generate_default_cfgs, register_model
56
+ from .vision_transformer import checkpoint_filter_fn
57
+
58
+ __all__ = ['Beit']
59
+
60
+
61
+ def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor:
62
+ num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
63
+ # cls to token & token 2 cls & cls to cls
64
+ # get pair-wise relative position index for each token inside the window
65
+ window_area = window_size[0] * window_size[1]
66
+ coords = torch.stack(ndgrid(torch.arange(window_size[0]), torch.arange(window_size[1]))) # 2, Wh, Ww
67
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
68
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
69
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
70
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
71
+ relative_coords[:, :, 1] += window_size[1] - 1
72
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
73
+ relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype)
74
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
75
+ relative_position_index[0, 0:] = num_relative_distance - 3
76
+ relative_position_index[0:, 0] = num_relative_distance - 2
77
+ relative_position_index[0, 0] = num_relative_distance - 1
78
+ return relative_position_index
79
+
80
+
81
+ class Attention(nn.Module):
82
+ fused_attn: torch.jit.Final[bool]
83
+
84
+ def __init__(
85
+ self,
86
+ dim: int,
87
+ num_heads: int = 8,
88
+ qkv_bias: bool = False,
89
+ attn_drop: float = 0.,
90
+ proj_drop: float = 0.,
91
+ window_size: Optional[Tuple[int, int]] = None,
92
+ attn_head_dim: Optional[int] = None,
93
+ ):
94
+ super().__init__()
95
+ self.num_heads = num_heads
96
+ head_dim = dim // num_heads
97
+ if attn_head_dim is not None:
98
+ head_dim = attn_head_dim
99
+ all_head_dim = head_dim * self.num_heads
100
+ self.scale = head_dim ** -0.5
101
+ self.fused_attn = use_fused_attn()
102
+
103
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
104
+ if qkv_bias:
105
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
106
+ self.register_buffer('k_bias', torch.zeros(all_head_dim), persistent=False)
107
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
108
+ else:
109
+ self.q_bias = None
110
+ self.k_bias = None
111
+ self.v_bias = None
112
+
113
+ if window_size:
114
+ self.window_size = window_size
115
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
116
+ self.relative_position_bias_table = nn.Parameter(
117
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
118
+ self.register_buffer("relative_position_index", gen_relative_position_index(window_size), persistent=False)
119
+ else:
120
+ self.window_size = None
121
+ self.relative_position_bias_table = None
122
+ self.relative_position_index = None
123
+
124
+ self.attn_drop = nn.Dropout(attn_drop)
125
+ self.proj = nn.Linear(all_head_dim, dim)
126
+ self.proj_drop = nn.Dropout(proj_drop)
127
+
128
+ def _get_rel_pos_bias(self):
129
+ relative_position_bias = self.relative_position_bias_table[
130
+ self.relative_position_index.view(-1)].view(
131
+ self.window_size[0] * self.window_size[1] + 1,
132
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
133
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
134
+ return relative_position_bias.unsqueeze(0)
135
+
136
+ def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None):
137
+ B, N, C = x.shape
138
+
139
+ qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
140
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
141
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
142
+ q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
143
+
144
+ if self.fused_attn:
145
+ rel_pos_bias = None
146
+ if self.relative_position_bias_table is not None:
147
+ rel_pos_bias = self._get_rel_pos_bias()
148
+ if shared_rel_pos_bias is not None:
149
+ rel_pos_bias = rel_pos_bias + shared_rel_pos_bias
150
+ elif shared_rel_pos_bias is not None:
151
+ rel_pos_bias = shared_rel_pos_bias
152
+
153
+ x = F.scaled_dot_product_attention(
154
+ q, k, v,
155
+ attn_mask=rel_pos_bias,
156
+ dropout_p=self.attn_drop.p if self.training else 0.,
157
+ )
158
+ else:
159
+ q = q * self.scale
160
+ attn = (q @ k.transpose(-2, -1))
161
+
162
+ if self.relative_position_bias_table is not None:
163
+ attn = attn + self._get_rel_pos_bias()
164
+ if shared_rel_pos_bias is not None:
165
+ attn = attn + shared_rel_pos_bias
166
+
167
+ attn = attn.softmax(dim=-1)
168
+ attn = self.attn_drop(attn)
169
+ x = attn @ v
170
+
171
+ x = x.transpose(1, 2).reshape(B, N, C)
172
+ x = self.proj(x)
173
+ x = self.proj_drop(x)
174
+ return x
175
+
176
+
177
+ class Block(nn.Module):
178
+
179
+ def __init__(
180
+ self,
181
+ dim: int,
182
+ num_heads: int,
183
+ qkv_bias: bool = False,
184
+ mlp_ratio: float = 4.,
185
+ scale_mlp: bool = False,
186
+ swiglu_mlp: bool = False,
187
+ proj_drop: float = 0.,
188
+ attn_drop: float = 0.,
189
+ drop_path: float = 0.,
190
+ init_values: Optional[float] = None,
191
+ act_layer: Callable = nn.GELU,
192
+ norm_layer: Callable = LayerNorm,
193
+ window_size: Optional[Tuple[int, int]] = None,
194
+ attn_head_dim: Optional[int] = None,
195
+ ):
196
+ super().__init__()
197
+ self.norm1 = norm_layer(dim)
198
+ self.attn = Attention(
199
+ dim,
200
+ num_heads=num_heads,
201
+ qkv_bias=qkv_bias,
202
+ attn_drop=attn_drop,
203
+ proj_drop=proj_drop,
204
+ window_size=window_size,
205
+ attn_head_dim=attn_head_dim,
206
+ )
207
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
208
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
209
+
210
+ self.norm2 = norm_layer(dim)
211
+ if swiglu_mlp:
212
+ self.mlp = SwiGLU(
213
+ in_features=dim,
214
+ hidden_features=int(dim * mlp_ratio),
215
+ norm_layer=norm_layer if scale_mlp else None,
216
+ drop=proj_drop,
217
+ )
218
+ else:
219
+ self.mlp = Mlp(
220
+ in_features=dim,
221
+ hidden_features=int(dim * mlp_ratio),
222
+ act_layer=act_layer,
223
+ norm_layer=norm_layer if scale_mlp else None,
224
+ drop=proj_drop,
225
+ )
226
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
227
+
228
+ if init_values:
229
+ self.gamma_1 = nn.Parameter(init_values * torch.ones(dim))
230
+ self.gamma_2 = nn.Parameter(init_values * torch.ones(dim))
231
+ else:
232
+ self.gamma_1, self.gamma_2 = None, None
233
+
234
+ def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None):
235
+ if self.gamma_1 is None:
236
+ x = x + self.drop_path1(self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias))
237
+ x = x + self.drop_path2(self.mlp(self.norm2(x)))
238
+ else:
239
+ x = x + self.drop_path1(self.gamma_1 * self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias))
240
+ x = x + self.drop_path2(self.gamma_2 * self.mlp(self.norm2(x)))
241
+ return x
242
+
243
+
244
+ class RelativePositionBias(nn.Module):
245
+
246
+ def __init__(self, window_size, num_heads):
247
+ super().__init__()
248
+ self.window_size = window_size
249
+ self.window_area = window_size[0] * window_size[1]
250
+ num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
251
+ self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads))
252
+ # trunc_normal_(self.relative_position_bias_table, std=.02)
253
+ self.register_buffer("relative_position_index", gen_relative_position_index(window_size))
254
+
255
+ def forward(self):
256
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
257
+ self.window_area + 1, self.window_area + 1, -1) # Wh*Ww,Wh*Ww,nH
258
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
259
+
260
+
261
+ class Beit(nn.Module):
262
+ """ Vision Transformer with support for patch or hybrid CNN input stage
263
+ """
264
+
265
+ def __init__(
266
+ self,
267
+ img_size: Union[int, Tuple[int, int]] = 224,
268
+ patch_size: Union[int, Tuple[int, int]] = 16,
269
+ in_chans: int = 3,
270
+ num_classes: int = 1000,
271
+ global_pool: str = 'avg',
272
+ embed_dim: int = 768,
273
+ depth: int = 12,
274
+ num_heads: int = 12,
275
+ qkv_bias: bool = True,
276
+ mlp_ratio: float = 4.,
277
+ swiglu_mlp: bool = False,
278
+ scale_mlp: bool = False,
279
+ drop_rate: float = 0.,
280
+ pos_drop_rate: float = 0.,
281
+ proj_drop_rate: float = 0.,
282
+ attn_drop_rate: float = 0.,
283
+ drop_path_rate: float = 0.,
284
+ norm_layer: Callable = LayerNorm,
285
+ init_values: Optional[float] = None,
286
+ use_abs_pos_emb: bool = True,
287
+ use_rel_pos_bias: bool = False,
288
+ use_shared_rel_pos_bias: bool = False,
289
+ head_init_scale: float = 0.001,
290
+ ):
291
+ super().__init__()
292
+ self.num_classes = num_classes
293
+ self.global_pool = global_pool
294
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
295
+ self.num_prefix_tokens = 1
296
+ self.grad_checkpointing = False
297
+
298
+ self.patch_embed = PatchEmbed(
299
+ img_size=img_size,
300
+ patch_size=patch_size,
301
+ in_chans=in_chans,
302
+ embed_dim=embed_dim,
303
+ )
304
+ num_patches = self.patch_embed.num_patches
305
+
306
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
307
+ # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
308
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) if use_abs_pos_emb else None
309
+ self.pos_drop = nn.Dropout(p=pos_drop_rate)
310
+
311
+ if use_shared_rel_pos_bias:
312
+ self.rel_pos_bias = RelativePositionBias(
313
+ window_size=self.patch_embed.grid_size,
314
+ num_heads=num_heads,
315
+ )
316
+ else:
317
+ self.rel_pos_bias = None
318
+
319
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
320
+ self.blocks = nn.ModuleList([
321
+ Block(
322
+ dim=embed_dim,
323
+ num_heads=num_heads,
324
+ qkv_bias=qkv_bias,
325
+ mlp_ratio=mlp_ratio,
326
+ scale_mlp=scale_mlp,
327
+ swiglu_mlp=swiglu_mlp,
328
+ proj_drop=proj_drop_rate,
329
+ attn_drop=attn_drop_rate,
330
+ drop_path=dpr[i],
331
+ norm_layer=norm_layer,
332
+ init_values=init_values,
333
+ window_size=self.patch_embed.grid_size if use_rel_pos_bias else None,
334
+ )
335
+ for i in range(depth)])
336
+
337
+ use_fc_norm = self.global_pool == 'avg'
338
+ self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
339
+ self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
340
+ self.head_drop = nn.Dropout(drop_rate)
341
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
342
+
343
+ self.apply(self._init_weights)
344
+ if self.pos_embed is not None:
345
+ trunc_normal_(self.pos_embed, std=.02)
346
+ trunc_normal_(self.cls_token, std=.02)
347
+
348
+ self.fix_init_weight()
349
+ if isinstance(self.head, nn.Linear):
350
+ trunc_normal_(self.head.weight, std=.02)
351
+ self.head.weight.data.mul_(head_init_scale)
352
+ self.head.bias.data.mul_(head_init_scale)
353
+
354
+ def fix_init_weight(self):
355
+ def rescale(param, layer_id):
356
+ param.div_(math.sqrt(2.0 * layer_id))
357
+
358
+ for layer_id, layer in enumerate(self.blocks):
359
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
360
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
361
+
362
+ def _init_weights(self, m):
363
+ if isinstance(m, nn.Linear):
364
+ trunc_normal_(m.weight, std=.02)
365
+ if isinstance(m, nn.Linear) and m.bias is not None:
366
+ nn.init.constant_(m.bias, 0)
367
+ elif isinstance(m, nn.LayerNorm):
368
+ nn.init.constant_(m.bias, 0)
369
+ nn.init.constant_(m.weight, 1.0)
370
+
371
+ @torch.jit.ignore
372
+ def no_weight_decay(self):
373
+ nwd = {'pos_embed', 'cls_token'}
374
+ for n, _ in self.named_parameters():
375
+ if 'relative_position_bias_table' in n:
376
+ nwd.add(n)
377
+ return nwd
378
+
379
+ @torch.jit.ignore
380
+ def set_grad_checkpointing(self, enable=True):
381
+ self.grad_checkpointing = enable
382
+
383
+ @torch.jit.ignore
384
+ def group_matcher(self, coarse=False):
385
+ matcher = dict(
386
+ stem=r'^cls_token|pos_embed|patch_embed|rel_pos_bias', # stem and embed
387
+ blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))],
388
+ )
389
+ return matcher
390
+
391
+ @torch.jit.ignore
392
+ def get_classifier(self):
393
+ return self.head
394
+
395
+ def reset_classifier(self, num_classes, global_pool=None):
396
+ self.num_classes = num_classes
397
+ if global_pool is not None:
398
+ self.global_pool = global_pool
399
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
400
+
401
+ def forward_features(self, x):
402
+ x = self.patch_embed(x)
403
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
404
+ if self.pos_embed is not None:
405
+ x = x + self.pos_embed
406
+ x = self.pos_drop(x)
407
+
408
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
409
+ for blk in self.blocks:
410
+ if self.grad_checkpointing and not torch.jit.is_scripting():
411
+ x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias)
412
+ else:
413
+ x = blk(x, shared_rel_pos_bias=rel_pos_bias)
414
+ x = self.norm(x)
415
+ return x
416
+
417
+ def forward_head(self, x, pre_logits: bool = False):
418
+ if self.global_pool:
419
+ x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
420
+ x = self.fc_norm(x)
421
+ x = self.head_drop(x)
422
+ return x if pre_logits else self.head(x)
423
+
424
+ def forward(self, x):
425
+ x = self.forward_features(x)
426
+ x = self.forward_head(x)
427
+ return x
428
+
429
+
430
+ def _cfg(url='', **kwargs):
431
+ return {
432
+ 'url': url,
433
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
434
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
435
+ 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
436
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
437
+ **kwargs
438
+ }
439
+
440
+
441
+ default_cfgs = generate_default_cfgs({
442
+ 'beit_base_patch16_224.in22k_ft_in22k_in1k': _cfg(
443
+ #url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth',
444
+ hf_hub_id='timm/'),
445
+ 'beit_base_patch16_384.in22k_ft_in22k_in1k': _cfg(
446
+ #url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth',
447
+ hf_hub_id='timm/',
448
+ input_size=(3, 384, 384), crop_pct=1.0,
449
+ ),
450
+ 'beit_base_patch16_224.in22k_ft_in22k': _cfg(
451
+ #url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth',
452
+ hf_hub_id='timm/',
453
+ num_classes=21841,
454
+ ),
455
+ 'beit_large_patch16_224.in22k_ft_in22k_in1k': _cfg(
456
+ #url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth',
457
+ hf_hub_id='timm/'),
458
+ 'beit_large_patch16_384.in22k_ft_in22k_in1k': _cfg(
459
+ #url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth',
460
+ hf_hub_id='timm/',
461
+ input_size=(3, 384, 384), crop_pct=1.0,
462
+ ),
463
+ 'beit_large_patch16_512.in22k_ft_in22k_in1k': _cfg(
464
+ #url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth',
465
+ hf_hub_id='timm/',
466
+ input_size=(3, 512, 512), crop_pct=1.0,
467
+ ),
468
+ 'beit_large_patch16_224.in22k_ft_in22k': _cfg(
469
+ #url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth',
470
+ hf_hub_id='timm/',
471
+ num_classes=21841,
472
+ ),
473
+
474
+ 'beitv2_base_patch16_224.in1k_ft_in22k_in1k': _cfg(
475
+ #url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth',
476
+ hf_hub_id='timm/',
477
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
478
+ ),
479
+ 'beitv2_base_patch16_224.in1k_ft_in1k': _cfg(
480
+ #url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft1k.pth',
481
+ hf_hub_id='timm/',
482
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
483
+ ),
484
+ 'beitv2_base_patch16_224.in1k_ft_in22k': _cfg(
485
+ #url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth',
486
+ hf_hub_id='timm/',
487
+ num_classes=21841, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
488
+ ),
489
+ 'beitv2_large_patch16_224.in1k_ft_in22k_in1k': _cfg(
490
+ #url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth',
491
+ hf_hub_id='timm/',
492
+ crop_pct=0.95, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
493
+ ),
494
+ 'beitv2_large_patch16_224.in1k_ft_in1k': _cfg(
495
+ #url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft1k.pth',
496
+ hf_hub_id='timm/',
497
+ crop_pct=0.95, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
498
+ ),
499
+ 'beitv2_large_patch16_224.in1k_ft_in22k': _cfg(
500
+ #url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth',
501
+ hf_hub_id='timm/',
502
+ num_classes=21841, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
503
+ ),
504
+ })
505
+
506
+
507
+ def _beit_checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antialias=True):
508
+ state_dict = state_dict.get('model', state_dict)
509
+ state_dict = state_dict.get('module', state_dict)
510
+ # beit v2 didn't strip module
511
+
512
+ out_dict = {}
513
+ for k, v in state_dict.items():
514
+ if 'relative_position_index' in k:
515
+ continue
516
+ if 'patch_embed.proj.weight' in k:
517
+ O, I, H, W = model.patch_embed.proj.weight.shape
518
+ if v.shape[-1] != W or v.shape[-2] != H:
519
+ v = resample_patch_embed(
520
+ v,
521
+ (H, W),
522
+ interpolation=interpolation,
523
+ antialias=antialias,
524
+ verbose=True,
525
+ )
526
+ elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
527
+ # To resize pos embedding when using model at different size from pretrained weights
528
+ num_prefix_tokens = 1
529
+ v = resample_abs_pos_embed(
530
+ v,
531
+ new_size=model.patch_embed.grid_size,
532
+ num_prefix_tokens=num_prefix_tokens,
533
+ interpolation=interpolation,
534
+ antialias=antialias,
535
+ verbose=True,
536
+ )
537
+ elif k.endswith('relative_position_bias_table'):
538
+ m = model.get_submodule(k[:-29])
539
+ if v.shape != m.relative_position_bias_table.shape or m.window_size[0] != m.window_size[1]:
540
+ v = resize_rel_pos_bias_table(
541
+ v,
542
+ new_window_size=m.window_size,
543
+ new_bias_shape=m.relative_position_bias_table.shape,
544
+ )
545
+ out_dict[k] = v
546
+ return out_dict
547
+
548
+
549
+ def _create_beit(variant, pretrained=False, **kwargs):
550
+ if kwargs.get('features_only', None):
551
+ raise RuntimeError('features_only not implemented for BEiT models.')
552
+
553
+ model = build_model_with_cfg(
554
+ Beit, variant, pretrained,
555
+ # FIXME an updated filter fn needed to interpolate rel pos emb if fine tuning to diff model sizes
556
+ pretrained_filter_fn=_beit_checkpoint_filter_fn,
557
+ **kwargs)
558
+ return model
559
+
560
+
561
+ @register_model
562
+ def beit_base_patch16_224(pretrained=False, **kwargs) -> Beit:
563
+ model_args = dict(
564
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
565
+ use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1)
566
+ model = _create_beit('beit_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
567
+ return model
568
+
569
+
570
+ @register_model
571
+ def beit_base_patch16_384(pretrained=False, **kwargs) -> Beit:
572
+ model_args = dict(
573
+ img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12,
574
+ use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1)
575
+ model = _create_beit('beit_base_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
576
+ return model
577
+
578
+
579
+ @register_model
580
+ def beit_large_patch16_224(pretrained=False, **kwargs) -> Beit:
581
+ model_args = dict(
582
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16,
583
+ use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
584
+ model = _create_beit('beit_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
585
+ return model
586
+
587
+
588
+ @register_model
589
+ def beit_large_patch16_384(pretrained=False, **kwargs) -> Beit:
590
+ model_args = dict(
591
+ img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16,
592
+ use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
593
+ model = _create_beit('beit_large_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
594
+ return model
595
+
596
+
597
+ @register_model
598
+ def beit_large_patch16_512(pretrained=False, **kwargs) -> Beit:
599
+ model_args = dict(
600
+ img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16,
601
+ use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
602
+ model = _create_beit('beit_large_patch16_512', pretrained=pretrained, **dict(model_args, **kwargs))
603
+ return model
604
+
605
+
606
+ @register_model
607
+ def beitv2_base_patch16_224(pretrained=False, **kwargs) -> Beit:
608
+ model_args = dict(
609
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
610
+ use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
611
+ model = _create_beit('beitv2_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
612
+ return model
613
+
614
+
615
+ @register_model
616
+ def beitv2_large_patch16_224(pretrained=False, **kwargs) -> Beit:
617
+ model_args = dict(
618
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16,
619
+ use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
620
+ model = _create_beit('beitv2_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
621
+ return model
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/byoanet.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Bring-Your-Own-Attention Network
2
+
3
+ A flexible network w/ dataclass based config for stacking NN blocks including
4
+ self-attention (or similar) layers.
5
+
6
+ Currently used to implement experimental variants of:
7
+ * Bottleneck Transformers
8
+ * Lambda ResNets
9
+ * HaloNets
10
+
11
+ Consider all of the models definitions here as experimental WIP and likely to change.
12
+
13
+ Hacked together by / copyright Ross Wightman, 2021.
14
+ """
15
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
16
+ from ._builder import build_model_with_cfg
17
+ from ._registry import register_model, generate_default_cfgs
18
+ from .byobnet import ByoBlockCfg, ByoModelCfg, ByobNet, interleave_blocks
19
+
20
+ __all__ = []
21
+
22
+
23
+ model_cfgs = dict(
24
+
25
+ botnet26t=ByoModelCfg(
26
+ blocks=(
27
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
28
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
29
+ interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
30
+ ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
31
+ ),
32
+ stem_chs=64,
33
+ stem_type='tiered',
34
+ stem_pool='maxpool',
35
+ fixed_input_size=True,
36
+ self_attn_layer='bottleneck',
37
+ self_attn_kwargs=dict()
38
+ ),
39
+ sebotnet33ts=ByoModelCfg(
40
+ blocks=(
41
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
42
+ interleave_blocks(types=('bottle', 'self_attn'), every=[2], d=3, c=512, s=2, gs=0, br=0.25),
43
+ interleave_blocks(types=('bottle', 'self_attn'), every=[2], d=3, c=1024, s=2, gs=0, br=0.25),
44
+ ByoBlockCfg('self_attn', d=2, c=1536, s=2, gs=0, br=0.333),
45
+ ),
46
+ stem_chs=64,
47
+ stem_type='tiered',
48
+ stem_pool='',
49
+ act_layer='silu',
50
+ num_features=1280,
51
+ attn_layer='se',
52
+ self_attn_layer='bottleneck',
53
+ self_attn_kwargs=dict()
54
+ ),
55
+ botnet50ts=ByoModelCfg(
56
+ blocks=(
57
+ ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
58
+ interleave_blocks(types=('bottle', 'self_attn'), every=4, d=4, c=512, s=2, gs=0, br=0.25),
59
+ interleave_blocks(types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25),
60
+ interleave_blocks(types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25),
61
+ ),
62
+ stem_chs=64,
63
+ stem_type='tiered',
64
+ stem_pool='maxpool',
65
+ act_layer='silu',
66
+ fixed_input_size=True,
67
+ self_attn_layer='bottleneck',
68
+ self_attn_kwargs=dict()
69
+ ),
70
+ eca_botnext26ts=ByoModelCfg(
71
+ blocks=(
72
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
73
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25),
74
+ interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=16, br=0.25),
75
+ ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
76
+ ),
77
+ stem_chs=64,
78
+ stem_type='tiered',
79
+ stem_pool='maxpool',
80
+ fixed_input_size=True,
81
+ act_layer='silu',
82
+ attn_layer='eca',
83
+ self_attn_layer='bottleneck',
84
+ self_attn_kwargs=dict(dim_head=16)
85
+ ),
86
+
87
+ halonet_h1=ByoModelCfg(
88
+ blocks=(
89
+ ByoBlockCfg(type='self_attn', d=3, c=64, s=1, gs=0, br=1.0),
90
+ ByoBlockCfg(type='self_attn', d=3, c=128, s=2, gs=0, br=1.0),
91
+ ByoBlockCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0),
92
+ ByoBlockCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0),
93
+ ),
94
+ stem_chs=64,
95
+ stem_type='7x7',
96
+ stem_pool='maxpool',
97
+
98
+ self_attn_layer='halo',
99
+ self_attn_kwargs=dict(block_size=8, halo_size=3),
100
+ ),
101
+ halonet26t=ByoModelCfg(
102
+ blocks=(
103
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
104
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
105
+ interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
106
+ ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
107
+ ),
108
+ stem_chs=64,
109
+ stem_type='tiered',
110
+ stem_pool='maxpool',
111
+ self_attn_layer='halo',
112
+ self_attn_kwargs=dict(block_size=8, halo_size=2)
113
+ ),
114
+ sehalonet33ts=ByoModelCfg(
115
+ blocks=(
116
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
117
+ interleave_blocks(types=('bottle', 'self_attn'), every=[2], d=3, c=512, s=2, gs=0, br=0.25),
118
+ interleave_blocks(types=('bottle', 'self_attn'), every=[2], d=3, c=1024, s=2, gs=0, br=0.25),
119
+ ByoBlockCfg('self_attn', d=2, c=1536, s=2, gs=0, br=0.333),
120
+ ),
121
+ stem_chs=64,
122
+ stem_type='tiered',
123
+ stem_pool='',
124
+ act_layer='silu',
125
+ num_features=1280,
126
+ attn_layer='se',
127
+ self_attn_layer='halo',
128
+ self_attn_kwargs=dict(block_size=8, halo_size=3)
129
+ ),
130
+ halonet50ts=ByoModelCfg(
131
+ blocks=(
132
+ ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
133
+ interleave_blocks(
134
+ types=('bottle', 'self_attn'), every=4, d=4, c=512, s=2, gs=0, br=0.25,
135
+ self_attn_layer='halo', self_attn_kwargs=dict(block_size=8, halo_size=3, num_heads=4)),
136
+ interleave_blocks(types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25),
137
+ interleave_blocks(types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25),
138
+ ),
139
+ stem_chs=64,
140
+ stem_type='tiered',
141
+ stem_pool='maxpool',
142
+ act_layer='silu',
143
+ self_attn_layer='halo',
144
+ self_attn_kwargs=dict(block_size=8, halo_size=3)
145
+ ),
146
+ eca_halonext26ts=ByoModelCfg(
147
+ blocks=(
148
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
149
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25),
150
+ interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=16, br=0.25),
151
+ ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
152
+ ),
153
+ stem_chs=64,
154
+ stem_type='tiered',
155
+ stem_pool='maxpool',
156
+ act_layer='silu',
157
+ attn_layer='eca',
158
+ self_attn_layer='halo',
159
+ self_attn_kwargs=dict(block_size=8, halo_size=2, dim_head=16)
160
+ ),
161
+
162
+ lambda_resnet26t=ByoModelCfg(
163
+ blocks=(
164
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
165
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
166
+ interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
167
+ ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
168
+ ),
169
+ stem_chs=64,
170
+ stem_type='tiered',
171
+ stem_pool='maxpool',
172
+ self_attn_layer='lambda',
173
+ self_attn_kwargs=dict(r=9)
174
+ ),
175
+ lambda_resnet50ts=ByoModelCfg(
176
+ blocks=(
177
+ ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
178
+ interleave_blocks(types=('bottle', 'self_attn'), every=4, d=4, c=512, s=2, gs=0, br=0.25),
179
+ interleave_blocks(types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25),
180
+ interleave_blocks(types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25),
181
+ ),
182
+ stem_chs=64,
183
+ stem_type='tiered',
184
+ stem_pool='maxpool',
185
+ act_layer='silu',
186
+ self_attn_layer='lambda',
187
+ self_attn_kwargs=dict(r=9)
188
+ ),
189
+ lambda_resnet26rpt_256=ByoModelCfg(
190
+ blocks=(
191
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
192
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
193
+ interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
194
+ ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
195
+ ),
196
+ stem_chs=64,
197
+ stem_type='tiered',
198
+ stem_pool='maxpool',
199
+ self_attn_layer='lambda',
200
+ self_attn_kwargs=dict(r=None)
201
+ ),
202
+
203
+ # experimental
204
+ haloregnetz_b=ByoModelCfg(
205
+ blocks=(
206
+ ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=3),
207
+ ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=3),
208
+ interleave_blocks(types=('bottle', 'self_attn'), every=3, d=12, c=192, s=2, gs=16, br=3),
209
+ ByoBlockCfg('self_attn', d=2, c=288, s=2, gs=16, br=3),
210
+ ),
211
+ stem_chs=32,
212
+ stem_pool='',
213
+ downsample='',
214
+ num_features=1536,
215
+ act_layer='silu',
216
+ attn_layer='se',
217
+ attn_kwargs=dict(rd_ratio=0.25),
218
+ block_kwargs=dict(bottle_in=True, linear_out=True),
219
+ self_attn_layer='halo',
220
+ self_attn_kwargs=dict(block_size=7, halo_size=2, qk_ratio=0.33)
221
+ ),
222
+
223
+ # experimental
224
+ lamhalobotnet50ts=ByoModelCfg(
225
+ blocks=(
226
+ ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
227
+ interleave_blocks(
228
+ types=('bottle', 'self_attn'), d=4, c=512, s=2, gs=0, br=0.25,
229
+ self_attn_layer='lambda', self_attn_kwargs=dict(r=13)),
230
+ interleave_blocks(
231
+ types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25,
232
+ self_attn_layer='halo', self_attn_kwargs=dict(halo_size=3)),
233
+ interleave_blocks(
234
+ types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25,
235
+ self_attn_layer='bottleneck', self_attn_kwargs=dict()),
236
+ ),
237
+ stem_chs=64,
238
+ stem_type='tiered',
239
+ stem_pool='',
240
+ act_layer='silu',
241
+ ),
242
+ halo2botnet50ts=ByoModelCfg(
243
+ blocks=(
244
+ ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
245
+ interleave_blocks(
246
+ types=('bottle', 'self_attn'), d=4, c=512, s=2, gs=0, br=0.25,
247
+ self_attn_layer='halo', self_attn_kwargs=dict(halo_size=3)),
248
+ interleave_blocks(
249
+ types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25,
250
+ self_attn_layer='halo', self_attn_kwargs=dict(halo_size=3)),
251
+ interleave_blocks(
252
+ types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25,
253
+ self_attn_layer='bottleneck', self_attn_kwargs=dict()),
254
+ ),
255
+ stem_chs=64,
256
+ stem_type='tiered',
257
+ stem_pool='',
258
+ act_layer='silu',
259
+ ),
260
+ )
261
+
262
+
263
+ def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs):
264
+ return build_model_with_cfg(
265
+ ByobNet, variant, pretrained,
266
+ model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
267
+ feature_cfg=dict(flatten_sequential=True),
268
+ **kwargs,
269
+ )
270
+
271
+
272
+ def _cfg(url='', **kwargs):
273
+ return {
274
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
275
+ 'crop_pct': 0.95, 'interpolation': 'bicubic',
276
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
277
+ 'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc',
278
+ 'fixed_input_size': False, 'min_input_size': (3, 224, 224),
279
+ **kwargs
280
+ }
281
+
282
+
283
+ default_cfgs = generate_default_cfgs({
284
+ # GPU-Efficient (ResNet) weights
285
+ 'botnet26t_256.c1_in1k': _cfg(
286
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/botnet26t_c1_256-167a0e9f.pth',
287
+ hf_hub_id='timm/',
288
+ fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
289
+ 'sebotnet33ts_256.a1h_in1k': _cfg(
290
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sebotnet33ts_a1h2_256-957e3c3e.pth',
291
+ hf_hub_id='timm/',
292
+ fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
293
+ 'botnet50ts_256.untrained': _cfg(
294
+ fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
295
+ 'eca_botnext26ts_256.c1_in1k': _cfg(
296
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_botnext26ts_c_256-95a898f6.pth',
297
+ hf_hub_id='timm/',
298
+ fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
299
+
300
+ 'halonet_h1.untrained': _cfg(input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
301
+ 'halonet26t.a1h_in1k': _cfg(
302
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet26t_a1h_256-3083328c.pth',
303
+ hf_hub_id='timm/',
304
+ input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
305
+ 'sehalonet33ts.ra2_in1k': _cfg(
306
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sehalonet33ts_256-87e053f9.pth',
307
+ hf_hub_id='timm/',
308
+ input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
309
+ 'halonet50ts.a1h_in1k': _cfg(
310
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet50ts_a1h2_256-f3a3daee.pth',
311
+ hf_hub_id='timm/',
312
+ input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
313
+ 'eca_halonext26ts.c1_in1k': _cfg(
314
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_halonext26ts_c_256-06906299.pth',
315
+ hf_hub_id='timm/',
316
+ input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
317
+
318
+ 'lambda_resnet26t.c1_in1k': _cfg(
319
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26t_c_256-e5a5c857.pth',
320
+ hf_hub_id='timm/',
321
+ min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
322
+ 'lambda_resnet50ts.a1h_in1k': _cfg(
323
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet50ts_a1h_256-b87370f7.pth',
324
+ hf_hub_id='timm/',
325
+ min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)),
326
+ 'lambda_resnet26rpt_256.c1_in1k': _cfg(
327
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26rpt_c_256-ab00292d.pth',
328
+ hf_hub_id='timm/',
329
+ fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
330
+
331
+ 'haloregnetz_b.ra3_in1k': _cfg(
332
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/haloregnetz_c_raa_256-c8ad7616.pth',
333
+ hf_hub_id='timm/',
334
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
335
+ first_conv='stem.conv', input_size=(3, 224, 224), pool_size=(7, 7), min_input_size=(3, 224, 224), crop_pct=0.94),
336
+
337
+ 'lamhalobotnet50ts_256.a1h_in1k': _cfg(
338
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lamhalobotnet50ts_a1h2_256-fe3d9445.pth',
339
+ hf_hub_id='timm/',
340
+ fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
341
+ 'halo2botnet50ts_256.a1h_in1k': _cfg(
342
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halo2botnet50ts_a1h2_256-fd9c11a3.pth',
343
+ hf_hub_id='timm/',
344
+ fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
345
+ })
346
+
347
+
348
+ @register_model
349
+ def botnet26t_256(pretrained=False, **kwargs) -> ByobNet:
350
+ """ Bottleneck Transformer w/ ResNet26-T backbone.
351
+ """
352
+ kwargs.setdefault('img_size', 256)
353
+ return _create_byoanet('botnet26t_256', 'botnet26t', pretrained=pretrained, **kwargs)
354
+
355
+
356
+ @register_model
357
+ def sebotnet33ts_256(pretrained=False, **kwargs) -> ByobNet:
358
+ """ Bottleneck Transformer w/ a ResNet33-t backbone, SE attn for non Halo blocks, SiLU,
359
+ """
360
+ return _create_byoanet('sebotnet33ts_256', 'sebotnet33ts', pretrained=pretrained, **kwargs)
361
+
362
+
363
+ @register_model
364
+ def botnet50ts_256(pretrained=False, **kwargs) -> ByobNet:
365
+ """ Bottleneck Transformer w/ ResNet50-T backbone, silu act.
366
+ """
367
+ kwargs.setdefault('img_size', 256)
368
+ return _create_byoanet('botnet50ts_256', 'botnet50ts', pretrained=pretrained, **kwargs)
369
+
370
+
371
+ @register_model
372
+ def eca_botnext26ts_256(pretrained=False, **kwargs) -> ByobNet:
373
+ """ Bottleneck Transformer w/ ResNet26-T backbone, silu act.
374
+ """
375
+ kwargs.setdefault('img_size', 256)
376
+ return _create_byoanet('eca_botnext26ts_256', 'eca_botnext26ts', pretrained=pretrained, **kwargs)
377
+
378
+
379
+ @register_model
380
+ def halonet_h1(pretrained=False, **kwargs) -> ByobNet:
381
+ """ HaloNet-H1. Halo attention in all stages as per the paper.
382
+ NOTE: This runs very slowly!
383
+ """
384
+ return _create_byoanet('halonet_h1', pretrained=pretrained, **kwargs)
385
+
386
+
387
+ @register_model
388
+ def halonet26t(pretrained=False, **kwargs) -> ByobNet:
389
+ """ HaloNet w/ a ResNet26-t backbone. Halo attention in final two stages
390
+ """
391
+ return _create_byoanet('halonet26t', pretrained=pretrained, **kwargs)
392
+
393
+
394
+ @register_model
395
+ def sehalonet33ts(pretrained=False, **kwargs) -> ByobNet:
396
+ """ HaloNet w/ a ResNet33-t backbone, SE attn for non Halo blocks, SiLU, 1-2 Halo in stage 2,3,4.
397
+ """
398
+ return _create_byoanet('sehalonet33ts', pretrained=pretrained, **kwargs)
399
+
400
+
401
+ @register_model
402
+ def halonet50ts(pretrained=False, **kwargs) -> ByobNet:
403
+ """ HaloNet w/ a ResNet50-t backbone, silu act. Halo attention in final two stages
404
+ """
405
+ return _create_byoanet('halonet50ts', pretrained=pretrained, **kwargs)
406
+
407
+
408
+ @register_model
409
+ def eca_halonext26ts(pretrained=False, **kwargs) -> ByobNet:
410
+ """ HaloNet w/ a ResNet26-t backbone, silu act. Halo attention in final two stages
411
+ """
412
+ return _create_byoanet('eca_halonext26ts', pretrained=pretrained, **kwargs)
413
+
414
+
415
+ @register_model
416
+ def lambda_resnet26t(pretrained=False, **kwargs) -> ByobNet:
417
+ """ Lambda-ResNet-26-T. Lambda layers w/ conv pos in last two stages.
418
+ """
419
+ return _create_byoanet('lambda_resnet26t', pretrained=pretrained, **kwargs)
420
+
421
+
422
+ @register_model
423
+ def lambda_resnet50ts(pretrained=False, **kwargs) -> ByobNet:
424
+ """ Lambda-ResNet-50-TS. SiLU act. Lambda layers w/ conv pos in last two stages.
425
+ """
426
+ return _create_byoanet('lambda_resnet50ts', pretrained=pretrained, **kwargs)
427
+
428
+
429
+ @register_model
430
+ def lambda_resnet26rpt_256(pretrained=False, **kwargs) -> ByobNet:
431
+ """ Lambda-ResNet-26-R-T. Lambda layers w/ rel pos embed in last two stages.
432
+ """
433
+ kwargs.setdefault('img_size', 256)
434
+ return _create_byoanet('lambda_resnet26rpt_256', pretrained=pretrained, **kwargs)
435
+
436
+
437
+ @register_model
438
+ def haloregnetz_b(pretrained=False, **kwargs) -> ByobNet:
439
+ """ Halo + RegNetZ
440
+ """
441
+ return _create_byoanet('haloregnetz_b', pretrained=pretrained, **kwargs)
442
+
443
+
444
+ @register_model
445
+ def lamhalobotnet50ts_256(pretrained=False, **kwargs) -> ByobNet:
446
+ """ Combo Attention (Lambda + Halo + Bot) Network
447
+ """
448
+ return _create_byoanet('lamhalobotnet50ts_256', 'lamhalobotnet50ts', pretrained=pretrained, **kwargs)
449
+
450
+
451
+ @register_model
452
+ def halo2botnet50ts_256(pretrained=False, **kwargs) -> ByobNet:
453
+ """ Combo Attention (Halo + Halo + Bot) Network
454
+ """
455
+ return _create_byoanet('halo2botnet50ts_256', 'halo2botnet50ts', pretrained=pretrained, **kwargs)
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/byobnet.py ADDED
@@ -0,0 +1,2245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Bring-Your-Own-Blocks Network
2
+
3
+ A flexible network w/ dataclass based config for stacking those NN blocks.
4
+
5
+ This model is currently used to implement the following networks:
6
+
7
+ GPU Efficient (ResNets) - gernet_l/m/s (original versions called genet, but this was already used (by SENet author)).
8
+ Paper: `Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
9
+ Code and weights: https://github.com/idstcv/GPU-Efficient-Networks, licensed Apache 2.0
10
+
11
+ RepVGG - repvgg_*
12
+ Paper: `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
13
+ Code and weights: https://github.com/DingXiaoH/RepVGG, licensed MIT
14
+
15
+ MobileOne - mobileone_*
16
+ Paper: `MobileOne: An Improved One millisecond Mobile Backbone` - https://arxiv.org/abs/2206.04040
17
+ Code and weights: https://github.com/apple/ml-mobileone, licensed MIT
18
+
19
+ In all cases the models have been modified to fit within the design of ByobNet. I've remapped
20
+ the original weights and verified accuracies.
21
+
22
+ For GPU Efficient nets, I used the original names for the blocks since they were for the most part
23
+ the same as original residual blocks in ResNe(X)t, DarkNet, and other existing models. Note also some
24
+ changes introduced in RegNet were also present in the stem and bottleneck blocks for this model.
25
+
26
+ A significant number of different network archs can be implemented here, including variants of the
27
+ above nets that include attention.
28
+
29
+ Hacked together by / copyright Ross Wightman, 2021.
30
+ """
31
+ import math
32
+ from dataclasses import dataclass, field, replace
33
+ from functools import partial
34
+ from typing import Tuple, List, Dict, Optional, Union, Any, Callable, Sequence
35
+
36
+ import torch
37
+ import torch.nn as nn
38
+
39
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
40
+ from timm.layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
41
+ create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0a
42
+ from ._builder import build_model_with_cfg
43
+ from ._manipulate import named_apply, checkpoint_seq
44
+ from ._registry import generate_default_cfgs, register_model
45
+
46
+ __all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block']
47
+
48
+
49
+ @dataclass
50
+ class ByoBlockCfg:
51
+ type: Union[str, nn.Module]
52
+ d: int # block depth (number of block repeats in stage)
53
+ c: int # number of output channels for each block in stage
54
+ s: int = 2 # stride of stage (first block)
55
+ gs: Optional[Union[int, Callable]] = None # group-size of blocks in stage, conv is depthwise if gs == 1
56
+ br: float = 1. # bottleneck-ratio of blocks in stage
57
+
58
+ # NOTE: these config items override the model cfgs that are applied to all blocks by default
59
+ attn_layer: Optional[str] = None
60
+ attn_kwargs: Optional[Dict[str, Any]] = None
61
+ self_attn_layer: Optional[str] = None
62
+ self_attn_kwargs: Optional[Dict[str, Any]] = None
63
+ block_kwargs: Optional[Dict[str, Any]] = None
64
+
65
+
66
+ @dataclass
67
+ class ByoModelCfg:
68
+ blocks: Tuple[Union[ByoBlockCfg, Tuple[ByoBlockCfg, ...]], ...]
69
+ downsample: str = 'conv1x1'
70
+ stem_type: str = '3x3'
71
+ stem_pool: Optional[str] = 'maxpool'
72
+ stem_chs: int = 32
73
+ width_factor: float = 1.0
74
+ num_features: int = 0 # num out_channels for final conv, no final 1x1 conv if 0
75
+ zero_init_last: bool = True # zero init last weight (usually bn) in residual path
76
+ fixed_input_size: bool = False # model constrained to a fixed-input size / img_size must be provided on creation
77
+
78
+ act_layer: str = 'relu'
79
+ norm_layer: str = 'batchnorm'
80
+
81
+ # NOTE: these config items will be overridden by the block cfg (per-block) if they are set there
82
+ attn_layer: Optional[str] = None
83
+ attn_kwargs: dict = field(default_factory=lambda: dict())
84
+ self_attn_layer: Optional[str] = None
85
+ self_attn_kwargs: dict = field(default_factory=lambda: dict())
86
+ block_kwargs: Dict[str, Any] = field(default_factory=lambda: dict())
87
+
88
+
89
+ def _rep_vgg_bcfg(d=(4, 6, 16, 1), wf=(1., 1., 1., 1.), groups=0):
90
+ c = (64, 128, 256, 512)
91
+ group_size = 0
92
+ if groups > 0:
93
+ group_size = lambda chs, idx: chs // groups if (idx + 1) % 2 == 0 else 0
94
+ bcfg = tuple([ByoBlockCfg(type='rep', d=d, c=c * wf, gs=group_size) for d, c, wf in zip(d, c, wf)])
95
+ return bcfg
96
+
97
+
98
+ def _mobileone_bcfg(d=(2, 8, 10, 1), wf=(1., 1., 1., 1.), se_blocks=(), num_conv_branches=1):
99
+ c = (64, 128, 256, 512)
100
+ prev_c = min(64, c[0] * wf[0])
101
+ se_blocks = se_blocks or (0,) * len(d)
102
+ bcfg = []
103
+ for d, c, w, se in zip(d, c, wf, se_blocks):
104
+ scfg = []
105
+ for i in range(d):
106
+ out_c = c * w
107
+ bk = dict(num_conv_branches=num_conv_branches)
108
+ ak = {}
109
+ if i >= d - se:
110
+ ak['attn_layer'] = 'se'
111
+ scfg += [ByoBlockCfg(type='one', d=1, c=prev_c, gs=1, block_kwargs=bk, **ak)] # depthwise block
112
+ scfg += [ByoBlockCfg(
113
+ type='one', d=1, c=out_c, gs=0, block_kwargs=dict(kernel_size=1, **bk), **ak)] # pointwise block
114
+ prev_c = out_c
115
+ bcfg += [scfg]
116
+ return bcfg
117
+
118
+
119
+ def interleave_blocks(
120
+ types: Tuple[str, str], d,
121
+ every: Union[int, List[int]] = 1,
122
+ first: bool = False,
123
+ **kwargs,
124
+ ) -> Tuple[ByoBlockCfg]:
125
+ """ interleave 2 block types in stack
126
+ """
127
+ assert len(types) == 2
128
+ if isinstance(every, int):
129
+ every = list(range(0 if first else every, d, every + 1))
130
+ if not every:
131
+ every = [d - 1]
132
+ set(every)
133
+ blocks = []
134
+ for i in range(d):
135
+ block_type = types[1] if i in every else types[0]
136
+ blocks += [ByoBlockCfg(type=block_type, d=1, **kwargs)]
137
+ return tuple(blocks)
138
+
139
+
140
+ def expand_blocks_cfg(stage_blocks_cfg: Union[ByoBlockCfg, Sequence[ByoBlockCfg]]) -> List[ByoBlockCfg]:
141
+ if not isinstance(stage_blocks_cfg, Sequence):
142
+ stage_blocks_cfg = (stage_blocks_cfg,)
143
+ block_cfgs = []
144
+ for i, cfg in enumerate(stage_blocks_cfg):
145
+ block_cfgs += [replace(cfg, d=1) for _ in range(cfg.d)]
146
+ return block_cfgs
147
+
148
+
149
+ def num_groups(group_size, channels):
150
+ if not group_size: # 0 or None
151
+ return 1 # normal conv with 1 group
152
+ else:
153
+ # NOTE group_size == 1 -> depthwise conv
154
+ assert channels % group_size == 0
155
+ return channels // group_size
156
+
157
+
158
+ @dataclass
159
+ class LayerFn:
160
+ conv_norm_act: Callable = ConvNormAct
161
+ norm_act: Callable = BatchNormAct2d
162
+ act: Callable = nn.ReLU
163
+ attn: Optional[Callable] = None
164
+ self_attn: Optional[Callable] = None
165
+
166
+
167
+ class DownsampleAvg(nn.Module):
168
+ def __init__(
169
+ self,
170
+ in_chs: int,
171
+ out_chs: int,
172
+ stride: int = 1,
173
+ dilation: int = 1,
174
+ apply_act: bool = False,
175
+ layers: LayerFn = None,
176
+ ):
177
+ """ AvgPool Downsampling as in 'D' ResNet variants."""
178
+ super(DownsampleAvg, self).__init__()
179
+ layers = layers or LayerFn()
180
+ avg_stride = stride if dilation == 1 else 1
181
+ if stride > 1 or dilation > 1:
182
+ avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
183
+ self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
184
+ else:
185
+ self.pool = nn.Identity()
186
+ self.conv = layers.conv_norm_act(in_chs, out_chs, 1, apply_act=apply_act)
187
+
188
+ def forward(self, x):
189
+ return self.conv(self.pool(x))
190
+
191
+
192
+ def create_shortcut(
193
+ downsample_type: str,
194
+ in_chs: int,
195
+ out_chs: int,
196
+ stride: int,
197
+ dilation: Tuple[int, int],
198
+ layers: LayerFn,
199
+ **kwargs,
200
+ ):
201
+ assert downsample_type in ('avg', 'conv1x1', '')
202
+ if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
203
+ if not downsample_type:
204
+ return None # no shortcut
205
+ elif downsample_type == 'avg':
206
+ return DownsampleAvg(in_chs, out_chs, stride=stride, dilation=dilation[0], **kwargs)
207
+ else:
208
+ return layers.conv_norm_act(in_chs, out_chs, kernel_size=1, stride=stride, dilation=dilation[0], **kwargs)
209
+ else:
210
+ return nn.Identity() # identity shortcut
211
+
212
+
213
+ class BasicBlock(nn.Module):
214
+ """ ResNet Basic Block - kxk + kxk
215
+ """
216
+
217
+ def __init__(
218
+ self,
219
+ in_chs: int,
220
+ out_chs: int,
221
+ kernel_size: int = 3,
222
+ stride: int = 1,
223
+ dilation: Tuple[int, int] = (1, 1),
224
+ group_size: Optional[int] = None,
225
+ bottle_ratio: float = 1.0,
226
+ downsample: str = 'avg',
227
+ attn_last: bool = True,
228
+ linear_out: bool = False,
229
+ layers: LayerFn = None,
230
+ drop_block: Callable = None,
231
+ drop_path_rate: float = 0.,
232
+ ):
233
+ super(BasicBlock, self).__init__()
234
+ layers = layers or LayerFn()
235
+ mid_chs = make_divisible(out_chs * bottle_ratio)
236
+ groups = num_groups(group_size, mid_chs)
237
+
238
+ self.shortcut = create_shortcut(
239
+ downsample, in_chs, out_chs,
240
+ stride=stride, dilation=dilation, apply_act=False, layers=layers,
241
+ )
242
+
243
+ self.conv1_kxk = layers.conv_norm_act(in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0])
244
+ self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
245
+ self.conv2_kxk = layers.conv_norm_act(
246
+ mid_chs, out_chs, kernel_size,
247
+ dilation=dilation[1], groups=groups, drop_layer=drop_block, apply_act=False,
248
+ )
249
+ self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
250
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
251
+ self.act = nn.Identity() if linear_out else layers.act(inplace=True)
252
+
253
+ def init_weights(self, zero_init_last: bool = False):
254
+ if zero_init_last and self.shortcut is not None and getattr(self.conv2_kxk.bn, 'weight', None) is not None:
255
+ nn.init.zeros_(self.conv2_kxk.bn.weight)
256
+ for attn in (self.attn, self.attn_last):
257
+ if hasattr(attn, 'reset_parameters'):
258
+ attn.reset_parameters()
259
+
260
+ def forward(self, x):
261
+ shortcut = x
262
+ x = self.conv1_kxk(x)
263
+ x = self.conv2_kxk(x)
264
+ x = self.attn(x)
265
+ x = self.drop_path(x)
266
+ if self.shortcut is not None:
267
+ x = x + self.shortcut(shortcut)
268
+ return self.act(x)
269
+
270
+
271
+ class BottleneckBlock(nn.Module):
272
+ """ ResNet-like Bottleneck Block - 1x1 - kxk - 1x1
273
+ """
274
+
275
+ def __init__(
276
+ self,
277
+ in_chs: int,
278
+ out_chs: int,
279
+ kernel_size: int = 3,
280
+ stride: int = 1,
281
+ dilation: Tuple[int, int] = (1, 1),
282
+ bottle_ratio: float = 1.,
283
+ group_size: Optional[int] = None,
284
+ downsample: str = 'avg',
285
+ attn_last: bool = False,
286
+ linear_out: bool = False,
287
+ extra_conv: bool = False,
288
+ bottle_in: bool = False,
289
+ layers: LayerFn = None,
290
+ drop_block: Callable = None,
291
+ drop_path_rate: float = 0.,
292
+ ):
293
+ super(BottleneckBlock, self).__init__()
294
+ layers = layers or LayerFn()
295
+ mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
296
+ groups = num_groups(group_size, mid_chs)
297
+
298
+ self.shortcut = create_shortcut(
299
+ downsample, in_chs, out_chs,
300
+ stride=stride, dilation=dilation, apply_act=False, layers=layers,
301
+ )
302
+
303
+ self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
304
+ self.conv2_kxk = layers.conv_norm_act(
305
+ mid_chs, mid_chs, kernel_size,
306
+ stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block,
307
+ )
308
+ if extra_conv:
309
+ self.conv2b_kxk = layers.conv_norm_act(
310
+ mid_chs, mid_chs, kernel_size, dilation=dilation[1], groups=groups)
311
+ else:
312
+ self.conv2b_kxk = nn.Identity()
313
+ self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
314
+ self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
315
+ self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
316
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
317
+ self.act = nn.Identity() if linear_out else layers.act(inplace=True)
318
+
319
+ def init_weights(self, zero_init_last: bool = False):
320
+ if zero_init_last and self.shortcut is not None and getattr(self.conv3_1x1.bn, 'weight', None) is not None:
321
+ nn.init.zeros_(self.conv3_1x1.bn.weight)
322
+ for attn in (self.attn, self.attn_last):
323
+ if hasattr(attn, 'reset_parameters'):
324
+ attn.reset_parameters()
325
+
326
+ def forward(self, x):
327
+ shortcut = x
328
+ x = self.conv1_1x1(x)
329
+ x = self.conv2_kxk(x)
330
+ x = self.conv2b_kxk(x)
331
+ x = self.attn(x)
332
+ x = self.conv3_1x1(x)
333
+ x = self.attn_last(x)
334
+ x = self.drop_path(x)
335
+ if self.shortcut is not None:
336
+ x = x + self.shortcut(shortcut)
337
+ return self.act(x)
338
+
339
+
340
+ class DarkBlock(nn.Module):
341
+ """ DarkNet-like (1x1 + 3x3 w/ stride) block
342
+
343
+ The GE-Net impl included a 1x1 + 3x3 block in their search space. It was not used in the feature models.
344
+ This block is pretty much a DarkNet block (also DenseNet) hence the name. Neither DarkNet or DenseNet
345
+ uses strides within the block (external 3x3 or maxpool downsampling is done in front of the block repeats).
346
+
347
+ If one does want to use a lot of these blocks w/ stride, I'd recommend using the EdgeBlock (3x3 /w stride + 1x1)
348
+ for more optimal compute.
349
+ """
350
+
351
+ def __init__(
352
+ self,
353
+ in_chs: int,
354
+ out_chs: int,
355
+ kernel_size: int = 3,
356
+ stride: int = 1,
357
+ dilation: Tuple[int, int] = (1, 1),
358
+ bottle_ratio: float = 1.0,
359
+ group_size: Optional[int] = None,
360
+ downsample: str = 'avg',
361
+ attn_last: bool = True,
362
+ linear_out: bool = False,
363
+ layers: LayerFn = None,
364
+ drop_block: Callable = None,
365
+ drop_path_rate: float = 0.,
366
+ ):
367
+ super(DarkBlock, self).__init__()
368
+ layers = layers or LayerFn()
369
+ mid_chs = make_divisible(out_chs * bottle_ratio)
370
+ groups = num_groups(group_size, mid_chs)
371
+
372
+ self.shortcut = create_shortcut(
373
+ downsample, in_chs, out_chs,
374
+ stride=stride, dilation=dilation, apply_act=False, layers=layers,
375
+ )
376
+
377
+ self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
378
+ self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
379
+ self.conv2_kxk = layers.conv_norm_act(
380
+ mid_chs, out_chs, kernel_size,
381
+ stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block, apply_act=False,
382
+ )
383
+ self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
384
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
385
+ self.act = nn.Identity() if linear_out else layers.act(inplace=True)
386
+
387
+ def init_weights(self, zero_init_last: bool = False):
388
+ if zero_init_last and self.shortcut is not None and getattr(self.conv2_kxk.bn, 'weight', None) is not None:
389
+ nn.init.zeros_(self.conv2_kxk.bn.weight)
390
+ for attn in (self.attn, self.attn_last):
391
+ if hasattr(attn, 'reset_parameters'):
392
+ attn.reset_parameters()
393
+
394
+ def forward(self, x):
395
+ shortcut = x
396
+ x = self.conv1_1x1(x)
397
+ x = self.attn(x)
398
+ x = self.conv2_kxk(x)
399
+ x = self.attn_last(x)
400
+ x = self.drop_path(x)
401
+ if self.shortcut is not None:
402
+ x = x + self.shortcut(shortcut)
403
+ return self.act(x)
404
+
405
+
406
+ class EdgeBlock(nn.Module):
407
+ """ EdgeResidual-like (3x3 + 1x1) block
408
+
409
+ A two layer block like DarkBlock, but with the order of the 3x3 and 1x1 convs reversed.
410
+ Very similar to the EfficientNet Edge-Residual block but this block it ends with activations, is
411
+ intended to be used with either expansion or bottleneck contraction, and can use DW/group/non-grouped convs.
412
+
413
+ FIXME is there a more common 3x3 + 1x1 conv block to name this after?
414
+ """
415
+
416
+ def __init__(
417
+ self,
418
+ in_chs: int,
419
+ out_chs: int,
420
+ kernel_size: int = 3,
421
+ stride: int = 1,
422
+ dilation: Tuple[int, int] = (1, 1),
423
+ bottle_ratio: float = 1.0,
424
+ group_size: Optional[int] = None,
425
+ downsample: str = 'avg',
426
+ attn_last: bool = False,
427
+ linear_out: bool = False,
428
+ layers: LayerFn = None,
429
+ drop_block: Callable = None,
430
+ drop_path_rate: float = 0.,
431
+ ):
432
+ super(EdgeBlock, self).__init__()
433
+ layers = layers or LayerFn()
434
+ mid_chs = make_divisible(out_chs * bottle_ratio)
435
+ groups = num_groups(group_size, mid_chs)
436
+
437
+ self.shortcut = create_shortcut(
438
+ downsample, in_chs, out_chs,
439
+ stride=stride, dilation=dilation, apply_act=False, layers=layers,
440
+ )
441
+
442
+ self.conv1_kxk = layers.conv_norm_act(
443
+ in_chs, mid_chs, kernel_size,
444
+ stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block,
445
+ )
446
+ self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
447
+ self.conv2_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
448
+ self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
449
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
450
+ self.act = nn.Identity() if linear_out else layers.act(inplace=True)
451
+
452
+ def init_weights(self, zero_init_last: bool = False):
453
+ if zero_init_last and self.shortcut is not None and getattr(self.conv2_1x1.bn, 'weight', None) is not None:
454
+ nn.init.zeros_(self.conv2_1x1.bn.weight)
455
+ for attn in (self.attn, self.attn_last):
456
+ if hasattr(attn, 'reset_parameters'):
457
+ attn.reset_parameters()
458
+
459
+ def forward(self, x):
460
+ shortcut = x
461
+ x = self.conv1_kxk(x)
462
+ x = self.attn(x)
463
+ x = self.conv2_1x1(x)
464
+ x = self.attn_last(x)
465
+ x = self.drop_path(x)
466
+ if self.shortcut is not None:
467
+ x = x + self.shortcut(shortcut)
468
+ return self.act(x)
469
+
470
+
471
+ class RepVggBlock(nn.Module):
472
+ """ RepVGG Block.
473
+
474
+ Adapted from impl at https://github.com/DingXiaoH/RepVGG
475
+ """
476
+
477
+ def __init__(
478
+ self,
479
+ in_chs: int,
480
+ out_chs: int,
481
+ kernel_size: int = 3,
482
+ stride: int = 1,
483
+ dilation: Tuple[int, int] = (1, 1),
484
+ bottle_ratio: float = 1.0,
485
+ group_size: Optional[int] = None,
486
+ downsample: str = '',
487
+ layers: LayerFn = None,
488
+ drop_block: Callable = None,
489
+ drop_path_rate: float = 0.,
490
+ inference_mode: bool = False
491
+ ):
492
+ super(RepVggBlock, self).__init__()
493
+ self.groups = groups = num_groups(group_size, in_chs)
494
+ layers = layers or LayerFn()
495
+
496
+ if inference_mode:
497
+ self.reparam_conv = nn.Conv2d(
498
+ in_channels=in_chs,
499
+ out_channels=out_chs,
500
+ kernel_size=kernel_size,
501
+ stride=stride,
502
+ dilation=dilation,
503
+ groups=groups,
504
+ bias=True,
505
+ )
506
+ else:
507
+ self.reparam_conv = None
508
+ use_ident = in_chs == out_chs and stride == 1 and dilation[0] == dilation[1]
509
+ self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None
510
+ self.conv_kxk = layers.conv_norm_act(
511
+ in_chs, out_chs, kernel_size,
512
+ stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block, apply_act=False,
513
+ )
514
+ self.conv_1x1 = layers.conv_norm_act(in_chs, out_chs, 1, stride=stride, groups=groups, apply_act=False)
515
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity()
516
+
517
+ self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
518
+ self.act = layers.act(inplace=True)
519
+
520
+ def init_weights(self, zero_init_last: bool = False):
521
+ # NOTE this init overrides that base model init with specific changes for the block type
522
+ for m in self.modules():
523
+ if isinstance(m, nn.BatchNorm2d):
524
+ nn.init.normal_(m.weight, .1, .1)
525
+ nn.init.normal_(m.bias, 0, .1)
526
+ if hasattr(self.attn, 'reset_parameters'):
527
+ self.attn.reset_parameters()
528
+
529
+ def forward(self, x):
530
+ if self.reparam_conv is not None:
531
+ return self.act(self.attn(self.reparam_conv(x)))
532
+
533
+ if self.identity is None:
534
+ x = self.conv_1x1(x) + self.conv_kxk(x)
535
+ else:
536
+ identity = self.identity(x)
537
+ x = self.conv_1x1(x) + self.conv_kxk(x)
538
+ x = self.drop_path(x) # not in the paper / official impl, experimental
539
+ x += identity
540
+ x = self.attn(x) # no attn in the paper / official impl, experimental
541
+ return self.act(x)
542
+
543
+ def reparameterize(self):
544
+ """ Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
545
+ https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
546
+ architecture used at training time to obtain a plain CNN-like structure
547
+ for inference.
548
+ """
549
+ if self.reparam_conv is not None:
550
+ return
551
+
552
+ kernel, bias = self._get_kernel_bias()
553
+ self.reparam_conv = nn.Conv2d(
554
+ in_channels=self.conv_kxk.conv.in_channels,
555
+ out_channels=self.conv_kxk.conv.out_channels,
556
+ kernel_size=self.conv_kxk.conv.kernel_size,
557
+ stride=self.conv_kxk.conv.stride,
558
+ padding=self.conv_kxk.conv.padding,
559
+ dilation=self.conv_kxk.conv.dilation,
560
+ groups=self.conv_kxk.conv.groups,
561
+ bias=True,
562
+ )
563
+ self.reparam_conv.weight.data = kernel
564
+ self.reparam_conv.bias.data = bias
565
+
566
+ # Delete un-used branches
567
+ for name, para in self.named_parameters():
568
+ if 'reparam_conv' in name:
569
+ continue
570
+ para.detach_()
571
+ self.__delattr__('conv_kxk')
572
+ self.__delattr__('conv_1x1')
573
+ self.__delattr__('identity')
574
+ self.__delattr__('drop_path')
575
+
576
+ def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
577
+ """ Method to obtain re-parameterized kernel and bias.
578
+ Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
579
+ """
580
+ # get weights and bias of scale branch
581
+ kernel_1x1 = 0
582
+ bias_1x1 = 0
583
+ if self.conv_1x1 is not None:
584
+ kernel_1x1, bias_1x1 = self._fuse_bn_tensor(self.conv_1x1)
585
+ # Pad scale branch kernel to match conv branch kernel size.
586
+ pad = self.conv_kxk.conv.kernel_size[0] // 2
587
+ kernel_1x1 = torch.nn.functional.pad(kernel_1x1, [pad, pad, pad, pad])
588
+
589
+ # get weights and bias of skip branch
590
+ kernel_identity = 0
591
+ bias_identity = 0
592
+ if self.identity is not None:
593
+ kernel_identity, bias_identity = self._fuse_bn_tensor(self.identity)
594
+
595
+ # get weights and bias of conv branches
596
+ kernel_conv, bias_conv = self._fuse_bn_tensor(self.conv_kxk)
597
+
598
+ kernel_final = kernel_conv + kernel_1x1 + kernel_identity
599
+ bias_final = bias_conv + bias_1x1 + bias_identity
600
+ return kernel_final, bias_final
601
+
602
+ def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]:
603
+ """ Method to fuse batchnorm layer with preceeding conv layer.
604
+ Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
605
+ """
606
+ if isinstance(branch, ConvNormAct):
607
+ kernel = branch.conv.weight
608
+ running_mean = branch.bn.running_mean
609
+ running_var = branch.bn.running_var
610
+ gamma = branch.bn.weight
611
+ beta = branch.bn.bias
612
+ eps = branch.bn.eps
613
+ else:
614
+ assert isinstance(branch, nn.BatchNorm2d)
615
+ if not hasattr(self, 'id_tensor'):
616
+ in_chs = self.conv_kxk.conv.in_channels
617
+ input_dim = in_chs // self.groups
618
+ kernel_size = self.conv_kxk.conv.kernel_size
619
+ kernel_value = torch.zeros_like(self.conv_kxk.conv.weight)
620
+ for i in range(in_chs):
621
+ kernel_value[i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2] = 1
622
+ self.id_tensor = kernel_value
623
+ kernel = self.id_tensor
624
+ running_mean = branch.running_mean
625
+ running_var = branch.running_var
626
+ gamma = branch.weight
627
+ beta = branch.bias
628
+ eps = branch.eps
629
+ std = (running_var + eps).sqrt()
630
+ t = (gamma / std).reshape(-1, 1, 1, 1)
631
+ return kernel * t, beta - running_mean * gamma / std
632
+
633
+
634
+ class MobileOneBlock(nn.Module):
635
+ """ MobileOne building block.
636
+
637
+ This block has a multi-branched architecture at train-time
638
+ and plain-CNN style architecture at inference time
639
+ For more details, please refer to our paper:
640
+ `An Improved One millisecond Mobile Backbone` -
641
+ https://arxiv.org/pdf/2206.04040.pdf
642
+ """
643
+
644
+ def __init__(
645
+ self,
646
+ in_chs: int,
647
+ out_chs: int,
648
+ kernel_size: int = 3,
649
+ stride: int = 1,
650
+ dilation: Tuple[int, int] = (1, 1),
651
+ bottle_ratio: float = 1.0, # unused
652
+ group_size: Optional[int] = None,
653
+ downsample: str = '', # unused
654
+ inference_mode: bool = False,
655
+ num_conv_branches: int = 1,
656
+ layers: LayerFn = None,
657
+ drop_block: Callable = None,
658
+ drop_path_rate: float = 0.,
659
+ ) -> None:
660
+ """ Construct a MobileOneBlock module.
661
+ """
662
+ super(MobileOneBlock, self).__init__()
663
+ self.num_conv_branches = num_conv_branches
664
+ self.groups = groups = num_groups(group_size, in_chs)
665
+ layers = layers or LayerFn()
666
+
667
+ if inference_mode:
668
+ self.reparam_conv = nn.Conv2d(
669
+ in_channels=in_chs,
670
+ out_channels=out_chs,
671
+ kernel_size=kernel_size,
672
+ stride=stride,
673
+ dilation=dilation,
674
+ groups=groups,
675
+ bias=True)
676
+ else:
677
+ self.reparam_conv = None
678
+
679
+ # Re-parameterizable skip connection
680
+ use_ident = in_chs == out_chs and stride == 1 and dilation[0] == dilation[1]
681
+ self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None
682
+
683
+ # Re-parameterizable conv branches
684
+ convs = []
685
+ for _ in range(self.num_conv_branches):
686
+ convs.append(layers.conv_norm_act(
687
+ in_chs, out_chs, kernel_size=kernel_size,
688
+ stride=stride, groups=groups, apply_act=False))
689
+ self.conv_kxk = nn.ModuleList(convs)
690
+
691
+ # Re-parameterizable scale branch
692
+ self.conv_scale = None
693
+ if kernel_size > 1:
694
+ self.conv_scale = layers.conv_norm_act(
695
+ in_chs, out_chs, kernel_size=1,
696
+ stride=stride, groups=groups, apply_act=False)
697
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity()
698
+
699
+ self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
700
+ self.act = layers.act(inplace=True)
701
+
702
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
703
+ """ Apply forward pass. """
704
+ # Inference mode forward pass.
705
+ if self.reparam_conv is not None:
706
+ return self.act(self.attn(self.reparam_conv(x)))
707
+
708
+ # Multi-branched train-time forward pass.
709
+ # Skip branch output
710
+ identity_out = 0
711
+ if self.identity is not None:
712
+ identity_out = self.identity(x)
713
+
714
+ # Scale branch output
715
+ scale_out = 0
716
+ if self.conv_scale is not None:
717
+ scale_out = self.conv_scale(x)
718
+
719
+ # Other branches
720
+ out = scale_out
721
+ for ck in self.conv_kxk:
722
+ out += ck(x)
723
+ out = self.drop_path(out)
724
+ out += identity_out
725
+
726
+ return self.act(self.attn(out))
727
+
728
+ def reparameterize(self):
729
+ """ Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
730
+ https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
731
+ architecture used at training time to obtain a plain CNN-like structure
732
+ for inference.
733
+ """
734
+ if self.reparam_conv is not None:
735
+ return
736
+
737
+ kernel, bias = self._get_kernel_bias()
738
+ self.reparam_conv = nn.Conv2d(
739
+ in_channels=self.conv_kxk[0].conv.in_channels,
740
+ out_channels=self.conv_kxk[0].conv.out_channels,
741
+ kernel_size=self.conv_kxk[0].conv.kernel_size,
742
+ stride=self.conv_kxk[0].conv.stride,
743
+ padding=self.conv_kxk[0].conv.padding,
744
+ dilation=self.conv_kxk[0].conv.dilation,
745
+ groups=self.conv_kxk[0].conv.groups,
746
+ bias=True)
747
+ self.reparam_conv.weight.data = kernel
748
+ self.reparam_conv.bias.data = bias
749
+
750
+ # Delete un-used branches
751
+ for name, para in self.named_parameters():
752
+ if 'reparam_conv' in name:
753
+ continue
754
+ para.detach_()
755
+ self.__delattr__('conv_kxk')
756
+ self.__delattr__('conv_scale')
757
+ self.__delattr__('identity')
758
+ self.__delattr__('drop_path')
759
+
760
+ def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
761
+ """ Method to obtain re-parameterized kernel and bias.
762
+ Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
763
+ """
764
+ # get weights and bias of scale branch
765
+ kernel_scale = 0
766
+ bias_scale = 0
767
+ if self.conv_scale is not None:
768
+ kernel_scale, bias_scale = self._fuse_bn_tensor(self.conv_scale)
769
+ # Pad scale branch kernel to match conv branch kernel size.
770
+ pad = self.conv_kxk[0].conv.kernel_size[0] // 2
771
+ kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
772
+
773
+ # get weights and bias of skip branch
774
+ kernel_identity = 0
775
+ bias_identity = 0
776
+ if self.identity is not None:
777
+ kernel_identity, bias_identity = self._fuse_bn_tensor(self.identity)
778
+
779
+ # get weights and bias of conv branches
780
+ kernel_conv = 0
781
+ bias_conv = 0
782
+ for ix in range(self.num_conv_branches):
783
+ _kernel, _bias = self._fuse_bn_tensor(self.conv_kxk[ix])
784
+ kernel_conv += _kernel
785
+ bias_conv += _bias
786
+
787
+ kernel_final = kernel_conv + kernel_scale + kernel_identity
788
+ bias_final = bias_conv + bias_scale + bias_identity
789
+ return kernel_final, bias_final
790
+
791
+ def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]:
792
+ """ Method to fuse batchnorm layer with preceeding conv layer.
793
+ Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
794
+ """
795
+ if isinstance(branch, ConvNormAct):
796
+ kernel = branch.conv.weight
797
+ running_mean = branch.bn.running_mean
798
+ running_var = branch.bn.running_var
799
+ gamma = branch.bn.weight
800
+ beta = branch.bn.bias
801
+ eps = branch.bn.eps
802
+ else:
803
+ assert isinstance(branch, nn.BatchNorm2d)
804
+ if not hasattr(self, 'id_tensor'):
805
+ in_chs = self.conv_kxk[0].conv.in_channels
806
+ input_dim = in_chs // self.groups
807
+ kernel_size = self.conv_kxk[0].conv.kernel_size
808
+ kernel_value = torch.zeros_like(self.conv_kxk[0].conv.weight)
809
+ for i in range(in_chs):
810
+ kernel_value[i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2] = 1
811
+ self.id_tensor = kernel_value
812
+ kernel = self.id_tensor
813
+ running_mean = branch.running_mean
814
+ running_var = branch.running_var
815
+ gamma = branch.weight
816
+ beta = branch.bias
817
+ eps = branch.eps
818
+ std = (running_var + eps).sqrt()
819
+ t = (gamma / std).reshape(-1, 1, 1, 1)
820
+ return kernel * t, beta - running_mean * gamma / std
821
+
822
+
823
+ class SelfAttnBlock(nn.Module):
824
+ """ ResNet-like Bottleneck Block - 1x1 - optional kxk - self attn - 1x1
825
+ """
826
+
827
+ def __init__(
828
+ self,
829
+ in_chs: int,
830
+ out_chs: int,
831
+ kernel_size: int = 3,
832
+ stride: int = 1,
833
+ dilation: Tuple[int, int] = (1, 1),
834
+ bottle_ratio: float = 1.,
835
+ group_size: Optional[int] = None,
836
+ downsample: str = 'avg',
837
+ extra_conv: bool = False,
838
+ linear_out: bool = False,
839
+ bottle_in: bool = False,
840
+ post_attn_na: bool = True,
841
+ feat_size: Optional[Tuple[int, int]] = None,
842
+ layers: LayerFn = None,
843
+ drop_block: Callable = None,
844
+ drop_path_rate: float = 0.,
845
+ ):
846
+ super(SelfAttnBlock, self).__init__()
847
+ assert layers is not None
848
+ mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
849
+ groups = num_groups(group_size, mid_chs)
850
+
851
+ self.shortcut = create_shortcut(
852
+ downsample, in_chs, out_chs,
853
+ stride=stride, dilation=dilation, apply_act=False, layers=layers,
854
+ )
855
+
856
+ self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
857
+ if extra_conv:
858
+ self.conv2_kxk = layers.conv_norm_act(
859
+ mid_chs, mid_chs, kernel_size,
860
+ stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block,
861
+ )
862
+ stride = 1 # striding done via conv if enabled
863
+ else:
864
+ self.conv2_kxk = nn.Identity()
865
+ opt_kwargs = {} if feat_size is None else dict(feat_size=feat_size)
866
+ # FIXME need to dilate self attn to have dilated network support, moop moop
867
+ self.self_attn = layers.self_attn(mid_chs, stride=stride, **opt_kwargs)
868
+ self.post_attn = layers.norm_act(mid_chs) if post_attn_na else nn.Identity()
869
+ self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
870
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
871
+ self.act = nn.Identity() if linear_out else layers.act(inplace=True)
872
+
873
+ def init_weights(self, zero_init_last: bool = False):
874
+ if zero_init_last and self.shortcut is not None and getattr(self.conv3_1x1.bn, 'weight', None) is not None:
875
+ nn.init.zeros_(self.conv3_1x1.bn.weight)
876
+ if hasattr(self.self_attn, 'reset_parameters'):
877
+ self.self_attn.reset_parameters()
878
+
879
+ def forward(self, x):
880
+ shortcut = x
881
+ x = self.conv1_1x1(x)
882
+ x = self.conv2_kxk(x)
883
+ x = self.self_attn(x)
884
+ x = self.post_attn(x)
885
+ x = self.conv3_1x1(x)
886
+ x = self.drop_path(x)
887
+ if self.shortcut is not None:
888
+ x = x + self.shortcut(shortcut)
889
+ return self.act(x)
890
+
891
+
892
+ _block_registry = dict(
893
+ basic=BasicBlock,
894
+ bottle=BottleneckBlock,
895
+ dark=DarkBlock,
896
+ edge=EdgeBlock,
897
+ rep=RepVggBlock,
898
+ one=MobileOneBlock,
899
+ self_attn=SelfAttnBlock,
900
+ )
901
+
902
+
903
+ def register_block(block_type:str, block_fn: nn.Module):
904
+ _block_registry[block_type] = block_fn
905
+
906
+
907
+ def create_block(block: Union[str, nn.Module], **kwargs):
908
+ if isinstance(block, (nn.Module, partial)):
909
+ return block(**kwargs)
910
+ assert block in _block_registry, f'Unknown block type ({block}'
911
+ return _block_registry[block](**kwargs)
912
+
913
+
914
+ class Stem(nn.Sequential):
915
+
916
+ def __init__(
917
+ self,
918
+ in_chs: int,
919
+ out_chs: int,
920
+ kernel_size: int = 3,
921
+ stride: int = 4,
922
+ pool: str = 'maxpool',
923
+ num_rep: int = 3,
924
+ num_act: Optional[int] = None,
925
+ chs_decay: float = 0.5,
926
+ layers: LayerFn = None,
927
+ ):
928
+ super().__init__()
929
+ assert stride in (2, 4)
930
+ layers = layers or LayerFn()
931
+
932
+ if isinstance(out_chs, (list, tuple)):
933
+ num_rep = len(out_chs)
934
+ stem_chs = out_chs
935
+ else:
936
+ stem_chs = [round(out_chs * chs_decay ** i) for i in range(num_rep)][::-1]
937
+
938
+ self.stride = stride
939
+ self.feature_info = [] # track intermediate features
940
+ prev_feat = ''
941
+ stem_strides = [2] + [1] * (num_rep - 1)
942
+ if stride == 4 and not pool:
943
+ # set last conv in stack to be strided if stride == 4 and no pooling layer
944
+ stem_strides[-1] = 2
945
+
946
+ num_act = num_rep if num_act is None else num_act
947
+ # if num_act < num_rep, first convs in stack won't have bn + act
948
+ stem_norm_acts = [False] * (num_rep - num_act) + [True] * num_act
949
+ prev_chs = in_chs
950
+ curr_stride = 1
951
+ for i, (ch, s, na) in enumerate(zip(stem_chs, stem_strides, stem_norm_acts)):
952
+ layer_fn = layers.conv_norm_act if na else create_conv2d
953
+ conv_name = f'conv{i + 1}'
954
+ if i > 0 and s > 1:
955
+ self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
956
+ self.add_module(conv_name, layer_fn(prev_chs, ch, kernel_size=kernel_size, stride=s))
957
+ prev_chs = ch
958
+ curr_stride *= s
959
+ prev_feat = conv_name
960
+
961
+ if pool and 'max' in pool.lower():
962
+ self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
963
+ self.add_module('pool', nn.MaxPool2d(3, 2, 1))
964
+ curr_stride *= 2
965
+ prev_feat = 'pool'
966
+
967
+ self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
968
+ assert curr_stride == stride
969
+
970
+
971
+ def create_byob_stem(
972
+ in_chs: int,
973
+ out_chs: int,
974
+ stem_type: str = '',
975
+ pool_type: str = '',
976
+ feat_prefix: str = 'stem',
977
+ layers: LayerFn = None,
978
+ ):
979
+ layers = layers or LayerFn()
980
+ assert stem_type in ('', 'quad', 'quad2', 'tiered', 'deep', 'rep', 'one', '7x7', '3x3')
981
+ if 'quad' in stem_type:
982
+ # based on NFNet stem, stack of 4 3x3 convs
983
+ num_act = 2 if 'quad2' in stem_type else None
984
+ stem = Stem(in_chs, out_chs, num_rep=4, num_act=num_act, pool=pool_type, layers=layers)
985
+ elif 'tiered' in stem_type:
986
+ # 3x3 stack of 3 convs as in my ResNet-T
987
+ stem = Stem(in_chs, (3 * out_chs // 8, out_chs // 2, out_chs), pool=pool_type, layers=layers)
988
+ elif 'deep' in stem_type:
989
+ # 3x3 stack of 3 convs as in ResNet-D
990
+ stem = Stem(in_chs, out_chs, num_rep=3, chs_decay=1.0, pool=pool_type, layers=layers)
991
+ elif 'rep' in stem_type:
992
+ stem = RepVggBlock(in_chs, out_chs, stride=2, layers=layers)
993
+ elif 'one' in stem_type:
994
+ stem = MobileOneBlock(in_chs, out_chs, kernel_size=3, stride=2, layers=layers)
995
+ elif '7x7' in stem_type:
996
+ # 7x7 stem conv as in ResNet
997
+ if pool_type:
998
+ stem = Stem(in_chs, out_chs, 7, num_rep=1, pool=pool_type, layers=layers)
999
+ else:
1000
+ stem = layers.conv_norm_act(in_chs, out_chs, 7, stride=2)
1001
+ else:
1002
+ # 3x3 stem conv as in RegNet is the default
1003
+ if pool_type:
1004
+ stem = Stem(in_chs, out_chs, 3, num_rep=1, pool=pool_type, layers=layers)
1005
+ else:
1006
+ stem = layers.conv_norm_act(in_chs, out_chs, 3, stride=2)
1007
+
1008
+ if isinstance(stem, Stem):
1009
+ feature_info = [dict(f, module='.'.join([feat_prefix, f['module']])) for f in stem.feature_info]
1010
+ else:
1011
+ feature_info = [dict(num_chs=out_chs, reduction=2, module=feat_prefix)]
1012
+ return stem, feature_info
1013
+
1014
+
1015
+ def reduce_feat_size(feat_size, stride=2):
1016
+ return None if feat_size is None else tuple([s // stride for s in feat_size])
1017
+
1018
+
1019
+ def override_kwargs(block_kwargs, model_kwargs):
1020
+ """ Override model level attn/self-attn/block kwargs w/ block level
1021
+
1022
+ NOTE: kwargs are NOT merged across levels, block_kwargs will fully replace model_kwargs
1023
+ for the block if set to anything that isn't None.
1024
+
1025
+ i.e. an empty block_kwargs dict will remove kwargs set at model level for that block
1026
+ """
1027
+ out_kwargs = block_kwargs if block_kwargs is not None else model_kwargs
1028
+ return out_kwargs or {} # make sure None isn't returned
1029
+
1030
+
1031
+ def update_block_kwargs(block_kwargs: Dict[str, Any], block_cfg: ByoBlockCfg, model_cfg: ByoModelCfg, ):
1032
+ layer_fns = block_kwargs['layers']
1033
+
1034
+ # override attn layer / args with block local config
1035
+ attn_set = block_cfg.attn_layer is not None
1036
+ if attn_set or block_cfg.attn_kwargs is not None:
1037
+ # override attn layer config
1038
+ if attn_set and not block_cfg.attn_layer:
1039
+ # empty string for attn_layer type will disable attn for this block
1040
+ attn_layer = None
1041
+ else:
1042
+ attn_kwargs = override_kwargs(block_cfg.attn_kwargs, model_cfg.attn_kwargs)
1043
+ attn_layer = block_cfg.attn_layer or model_cfg.attn_layer
1044
+ attn_layer = partial(get_attn(attn_layer), **attn_kwargs) if attn_layer is not None else None
1045
+ layer_fns = replace(layer_fns, attn=attn_layer)
1046
+
1047
+ # override self-attn layer / args with block local cfg
1048
+ self_attn_set = block_cfg.self_attn_layer is not None
1049
+ if self_attn_set or block_cfg.self_attn_kwargs is not None:
1050
+ # override attn layer config
1051
+ if self_attn_set and not block_cfg.self_attn_layer: # attn_layer == ''
1052
+ # empty string for self_attn_layer type will disable attn for this block
1053
+ self_attn_layer = None
1054
+ else:
1055
+ self_attn_kwargs = override_kwargs(block_cfg.self_attn_kwargs, model_cfg.self_attn_kwargs)
1056
+ self_attn_layer = block_cfg.self_attn_layer or model_cfg.self_attn_layer
1057
+ self_attn_layer = partial(get_attn(self_attn_layer), **self_attn_kwargs) \
1058
+ if self_attn_layer is not None else None
1059
+ layer_fns = replace(layer_fns, self_attn=self_attn_layer)
1060
+
1061
+ block_kwargs['layers'] = layer_fns
1062
+
1063
+ # add additional block_kwargs specified in block_cfg or model_cfg, precedence to block if set
1064
+ block_kwargs.update(override_kwargs(block_cfg.block_kwargs, model_cfg.block_kwargs))
1065
+
1066
+
1067
+ def create_byob_stages(
1068
+ cfg: ByoModelCfg,
1069
+ drop_path_rate: float,
1070
+ output_stride: int,
1071
+ stem_feat: Dict[str, Any],
1072
+ feat_size: Optional[int] = None,
1073
+ layers: Optional[LayerFn] = None,
1074
+ block_kwargs_fn: Optional[Callable] = update_block_kwargs,
1075
+ ):
1076
+
1077
+ layers = layers or LayerFn()
1078
+ feature_info = []
1079
+ block_cfgs = [expand_blocks_cfg(s) for s in cfg.blocks]
1080
+ depths = [sum([bc.d for bc in stage_bcs]) for stage_bcs in block_cfgs]
1081
+ dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
1082
+ dilation = 1
1083
+ net_stride = stem_feat['reduction']
1084
+ prev_chs = stem_feat['num_chs']
1085
+ prev_feat = stem_feat
1086
+ stages = []
1087
+ for stage_idx, stage_block_cfgs in enumerate(block_cfgs):
1088
+ stride = stage_block_cfgs[0].s
1089
+ if stride != 1 and prev_feat:
1090
+ feature_info.append(prev_feat)
1091
+ if net_stride >= output_stride and stride > 1:
1092
+ dilation *= stride
1093
+ stride = 1
1094
+ net_stride *= stride
1095
+ first_dilation = 1 if dilation in (1, 2) else 2
1096
+
1097
+ blocks = []
1098
+ for block_idx, block_cfg in enumerate(stage_block_cfgs):
1099
+ out_chs = make_divisible(block_cfg.c * cfg.width_factor)
1100
+ group_size = block_cfg.gs
1101
+ if isinstance(group_size, Callable):
1102
+ group_size = group_size(out_chs, block_idx)
1103
+ block_kwargs = dict( # Blocks used in this model must accept these arguments
1104
+ in_chs=prev_chs,
1105
+ out_chs=out_chs,
1106
+ stride=stride if block_idx == 0 else 1,
1107
+ dilation=(first_dilation, dilation),
1108
+ group_size=group_size,
1109
+ bottle_ratio=block_cfg.br,
1110
+ downsample=cfg.downsample,
1111
+ drop_path_rate=dpr[stage_idx][block_idx],
1112
+ layers=layers,
1113
+ )
1114
+ if block_cfg.type in ('self_attn',):
1115
+ # add feat_size arg for blocks that support/need it
1116
+ block_kwargs['feat_size'] = feat_size
1117
+ block_kwargs_fn(block_kwargs, block_cfg=block_cfg, model_cfg=cfg)
1118
+ blocks += [create_block(block_cfg.type, **block_kwargs)]
1119
+ first_dilation = dilation
1120
+ prev_chs = out_chs
1121
+ if stride > 1 and block_idx == 0:
1122
+ feat_size = reduce_feat_size(feat_size, stride)
1123
+
1124
+ stages += [nn.Sequential(*blocks)]
1125
+ prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}')
1126
+
1127
+ feature_info.append(prev_feat)
1128
+ return nn.Sequential(*stages), feature_info
1129
+
1130
+
1131
+ def get_layer_fns(cfg: ByoModelCfg):
1132
+ act = get_act_layer(cfg.act_layer)
1133
+ norm_act = get_norm_act_layer(norm_layer=cfg.norm_layer, act_layer=act)
1134
+ conv_norm_act = partial(ConvNormAct, norm_layer=cfg.norm_layer, act_layer=act)
1135
+ attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
1136
+ self_attn = partial(get_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None
1137
+ layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn)
1138
+ return layer_fn
1139
+
1140
+
1141
+ class ByobNet(nn.Module):
1142
+ """ 'Bring-your-own-blocks' Net
1143
+
1144
+ A flexible network backbone that allows building model stem + blocks via
1145
+ dataclass cfg definition w/ factory functions for module instantiation.
1146
+
1147
+ Current assumption is that both stem and blocks are in conv-bn-act order (w/ block ending in act).
1148
+ """
1149
+ def __init__(
1150
+ self,
1151
+ cfg: ByoModelCfg,
1152
+ num_classes: int = 1000,
1153
+ in_chans: int = 3,
1154
+ global_pool: str = 'avg',
1155
+ output_stride: int = 32,
1156
+ img_size: Optional[Union[int, Tuple[int, int]]] = None,
1157
+ drop_rate: float = 0.,
1158
+ drop_path_rate: float =0.,
1159
+ zero_init_last: bool = True,
1160
+ **kwargs,
1161
+ ):
1162
+ """
1163
+ Args:
1164
+ cfg: Model architecture configuration.
1165
+ num_classes: Number of classifier classes.
1166
+ in_chans: Number of input channels.
1167
+ global_pool: Global pooling type.
1168
+ output_stride: Output stride of network, one of (8, 16, 32).
1169
+ img_size: Image size for fixed image size models (i.e. self-attn).
1170
+ drop_rate: Classifier dropout rate.
1171
+ drop_path_rate: Stochastic depth drop-path rate.
1172
+ zero_init_last: Zero-init last weight of residual path.
1173
+ **kwargs: Extra kwargs overlayed onto cfg.
1174
+ """
1175
+ super().__init__()
1176
+ self.num_classes = num_classes
1177
+ self.drop_rate = drop_rate
1178
+ self.grad_checkpointing = False
1179
+
1180
+ cfg = replace(cfg, **kwargs) # overlay kwargs onto cfg
1181
+ layers = get_layer_fns(cfg)
1182
+ if cfg.fixed_input_size:
1183
+ assert img_size is not None, 'img_size argument is required for fixed input size model'
1184
+ feat_size = to_2tuple(img_size) if img_size is not None else None
1185
+
1186
+ self.feature_info = []
1187
+ stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor))
1188
+ self.stem, stem_feat = create_byob_stem(in_chans, stem_chs, cfg.stem_type, cfg.stem_pool, layers=layers)
1189
+ self.feature_info.extend(stem_feat[:-1])
1190
+ feat_size = reduce_feat_size(feat_size, stride=stem_feat[-1]['reduction'])
1191
+
1192
+ self.stages, stage_feat = create_byob_stages(
1193
+ cfg,
1194
+ drop_path_rate,
1195
+ output_stride,
1196
+ stem_feat[-1],
1197
+ layers=layers,
1198
+ feat_size=feat_size,
1199
+ )
1200
+ self.feature_info.extend(stage_feat[:-1])
1201
+
1202
+ prev_chs = stage_feat[-1]['num_chs']
1203
+ if cfg.num_features:
1204
+ self.num_features = int(round(cfg.width_factor * cfg.num_features))
1205
+ self.final_conv = layers.conv_norm_act(prev_chs, self.num_features, 1)
1206
+ else:
1207
+ self.num_features = prev_chs
1208
+ self.final_conv = nn.Identity()
1209
+ self.feature_info += [
1210
+ dict(num_chs=self.num_features, reduction=stage_feat[-1]['reduction'], module='final_conv')]
1211
+
1212
+ self.head = ClassifierHead(
1213
+ self.num_features,
1214
+ num_classes,
1215
+ pool_type=global_pool,
1216
+ drop_rate=self.drop_rate,
1217
+ )
1218
+
1219
+ # init weights
1220
+ named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
1221
+
1222
+ @torch.jit.ignore
1223
+ def group_matcher(self, coarse=False):
1224
+ matcher = dict(
1225
+ stem=r'^stem',
1226
+ blocks=[
1227
+ (r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.(\d+)', None),
1228
+ (r'^final_conv', (99999,))
1229
+ ]
1230
+ )
1231
+ return matcher
1232
+
1233
+ @torch.jit.ignore
1234
+ def set_grad_checkpointing(self, enable=True):
1235
+ self.grad_checkpointing = enable
1236
+
1237
+ @torch.jit.ignore
1238
+ def get_classifier(self):
1239
+ return self.head.fc
1240
+
1241
+ def reset_classifier(self, num_classes, global_pool='avg'):
1242
+ self.head.reset(num_classes, global_pool)
1243
+
1244
+ def forward_features(self, x):
1245
+ x = self.stem(x)
1246
+ if self.grad_checkpointing and not torch.jit.is_scripting():
1247
+ x = checkpoint_seq(self.stages, x)
1248
+ else:
1249
+ x = self.stages(x)
1250
+ x = self.final_conv(x)
1251
+ return x
1252
+
1253
+ def forward_head(self, x, pre_logits: bool = False):
1254
+ return self.head(x, pre_logits=pre_logits)
1255
+
1256
+ def forward(self, x):
1257
+ x = self.forward_features(x)
1258
+ x = self.forward_head(x)
1259
+ return x
1260
+
1261
+
1262
+ def _init_weights(module, name='', zero_init_last=False):
1263
+ if isinstance(module, nn.Conv2d):
1264
+ fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
1265
+ fan_out //= module.groups
1266
+ module.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
1267
+ if module.bias is not None:
1268
+ module.bias.data.zero_()
1269
+ elif isinstance(module, nn.Linear):
1270
+ nn.init.normal_(module.weight, mean=0.0, std=0.01)
1271
+ if module.bias is not None:
1272
+ nn.init.zeros_(module.bias)
1273
+ elif isinstance(module, nn.BatchNorm2d):
1274
+ nn.init.ones_(module.weight)
1275
+ nn.init.zeros_(module.bias)
1276
+ elif hasattr(module, 'init_weights'):
1277
+ module.init_weights(zero_init_last=zero_init_last)
1278
+
1279
+
1280
+ model_cfgs = dict(
1281
+ gernet_l=ByoModelCfg(
1282
+ blocks=(
1283
+ ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.),
1284
+ ByoBlockCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.),
1285
+ ByoBlockCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4),
1286
+ ByoBlockCfg(type='bottle', d=5, c=640, s=2, gs=1, br=3.),
1287
+ ByoBlockCfg(type='bottle', d=4, c=640, s=1, gs=1, br=3.),
1288
+ ),
1289
+ stem_chs=32,
1290
+ stem_pool=None,
1291
+ num_features=2560,
1292
+ ),
1293
+ gernet_m=ByoModelCfg(
1294
+ blocks=(
1295
+ ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.),
1296
+ ByoBlockCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.),
1297
+ ByoBlockCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4),
1298
+ ByoBlockCfg(type='bottle', d=4, c=640, s=2, gs=1, br=3.),
1299
+ ByoBlockCfg(type='bottle', d=1, c=640, s=1, gs=1, br=3.),
1300
+ ),
1301
+ stem_chs=32,
1302
+ stem_pool=None,
1303
+ num_features=2560,
1304
+ ),
1305
+ gernet_s=ByoModelCfg(
1306
+ blocks=(
1307
+ ByoBlockCfg(type='basic', d=1, c=48, s=2, gs=0, br=1.),
1308
+ ByoBlockCfg(type='basic', d=3, c=48, s=2, gs=0, br=1.),
1309
+ ByoBlockCfg(type='bottle', d=7, c=384, s=2, gs=0, br=1 / 4),
1310
+ ByoBlockCfg(type='bottle', d=2, c=560, s=2, gs=1, br=3.),
1311
+ ByoBlockCfg(type='bottle', d=1, c=256, s=1, gs=1, br=3.),
1312
+ ),
1313
+ stem_chs=13,
1314
+ stem_pool=None,
1315
+ num_features=1920,
1316
+ ),
1317
+
1318
+ repvgg_a0=ByoModelCfg(
1319
+ blocks=_rep_vgg_bcfg(d=(2, 4, 14, 1), wf=(0.75, 0.75, 0.75, 2.5)),
1320
+ stem_type='rep',
1321
+ stem_chs=48,
1322
+ ),
1323
+ repvgg_a1=ByoModelCfg(
1324
+ blocks=_rep_vgg_bcfg(d=(2, 4, 14, 1), wf=(1, 1, 1, 2.5)),
1325
+ stem_type='rep',
1326
+ stem_chs=64,
1327
+ ),
1328
+ repvgg_a2=ByoModelCfg(
1329
+ blocks=_rep_vgg_bcfg(d=(2, 4, 14, 1), wf=(1.5, 1.5, 1.5, 2.75)),
1330
+ stem_type='rep',
1331
+ stem_chs=64,
1332
+ ),
1333
+ repvgg_b0=ByoModelCfg(
1334
+ blocks=_rep_vgg_bcfg(wf=(1., 1., 1., 2.5)),
1335
+ stem_type='rep',
1336
+ stem_chs=64,
1337
+ ),
1338
+ repvgg_b1=ByoModelCfg(
1339
+ blocks=_rep_vgg_bcfg(wf=(2., 2., 2., 4.)),
1340
+ stem_type='rep',
1341
+ stem_chs=64,
1342
+ ),
1343
+ repvgg_b1g4=ByoModelCfg(
1344
+ blocks=_rep_vgg_bcfg(wf=(2., 2., 2., 4.), groups=4),
1345
+ stem_type='rep',
1346
+ stem_chs=64,
1347
+ ),
1348
+ repvgg_b2=ByoModelCfg(
1349
+ blocks=_rep_vgg_bcfg(wf=(2.5, 2.5, 2.5, 5.)),
1350
+ stem_type='rep',
1351
+ stem_chs=64,
1352
+ ),
1353
+ repvgg_b2g4=ByoModelCfg(
1354
+ blocks=_rep_vgg_bcfg(wf=(2.5, 2.5, 2.5, 5.), groups=4),
1355
+ stem_type='rep',
1356
+ stem_chs=64,
1357
+ ),
1358
+ repvgg_b3=ByoModelCfg(
1359
+ blocks=_rep_vgg_bcfg(wf=(3., 3., 3., 5.)),
1360
+ stem_type='rep',
1361
+ stem_chs=64,
1362
+ ),
1363
+ repvgg_b3g4=ByoModelCfg(
1364
+ blocks=_rep_vgg_bcfg(wf=(3., 3., 3., 5.), groups=4),
1365
+ stem_type='rep',
1366
+ stem_chs=64,
1367
+ ),
1368
+ repvgg_d2se=ByoModelCfg(
1369
+ blocks=_rep_vgg_bcfg(d=(8, 14, 24, 1), wf=(2.5, 2.5, 2.5, 5.)),
1370
+ stem_type='rep',
1371
+ stem_chs=64,
1372
+ attn_layer='se',
1373
+ attn_kwargs=dict(rd_ratio=0.0625, rd_divisor=1),
1374
+ ),
1375
+
1376
+ # 4 x conv stem w/ 2 act, no maxpool, 2,4,6,4 repeats, group size 32 in first 3 blocks
1377
+ # DW convs in last block, 2048 pre-FC, silu act
1378
+ resnet51q=ByoModelCfg(
1379
+ blocks=(
1380
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
1381
+ ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25),
1382
+ ByoBlockCfg(type='bottle', d=6, c=1536, s=2, gs=32, br=0.25),
1383
+ ByoBlockCfg(type='bottle', d=4, c=1536, s=2, gs=1, br=1.0),
1384
+ ),
1385
+ stem_chs=128,
1386
+ stem_type='quad2',
1387
+ stem_pool=None,
1388
+ num_features=2048,
1389
+ act_layer='silu',
1390
+ ),
1391
+
1392
+ # 4 x conv stem w/ 4 act, no maxpool, 1,4,6,4 repeats, edge block first, group size 32 in next 2 blocks
1393
+ # DW convs in last block, 4 conv for each bottle block, 2048 pre-FC, silu act
1394
+ resnet61q=ByoModelCfg(
1395
+ blocks=(
1396
+ ByoBlockCfg(type='edge', d=1, c=256, s=1, gs=0, br=1.0, block_kwargs=dict()),
1397
+ ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25),
1398
+ ByoBlockCfg(type='bottle', d=6, c=1536, s=2, gs=32, br=0.25),
1399
+ ByoBlockCfg(type='bottle', d=4, c=1536, s=2, gs=1, br=1.0),
1400
+ ),
1401
+ stem_chs=128,
1402
+ stem_type='quad',
1403
+ stem_pool=None,
1404
+ num_features=2048,
1405
+ act_layer='silu',
1406
+ block_kwargs=dict(extra_conv=True),
1407
+ ),
1408
+
1409
+ # A series of ResNeXt-26 models w/ one of none, GC, SE, ECA, BAT attn, group size 32, SiLU act,
1410
+ # and a tiered stem w/ maxpool
1411
+ resnext26ts=ByoModelCfg(
1412
+ blocks=(
1413
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
1414
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=32, br=0.25),
1415
+ ByoBlockCfg(type='bottle', d=2, c=1024, s=2, gs=32, br=0.25),
1416
+ ByoBlockCfg(type='bottle', d=2, c=2048, s=2, gs=32, br=0.25),
1417
+ ),
1418
+ stem_chs=64,
1419
+ stem_type='tiered',
1420
+ stem_pool='maxpool',
1421
+ act_layer='silu',
1422
+ ),
1423
+ gcresnext26ts=ByoModelCfg(
1424
+ blocks=(
1425
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
1426
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=32, br=0.25),
1427
+ ByoBlockCfg(type='bottle', d=2, c=1024, s=2, gs=32, br=0.25),
1428
+ ByoBlockCfg(type='bottle', d=2, c=2048, s=2, gs=32, br=0.25),
1429
+ ),
1430
+ stem_chs=64,
1431
+ stem_type='tiered',
1432
+ stem_pool='maxpool',
1433
+ act_layer='silu',
1434
+ attn_layer='gca',
1435
+ ),
1436
+ seresnext26ts=ByoModelCfg(
1437
+ blocks=(
1438
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
1439
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=32, br=0.25),
1440
+ ByoBlockCfg(type='bottle', d=2, c=1024, s=2, gs=32, br=0.25),
1441
+ ByoBlockCfg(type='bottle', d=2, c=2048, s=2, gs=32, br=0.25),
1442
+ ),
1443
+ stem_chs=64,
1444
+ stem_type='tiered',
1445
+ stem_pool='maxpool',
1446
+ act_layer='silu',
1447
+ attn_layer='se',
1448
+ ),
1449
+ eca_resnext26ts=ByoModelCfg(
1450
+ blocks=(
1451
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
1452
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=32, br=0.25),
1453
+ ByoBlockCfg(type='bottle', d=2, c=1024, s=2, gs=32, br=0.25),
1454
+ ByoBlockCfg(type='bottle', d=2, c=2048, s=2, gs=32, br=0.25),
1455
+ ),
1456
+ stem_chs=64,
1457
+ stem_type='tiered',
1458
+ stem_pool='maxpool',
1459
+ act_layer='silu',
1460
+ attn_layer='eca',
1461
+ ),
1462
+ bat_resnext26ts=ByoModelCfg(
1463
+ blocks=(
1464
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
1465
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=32, br=0.25),
1466
+ ByoBlockCfg(type='bottle', d=2, c=1024, s=2, gs=32, br=0.25),
1467
+ ByoBlockCfg(type='bottle', d=2, c=2048, s=2, gs=32, br=0.25),
1468
+ ),
1469
+ stem_chs=64,
1470
+ stem_type='tiered',
1471
+ stem_pool='maxpool',
1472
+ act_layer='silu',
1473
+ attn_layer='bat',
1474
+ attn_kwargs=dict(block_size=8)
1475
+ ),
1476
+
1477
+ # ResNet-32 (2, 3, 3, 2) models w/ no attn, no groups, SiLU act, no pre-fc feat layer, tiered stem w/o maxpool
1478
+ resnet32ts=ByoModelCfg(
1479
+ blocks=(
1480
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
1481
+ ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
1482
+ ByoBlockCfg(type='bottle', d=3, c=1536, s=2, gs=0, br=0.25),
1483
+ ByoBlockCfg(type='bottle', d=2, c=1536, s=2, gs=0, br=0.25),
1484
+ ),
1485
+ stem_chs=64,
1486
+ stem_type='tiered',
1487
+ stem_pool='',
1488
+ num_features=0,
1489
+ act_layer='silu',
1490
+ ),
1491
+
1492
+ # ResNet-33 (2, 3, 3, 2) models w/ no attn, no groups, SiLU act, 1280 pre-FC feat, tiered stem w/o maxpool
1493
+ resnet33ts=ByoModelCfg(
1494
+ blocks=(
1495
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
1496
+ ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
1497
+ ByoBlockCfg(type='bottle', d=3, c=1536, s=2, gs=0, br=0.25),
1498
+ ByoBlockCfg(type='bottle', d=2, c=1536, s=2, gs=0, br=0.25),
1499
+ ),
1500
+ stem_chs=64,
1501
+ stem_type='tiered',
1502
+ stem_pool='',
1503
+ num_features=1280,
1504
+ act_layer='silu',
1505
+ ),
1506
+
1507
+ # A series of ResNet-33 (2, 3, 3, 2) models w/ one of GC, SE, ECA attn, no groups, SiLU act, 1280 pre-FC feat
1508
+ # and a tiered stem w/ no maxpool
1509
+ gcresnet33ts=ByoModelCfg(
1510
+ blocks=(
1511
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
1512
+ ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
1513
+ ByoBlockCfg(type='bottle', d=3, c=1536, s=2, gs=0, br=0.25),
1514
+ ByoBlockCfg(type='bottle', d=2, c=1536, s=2, gs=0, br=0.25),
1515
+ ),
1516
+ stem_chs=64,
1517
+ stem_type='tiered',
1518
+ stem_pool='',
1519
+ num_features=1280,
1520
+ act_layer='silu',
1521
+ attn_layer='gca',
1522
+ ),
1523
+ seresnet33ts=ByoModelCfg(
1524
+ blocks=(
1525
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
1526
+ ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
1527
+ ByoBlockCfg(type='bottle', d=3, c=1536, s=2, gs=0, br=0.25),
1528
+ ByoBlockCfg(type='bottle', d=2, c=1536, s=2, gs=0, br=0.25),
1529
+ ),
1530
+ stem_chs=64,
1531
+ stem_type='tiered',
1532
+ stem_pool='',
1533
+ num_features=1280,
1534
+ act_layer='silu',
1535
+ attn_layer='se',
1536
+ ),
1537
+ eca_resnet33ts=ByoModelCfg(
1538
+ blocks=(
1539
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
1540
+ ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
1541
+ ByoBlockCfg(type='bottle', d=3, c=1536, s=2, gs=0, br=0.25),
1542
+ ByoBlockCfg(type='bottle', d=2, c=1536, s=2, gs=0, br=0.25),
1543
+ ),
1544
+ stem_chs=64,
1545
+ stem_type='tiered',
1546
+ stem_pool='',
1547
+ num_features=1280,
1548
+ act_layer='silu',
1549
+ attn_layer='eca',
1550
+ ),
1551
+
1552
+ gcresnet50t=ByoModelCfg(
1553
+ blocks=(
1554
+ ByoBlockCfg(type='bottle', d=3, c=256, s=1, br=0.25),
1555
+ ByoBlockCfg(type='bottle', d=4, c=512, s=2, br=0.25),
1556
+ ByoBlockCfg(type='bottle', d=6, c=1024, s=2, br=0.25),
1557
+ ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25),
1558
+ ),
1559
+ stem_chs=64,
1560
+ stem_type='tiered',
1561
+ stem_pool='',
1562
+ attn_layer='gca',
1563
+ ),
1564
+
1565
+ gcresnext50ts=ByoModelCfg(
1566
+ blocks=(
1567
+ ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=32, br=0.25),
1568
+ ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25),
1569
+ ByoBlockCfg(type='bottle', d=6, c=1024, s=2, gs=32, br=0.25),
1570
+ ByoBlockCfg(type='bottle', d=3, c=2048, s=2, gs=32, br=0.25),
1571
+ ),
1572
+ stem_chs=64,
1573
+ stem_type='tiered',
1574
+ stem_pool='maxpool',
1575
+ act_layer='silu',
1576
+ attn_layer='gca',
1577
+ ),
1578
+
1579
+ # experimental models, closer to a RegNetZ than a ResNet. Similar to EfficientNets but w/ groups instead of DW
1580
+ regnetz_b16=ByoModelCfg(
1581
+ blocks=(
1582
+ ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=3),
1583
+ ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=3),
1584
+ ByoBlockCfg(type='bottle', d=12, c=192, s=2, gs=16, br=3),
1585
+ ByoBlockCfg(type='bottle', d=2, c=288, s=2, gs=16, br=3),
1586
+ ),
1587
+ stem_chs=32,
1588
+ stem_pool='',
1589
+ downsample='',
1590
+ num_features=1536,
1591
+ act_layer='silu',
1592
+ attn_layer='se',
1593
+ attn_kwargs=dict(rd_ratio=0.25),
1594
+ block_kwargs=dict(bottle_in=True, linear_out=True),
1595
+ ),
1596
+ regnetz_c16=ByoModelCfg(
1597
+ blocks=(
1598
+ ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=4),
1599
+ ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=4),
1600
+ ByoBlockCfg(type='bottle', d=12, c=192, s=2, gs=16, br=4),
1601
+ ByoBlockCfg(type='bottle', d=2, c=288, s=2, gs=16, br=4),
1602
+ ),
1603
+ stem_chs=32,
1604
+ stem_pool='',
1605
+ downsample='',
1606
+ num_features=1536,
1607
+ act_layer='silu',
1608
+ attn_layer='se',
1609
+ attn_kwargs=dict(rd_ratio=0.25),
1610
+ block_kwargs=dict(bottle_in=True, linear_out=True),
1611
+ ),
1612
+ regnetz_d32=ByoModelCfg(
1613
+ blocks=(
1614
+ ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=32, br=4),
1615
+ ByoBlockCfg(type='bottle', d=6, c=128, s=2, gs=32, br=4),
1616
+ ByoBlockCfg(type='bottle', d=12, c=256, s=2, gs=32, br=4),
1617
+ ByoBlockCfg(type='bottle', d=3, c=384, s=2, gs=32, br=4),
1618
+ ),
1619
+ stem_chs=64,
1620
+ stem_type='tiered',
1621
+ stem_pool='',
1622
+ downsample='',
1623
+ num_features=1792,
1624
+ act_layer='silu',
1625
+ attn_layer='se',
1626
+ attn_kwargs=dict(rd_ratio=0.25),
1627
+ block_kwargs=dict(bottle_in=True, linear_out=True),
1628
+ ),
1629
+ regnetz_d8=ByoModelCfg(
1630
+ blocks=(
1631
+ ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=8, br=4),
1632
+ ByoBlockCfg(type='bottle', d=6, c=128, s=2, gs=8, br=4),
1633
+ ByoBlockCfg(type='bottle', d=12, c=256, s=2, gs=8, br=4),
1634
+ ByoBlockCfg(type='bottle', d=3, c=384, s=2, gs=8, br=4),
1635
+ ),
1636
+ stem_chs=64,
1637
+ stem_type='tiered',
1638
+ stem_pool='',
1639
+ downsample='',
1640
+ num_features=1792,
1641
+ act_layer='silu',
1642
+ attn_layer='se',
1643
+ attn_kwargs=dict(rd_ratio=0.25),
1644
+ block_kwargs=dict(bottle_in=True, linear_out=True),
1645
+ ),
1646
+ regnetz_e8=ByoModelCfg(
1647
+ blocks=(
1648
+ ByoBlockCfg(type='bottle', d=3, c=96, s=1, gs=8, br=4),
1649
+ ByoBlockCfg(type='bottle', d=8, c=192, s=2, gs=8, br=4),
1650
+ ByoBlockCfg(type='bottle', d=16, c=384, s=2, gs=8, br=4),
1651
+ ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=8, br=4),
1652
+ ),
1653
+ stem_chs=64,
1654
+ stem_type='tiered',
1655
+ stem_pool='',
1656
+ downsample='',
1657
+ num_features=2048,
1658
+ act_layer='silu',
1659
+ attn_layer='se',
1660
+ attn_kwargs=dict(rd_ratio=0.25),
1661
+ block_kwargs=dict(bottle_in=True, linear_out=True),
1662
+ ),
1663
+
1664
+ # experimental EvoNorm configs
1665
+ regnetz_b16_evos=ByoModelCfg(
1666
+ blocks=(
1667
+ ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=3),
1668
+ ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=3),
1669
+ ByoBlockCfg(type='bottle', d=12, c=192, s=2, gs=16, br=3),
1670
+ ByoBlockCfg(type='bottle', d=2, c=288, s=2, gs=16, br=3),
1671
+ ),
1672
+ stem_chs=32,
1673
+ stem_pool='',
1674
+ downsample='',
1675
+ num_features=1536,
1676
+ act_layer='silu',
1677
+ norm_layer=partial(EvoNorm2dS0a, group_size=16),
1678
+ attn_layer='se',
1679
+ attn_kwargs=dict(rd_ratio=0.25),
1680
+ block_kwargs=dict(bottle_in=True, linear_out=True),
1681
+ ),
1682
+ regnetz_c16_evos=ByoModelCfg(
1683
+ blocks=(
1684
+ ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=4),
1685
+ ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=4),
1686
+ ByoBlockCfg(type='bottle', d=12, c=192, s=2, gs=16, br=4),
1687
+ ByoBlockCfg(type='bottle', d=2, c=288, s=2, gs=16, br=4),
1688
+ ),
1689
+ stem_chs=32,
1690
+ stem_pool='',
1691
+ downsample='',
1692
+ num_features=1536,
1693
+ act_layer='silu',
1694
+ norm_layer=partial(EvoNorm2dS0a, group_size=16),
1695
+ attn_layer='se',
1696
+ attn_kwargs=dict(rd_ratio=0.25),
1697
+ block_kwargs=dict(bottle_in=True, linear_out=True),
1698
+ ),
1699
+ regnetz_d8_evos=ByoModelCfg(
1700
+ blocks=(
1701
+ ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=8, br=4),
1702
+ ByoBlockCfg(type='bottle', d=6, c=128, s=2, gs=8, br=4),
1703
+ ByoBlockCfg(type='bottle', d=12, c=256, s=2, gs=8, br=4),
1704
+ ByoBlockCfg(type='bottle', d=3, c=384, s=2, gs=8, br=4),
1705
+ ),
1706
+ stem_chs=64,
1707
+ stem_type='deep',
1708
+ stem_pool='',
1709
+ downsample='',
1710
+ num_features=1792,
1711
+ act_layer='silu',
1712
+ norm_layer=partial(EvoNorm2dS0a, group_size=16),
1713
+ attn_layer='se',
1714
+ attn_kwargs=dict(rd_ratio=0.25),
1715
+ block_kwargs=dict(bottle_in=True, linear_out=True),
1716
+ ),
1717
+
1718
+ mobileone_s0=ByoModelCfg(
1719
+ blocks=_mobileone_bcfg(wf=(0.75, 1.0, 1.0, 2.), num_conv_branches=4),
1720
+ stem_type='one',
1721
+ stem_chs=48,
1722
+ ),
1723
+ mobileone_s1=ByoModelCfg(
1724
+ blocks=_mobileone_bcfg(wf=(1.5, 1.5, 2.0, 2.5)),
1725
+ stem_type='one',
1726
+ stem_chs=64,
1727
+ ),
1728
+ mobileone_s2=ByoModelCfg(
1729
+ blocks=_mobileone_bcfg(wf=(1.5, 2.0, 2.5, 4.0)),
1730
+ stem_type='one',
1731
+ stem_chs=64,
1732
+ ),
1733
+ mobileone_s3=ByoModelCfg(
1734
+ blocks=_mobileone_bcfg(wf=(2.0, 2.5, 3.0, 4.0)),
1735
+ stem_type='one',
1736
+ stem_chs=64,
1737
+ ),
1738
+ mobileone_s4=ByoModelCfg(
1739
+ blocks=_mobileone_bcfg(wf=(3.0, 3.5, 3.5, 4.0), se_blocks=(0, 0, 5, 1)),
1740
+ stem_type='one',
1741
+ stem_chs=64,
1742
+ ),
1743
+ )
1744
+
1745
+
1746
+ def _create_byobnet(variant, pretrained=False, **kwargs):
1747
+ return build_model_with_cfg(
1748
+ ByobNet, variant, pretrained,
1749
+ model_cfg=model_cfgs[variant],
1750
+ feature_cfg=dict(flatten_sequential=True),
1751
+ **kwargs)
1752
+
1753
+
1754
+ def _cfg(url='', **kwargs):
1755
+ return {
1756
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
1757
+ 'crop_pct': 0.875, 'interpolation': 'bilinear',
1758
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
1759
+ 'first_conv': 'stem.conv', 'classifier': 'head.fc',
1760
+ **kwargs
1761
+ }
1762
+
1763
+
1764
+ def _cfgr(url='', **kwargs):
1765
+ return {
1766
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
1767
+ 'crop_pct': 0.9, 'interpolation': 'bicubic',
1768
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
1769
+ 'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc',
1770
+ **kwargs
1771
+ }
1772
+
1773
+
1774
+ default_cfgs = generate_default_cfgs({
1775
+ # GPU-Efficient (ResNet) weights
1776
+ 'gernet_s.idstcv_in1k': _cfg(hf_hub_id='timm/'),
1777
+ 'gernet_m.idstcv_in1k': _cfg(hf_hub_id='timm/'),
1778
+ 'gernet_l.idstcv_in1k': _cfg(hf_hub_id='timm/', input_size=(3, 256, 256), pool_size=(8, 8)),
1779
+
1780
+ # RepVGG weights
1781
+ 'repvgg_a0.rvgg_in1k': _cfg(
1782
+ hf_hub_id='timm/',
1783
+ first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
1784
+ 'repvgg_a1.rvgg_in1k': _cfg(
1785
+ hf_hub_id='timm/',
1786
+ first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
1787
+ 'repvgg_a2.rvgg_in1k': _cfg(
1788
+ hf_hub_id='timm/',
1789
+ first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
1790
+ 'repvgg_b0.rvgg_in1k': _cfg(
1791
+ hf_hub_id='timm/',
1792
+ first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
1793
+ 'repvgg_b1.rvgg_in1k': _cfg(
1794
+ hf_hub_id='timm/',
1795
+ first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
1796
+ 'repvgg_b1g4.rvgg_in1k': _cfg(
1797
+ hf_hub_id='timm/',
1798
+ first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
1799
+ 'repvgg_b2.rvgg_in1k': _cfg(
1800
+ hf_hub_id='timm/',
1801
+ first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
1802
+ 'repvgg_b2g4.rvgg_in1k': _cfg(
1803
+ hf_hub_id='timm/',
1804
+ first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
1805
+ 'repvgg_b3.rvgg_in1k': _cfg(
1806
+ hf_hub_id='timm/',
1807
+ first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
1808
+ 'repvgg_b3g4.rvgg_in1k': _cfg(
1809
+ hf_hub_id='timm/',
1810
+ first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
1811
+ 'repvgg_d2se.rvgg_in1k': _cfg(
1812
+ hf_hub_id='timm/',
1813
+ first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit',
1814
+ input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0,
1815
+ ),
1816
+
1817
+ # experimental ResNet configs
1818
+ 'resnet51q.ra2_in1k': _cfg(
1819
+ hf_hub_id='timm/',
1820
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet51q_ra2-d47dcc76.pth',
1821
+ first_conv='stem.conv1', input_size=(3, 256, 256), pool_size=(8, 8),
1822
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
1823
+ 'resnet61q.ra2_in1k': _cfgr(
1824
+ hf_hub_id='timm/',
1825
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet61q_ra2-6afc536c.pth',
1826
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
1827
+
1828
+ # ResNeXt-26 models with different attention in Bottleneck blocks
1829
+ 'resnext26ts.ra2_in1k': _cfgr(
1830
+ hf_hub_id='timm/',
1831
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnext26ts_256_ra2-8bbd9106.pth',
1832
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
1833
+ 'seresnext26ts.ch_in1k': _cfgr(
1834
+ hf_hub_id='timm/',
1835
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/seresnext26ts_256-6f0d74a3.pth',
1836
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
1837
+ 'gcresnext26ts.ch_in1k': _cfgr(
1838
+ hf_hub_id='timm/',
1839
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext26ts_256-e414378b.pth',
1840
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
1841
+ 'eca_resnext26ts.ch_in1k': _cfgr(
1842
+ hf_hub_id='timm/',
1843
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_resnext26ts_256-5a1d030f.pth',
1844
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
1845
+ 'bat_resnext26ts.ch_in1k': _cfgr(
1846
+ hf_hub_id='timm/',
1847
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/bat_resnext26ts_256-fa6fd595.pth',
1848
+ min_input_size=(3, 256, 256)),
1849
+
1850
+ # ResNet-32 / 33 models with different attention in Bottleneck blocks
1851
+ 'resnet32ts.ra2_in1k': _cfgr(
1852
+ hf_hub_id='timm/',
1853
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet32ts_256-aacf5250.pth',
1854
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
1855
+ 'resnet33ts.ra2_in1k': _cfgr(
1856
+ hf_hub_id='timm/',
1857
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet33ts_256-e91b09a4.pth',
1858
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
1859
+ 'gcresnet33ts.ra2_in1k': _cfgr(
1860
+ hf_hub_id='timm/',
1861
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnet33ts_256-0e0cd345.pth',
1862
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
1863
+ 'seresnet33ts.ra2_in1k': _cfgr(
1864
+ hf_hub_id='timm/',
1865
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/seresnet33ts_256-f8ad44d9.pth',
1866
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
1867
+ 'eca_resnet33ts.ra2_in1k': _cfgr(
1868
+ hf_hub_id='timm/',
1869
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_resnet33ts_256-8f98face.pth',
1870
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
1871
+
1872
+ 'gcresnet50t.ra2_in1k': _cfgr(
1873
+ hf_hub_id='timm/',
1874
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnet50t_256-96374d1c.pth',
1875
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
1876
+
1877
+ 'gcresnext50ts.ch_in1k': _cfgr(
1878
+ hf_hub_id='timm/',
1879
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext50ts_256-3e0f515e.pth',
1880
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
1881
+
1882
+ # custom `timm` specific RegNetZ inspired models w/ different sizing from paper
1883
+ 'regnetz_b16.ra3_in1k': _cfgr(
1884
+ hf_hub_id='timm/',
1885
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_b_raa-677d9606.pth',
1886
+ first_conv='stem.conv', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
1887
+ input_size=(3, 224, 224), pool_size=(7, 7), crop_pct=0.94, test_input_size=(3, 288, 288), test_crop_pct=1.0),
1888
+ 'regnetz_c16.ra3_in1k': _cfgr(
1889
+ hf_hub_id='timm/',
1890
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_c_rab2_256-a54bf36a.pth',
1891
+ first_conv='stem.conv', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
1892
+ crop_pct=0.94, test_input_size=(3, 320, 320), test_crop_pct=1.0),
1893
+ 'regnetz_d32.ra3_in1k': _cfgr(
1894
+ hf_hub_id='timm/',
1895
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_d_rab_256-b8073a89.pth',
1896
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.95, test_input_size=(3, 320, 320)),
1897
+ 'regnetz_d8.ra3_in1k': _cfgr(
1898
+ hf_hub_id='timm/',
1899
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_d8_bh-afc03c55.pth',
1900
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.94, test_input_size=(3, 320, 320), test_crop_pct=1.0),
1901
+ 'regnetz_e8.ra3_in1k': _cfgr(
1902
+ hf_hub_id='timm/',
1903
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_e8_bh-aace8e6e.pth',
1904
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.94, test_input_size=(3, 320, 320), test_crop_pct=1.0),
1905
+
1906
+ 'regnetz_b16_evos.untrained': _cfgr(
1907
+ first_conv='stem.conv', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
1908
+ input_size=(3, 224, 224), pool_size=(7, 7), crop_pct=0.95, test_input_size=(3, 288, 288)),
1909
+ 'regnetz_c16_evos.ch_in1k': _cfgr(
1910
+ hf_hub_id='timm/',
1911
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_c16_evos_ch-d8311942.pth',
1912
+ first_conv='stem.conv', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
1913
+ crop_pct=0.95, test_input_size=(3, 320, 320)),
1914
+ 'regnetz_d8_evos.ch_in1k': _cfgr(
1915
+ hf_hub_id='timm/',
1916
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_d8_evos_ch-2bc12646.pth',
1917
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0),
1918
+
1919
+ 'mobileone_s0.apple_in1k': _cfg(
1920
+ hf_hub_id='timm/',
1921
+ crop_pct=0.875,
1922
+ first_conv=('stem.conv_kxk.0.conv', 'stem.conv_scale.conv'),
1923
+ ),
1924
+ 'mobileone_s1.apple_in1k': _cfg(
1925
+ hf_hub_id='timm/',
1926
+ crop_pct=0.9,
1927
+ first_conv=('stem.conv_kxk.0.conv', 'stem.conv_scale.conv'),
1928
+ ),
1929
+ 'mobileone_s2.apple_in1k': _cfg(
1930
+ hf_hub_id='timm/',
1931
+ crop_pct=0.9,
1932
+ first_conv=('stem.conv_kxk.0.conv', 'stem.conv_scale.conv'),
1933
+ ),
1934
+ 'mobileone_s3.apple_in1k': _cfg(
1935
+ hf_hub_id='timm/',
1936
+ crop_pct=0.9,
1937
+ first_conv=('stem.conv_kxk.0.conv', 'stem.conv_scale.conv'),
1938
+ ),
1939
+ 'mobileone_s4.apple_in1k': _cfg(
1940
+ hf_hub_id='timm/',
1941
+ crop_pct=0.9,
1942
+ first_conv=('stem.conv_kxk.0.conv', 'stem.conv_scale.conv'),
1943
+ ),
1944
+ })
1945
+
1946
+
1947
+ @register_model
1948
+ def gernet_l(pretrained=False, **kwargs) -> ByobNet:
1949
+ """ GEResNet-Large (GENet-Large from official impl)
1950
+ `Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
1951
+ """
1952
+ return _create_byobnet('gernet_l', pretrained=pretrained, **kwargs)
1953
+
1954
+
1955
+ @register_model
1956
+ def gernet_m(pretrained=False, **kwargs) -> ByobNet:
1957
+ """ GEResNet-Medium (GENet-Normal from official impl)
1958
+ `Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
1959
+ """
1960
+ return _create_byobnet('gernet_m', pretrained=pretrained, **kwargs)
1961
+
1962
+
1963
+ @register_model
1964
+ def gernet_s(pretrained=False, **kwargs) -> ByobNet:
1965
+ """ EResNet-Small (GENet-Small from official impl)
1966
+ `Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
1967
+ """
1968
+ return _create_byobnet('gernet_s', pretrained=pretrained, **kwargs)
1969
+
1970
+
1971
+ @register_model
1972
+ def repvgg_a0(pretrained=False, **kwargs) -> ByobNet:
1973
+ """ RepVGG-A0
1974
+ `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
1975
+ """
1976
+ return _create_byobnet('repvgg_a0', pretrained=pretrained, **kwargs)
1977
+
1978
+
1979
+ @register_model
1980
+ def repvgg_a1(pretrained=False, **kwargs) -> ByobNet:
1981
+ """ RepVGG-A1
1982
+ `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
1983
+ """
1984
+ return _create_byobnet('repvgg_a1', pretrained=pretrained, **kwargs)
1985
+
1986
+
1987
+ @register_model
1988
+ def repvgg_a2(pretrained=False, **kwargs) -> ByobNet:
1989
+ """ RepVGG-A2
1990
+ `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
1991
+ """
1992
+ return _create_byobnet('repvgg_a2', pretrained=pretrained, **kwargs)
1993
+
1994
+
1995
+ @register_model
1996
+ def repvgg_b0(pretrained=False, **kwargs) -> ByobNet:
1997
+ """ RepVGG-B0
1998
+ `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
1999
+ """
2000
+ return _create_byobnet('repvgg_b0', pretrained=pretrained, **kwargs)
2001
+
2002
+
2003
+ @register_model
2004
+ def repvgg_b1(pretrained=False, **kwargs) -> ByobNet:
2005
+ """ RepVGG-B1
2006
+ `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
2007
+ """
2008
+ return _create_byobnet('repvgg_b1', pretrained=pretrained, **kwargs)
2009
+
2010
+
2011
+ @register_model
2012
+ def repvgg_b1g4(pretrained=False, **kwargs) -> ByobNet:
2013
+ """ RepVGG-B1g4
2014
+ `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
2015
+ """
2016
+ return _create_byobnet('repvgg_b1g4', pretrained=pretrained, **kwargs)
2017
+
2018
+
2019
+ @register_model
2020
+ def repvgg_b2(pretrained=False, **kwargs) -> ByobNet:
2021
+ """ RepVGG-B2
2022
+ `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
2023
+ """
2024
+ return _create_byobnet('repvgg_b2', pretrained=pretrained, **kwargs)
2025
+
2026
+
2027
+ @register_model
2028
+ def repvgg_b2g4(pretrained=False, **kwargs) -> ByobNet:
2029
+ """ RepVGG-B2g4
2030
+ `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
2031
+ """
2032
+ return _create_byobnet('repvgg_b2g4', pretrained=pretrained, **kwargs)
2033
+
2034
+
2035
+ @register_model
2036
+ def repvgg_b3(pretrained=False, **kwargs) -> ByobNet:
2037
+ """ RepVGG-B3
2038
+ `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
2039
+ """
2040
+ return _create_byobnet('repvgg_b3', pretrained=pretrained, **kwargs)
2041
+
2042
+
2043
+ @register_model
2044
+ def repvgg_b3g4(pretrained=False, **kwargs) -> ByobNet:
2045
+ """ RepVGG-B3g4
2046
+ `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
2047
+ """
2048
+ return _create_byobnet('repvgg_b3g4', pretrained=pretrained, **kwargs)
2049
+
2050
+
2051
+ @register_model
2052
+ def repvgg_d2se(pretrained=False, **kwargs) -> ByobNet:
2053
+ """ RepVGG-D2se
2054
+ `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
2055
+ """
2056
+ return _create_byobnet('repvgg_d2se', pretrained=pretrained, **kwargs)
2057
+
2058
+
2059
+ @register_model
2060
+ def resnet51q(pretrained=False, **kwargs) -> ByobNet:
2061
+ """
2062
+ """
2063
+ return _create_byobnet('resnet51q', pretrained=pretrained, **kwargs)
2064
+
2065
+
2066
+ @register_model
2067
+ def resnet61q(pretrained=False, **kwargs) -> ByobNet:
2068
+ """
2069
+ """
2070
+ return _create_byobnet('resnet61q', pretrained=pretrained, **kwargs)
2071
+
2072
+
2073
+ @register_model
2074
+ def resnext26ts(pretrained=False, **kwargs) -> ByobNet:
2075
+ """
2076
+ """
2077
+ return _create_byobnet('resnext26ts', pretrained=pretrained, **kwargs)
2078
+
2079
+
2080
+ @register_model
2081
+ def gcresnext26ts(pretrained=False, **kwargs) -> ByobNet:
2082
+ """
2083
+ """
2084
+ return _create_byobnet('gcresnext26ts', pretrained=pretrained, **kwargs)
2085
+
2086
+
2087
+ @register_model
2088
+ def seresnext26ts(pretrained=False, **kwargs) -> ByobNet:
2089
+ """
2090
+ """
2091
+ return _create_byobnet('seresnext26ts', pretrained=pretrained, **kwargs)
2092
+
2093
+
2094
+ @register_model
2095
+ def eca_resnext26ts(pretrained=False, **kwargs) -> ByobNet:
2096
+ """
2097
+ """
2098
+ return _create_byobnet('eca_resnext26ts', pretrained=pretrained, **kwargs)
2099
+
2100
+
2101
+ @register_model
2102
+ def bat_resnext26ts(pretrained=False, **kwargs) -> ByobNet:
2103
+ """
2104
+ """
2105
+ return _create_byobnet('bat_resnext26ts', pretrained=pretrained, **kwargs)
2106
+
2107
+
2108
+ @register_model
2109
+ def resnet32ts(pretrained=False, **kwargs) -> ByobNet:
2110
+ """
2111
+ """
2112
+ return _create_byobnet('resnet32ts', pretrained=pretrained, **kwargs)
2113
+
2114
+
2115
+ @register_model
2116
+ def resnet33ts(pretrained=False, **kwargs) -> ByobNet:
2117
+ """
2118
+ """
2119
+ return _create_byobnet('resnet33ts', pretrained=pretrained, **kwargs)
2120
+
2121
+
2122
+ @register_model
2123
+ def gcresnet33ts(pretrained=False, **kwargs) -> ByobNet:
2124
+ """
2125
+ """
2126
+ return _create_byobnet('gcresnet33ts', pretrained=pretrained, **kwargs)
2127
+
2128
+
2129
+ @register_model
2130
+ def seresnet33ts(pretrained=False, **kwargs) -> ByobNet:
2131
+ """
2132
+ """
2133
+ return _create_byobnet('seresnet33ts', pretrained=pretrained, **kwargs)
2134
+
2135
+
2136
+ @register_model
2137
+ def eca_resnet33ts(pretrained=False, **kwargs) -> ByobNet:
2138
+ """
2139
+ """
2140
+ return _create_byobnet('eca_resnet33ts', pretrained=pretrained, **kwargs)
2141
+
2142
+
2143
+ @register_model
2144
+ def gcresnet50t(pretrained=False, **kwargs) -> ByobNet:
2145
+ """
2146
+ """
2147
+ return _create_byobnet('gcresnet50t', pretrained=pretrained, **kwargs)
2148
+
2149
+
2150
+ @register_model
2151
+ def gcresnext50ts(pretrained=False, **kwargs) -> ByobNet:
2152
+ """
2153
+ """
2154
+ return _create_byobnet('gcresnext50ts', pretrained=pretrained, **kwargs)
2155
+
2156
+
2157
+ @register_model
2158
+ def regnetz_b16(pretrained=False, **kwargs) -> ByobNet:
2159
+ """
2160
+ """
2161
+ return _create_byobnet('regnetz_b16', pretrained=pretrained, **kwargs)
2162
+
2163
+
2164
+ @register_model
2165
+ def regnetz_c16(pretrained=False, **kwargs) -> ByobNet:
2166
+ """
2167
+ """
2168
+ return _create_byobnet('regnetz_c16', pretrained=pretrained, **kwargs)
2169
+
2170
+
2171
+ @register_model
2172
+ def regnetz_d32(pretrained=False, **kwargs) -> ByobNet:
2173
+ """
2174
+ """
2175
+ return _create_byobnet('regnetz_d32', pretrained=pretrained, **kwargs)
2176
+
2177
+
2178
+ @register_model
2179
+ def regnetz_d8(pretrained=False, **kwargs) -> ByobNet:
2180
+ """
2181
+ """
2182
+ return _create_byobnet('regnetz_d8', pretrained=pretrained, **kwargs)
2183
+
2184
+
2185
+ @register_model
2186
+ def regnetz_e8(pretrained=False, **kwargs) -> ByobNet:
2187
+ """
2188
+ """
2189
+ return _create_byobnet('regnetz_e8', pretrained=pretrained, **kwargs)
2190
+
2191
+
2192
+ @register_model
2193
+ def regnetz_b16_evos(pretrained=False, **kwargs) -> ByobNet:
2194
+ """
2195
+ """
2196
+ return _create_byobnet('regnetz_b16_evos', pretrained=pretrained, **kwargs)
2197
+
2198
+
2199
+ @register_model
2200
+ def regnetz_c16_evos(pretrained=False, **kwargs) -> ByobNet:
2201
+ """
2202
+ """
2203
+ return _create_byobnet('regnetz_c16_evos', pretrained=pretrained, **kwargs)
2204
+
2205
+
2206
+ @register_model
2207
+ def regnetz_d8_evos(pretrained=False, **kwargs) -> ByobNet:
2208
+ """
2209
+ """
2210
+ return _create_byobnet('regnetz_d8_evos', pretrained=pretrained, **kwargs)
2211
+
2212
+
2213
+ @register_model
2214
+ def mobileone_s0(pretrained=False, **kwargs) -> ByobNet:
2215
+ """
2216
+ """
2217
+ return _create_byobnet('mobileone_s0', pretrained=pretrained, **kwargs)
2218
+
2219
+
2220
+ @register_model
2221
+ def mobileone_s1(pretrained=False, **kwargs) -> ByobNet:
2222
+ """
2223
+ """
2224
+ return _create_byobnet('mobileone_s1', pretrained=pretrained, **kwargs)
2225
+
2226
+
2227
+ @register_model
2228
+ def mobileone_s2(pretrained=False, **kwargs) -> ByobNet:
2229
+ """
2230
+ """
2231
+ return _create_byobnet('mobileone_s2', pretrained=pretrained, **kwargs)
2232
+
2233
+
2234
+ @register_model
2235
+ def mobileone_s3(pretrained=False, **kwargs) -> ByobNet:
2236
+ """
2237
+ """
2238
+ return _create_byobnet('mobileone_s3', pretrained=pretrained, **kwargs)
2239
+
2240
+
2241
+ @register_model
2242
+ def mobileone_s4(pretrained=False, **kwargs) -> ByobNet:
2243
+ """
2244
+ """
2245
+ return _create_byobnet('mobileone_s4', pretrained=pretrained, **kwargs)
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/coat.py ADDED
@@ -0,0 +1,804 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CoaT architecture.
3
+
4
+ Paper: Co-Scale Conv-Attentional Image Transformers - https://arxiv.org/abs/2104.06399
5
+
6
+ Official CoaT code at: https://github.com/mlpc-ucsd/CoaT
7
+
8
+ Modified from timm/models/vision_transformer.py
9
+ """
10
+ from functools import partial
11
+ from typing import Tuple, List, Union
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
18
+ from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, LayerNorm
19
+ from ._builder import build_model_with_cfg
20
+ from ._registry import register_model, generate_default_cfgs
21
+
22
+ __all__ = ['CoaT']
23
+
24
+
25
+ class ConvRelPosEnc(nn.Module):
26
+ """ Convolutional relative position encoding. """
27
+ def __init__(self, head_chs, num_heads, window):
28
+ """
29
+ Initialization.
30
+ Ch: Channels per head.
31
+ h: Number of heads.
32
+ window: Window size(s) in convolutional relative positional encoding. It can have two forms:
33
+ 1. An integer of window size, which assigns all attention heads with the same window s
34
+ size in ConvRelPosEnc.
35
+ 2. A dict mapping window size to #attention head splits (
36
+ e.g. {window size 1: #attention head split 1, window size 2: #attention head split 2})
37
+ It will apply different window size to the attention head splits.
38
+ """
39
+ super().__init__()
40
+
41
+ if isinstance(window, int):
42
+ # Set the same window size for all attention heads.
43
+ window = {window: num_heads}
44
+ self.window = window
45
+ elif isinstance(window, dict):
46
+ self.window = window
47
+ else:
48
+ raise ValueError()
49
+
50
+ self.conv_list = nn.ModuleList()
51
+ self.head_splits = []
52
+ for cur_window, cur_head_split in window.items():
53
+ dilation = 1
54
+ # Determine padding size.
55
+ # Ref: https://discuss.pytorch.org/t/how-to-keep-the-shape-of-input-and-output-same-when-dilation-conv/14338
56
+ padding_size = (cur_window + (cur_window - 1) * (dilation - 1)) // 2
57
+ cur_conv = nn.Conv2d(
58
+ cur_head_split * head_chs,
59
+ cur_head_split * head_chs,
60
+ kernel_size=(cur_window, cur_window),
61
+ padding=(padding_size, padding_size),
62
+ dilation=(dilation, dilation),
63
+ groups=cur_head_split * head_chs,
64
+ )
65
+ self.conv_list.append(cur_conv)
66
+ self.head_splits.append(cur_head_split)
67
+ self.channel_splits = [x * head_chs for x in self.head_splits]
68
+
69
+ def forward(self, q, v, size: Tuple[int, int]):
70
+ B, num_heads, N, C = q.shape
71
+ H, W = size
72
+ _assert(N == 1 + H * W, '')
73
+
74
+ # Convolutional relative position encoding.
75
+ q_img = q[:, :, 1:, :] # [B, h, H*W, Ch]
76
+ v_img = v[:, :, 1:, :] # [B, h, H*W, Ch]
77
+
78
+ v_img = v_img.transpose(-1, -2).reshape(B, num_heads * C, H, W)
79
+ v_img_list = torch.split(v_img, self.channel_splits, dim=1) # Split according to channels
80
+ conv_v_img_list = []
81
+ for i, conv in enumerate(self.conv_list):
82
+ conv_v_img_list.append(conv(v_img_list[i]))
83
+ conv_v_img = torch.cat(conv_v_img_list, dim=1)
84
+ conv_v_img = conv_v_img.reshape(B, num_heads, C, H * W).transpose(-1, -2)
85
+
86
+ EV_hat = q_img * conv_v_img
87
+ EV_hat = F.pad(EV_hat, (0, 0, 1, 0, 0, 0)) # [B, h, N, Ch].
88
+ return EV_hat
89
+
90
+
91
+ class FactorAttnConvRelPosEnc(nn.Module):
92
+ """ Factorized attention with convolutional relative position encoding class. """
93
+ def __init__(
94
+ self,
95
+ dim,
96
+ num_heads=8,
97
+ qkv_bias=False,
98
+ attn_drop=0.,
99
+ proj_drop=0.,
100
+ shared_crpe=None,
101
+ ):
102
+ super().__init__()
103
+ self.num_heads = num_heads
104
+ head_dim = dim // num_heads
105
+ self.scale = head_dim ** -0.5
106
+
107
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
108
+ self.attn_drop = nn.Dropout(attn_drop) # Note: attn_drop is actually not used.
109
+ self.proj = nn.Linear(dim, dim)
110
+ self.proj_drop = nn.Dropout(proj_drop)
111
+
112
+ # Shared convolutional relative position encoding.
113
+ self.crpe = shared_crpe
114
+
115
+ def forward(self, x, size: Tuple[int, int]):
116
+ B, N, C = x.shape
117
+
118
+ # Generate Q, K, V.
119
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
120
+ q, k, v = qkv.unbind(0) # [B, h, N, Ch]
121
+
122
+ # Factorized attention.
123
+ k_softmax = k.softmax(dim=2)
124
+ factor_att = k_softmax.transpose(-1, -2) @ v
125
+ factor_att = q @ factor_att
126
+
127
+ # Convolutional relative position encoding.
128
+ crpe = self.crpe(q, v, size=size) # [B, h, N, Ch]
129
+
130
+ # Merge and reshape.
131
+ x = self.scale * factor_att + crpe
132
+ x = x.transpose(1, 2).reshape(B, N, C) # [B, h, N, Ch] -> [B, N, h, Ch] -> [B, N, C]
133
+
134
+ # Output projection.
135
+ x = self.proj(x)
136
+ x = self.proj_drop(x)
137
+
138
+ return x
139
+
140
+
141
+ class ConvPosEnc(nn.Module):
142
+ """ Convolutional Position Encoding.
143
+ Note: This module is similar to the conditional position encoding in CPVT.
144
+ """
145
+ def __init__(self, dim, k=3):
146
+ super(ConvPosEnc, self).__init__()
147
+ self.proj = nn.Conv2d(dim, dim, k, 1, k//2, groups=dim)
148
+
149
+ def forward(self, x, size: Tuple[int, int]):
150
+ B, N, C = x.shape
151
+ H, W = size
152
+ _assert(N == 1 + H * W, '')
153
+
154
+ # Extract CLS token and image tokens.
155
+ cls_token, img_tokens = x[:, :1], x[:, 1:] # [B, 1, C], [B, H*W, C]
156
+
157
+ # Depthwise convolution.
158
+ feat = img_tokens.transpose(1, 2).view(B, C, H, W)
159
+ x = self.proj(feat) + feat
160
+ x = x.flatten(2).transpose(1, 2)
161
+
162
+ # Combine with CLS token.
163
+ x = torch.cat((cls_token, x), dim=1)
164
+
165
+ return x
166
+
167
+
168
+ class SerialBlock(nn.Module):
169
+ """ Serial block class.
170
+ Note: In this implementation, each serial block only contains a conv-attention and a FFN (MLP) module. """
171
+ def __init__(
172
+ self,
173
+ dim,
174
+ num_heads,
175
+ mlp_ratio=4.,
176
+ qkv_bias=False,
177
+ proj_drop=0.,
178
+ attn_drop=0.,
179
+ drop_path=0.,
180
+ act_layer=nn.GELU,
181
+ norm_layer=nn.LayerNorm,
182
+ shared_cpe=None,
183
+ shared_crpe=None,
184
+ ):
185
+ super().__init__()
186
+
187
+ # Conv-Attention.
188
+ self.cpe = shared_cpe
189
+
190
+ self.norm1 = norm_layer(dim)
191
+ self.factoratt_crpe = FactorAttnConvRelPosEnc(
192
+ dim,
193
+ num_heads=num_heads,
194
+ qkv_bias=qkv_bias,
195
+ attn_drop=attn_drop,
196
+ proj_drop=proj_drop,
197
+ shared_crpe=shared_crpe,
198
+ )
199
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
200
+
201
+ # MLP.
202
+ self.norm2 = norm_layer(dim)
203
+ mlp_hidden_dim = int(dim * mlp_ratio)
204
+ self.mlp = Mlp(
205
+ in_features=dim,
206
+ hidden_features=mlp_hidden_dim,
207
+ act_layer=act_layer,
208
+ drop=proj_drop,
209
+ )
210
+
211
+ def forward(self, x, size: Tuple[int, int]):
212
+ # Conv-Attention.
213
+ x = self.cpe(x, size)
214
+ cur = self.norm1(x)
215
+ cur = self.factoratt_crpe(cur, size)
216
+ x = x + self.drop_path(cur)
217
+
218
+ # MLP.
219
+ cur = self.norm2(x)
220
+ cur = self.mlp(cur)
221
+ x = x + self.drop_path(cur)
222
+
223
+ return x
224
+
225
+
226
+ class ParallelBlock(nn.Module):
227
+ """ Parallel block class. """
228
+ def __init__(
229
+ self,
230
+ dims,
231
+ num_heads,
232
+ mlp_ratios=[],
233
+ qkv_bias=False,
234
+ proj_drop=0.,
235
+ attn_drop=0.,
236
+ drop_path=0.,
237
+ act_layer=nn.GELU,
238
+ norm_layer=nn.LayerNorm,
239
+ shared_crpes=None,
240
+ ):
241
+ super().__init__()
242
+
243
+ # Conv-Attention.
244
+ self.norm12 = norm_layer(dims[1])
245
+ self.norm13 = norm_layer(dims[2])
246
+ self.norm14 = norm_layer(dims[3])
247
+ self.factoratt_crpe2 = FactorAttnConvRelPosEnc(
248
+ dims[1],
249
+ num_heads=num_heads,
250
+ qkv_bias=qkv_bias,
251
+ attn_drop=attn_drop,
252
+ proj_drop=proj_drop,
253
+ shared_crpe=shared_crpes[1],
254
+ )
255
+ self.factoratt_crpe3 = FactorAttnConvRelPosEnc(
256
+ dims[2],
257
+ num_heads=num_heads,
258
+ qkv_bias=qkv_bias,
259
+ attn_drop=attn_drop,
260
+ proj_drop=proj_drop,
261
+ shared_crpe=shared_crpes[2],
262
+ )
263
+ self.factoratt_crpe4 = FactorAttnConvRelPosEnc(
264
+ dims[3],
265
+ num_heads=num_heads,
266
+ qkv_bias=qkv_bias,
267
+ attn_drop=attn_drop,
268
+ proj_drop=proj_drop,
269
+ shared_crpe=shared_crpes[3],
270
+ )
271
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
272
+
273
+ # MLP.
274
+ self.norm22 = norm_layer(dims[1])
275
+ self.norm23 = norm_layer(dims[2])
276
+ self.norm24 = norm_layer(dims[3])
277
+ # In parallel block, we assume dimensions are the same and share the linear transformation.
278
+ assert dims[1] == dims[2] == dims[3]
279
+ assert mlp_ratios[1] == mlp_ratios[2] == mlp_ratios[3]
280
+ mlp_hidden_dim = int(dims[1] * mlp_ratios[1])
281
+ self.mlp2 = self.mlp3 = self.mlp4 = Mlp(
282
+ in_features=dims[1],
283
+ hidden_features=mlp_hidden_dim,
284
+ act_layer=act_layer,
285
+ drop=proj_drop,
286
+ )
287
+
288
+ def upsample(self, x, factor: float, size: Tuple[int, int]):
289
+ """ Feature map up-sampling. """
290
+ return self.interpolate(x, scale_factor=factor, size=size)
291
+
292
+ def downsample(self, x, factor: float, size: Tuple[int, int]):
293
+ """ Feature map down-sampling. """
294
+ return self.interpolate(x, scale_factor=1.0/factor, size=size)
295
+
296
+ def interpolate(self, x, scale_factor: float, size: Tuple[int, int]):
297
+ """ Feature map interpolation. """
298
+ B, N, C = x.shape
299
+ H, W = size
300
+ _assert(N == 1 + H * W, '')
301
+
302
+ cls_token = x[:, :1, :]
303
+ img_tokens = x[:, 1:, :]
304
+
305
+ img_tokens = img_tokens.transpose(1, 2).reshape(B, C, H, W)
306
+ img_tokens = F.interpolate(
307
+ img_tokens,
308
+ scale_factor=scale_factor,
309
+ recompute_scale_factor=False,
310
+ mode='bilinear',
311
+ align_corners=False,
312
+ )
313
+ img_tokens = img_tokens.reshape(B, C, -1).transpose(1, 2)
314
+
315
+ out = torch.cat((cls_token, img_tokens), dim=1)
316
+
317
+ return out
318
+
319
+ def forward(self, x1, x2, x3, x4, sizes: List[Tuple[int, int]]):
320
+ _, S2, S3, S4 = sizes
321
+ cur2 = self.norm12(x2)
322
+ cur3 = self.norm13(x3)
323
+ cur4 = self.norm14(x4)
324
+ cur2 = self.factoratt_crpe2(cur2, size=S2)
325
+ cur3 = self.factoratt_crpe3(cur3, size=S3)
326
+ cur4 = self.factoratt_crpe4(cur4, size=S4)
327
+ upsample3_2 = self.upsample(cur3, factor=2., size=S3)
328
+ upsample4_3 = self.upsample(cur4, factor=2., size=S4)
329
+ upsample4_2 = self.upsample(cur4, factor=4., size=S4)
330
+ downsample2_3 = self.downsample(cur2, factor=2., size=S2)
331
+ downsample3_4 = self.downsample(cur3, factor=2., size=S3)
332
+ downsample2_4 = self.downsample(cur2, factor=4., size=S2)
333
+ cur2 = cur2 + upsample3_2 + upsample4_2
334
+ cur3 = cur3 + upsample4_3 + downsample2_3
335
+ cur4 = cur4 + downsample3_4 + downsample2_4
336
+ x2 = x2 + self.drop_path(cur2)
337
+ x3 = x3 + self.drop_path(cur3)
338
+ x4 = x4 + self.drop_path(cur4)
339
+
340
+ # MLP.
341
+ cur2 = self.norm22(x2)
342
+ cur3 = self.norm23(x3)
343
+ cur4 = self.norm24(x4)
344
+ cur2 = self.mlp2(cur2)
345
+ cur3 = self.mlp3(cur3)
346
+ cur4 = self.mlp4(cur4)
347
+ x2 = x2 + self.drop_path(cur2)
348
+ x3 = x3 + self.drop_path(cur3)
349
+ x4 = x4 + self.drop_path(cur4)
350
+
351
+ return x1, x2, x3, x4
352
+
353
+
354
+ class CoaT(nn.Module):
355
+ """ CoaT class. """
356
+ def __init__(
357
+ self,
358
+ img_size=224,
359
+ patch_size=16,
360
+ in_chans=3,
361
+ num_classes=1000,
362
+ embed_dims=(64, 128, 320, 512),
363
+ serial_depths=(3, 4, 6, 3),
364
+ parallel_depth=0,
365
+ num_heads=8,
366
+ mlp_ratios=(4, 4, 4, 4),
367
+ qkv_bias=True,
368
+ drop_rate=0.,
369
+ proj_drop_rate=0.,
370
+ attn_drop_rate=0.,
371
+ drop_path_rate=0.,
372
+ norm_layer=LayerNorm,
373
+ return_interm_layers=False,
374
+ out_features=None,
375
+ crpe_window=None,
376
+ global_pool='token',
377
+ ):
378
+ super().__init__()
379
+ assert global_pool in ('token', 'avg')
380
+ crpe_window = crpe_window or {3: 2, 5: 3, 7: 3}
381
+ self.return_interm_layers = return_interm_layers
382
+ self.out_features = out_features
383
+ self.embed_dims = embed_dims
384
+ self.num_features = embed_dims[-1]
385
+ self.num_classes = num_classes
386
+ self.global_pool = global_pool
387
+
388
+ # Patch embeddings.
389
+ img_size = to_2tuple(img_size)
390
+ self.patch_embed1 = PatchEmbed(
391
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans,
392
+ embed_dim=embed_dims[0], norm_layer=nn.LayerNorm)
393
+ self.patch_embed2 = PatchEmbed(
394
+ img_size=[x // 4 for x in img_size], patch_size=2, in_chans=embed_dims[0],
395
+ embed_dim=embed_dims[1], norm_layer=nn.LayerNorm)
396
+ self.patch_embed3 = PatchEmbed(
397
+ img_size=[x // 8 for x in img_size], patch_size=2, in_chans=embed_dims[1],
398
+ embed_dim=embed_dims[2], norm_layer=nn.LayerNorm)
399
+ self.patch_embed4 = PatchEmbed(
400
+ img_size=[x // 16 for x in img_size], patch_size=2, in_chans=embed_dims[2],
401
+ embed_dim=embed_dims[3], norm_layer=nn.LayerNorm)
402
+
403
+ # Class tokens.
404
+ self.cls_token1 = nn.Parameter(torch.zeros(1, 1, embed_dims[0]))
405
+ self.cls_token2 = nn.Parameter(torch.zeros(1, 1, embed_dims[1]))
406
+ self.cls_token3 = nn.Parameter(torch.zeros(1, 1, embed_dims[2]))
407
+ self.cls_token4 = nn.Parameter(torch.zeros(1, 1, embed_dims[3]))
408
+
409
+ # Convolutional position encodings.
410
+ self.cpe1 = ConvPosEnc(dim=embed_dims[0], k=3)
411
+ self.cpe2 = ConvPosEnc(dim=embed_dims[1], k=3)
412
+ self.cpe3 = ConvPosEnc(dim=embed_dims[2], k=3)
413
+ self.cpe4 = ConvPosEnc(dim=embed_dims[3], k=3)
414
+
415
+ # Convolutional relative position encodings.
416
+ self.crpe1 = ConvRelPosEnc(head_chs=embed_dims[0] // num_heads, num_heads=num_heads, window=crpe_window)
417
+ self.crpe2 = ConvRelPosEnc(head_chs=embed_dims[1] // num_heads, num_heads=num_heads, window=crpe_window)
418
+ self.crpe3 = ConvRelPosEnc(head_chs=embed_dims[2] // num_heads, num_heads=num_heads, window=crpe_window)
419
+ self.crpe4 = ConvRelPosEnc(head_chs=embed_dims[3] // num_heads, num_heads=num_heads, window=crpe_window)
420
+
421
+ # Disable stochastic depth.
422
+ dpr = drop_path_rate
423
+ assert dpr == 0.0
424
+ skwargs = dict(
425
+ num_heads=num_heads,
426
+ qkv_bias=qkv_bias,
427
+ proj_drop=proj_drop_rate,
428
+ attn_drop=attn_drop_rate,
429
+ drop_path=dpr,
430
+ norm_layer=norm_layer,
431
+ )
432
+
433
+ # Serial blocks 1.
434
+ self.serial_blocks1 = nn.ModuleList([
435
+ SerialBlock(
436
+ dim=embed_dims[0],
437
+ mlp_ratio=mlp_ratios[0],
438
+ shared_cpe=self.cpe1,
439
+ shared_crpe=self.crpe1,
440
+ **skwargs,
441
+ )
442
+ for _ in range(serial_depths[0])]
443
+ )
444
+
445
+ # Serial blocks 2.
446
+ self.serial_blocks2 = nn.ModuleList([
447
+ SerialBlock(
448
+ dim=embed_dims[1],
449
+ mlp_ratio=mlp_ratios[1],
450
+ shared_cpe=self.cpe2,
451
+ shared_crpe=self.crpe2,
452
+ **skwargs,
453
+ )
454
+ for _ in range(serial_depths[1])]
455
+ )
456
+
457
+ # Serial blocks 3.
458
+ self.serial_blocks3 = nn.ModuleList([
459
+ SerialBlock(
460
+ dim=embed_dims[2],
461
+ mlp_ratio=mlp_ratios[2],
462
+ shared_cpe=self.cpe3,
463
+ shared_crpe=self.crpe3,
464
+ **skwargs,
465
+ )
466
+ for _ in range(serial_depths[2])]
467
+ )
468
+
469
+ # Serial blocks 4.
470
+ self.serial_blocks4 = nn.ModuleList([
471
+ SerialBlock(
472
+ dim=embed_dims[3],
473
+ mlp_ratio=mlp_ratios[3],
474
+ shared_cpe=self.cpe4,
475
+ shared_crpe=self.crpe4,
476
+ **skwargs,
477
+ )
478
+ for _ in range(serial_depths[3])]
479
+ )
480
+
481
+ # Parallel blocks.
482
+ self.parallel_depth = parallel_depth
483
+ if self.parallel_depth > 0:
484
+ self.parallel_blocks = nn.ModuleList([
485
+ ParallelBlock(
486
+ dims=embed_dims,
487
+ mlp_ratios=mlp_ratios,
488
+ shared_crpes=(self.crpe1, self.crpe2, self.crpe3, self.crpe4),
489
+ **skwargs,
490
+ )
491
+ for _ in range(parallel_depth)]
492
+ )
493
+ else:
494
+ self.parallel_blocks = None
495
+
496
+ # Classification head(s).
497
+ if not self.return_interm_layers:
498
+ if self.parallel_blocks is not None:
499
+ self.norm2 = norm_layer(embed_dims[1])
500
+ self.norm3 = norm_layer(embed_dims[2])
501
+ else:
502
+ self.norm2 = self.norm3 = None
503
+ self.norm4 = norm_layer(embed_dims[3])
504
+
505
+ if self.parallel_depth > 0:
506
+ # CoaT series: Aggregate features of last three scales for classification.
507
+ assert embed_dims[1] == embed_dims[2] == embed_dims[3]
508
+ self.aggregate = torch.nn.Conv1d(in_channels=3, out_channels=1, kernel_size=1)
509
+ self.head_drop = nn.Dropout(drop_rate)
510
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
511
+ else:
512
+ # CoaT-Lite series: Use feature of last scale for classification.
513
+ self.aggregate = None
514
+ self.head_drop = nn.Dropout(drop_rate)
515
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
516
+
517
+ # Initialize weights.
518
+ trunc_normal_(self.cls_token1, std=.02)
519
+ trunc_normal_(self.cls_token2, std=.02)
520
+ trunc_normal_(self.cls_token3, std=.02)
521
+ trunc_normal_(self.cls_token4, std=.02)
522
+ self.apply(self._init_weights)
523
+
524
+ def _init_weights(self, m):
525
+ if isinstance(m, nn.Linear):
526
+ trunc_normal_(m.weight, std=.02)
527
+ if isinstance(m, nn.Linear) and m.bias is not None:
528
+ nn.init.constant_(m.bias, 0)
529
+ elif isinstance(m, nn.LayerNorm):
530
+ nn.init.constant_(m.bias, 0)
531
+ nn.init.constant_(m.weight, 1.0)
532
+
533
+ @torch.jit.ignore
534
+ def no_weight_decay(self):
535
+ return {'cls_token1', 'cls_token2', 'cls_token3', 'cls_token4'}
536
+
537
+ @torch.jit.ignore
538
+ def set_grad_checkpointing(self, enable=True):
539
+ assert not enable, 'gradient checkpointing not supported'
540
+
541
+ @torch.jit.ignore
542
+ def group_matcher(self, coarse=False):
543
+ matcher = dict(
544
+ stem1=r'^cls_token1|patch_embed1|crpe1|cpe1',
545
+ serial_blocks1=r'^serial_blocks1\.(\d+)',
546
+ stem2=r'^cls_token2|patch_embed2|crpe2|cpe2',
547
+ serial_blocks2=r'^serial_blocks2\.(\d+)',
548
+ stem3=r'^cls_token3|patch_embed3|crpe3|cpe3',
549
+ serial_blocks3=r'^serial_blocks3\.(\d+)',
550
+ stem4=r'^cls_token4|patch_embed4|crpe4|cpe4',
551
+ serial_blocks4=r'^serial_blocks4\.(\d+)',
552
+ parallel_blocks=[ # FIXME (partially?) overlap parallel w/ serial blocks??
553
+ (r'^parallel_blocks\.(\d+)', None),
554
+ (r'^norm|aggregate', (99999,)),
555
+ ]
556
+ )
557
+ return matcher
558
+
559
+ @torch.jit.ignore
560
+ def get_classifier(self):
561
+ return self.head
562
+
563
+ def reset_classifier(self, num_classes, global_pool=None):
564
+ self.num_classes = num_classes
565
+ if global_pool is not None:
566
+ assert global_pool in ('token', 'avg')
567
+ self.global_pool = global_pool
568
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
569
+
570
+ def forward_features(self, x0):
571
+ B = x0.shape[0]
572
+
573
+ # Serial blocks 1.
574
+ x1 = self.patch_embed1(x0)
575
+ H1, W1 = self.patch_embed1.grid_size
576
+ x1 = insert_cls(x1, self.cls_token1)
577
+ for blk in self.serial_blocks1:
578
+ x1 = blk(x1, size=(H1, W1))
579
+ x1_nocls = remove_cls(x1).reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
580
+
581
+ # Serial blocks 2.
582
+ x2 = self.patch_embed2(x1_nocls)
583
+ H2, W2 = self.patch_embed2.grid_size
584
+ x2 = insert_cls(x2, self.cls_token2)
585
+ for blk in self.serial_blocks2:
586
+ x2 = blk(x2, size=(H2, W2))
587
+ x2_nocls = remove_cls(x2).reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous()
588
+
589
+ # Serial blocks 3.
590
+ x3 = self.patch_embed3(x2_nocls)
591
+ H3, W3 = self.patch_embed3.grid_size
592
+ x3 = insert_cls(x3, self.cls_token3)
593
+ for blk in self.serial_blocks3:
594
+ x3 = blk(x3, size=(H3, W3))
595
+ x3_nocls = remove_cls(x3).reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous()
596
+
597
+ # Serial blocks 4.
598
+ x4 = self.patch_embed4(x3_nocls)
599
+ H4, W4 = self.patch_embed4.grid_size
600
+ x4 = insert_cls(x4, self.cls_token4)
601
+ for blk in self.serial_blocks4:
602
+ x4 = blk(x4, size=(H4, W4))
603
+ x4_nocls = remove_cls(x4).reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous()
604
+
605
+ # Only serial blocks: Early return.
606
+ if self.parallel_blocks is None:
607
+ if not torch.jit.is_scripting() and self.return_interm_layers:
608
+ # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2).
609
+ feat_out = {}
610
+ if 'x1_nocls' in self.out_features:
611
+ feat_out['x1_nocls'] = x1_nocls
612
+ if 'x2_nocls' in self.out_features:
613
+ feat_out['x2_nocls'] = x2_nocls
614
+ if 'x3_nocls' in self.out_features:
615
+ feat_out['x3_nocls'] = x3_nocls
616
+ if 'x4_nocls' in self.out_features:
617
+ feat_out['x4_nocls'] = x4_nocls
618
+ return feat_out
619
+ else:
620
+ # Return features for classification.
621
+ x4 = self.norm4(x4)
622
+ return x4
623
+
624
+ # Parallel blocks.
625
+ for blk in self.parallel_blocks:
626
+ x2, x3, x4 = self.cpe2(x2, (H2, W2)), self.cpe3(x3, (H3, W3)), self.cpe4(x4, (H4, W4))
627
+ x1, x2, x3, x4 = blk(x1, x2, x3, x4, sizes=[(H1, W1), (H2, W2), (H3, W3), (H4, W4)])
628
+
629
+ if not torch.jit.is_scripting() and self.return_interm_layers:
630
+ # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2).
631
+ feat_out = {}
632
+ if 'x1_nocls' in self.out_features:
633
+ x1_nocls = remove_cls(x1).reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
634
+ feat_out['x1_nocls'] = x1_nocls
635
+ if 'x2_nocls' in self.out_features:
636
+ x2_nocls = remove_cls(x2).reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous()
637
+ feat_out['x2_nocls'] = x2_nocls
638
+ if 'x3_nocls' in self.out_features:
639
+ x3_nocls = remove_cls(x3).reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous()
640
+ feat_out['x3_nocls'] = x3_nocls
641
+ if 'x4_nocls' in self.out_features:
642
+ x4_nocls = remove_cls(x4).reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous()
643
+ feat_out['x4_nocls'] = x4_nocls
644
+ return feat_out
645
+ else:
646
+ x2 = self.norm2(x2)
647
+ x3 = self.norm3(x3)
648
+ x4 = self.norm4(x4)
649
+ return [x2, x3, x4]
650
+
651
+ def forward_head(self, x_feat: Union[torch.Tensor, List[torch.Tensor]], pre_logits: bool = False):
652
+ if isinstance(x_feat, list):
653
+ assert self.aggregate is not None
654
+ if self.global_pool == 'avg':
655
+ x = torch.cat([xl[:, 1:].mean(dim=1, keepdim=True) for xl in x_feat], dim=1) # [B, 3, C]
656
+ else:
657
+ x = torch.stack([xl[:, 0] for xl in x_feat], dim=1) # [B, 3, C]
658
+ x = self.aggregate(x).squeeze(dim=1) # Shape: [B, C]
659
+ else:
660
+ x = x_feat[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x_feat[:, 0]
661
+ x = self.head_drop(x)
662
+ return x if pre_logits else self.head(x)
663
+
664
+ def forward(self, x) -> torch.Tensor:
665
+ if not torch.jit.is_scripting() and self.return_interm_layers:
666
+ # Return intermediate features (for down-stream tasks).
667
+ return self.forward_features(x)
668
+ else:
669
+ # Return features for classification.
670
+ x_feat = self.forward_features(x)
671
+ x = self.forward_head(x_feat)
672
+ return x
673
+
674
+
675
+ def insert_cls(x, cls_token):
676
+ """ Insert CLS token. """
677
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
678
+ x = torch.cat((cls_tokens, x), dim=1)
679
+ return x
680
+
681
+
682
+ def remove_cls(x):
683
+ """ Remove CLS token. """
684
+ return x[:, 1:, :]
685
+
686
+
687
+ def checkpoint_filter_fn(state_dict, model):
688
+ out_dict = {}
689
+ state_dict = state_dict.get('model', state_dict)
690
+ for k, v in state_dict.items():
691
+ # original model had unused norm layers, removing them requires filtering pretrained checkpoints
692
+ if k.startswith('norm1') or \
693
+ (k.startswith('norm2') and getattr(model, 'norm2', None) is None) or \
694
+ (k.startswith('norm3') and getattr(model, 'norm3', None) is None) or \
695
+ (k.startswith('norm4') and getattr(model, 'norm4', None) is None) or \
696
+ (k.startswith('aggregate') and getattr(model, 'aggregate', None) is None) or \
697
+ (k.startswith('head') and getattr(model, 'head', None) is None):
698
+ continue
699
+ out_dict[k] = v
700
+ return out_dict
701
+
702
+
703
+ def _create_coat(variant, pretrained=False, default_cfg=None, **kwargs):
704
+ if kwargs.get('features_only', None):
705
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
706
+
707
+ model = build_model_with_cfg(
708
+ CoaT,
709
+ variant,
710
+ pretrained,
711
+ pretrained_filter_fn=checkpoint_filter_fn,
712
+ **kwargs,
713
+ )
714
+ return model
715
+
716
+
717
+ def _cfg_coat(url='', **kwargs):
718
+ return {
719
+ 'url': url,
720
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
721
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
722
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
723
+ 'first_conv': 'patch_embed1.proj', 'classifier': 'head',
724
+ **kwargs
725
+ }
726
+
727
+
728
+ default_cfgs = generate_default_cfgs({
729
+ 'coat_tiny.in1k': _cfg_coat(hf_hub_id='timm/'),
730
+ 'coat_mini.in1k': _cfg_coat(hf_hub_id='timm/'),
731
+ 'coat_small.in1k': _cfg_coat(hf_hub_id='timm/'),
732
+ 'coat_lite_tiny.in1k': _cfg_coat(hf_hub_id='timm/'),
733
+ 'coat_lite_mini.in1k': _cfg_coat(hf_hub_id='timm/'),
734
+ 'coat_lite_small.in1k': _cfg_coat(hf_hub_id='timm/'),
735
+ 'coat_lite_medium.in1k': _cfg_coat(hf_hub_id='timm/'),
736
+ 'coat_lite_medium_384.in1k': _cfg_coat(
737
+ hf_hub_id='timm/',
738
+ input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash',
739
+ ),
740
+ })
741
+
742
+
743
+ @register_model
744
+ def coat_tiny(pretrained=False, **kwargs) -> CoaT:
745
+ model_cfg = dict(
746
+ patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6)
747
+ model = _create_coat('coat_tiny', pretrained=pretrained, **dict(model_cfg, **kwargs))
748
+ return model
749
+
750
+
751
+ @register_model
752
+ def coat_mini(pretrained=False, **kwargs) -> CoaT:
753
+ model_cfg = dict(
754
+ patch_size=4, embed_dims=[152, 216, 216, 216], serial_depths=[2, 2, 2, 2], parallel_depth=6)
755
+ model = _create_coat('coat_mini', pretrained=pretrained, **dict(model_cfg, **kwargs))
756
+ return model
757
+
758
+
759
+ @register_model
760
+ def coat_small(pretrained=False, **kwargs) -> CoaT:
761
+ model_cfg = dict(
762
+ patch_size=4, embed_dims=[152, 320, 320, 320], serial_depths=[2, 2, 2, 2], parallel_depth=6, **kwargs)
763
+ model = _create_coat('coat_small', pretrained=pretrained, **dict(model_cfg, **kwargs))
764
+ return model
765
+
766
+
767
+ @register_model
768
+ def coat_lite_tiny(pretrained=False, **kwargs) -> CoaT:
769
+ model_cfg = dict(
770
+ patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], mlp_ratios=[8, 8, 4, 4])
771
+ model = _create_coat('coat_lite_tiny', pretrained=pretrained, **dict(model_cfg, **kwargs))
772
+ return model
773
+
774
+
775
+ @register_model
776
+ def coat_lite_mini(pretrained=False, **kwargs) -> CoaT:
777
+ model_cfg = dict(
778
+ patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[2, 2, 2, 2], mlp_ratios=[8, 8, 4, 4])
779
+ model = _create_coat('coat_lite_mini', pretrained=pretrained, **dict(model_cfg, **kwargs))
780
+ return model
781
+
782
+
783
+ @register_model
784
+ def coat_lite_small(pretrained=False, **kwargs) -> CoaT:
785
+ model_cfg = dict(
786
+ patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[3, 4, 6, 3], mlp_ratios=[8, 8, 4, 4])
787
+ model = _create_coat('coat_lite_small', pretrained=pretrained, **dict(model_cfg, **kwargs))
788
+ return model
789
+
790
+
791
+ @register_model
792
+ def coat_lite_medium(pretrained=False, **kwargs) -> CoaT:
793
+ model_cfg = dict(
794
+ patch_size=4, embed_dims=[128, 256, 320, 512], serial_depths=[3, 6, 10, 8])
795
+ model = _create_coat('coat_lite_medium', pretrained=pretrained, **dict(model_cfg, **kwargs))
796
+ return model
797
+
798
+
799
+ @register_model
800
+ def coat_lite_medium_384(pretrained=False, **kwargs) -> CoaT:
801
+ model_cfg = dict(
802
+ img_size=384, patch_size=4, embed_dims=[128, 256, 320, 512], serial_depths=[3, 6, 10, 8])
803
+ model = _create_coat('coat_lite_medium_384', pretrained=pretrained, **dict(model_cfg, **kwargs))
804
+ return model
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/convit.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ ConViT Model
2
+
3
+ @article{d2021convit,
4
+ title={ConViT: Improving Vision Transformers with Soft Convolutional Inductive Biases},
5
+ author={d'Ascoli, St{\'e}phane and Touvron, Hugo and Leavitt, Matthew and Morcos, Ari and Biroli, Giulio and Sagun, Levent},
6
+ journal={arXiv preprint arXiv:2103.10697},
7
+ year={2021}
8
+ }
9
+
10
+ Paper link: https://arxiv.org/abs/2103.10697
11
+ Original code: https://github.com/facebookresearch/convit, original copyright below
12
+
13
+ Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
14
+ """
15
+ # Copyright (c) 2015-present, Facebook, Inc.
16
+ # All rights reserved.
17
+ #
18
+ # This source code is licensed under the CC-by-NC license found in the
19
+ # LICENSE file in the root directory of this source tree.
20
+ #
21
+ '''These modules are adapted from those of timm, see
22
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
23
+ '''
24
+
25
+ from functools import partial
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+
30
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
31
+ from timm.layers import DropPath, trunc_normal_, PatchEmbed, Mlp, LayerNorm
32
+ from ._builder import build_model_with_cfg
33
+ from ._features_fx import register_notrace_module
34
+ from ._registry import register_model, generate_default_cfgs
35
+ from .vision_transformer_hybrid import HybridEmbed
36
+
37
+
38
+ __all__ = ['ConVit']
39
+
40
+
41
+ @register_notrace_module # reason: FX can't symbolically trace control flow in forward method
42
+ class GPSA(nn.Module):
43
+ def __init__(
44
+ self,
45
+ dim,
46
+ num_heads=8,
47
+ qkv_bias=False,
48
+ attn_drop=0.,
49
+ proj_drop=0.,
50
+ locality_strength=1.,
51
+ ):
52
+ super().__init__()
53
+ self.num_heads = num_heads
54
+ self.dim = dim
55
+ head_dim = dim // num_heads
56
+ self.scale = head_dim ** -0.5
57
+ self.locality_strength = locality_strength
58
+
59
+ self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias)
60
+ self.v = nn.Linear(dim, dim, bias=qkv_bias)
61
+
62
+ self.attn_drop = nn.Dropout(attn_drop)
63
+ self.proj = nn.Linear(dim, dim)
64
+ self.pos_proj = nn.Linear(3, num_heads)
65
+ self.proj_drop = nn.Dropout(proj_drop)
66
+ self.gating_param = nn.Parameter(torch.ones(self.num_heads))
67
+ self.rel_indices: torch.Tensor = torch.zeros(1, 1, 1, 3) # silly torchscript hack, won't work with None
68
+
69
+ def forward(self, x):
70
+ B, N, C = x.shape
71
+ if self.rel_indices is None or self.rel_indices.shape[1] != N:
72
+ self.rel_indices = self.get_rel_indices(N)
73
+ attn = self.get_attention(x)
74
+ v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
75
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
76
+ x = self.proj(x)
77
+ x = self.proj_drop(x)
78
+ return x
79
+
80
+ def get_attention(self, x):
81
+ B, N, C = x.shape
82
+ qk = self.qk(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
83
+ q, k = qk[0], qk[1]
84
+ pos_score = self.rel_indices.expand(B, -1, -1, -1)
85
+ pos_score = self.pos_proj(pos_score).permute(0, 3, 1, 2)
86
+ patch_score = (q @ k.transpose(-2, -1)) * self.scale
87
+ patch_score = patch_score.softmax(dim=-1)
88
+ pos_score = pos_score.softmax(dim=-1)
89
+
90
+ gating = self.gating_param.view(1, -1, 1, 1)
91
+ attn = (1. - torch.sigmoid(gating)) * patch_score + torch.sigmoid(gating) * pos_score
92
+ attn /= attn.sum(dim=-1).unsqueeze(-1)
93
+ attn = self.attn_drop(attn)
94
+ return attn
95
+
96
+ def get_attention_map(self, x, return_map=False):
97
+ attn_map = self.get_attention(x).mean(0) # average over batch
98
+ distances = self.rel_indices.squeeze()[:, :, -1] ** .5
99
+ dist = torch.einsum('nm,hnm->h', (distances, attn_map)) / distances.size(0)
100
+ if return_map:
101
+ return dist, attn_map
102
+ else:
103
+ return dist
104
+
105
+ def local_init(self):
106
+ self.v.weight.data.copy_(torch.eye(self.dim))
107
+ locality_distance = 1 # max(1,1/locality_strength**.5)
108
+
109
+ kernel_size = int(self.num_heads ** .5)
110
+ center = (kernel_size - 1) / 2 if kernel_size % 2 == 0 else kernel_size // 2
111
+ for h1 in range(kernel_size):
112
+ for h2 in range(kernel_size):
113
+ position = h1 + kernel_size * h2
114
+ self.pos_proj.weight.data[position, 2] = -1
115
+ self.pos_proj.weight.data[position, 1] = 2 * (h1 - center) * locality_distance
116
+ self.pos_proj.weight.data[position, 0] = 2 * (h2 - center) * locality_distance
117
+ self.pos_proj.weight.data *= self.locality_strength
118
+
119
+ def get_rel_indices(self, num_patches: int) -> torch.Tensor:
120
+ img_size = int(num_patches ** .5)
121
+ rel_indices = torch.zeros(1, num_patches, num_patches, 3)
122
+ ind = torch.arange(img_size).view(1, -1) - torch.arange(img_size).view(-1, 1)
123
+ indx = ind.repeat(img_size, img_size)
124
+ indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1)
125
+ indd = indx ** 2 + indy ** 2
126
+ rel_indices[:, :, :, 2] = indd.unsqueeze(0)
127
+ rel_indices[:, :, :, 1] = indy.unsqueeze(0)
128
+ rel_indices[:, :, :, 0] = indx.unsqueeze(0)
129
+ device = self.qk.weight.device
130
+ return rel_indices.to(device)
131
+
132
+
133
+ class MHSA(nn.Module):
134
+ def __init__(
135
+ self,
136
+ dim,
137
+ num_heads=8,
138
+ qkv_bias=False,
139
+ attn_drop=0.,
140
+ proj_drop=0.,
141
+ ):
142
+ super().__init__()
143
+ self.num_heads = num_heads
144
+ head_dim = dim // num_heads
145
+ self.scale = head_dim ** -0.5
146
+
147
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
148
+ self.attn_drop = nn.Dropout(attn_drop)
149
+ self.proj = nn.Linear(dim, dim)
150
+ self.proj_drop = nn.Dropout(proj_drop)
151
+
152
+ def get_attention_map(self, x, return_map=False):
153
+ B, N, C = x.shape
154
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
155
+ q, k, v = qkv[0], qkv[1], qkv[2]
156
+ attn_map = (q @ k.transpose(-2, -1)) * self.scale
157
+ attn_map = attn_map.softmax(dim=-1).mean(0)
158
+
159
+ img_size = int(N ** .5)
160
+ ind = torch.arange(img_size).view(1, -1) - torch.arange(img_size).view(-1, 1)
161
+ indx = ind.repeat(img_size, img_size)
162
+ indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1)
163
+ indd = indx ** 2 + indy ** 2
164
+ distances = indd ** .5
165
+ distances = distances.to(x.device)
166
+
167
+ dist = torch.einsum('nm,hnm->h', (distances, attn_map)) / N
168
+ if return_map:
169
+ return dist, attn_map
170
+ else:
171
+ return dist
172
+
173
+ def forward(self, x):
174
+ B, N, C = x.shape
175
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
176
+ q, k, v = qkv.unbind(0)
177
+
178
+ attn = (q @ k.transpose(-2, -1)) * self.scale
179
+ attn = attn.softmax(dim=-1)
180
+ attn = self.attn_drop(attn)
181
+
182
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
183
+ x = self.proj(x)
184
+ x = self.proj_drop(x)
185
+ return x
186
+
187
+
188
+ class Block(nn.Module):
189
+
190
+ def __init__(
191
+ self,
192
+ dim,
193
+ num_heads,
194
+ mlp_ratio=4.,
195
+ qkv_bias=False,
196
+ proj_drop=0.,
197
+ attn_drop=0.,
198
+ drop_path=0.,
199
+ act_layer=nn.GELU,
200
+ norm_layer=LayerNorm,
201
+ use_gpsa=True,
202
+ locality_strength=1.,
203
+ ):
204
+ super().__init__()
205
+ self.norm1 = norm_layer(dim)
206
+ self.use_gpsa = use_gpsa
207
+ if self.use_gpsa:
208
+ self.attn = GPSA(
209
+ dim,
210
+ num_heads=num_heads,
211
+ qkv_bias=qkv_bias,
212
+ attn_drop=attn_drop,
213
+ proj_drop=proj_drop,
214
+ locality_strength=locality_strength,
215
+ )
216
+ else:
217
+ self.attn = MHSA(
218
+ dim,
219
+ num_heads=num_heads,
220
+ qkv_bias=qkv_bias,
221
+ attn_drop=attn_drop,
222
+ proj_drop=proj_drop,
223
+ )
224
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
225
+ self.norm2 = norm_layer(dim)
226
+ mlp_hidden_dim = int(dim * mlp_ratio)
227
+ self.mlp = Mlp(
228
+ in_features=dim,
229
+ hidden_features=mlp_hidden_dim,
230
+ act_layer=act_layer,
231
+ drop=proj_drop,
232
+ )
233
+
234
+ def forward(self, x):
235
+ x = x + self.drop_path(self.attn(self.norm1(x)))
236
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
237
+ return x
238
+
239
+
240
+ class ConVit(nn.Module):
241
+ """ Vision Transformer with support for patch or hybrid CNN input stage
242
+ """
243
+
244
+ def __init__(
245
+ self,
246
+ img_size=224,
247
+ patch_size=16,
248
+ in_chans=3,
249
+ num_classes=1000,
250
+ global_pool='token',
251
+ embed_dim=768,
252
+ depth=12,
253
+ num_heads=12,
254
+ mlp_ratio=4.,
255
+ qkv_bias=False,
256
+ drop_rate=0.,
257
+ pos_drop_rate=0.,
258
+ proj_drop_rate=0.,
259
+ attn_drop_rate=0.,
260
+ drop_path_rate=0.,
261
+ hybrid_backbone=None,
262
+ norm_layer=LayerNorm,
263
+ local_up_to_layer=3,
264
+ locality_strength=1.,
265
+ use_pos_embed=True,
266
+ ):
267
+ super().__init__()
268
+ assert global_pool in ('', 'avg', 'token')
269
+ embed_dim *= num_heads
270
+ self.num_classes = num_classes
271
+ self.global_pool = global_pool
272
+ self.local_up_to_layer = local_up_to_layer
273
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
274
+ self.locality_strength = locality_strength
275
+ self.use_pos_embed = use_pos_embed
276
+
277
+ if hybrid_backbone is not None:
278
+ self.patch_embed = HybridEmbed(
279
+ hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
280
+ else:
281
+ self.patch_embed = PatchEmbed(
282
+ img_size=img_size,
283
+ patch_size=patch_size,
284
+ in_chans=in_chans,
285
+ embed_dim=embed_dim,
286
+ )
287
+ num_patches = self.patch_embed.num_patches
288
+ self.num_patches = num_patches
289
+
290
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
291
+ self.pos_drop = nn.Dropout(p=pos_drop_rate)
292
+
293
+ if self.use_pos_embed:
294
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
295
+ trunc_normal_(self.pos_embed, std=.02)
296
+
297
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
298
+ self.blocks = nn.ModuleList([
299
+ Block(
300
+ dim=embed_dim,
301
+ num_heads=num_heads,
302
+ mlp_ratio=mlp_ratio,
303
+ qkv_bias=qkv_bias,
304
+ proj_drop=proj_drop_rate,
305
+ attn_drop=attn_drop_rate,
306
+ drop_path=dpr[i],
307
+ norm_layer=norm_layer,
308
+ use_gpsa=i < local_up_to_layer,
309
+ locality_strength=locality_strength,
310
+ ) for i in range(depth)])
311
+ self.norm = norm_layer(embed_dim)
312
+
313
+ # Classifier head
314
+ self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')]
315
+ self.head_drop = nn.Dropout(drop_rate)
316
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
317
+
318
+ trunc_normal_(self.cls_token, std=.02)
319
+ self.apply(self._init_weights)
320
+ for n, m in self.named_modules():
321
+ if hasattr(m, 'local_init'):
322
+ m.local_init()
323
+
324
+ def _init_weights(self, m):
325
+ if isinstance(m, nn.Linear):
326
+ trunc_normal_(m.weight, std=.02)
327
+ if isinstance(m, nn.Linear) and m.bias is not None:
328
+ nn.init.constant_(m.bias, 0)
329
+ elif isinstance(m, nn.LayerNorm):
330
+ nn.init.constant_(m.bias, 0)
331
+ nn.init.constant_(m.weight, 1.0)
332
+
333
+ @torch.jit.ignore
334
+ def no_weight_decay(self):
335
+ return {'pos_embed', 'cls_token'}
336
+
337
+ @torch.jit.ignore
338
+ def group_matcher(self, coarse=False):
339
+ return dict(
340
+ stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
341
+ blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
342
+ )
343
+
344
+ @torch.jit.ignore
345
+ def set_grad_checkpointing(self, enable=True):
346
+ assert not enable, 'gradient checkpointing not supported'
347
+
348
+ @torch.jit.ignore
349
+ def get_classifier(self):
350
+ return self.head
351
+
352
+ def reset_classifier(self, num_classes, global_pool=None):
353
+ self.num_classes = num_classes
354
+ if global_pool is not None:
355
+ assert global_pool in ('', 'token', 'avg')
356
+ self.global_pool = global_pool
357
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
358
+
359
+ def forward_features(self, x):
360
+ x = self.patch_embed(x)
361
+ if self.use_pos_embed:
362
+ x = x + self.pos_embed
363
+ x = self.pos_drop(x)
364
+ cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
365
+ for u, blk in enumerate(self.blocks):
366
+ if u == self.local_up_to_layer:
367
+ x = torch.cat((cls_tokens, x), dim=1)
368
+ x = blk(x)
369
+ x = self.norm(x)
370
+ return x
371
+
372
+ def forward_head(self, x, pre_logits: bool = False):
373
+ if self.global_pool:
374
+ x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
375
+ x = self.head_drop(x)
376
+ return x if pre_logits else self.head(x)
377
+
378
+ def forward(self, x):
379
+ x = self.forward_features(x)
380
+ x = self.forward_head(x)
381
+ return x
382
+
383
+
384
+ def _create_convit(variant, pretrained=False, **kwargs):
385
+ if kwargs.get('features_only', None):
386
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
387
+
388
+ return build_model_with_cfg(ConVit, variant, pretrained, **kwargs)
389
+
390
+
391
+ def _cfg(url='', **kwargs):
392
+ return {
393
+ 'url': url,
394
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
395
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True,
396
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
397
+ **kwargs
398
+ }
399
+
400
+
401
+ default_cfgs = generate_default_cfgs({
402
+ # ConViT
403
+ 'convit_tiny.fb_in1k': _cfg(hf_hub_id='timm/'),
404
+ 'convit_small.fb_in1k': _cfg(hf_hub_id='timm/'),
405
+ 'convit_base.fb_in1k': _cfg(hf_hub_id='timm/')
406
+ })
407
+
408
+
409
+ @register_model
410
+ def convit_tiny(pretrained=False, **kwargs) -> ConVit:
411
+ model_args = dict(
412
+ local_up_to_layer=10, locality_strength=1.0, embed_dim=48, num_heads=4)
413
+ model = _create_convit(variant='convit_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
414
+ return model
415
+
416
+
417
+ @register_model
418
+ def convit_small(pretrained=False, **kwargs) -> ConVit:
419
+ model_args = dict(
420
+ local_up_to_layer=10, locality_strength=1.0, embed_dim=48, num_heads=9)
421
+ model = _create_convit(variant='convit_small', pretrained=pretrained, **dict(model_args, **kwargs))
422
+ return model
423
+
424
+
425
+ @register_model
426
+ def convit_base(pretrained=False, **kwargs) -> ConVit:
427
+ model_args = dict(
428
+ local_up_to_layer=10, locality_strength=1.0, embed_dim=48, num_heads=16)
429
+ model = _create_convit(variant='convit_base', pretrained=pretrained, **dict(model_args, **kwargs))
430
+ return model
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/crossvit.py ADDED
@@ -0,0 +1,627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ CrossViT Model
2
+
3
+ @inproceedings{
4
+ chen2021crossvit,
5
+ title={{CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification}},
6
+ author={Chun-Fu (Richard) Chen and Quanfu Fan and Rameswar Panda},
7
+ booktitle={International Conference on Computer Vision (ICCV)},
8
+ year={2021}
9
+ }
10
+
11
+ Paper link: https://arxiv.org/abs/2103.14899
12
+ Original code: https://github.com/IBM/CrossViT/blob/main/models/crossvit.py
13
+
14
+ NOTE: model names have been renamed from originals to represent actual input res all *_224 -> *_240 and *_384 -> *_408
15
+
16
+ Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
17
+ """
18
+
19
+ # Copyright IBM All Rights Reserved.
20
+ # SPDX-License-Identifier: Apache-2.0
21
+
22
+
23
+ """
24
+ Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
25
+
26
+ """
27
+ from functools import partial
28
+ from typing import List
29
+ from typing import Tuple
30
+
31
+ import torch
32
+ import torch.hub
33
+ import torch.nn as nn
34
+
35
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
36
+ from timm.layers import DropPath, to_2tuple, trunc_normal_, _assert
37
+ from ._builder import build_model_with_cfg
38
+ from ._features_fx import register_notrace_function
39
+ from ._registry import register_model, generate_default_cfgs
40
+ from .vision_transformer import Block
41
+
42
+ __all__ = ['CrossVit'] # model_registry will add each entrypoint fn to this
43
+
44
+
45
+ class PatchEmbed(nn.Module):
46
+ """ Image to Patch Embedding
47
+ """
48
+
49
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, multi_conv=False):
50
+ super().__init__()
51
+ img_size = to_2tuple(img_size)
52
+ patch_size = to_2tuple(patch_size)
53
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
54
+ self.img_size = img_size
55
+ self.patch_size = patch_size
56
+ self.num_patches = num_patches
57
+ if multi_conv:
58
+ if patch_size[0] == 12:
59
+ self.proj = nn.Sequential(
60
+ nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3),
61
+ nn.ReLU(inplace=True),
62
+ nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=3, padding=0),
63
+ nn.ReLU(inplace=True),
64
+ nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=1, padding=1),
65
+ )
66
+ elif patch_size[0] == 16:
67
+ self.proj = nn.Sequential(
68
+ nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3),
69
+ nn.ReLU(inplace=True),
70
+ nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=2, padding=1),
71
+ nn.ReLU(inplace=True),
72
+ nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1),
73
+ )
74
+ else:
75
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
76
+
77
+ def forward(self, x):
78
+ B, C, H, W = x.shape
79
+ # FIXME look at relaxing size constraints
80
+ _assert(H == self.img_size[0],
81
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
82
+ _assert(W == self.img_size[1],
83
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
84
+ x = self.proj(x).flatten(2).transpose(1, 2)
85
+ return x
86
+
87
+
88
+ class CrossAttention(nn.Module):
89
+ def __init__(
90
+ self,
91
+ dim,
92
+ num_heads=8,
93
+ qkv_bias=False,
94
+ attn_drop=0.,
95
+ proj_drop=0.,
96
+ ):
97
+ super().__init__()
98
+ self.num_heads = num_heads
99
+ head_dim = dim // num_heads
100
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
101
+ self.scale = head_dim ** -0.5
102
+
103
+ self.wq = nn.Linear(dim, dim, bias=qkv_bias)
104
+ self.wk = nn.Linear(dim, dim, bias=qkv_bias)
105
+ self.wv = nn.Linear(dim, dim, bias=qkv_bias)
106
+ self.attn_drop = nn.Dropout(attn_drop)
107
+ self.proj = nn.Linear(dim, dim)
108
+ self.proj_drop = nn.Dropout(proj_drop)
109
+
110
+ def forward(self, x):
111
+ B, N, C = x.shape
112
+ # B1C -> B1H(C/H) -> BH1(C/H)
113
+ q = self.wq(x[:, 0:1, ...]).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
114
+ # BNC -> BNH(C/H) -> BHN(C/H)
115
+ k = self.wk(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
116
+ # BNC -> BNH(C/H) -> BHN(C/H)
117
+ v = self.wv(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
118
+
119
+ attn = (q @ k.transpose(-2, -1)) * self.scale # BH1(C/H) @ BH(C/H)N -> BH1N
120
+ attn = attn.softmax(dim=-1)
121
+ attn = self.attn_drop(attn)
122
+
123
+ x = (attn @ v).transpose(1, 2).reshape(B, 1, C) # (BH1N @ BHN(C/H)) -> BH1(C/H) -> B1H(C/H) -> B1C
124
+ x = self.proj(x)
125
+ x = self.proj_drop(x)
126
+ return x
127
+
128
+
129
+ class CrossAttentionBlock(nn.Module):
130
+
131
+ def __init__(
132
+ self,
133
+ dim,
134
+ num_heads,
135
+ mlp_ratio=4.,
136
+ qkv_bias=False,
137
+ proj_drop=0.,
138
+ attn_drop=0.,
139
+ drop_path=0.,
140
+ act_layer=nn.GELU,
141
+ norm_layer=nn.LayerNorm,
142
+ ):
143
+ super().__init__()
144
+ self.norm1 = norm_layer(dim)
145
+ self.attn = CrossAttention(
146
+ dim,
147
+ num_heads=num_heads,
148
+ qkv_bias=qkv_bias,
149
+ attn_drop=attn_drop,
150
+ proj_drop=proj_drop,
151
+ )
152
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
153
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
154
+
155
+ def forward(self, x):
156
+ x = x[:, 0:1, ...] + self.drop_path(self.attn(self.norm1(x)))
157
+ return x
158
+
159
+
160
+ class MultiScaleBlock(nn.Module):
161
+
162
+ def __init__(
163
+ self,
164
+ dim,
165
+ patches,
166
+ depth,
167
+ num_heads,
168
+ mlp_ratio,
169
+ qkv_bias=False,
170
+ proj_drop=0.,
171
+ attn_drop=0.,
172
+ drop_path=0.,
173
+ act_layer=nn.GELU,
174
+ norm_layer=nn.LayerNorm,
175
+ ):
176
+ super().__init__()
177
+
178
+ num_branches = len(dim)
179
+ self.num_branches = num_branches
180
+ # different branch could have different embedding size, the first one is the base
181
+ self.blocks = nn.ModuleList()
182
+ for d in range(num_branches):
183
+ tmp = []
184
+ for i in range(depth[d]):
185
+ tmp.append(Block(
186
+ dim=dim[d],
187
+ num_heads=num_heads[d],
188
+ mlp_ratio=mlp_ratio[d],
189
+ qkv_bias=qkv_bias,
190
+ proj_drop=proj_drop,
191
+ attn_drop=attn_drop,
192
+ drop_path=drop_path[i],
193
+ norm_layer=norm_layer,
194
+ ))
195
+ if len(tmp) != 0:
196
+ self.blocks.append(nn.Sequential(*tmp))
197
+
198
+ if len(self.blocks) == 0:
199
+ self.blocks = None
200
+
201
+ self.projs = nn.ModuleList()
202
+ for d in range(num_branches):
203
+ if dim[d] == dim[(d + 1) % num_branches] and False:
204
+ tmp = [nn.Identity()]
205
+ else:
206
+ tmp = [norm_layer(dim[d]), act_layer(), nn.Linear(dim[d], dim[(d + 1) % num_branches])]
207
+ self.projs.append(nn.Sequential(*tmp))
208
+
209
+ self.fusion = nn.ModuleList()
210
+ for d in range(num_branches):
211
+ d_ = (d + 1) % num_branches
212
+ nh = num_heads[d_]
213
+ if depth[-1] == 0: # backward capability:
214
+ self.fusion.append(
215
+ CrossAttentionBlock(
216
+ dim=dim[d_],
217
+ num_heads=nh,
218
+ mlp_ratio=mlp_ratio[d],
219
+ qkv_bias=qkv_bias,
220
+ proj_drop=proj_drop,
221
+ attn_drop=attn_drop,
222
+ drop_path=drop_path[-1],
223
+ norm_layer=norm_layer,
224
+ ))
225
+ else:
226
+ tmp = []
227
+ for _ in range(depth[-1]):
228
+ tmp.append(CrossAttentionBlock(
229
+ dim=dim[d_],
230
+ num_heads=nh,
231
+ mlp_ratio=mlp_ratio[d],
232
+ qkv_bias=qkv_bias,
233
+ proj_drop=proj_drop,
234
+ attn_drop=attn_drop,
235
+ drop_path=drop_path[-1],
236
+ norm_layer=norm_layer,
237
+ ))
238
+ self.fusion.append(nn.Sequential(*tmp))
239
+
240
+ self.revert_projs = nn.ModuleList()
241
+ for d in range(num_branches):
242
+ if dim[(d + 1) % num_branches] == dim[d] and False:
243
+ tmp = [nn.Identity()]
244
+ else:
245
+ tmp = [norm_layer(dim[(d + 1) % num_branches]), act_layer(),
246
+ nn.Linear(dim[(d + 1) % num_branches], dim[d])]
247
+ self.revert_projs.append(nn.Sequential(*tmp))
248
+
249
+ def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
250
+
251
+ outs_b = []
252
+ for i, block in enumerate(self.blocks):
253
+ outs_b.append(block(x[i]))
254
+
255
+ # only take the cls token out
256
+ proj_cls_token = torch.jit.annotate(List[torch.Tensor], [])
257
+ for i, proj in enumerate(self.projs):
258
+ proj_cls_token.append(proj(outs_b[i][:, 0:1, ...]))
259
+
260
+ # cross attention
261
+ outs = []
262
+ for i, (fusion, revert_proj) in enumerate(zip(self.fusion, self.revert_projs)):
263
+ tmp = torch.cat((proj_cls_token[i], outs_b[(i + 1) % self.num_branches][:, 1:, ...]), dim=1)
264
+ tmp = fusion(tmp)
265
+ reverted_proj_cls_token = revert_proj(tmp[:, 0:1, ...])
266
+ tmp = torch.cat((reverted_proj_cls_token, outs_b[i][:, 1:, ...]), dim=1)
267
+ outs.append(tmp)
268
+ return outs
269
+
270
+
271
+ def _compute_num_patches(img_size, patches):
272
+ return [i[0] // p * i[1] // p for i, p in zip(img_size, patches)]
273
+
274
+
275
+ @register_notrace_function
276
+ def scale_image(x, ss: Tuple[int, int], crop_scale: bool = False): # annotations for torchscript
277
+ """
278
+ Pulled out of CrossViT.forward_features to bury conditional logic in a leaf node for FX tracing.
279
+ Args:
280
+ x (Tensor): input image
281
+ ss (tuple[int, int]): height and width to scale to
282
+ crop_scale (bool): whether to crop instead of interpolate to achieve the desired scale. Defaults to False
283
+ Returns:
284
+ Tensor: the "scaled" image batch tensor
285
+ """
286
+ H, W = x.shape[-2:]
287
+ if H != ss[0] or W != ss[1]:
288
+ if crop_scale and ss[0] <= H and ss[1] <= W:
289
+ cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.))
290
+ x = x[:, :, cu:cu + ss[0], cl:cl + ss[1]]
291
+ else:
292
+ x = torch.nn.functional.interpolate(x, size=ss, mode='bicubic', align_corners=False)
293
+ return x
294
+
295
+
296
+ class CrossVit(nn.Module):
297
+ """ Vision Transformer with support for patch or hybrid CNN input stage
298
+ """
299
+
300
+ def __init__(
301
+ self,
302
+ img_size=224,
303
+ img_scale=(1.0, 1.0),
304
+ patch_size=(8, 16),
305
+ in_chans=3,
306
+ num_classes=1000,
307
+ embed_dim=(192, 384),
308
+ depth=((1, 3, 1), (1, 3, 1), (1, 3, 1)),
309
+ num_heads=(6, 12),
310
+ mlp_ratio=(2., 2., 4.),
311
+ multi_conv=False,
312
+ crop_scale=False,
313
+ qkv_bias=True,
314
+ drop_rate=0.,
315
+ pos_drop_rate=0.,
316
+ proj_drop_rate=0.,
317
+ attn_drop_rate=0.,
318
+ drop_path_rate=0.,
319
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
320
+ global_pool='token',
321
+ ):
322
+ super().__init__()
323
+ assert global_pool in ('token', 'avg')
324
+
325
+ self.num_classes = num_classes
326
+ self.global_pool = global_pool
327
+ self.img_size = to_2tuple(img_size)
328
+ img_scale = to_2tuple(img_scale)
329
+ self.img_size_scaled = [tuple([int(sj * si) for sj in self.img_size]) for si in img_scale]
330
+ self.crop_scale = crop_scale # crop instead of interpolate for scale
331
+ num_patches = _compute_num_patches(self.img_size_scaled, patch_size)
332
+ self.num_branches = len(patch_size)
333
+ self.embed_dim = embed_dim
334
+ self.num_features = sum(embed_dim)
335
+ self.patch_embed = nn.ModuleList()
336
+
337
+ # hard-coded for torch jit script
338
+ for i in range(self.num_branches):
339
+ setattr(self, f'pos_embed_{i}', nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i])))
340
+ setattr(self, f'cls_token_{i}', nn.Parameter(torch.zeros(1, 1, embed_dim[i])))
341
+
342
+ for im_s, p, d in zip(self.img_size_scaled, patch_size, embed_dim):
343
+ self.patch_embed.append(
344
+ PatchEmbed(
345
+ img_size=im_s,
346
+ patch_size=p,
347
+ in_chans=in_chans,
348
+ embed_dim=d,
349
+ multi_conv=multi_conv,
350
+ ))
351
+
352
+ self.pos_drop = nn.Dropout(p=pos_drop_rate)
353
+
354
+ total_depth = sum([sum(x[-2:]) for x in depth])
355
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_depth)] # stochastic depth decay rule
356
+ dpr_ptr = 0
357
+ self.blocks = nn.ModuleList()
358
+ for idx, block_cfg in enumerate(depth):
359
+ curr_depth = max(block_cfg[:-1]) + block_cfg[-1]
360
+ dpr_ = dpr[dpr_ptr:dpr_ptr + curr_depth]
361
+ blk = MultiScaleBlock(
362
+ embed_dim,
363
+ num_patches,
364
+ block_cfg,
365
+ num_heads=num_heads,
366
+ mlp_ratio=mlp_ratio,
367
+ qkv_bias=qkv_bias,
368
+ proj_drop=proj_drop_rate,
369
+ attn_drop=attn_drop_rate,
370
+ drop_path=dpr_,
371
+ norm_layer=norm_layer,
372
+ )
373
+ dpr_ptr += curr_depth
374
+ self.blocks.append(blk)
375
+
376
+ self.norm = nn.ModuleList([norm_layer(embed_dim[i]) for i in range(self.num_branches)])
377
+ self.head_drop = nn.Dropout(drop_rate)
378
+ self.head = nn.ModuleList([
379
+ nn.Linear(embed_dim[i], num_classes) if num_classes > 0 else nn.Identity()
380
+ for i in range(self.num_branches)])
381
+
382
+ for i in range(self.num_branches):
383
+ trunc_normal_(getattr(self, f'pos_embed_{i}'), std=.02)
384
+ trunc_normal_(getattr(self, f'cls_token_{i}'), std=.02)
385
+
386
+ self.apply(self._init_weights)
387
+
388
+ def _init_weights(self, m):
389
+ if isinstance(m, nn.Linear):
390
+ trunc_normal_(m.weight, std=.02)
391
+ if isinstance(m, nn.Linear) and m.bias is not None:
392
+ nn.init.constant_(m.bias, 0)
393
+ elif isinstance(m, nn.LayerNorm):
394
+ nn.init.constant_(m.bias, 0)
395
+ nn.init.constant_(m.weight, 1.0)
396
+
397
+ @torch.jit.ignore
398
+ def no_weight_decay(self):
399
+ out = set()
400
+ for i in range(self.num_branches):
401
+ out.add(f'cls_token_{i}')
402
+ pe = getattr(self, f'pos_embed_{i}', None)
403
+ if pe is not None and pe.requires_grad:
404
+ out.add(f'pos_embed_{i}')
405
+ return out
406
+
407
+ @torch.jit.ignore
408
+ def group_matcher(self, coarse=False):
409
+ return dict(
410
+ stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
411
+ blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
412
+ )
413
+
414
+ @torch.jit.ignore
415
+ def set_grad_checkpointing(self, enable=True):
416
+ assert not enable, 'gradient checkpointing not supported'
417
+
418
+ @torch.jit.ignore
419
+ def get_classifier(self):
420
+ return self.head
421
+
422
+ def reset_classifier(self, num_classes, global_pool=None):
423
+ self.num_classes = num_classes
424
+ if global_pool is not None:
425
+ assert global_pool in ('token', 'avg')
426
+ self.global_pool = global_pool
427
+ self.head = nn.ModuleList(
428
+ [nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in
429
+ range(self.num_branches)])
430
+
431
+ def forward_features(self, x) -> List[torch.Tensor]:
432
+ B = x.shape[0]
433
+ xs = []
434
+ for i, patch_embed in enumerate(self.patch_embed):
435
+ x_ = x
436
+ ss = self.img_size_scaled[i]
437
+ x_ = scale_image(x_, ss, self.crop_scale)
438
+ x_ = patch_embed(x_)
439
+ cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script
440
+ cls_tokens = cls_tokens.expand(B, -1, -1)
441
+ x_ = torch.cat((cls_tokens, x_), dim=1)
442
+ pos_embed = self.pos_embed_0 if i == 0 else self.pos_embed_1 # hard-coded for torch jit script
443
+ x_ = x_ + pos_embed
444
+ x_ = self.pos_drop(x_)
445
+ xs.append(x_)
446
+
447
+ for i, blk in enumerate(self.blocks):
448
+ xs = blk(xs)
449
+
450
+ # NOTE: was before branch token section, move to here to assure all branch token are before layer norm
451
+ xs = [norm(xs[i]) for i, norm in enumerate(self.norm)]
452
+ return xs
453
+
454
+ def forward_head(self, xs: List[torch.Tensor], pre_logits: bool = False) -> torch.Tensor:
455
+ xs = [x[:, 1:].mean(dim=1) for x in xs] if self.global_pool == 'avg' else [x[:, 0] for x in xs]
456
+ xs = [self.head_drop(x) for x in xs]
457
+ if pre_logits or isinstance(self.head[0], nn.Identity):
458
+ return torch.cat([x for x in xs], dim=1)
459
+ return torch.mean(torch.stack([head(xs[i]) for i, head in enumerate(self.head)], dim=0), dim=0)
460
+
461
+ def forward(self, x):
462
+ xs = self.forward_features(x)
463
+ x = self.forward_head(xs)
464
+ return x
465
+
466
+
467
+ def _create_crossvit(variant, pretrained=False, **kwargs):
468
+ if kwargs.get('features_only', None):
469
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
470
+
471
+ def pretrained_filter_fn(state_dict):
472
+ new_state_dict = {}
473
+ for key in state_dict.keys():
474
+ if 'pos_embed' in key or 'cls_token' in key:
475
+ new_key = key.replace(".", "_")
476
+ else:
477
+ new_key = key
478
+ new_state_dict[new_key] = state_dict[key]
479
+ return new_state_dict
480
+
481
+ return build_model_with_cfg(
482
+ CrossVit,
483
+ variant,
484
+ pretrained,
485
+ pretrained_filter_fn=pretrained_filter_fn,
486
+ **kwargs,
487
+ )
488
+
489
+
490
+ def _cfg(url='', **kwargs):
491
+ return {
492
+ 'url': url,
493
+ 'num_classes': 1000, 'input_size': (3, 240, 240), 'pool_size': None, 'crop_pct': 0.875,
494
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True,
495
+ 'first_conv': ('patch_embed.0.proj', 'patch_embed.1.proj'),
496
+ 'classifier': ('head.0', 'head.1'),
497
+ **kwargs
498
+ }
499
+
500
+
501
+ default_cfgs = generate_default_cfgs({
502
+ 'crossvit_15_240.in1k': _cfg(hf_hub_id='timm/'),
503
+ 'crossvit_15_dagger_240.in1k': _cfg(
504
+ hf_hub_id='timm/',
505
+ first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
506
+ ),
507
+ 'crossvit_15_dagger_408.in1k': _cfg(
508
+ hf_hub_id='timm/',
509
+ input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), crop_pct=1.0,
510
+ ),
511
+ 'crossvit_18_240.in1k': _cfg(hf_hub_id='timm/'),
512
+ 'crossvit_18_dagger_240.in1k': _cfg(
513
+ hf_hub_id='timm/',
514
+ first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
515
+ ),
516
+ 'crossvit_18_dagger_408.in1k': _cfg(
517
+ hf_hub_id='timm/',
518
+ input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), crop_pct=1.0,
519
+ ),
520
+ 'crossvit_9_240.in1k': _cfg(hf_hub_id='timm/'),
521
+ 'crossvit_9_dagger_240.in1k': _cfg(
522
+ hf_hub_id='timm/',
523
+ first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
524
+ ),
525
+ 'crossvit_base_240.in1k': _cfg(hf_hub_id='timm/'),
526
+ 'crossvit_small_240.in1k': _cfg(hf_hub_id='timm/'),
527
+ 'crossvit_tiny_240.in1k': _cfg(hf_hub_id='timm/'),
528
+ })
529
+
530
+
531
+ @register_model
532
+ def crossvit_tiny_240(pretrained=False, **kwargs) -> CrossVit:
533
+ model_args = dict(
534
+ img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[96, 192], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
535
+ num_heads=[3, 3], mlp_ratio=[4, 4, 1])
536
+ model = _create_crossvit(variant='crossvit_tiny_240', pretrained=pretrained, **dict(model_args, **kwargs))
537
+ return model
538
+
539
+
540
+ @register_model
541
+ def crossvit_small_240(pretrained=False, **kwargs) -> CrossVit:
542
+ model_args = dict(
543
+ img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
544
+ num_heads=[6, 6], mlp_ratio=[4, 4, 1])
545
+ model = _create_crossvit(variant='crossvit_small_240', pretrained=pretrained, **dict(model_args, **kwargs))
546
+ return model
547
+
548
+
549
+ @register_model
550
+ def crossvit_base_240(pretrained=False, **kwargs) -> CrossVit:
551
+ model_args = dict(
552
+ img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[384, 768], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
553
+ num_heads=[12, 12], mlp_ratio=[4, 4, 1])
554
+ model = _create_crossvit(variant='crossvit_base_240', pretrained=pretrained, **dict(model_args, **kwargs))
555
+ return model
556
+
557
+
558
+ @register_model
559
+ def crossvit_9_240(pretrained=False, **kwargs) -> CrossVit:
560
+ model_args = dict(
561
+ img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]],
562
+ num_heads=[4, 4], mlp_ratio=[3, 3, 1])
563
+ model = _create_crossvit(variant='crossvit_9_240', pretrained=pretrained, **dict(model_args, **kwargs))
564
+ return model
565
+
566
+
567
+ @register_model
568
+ def crossvit_15_240(pretrained=False, **kwargs) -> CrossVit:
569
+ model_args = dict(
570
+ img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
571
+ num_heads=[6, 6], mlp_ratio=[3, 3, 1])
572
+ model = _create_crossvit(variant='crossvit_15_240', pretrained=pretrained, **dict(model_args, **kwargs))
573
+ return model
574
+
575
+
576
+ @register_model
577
+ def crossvit_18_240(pretrained=False, **kwargs) -> CrossVit:
578
+ model_args = dict(
579
+ img_scale=(1.0, 224 / 240), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
580
+ num_heads=[7, 7], mlp_ratio=[3, 3, 1], **kwargs)
581
+ model = _create_crossvit(variant='crossvit_18_240', pretrained=pretrained, **dict(model_args, **kwargs))
582
+ return model
583
+
584
+
585
+ @register_model
586
+ def crossvit_9_dagger_240(pretrained=False, **kwargs) -> CrossVit:
587
+ model_args = dict(
588
+ img_scale=(1.0, 224 / 240), patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]],
589
+ num_heads=[4, 4], mlp_ratio=[3, 3, 1], multi_conv=True)
590
+ model = _create_crossvit(variant='crossvit_9_dagger_240', pretrained=pretrained, **dict(model_args, **kwargs))
591
+ return model
592
+
593
+
594
+ @register_model
595
+ def crossvit_15_dagger_240(pretrained=False, **kwargs) -> CrossVit:
596
+ model_args = dict(
597
+ img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
598
+ num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True)
599
+ model = _create_crossvit(variant='crossvit_15_dagger_240', pretrained=pretrained, **dict(model_args, **kwargs))
600
+ return model
601
+
602
+
603
+ @register_model
604
+ def crossvit_15_dagger_408(pretrained=False, **kwargs) -> CrossVit:
605
+ model_args = dict(
606
+ img_scale=(1.0, 384/408), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
607
+ num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True)
608
+ model = _create_crossvit(variant='crossvit_15_dagger_408', pretrained=pretrained, **dict(model_args, **kwargs))
609
+ return model
610
+
611
+
612
+ @register_model
613
+ def crossvit_18_dagger_240(pretrained=False, **kwargs) -> CrossVit:
614
+ model_args = dict(
615
+ img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
616
+ num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True)
617
+ model = _create_crossvit(variant='crossvit_18_dagger_240', pretrained=pretrained, **dict(model_args, **kwargs))
618
+ return model
619
+
620
+
621
+ @register_model
622
+ def crossvit_18_dagger_408(pretrained=False, **kwargs) -> CrossVit:
623
+ model_args = dict(
624
+ img_scale=(1.0, 384/408), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
625
+ num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True)
626
+ model = _create_crossvit(variant='crossvit_18_dagger_408', pretrained=pretrained, **dict(model_args, **kwargs))
627
+ return model
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/cspnet.py ADDED
@@ -0,0 +1,1106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch CspNet
2
+
3
+ A PyTorch implementation of Cross Stage Partial Networks including:
4
+ * CSPResNet50
5
+ * CSPResNeXt50
6
+ * CSPDarkNet53
7
+ * and DarkNet53 for good measure
8
+
9
+ Based on paper `CSPNet: A New Backbone that can Enhance Learning Capability of CNN` - https://arxiv.org/abs/1911.11929
10
+
11
+ Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStagePartialNetworks
12
+
13
+ Hacked together by / Copyright 2020 Ross Wightman
14
+ """
15
+ from dataclasses import dataclass, asdict, replace
16
+ from functools import partial
17
+ from typing import Any, Dict, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
23
+ from timm.layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, get_attn, create_act_layer, make_divisible
24
+ from ._builder import build_model_with_cfg
25
+ from ._manipulate import named_apply, MATCH_PREV_GROUP
26
+ from ._registry import register_model, generate_default_cfgs
27
+
28
+ __all__ = ['CspNet'] # model_registry will add each entrypoint fn to this
29
+
30
+
31
+ @dataclass
32
+ class CspStemCfg:
33
+ out_chs: Union[int, Tuple[int, ...]] = 32
34
+ stride: Union[int, Tuple[int, ...]] = 2
35
+ kernel_size: int = 3
36
+ padding: Union[int, str] = ''
37
+ pool: Optional[str] = ''
38
+
39
+
40
+ def _pad_arg(x, n):
41
+ # pads an argument tuple to specified n by padding with last value
42
+ if not isinstance(x, (tuple, list)):
43
+ x = (x,)
44
+ curr_n = len(x)
45
+ pad_n = n - curr_n
46
+ if pad_n <= 0:
47
+ return x[:n]
48
+ return tuple(x + (x[-1],) * pad_n)
49
+
50
+
51
+ @dataclass
52
+ class CspStagesCfg:
53
+ depth: Tuple[int, ...] = (3, 3, 5, 2) # block depth (number of block repeats in stages)
54
+ out_chs: Tuple[int, ...] = (128, 256, 512, 1024) # number of output channels for blocks in stage
55
+ stride: Union[int, Tuple[int, ...]] = 2 # stride of stage
56
+ groups: Union[int, Tuple[int, ...]] = 1 # num kxk conv groups
57
+ block_ratio: Union[float, Tuple[float, ...]] = 1.0
58
+ bottle_ratio: Union[float, Tuple[float, ...]] = 1. # bottleneck-ratio of blocks in stage
59
+ avg_down: Union[bool, Tuple[bool, ...]] = False
60
+ attn_layer: Optional[Union[str, Tuple[str, ...]]] = None
61
+ attn_kwargs: Optional[Union[Dict, Tuple[Dict]]] = None
62
+ stage_type: Union[str, Tuple[str]] = 'csp' # stage type ('csp', 'cs2', 'dark')
63
+ block_type: Union[str, Tuple[str]] = 'bottle' # blocks type for stages ('bottle', 'dark')
64
+
65
+ # cross-stage only
66
+ expand_ratio: Union[float, Tuple[float, ...]] = 1.0
67
+ cross_linear: Union[bool, Tuple[bool, ...]] = False
68
+ down_growth: Union[bool, Tuple[bool, ...]] = False
69
+
70
+ def __post_init__(self):
71
+ n = len(self.depth)
72
+ assert len(self.out_chs) == n
73
+ self.stride = _pad_arg(self.stride, n)
74
+ self.groups = _pad_arg(self.groups, n)
75
+ self.block_ratio = _pad_arg(self.block_ratio, n)
76
+ self.bottle_ratio = _pad_arg(self.bottle_ratio, n)
77
+ self.avg_down = _pad_arg(self.avg_down, n)
78
+ self.attn_layer = _pad_arg(self.attn_layer, n)
79
+ self.attn_kwargs = _pad_arg(self.attn_kwargs, n)
80
+ self.stage_type = _pad_arg(self.stage_type, n)
81
+ self.block_type = _pad_arg(self.block_type, n)
82
+
83
+ self.expand_ratio = _pad_arg(self.expand_ratio, n)
84
+ self.cross_linear = _pad_arg(self.cross_linear, n)
85
+ self.down_growth = _pad_arg(self.down_growth, n)
86
+
87
+
88
+ @dataclass
89
+ class CspModelCfg:
90
+ stem: CspStemCfg
91
+ stages: CspStagesCfg
92
+ zero_init_last: bool = True # zero init last weight (usually bn) in residual path
93
+ act_layer: str = 'leaky_relu'
94
+ norm_layer: str = 'batchnorm'
95
+ aa_layer: Optional[str] = None # FIXME support string factory for this
96
+
97
+
98
+ def _cs3_cfg(
99
+ width_multiplier=1.0,
100
+ depth_multiplier=1.0,
101
+ avg_down=False,
102
+ act_layer='silu',
103
+ focus=False,
104
+ attn_layer=None,
105
+ attn_kwargs=None,
106
+ bottle_ratio=1.0,
107
+ block_type='dark',
108
+ ):
109
+ if focus:
110
+ stem_cfg = CspStemCfg(
111
+ out_chs=make_divisible(64 * width_multiplier),
112
+ kernel_size=6, stride=2, padding=2, pool='')
113
+ else:
114
+ stem_cfg = CspStemCfg(
115
+ out_chs=tuple([make_divisible(c * width_multiplier) for c in (32, 64)]),
116
+ kernel_size=3, stride=2, pool='')
117
+ return CspModelCfg(
118
+ stem=stem_cfg,
119
+ stages=CspStagesCfg(
120
+ out_chs=tuple([make_divisible(c * width_multiplier) for c in (128, 256, 512, 1024)]),
121
+ depth=tuple([int(d * depth_multiplier) for d in (3, 6, 9, 3)]),
122
+ stride=2,
123
+ bottle_ratio=bottle_ratio,
124
+ block_ratio=0.5,
125
+ avg_down=avg_down,
126
+ attn_layer=attn_layer,
127
+ attn_kwargs=attn_kwargs,
128
+ stage_type='cs3',
129
+ block_type=block_type,
130
+ ),
131
+ act_layer=act_layer,
132
+ )
133
+
134
+
135
+ class BottleneckBlock(nn.Module):
136
+ """ ResNe(X)t Bottleneck Block
137
+ """
138
+
139
+ def __init__(
140
+ self,
141
+ in_chs,
142
+ out_chs,
143
+ dilation=1,
144
+ bottle_ratio=0.25,
145
+ groups=1,
146
+ act_layer=nn.ReLU,
147
+ norm_layer=nn.BatchNorm2d,
148
+ attn_last=False,
149
+ attn_layer=None,
150
+ drop_block=None,
151
+ drop_path=0.
152
+ ):
153
+ super(BottleneckBlock, self).__init__()
154
+ mid_chs = int(round(out_chs * bottle_ratio))
155
+ ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
156
+ attn_last = attn_layer is not None and attn_last
157
+ attn_first = attn_layer is not None and not attn_last
158
+
159
+ self.conv1 = ConvNormAct(in_chs, mid_chs, kernel_size=1, **ckwargs)
160
+ self.conv2 = ConvNormAct(
161
+ mid_chs, mid_chs, kernel_size=3, dilation=dilation, groups=groups,
162
+ drop_layer=drop_block, **ckwargs)
163
+ self.attn2 = attn_layer(mid_chs, act_layer=act_layer) if attn_first else nn.Identity()
164
+ self.conv3 = ConvNormAct(mid_chs, out_chs, kernel_size=1, apply_act=False, **ckwargs)
165
+ self.attn3 = attn_layer(out_chs, act_layer=act_layer) if attn_last else nn.Identity()
166
+ self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
167
+ self.act3 = create_act_layer(act_layer)
168
+
169
+ def zero_init_last(self):
170
+ nn.init.zeros_(self.conv3.bn.weight)
171
+
172
+ def forward(self, x):
173
+ shortcut = x
174
+ x = self.conv1(x)
175
+ x = self.conv2(x)
176
+ x = self.attn2(x)
177
+ x = self.conv3(x)
178
+ x = self.attn3(x)
179
+ x = self.drop_path(x) + shortcut
180
+ # FIXME partial shortcut needed if first block handled as per original, not used for my current impl
181
+ #x[:, :shortcut.size(1)] += shortcut
182
+ x = self.act3(x)
183
+ return x
184
+
185
+
186
+ class DarkBlock(nn.Module):
187
+ """ DarkNet Block
188
+ """
189
+
190
+ def __init__(
191
+ self,
192
+ in_chs,
193
+ out_chs,
194
+ dilation=1,
195
+ bottle_ratio=0.5,
196
+ groups=1,
197
+ act_layer=nn.ReLU,
198
+ norm_layer=nn.BatchNorm2d,
199
+ attn_layer=None,
200
+ drop_block=None,
201
+ drop_path=0.
202
+ ):
203
+ super(DarkBlock, self).__init__()
204
+ mid_chs = int(round(out_chs * bottle_ratio))
205
+ ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
206
+
207
+ self.conv1 = ConvNormAct(in_chs, mid_chs, kernel_size=1, **ckwargs)
208
+ self.attn = attn_layer(mid_chs, act_layer=act_layer) if attn_layer is not None else nn.Identity()
209
+ self.conv2 = ConvNormAct(
210
+ mid_chs, out_chs, kernel_size=3, dilation=dilation, groups=groups,
211
+ drop_layer=drop_block, **ckwargs)
212
+ self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
213
+
214
+ def zero_init_last(self):
215
+ nn.init.zeros_(self.conv2.bn.weight)
216
+
217
+ def forward(self, x):
218
+ shortcut = x
219
+ x = self.conv1(x)
220
+ x = self.attn(x)
221
+ x = self.conv2(x)
222
+ x = self.drop_path(x) + shortcut
223
+ return x
224
+
225
+
226
+ class EdgeBlock(nn.Module):
227
+ """ EdgeResidual / Fused-MBConv / MobileNetV1-like 3x3 + 1x1 block (w/ activated output)
228
+ """
229
+
230
+ def __init__(
231
+ self,
232
+ in_chs,
233
+ out_chs,
234
+ dilation=1,
235
+ bottle_ratio=0.5,
236
+ groups=1,
237
+ act_layer=nn.ReLU,
238
+ norm_layer=nn.BatchNorm2d,
239
+ attn_layer=None,
240
+ drop_block=None,
241
+ drop_path=0.
242
+ ):
243
+ super(EdgeBlock, self).__init__()
244
+ mid_chs = int(round(out_chs * bottle_ratio))
245
+ ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
246
+
247
+ self.conv1 = ConvNormAct(
248
+ in_chs, mid_chs, kernel_size=3, dilation=dilation, groups=groups,
249
+ drop_layer=drop_block, **ckwargs)
250
+ self.attn = attn_layer(mid_chs, act_layer=act_layer) if attn_layer is not None else nn.Identity()
251
+ self.conv2 = ConvNormAct(mid_chs, out_chs, kernel_size=1, **ckwargs)
252
+ self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
253
+
254
+ def zero_init_last(self):
255
+ nn.init.zeros_(self.conv2.bn.weight)
256
+
257
+ def forward(self, x):
258
+ shortcut = x
259
+ x = self.conv1(x)
260
+ x = self.attn(x)
261
+ x = self.conv2(x)
262
+ x = self.drop_path(x) + shortcut
263
+ return x
264
+
265
+
266
+ class CrossStage(nn.Module):
267
+ """Cross Stage."""
268
+ def __init__(
269
+ self,
270
+ in_chs,
271
+ out_chs,
272
+ stride,
273
+ dilation,
274
+ depth,
275
+ block_ratio=1.,
276
+ bottle_ratio=1.,
277
+ expand_ratio=1.,
278
+ groups=1,
279
+ first_dilation=None,
280
+ avg_down=False,
281
+ down_growth=False,
282
+ cross_linear=False,
283
+ block_dpr=None,
284
+ block_fn=BottleneckBlock,
285
+ **block_kwargs,
286
+ ):
287
+ super(CrossStage, self).__init__()
288
+ first_dilation = first_dilation or dilation
289
+ down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels
290
+ self.expand_chs = exp_chs = int(round(out_chs * expand_ratio))
291
+ block_out_chs = int(round(out_chs * block_ratio))
292
+ conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
293
+ aa_layer = block_kwargs.pop('aa_layer', None)
294
+
295
+ if stride != 1 or first_dilation != dilation:
296
+ if avg_down:
297
+ self.conv_down = nn.Sequential(
298
+ nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
299
+ ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
300
+ )
301
+ else:
302
+ self.conv_down = ConvNormActAa(
303
+ in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
304
+ aa_layer=aa_layer, **conv_kwargs)
305
+ prev_chs = down_chs
306
+ else:
307
+ self.conv_down = nn.Identity()
308
+ prev_chs = in_chs
309
+
310
+ # FIXME this 1x1 expansion is pushed down into the cross and block paths in the darknet cfgs. Also,
311
+ # there is also special case for the first stage for some of the model that results in uneven split
312
+ # across the two paths. I did it this way for simplicity for now.
313
+ self.conv_exp = ConvNormAct(prev_chs, exp_chs, kernel_size=1, apply_act=not cross_linear, **conv_kwargs)
314
+ prev_chs = exp_chs // 2 # output of conv_exp is always split in two
315
+
316
+ self.blocks = nn.Sequential()
317
+ for i in range(depth):
318
+ self.blocks.add_module(str(i), block_fn(
319
+ in_chs=prev_chs,
320
+ out_chs=block_out_chs,
321
+ dilation=dilation,
322
+ bottle_ratio=bottle_ratio,
323
+ groups=groups,
324
+ drop_path=block_dpr[i] if block_dpr is not None else 0.,
325
+ **block_kwargs,
326
+ ))
327
+ prev_chs = block_out_chs
328
+
329
+ # transition convs
330
+ self.conv_transition_b = ConvNormAct(prev_chs, exp_chs // 2, kernel_size=1, **conv_kwargs)
331
+ self.conv_transition = ConvNormAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs)
332
+
333
+ def forward(self, x):
334
+ x = self.conv_down(x)
335
+ x = self.conv_exp(x)
336
+ xs, xb = x.split(self.expand_chs // 2, dim=1)
337
+ xb = self.blocks(xb)
338
+ xb = self.conv_transition_b(xb).contiguous()
339
+ out = self.conv_transition(torch.cat([xs, xb], dim=1))
340
+ return out
341
+
342
+
343
+ class CrossStage3(nn.Module):
344
+ """Cross Stage 3.
345
+ Similar to CrossStage, but with only one transition conv for the output.
346
+ """
347
+ def __init__(
348
+ self,
349
+ in_chs,
350
+ out_chs,
351
+ stride,
352
+ dilation,
353
+ depth,
354
+ block_ratio=1.,
355
+ bottle_ratio=1.,
356
+ expand_ratio=1.,
357
+ groups=1,
358
+ first_dilation=None,
359
+ avg_down=False,
360
+ down_growth=False,
361
+ cross_linear=False,
362
+ block_dpr=None,
363
+ block_fn=BottleneckBlock,
364
+ **block_kwargs,
365
+ ):
366
+ super(CrossStage3, self).__init__()
367
+ first_dilation = first_dilation or dilation
368
+ down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels
369
+ self.expand_chs = exp_chs = int(round(out_chs * expand_ratio))
370
+ block_out_chs = int(round(out_chs * block_ratio))
371
+ conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
372
+ aa_layer = block_kwargs.pop('aa_layer', None)
373
+
374
+ if stride != 1 or first_dilation != dilation:
375
+ if avg_down:
376
+ self.conv_down = nn.Sequential(
377
+ nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
378
+ ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
379
+ )
380
+ else:
381
+ self.conv_down = ConvNormActAa(
382
+ in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
383
+ aa_layer=aa_layer, **conv_kwargs)
384
+ prev_chs = down_chs
385
+ else:
386
+ self.conv_down = None
387
+ prev_chs = in_chs
388
+
389
+ # expansion conv
390
+ self.conv_exp = ConvNormAct(prev_chs, exp_chs, kernel_size=1, apply_act=not cross_linear, **conv_kwargs)
391
+ prev_chs = exp_chs // 2 # expanded output is split in 2 for blocks and cross stage
392
+
393
+ self.blocks = nn.Sequential()
394
+ for i in range(depth):
395
+ self.blocks.add_module(str(i), block_fn(
396
+ in_chs=prev_chs,
397
+ out_chs=block_out_chs,
398
+ dilation=dilation,
399
+ bottle_ratio=bottle_ratio,
400
+ groups=groups,
401
+ drop_path=block_dpr[i] if block_dpr is not None else 0.,
402
+ **block_kwargs,
403
+ ))
404
+ prev_chs = block_out_chs
405
+
406
+ # transition convs
407
+ self.conv_transition = ConvNormAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs)
408
+
409
+ def forward(self, x):
410
+ x = self.conv_down(x)
411
+ x = self.conv_exp(x)
412
+ x1, x2 = x.split(self.expand_chs // 2, dim=1)
413
+ x1 = self.blocks(x1)
414
+ out = self.conv_transition(torch.cat([x1, x2], dim=1))
415
+ return out
416
+
417
+
418
+ class DarkStage(nn.Module):
419
+ """DarkNet stage."""
420
+
421
+ def __init__(
422
+ self,
423
+ in_chs,
424
+ out_chs,
425
+ stride,
426
+ dilation,
427
+ depth,
428
+ block_ratio=1.,
429
+ bottle_ratio=1.,
430
+ groups=1,
431
+ first_dilation=None,
432
+ avg_down=False,
433
+ block_fn=BottleneckBlock,
434
+ block_dpr=None,
435
+ **block_kwargs,
436
+ ):
437
+ super(DarkStage, self).__init__()
438
+ first_dilation = first_dilation or dilation
439
+ conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
440
+ aa_layer = block_kwargs.pop('aa_layer', None)
441
+
442
+ if avg_down:
443
+ self.conv_down = nn.Sequential(
444
+ nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
445
+ ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
446
+ )
447
+ else:
448
+ self.conv_down = ConvNormActAa(
449
+ in_chs, out_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
450
+ aa_layer=aa_layer, **conv_kwargs)
451
+
452
+ prev_chs = out_chs
453
+ block_out_chs = int(round(out_chs * block_ratio))
454
+ self.blocks = nn.Sequential()
455
+ for i in range(depth):
456
+ self.blocks.add_module(str(i), block_fn(
457
+ in_chs=prev_chs,
458
+ out_chs=block_out_chs,
459
+ dilation=dilation,
460
+ bottle_ratio=bottle_ratio,
461
+ groups=groups,
462
+ drop_path=block_dpr[i] if block_dpr is not None else 0.,
463
+ **block_kwargs
464
+ ))
465
+ prev_chs = block_out_chs
466
+
467
+ def forward(self, x):
468
+ x = self.conv_down(x)
469
+ x = self.blocks(x)
470
+ return x
471
+
472
+
473
+ def create_csp_stem(
474
+ in_chans=3,
475
+ out_chs=32,
476
+ kernel_size=3,
477
+ stride=2,
478
+ pool='',
479
+ padding='',
480
+ act_layer=nn.ReLU,
481
+ norm_layer=nn.BatchNorm2d,
482
+ aa_layer=None,
483
+ ):
484
+ stem = nn.Sequential()
485
+ feature_info = []
486
+ if not isinstance(out_chs, (tuple, list)):
487
+ out_chs = [out_chs]
488
+ stem_depth = len(out_chs)
489
+ assert stem_depth
490
+ assert stride in (1, 2, 4)
491
+ prev_feat = None
492
+ prev_chs = in_chans
493
+ last_idx = stem_depth - 1
494
+ stem_stride = 1
495
+ for i, chs in enumerate(out_chs):
496
+ conv_name = f'conv{i + 1}'
497
+ conv_stride = 2 if (i == 0 and stride > 1) or (i == last_idx and stride > 2 and not pool) else 1
498
+ if conv_stride > 1 and prev_feat is not None:
499
+ feature_info.append(prev_feat)
500
+ stem.add_module(conv_name, ConvNormAct(
501
+ prev_chs, chs, kernel_size,
502
+ stride=conv_stride,
503
+ padding=padding if i == 0 else '',
504
+ act_layer=act_layer,
505
+ norm_layer=norm_layer,
506
+ ))
507
+ stem_stride *= conv_stride
508
+ prev_chs = chs
509
+ prev_feat = dict(num_chs=prev_chs, reduction=stem_stride, module='.'.join(['stem', conv_name]))
510
+ if pool:
511
+ assert stride > 2
512
+ if prev_feat is not None:
513
+ feature_info.append(prev_feat)
514
+ if aa_layer is not None:
515
+ stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=1, padding=1))
516
+ stem.add_module('aa', aa_layer(channels=prev_chs, stride=2))
517
+ pool_name = 'aa'
518
+ else:
519
+ stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
520
+ pool_name = 'pool'
521
+ stem_stride *= 2
522
+ prev_feat = dict(num_chs=prev_chs, reduction=stem_stride, module='.'.join(['stem', pool_name]))
523
+ feature_info.append(prev_feat)
524
+ return stem, feature_info
525
+
526
+
527
+ def _get_stage_fn(stage_args):
528
+ stage_type = stage_args.pop('stage_type')
529
+ assert stage_type in ('dark', 'csp', 'cs3')
530
+ if stage_type == 'dark':
531
+ stage_args.pop('expand_ratio', None)
532
+ stage_args.pop('cross_linear', None)
533
+ stage_args.pop('down_growth', None)
534
+ stage_fn = DarkStage
535
+ elif stage_type == 'csp':
536
+ stage_fn = CrossStage
537
+ else:
538
+ stage_fn = CrossStage3
539
+ return stage_fn, stage_args
540
+
541
+
542
+ def _get_block_fn(stage_args):
543
+ block_type = stage_args.pop('block_type')
544
+ assert block_type in ('dark', 'edge', 'bottle')
545
+ if block_type == 'dark':
546
+ return DarkBlock, stage_args
547
+ elif block_type == 'edge':
548
+ return EdgeBlock, stage_args
549
+ else:
550
+ return BottleneckBlock, stage_args
551
+
552
+
553
+ def _get_attn_fn(stage_args):
554
+ attn_layer = stage_args.pop('attn_layer')
555
+ attn_kwargs = stage_args.pop('attn_kwargs', None) or {}
556
+ if attn_layer is not None:
557
+ attn_layer = get_attn(attn_layer)
558
+ if attn_kwargs:
559
+ attn_layer = partial(attn_layer, **attn_kwargs)
560
+ return attn_layer, stage_args
561
+
562
+
563
+ def create_csp_stages(
564
+ cfg: CspModelCfg,
565
+ drop_path_rate: float,
566
+ output_stride: int,
567
+ stem_feat: Dict[str, Any],
568
+ ):
569
+ cfg_dict = asdict(cfg.stages)
570
+ num_stages = len(cfg.stages.depth)
571
+ cfg_dict['block_dpr'] = [None] * num_stages if not drop_path_rate else \
572
+ [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.stages.depth)).split(cfg.stages.depth)]
573
+ stage_args = [dict(zip(cfg_dict.keys(), values)) for values in zip(*cfg_dict.values())]
574
+ block_kwargs = dict(
575
+ act_layer=cfg.act_layer,
576
+ norm_layer=cfg.norm_layer,
577
+ )
578
+
579
+ dilation = 1
580
+ net_stride = stem_feat['reduction']
581
+ prev_chs = stem_feat['num_chs']
582
+ prev_feat = stem_feat
583
+ feature_info = []
584
+ stages = []
585
+ for stage_idx, stage_args in enumerate(stage_args):
586
+ stage_fn, stage_args = _get_stage_fn(stage_args)
587
+ block_fn, stage_args = _get_block_fn(stage_args)
588
+ attn_fn, stage_args = _get_attn_fn(stage_args)
589
+ stride = stage_args.pop('stride')
590
+ if stride != 1 and prev_feat:
591
+ feature_info.append(prev_feat)
592
+ if net_stride >= output_stride and stride > 1:
593
+ dilation *= stride
594
+ stride = 1
595
+ net_stride *= stride
596
+ first_dilation = 1 if dilation in (1, 2) else 2
597
+
598
+ stages += [stage_fn(
599
+ prev_chs,
600
+ **stage_args,
601
+ stride=stride,
602
+ first_dilation=first_dilation,
603
+ dilation=dilation,
604
+ block_fn=block_fn,
605
+ aa_layer=cfg.aa_layer,
606
+ attn_layer=attn_fn, # will be passed through stage as block_kwargs
607
+ **block_kwargs,
608
+ )]
609
+ prev_chs = stage_args['out_chs']
610
+ prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}')
611
+
612
+ feature_info.append(prev_feat)
613
+ return nn.Sequential(*stages), feature_info
614
+
615
+
616
+ class CspNet(nn.Module):
617
+ """Cross Stage Partial base model.
618
+
619
+ Paper: `CSPNet: A New Backbone that can Enhance Learning Capability of CNN` - https://arxiv.org/abs/1911.11929
620
+ Ref Impl: https://github.com/WongKinYiu/CrossStagePartialNetworks
621
+
622
+ NOTE: There are differences in the way I handle the 1x1 'expansion' conv in this impl vs the
623
+ darknet impl. I did it this way for simplicity and less special cases.
624
+ """
625
+
626
+ def __init__(
627
+ self,
628
+ cfg: CspModelCfg,
629
+ in_chans=3,
630
+ num_classes=1000,
631
+ output_stride=32,
632
+ global_pool='avg',
633
+ drop_rate=0.,
634
+ drop_path_rate=0.,
635
+ zero_init_last=True,
636
+ **kwargs,
637
+ ):
638
+ """
639
+ Args:
640
+ cfg (CspModelCfg): Model architecture configuration
641
+ in_chans (int): Number of input channels (default: 3)
642
+ num_classes (int): Number of classifier classes (default: 1000)
643
+ output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
644
+ global_pool (str): Global pooling type (default: 'avg')
645
+ drop_rate (float): Dropout rate (default: 0.)
646
+ drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
647
+ zero_init_last (bool): Zero-init last weight of residual path
648
+ kwargs (dict): Extra kwargs overlayed onto cfg
649
+ """
650
+ super().__init__()
651
+ self.num_classes = num_classes
652
+ self.drop_rate = drop_rate
653
+ assert output_stride in (8, 16, 32)
654
+
655
+ cfg = replace(cfg, **kwargs) # overlay kwargs onto cfg
656
+ layer_args = dict(
657
+ act_layer=cfg.act_layer,
658
+ norm_layer=cfg.norm_layer,
659
+ aa_layer=cfg.aa_layer
660
+ )
661
+ self.feature_info = []
662
+
663
+ # Construct the stem
664
+ self.stem, stem_feat_info = create_csp_stem(in_chans, **asdict(cfg.stem), **layer_args)
665
+ self.feature_info.extend(stem_feat_info[:-1])
666
+
667
+ # Construct the stages
668
+ self.stages, stage_feat_info = create_csp_stages(
669
+ cfg,
670
+ drop_path_rate=drop_path_rate,
671
+ output_stride=output_stride,
672
+ stem_feat=stem_feat_info[-1],
673
+ )
674
+ prev_chs = stage_feat_info[-1]['num_chs']
675
+ self.feature_info.extend(stage_feat_info)
676
+
677
+ # Construct the head
678
+ self.num_features = prev_chs
679
+ self.head = ClassifierHead(
680
+ in_features=prev_chs, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
681
+
682
+ named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
683
+
684
+ @torch.jit.ignore
685
+ def group_matcher(self, coarse=False):
686
+ matcher = dict(
687
+ stem=r'^stem',
688
+ blocks=r'^stages\.(\d+)' if coarse else [
689
+ (r'^stages\.(\d+)\.blocks\.(\d+)', None),
690
+ (r'^stages\.(\d+)\..*transition', MATCH_PREV_GROUP), # map to last block in stage
691
+ (r'^stages\.(\d+)', (0,)),
692
+ ]
693
+ )
694
+ return matcher
695
+
696
+ @torch.jit.ignore
697
+ def set_grad_checkpointing(self, enable=True):
698
+ assert not enable, 'gradient checkpointing not supported'
699
+
700
+ @torch.jit.ignore
701
+ def get_classifier(self):
702
+ return self.head.fc
703
+
704
+ def reset_classifier(self, num_classes, global_pool='avg'):
705
+ self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
706
+
707
+ def forward_features(self, x):
708
+ x = self.stem(x)
709
+ x = self.stages(x)
710
+ return x
711
+
712
+ def forward_head(self, x, pre_logits: bool = False):
713
+ return self.head(x, pre_logits=pre_logits)
714
+
715
+ def forward(self, x):
716
+ x = self.forward_features(x)
717
+ x = self.forward_head(x)
718
+ return x
719
+
720
+
721
+ def _init_weights(module, name, zero_init_last=False):
722
+ if isinstance(module, nn.Conv2d):
723
+ nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
724
+ if module.bias is not None:
725
+ nn.init.zeros_(module.bias)
726
+ elif isinstance(module, nn.Linear):
727
+ nn.init.normal_(module.weight, mean=0.0, std=0.01)
728
+ if module.bias is not None:
729
+ nn.init.zeros_(module.bias)
730
+ elif zero_init_last and hasattr(module, 'zero_init_last'):
731
+ module.zero_init_last()
732
+
733
+
734
+ model_cfgs = dict(
735
+ cspresnet50=CspModelCfg(
736
+ stem=CspStemCfg(out_chs=64, kernel_size=7, stride=4, pool='max'),
737
+ stages=CspStagesCfg(
738
+ depth=(3, 3, 5, 2),
739
+ out_chs=(128, 256, 512, 1024),
740
+ stride=(1, 2),
741
+ expand_ratio=2.,
742
+ bottle_ratio=0.5,
743
+ cross_linear=True,
744
+ ),
745
+ ),
746
+ cspresnet50d=CspModelCfg(
747
+ stem=CspStemCfg(out_chs=(32, 32, 64), kernel_size=3, stride=4, pool='max'),
748
+ stages=CspStagesCfg(
749
+ depth=(3, 3, 5, 2),
750
+ out_chs=(128, 256, 512, 1024),
751
+ stride=(1,) + (2,),
752
+ expand_ratio=2.,
753
+ bottle_ratio=0.5,
754
+ block_ratio=1.,
755
+ cross_linear=True,
756
+ ),
757
+ ),
758
+ cspresnet50w=CspModelCfg(
759
+ stem=CspStemCfg(out_chs=(32, 32, 64), kernel_size=3, stride=4, pool='max'),
760
+ stages=CspStagesCfg(
761
+ depth=(3, 3, 5, 2),
762
+ out_chs=(256, 512, 1024, 2048),
763
+ stride=(1,) + (2,),
764
+ expand_ratio=1.,
765
+ bottle_ratio=0.25,
766
+ block_ratio=0.5,
767
+ cross_linear=True,
768
+ ),
769
+ ),
770
+ cspresnext50=CspModelCfg(
771
+ stem=CspStemCfg(out_chs=64, kernel_size=7, stride=4, pool='max'),
772
+ stages=CspStagesCfg(
773
+ depth=(3, 3, 5, 2),
774
+ out_chs=(256, 512, 1024, 2048),
775
+ stride=(1,) + (2,),
776
+ groups=32,
777
+ expand_ratio=1.,
778
+ bottle_ratio=1.,
779
+ block_ratio=0.5,
780
+ cross_linear=True,
781
+ ),
782
+ ),
783
+ cspdarknet53=CspModelCfg(
784
+ stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
785
+ stages=CspStagesCfg(
786
+ depth=(1, 2, 8, 8, 4),
787
+ out_chs=(64, 128, 256, 512, 1024),
788
+ stride=2,
789
+ expand_ratio=(2.,) + (1.,),
790
+ bottle_ratio=(0.5,) + (1.,),
791
+ block_ratio=(1.,) + (0.5,),
792
+ down_growth=True,
793
+ block_type='dark',
794
+ ),
795
+ ),
796
+ darknet17=CspModelCfg(
797
+ stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
798
+ stages=CspStagesCfg(
799
+ depth=(1,) * 5,
800
+ out_chs=(64, 128, 256, 512, 1024),
801
+ stride=(2,),
802
+ bottle_ratio=(0.5,),
803
+ block_ratio=(1.,),
804
+ stage_type='dark',
805
+ block_type='dark',
806
+ ),
807
+ ),
808
+ darknet21=CspModelCfg(
809
+ stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
810
+ stages=CspStagesCfg(
811
+ depth=(1, 1, 1, 2, 2),
812
+ out_chs=(64, 128, 256, 512, 1024),
813
+ stride=(2,),
814
+ bottle_ratio=(0.5,),
815
+ block_ratio=(1.,),
816
+ stage_type='dark',
817
+ block_type='dark',
818
+
819
+ ),
820
+ ),
821
+ sedarknet21=CspModelCfg(
822
+ stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
823
+ stages=CspStagesCfg(
824
+ depth=(1, 1, 1, 2, 2),
825
+ out_chs=(64, 128, 256, 512, 1024),
826
+ stride=2,
827
+ bottle_ratio=0.5,
828
+ block_ratio=1.,
829
+ attn_layer='se',
830
+ stage_type='dark',
831
+ block_type='dark',
832
+
833
+ ),
834
+ ),
835
+ darknet53=CspModelCfg(
836
+ stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
837
+ stages=CspStagesCfg(
838
+ depth=(1, 2, 8, 8, 4),
839
+ out_chs=(64, 128, 256, 512, 1024),
840
+ stride=2,
841
+ bottle_ratio=0.5,
842
+ block_ratio=1.,
843
+ stage_type='dark',
844
+ block_type='dark',
845
+ ),
846
+ ),
847
+ darknetaa53=CspModelCfg(
848
+ stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
849
+ stages=CspStagesCfg(
850
+ depth=(1, 2, 8, 8, 4),
851
+ out_chs=(64, 128, 256, 512, 1024),
852
+ stride=2,
853
+ bottle_ratio=0.5,
854
+ block_ratio=1.,
855
+ avg_down=True,
856
+ stage_type='dark',
857
+ block_type='dark',
858
+ ),
859
+ ),
860
+
861
+ cs3darknet_s=_cs3_cfg(width_multiplier=0.5, depth_multiplier=0.5),
862
+ cs3darknet_m=_cs3_cfg(width_multiplier=0.75, depth_multiplier=0.67),
863
+ cs3darknet_l=_cs3_cfg(),
864
+ cs3darknet_x=_cs3_cfg(width_multiplier=1.25, depth_multiplier=1.33),
865
+
866
+ cs3darknet_focus_s=_cs3_cfg(width_multiplier=0.5, depth_multiplier=0.5, focus=True),
867
+ cs3darknet_focus_m=_cs3_cfg(width_multiplier=0.75, depth_multiplier=0.67, focus=True),
868
+ cs3darknet_focus_l=_cs3_cfg(focus=True),
869
+ cs3darknet_focus_x=_cs3_cfg(width_multiplier=1.25, depth_multiplier=1.33, focus=True),
870
+
871
+ cs3sedarknet_l=_cs3_cfg(attn_layer='se', attn_kwargs=dict(rd_ratio=.25)),
872
+ cs3sedarknet_x=_cs3_cfg(attn_layer='se', width_multiplier=1.25, depth_multiplier=1.33),
873
+
874
+ cs3sedarknet_xdw=CspModelCfg(
875
+ stem=CspStemCfg(out_chs=(32, 64), kernel_size=3, stride=2, pool=''),
876
+ stages=CspStagesCfg(
877
+ depth=(3, 6, 12, 4),
878
+ out_chs=(256, 512, 1024, 2048),
879
+ stride=2,
880
+ groups=(1, 1, 256, 512),
881
+ bottle_ratio=0.5,
882
+ block_ratio=0.5,
883
+ attn_layer='se',
884
+ ),
885
+ act_layer='silu',
886
+ ),
887
+
888
+ cs3edgenet_x=_cs3_cfg(width_multiplier=1.25, depth_multiplier=1.33, bottle_ratio=1.5, block_type='edge'),
889
+ cs3se_edgenet_x=_cs3_cfg(
890
+ width_multiplier=1.25, depth_multiplier=1.33, bottle_ratio=1.5, block_type='edge',
891
+ attn_layer='se', attn_kwargs=dict(rd_ratio=.25)),
892
+ )
893
+
894
+
895
+ def _create_cspnet(variant, pretrained=False, **kwargs):
896
+ if variant.startswith('darknet') or variant.startswith('cspdarknet'):
897
+ # NOTE: DarkNet is one of few models with stride==1 features w/ 6 out_indices [0..5]
898
+ default_out_indices = (0, 1, 2, 3, 4, 5)
899
+ else:
900
+ default_out_indices = (0, 1, 2, 3, 4)
901
+ out_indices = kwargs.pop('out_indices', default_out_indices)
902
+ return build_model_with_cfg(
903
+ CspNet, variant, pretrained,
904
+ model_cfg=model_cfgs[variant],
905
+ feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
906
+ **kwargs)
907
+
908
+
909
+ def _cfg(url='', **kwargs):
910
+ return {
911
+ 'url': url,
912
+ 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
913
+ 'crop_pct': 0.887, 'interpolation': 'bilinear',
914
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
915
+ 'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc',
916
+ **kwargs
917
+ }
918
+
919
+
920
+ default_cfgs = generate_default_cfgs({
921
+ 'cspresnet50.ra_in1k': _cfg(
922
+ hf_hub_id='timm/',
923
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspresnet50_ra-d3e8d487.pth'),
924
+ 'cspresnet50d.untrained': _cfg(),
925
+ 'cspresnet50w.untrained': _cfg(),
926
+ 'cspresnext50.ra_in1k': _cfg(
927
+ hf_hub_id='timm/',
928
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspresnext50_ra_224-648b4713.pth',
929
+ ),
930
+ 'cspdarknet53.ra_in1k': _cfg(
931
+ hf_hub_id='timm/',
932
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspdarknet53_ra_256-d05c7c21.pth'),
933
+
934
+ 'darknet17.untrained': _cfg(),
935
+ 'darknet21.untrained': _cfg(),
936
+ 'sedarknet21.untrained': _cfg(),
937
+ 'darknet53.c2ns_in1k': _cfg(
938
+ hf_hub_id='timm/',
939
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/darknet53_256_c2ns-3aeff817.pth',
940
+ interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=1.0),
941
+ 'darknetaa53.c2ns_in1k': _cfg(
942
+ hf_hub_id='timm/',
943
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/darknetaa53_c2ns-5c28ec8a.pth',
944
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
945
+
946
+ 'cs3darknet_s.untrained': _cfg(interpolation='bicubic'),
947
+ 'cs3darknet_m.c2ns_in1k': _cfg(
948
+ hf_hub_id='timm/',
949
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_m_c2ns-43f06604.pth',
950
+ interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95,
951
+ ),
952
+ 'cs3darknet_l.c2ns_in1k': _cfg(
953
+ hf_hub_id='timm/',
954
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_l_c2ns-16220c5d.pth',
955
+ interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
956
+ 'cs3darknet_x.c2ns_in1k': _cfg(
957
+ hf_hub_id='timm/',
958
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_x_c2ns-4e4490aa.pth',
959
+ interpolation='bicubic', crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
960
+
961
+ 'cs3darknet_focus_s.untrained': _cfg(interpolation='bicubic'),
962
+ 'cs3darknet_focus_m.c2ns_in1k': _cfg(
963
+ hf_hub_id='timm/',
964
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_focus_m_c2ns-e23bed41.pth',
965
+ interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
966
+ 'cs3darknet_focus_l.c2ns_in1k': _cfg(
967
+ hf_hub_id='timm/',
968
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_focus_l_c2ns-65ef8888.pth',
969
+ interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
970
+ 'cs3darknet_focus_x.untrained': _cfg(interpolation='bicubic'),
971
+
972
+ 'cs3sedarknet_l.c2ns_in1k': _cfg(
973
+ hf_hub_id='timm/',
974
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3sedarknet_l_c2ns-e8d1dc13.pth',
975
+ interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
976
+ 'cs3sedarknet_x.c2ns_in1k': _cfg(
977
+ hf_hub_id='timm/',
978
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3sedarknet_x_c2ns-b4d0abc0.pth',
979
+ interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=1.0),
980
+
981
+ 'cs3sedarknet_xdw.untrained': _cfg(interpolation='bicubic'),
982
+
983
+ 'cs3edgenet_x.c2_in1k': _cfg(
984
+ hf_hub_id='timm/',
985
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3edgenet_x_c2-2e1610a9.pth',
986
+ interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=1.0),
987
+ 'cs3se_edgenet_x.c2ns_in1k': _cfg(
988
+ hf_hub_id='timm/',
989
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3se_edgenet_x_c2ns-76f8e3ac.pth',
990
+ interpolation='bicubic', crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0),
991
+ })
992
+
993
+
994
+ @register_model
995
+ def cspresnet50(pretrained=False, **kwargs) -> CspNet:
996
+ return _create_cspnet('cspresnet50', pretrained=pretrained, **kwargs)
997
+
998
+
999
+ @register_model
1000
+ def cspresnet50d(pretrained=False, **kwargs) -> CspNet:
1001
+ return _create_cspnet('cspresnet50d', pretrained=pretrained, **kwargs)
1002
+
1003
+
1004
+ @register_model
1005
+ def cspresnet50w(pretrained=False, **kwargs) -> CspNet:
1006
+ return _create_cspnet('cspresnet50w', pretrained=pretrained, **kwargs)
1007
+
1008
+
1009
+ @register_model
1010
+ def cspresnext50(pretrained=False, **kwargs) -> CspNet:
1011
+ return _create_cspnet('cspresnext50', pretrained=pretrained, **kwargs)
1012
+
1013
+
1014
+ @register_model
1015
+ def cspdarknet53(pretrained=False, **kwargs) -> CspNet:
1016
+ return _create_cspnet('cspdarknet53', pretrained=pretrained, **kwargs)
1017
+
1018
+
1019
+ @register_model
1020
+ def darknet17(pretrained=False, **kwargs) -> CspNet:
1021
+ return _create_cspnet('darknet17', pretrained=pretrained, **kwargs)
1022
+
1023
+
1024
+ @register_model
1025
+ def darknet21(pretrained=False, **kwargs) -> CspNet:
1026
+ return _create_cspnet('darknet21', pretrained=pretrained, **kwargs)
1027
+
1028
+
1029
+ @register_model
1030
+ def sedarknet21(pretrained=False, **kwargs) -> CspNet:
1031
+ return _create_cspnet('sedarknet21', pretrained=pretrained, **kwargs)
1032
+
1033
+
1034
+ @register_model
1035
+ def darknet53(pretrained=False, **kwargs) -> CspNet:
1036
+ return _create_cspnet('darknet53', pretrained=pretrained, **kwargs)
1037
+
1038
+
1039
+ @register_model
1040
+ def darknetaa53(pretrained=False, **kwargs) -> CspNet:
1041
+ return _create_cspnet('darknetaa53', pretrained=pretrained, **kwargs)
1042
+
1043
+
1044
+ @register_model
1045
+ def cs3darknet_s(pretrained=False, **kwargs) -> CspNet:
1046
+ return _create_cspnet('cs3darknet_s', pretrained=pretrained, **kwargs)
1047
+
1048
+
1049
+ @register_model
1050
+ def cs3darknet_m(pretrained=False, **kwargs) -> CspNet:
1051
+ return _create_cspnet('cs3darknet_m', pretrained=pretrained, **kwargs)
1052
+
1053
+
1054
+ @register_model
1055
+ def cs3darknet_l(pretrained=False, **kwargs) -> CspNet:
1056
+ return _create_cspnet('cs3darknet_l', pretrained=pretrained, **kwargs)
1057
+
1058
+
1059
+ @register_model
1060
+ def cs3darknet_x(pretrained=False, **kwargs) -> CspNet:
1061
+ return _create_cspnet('cs3darknet_x', pretrained=pretrained, **kwargs)
1062
+
1063
+
1064
+ @register_model
1065
+ def cs3darknet_focus_s(pretrained=False, **kwargs) -> CspNet:
1066
+ return _create_cspnet('cs3darknet_focus_s', pretrained=pretrained, **kwargs)
1067
+
1068
+
1069
+ @register_model
1070
+ def cs3darknet_focus_m(pretrained=False, **kwargs) -> CspNet:
1071
+ return _create_cspnet('cs3darknet_focus_m', pretrained=pretrained, **kwargs)
1072
+
1073
+
1074
+ @register_model
1075
+ def cs3darknet_focus_l(pretrained=False, **kwargs) -> CspNet:
1076
+ return _create_cspnet('cs3darknet_focus_l', pretrained=pretrained, **kwargs)
1077
+
1078
+
1079
+ @register_model
1080
+ def cs3darknet_focus_x(pretrained=False, **kwargs) -> CspNet:
1081
+ return _create_cspnet('cs3darknet_focus_x', pretrained=pretrained, **kwargs)
1082
+
1083
+
1084
+ @register_model
1085
+ def cs3sedarknet_l(pretrained=False, **kwargs) -> CspNet:
1086
+ return _create_cspnet('cs3sedarknet_l', pretrained=pretrained, **kwargs)
1087
+
1088
+
1089
+ @register_model
1090
+ def cs3sedarknet_x(pretrained=False, **kwargs) -> CspNet:
1091
+ return _create_cspnet('cs3sedarknet_x', pretrained=pretrained, **kwargs)
1092
+
1093
+
1094
+ @register_model
1095
+ def cs3sedarknet_xdw(pretrained=False, **kwargs) -> CspNet:
1096
+ return _create_cspnet('cs3sedarknet_xdw', pretrained=pretrained, **kwargs)
1097
+
1098
+
1099
+ @register_model
1100
+ def cs3edgenet_x(pretrained=False, **kwargs) -> CspNet:
1101
+ return _create_cspnet('cs3edgenet_x', pretrained=pretrained, **kwargs)
1102
+
1103
+
1104
+ @register_model
1105
+ def cs3se_edgenet_x(pretrained=False, **kwargs) -> CspNet:
1106
+ return _create_cspnet('cs3se_edgenet_x', pretrained=pretrained, **kwargs)
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/deit.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ DeiT - Data-efficient Image Transformers
2
+
3
+ DeiT model defs and weights from https://github.com/facebookresearch/deit, original copyright below
4
+
5
+ paper: `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
6
+
7
+ paper: `DeiT III: Revenge of the ViT` - https://arxiv.org/abs/2204.07118
8
+
9
+ Modifications copyright 2021, Ross Wightman
10
+ """
11
+ # Copyright (c) 2015-present, Facebook, Inc.
12
+ # All rights reserved.
13
+ from functools import partial
14
+ from typing import Sequence, Union
15
+
16
+ import torch
17
+ from torch import nn as nn
18
+
19
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
20
+ from timm.layers import resample_abs_pos_embed
21
+ from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn
22
+ from ._builder import build_model_with_cfg
23
+ from ._manipulate import checkpoint_seq
24
+ from ._registry import generate_default_cfgs, register_model, register_model_deprecations
25
+
26
+ __all__ = ['VisionTransformerDistilled'] # model_registry will add each entrypoint fn to this
27
+
28
+
29
+ class VisionTransformerDistilled(VisionTransformer):
30
+ """ Vision Transformer w/ Distillation Token and Head
31
+
32
+ Distillation token & head support for `DeiT: Data-efficient Image Transformers`
33
+ - https://arxiv.org/abs/2012.12877
34
+ """
35
+
36
+ def __init__(self, *args, **kwargs):
37
+ weight_init = kwargs.pop('weight_init', '')
38
+ super().__init__(*args, **kwargs, weight_init='skip')
39
+ assert self.global_pool in ('token',)
40
+
41
+ self.num_prefix_tokens = 2
42
+ self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
43
+ self.pos_embed = nn.Parameter(
44
+ torch.zeros(1, self.patch_embed.num_patches + self.num_prefix_tokens, self.embed_dim))
45
+ self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
46
+ self.distilled_training = False # must set this True to train w/ distillation token
47
+
48
+ self.init_weights(weight_init)
49
+
50
+ def init_weights(self, mode=''):
51
+ trunc_normal_(self.dist_token, std=.02)
52
+ super().init_weights(mode=mode)
53
+
54
+ @torch.jit.ignore
55
+ def group_matcher(self, coarse=False):
56
+ return dict(
57
+ stem=r'^cls_token|pos_embed|patch_embed|dist_token',
58
+ blocks=[
59
+ (r'^blocks\.(\d+)', None),
60
+ (r'^norm', (99999,))] # final norm w/ last block
61
+ )
62
+
63
+ @torch.jit.ignore
64
+ def get_classifier(self):
65
+ return self.head, self.head_dist
66
+
67
+ def reset_classifier(self, num_classes, global_pool=None):
68
+ self.num_classes = num_classes
69
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
70
+ self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
71
+
72
+ @torch.jit.ignore
73
+ def set_distilled_training(self, enable=True):
74
+ self.distilled_training = enable
75
+
76
+ def _pos_embed(self, x):
77
+ if self.dynamic_img_size:
78
+ B, H, W, C = x.shape
79
+ pos_embed = resample_abs_pos_embed(
80
+ self.pos_embed,
81
+ (H, W),
82
+ num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
83
+ )
84
+ x = x.view(B, -1, C)
85
+ else:
86
+ pos_embed = self.pos_embed
87
+ if self.no_embed_class:
88
+ # deit-3, updated JAX (big vision)
89
+ # position embedding does not overlap with class token, add then concat
90
+ x = x + pos_embed
91
+ x = torch.cat((
92
+ self.cls_token.expand(x.shape[0], -1, -1),
93
+ self.dist_token.expand(x.shape[0], -1, -1),
94
+ x),
95
+ dim=1)
96
+ else:
97
+ # original timm, JAX, and deit vit impl
98
+ # pos_embed has entry for class token, concat then add
99
+ x = torch.cat((
100
+ self.cls_token.expand(x.shape[0], -1, -1),
101
+ self.dist_token.expand(x.shape[0], -1, -1),
102
+ x),
103
+ dim=1)
104
+ x = x + pos_embed
105
+ return self.pos_drop(x)
106
+
107
+ def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
108
+ x, x_dist = x[:, 0], x[:, 1]
109
+ if pre_logits:
110
+ return (x + x_dist) / 2
111
+ x = self.head(x)
112
+ x_dist = self.head_dist(x_dist)
113
+ if self.distilled_training and self.training and not torch.jit.is_scripting():
114
+ # only return separate classification predictions when training in distilled mode
115
+ return x, x_dist
116
+ else:
117
+ # during standard train / finetune, inference average the classifier predictions
118
+ return (x + x_dist) / 2
119
+
120
+
121
+ def _create_deit(variant, pretrained=False, distilled=False, **kwargs):
122
+ if kwargs.get('features_only', None):
123
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
124
+ model_cls = VisionTransformerDistilled if distilled else VisionTransformer
125
+ model = build_model_with_cfg(
126
+ model_cls,
127
+ variant,
128
+ pretrained,
129
+ pretrained_filter_fn=partial(checkpoint_filter_fn, adapt_layer_scale=True),
130
+ **kwargs,
131
+ )
132
+ return model
133
+
134
+
135
+ def _cfg(url='', **kwargs):
136
+ return {
137
+ 'url': url,
138
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
139
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
140
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
141
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
142
+ **kwargs
143
+ }
144
+
145
+
146
+ default_cfgs = generate_default_cfgs({
147
+ # deit models (FB weights)
148
+ 'deit_tiny_patch16_224.fb_in1k': _cfg(
149
+ hf_hub_id='timm/',
150
+ url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
151
+ 'deit_small_patch16_224.fb_in1k': _cfg(
152
+ hf_hub_id='timm/',
153
+ url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
154
+ 'deit_base_patch16_224.fb_in1k': _cfg(
155
+ hf_hub_id='timm/',
156
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth'),
157
+ 'deit_base_patch16_384.fb_in1k': _cfg(
158
+ hf_hub_id='timm/',
159
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
160
+ input_size=(3, 384, 384), crop_pct=1.0),
161
+
162
+ 'deit_tiny_distilled_patch16_224.fb_in1k': _cfg(
163
+ hf_hub_id='timm/',
164
+ url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
165
+ classifier=('head', 'head_dist')),
166
+ 'deit_small_distilled_patch16_224.fb_in1k': _cfg(
167
+ hf_hub_id='timm/',
168
+ url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
169
+ classifier=('head', 'head_dist')),
170
+ 'deit_base_distilled_patch16_224.fb_in1k': _cfg(
171
+ hf_hub_id='timm/',
172
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
173
+ classifier=('head', 'head_dist')),
174
+ 'deit_base_distilled_patch16_384.fb_in1k': _cfg(
175
+ hf_hub_id='timm/',
176
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
177
+ input_size=(3, 384, 384), crop_pct=1.0,
178
+ classifier=('head', 'head_dist')),
179
+
180
+ 'deit3_small_patch16_224.fb_in1k': _cfg(
181
+ hf_hub_id='timm/',
182
+ url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_1k.pth'),
183
+ 'deit3_small_patch16_384.fb_in1k': _cfg(
184
+ hf_hub_id='timm/',
185
+ url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_1k.pth',
186
+ input_size=(3, 384, 384), crop_pct=1.0),
187
+ 'deit3_medium_patch16_224.fb_in1k': _cfg(
188
+ hf_hub_id='timm/',
189
+ url='https://dl.fbaipublicfiles.com/deit/deit_3_medium_224_1k.pth'),
190
+ 'deit3_base_patch16_224.fb_in1k': _cfg(
191
+ hf_hub_id='timm/',
192
+ url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_1k.pth'),
193
+ 'deit3_base_patch16_384.fb_in1k': _cfg(
194
+ hf_hub_id='timm/',
195
+ url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_1k.pth',
196
+ input_size=(3, 384, 384), crop_pct=1.0),
197
+ 'deit3_large_patch16_224.fb_in1k': _cfg(
198
+ hf_hub_id='timm/',
199
+ url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_1k.pth'),
200
+ 'deit3_large_patch16_384.fb_in1k': _cfg(
201
+ hf_hub_id='timm/',
202
+ url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_1k.pth',
203
+ input_size=(3, 384, 384), crop_pct=1.0),
204
+ 'deit3_huge_patch14_224.fb_in1k': _cfg(
205
+ hf_hub_id='timm/',
206
+ url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_1k.pth'),
207
+
208
+ 'deit3_small_patch16_224.fb_in22k_ft_in1k': _cfg(
209
+ hf_hub_id='timm/',
210
+ url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_21k.pth',
211
+ crop_pct=1.0),
212
+ 'deit3_small_patch16_384.fb_in22k_ft_in1k': _cfg(
213
+ hf_hub_id='timm/',
214
+ url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_21k.pth',
215
+ input_size=(3, 384, 384), crop_pct=1.0),
216
+ 'deit3_medium_patch16_224.fb_in22k_ft_in1k': _cfg(
217
+ hf_hub_id='timm/',
218
+ url='https://dl.fbaipublicfiles.com/deit/deit_3_medium_224_21k.pth',
219
+ crop_pct=1.0),
220
+ 'deit3_base_patch16_224.fb_in22k_ft_in1k': _cfg(
221
+ hf_hub_id='timm/',
222
+ url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_21k.pth',
223
+ crop_pct=1.0),
224
+ 'deit3_base_patch16_384.fb_in22k_ft_in1k': _cfg(
225
+ hf_hub_id='timm/',
226
+ url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_21k.pth',
227
+ input_size=(3, 384, 384), crop_pct=1.0),
228
+ 'deit3_large_patch16_224.fb_in22k_ft_in1k': _cfg(
229
+ hf_hub_id='timm/',
230
+ url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_21k.pth',
231
+ crop_pct=1.0),
232
+ 'deit3_large_patch16_384.fb_in22k_ft_in1k': _cfg(
233
+ hf_hub_id='timm/',
234
+ url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_21k.pth',
235
+ input_size=(3, 384, 384), crop_pct=1.0),
236
+ 'deit3_huge_patch14_224.fb_in22k_ft_in1k': _cfg(
237
+ hf_hub_id='timm/',
238
+ url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_21k_v1.pth',
239
+ crop_pct=1.0),
240
+ })
241
+
242
+
243
+ @register_model
244
+ def deit_tiny_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
245
+ """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
246
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
247
+ """
248
+ model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3)
249
+ model = _create_deit('deit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
250
+ return model
251
+
252
+
253
+ @register_model
254
+ def deit_small_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
255
+ """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
256
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
257
+ """
258
+ model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6)
259
+ model = _create_deit('deit_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
260
+ return model
261
+
262
+
263
+ @register_model
264
+ def deit_base_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
265
+ """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
266
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
267
+ """
268
+ model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
269
+ model = _create_deit('deit_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
270
+ return model
271
+
272
+
273
+ @register_model
274
+ def deit_base_patch16_384(pretrained=False, **kwargs) -> VisionTransformer:
275
+ """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
276
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
277
+ """
278
+ model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
279
+ model = _create_deit('deit_base_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
280
+ return model
281
+
282
+
283
+ @register_model
284
+ def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs) -> VisionTransformerDistilled:
285
+ """ DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
286
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
287
+ """
288
+ model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3)
289
+ model = _create_deit(
290
+ 'deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **dict(model_args, **kwargs))
291
+ return model
292
+
293
+
294
+ @register_model
295
+ def deit_small_distilled_patch16_224(pretrained=False, **kwargs) -> VisionTransformerDistilled:
296
+ """ DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
297
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
298
+ """
299
+ model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6)
300
+ model = _create_deit(
301
+ 'deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **dict(model_args, **kwargs))
302
+ return model
303
+
304
+
305
+ @register_model
306
+ def deit_base_distilled_patch16_224(pretrained=False, **kwargs) -> VisionTransformerDistilled:
307
+ """ DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
308
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
309
+ """
310
+ model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
311
+ model = _create_deit(
312
+ 'deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **dict(model_args, **kwargs))
313
+ return model
314
+
315
+
316
+ @register_model
317
+ def deit_base_distilled_patch16_384(pretrained=False, **kwargs) -> VisionTransformerDistilled:
318
+ """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
319
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
320
+ """
321
+ model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
322
+ model = _create_deit(
323
+ 'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **dict(model_args, **kwargs))
324
+ return model
325
+
326
+
327
+ @register_model
328
+ def deit3_small_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
329
+ """ DeiT-3 small model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
330
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
331
+ """
332
+ model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6)
333
+ model = _create_deit('deit3_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
334
+ return model
335
+
336
+
337
+ @register_model
338
+ def deit3_small_patch16_384(pretrained=False, **kwargs) -> VisionTransformer:
339
+ """ DeiT-3 small model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
340
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
341
+ """
342
+ model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6)
343
+ model = _create_deit('deit3_small_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
344
+ return model
345
+
346
+
347
+ @register_model
348
+ def deit3_medium_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
349
+ """ DeiT-3 medium model @ 224x224 (https://arxiv.org/abs/2012.12877).
350
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
351
+ """
352
+ model_args = dict(patch_size=16, embed_dim=512, depth=12, num_heads=8, no_embed_class=True, init_values=1e-6)
353
+ model = _create_deit('deit3_medium_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
354
+ return model
355
+
356
+
357
+ @register_model
358
+ def deit3_base_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
359
+ """ DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
360
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
361
+ """
362
+ model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6)
363
+ model = _create_deit('deit3_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
364
+ return model
365
+
366
+
367
+ @register_model
368
+ def deit3_base_patch16_384(pretrained=False, **kwargs) -> VisionTransformer:
369
+ """ DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
370
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
371
+ """
372
+ model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6)
373
+ model = _create_deit('deit3_base_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
374
+ return model
375
+
376
+
377
+ @register_model
378
+ def deit3_large_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
379
+ """ DeiT-3 large model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
380
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
381
+ """
382
+ model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6)
383
+ model = _create_deit('deit3_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
384
+ return model
385
+
386
+
387
+ @register_model
388
+ def deit3_large_patch16_384(pretrained=False, **kwargs) -> VisionTransformer:
389
+ """ DeiT-3 large model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
390
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
391
+ """
392
+ model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6)
393
+ model = _create_deit('deit3_large_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
394
+ return model
395
+
396
+
397
+ @register_model
398
+ def deit3_huge_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
399
+ """ DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
400
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
401
+ """
402
+ model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, no_embed_class=True, init_values=1e-6)
403
+ model = _create_deit('deit3_huge_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
404
+ return model
405
+
406
+
407
+ register_model_deprecations(__name__, {
408
+ 'deit3_small_patch16_224_in21ft1k': 'deit3_small_patch16_224.fb_in22k_ft_in1k',
409
+ 'deit3_small_patch16_384_in21ft1k': 'deit3_small_patch16_384.fb_in22k_ft_in1k',
410
+ 'deit3_medium_patch16_224_in21ft1k': 'deit3_medium_patch16_224.fb_in22k_ft_in1k',
411
+ 'deit3_base_patch16_224_in21ft1k': 'deit3_base_patch16_224.fb_in22k_ft_in1k',
412
+ 'deit3_base_patch16_384_in21ft1k': 'deit3_base_patch16_384.fb_in22k_ft_in1k',
413
+ 'deit3_large_patch16_224_in21ft1k': 'deit3_large_patch16_224.fb_in22k_ft_in1k',
414
+ 'deit3_large_patch16_384_in21ft1k': 'deit3_large_patch16_384.fb_in22k_ft_in1k',
415
+ 'deit3_huge_patch14_224_in21ft1k': 'deit3_huge_patch14_224.fb_in22k_ft_in1k'
416
+ })
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/dla.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Deep Layer Aggregation and DLA w/ Res2Net
2
+ DLA original adapted from Official Pytorch impl at: https://github.com/ucbdrive/dla
3
+ DLA Paper: `Deep Layer Aggregation` - https://arxiv.org/abs/1707.06484
4
+
5
+ Res2Net additions from: https://github.com/gasvn/Res2Net/
6
+ Res2Net Paper: `Res2Net: A New Multi-scale Backbone Architecture` - https://arxiv.org/abs/1904.01169
7
+ """
8
+ import math
9
+ from typing import List, Optional
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
16
+ from timm.layers import create_classifier
17
+ from ._builder import build_model_with_cfg
18
+ from ._registry import register_model, generate_default_cfgs
19
+
20
+ __all__ = ['DLA']
21
+
22
+
23
+ class DlaBasic(nn.Module):
24
+ """DLA Basic"""
25
+
26
+ def __init__(self, inplanes, planes, stride=1, dilation=1, **_):
27
+ super(DlaBasic, self).__init__()
28
+ self.conv1 = nn.Conv2d(
29
+ inplanes, planes, kernel_size=3,
30
+ stride=stride, padding=dilation, bias=False, dilation=dilation)
31
+ self.bn1 = nn.BatchNorm2d(planes)
32
+ self.relu = nn.ReLU(inplace=True)
33
+ self.conv2 = nn.Conv2d(
34
+ planes, planes, kernel_size=3,
35
+ stride=1, padding=dilation, bias=False, dilation=dilation)
36
+ self.bn2 = nn.BatchNorm2d(planes)
37
+ self.stride = stride
38
+
39
+ def forward(self, x, shortcut: Optional[torch.Tensor] = None, children: Optional[List[torch.Tensor]] = None):
40
+ if shortcut is None:
41
+ shortcut = x
42
+
43
+ out = self.conv1(x)
44
+ out = self.bn1(out)
45
+ out = self.relu(out)
46
+
47
+ out = self.conv2(out)
48
+ out = self.bn2(out)
49
+
50
+ out += shortcut
51
+ out = self.relu(out)
52
+
53
+ return out
54
+
55
+
56
+ class DlaBottleneck(nn.Module):
57
+ """DLA/DLA-X Bottleneck"""
58
+ expansion = 2
59
+
60
+ def __init__(self, inplanes, outplanes, stride=1, dilation=1, cardinality=1, base_width=64):
61
+ super(DlaBottleneck, self).__init__()
62
+ self.stride = stride
63
+ mid_planes = int(math.floor(outplanes * (base_width / 64)) * cardinality)
64
+ mid_planes = mid_planes // self.expansion
65
+
66
+ self.conv1 = nn.Conv2d(inplanes, mid_planes, kernel_size=1, bias=False)
67
+ self.bn1 = nn.BatchNorm2d(mid_planes)
68
+ self.conv2 = nn.Conv2d(
69
+ mid_planes, mid_planes, kernel_size=3,
70
+ stride=stride, padding=dilation, bias=False, dilation=dilation, groups=cardinality)
71
+ self.bn2 = nn.BatchNorm2d(mid_planes)
72
+ self.conv3 = nn.Conv2d(mid_planes, outplanes, kernel_size=1, bias=False)
73
+ self.bn3 = nn.BatchNorm2d(outplanes)
74
+ self.relu = nn.ReLU(inplace=True)
75
+
76
+ def forward(self, x, shortcut: Optional[torch.Tensor] = None, children: Optional[List[torch.Tensor]] = None):
77
+ if shortcut is None:
78
+ shortcut = x
79
+
80
+ out = self.conv1(x)
81
+ out = self.bn1(out)
82
+ out = self.relu(out)
83
+
84
+ out = self.conv2(out)
85
+ out = self.bn2(out)
86
+ out = self.relu(out)
87
+
88
+ out = self.conv3(out)
89
+ out = self.bn3(out)
90
+
91
+ out += shortcut
92
+ out = self.relu(out)
93
+
94
+ return out
95
+
96
+
97
+ class DlaBottle2neck(nn.Module):
98
+ """ Res2Net/Res2NeXT DLA Bottleneck
99
+ Adapted from https://github.com/gasvn/Res2Net/blob/master/dla.py
100
+ """
101
+ expansion = 2
102
+
103
+ def __init__(self, inplanes, outplanes, stride=1, dilation=1, scale=4, cardinality=8, base_width=4):
104
+ super(DlaBottle2neck, self).__init__()
105
+ self.is_first = stride > 1
106
+ self.scale = scale
107
+ mid_planes = int(math.floor(outplanes * (base_width / 64)) * cardinality)
108
+ mid_planes = mid_planes // self.expansion
109
+ self.width = mid_planes
110
+
111
+ self.conv1 = nn.Conv2d(inplanes, mid_planes * scale, kernel_size=1, bias=False)
112
+ self.bn1 = nn.BatchNorm2d(mid_planes * scale)
113
+
114
+ num_scale_convs = max(1, scale - 1)
115
+ convs = []
116
+ bns = []
117
+ for _ in range(num_scale_convs):
118
+ convs.append(nn.Conv2d(
119
+ mid_planes, mid_planes, kernel_size=3,
120
+ stride=stride, padding=dilation, dilation=dilation, groups=cardinality, bias=False))
121
+ bns.append(nn.BatchNorm2d(mid_planes))
122
+ self.convs = nn.ModuleList(convs)
123
+ self.bns = nn.ModuleList(bns)
124
+ self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1) if self.is_first else None
125
+
126
+ self.conv3 = nn.Conv2d(mid_planes * scale, outplanes, kernel_size=1, bias=False)
127
+ self.bn3 = nn.BatchNorm2d(outplanes)
128
+ self.relu = nn.ReLU(inplace=True)
129
+
130
+ def forward(self, x, shortcut: Optional[torch.Tensor] = None, children: Optional[List[torch.Tensor]] = None):
131
+ if shortcut is None:
132
+ shortcut = x
133
+
134
+ out = self.conv1(x)
135
+ out = self.bn1(out)
136
+ out = self.relu(out)
137
+
138
+ spx = torch.split(out, self.width, 1)
139
+ spo = []
140
+ sp = spx[0] # redundant, for torchscript
141
+ for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
142
+ if i == 0 or self.is_first:
143
+ sp = spx[i]
144
+ else:
145
+ sp = sp + spx[i]
146
+ sp = conv(sp)
147
+ sp = bn(sp)
148
+ sp = self.relu(sp)
149
+ spo.append(sp)
150
+ if self.scale > 1:
151
+ if self.pool is not None: # self.is_first == True, None check for torchscript
152
+ spo.append(self.pool(spx[-1]))
153
+ else:
154
+ spo.append(spx[-1])
155
+ out = torch.cat(spo, 1)
156
+
157
+ out = self.conv3(out)
158
+ out = self.bn3(out)
159
+
160
+ out += shortcut
161
+ out = self.relu(out)
162
+
163
+ return out
164
+
165
+
166
+ class DlaRoot(nn.Module):
167
+ def __init__(self, in_channels, out_channels, kernel_size, shortcut):
168
+ super(DlaRoot, self).__init__()
169
+ self.conv = nn.Conv2d(
170
+ in_channels, out_channels, 1, stride=1, bias=False, padding=(kernel_size - 1) // 2)
171
+ self.bn = nn.BatchNorm2d(out_channels)
172
+ self.relu = nn.ReLU(inplace=True)
173
+ self.shortcut = shortcut
174
+
175
+ def forward(self, x_children: List[torch.Tensor]):
176
+ x = self.conv(torch.cat(x_children, 1))
177
+ x = self.bn(x)
178
+ if self.shortcut:
179
+ x += x_children[0]
180
+ x = self.relu(x)
181
+
182
+ return x
183
+
184
+
185
+ class DlaTree(nn.Module):
186
+ def __init__(
187
+ self,
188
+ levels,
189
+ block,
190
+ in_channels,
191
+ out_channels,
192
+ stride=1,
193
+ dilation=1,
194
+ cardinality=1,
195
+ base_width=64,
196
+ level_root=False,
197
+ root_dim=0,
198
+ root_kernel_size=1,
199
+ root_shortcut=False,
200
+ ):
201
+ super(DlaTree, self).__init__()
202
+ if root_dim == 0:
203
+ root_dim = 2 * out_channels
204
+ if level_root:
205
+ root_dim += in_channels
206
+ self.downsample = nn.MaxPool2d(stride, stride=stride) if stride > 1 else nn.Identity()
207
+ self.project = nn.Identity()
208
+ cargs = dict(dilation=dilation, cardinality=cardinality, base_width=base_width)
209
+ if levels == 1:
210
+ self.tree1 = block(in_channels, out_channels, stride, **cargs)
211
+ self.tree2 = block(out_channels, out_channels, 1, **cargs)
212
+ if in_channels != out_channels:
213
+ # NOTE the official impl/weights have project layers in levels > 1 case that are never
214
+ # used, I've moved the project layer here to avoid wasted params but old checkpoints will
215
+ # need strict=False while loading.
216
+ self.project = nn.Sequential(
217
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
218
+ nn.BatchNorm2d(out_channels))
219
+ self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_shortcut)
220
+ else:
221
+ cargs.update(dict(root_kernel_size=root_kernel_size, root_shortcut=root_shortcut))
222
+ self.tree1 = DlaTree(
223
+ levels - 1,
224
+ block,
225
+ in_channels,
226
+ out_channels,
227
+ stride,
228
+ root_dim=0,
229
+ **cargs,
230
+ )
231
+ self.tree2 = DlaTree(
232
+ levels - 1,
233
+ block,
234
+ out_channels,
235
+ out_channels,
236
+ root_dim=root_dim + out_channels,
237
+ **cargs,
238
+ )
239
+ self.root = None
240
+ self.level_root = level_root
241
+ self.root_dim = root_dim
242
+ self.levels = levels
243
+
244
+ def forward(self, x, shortcut: Optional[torch.Tensor] = None, children: Optional[List[torch.Tensor]] = None):
245
+ if children is None:
246
+ children = []
247
+ bottom = self.downsample(x)
248
+ shortcut = self.project(bottom)
249
+ if self.level_root:
250
+ children.append(bottom)
251
+ x1 = self.tree1(x, shortcut)
252
+ if self.root is not None: # levels == 1
253
+ x2 = self.tree2(x1)
254
+ x = self.root([x2, x1] + children)
255
+ else:
256
+ children.append(x1)
257
+ x = self.tree2(x1, None, children)
258
+ return x
259
+
260
+
261
+ class DLA(nn.Module):
262
+ def __init__(
263
+ self,
264
+ levels,
265
+ channels,
266
+ output_stride=32,
267
+ num_classes=1000,
268
+ in_chans=3,
269
+ global_pool='avg',
270
+ cardinality=1,
271
+ base_width=64,
272
+ block=DlaBottle2neck,
273
+ shortcut_root=False,
274
+ drop_rate=0.0,
275
+ ):
276
+ super(DLA, self).__init__()
277
+ self.channels = channels
278
+ self.num_classes = num_classes
279
+ self.cardinality = cardinality
280
+ self.base_width = base_width
281
+ assert output_stride == 32 # FIXME support dilation
282
+
283
+ self.base_layer = nn.Sequential(
284
+ nn.Conv2d(in_chans, channels[0], kernel_size=7, stride=1, padding=3, bias=False),
285
+ nn.BatchNorm2d(channels[0]),
286
+ nn.ReLU(inplace=True),
287
+ )
288
+ self.level0 = self._make_conv_level(channels[0], channels[0], levels[0])
289
+ self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2)
290
+ cargs = dict(cardinality=cardinality, base_width=base_width, root_shortcut=shortcut_root)
291
+ self.level2 = DlaTree(levels[2], block, channels[1], channels[2], 2, level_root=False, **cargs)
292
+ self.level3 = DlaTree(levels[3], block, channels[2], channels[3], 2, level_root=True, **cargs)
293
+ self.level4 = DlaTree(levels[4], block, channels[3], channels[4], 2, level_root=True, **cargs)
294
+ self.level5 = DlaTree(levels[5], block, channels[4], channels[5], 2, level_root=True, **cargs)
295
+ self.feature_info = [
296
+ dict(num_chs=channels[0], reduction=1, module='level0'), # rare to have a meaningful stride 1 level
297
+ dict(num_chs=channels[1], reduction=2, module='level1'),
298
+ dict(num_chs=channels[2], reduction=4, module='level2'),
299
+ dict(num_chs=channels[3], reduction=8, module='level3'),
300
+ dict(num_chs=channels[4], reduction=16, module='level4'),
301
+ dict(num_chs=channels[5], reduction=32, module='level5'),
302
+ ]
303
+
304
+ self.num_features = channels[-1]
305
+ self.global_pool, self.head_drop, self.fc = create_classifier(
306
+ self.num_features,
307
+ self.num_classes,
308
+ pool_type=global_pool,
309
+ use_conv=True,
310
+ drop_rate=drop_rate,
311
+ )
312
+ self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
313
+
314
+ for m in self.modules():
315
+ if isinstance(m, nn.Conv2d):
316
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
317
+ m.weight.data.normal_(0, math.sqrt(2. / n))
318
+ elif isinstance(m, nn.BatchNorm2d):
319
+ m.weight.data.fill_(1)
320
+ m.bias.data.zero_()
321
+
322
+ def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1):
323
+ modules = []
324
+ for i in range(convs):
325
+ modules.extend([
326
+ nn.Conv2d(
327
+ inplanes, planes, kernel_size=3,
328
+ stride=stride if i == 0 else 1,
329
+ padding=dilation, bias=False, dilation=dilation),
330
+ nn.BatchNorm2d(planes),
331
+ nn.ReLU(inplace=True)])
332
+ inplanes = planes
333
+ return nn.Sequential(*modules)
334
+
335
+ @torch.jit.ignore
336
+ def group_matcher(self, coarse=False):
337
+ matcher = dict(
338
+ stem=r'^base_layer',
339
+ blocks=r'^level(\d+)' if coarse else [
340
+ # an unusual arch, this achieves somewhat more granularity without getting super messy
341
+ (r'^level(\d+)\.tree(\d+)', None),
342
+ (r'^level(\d+)\.root', (2,)),
343
+ (r'^level(\d+)', (1,))
344
+ ]
345
+ )
346
+ return matcher
347
+
348
+ @torch.jit.ignore
349
+ def set_grad_checkpointing(self, enable=True):
350
+ assert not enable, 'gradient checkpointing not supported'
351
+
352
+ @torch.jit.ignore
353
+ def get_classifier(self):
354
+ return self.fc
355
+
356
+ def reset_classifier(self, num_classes, global_pool='avg'):
357
+ self.num_classes = num_classes
358
+ self.global_pool, self.fc = create_classifier(
359
+ self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
360
+ self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
361
+
362
+ def forward_features(self, x):
363
+ x = self.base_layer(x)
364
+ x = self.level0(x)
365
+ x = self.level1(x)
366
+ x = self.level2(x)
367
+ x = self.level3(x)
368
+ x = self.level4(x)
369
+ x = self.level5(x)
370
+ return x
371
+
372
+ def forward_head(self, x, pre_logits: bool = False):
373
+ x = self.global_pool(x)
374
+ x = self.head_drop(x)
375
+ if pre_logits:
376
+ return self.flatten(x)
377
+ x = self.fc(x)
378
+ return self.flatten(x)
379
+
380
+ def forward(self, x):
381
+ x = self.forward_features(x)
382
+ x = self.forward_head(x)
383
+ return x
384
+
385
+
386
+ def _create_dla(variant, pretrained=False, **kwargs):
387
+ return build_model_with_cfg(
388
+ DLA,
389
+ variant,
390
+ pretrained,
391
+ pretrained_strict=False,
392
+ feature_cfg=dict(out_indices=(1, 2, 3, 4, 5)),
393
+ **kwargs,
394
+ )
395
+
396
+
397
+ def _cfg(url='', **kwargs):
398
+ return {
399
+ 'url': url,
400
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
401
+ 'crop_pct': 0.875, 'interpolation': 'bilinear',
402
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
403
+ 'first_conv': 'base_layer.0', 'classifier': 'fc',
404
+ **kwargs
405
+ }
406
+
407
+
408
+ default_cfgs = generate_default_cfgs({
409
+ 'dla34.in1k': _cfg(hf_hub_id='timm/'),
410
+ 'dla46_c.in1k': _cfg(hf_hub_id='timm/'),
411
+ 'dla46x_c.in1k': _cfg(hf_hub_id='timm/'),
412
+ 'dla60x_c.in1k': _cfg(hf_hub_id='timm/'),
413
+ 'dla60.in1k': _cfg(hf_hub_id='timm/'),
414
+ 'dla60x.in1k': _cfg(hf_hub_id='timm/'),
415
+ 'dla102.in1k': _cfg(hf_hub_id='timm/'),
416
+ 'dla102x.in1k': _cfg(hf_hub_id='timm/'),
417
+ 'dla102x2.in1k': _cfg(hf_hub_id='timm/'),
418
+ 'dla169.in1k': _cfg(hf_hub_id='timm/'),
419
+ 'dla60_res2net.in1k': _cfg(hf_hub_id='timm/'),
420
+ 'dla60_res2next.in1k': _cfg(hf_hub_id='timm/'),
421
+ })
422
+
423
+
424
+ @register_model
425
+ def dla60_res2net(pretrained=False, **kwargs) -> DLA:
426
+ model_args = dict(
427
+ levels=(1, 1, 1, 2, 3, 1), channels=(16, 32, 128, 256, 512, 1024),
428
+ block=DlaBottle2neck, cardinality=1, base_width=28)
429
+ return _create_dla('dla60_res2net', pretrained, **dict(model_args, **kwargs))
430
+
431
+
432
+ @register_model
433
+ def dla60_res2next(pretrained=False,**kwargs):
434
+ model_args = dict(
435
+ levels=(1, 1, 1, 2, 3, 1), channels=(16, 32, 128, 256, 512, 1024),
436
+ block=DlaBottle2neck, cardinality=8, base_width=4)
437
+ return _create_dla('dla60_res2next', pretrained, **dict(model_args, **kwargs))
438
+
439
+
440
+ @register_model
441
+ def dla34(pretrained=False, **kwargs) -> DLA: # DLA-34
442
+ model_args = dict(
443
+ levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 128, 256, 512], block=DlaBasic)
444
+ return _create_dla('dla34', pretrained, **dict(model_args, **kwargs))
445
+
446
+
447
+ @register_model
448
+ def dla46_c(pretrained=False, **kwargs) -> DLA: # DLA-46-C
449
+ model_args = dict(
450
+ levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 64, 128, 256], block=DlaBottleneck)
451
+ return _create_dla('dla46_c', pretrained, **dict(model_args, **kwargs))
452
+
453
+
454
+ @register_model
455
+ def dla46x_c(pretrained=False, **kwargs) -> DLA: # DLA-X-46-C
456
+ model_args = dict(
457
+ levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 64, 128, 256],
458
+ block=DlaBottleneck, cardinality=32, base_width=4)
459
+ return _create_dla('dla46x_c', pretrained, **dict(model_args, **kwargs))
460
+
461
+
462
+ @register_model
463
+ def dla60x_c(pretrained=False, **kwargs) -> DLA: # DLA-X-60-C
464
+ model_args = dict(
465
+ levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 64, 64, 128, 256],
466
+ block=DlaBottleneck, cardinality=32, base_width=4)
467
+ return _create_dla('dla60x_c', pretrained, **dict(model_args, **kwargs))
468
+
469
+
470
+ @register_model
471
+ def dla60(pretrained=False, **kwargs) -> DLA: # DLA-60
472
+ model_args = dict(
473
+ levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 128, 256, 512, 1024],
474
+ block=DlaBottleneck)
475
+ return _create_dla('dla60', pretrained, **dict(model_args, **kwargs))
476
+
477
+
478
+ @register_model
479
+ def dla60x(pretrained=False, **kwargs) -> DLA: # DLA-X-60
480
+ model_args = dict(
481
+ levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 128, 256, 512, 1024],
482
+ block=DlaBottleneck, cardinality=32, base_width=4)
483
+ return _create_dla('dla60x', pretrained, **dict(model_args, **kwargs))
484
+
485
+
486
+ @register_model
487
+ def dla102(pretrained=False, **kwargs) -> DLA: # DLA-102
488
+ model_args = dict(
489
+ levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
490
+ block=DlaBottleneck, shortcut_root=True)
491
+ return _create_dla('dla102', pretrained, **dict(model_args, **kwargs))
492
+
493
+
494
+ @register_model
495
+ def dla102x(pretrained=False, **kwargs) -> DLA: # DLA-X-102
496
+ model_args = dict(
497
+ levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
498
+ block=DlaBottleneck, cardinality=32, base_width=4, shortcut_root=True)
499
+ return _create_dla('dla102x', pretrained, **dict(model_args, **kwargs))
500
+
501
+
502
+ @register_model
503
+ def dla102x2(pretrained=False, **kwargs) -> DLA: # DLA-X-102 64
504
+ model_args = dict(
505
+ levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
506
+ block=DlaBottleneck, cardinality=64, base_width=4, shortcut_root=True)
507
+ return _create_dla('dla102x2', pretrained, **dict(model_args, **kwargs))
508
+
509
+
510
+ @register_model
511
+ def dla169(pretrained=False, **kwargs) -> DLA: # DLA-169
512
+ model_args = dict(
513
+ levels=[1, 1, 2, 3, 5, 1], channels=[16, 32, 128, 256, 512, 1024],
514
+ block=DlaBottleneck, shortcut_root=True)
515
+ return _create_dla('dla169', pretrained, **dict(model_args, **kwargs))
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/eva.py ADDED
@@ -0,0 +1,1109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ EVA
2
+
3
+ EVA from https://github.com/baaivision/EVA , paper: https://arxiv.org/abs/2211.07636
4
+
5
+ @article{EVA,
6
+ title={EVA: Exploring the Limits of Masked Visual Representation Learning at Scale},
7
+ author={Fang, Yuxin and Wang, Wen and Xie, Binhui and Sun, Quan and Wu, Ledell and Wang, Xinggang and Huang,
8
+ Tiejun and Wang, Xinlong and Cao, Yue},
9
+ journal={arXiv preprint arXiv:2211.07636},
10
+ year={2022}
11
+ }
12
+
13
+ EVA-02: A Visual Representation for Neon Genesis - https://arxiv.org/abs/2303.11331
14
+ @article{EVA02,
15
+ title={EVA-02: A Visual Representation for Neon Genesis},
16
+ author={Fang, Yuxin and Sun, Quan and Wang, Xinggang and Huang, Tiejun and Wang, Xinlong and Cao, Yue},
17
+ journal={arXiv preprint arXiv:2303.11331},
18
+ year={2023}
19
+ }
20
+
21
+ This file contains EVA & EVA02 model implementations evolved from BEiT, additional models in vision_transformer.py.
22
+
23
+ Modifications by / Copyright 2023 Ross Wightman, original copyrights below
24
+ """
25
+ # EVA models Copyright (c) 2022 BAAI-Vision
26
+ # EVA02 models Copyright (c) 2023 BAAI-Vision
27
+
28
+ import math
29
+ from typing import Callable, Optional, Tuple, Union
30
+
31
+ import torch
32
+ import torch.nn as nn
33
+ import torch.nn.functional as F
34
+ from torch.utils.checkpoint import checkpoint
35
+
36
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
37
+ from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, PatchDropout, RotaryEmbeddingCat, \
38
+ apply_rot_embed_cat, apply_keep_indices_nlc, trunc_normal_, resample_patch_embed, resample_abs_pos_embed, \
39
+ to_2tuple, use_fused_attn
40
+
41
+ from ._builder import build_model_with_cfg
42
+ from ._registry import generate_default_cfgs, register_model
43
+
44
+ __all__ = ['Eva']
45
+
46
+
47
+ class EvaAttention(nn.Module):
48
+ fused_attn: torch.jit.Final[bool]
49
+
50
+ def __init__(
51
+ self,
52
+ dim: int,
53
+ num_heads: int = 8,
54
+ qkv_bias: bool = True,
55
+ qkv_fused: bool = True,
56
+ attn_drop: float = 0.,
57
+ proj_drop: float = 0.,
58
+ attn_head_dim: Optional[int] = None,
59
+ norm_layer: Optional[Callable] = None,
60
+ ):
61
+ """
62
+
63
+ Args:
64
+ dim:
65
+ num_heads:
66
+ qkv_bias:
67
+ qkv_fused:
68
+ attn_drop:
69
+ proj_drop:
70
+ attn_head_dim:
71
+ norm_layer:
72
+ """
73
+ super().__init__()
74
+ self.num_heads = num_heads
75
+ head_dim = dim // num_heads
76
+ if attn_head_dim is not None:
77
+ head_dim = attn_head_dim
78
+ all_head_dim = head_dim * self.num_heads
79
+ self.scale = head_dim ** -0.5
80
+ self.fused_attn = use_fused_attn()
81
+
82
+ if qkv_fused:
83
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
84
+ self.q_proj = self.k_proj = self.v_proj = None
85
+ if qkv_bias:
86
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
87
+ self.register_buffer('k_bias', torch.zeros(all_head_dim), persistent=False)
88
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
89
+ else:
90
+ self.q_bias = self.k_bias = self.v_bias = None
91
+ else:
92
+ self.q_proj = nn.Linear(dim, all_head_dim, bias=qkv_bias)
93
+ self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
94
+ self.v_proj = nn.Linear(dim, all_head_dim, bias=qkv_bias)
95
+ self.qkv = None
96
+ self.q_bias = self.k_bias = self.v_bias = None
97
+
98
+ self.attn_drop = nn.Dropout(attn_drop)
99
+ self.norm = norm_layer(all_head_dim) if norm_layer is not None else nn.Identity()
100
+ self.proj = nn.Linear(all_head_dim, dim)
101
+ self.proj_drop = nn.Dropout(proj_drop)
102
+
103
+ def forward(
104
+ self,
105
+ x,
106
+ rope: Optional[torch.Tensor] = None,
107
+ attn_mask: Optional[torch.Tensor] = None,
108
+ ):
109
+ B, N, C = x.shape
110
+
111
+ if self.qkv is not None:
112
+ qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
113
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
114
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
115
+ q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
116
+ else:
117
+ q = self.q_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) # B, num_heads, N, C
118
+ k = self.k_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
119
+ v = self.v_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
120
+
121
+ if rope is not None:
122
+ q = torch.cat([q[:, :, :1, :], apply_rot_embed_cat(q[:, :, 1:, :], rope)], 2).type_as(v)
123
+ k = torch.cat([k[:, :, :1, :], apply_rot_embed_cat(k[:, :, 1:, :], rope)], 2).type_as(v)
124
+
125
+ if self.fused_attn:
126
+ x = F.scaled_dot_product_attention(
127
+ q, k, v,
128
+ attn_mask=attn_mask,
129
+ dropout_p=self.attn_drop.p if self.training else 0.,
130
+ )
131
+ else:
132
+ q = q * self.scale
133
+ attn = (q @ k.transpose(-2, -1))
134
+ attn = attn.softmax(dim=-1)
135
+ if attn_mask is not None:
136
+ attn_mask = attn_mask.to(torch.bool)
137
+ attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
138
+ attn = self.attn_drop(attn)
139
+ x = attn @ v
140
+
141
+ x = x.transpose(1, 2).reshape(B, N, C)
142
+ x = self.norm(x)
143
+ x = self.proj(x)
144
+ x = self.proj_drop(x)
145
+ return x
146
+
147
+
148
+ class EvaBlock(nn.Module):
149
+
150
+ def __init__(
151
+ self,
152
+ dim: int,
153
+ num_heads: int,
154
+ qkv_bias: bool = True,
155
+ qkv_fused: bool = True,
156
+ mlp_ratio: float = 4.,
157
+ swiglu_mlp: bool = False,
158
+ scale_mlp: bool = False,
159
+ scale_attn_inner: bool = False,
160
+ proj_drop: float = 0.,
161
+ attn_drop: float = 0.,
162
+ drop_path: float = 0.,
163
+ init_values: Optional[float] = None,
164
+ act_layer: Callable = nn.GELU,
165
+ norm_layer: Callable = LayerNorm,
166
+ attn_head_dim: Optional[int] = None,
167
+ ):
168
+ """
169
+
170
+ Args:
171
+ dim:
172
+ num_heads:
173
+ qkv_bias:
174
+ qkv_fused:
175
+ mlp_ratio:
176
+ swiglu_mlp:
177
+ scale_mlp:
178
+ scale_attn_inner:
179
+ proj_drop:
180
+ attn_drop:
181
+ drop_path:
182
+ init_values:
183
+ act_layer:
184
+ norm_layer:
185
+ attn_head_dim:
186
+ """
187
+ super().__init__()
188
+ self.norm1 = norm_layer(dim)
189
+ self.attn = EvaAttention(
190
+ dim,
191
+ num_heads=num_heads,
192
+ qkv_bias=qkv_bias,
193
+ qkv_fused=qkv_fused,
194
+ attn_drop=attn_drop,
195
+ proj_drop=proj_drop,
196
+ attn_head_dim=attn_head_dim,
197
+ norm_layer=norm_layer if scale_attn_inner else None,
198
+ )
199
+ self.gamma_1 = nn.Parameter(init_values * torch.ones(dim)) if init_values is not None else None
200
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
201
+
202
+ self.norm2 = norm_layer(dim)
203
+ hidden_features = int(dim * mlp_ratio)
204
+ if swiglu_mlp:
205
+ if scale_mlp:
206
+ # when norm in SwiGLU used, an impl with separate fc for gate & x is used
207
+ self.mlp = SwiGLU(
208
+ in_features=dim,
209
+ hidden_features=hidden_features,
210
+ norm_layer=norm_layer if scale_mlp else None,
211
+ drop=proj_drop,
212
+ )
213
+ else:
214
+ # w/o any extra norm, an impl with packed weights is used, matches existing GluMLP
215
+ self.mlp = GluMlp(
216
+ in_features=dim,
217
+ hidden_features=hidden_features * 2,
218
+ norm_layer=norm_layer if scale_mlp else None,
219
+ act_layer=nn.SiLU,
220
+ gate_last=False,
221
+ drop=proj_drop,
222
+ )
223
+ else:
224
+ self.mlp = Mlp(
225
+ in_features=dim,
226
+ hidden_features=hidden_features,
227
+ act_layer=act_layer,
228
+ norm_layer=norm_layer if scale_mlp else None,
229
+ drop=proj_drop,
230
+ )
231
+ self.gamma_2 = nn.Parameter(init_values * torch.ones(dim)) if init_values is not None else None
232
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
233
+
234
+ def forward(self, x, rope: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None):
235
+ if self.gamma_1 is None:
236
+ x = x + self.drop_path1(self.attn(self.norm1(x), rope=rope, attn_mask=attn_mask))
237
+ x = x + self.drop_path2(self.mlp(self.norm2(x)))
238
+ else:
239
+ x = x + self.drop_path1(self.gamma_1 * self.attn(self.norm1(x), rope=rope, attn_mask=attn_mask))
240
+ x = x + self.drop_path2(self.gamma_2 * self.mlp(self.norm2(x)))
241
+ return x
242
+
243
+
244
+ class EvaBlockPostNorm(nn.Module):
245
+ """ EVA block w/ post-norm and support for swiglu, MLP norm scale, ROPE. """
246
+ def __init__(
247
+ self,
248
+ dim: int,
249
+ num_heads: int,
250
+ qkv_bias: bool = True,
251
+ qkv_fused: bool = True,
252
+ mlp_ratio: float = 4.,
253
+ swiglu_mlp: bool = False,
254
+ scale_mlp: bool = False,
255
+ scale_attn_inner: bool = False,
256
+ proj_drop: float = 0.,
257
+ attn_drop: float = 0.,
258
+ drop_path: float = 0.,
259
+ init_values: Optional[float] = None, # ignore for post-norm
260
+ act_layer: Callable = nn.GELU,
261
+ norm_layer: Callable = nn.LayerNorm,
262
+ attn_head_dim: Optional[int] = None,
263
+ ):
264
+ """
265
+
266
+ Args:
267
+ dim:
268
+ num_heads:
269
+ qkv_bias:
270
+ qkv_fused:
271
+ mlp_ratio:
272
+ swiglu_mlp:
273
+ scale_mlp:
274
+ scale_attn_inner:
275
+ proj_drop:
276
+ attn_drop:
277
+ drop_path:
278
+ init_values:
279
+ act_layer:
280
+ norm_layer:
281
+ attn_head_dim:
282
+ """
283
+ super().__init__()
284
+ self.attn = EvaAttention(
285
+ dim,
286
+ num_heads=num_heads,
287
+ qkv_bias=qkv_bias,
288
+ qkv_fused=qkv_fused,
289
+ attn_drop=attn_drop,
290
+ proj_drop=proj_drop,
291
+ attn_head_dim=attn_head_dim,
292
+ norm_layer=norm_layer if scale_attn_inner else None,
293
+ )
294
+ self.norm1 = norm_layer(dim)
295
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
296
+
297
+ hidden_features = int(dim * mlp_ratio)
298
+ if swiglu_mlp:
299
+ if scale_mlp:
300
+ # when norm in SwiGLU used, an impl with separate fc for gate & x is used
301
+ self.mlp = SwiGLU(
302
+ in_features=dim,
303
+ hidden_features=hidden_features,
304
+ norm_layer=norm_layer if scale_mlp else None,
305
+ drop=proj_drop,
306
+ )
307
+ else:
308
+ # w/o any extra norm, an impl with packed fc1 weights is used, matches existing GluMLP
309
+ self.mlp = GluMlp(
310
+ in_features=dim,
311
+ hidden_features=hidden_features * 2,
312
+ norm_layer=norm_layer if scale_mlp else None,
313
+ act_layer=nn.SiLU,
314
+ gate_last=False,
315
+ drop=proj_drop,
316
+ )
317
+ else:
318
+ self.mlp = Mlp(
319
+ in_features=dim,
320
+ hidden_features=hidden_features,
321
+ act_layer=act_layer,
322
+ norm_layer=norm_layer if scale_mlp else None,
323
+ drop=proj_drop,
324
+ )
325
+ self.norm2 = norm_layer(dim)
326
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
327
+
328
+ def forward(self, x, rope: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None):
329
+ x = x + self.drop_path1(self.norm1(self.attn(x, rope=rope, attn_mask=attn_mask)))
330
+ x = x + self.drop_path2(self.norm2(self.mlp(x)))
331
+ return x
332
+
333
+
334
+ class Eva(nn.Module):
335
+ """ Eva Vision Transformer w/ Abs & Rotary Pos Embed
336
+
337
+ This class implements the EVA and EVA02 models that were based on the BEiT ViT variant
338
+ * EVA - abs pos embed, global avg pool
339
+ * EVA02 - abs + rope pos embed, global avg pool, SwiGLU, scale Norm in MLP (ala normformer)
340
+ """
341
+
342
+ def __init__(
343
+ self,
344
+ img_size: Union[int, Tuple[int, int]] = 224,
345
+ patch_size: Union[int, Tuple[int, int]] = 16,
346
+ in_chans: int = 3,
347
+ num_classes: int = 1000,
348
+ global_pool: str = 'avg',
349
+ embed_dim: int = 768,
350
+ depth: int = 12,
351
+ num_heads: int = 12,
352
+ qkv_bias: bool = True,
353
+ qkv_fused: bool = True,
354
+ mlp_ratio: float = 4.,
355
+ swiglu_mlp: bool = False,
356
+ scale_mlp: bool = False,
357
+ scale_attn_inner: bool = False,
358
+ drop_rate: float = 0.,
359
+ pos_drop_rate: float = 0.,
360
+ patch_drop_rate: float = 0.,
361
+ proj_drop_rate: float = 0.,
362
+ attn_drop_rate: float = 0.,
363
+ drop_path_rate: float = 0.,
364
+ norm_layer: Callable = LayerNorm,
365
+ init_values: Optional[float] = None,
366
+ class_token: bool = True,
367
+ use_abs_pos_emb: bool = True,
368
+ use_rot_pos_emb: bool = False,
369
+ use_post_norm: bool = False,
370
+ dynamic_img_size: bool = False,
371
+ dynamic_img_pad: bool = False,
372
+ ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None,
373
+ head_init_scale: float = 0.001,
374
+ ):
375
+ """
376
+
377
+ Args:
378
+ img_size:
379
+ patch_size:
380
+ in_chans:
381
+ num_classes:
382
+ global_pool:
383
+ embed_dim:
384
+ depth:
385
+ num_heads:
386
+ qkv_bias:
387
+ qkv_fused:
388
+ mlp_ratio:
389
+ swiglu_mlp:
390
+ scale_mlp:
391
+ scale_attn_inner:
392
+ drop_rate:
393
+ pos_drop_rate:
394
+ proj_drop_rate:
395
+ attn_drop_rate:
396
+ drop_path_rate:
397
+ norm_layer:
398
+ init_values:
399
+ class_token:
400
+ use_abs_pos_emb:
401
+ use_rot_pos_emb:
402
+ use_post_norm:
403
+ ref_feat_shape:
404
+ head_init_scale:
405
+ """
406
+ super().__init__()
407
+ self.num_classes = num_classes
408
+ self.global_pool = global_pool
409
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
410
+ self.num_prefix_tokens = 1 if class_token else 0
411
+ self.dynamic_img_size = dynamic_img_size
412
+ self.grad_checkpointing = False
413
+
414
+ embed_args = {}
415
+ if dynamic_img_size:
416
+ # flatten deferred until after pos embed
417
+ embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
418
+ self.patch_embed = PatchEmbed(
419
+ img_size=img_size,
420
+ patch_size=patch_size,
421
+ in_chans=in_chans,
422
+ embed_dim=embed_dim,
423
+ dynamic_img_pad=dynamic_img_pad,
424
+ **embed_args,
425
+ )
426
+ num_patches = self.patch_embed.num_patches
427
+
428
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
429
+
430
+ self.pos_embed = nn.Parameter(
431
+ torch.zeros(1, num_patches + self.num_prefix_tokens, embed_dim)) if use_abs_pos_emb else None
432
+ self.pos_drop = nn.Dropout(p=pos_drop_rate)
433
+ if patch_drop_rate > 0:
434
+ self.patch_drop = PatchDropout(
435
+ patch_drop_rate,
436
+ num_prefix_tokens=self.num_prefix_tokens,
437
+ return_indices=True,
438
+ )
439
+ else:
440
+ self.patch_drop = None
441
+
442
+ if use_rot_pos_emb:
443
+ ref_feat_shape = to_2tuple(ref_feat_shape) if ref_feat_shape is not None else None
444
+ self.rope = RotaryEmbeddingCat(
445
+ embed_dim // num_heads,
446
+ in_pixels=False,
447
+ feat_shape=None if dynamic_img_size else self.patch_embed.grid_size,
448
+ ref_feat_shape=ref_feat_shape,
449
+ )
450
+ else:
451
+ self.rope = None
452
+
453
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
454
+ block_fn = EvaBlockPostNorm if use_post_norm else EvaBlock
455
+ self.blocks = nn.ModuleList([
456
+ block_fn(
457
+ dim=embed_dim,
458
+ num_heads=num_heads,
459
+ qkv_bias=qkv_bias,
460
+ qkv_fused=qkv_fused,
461
+ mlp_ratio=mlp_ratio,
462
+ swiglu_mlp=swiglu_mlp,
463
+ scale_mlp=scale_mlp,
464
+ scale_attn_inner=scale_attn_inner,
465
+ proj_drop=proj_drop_rate,
466
+ attn_drop=attn_drop_rate,
467
+ drop_path=dpr[i],
468
+ norm_layer=norm_layer,
469
+ init_values=init_values,
470
+ )
471
+ for i in range(depth)])
472
+
473
+ use_fc_norm = self.global_pool == 'avg'
474
+ self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
475
+ self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
476
+ self.head_drop = nn.Dropout(drop_rate)
477
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
478
+
479
+ self.apply(self._init_weights)
480
+ if self.pos_embed is not None:
481
+ trunc_normal_(self.pos_embed, std=.02)
482
+ if self.cls_token is not None:
483
+ trunc_normal_(self.cls_token, std=.02)
484
+
485
+ self.fix_init_weight()
486
+ if isinstance(self.head, nn.Linear):
487
+ trunc_normal_(self.head.weight, std=.02)
488
+ self.head.weight.data.mul_(head_init_scale)
489
+ self.head.bias.data.mul_(head_init_scale)
490
+
491
+ def fix_init_weight(self):
492
+ def rescale(param, layer_id):
493
+ param.div_(math.sqrt(2.0 * layer_id))
494
+
495
+ for layer_id, layer in enumerate(self.blocks):
496
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
497
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
498
+
499
+ def _init_weights(self, m):
500
+ if isinstance(m, nn.Linear):
501
+ trunc_normal_(m.weight, std=.02)
502
+ if m.bias is not None:
503
+ nn.init.zeros_(m.bias)
504
+
505
+ @torch.jit.ignore
506
+ def no_weight_decay(self):
507
+ nwd = {'pos_embed', 'cls_token'}
508
+ return nwd
509
+
510
+ @torch.jit.ignore
511
+ def set_grad_checkpointing(self, enable=True):
512
+ self.grad_checkpointing = enable
513
+
514
+ @torch.jit.ignore
515
+ def group_matcher(self, coarse=False):
516
+ matcher = dict(
517
+ stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
518
+ blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))],
519
+ )
520
+ return matcher
521
+
522
+ @torch.jit.ignore
523
+ def get_classifier(self):
524
+ return self.head
525
+
526
+ def reset_classifier(self, num_classes, global_pool=None):
527
+ self.num_classes = num_classes
528
+ if global_pool is not None:
529
+ self.global_pool = global_pool
530
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
531
+
532
+ def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
533
+ if self.dynamic_img_size:
534
+ B, H, W, C = x.shape
535
+ if self.pos_embed is not None:
536
+ pos_embed = resample_abs_pos_embed(
537
+ self.pos_embed,
538
+ (H, W),
539
+ num_prefix_tokens=self.num_prefix_tokens,
540
+ )
541
+ else:
542
+ pos_embed = None
543
+ x = x.view(B, -1, C)
544
+ rot_pos_embed = self.rope.get_embed(shape=(H, W)) if self.rope is not None else None
545
+ else:
546
+ pos_embed = self.pos_embed
547
+ rot_pos_embed = self.rope.get_embed() if self.rope is not None else None
548
+
549
+ if self.cls_token is not None:
550
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
551
+ if pos_embed is not None:
552
+ x = x + pos_embed
553
+ x = self.pos_drop(x)
554
+
555
+ # obtain shared rotary position embedding and apply patch dropout
556
+ if self.patch_drop is not None:
557
+ x, keep_indices = self.patch_drop(x)
558
+ if rot_pos_embed is not None and keep_indices is not None:
559
+ rot_pos_embed = apply_keep_indices_nlc(x, rot_pos_embed, keep_indices)
560
+ return x, rot_pos_embed
561
+
562
+ def forward_features(self, x):
563
+ x = self.patch_embed(x)
564
+ x, rot_pos_embed = self._pos_embed(x)
565
+ for blk in self.blocks:
566
+ if self.grad_checkpointing and not torch.jit.is_scripting():
567
+ x = checkpoint(blk, x, rope=rot_pos_embed)
568
+ else:
569
+ x = blk(x, rope=rot_pos_embed)
570
+ x = self.norm(x)
571
+ return x
572
+
573
+ def forward_head(self, x, pre_logits: bool = False):
574
+ if self.global_pool:
575
+ x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
576
+ x = self.fc_norm(x)
577
+ x = self.head_drop(x)
578
+ return x if pre_logits else self.head(x)
579
+
580
+ def forward(self, x):
581
+ x = self.forward_features(x)
582
+ x = self.forward_head(x)
583
+ return x
584
+
585
+
586
+ def checkpoint_filter_fn(
587
+ state_dict,
588
+ model,
589
+ interpolation='bicubic',
590
+ antialias=True,
591
+ ):
592
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
593
+ out_dict = {}
594
+ state_dict = state_dict.get('model_ema', state_dict)
595
+ state_dict = state_dict.get('model', state_dict)
596
+ state_dict = state_dict.get('module', state_dict)
597
+ state_dict = state_dict.get('state_dict', state_dict)
598
+ # prefix for loading OpenCLIP compatible weights
599
+ if 'visual.trunk.pos_embed' in state_dict:
600
+ prefix = 'visual.trunk.'
601
+ elif 'visual.pos_embed' in state_dict:
602
+ prefix = 'visual.'
603
+ else:
604
+ prefix = ''
605
+ mim_weights = prefix + 'mask_token' in state_dict
606
+ no_qkv = prefix + 'blocks.0.attn.q_proj.weight' in state_dict
607
+
608
+ len_prefix = len(prefix)
609
+ for k, v in state_dict.items():
610
+ if prefix:
611
+ if k.startswith(prefix):
612
+ k = k[len_prefix:]
613
+ else:
614
+ continue
615
+
616
+ if 'rope' in k:
617
+ # fixed embedding no need to load buffer from checkpoint
618
+ continue
619
+
620
+ if 'patch_embed.proj.weight' in k:
621
+ _, _, H, W = model.patch_embed.proj.weight.shape
622
+ if v.shape[-1] != W or v.shape[-2] != H:
623
+ v = resample_patch_embed(
624
+ v,
625
+ (H, W),
626
+ interpolation=interpolation,
627
+ antialias=antialias,
628
+ verbose=True,
629
+ )
630
+ elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
631
+ # To resize pos embedding when using model at different size from pretrained weights
632
+ num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1)
633
+ v = resample_abs_pos_embed(
634
+ v,
635
+ new_size=model.patch_embed.grid_size,
636
+ num_prefix_tokens=num_prefix_tokens,
637
+ interpolation=interpolation,
638
+ antialias=antialias,
639
+ verbose=True,
640
+ )
641
+
642
+ k = k.replace('mlp.ffn_ln', 'mlp.norm')
643
+ k = k.replace('attn.inner_attn_ln', 'attn.norm')
644
+ k = k.replace('mlp.w12', 'mlp.fc1')
645
+ k = k.replace('mlp.w1', 'mlp.fc1_g')
646
+ k = k.replace('mlp.w2', 'mlp.fc1_x')
647
+ k = k.replace('mlp.w3', 'mlp.fc2')
648
+ if no_qkv:
649
+ k = k.replace('q_bias', 'q_proj.bias')
650
+ k = k.replace('v_bias', 'v_proj.bias')
651
+
652
+ if mim_weights and k in ('mask_token', 'lm_head.weight', 'lm_head.bias', 'norm.weight', 'norm.bias'):
653
+ if k == 'norm.weight' or k == 'norm.bias':
654
+ # try moving norm -> fc norm on fine-tune, probably a better starting point than new init
655
+ k = k.replace('norm', 'fc_norm')
656
+ else:
657
+ # skip pretrain mask token & head weights
658
+ continue
659
+
660
+ out_dict[k] = v
661
+
662
+ return out_dict
663
+
664
+
665
+ def _create_eva(variant, pretrained=False, **kwargs):
666
+ if kwargs.get('features_only', None):
667
+ raise RuntimeError('features_only not implemented for Eva models.')
668
+
669
+ model = build_model_with_cfg(
670
+ Eva, variant, pretrained,
671
+ pretrained_filter_fn=checkpoint_filter_fn,
672
+ **kwargs)
673
+ return model
674
+
675
+
676
+ def _cfg(url='', **kwargs):
677
+ return {
678
+ 'url': url,
679
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
680
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
681
+ 'mean': OPENAI_CLIP_MEAN, 'std': OPENAI_CLIP_STD,
682
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
683
+ 'license': 'mit', **kwargs
684
+ }
685
+
686
+
687
+ default_cfgs = generate_default_cfgs({
688
+
689
+ # EVA 01 CLIP fine-tuned on imagenet-1k
690
+ 'eva_giant_patch14_224.clip_ft_in1k': _cfg(
691
+ # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt',
692
+ hf_hub_id='timm/',
693
+ ),
694
+ 'eva_giant_patch14_336.clip_ft_in1k': _cfg(
695
+ # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt',
696
+ hf_hub_id='timm/',
697
+ input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
698
+
699
+ # MIM EVA 01 pretrain, ft on in22k -> in1k
700
+ 'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg(
701
+ # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt',
702
+ hf_hub_id='timm/',
703
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
704
+ input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
705
+ 'eva_giant_patch14_560.m30m_ft_in22k_in1k': _cfg(
706
+ # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt',
707
+ hf_hub_id='timm/',
708
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
709
+ input_size=(3, 560, 560), crop_pct=1.0, crop_mode='squash'),
710
+
711
+ # in22k or m38m MIM pretrain w/ intermediate in22k fine-tune and final in1k fine-tune
712
+ 'eva02_base_patch14_448.mim_in22k_ft_in22k_in1k': _cfg(
713
+ # hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_B_pt_in21k_medft_in21k_ft_in1k_p14.pt',
714
+ hf_hub_id='timm/',
715
+ input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash',
716
+ ),
717
+ 'eva02_large_patch14_448.mim_in22k_ft_in22k_in1k': _cfg(
718
+ # hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_L_pt_in21k_medft_in21k_ft_in1k_p14.pt',
719
+ hf_hub_id='timm/',
720
+ input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash',
721
+ ),
722
+ 'eva02_large_patch14_448.mim_m38m_ft_in22k_in1k': _cfg(
723
+ hf_hub_id='timm/',
724
+ #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_L_pt_m38m_medft_in21k_ft_in1k_p14.pt',
725
+ input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash',
726
+ ),
727
+
728
+ # in22k or m3m MIM pretrain w/ in1k fine-tune
729
+ 'eva02_tiny_patch14_336.mim_in22k_ft_in1k': _cfg(
730
+ #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in1k/eva02_Ti_pt_in21k_ft_in1k_p14.pt',
731
+ hf_hub_id='timm/',
732
+ input_size=(3, 336, 336), crop_pct=1.0,
733
+ ),
734
+ 'eva02_small_patch14_336.mim_in22k_ft_in1k': _cfg(
735
+ #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in1k/eva02_S_pt_in21k_ft_in1k_p14.pt',
736
+ hf_hub_id='timm/',
737
+ input_size=(3, 336, 336), crop_pct=1.0,
738
+ ),
739
+ 'eva02_base_patch14_448.mim_in22k_ft_in1k': _cfg(
740
+ #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in1k/eva02_B_pt_in21k_ft_in1k_p14.pt',
741
+ hf_hub_id='timm/',
742
+ input_size=(3, 448, 448), crop_pct=1.0,
743
+ ),
744
+ 'eva02_large_patch14_448.mim_in22k_ft_in1k': _cfg(
745
+ #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in1k/eva02_L_pt_in21k_ft_in1k_p14.pt',
746
+ hf_hub_id='timm/',
747
+ input_size=(3, 448, 448), crop_pct=1.0,
748
+ ),
749
+ 'eva02_large_patch14_448.mim_m38m_ft_in1k': _cfg(
750
+ #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in1k/eva02_L_pt_m38m_ft_in1k_p14.pt',
751
+ hf_hub_id='timm/',
752
+ input_size=(3, 448, 448), crop_pct=1.0,
753
+ ),
754
+
755
+ # in22k or m3m MIM pretrain w/ in22k fine-tune
756
+ 'eva02_base_patch14_448.mim_in22k_ft_in22k': _cfg(
757
+ #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_B_pt_in21k_medft_in21k_p14.pt',
758
+ hf_hub_id='timm/',
759
+ input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', num_classes=21841,
760
+ ),
761
+ 'eva02_large_patch14_448.mim_in22k_ft_in22k': _cfg(
762
+ #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_L_pt_in21k_medft_in21k_p14.pt',
763
+ hf_hub_id='timm/',
764
+ input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', num_classes=21841,
765
+ ),
766
+ 'eva02_large_patch14_448.mim_m38m_ft_in22k': _cfg(
767
+ #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_L_pt_m38m_medft_in21k_p14.pt',
768
+ hf_hub_id='timm/',
769
+ input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', num_classes=21841,
770
+ ),
771
+
772
+ # in22k or m38m MIM pretrain
773
+ 'eva02_tiny_patch14_224.mim_in22k': _cfg(
774
+ # hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_Ti_pt_in21k_p14.pt',
775
+ hf_hub_id='timm/',
776
+ num_classes=0,
777
+ ),
778
+ 'eva02_small_patch14_224.mim_in22k': _cfg(
779
+ #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_S_pt_in21k_p14.pt',
780
+ hf_hub_id='timm/',
781
+ num_classes=0,
782
+ ),
783
+ 'eva02_base_patch14_224.mim_in22k': _cfg(
784
+ #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_B_pt_in21k_p14.pt',
785
+ hf_hub_id='timm/',
786
+ num_classes=0,
787
+ ),
788
+ 'eva02_large_patch14_224.mim_in22k': _cfg(
789
+ #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_L_pt_in21k_p14.pt',
790
+ hf_hub_id='timm/',
791
+ num_classes=0,
792
+ ),
793
+ 'eva02_large_patch14_224.mim_m38m': _cfg(
794
+ #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_L_pt_m38m_p14.pt',
795
+ hf_hub_id='timm/',
796
+ num_classes=0,
797
+ ),
798
+
799
+ # EVA01 and EVA02 CLIP image towers
800
+ 'eva_giant_patch14_clip_224.laion400m': _cfg(
801
+ # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA01_CLIP_g_14_plus_psz14_s11B.pt',
802
+ hf_hub_id='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k', # float16 weights
803
+ hf_hub_filename='open_clip_pytorch_model.bin',
804
+ num_classes=1024,
805
+ ),
806
+ 'eva_giant_patch14_clip_224.merged2b': _cfg(
807
+ # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA01_CLIP_g_14_plus_psz14_s11B.pt',
808
+ hf_hub_id='timm/eva_giant_patch14_plus_clip_224.merged2b_s11b_b114k', # float16 weights
809
+ hf_hub_filename='open_clip_pytorch_model.bin',
810
+ num_classes=1024,
811
+ ),
812
+ 'eva02_base_patch16_clip_224.merged2b': _cfg(
813
+ # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_L_psz14_s4B.pt',
814
+ hf_hub_id='timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k', # float16 weights
815
+ hf_hub_filename='open_clip_pytorch_model.bin',
816
+ num_classes=512,
817
+ ),
818
+ 'eva02_large_patch14_clip_224.merged2b': _cfg(
819
+ # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_L_psz14_s4B.pt',
820
+ hf_hub_id='timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k', # float16 weights
821
+ hf_hub_filename='open_clip_pytorch_model.bin',
822
+ num_classes=768,
823
+ ),
824
+ 'eva02_large_patch14_clip_336.merged2b': _cfg(
825
+ # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_L_psz14_s4B.pt',
826
+ hf_hub_id='timm/eva02_large_patch14_clip_336.merged2b_s6b_b61k', # float16 weights
827
+ hf_hub_filename='open_clip_pytorch_model.bin',
828
+ input_size=(3, 336, 336), crop_pct=1.0,
829
+ num_classes=768,
830
+ ),
831
+ 'eva02_enormous_patch14_clip_224.laion2b': _cfg(
832
+ # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_E_psz14_plus_s9B.pt',
833
+ hf_hub_id='timm/eva02_enormous_patch14_clip_224.laion2b_s4b_b115k', # float16 weights
834
+ hf_hub_filename='open_clip_pytorch_model.bin',
835
+ num_classes=1024,
836
+ ),
837
+ 'eva02_enormous_patch14_clip_224.laion2b_plus': _cfg(
838
+ # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_E_psz14_plus_s9B.pt',
839
+ hf_hub_id='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k', # bfloat16 weights
840
+ hf_hub_filename='open_clip_pytorch_model.bin',
841
+ num_classes=1024,
842
+ ),
843
+ 'eva02_enormous_patch14_clip_224.pretrain': _cfg(
844
+ # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_E_psz14.pt',
845
+ num_classes=0,
846
+ ),
847
+
848
+ })
849
+
850
+
851
+ @register_model
852
+ def eva_giant_patch14_224(pretrained=False, **kwargs) -> Eva:
853
+ """ EVA-g model https://arxiv.org/abs/2211.07636 """
854
+ model_args = dict(patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408)
855
+ model = _create_eva('eva_giant_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
856
+ return model
857
+
858
+
859
+ @register_model
860
+ def eva_giant_patch14_336(pretrained=False, **kwargs) -> Eva:
861
+ """ EVA-g model https://arxiv.org/abs/2211.07636 """
862
+ model_args = dict(patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408)
863
+ model = _create_eva('eva_giant_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
864
+ return model
865
+
866
+
867
+ @register_model
868
+ def eva_giant_patch14_560(pretrained=False, **kwargs) -> Eva:
869
+ """ EVA-g model https://arxiv.org/abs/2211.07636 """
870
+ model_args = dict(patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408)
871
+ model = _create_eva('eva_giant_patch14_560', pretrained=pretrained, **dict(model_args, **kwargs))
872
+ return model
873
+
874
+
875
+ @register_model
876
+ def eva02_tiny_patch14_224(pretrained=False, **kwargs) -> Eva:
877
+ model_args = dict(
878
+ img_size=224,
879
+ patch_size=14,
880
+ embed_dim=192,
881
+ depth=12,
882
+ num_heads=3,
883
+ mlp_ratio=4 * 2 / 3,
884
+ swiglu_mlp=True,
885
+ use_rot_pos_emb=True,
886
+ ref_feat_shape=(16, 16), # 224/14
887
+ )
888
+ model = _create_eva('eva02_tiny_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
889
+ return model
890
+
891
+
892
+ @register_model
893
+ def eva02_small_patch14_224(pretrained=False, **kwargs) -> Eva:
894
+ model_args = dict(
895
+ img_size=224,
896
+ patch_size=14,
897
+ embed_dim=384,
898
+ depth=12,
899
+ num_heads=6,
900
+ mlp_ratio=4 * 2 / 3,
901
+ swiglu_mlp=True,
902
+ use_rot_pos_emb=True,
903
+ ref_feat_shape=(16, 16), # 224/14
904
+ )
905
+ model = _create_eva('eva02_small_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
906
+ return model
907
+
908
+
909
+ @register_model
910
+ def eva02_base_patch14_224(pretrained=False, **kwargs) -> Eva:
911
+ model_args = dict(
912
+ img_size=224,
913
+ patch_size=14,
914
+ embed_dim=768,
915
+ depth=12,
916
+ num_heads=12,
917
+ qkv_fused=False,
918
+ mlp_ratio=4 * 2 / 3,
919
+ swiglu_mlp=True,
920
+ scale_mlp=True,
921
+ use_rot_pos_emb=True,
922
+ ref_feat_shape=(16, 16), # 224/14
923
+ )
924
+ model = _create_eva('eva02_base_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
925
+ return model
926
+
927
+
928
+ @register_model
929
+ def eva02_large_patch14_224(pretrained=False, **kwargs) -> Eva:
930
+ model_args = dict(
931
+ img_size=224,
932
+ patch_size=14,
933
+ embed_dim=1024,
934
+ depth=24,
935
+ num_heads=16,
936
+ mlp_ratio=4 * 2 / 3,
937
+ qkv_fused=False,
938
+ swiglu_mlp=True,
939
+ scale_mlp=True,
940
+ use_rot_pos_emb=True,
941
+ ref_feat_shape=(16, 16), # 224/14
942
+ )
943
+ model = _create_eva('eva02_large_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
944
+ return model
945
+
946
+
947
+ @register_model
948
+ def eva02_tiny_patch14_336(pretrained=False, **kwargs) -> Eva:
949
+ model_args = dict(
950
+ img_size=336,
951
+ patch_size=14,
952
+ embed_dim=192,
953
+ depth=12,
954
+ num_heads=3,
955
+ mlp_ratio=4 * 2 / 3,
956
+ swiglu_mlp=True,
957
+ use_rot_pos_emb=True,
958
+ ref_feat_shape=(16, 16), # 224/14
959
+ )
960
+ model = _create_eva('eva02_tiny_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
961
+ return model
962
+
963
+
964
+ @register_model
965
+ def eva02_small_patch14_336(pretrained=False, **kwargs) -> Eva:
966
+ model_args = dict(
967
+ img_size=336,
968
+ patch_size=14,
969
+ embed_dim=384,
970
+ depth=12,
971
+ num_heads=6,
972
+ mlp_ratio=4 * 2 / 3,
973
+ swiglu_mlp=True,
974
+ use_rot_pos_emb=True,
975
+ ref_feat_shape=(16, 16), # 224/14
976
+ )
977
+ model = _create_eva('eva02_small_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
978
+ return model
979
+
980
+
981
+ @register_model
982
+ def eva02_base_patch14_448(pretrained=False, **kwargs) -> Eva:
983
+ model_args = dict(
984
+ img_size=448,
985
+ patch_size=14,
986
+ embed_dim=768,
987
+ depth=12,
988
+ num_heads=12,
989
+ qkv_fused=False,
990
+ mlp_ratio=4 * 2 / 3,
991
+ swiglu_mlp=True,
992
+ scale_mlp=True,
993
+ use_rot_pos_emb=True,
994
+ ref_feat_shape=(16, 16), # 224/14
995
+ )
996
+ model = _create_eva('eva02_base_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
997
+ return model
998
+
999
+
1000
+ @register_model
1001
+ def eva02_large_patch14_448(pretrained=False, **kwargs) -> Eva:
1002
+ model_args = dict(
1003
+ img_size=448,
1004
+ patch_size=14,
1005
+ embed_dim=1024,
1006
+ depth=24,
1007
+ num_heads=16,
1008
+ mlp_ratio=4 * 2 / 3,
1009
+ qkv_fused=False,
1010
+ swiglu_mlp=True,
1011
+ scale_mlp=True,
1012
+ use_rot_pos_emb=True,
1013
+ ref_feat_shape=(16, 16), # 224/14
1014
+ )
1015
+ model = _create_eva('eva02_large_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
1016
+ return model
1017
+
1018
+
1019
+ @register_model
1020
+ def eva_giant_patch14_clip_224(pretrained=False, **kwargs) -> Eva:
1021
+ """ EVA-g CLIP model (only difference from non-CLIP is the pooling) """
1022
+ model_args = dict(
1023
+ patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408,
1024
+ global_pool=kwargs.pop('global_pool', 'token'))
1025
+ model = _create_eva('eva_giant_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
1026
+ return model
1027
+
1028
+
1029
+ @register_model
1030
+ def eva02_base_patch16_clip_224(pretrained=False, **kwargs) -> Eva:
1031
+ """ A EVA-CLIP specific variant that adds additional attn scale layernorm to eva02_base """
1032
+ model_args = dict(
1033
+ img_size=224,
1034
+ patch_size=16,
1035
+ embed_dim=768,
1036
+ depth=12,
1037
+ num_heads=12,
1038
+ qkv_fused=False,
1039
+ mlp_ratio=4 * 2 / 3,
1040
+ swiglu_mlp=True,
1041
+ scale_mlp=True,
1042
+ scale_attn_inner=True,
1043
+ use_rot_pos_emb=True,
1044
+ ref_feat_shape=(16, 16), # 224/14
1045
+ global_pool=kwargs.pop('global_pool', 'token'),
1046
+ )
1047
+ model = _create_eva('eva02_base_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
1048
+ return model
1049
+
1050
+
1051
+ @register_model
1052
+ def eva02_large_patch14_clip_224(pretrained=False, **kwargs) -> Eva:
1053
+ """ A EVA-CLIP specific variant that adds additional attn scale layernorm to eva02_large """
1054
+ model_args = dict(
1055
+ img_size=224,
1056
+ patch_size=14,
1057
+ embed_dim=1024,
1058
+ depth=24,
1059
+ num_heads=16,
1060
+ mlp_ratio=4 * 2 / 3,
1061
+ qkv_fused=False,
1062
+ swiglu_mlp=True,
1063
+ scale_mlp=True,
1064
+ scale_attn_inner=True,
1065
+ use_rot_pos_emb=True,
1066
+ ref_feat_shape=(16, 16), # 224/14
1067
+ global_pool=kwargs.pop('global_pool', 'token'),
1068
+ )
1069
+ model = _create_eva('eva02_large_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
1070
+ return model
1071
+
1072
+
1073
+ @register_model
1074
+ def eva02_large_patch14_clip_336(pretrained=False, **kwargs) -> Eva:
1075
+ """ A EVA-CLIP specific variant that adds additional attn scale layernorm to eva02_large """
1076
+ model_args = dict(
1077
+ img_size=336,
1078
+ patch_size=14,
1079
+ embed_dim=1024,
1080
+ depth=24,
1081
+ num_heads=16,
1082
+ mlp_ratio=4 * 2 / 3,
1083
+ qkv_fused=False,
1084
+ swiglu_mlp=True,
1085
+ scale_mlp=True,
1086
+ scale_attn_inner=True,
1087
+ use_rot_pos_emb=True,
1088
+ ref_feat_shape=(16, 16), # 224/14
1089
+ global_pool=kwargs.pop('global_pool', 'token'),
1090
+ )
1091
+ model = _create_eva('eva02_large_patch14_clip_336', pretrained=pretrained, **dict(model_args, **kwargs))
1092
+ return model
1093
+
1094
+
1095
+ @register_model
1096
+ def eva02_enormous_patch14_clip_224(pretrained=False, **kwargs) -> Eva:
1097
+ """ A EVA-CLIP specific variant that uses residual post-norm in blocks """
1098
+ model_args = dict(
1099
+ img_size=224,
1100
+ patch_size=14,
1101
+ embed_dim=1792,
1102
+ depth=64,
1103
+ num_heads=16,
1104
+ mlp_ratio=15360 / 1792,
1105
+ use_post_norm=True,
1106
+ global_pool=kwargs.pop('global_pool', 'token'),
1107
+ )
1108
+ model = _create_eva('eva02_enormous_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
1109
+ return model
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/factory.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from ._factory import *
2
+
3
+ import warnings
4
+ warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/features.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from ._features import *
2
+
3
+ import warnings
4
+ warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/gcvit.py ADDED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Global Context ViT
2
+
3
+ From scratch implementation of GCViT in the style of timm swin_transformer_v2_cr.py
4
+
5
+ Global Context Vision Transformers -https://arxiv.org/abs/2206.09959
6
+
7
+ @article{hatamizadeh2022global,
8
+ title={Global Context Vision Transformers},
9
+ author={Hatamizadeh, Ali and Yin, Hongxu and Kautz, Jan and Molchanov, Pavlo},
10
+ journal={arXiv preprint arXiv:2206.09959},
11
+ year={2022}
12
+ }
13
+
14
+ Free of any code related to NVIDIA GCVit impl at https://github.com/NVlabs/GCVit.
15
+ The license for this code release is Apache 2.0 with no commercial restrictions.
16
+
17
+ However, weight files adapted from NVIDIA GCVit impl ARE under a non-commercial share-alike license
18
+ (https://creativecommons.org/licenses/by-nc-sa/4.0/) until I have a chance to train new ones...
19
+
20
+ Hacked together by / Copyright 2022, Ross Wightman
21
+ """
22
+ import math
23
+ from functools import partial
24
+ from typing import Callable, List, Optional, Tuple, Union
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.utils.checkpoint as checkpoint
29
+
30
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
31
+ from timm.layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d, \
32
+ get_attn, get_act_layer, get_norm_layer, RelPosBias, _assert
33
+ from ._builder import build_model_with_cfg
34
+ from ._features_fx import register_notrace_function
35
+ from ._manipulate import named_apply
36
+ from ._registry import register_model, generate_default_cfgs
37
+
38
+ __all__ = ['GlobalContextVit']
39
+
40
+
41
+ class MbConvBlock(nn.Module):
42
+ """ A depthwise separable / fused mbconv style residual block with SE, `no norm.
43
+ """
44
+ def __init__(
45
+ self,
46
+ in_chs,
47
+ out_chs=None,
48
+ expand_ratio=1.0,
49
+ attn_layer='se',
50
+ bias=False,
51
+ act_layer=nn.GELU,
52
+ ):
53
+ super().__init__()
54
+ attn_kwargs = dict(act_layer=act_layer)
55
+ if isinstance(attn_layer, str) and attn_layer == 'se' or attn_layer == 'eca':
56
+ attn_kwargs['rd_ratio'] = 0.25
57
+ attn_kwargs['bias'] = False
58
+ attn_layer = get_attn(attn_layer)
59
+ out_chs = out_chs or in_chs
60
+ mid_chs = int(expand_ratio * in_chs)
61
+
62
+ self.conv_dw = nn.Conv2d(in_chs, mid_chs, 3, 1, 1, groups=in_chs, bias=bias)
63
+ self.act = act_layer()
64
+ self.se = attn_layer(mid_chs, **attn_kwargs)
65
+ self.conv_pw = nn.Conv2d(mid_chs, out_chs, 1, 1, 0, bias=bias)
66
+
67
+ def forward(self, x):
68
+ shortcut = x
69
+ x = self.conv_dw(x)
70
+ x = self.act(x)
71
+ x = self.se(x)
72
+ x = self.conv_pw(x)
73
+ x = x + shortcut
74
+ return x
75
+
76
+
77
+ class Downsample2d(nn.Module):
78
+ def __init__(
79
+ self,
80
+ dim,
81
+ dim_out=None,
82
+ reduction='conv',
83
+ act_layer=nn.GELU,
84
+ norm_layer=LayerNorm2d, # NOTE in NCHW
85
+ ):
86
+ super().__init__()
87
+ dim_out = dim_out or dim
88
+
89
+ self.norm1 = norm_layer(dim) if norm_layer is not None else nn.Identity()
90
+ self.conv_block = MbConvBlock(dim, act_layer=act_layer)
91
+ assert reduction in ('conv', 'max', 'avg')
92
+ if reduction == 'conv':
93
+ self.reduction = nn.Conv2d(dim, dim_out, 3, 2, 1, bias=False)
94
+ elif reduction == 'max':
95
+ assert dim == dim_out
96
+ self.reduction = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
97
+ else:
98
+ assert dim == dim_out
99
+ self.reduction = nn.AvgPool2d(kernel_size=2)
100
+ self.norm2 = norm_layer(dim_out) if norm_layer is not None else nn.Identity()
101
+
102
+ def forward(self, x):
103
+ x = self.norm1(x)
104
+ x = self.conv_block(x)
105
+ x = self.reduction(x)
106
+ x = self.norm2(x)
107
+ return x
108
+
109
+
110
+ class FeatureBlock(nn.Module):
111
+ def __init__(
112
+ self,
113
+ dim,
114
+ levels=0,
115
+ reduction='max',
116
+ act_layer=nn.GELU,
117
+ ):
118
+ super().__init__()
119
+ reductions = levels
120
+ levels = max(1, levels)
121
+ if reduction == 'avg':
122
+ pool_fn = partial(nn.AvgPool2d, kernel_size=2)
123
+ else:
124
+ pool_fn = partial(nn.MaxPool2d, kernel_size=3, stride=2, padding=1)
125
+ self.blocks = nn.Sequential()
126
+ for i in range(levels):
127
+ self.blocks.add_module(f'conv{i+1}', MbConvBlock(dim, act_layer=act_layer))
128
+ if reductions:
129
+ self.blocks.add_module(f'pool{i+1}', pool_fn())
130
+ reductions -= 1
131
+
132
+ def forward(self, x):
133
+ return self.blocks(x)
134
+
135
+
136
+ class Stem(nn.Module):
137
+ def __init__(
138
+ self,
139
+ in_chs: int = 3,
140
+ out_chs: int = 96,
141
+ act_layer: Callable = nn.GELU,
142
+ norm_layer: Callable = LayerNorm2d, # NOTE stem in NCHW
143
+ ):
144
+ super().__init__()
145
+ self.conv1 = nn.Conv2d(in_chs, out_chs, kernel_size=3, stride=2, padding=1)
146
+ self.down = Downsample2d(out_chs, act_layer=act_layer, norm_layer=norm_layer)
147
+
148
+ def forward(self, x):
149
+ x = self.conv1(x)
150
+ x = self.down(x)
151
+ return x
152
+
153
+
154
+ class WindowAttentionGlobal(nn.Module):
155
+
156
+ def __init__(
157
+ self,
158
+ dim: int,
159
+ num_heads: int,
160
+ window_size: Tuple[int, int],
161
+ use_global: bool = True,
162
+ qkv_bias: bool = True,
163
+ attn_drop: float = 0.,
164
+ proj_drop: float = 0.,
165
+ ):
166
+ super().__init__()
167
+ window_size = to_2tuple(window_size)
168
+ self.window_size = window_size
169
+ self.num_heads = num_heads
170
+ self.head_dim = dim // num_heads
171
+ self.scale = self.head_dim ** -0.5
172
+ self.use_global = use_global
173
+
174
+ self.rel_pos = RelPosBias(window_size=window_size, num_heads=num_heads)
175
+ if self.use_global:
176
+ self.qkv = nn.Linear(dim, dim * 2, bias=qkv_bias)
177
+ else:
178
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
179
+ self.attn_drop = nn.Dropout(attn_drop)
180
+ self.proj = nn.Linear(dim, dim)
181
+ self.proj_drop = nn.Dropout(proj_drop)
182
+
183
+ def forward(self, x, q_global: Optional[torch.Tensor] = None):
184
+ B, N, C = x.shape
185
+ if self.use_global and q_global is not None:
186
+ _assert(x.shape[-1] == q_global.shape[-1], 'x and q_global seq lengths should be equal')
187
+
188
+ kv = self.qkv(x)
189
+ kv = kv.reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
190
+ k, v = kv.unbind(0)
191
+
192
+ q = q_global.repeat(B // q_global.shape[0], 1, 1, 1)
193
+ q = q.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
194
+ else:
195
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
196
+ q, k, v = qkv.unbind(0)
197
+ q = q * self.scale
198
+
199
+ attn = q @ k.transpose(-2, -1).contiguous() # NOTE contiguous() fixes an odd jit bug in PyTorch 2.0
200
+ attn = self.rel_pos(attn)
201
+ attn = attn.softmax(dim=-1)
202
+ attn = self.attn_drop(attn)
203
+
204
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
205
+ x = self.proj(x)
206
+ x = self.proj_drop(x)
207
+ return x
208
+
209
+
210
+ def window_partition(x, window_size: Tuple[int, int]):
211
+ B, H, W, C = x.shape
212
+ x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
213
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
214
+ return windows
215
+
216
+
217
+ @register_notrace_function # reason: int argument is a Proxy
218
+ def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]):
219
+ H, W = img_size
220
+ C = windows.shape[-1]
221
+ x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
222
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
223
+ return x
224
+
225
+
226
+ class LayerScale(nn.Module):
227
+ def __init__(self, dim, init_values=1e-5, inplace=False):
228
+ super().__init__()
229
+ self.inplace = inplace
230
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
231
+
232
+ def forward(self, x):
233
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
234
+
235
+
236
+ class GlobalContextVitBlock(nn.Module):
237
+ def __init__(
238
+ self,
239
+ dim: int,
240
+ feat_size: Tuple[int, int],
241
+ num_heads: int,
242
+ window_size: int = 7,
243
+ mlp_ratio: float = 4.,
244
+ use_global: bool = True,
245
+ qkv_bias: bool = True,
246
+ layer_scale: Optional[float] = None,
247
+ proj_drop: float = 0.,
248
+ attn_drop: float = 0.,
249
+ drop_path: float = 0.,
250
+ attn_layer: Callable = WindowAttentionGlobal,
251
+ act_layer: Callable = nn.GELU,
252
+ norm_layer: Callable = nn.LayerNorm,
253
+ ):
254
+ super().__init__()
255
+ feat_size = to_2tuple(feat_size)
256
+ window_size = to_2tuple(window_size)
257
+ self.window_size = window_size
258
+ self.num_windows = int((feat_size[0] // window_size[0]) * (feat_size[1] // window_size[1]))
259
+
260
+ self.norm1 = norm_layer(dim)
261
+ self.attn = attn_layer(
262
+ dim,
263
+ num_heads=num_heads,
264
+ window_size=window_size,
265
+ use_global=use_global,
266
+ qkv_bias=qkv_bias,
267
+ attn_drop=attn_drop,
268
+ proj_drop=proj_drop,
269
+ )
270
+ self.ls1 = LayerScale(dim, layer_scale) if layer_scale is not None else nn.Identity()
271
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
272
+
273
+ self.norm2 = norm_layer(dim)
274
+ self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop)
275
+ self.ls2 = LayerScale(dim, layer_scale) if layer_scale is not None else nn.Identity()
276
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
277
+
278
+ def _window_attn(self, x, q_global: Optional[torch.Tensor] = None):
279
+ B, H, W, C = x.shape
280
+ x_win = window_partition(x, self.window_size)
281
+ x_win = x_win.view(-1, self.window_size[0] * self.window_size[1], C)
282
+ attn_win = self.attn(x_win, q_global)
283
+ x = window_reverse(attn_win, self.window_size, (H, W))
284
+ return x
285
+
286
+ def forward(self, x, q_global: Optional[torch.Tensor] = None):
287
+ x = x + self.drop_path1(self.ls1(self._window_attn(self.norm1(x), q_global)))
288
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
289
+ return x
290
+
291
+
292
+ class GlobalContextVitStage(nn.Module):
293
+ def __init__(
294
+ self,
295
+ dim,
296
+ depth: int,
297
+ num_heads: int,
298
+ feat_size: Tuple[int, int],
299
+ window_size: Tuple[int, int],
300
+ downsample: bool = True,
301
+ global_norm: bool = False,
302
+ stage_norm: bool = False,
303
+ mlp_ratio: float = 4.,
304
+ qkv_bias: bool = True,
305
+ layer_scale: Optional[float] = None,
306
+ proj_drop: float = 0.,
307
+ attn_drop: float = 0.,
308
+ drop_path: Union[List[float], float] = 0.0,
309
+ act_layer: Callable = nn.GELU,
310
+ norm_layer: Callable = nn.LayerNorm,
311
+ norm_layer_cl: Callable = LayerNorm2d,
312
+ ):
313
+ super().__init__()
314
+ if downsample:
315
+ self.downsample = Downsample2d(
316
+ dim=dim,
317
+ dim_out=dim * 2,
318
+ norm_layer=norm_layer,
319
+ )
320
+ dim = dim * 2
321
+ feat_size = (feat_size[0] // 2, feat_size[1] // 2)
322
+ else:
323
+ self.downsample = nn.Identity()
324
+ self.feat_size = feat_size
325
+ window_size = to_2tuple(window_size)
326
+
327
+ feat_levels = int(math.log2(min(feat_size) / min(window_size)))
328
+ self.global_block = FeatureBlock(dim, feat_levels)
329
+ self.global_norm = norm_layer_cl(dim) if global_norm else nn.Identity()
330
+
331
+ self.blocks = nn.ModuleList([
332
+ GlobalContextVitBlock(
333
+ dim=dim,
334
+ num_heads=num_heads,
335
+ feat_size=feat_size,
336
+ window_size=window_size,
337
+ mlp_ratio=mlp_ratio,
338
+ qkv_bias=qkv_bias,
339
+ use_global=(i % 2 != 0),
340
+ layer_scale=layer_scale,
341
+ proj_drop=proj_drop,
342
+ attn_drop=attn_drop,
343
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
344
+ act_layer=act_layer,
345
+ norm_layer=norm_layer_cl,
346
+ )
347
+ for i in range(depth)
348
+ ])
349
+ self.norm = norm_layer_cl(dim) if stage_norm else nn.Identity()
350
+ self.dim = dim
351
+ self.feat_size = feat_size
352
+ self.grad_checkpointing = False
353
+
354
+ def forward(self, x):
355
+ # input NCHW, downsample & global block are 2d conv + pooling
356
+ x = self.downsample(x)
357
+ global_query = self.global_block(x)
358
+
359
+ # reshape NCHW --> NHWC for transformer blocks
360
+ x = x.permute(0, 2, 3, 1)
361
+ global_query = self.global_norm(global_query.permute(0, 2, 3, 1))
362
+ for blk in self.blocks:
363
+ if self.grad_checkpointing and not torch.jit.is_scripting():
364
+ x = checkpoint.checkpoint(blk, x)
365
+ else:
366
+ x = blk(x, global_query)
367
+ x = self.norm(x)
368
+ x = x.permute(0, 3, 1, 2).contiguous() # back to NCHW
369
+ return x
370
+
371
+
372
+ class GlobalContextVit(nn.Module):
373
+ def __init__(
374
+ self,
375
+ in_chans: int = 3,
376
+ num_classes: int = 1000,
377
+ global_pool: str = 'avg',
378
+ img_size: Tuple[int, int] = 224,
379
+ window_ratio: Tuple[int, ...] = (32, 32, 16, 32),
380
+ window_size: Tuple[int, ...] = None,
381
+ embed_dim: int = 64,
382
+ depths: Tuple[int, ...] = (3, 4, 19, 5),
383
+ num_heads: Tuple[int, ...] = (2, 4, 8, 16),
384
+ mlp_ratio: float = 3.0,
385
+ qkv_bias: bool = True,
386
+ layer_scale: Optional[float] = None,
387
+ drop_rate: float = 0.,
388
+ proj_drop_rate: float = 0.,
389
+ attn_drop_rate: float = 0.,
390
+ drop_path_rate: float = 0.,
391
+ weight_init='',
392
+ act_layer: str = 'gelu',
393
+ norm_layer: str = 'layernorm2d',
394
+ norm_layer_cl: str = 'layernorm',
395
+ norm_eps: float = 1e-5,
396
+ ):
397
+ super().__init__()
398
+ act_layer = get_act_layer(act_layer)
399
+ norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
400
+ norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps)
401
+
402
+ img_size = to_2tuple(img_size)
403
+ feat_size = tuple(d // 4 for d in img_size) # stem reduction by 4
404
+ self.global_pool = global_pool
405
+ self.num_classes = num_classes
406
+ self.drop_rate = drop_rate
407
+ num_stages = len(depths)
408
+ self.num_features = int(embed_dim * 2 ** (num_stages - 1))
409
+ if window_size is not None:
410
+ window_size = to_ntuple(num_stages)(window_size)
411
+ else:
412
+ assert window_ratio is not None
413
+ window_size = tuple([(img_size[0] // r, img_size[1] // r) for r in to_ntuple(num_stages)(window_ratio)])
414
+
415
+ self.stem = Stem(
416
+ in_chs=in_chans,
417
+ out_chs=embed_dim,
418
+ act_layer=act_layer,
419
+ norm_layer=norm_layer
420
+ )
421
+
422
+ dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
423
+ stages = []
424
+ for i in range(num_stages):
425
+ last_stage = i == num_stages - 1
426
+ stage_scale = 2 ** max(i - 1, 0)
427
+ stages.append(GlobalContextVitStage(
428
+ dim=embed_dim * stage_scale,
429
+ depth=depths[i],
430
+ num_heads=num_heads[i],
431
+ feat_size=(feat_size[0] // stage_scale, feat_size[1] // stage_scale),
432
+ window_size=window_size[i],
433
+ downsample=i != 0,
434
+ stage_norm=last_stage,
435
+ mlp_ratio=mlp_ratio,
436
+ qkv_bias=qkv_bias,
437
+ layer_scale=layer_scale,
438
+ proj_drop=proj_drop_rate,
439
+ attn_drop=attn_drop_rate,
440
+ drop_path=dpr[i],
441
+ act_layer=act_layer,
442
+ norm_layer=norm_layer,
443
+ norm_layer_cl=norm_layer_cl,
444
+ ))
445
+ self.stages = nn.Sequential(*stages)
446
+
447
+ # Classifier head
448
+ self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
449
+
450
+ if weight_init:
451
+ named_apply(partial(self._init_weights, scheme=weight_init), self)
452
+
453
+ def _init_weights(self, module, name, scheme='vit'):
454
+ # note Conv2d left as default init
455
+ if scheme == 'vit':
456
+ if isinstance(module, nn.Linear):
457
+ nn.init.xavier_uniform_(module.weight)
458
+ if module.bias is not None:
459
+ if 'mlp' in name:
460
+ nn.init.normal_(module.bias, std=1e-6)
461
+ else:
462
+ nn.init.zeros_(module.bias)
463
+ else:
464
+ if isinstance(module, nn.Linear):
465
+ nn.init.normal_(module.weight, std=.02)
466
+ if module.bias is not None:
467
+ nn.init.zeros_(module.bias)
468
+
469
+ @torch.jit.ignore
470
+ def no_weight_decay(self):
471
+ return {
472
+ k for k, _ in self.named_parameters()
473
+ if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])}
474
+
475
+ @torch.jit.ignore
476
+ def group_matcher(self, coarse=False):
477
+ matcher = dict(
478
+ stem=r'^stem', # stem and embed
479
+ blocks=r'^stages\.(\d+)'
480
+ )
481
+ return matcher
482
+
483
+ @torch.jit.ignore
484
+ def set_grad_checkpointing(self, enable=True):
485
+ for s in self.stages:
486
+ s.grad_checkpointing = enable
487
+
488
+ @torch.jit.ignore
489
+ def get_classifier(self):
490
+ return self.head.fc
491
+
492
+ def reset_classifier(self, num_classes, global_pool=None):
493
+ self.num_classes = num_classes
494
+ if global_pool is None:
495
+ global_pool = self.head.global_pool.pool_type
496
+ self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
497
+
498
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
499
+ x = self.stem(x)
500
+ x = self.stages(x)
501
+ return x
502
+
503
+ def forward_head(self, x, pre_logits: bool = False):
504
+ return self.head(x, pre_logits=pre_logits)
505
+
506
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
507
+ x = self.forward_features(x)
508
+ x = self.forward_head(x)
509
+ return x
510
+
511
+
512
+ def _create_gcvit(variant, pretrained=False, **kwargs):
513
+ if kwargs.get('features_only', None):
514
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
515
+ model = build_model_with_cfg(GlobalContextVit, variant, pretrained, **kwargs)
516
+ return model
517
+
518
+
519
+ def _cfg(url='', **kwargs):
520
+ return {
521
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
522
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
523
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
524
+ 'first_conv': 'stem.conv1', 'classifier': 'head.fc',
525
+ 'fixed_input_size': True,
526
+ **kwargs
527
+ }
528
+
529
+
530
+ default_cfgs = generate_default_cfgs({
531
+ 'gcvit_xxtiny.in1k': _cfg(
532
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_xxtiny_224_nvidia-d1d86009.pth'),
533
+ 'gcvit_xtiny.in1k': _cfg(
534
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_xtiny_224_nvidia-274b92b7.pth'),
535
+ 'gcvit_tiny.in1k': _cfg(
536
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_tiny_224_nvidia-ac783954.pth'),
537
+ 'gcvit_small.in1k': _cfg(
538
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_small_224_nvidia-4e98afa2.pth'),
539
+ 'gcvit_base.in1k': _cfg(
540
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_base_224_nvidia-f009139b.pth'),
541
+ })
542
+
543
+
544
+ @register_model
545
+ def gcvit_xxtiny(pretrained=False, **kwargs) -> GlobalContextVit:
546
+ model_kwargs = dict(
547
+ depths=(2, 2, 6, 2),
548
+ num_heads=(2, 4, 8, 16),
549
+ **kwargs)
550
+ return _create_gcvit('gcvit_xxtiny', pretrained=pretrained, **model_kwargs)
551
+
552
+
553
+ @register_model
554
+ def gcvit_xtiny(pretrained=False, **kwargs) -> GlobalContextVit:
555
+ model_kwargs = dict(
556
+ depths=(3, 4, 6, 5),
557
+ num_heads=(2, 4, 8, 16),
558
+ **kwargs)
559
+ return _create_gcvit('gcvit_xtiny', pretrained=pretrained, **model_kwargs)
560
+
561
+
562
+ @register_model
563
+ def gcvit_tiny(pretrained=False, **kwargs) -> GlobalContextVit:
564
+ model_kwargs = dict(
565
+ depths=(3, 4, 19, 5),
566
+ num_heads=(2, 4, 8, 16),
567
+ **kwargs)
568
+ return _create_gcvit('gcvit_tiny', pretrained=pretrained, **model_kwargs)
569
+
570
+
571
+ @register_model
572
+ def gcvit_small(pretrained=False, **kwargs) -> GlobalContextVit:
573
+ model_kwargs = dict(
574
+ depths=(3, 4, 19, 5),
575
+ num_heads=(3, 6, 12, 24),
576
+ embed_dim=96,
577
+ mlp_ratio=2,
578
+ layer_scale=1e-5,
579
+ **kwargs)
580
+ return _create_gcvit('gcvit_small', pretrained=pretrained, **model_kwargs)
581
+
582
+
583
+ @register_model
584
+ def gcvit_base(pretrained=False, **kwargs) -> GlobalContextVit:
585
+ model_kwargs = dict(
586
+ depths=(3, 4, 19, 5),
587
+ num_heads=(4, 8, 16, 32),
588
+ embed_dim=128,
589
+ mlp_ratio=2,
590
+ layer_scale=1e-5,
591
+ **kwargs)
592
+ return _create_gcvit('gcvit_base', pretrained=pretrained, **model_kwargs)
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/ghostnet.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ An implementation of GhostNet & GhostNetV2 Models as defined in:
3
+ GhostNet: More Features from Cheap Operations. https://arxiv.org/abs/1911.11907
4
+ GhostNetV2: Enhance Cheap Operation with Long-Range Attention. https://proceedings.neurips.cc/paper_files/paper/2022/file/40b60852a4abdaa696b5a1a78da34635-Paper-Conference.pdf
5
+
6
+ The train script & code of models at:
7
+ Original model: https://github.com/huawei-noah/CV-backbones/tree/master/ghostnet_pytorch
8
+ Original model: https://github.com/huawei-noah/Efficient-AI-Backbones/blob/master/ghostnetv2_pytorch/model/ghostnetv2_torch.py
9
+ """
10
+ import math
11
+ from functools import partial
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
18
+ from timm.layers import SelectAdaptivePool2d, Linear, make_divisible
19
+ from ._builder import build_model_with_cfg
20
+ from ._efficientnet_blocks import SqueezeExcite, ConvBnAct
21
+ from ._manipulate import checkpoint_seq
22
+ from ._registry import register_model, generate_default_cfgs
23
+
24
+ __all__ = ['GhostNet']
25
+
26
+
27
+ _SE_LAYER = partial(SqueezeExcite, gate_layer='hard_sigmoid', rd_round_fn=partial(make_divisible, divisor=4))
28
+
29
+
30
+ class GhostModule(nn.Module):
31
+ def __init__(
32
+ self,
33
+ in_chs,
34
+ out_chs,
35
+ kernel_size=1,
36
+ ratio=2,
37
+ dw_size=3,
38
+ stride=1,
39
+ use_act=True,
40
+ act_layer=nn.ReLU,
41
+ ):
42
+ super(GhostModule, self).__init__()
43
+ self.out_chs = out_chs
44
+ init_chs = math.ceil(out_chs / ratio)
45
+ new_chs = init_chs * (ratio - 1)
46
+
47
+ self.primary_conv = nn.Sequential(
48
+ nn.Conv2d(in_chs, init_chs, kernel_size, stride, kernel_size // 2, bias=False),
49
+ nn.BatchNorm2d(init_chs),
50
+ act_layer(inplace=True) if use_act else nn.Identity(),
51
+ )
52
+
53
+ self.cheap_operation = nn.Sequential(
54
+ nn.Conv2d(init_chs, new_chs, dw_size, 1, dw_size//2, groups=init_chs, bias=False),
55
+ nn.BatchNorm2d(new_chs),
56
+ act_layer(inplace=True) if use_act else nn.Identity(),
57
+ )
58
+
59
+ def forward(self, x):
60
+ x1 = self.primary_conv(x)
61
+ x2 = self.cheap_operation(x1)
62
+ out = torch.cat([x1, x2], dim=1)
63
+ return out[:, :self.out_chs, :, :]
64
+
65
+
66
+ class GhostModuleV2(nn.Module):
67
+ def __init__(
68
+ self,
69
+ in_chs,
70
+ out_chs,
71
+ kernel_size=1,
72
+ ratio=2,
73
+ dw_size=3,
74
+ stride=1,
75
+ use_act=True,
76
+ act_layer=nn.ReLU,
77
+ ):
78
+ super().__init__()
79
+ self.gate_fn = nn.Sigmoid()
80
+ self.out_chs = out_chs
81
+ init_chs = math.ceil(out_chs / ratio)
82
+ new_chs = init_chs * (ratio - 1)
83
+ self.primary_conv = nn.Sequential(
84
+ nn.Conv2d(in_chs, init_chs, kernel_size, stride, kernel_size // 2, bias=False),
85
+ nn.BatchNorm2d(init_chs),
86
+ act_layer(inplace=True) if use_act else nn.Identity(),
87
+ )
88
+ self.cheap_operation = nn.Sequential(
89
+ nn.Conv2d(init_chs, new_chs, dw_size, 1, dw_size // 2, groups=init_chs, bias=False),
90
+ nn.BatchNorm2d(new_chs),
91
+ act_layer(inplace=True) if use_act else nn.Identity(),
92
+ )
93
+ self.short_conv = nn.Sequential(
94
+ nn.Conv2d(in_chs, out_chs, kernel_size, stride, kernel_size // 2, bias=False),
95
+ nn.BatchNorm2d(out_chs),
96
+ nn.Conv2d(out_chs, out_chs, kernel_size=(1, 5), stride=1, padding=(0, 2), groups=out_chs, bias=False),
97
+ nn.BatchNorm2d(out_chs),
98
+ nn.Conv2d(out_chs, out_chs, kernel_size=(5, 1), stride=1, padding=(2, 0), groups=out_chs, bias=False),
99
+ nn.BatchNorm2d(out_chs),
100
+ )
101
+
102
+ def forward(self, x):
103
+ res = self.short_conv(F.avg_pool2d(x, kernel_size=2, stride=2))
104
+ x1 = self.primary_conv(x)
105
+ x2 = self.cheap_operation(x1)
106
+ out = torch.cat([x1, x2], dim=1)
107
+ return out[:, :self.out_chs, :, :] * F.interpolate(
108
+ self.gate_fn(res), size=(out.shape[-2], out.shape[-1]), mode='nearest')
109
+
110
+
111
+ class GhostBottleneck(nn.Module):
112
+ """ Ghost bottleneck w/ optional SE"""
113
+
114
+ def __init__(
115
+ self,
116
+ in_chs,
117
+ mid_chs,
118
+ out_chs,
119
+ dw_kernel_size=3,
120
+ stride=1,
121
+ act_layer=nn.ReLU,
122
+ se_ratio=0.,
123
+ mode='original',
124
+ ):
125
+ super(GhostBottleneck, self).__init__()
126
+ has_se = se_ratio is not None and se_ratio > 0.
127
+ self.stride = stride
128
+
129
+ # Point-wise expansion
130
+ if mode == 'original':
131
+ self.ghost1 = GhostModule(in_chs, mid_chs, use_act=True, act_layer=act_layer)
132
+ else:
133
+ self.ghost1 = GhostModuleV2(in_chs, mid_chs, use_act=True, act_layer=act_layer)
134
+
135
+ # Depth-wise convolution
136
+ if self.stride > 1:
137
+ self.conv_dw = nn.Conv2d(
138
+ mid_chs, mid_chs, dw_kernel_size, stride=stride,
139
+ padding=(dw_kernel_size-1)//2, groups=mid_chs, bias=False)
140
+ self.bn_dw = nn.BatchNorm2d(mid_chs)
141
+ else:
142
+ self.conv_dw = None
143
+ self.bn_dw = None
144
+
145
+ # Squeeze-and-excitation
146
+ self.se = _SE_LAYER(mid_chs, rd_ratio=se_ratio) if has_se else None
147
+
148
+ # Point-wise linear projection
149
+ self.ghost2 = GhostModule(mid_chs, out_chs, use_act=False)
150
+
151
+ # shortcut
152
+ if in_chs == out_chs and self.stride == 1:
153
+ self.shortcut = nn.Sequential()
154
+ else:
155
+ self.shortcut = nn.Sequential(
156
+ nn.Conv2d(
157
+ in_chs, in_chs, dw_kernel_size, stride=stride,
158
+ padding=(dw_kernel_size-1)//2, groups=in_chs, bias=False),
159
+ nn.BatchNorm2d(in_chs),
160
+ nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False),
161
+ nn.BatchNorm2d(out_chs),
162
+ )
163
+
164
+ def forward(self, x):
165
+ shortcut = x
166
+
167
+ # 1st ghost bottleneck
168
+ x = self.ghost1(x)
169
+
170
+ # Depth-wise convolution
171
+ if self.conv_dw is not None:
172
+ x = self.conv_dw(x)
173
+ x = self.bn_dw(x)
174
+
175
+ # Squeeze-and-excitation
176
+ if self.se is not None:
177
+ x = self.se(x)
178
+
179
+ # 2nd ghost bottleneck
180
+ x = self.ghost2(x)
181
+
182
+ x += self.shortcut(shortcut)
183
+ return x
184
+
185
+
186
+ class GhostNet(nn.Module):
187
+ def __init__(
188
+ self,
189
+ cfgs,
190
+ num_classes=1000,
191
+ width=1.0,
192
+ in_chans=3,
193
+ output_stride=32,
194
+ global_pool='avg',
195
+ drop_rate=0.2,
196
+ version='v1',
197
+ ):
198
+ super(GhostNet, self).__init__()
199
+ # setting of inverted residual blocks
200
+ assert output_stride == 32, 'only output_stride==32 is valid, dilation not supported'
201
+ self.cfgs = cfgs
202
+ self.num_classes = num_classes
203
+ self.drop_rate = drop_rate
204
+ self.grad_checkpointing = False
205
+ self.feature_info = []
206
+
207
+ # building first layer
208
+ stem_chs = make_divisible(16 * width, 4)
209
+ self.conv_stem = nn.Conv2d(in_chans, stem_chs, 3, 2, 1, bias=False)
210
+ self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=f'conv_stem'))
211
+ self.bn1 = nn.BatchNorm2d(stem_chs)
212
+ self.act1 = nn.ReLU(inplace=True)
213
+ prev_chs = stem_chs
214
+
215
+ # building inverted residual blocks
216
+ stages = nn.ModuleList([])
217
+ stage_idx = 0
218
+ layer_idx = 0
219
+ net_stride = 2
220
+ for cfg in self.cfgs:
221
+ layers = []
222
+ s = 1
223
+ for k, exp_size, c, se_ratio, s in cfg:
224
+ out_chs = make_divisible(c * width, 4)
225
+ mid_chs = make_divisible(exp_size * width, 4)
226
+ layer_kwargs = {}
227
+ if version == 'v2' and layer_idx > 1:
228
+ layer_kwargs['mode'] = 'attn'
229
+ layers.append(GhostBottleneck(prev_chs, mid_chs, out_chs, k, s, se_ratio=se_ratio, **layer_kwargs))
230
+ prev_chs = out_chs
231
+ layer_idx += 1
232
+ if s > 1:
233
+ net_stride *= 2
234
+ self.feature_info.append(dict(
235
+ num_chs=prev_chs, reduction=net_stride, module=f'blocks.{stage_idx}'))
236
+ stages.append(nn.Sequential(*layers))
237
+ stage_idx += 1
238
+
239
+ out_chs = make_divisible(exp_size * width, 4)
240
+ stages.append(nn.Sequential(ConvBnAct(prev_chs, out_chs, 1)))
241
+ self.pool_dim = prev_chs = out_chs
242
+
243
+ self.blocks = nn.Sequential(*stages)
244
+
245
+ # building last several layers
246
+ self.num_features = out_chs = 1280
247
+ self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
248
+ self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True)
249
+ self.act2 = nn.ReLU(inplace=True)
250
+ self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
251
+ self.classifier = Linear(out_chs, num_classes) if num_classes > 0 else nn.Identity()
252
+
253
+ # FIXME init
254
+
255
+ @torch.jit.ignore
256
+ def group_matcher(self, coarse=False):
257
+ matcher = dict(
258
+ stem=r'^conv_stem|bn1',
259
+ blocks=[
260
+ (r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)', None),
261
+ (r'conv_head', (99999,))
262
+ ]
263
+ )
264
+ return matcher
265
+
266
+ @torch.jit.ignore
267
+ def set_grad_checkpointing(self, enable=True):
268
+ self.grad_checkpointing = enable
269
+
270
+ @torch.jit.ignore
271
+ def get_classifier(self):
272
+ return self.classifier
273
+
274
+ def reset_classifier(self, num_classes, global_pool='avg'):
275
+ self.num_classes = num_classes
276
+ # cannot meaningfully change pooling of efficient head after creation
277
+ self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
278
+ self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
279
+ self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
280
+
281
+ def forward_features(self, x):
282
+ x = self.conv_stem(x)
283
+ x = self.bn1(x)
284
+ x = self.act1(x)
285
+ if self.grad_checkpointing and not torch.jit.is_scripting():
286
+ x = checkpoint_seq(self.blocks, x, flatten=True)
287
+ else:
288
+ x = self.blocks(x)
289
+ return x
290
+
291
+ def forward_head(self, x):
292
+ x = self.global_pool(x)
293
+ x = self.conv_head(x)
294
+ x = self.act2(x)
295
+ x = self.flatten(x)
296
+ if self.drop_rate > 0.:
297
+ x = F.dropout(x, p=self.drop_rate, training=self.training)
298
+ x = self.classifier(x)
299
+ return x
300
+
301
+ def forward(self, x):
302
+ x = self.forward_features(x)
303
+ x = self.forward_head(x)
304
+ return x
305
+
306
+
307
+ def checkpoint_filter_fn(state_dict, model: nn.Module):
308
+ out_dict = {}
309
+ for k, v in state_dict.items():
310
+ if 'total' in k:
311
+ continue
312
+ out_dict[k] = v
313
+ return out_dict
314
+
315
+
316
+ def _create_ghostnet(variant, width=1.0, pretrained=False, **kwargs):
317
+ """
318
+ Constructs a GhostNet model
319
+ """
320
+ cfgs = [
321
+ # k, t, c, SE, s
322
+ # stage1
323
+ [[3, 16, 16, 0, 1]],
324
+ # stage2
325
+ [[3, 48, 24, 0, 2]],
326
+ [[3, 72, 24, 0, 1]],
327
+ # stage3
328
+ [[5, 72, 40, 0.25, 2]],
329
+ [[5, 120, 40, 0.25, 1]],
330
+ # stage4
331
+ [[3, 240, 80, 0, 2]],
332
+ [[3, 200, 80, 0, 1],
333
+ [3, 184, 80, 0, 1],
334
+ [3, 184, 80, 0, 1],
335
+ [3, 480, 112, 0.25, 1],
336
+ [3, 672, 112, 0.25, 1]
337
+ ],
338
+ # stage5
339
+ [[5, 672, 160, 0.25, 2]],
340
+ [[5, 960, 160, 0, 1],
341
+ [5, 960, 160, 0.25, 1],
342
+ [5, 960, 160, 0, 1],
343
+ [5, 960, 160, 0.25, 1]
344
+ ]
345
+ ]
346
+ model_kwargs = dict(
347
+ cfgs=cfgs,
348
+ width=width,
349
+ **kwargs,
350
+ )
351
+ return build_model_with_cfg(
352
+ GhostNet,
353
+ variant,
354
+ pretrained,
355
+ pretrained_filter_fn=checkpoint_filter_fn,
356
+ feature_cfg=dict(flatten_sequential=True),
357
+ **model_kwargs,
358
+ )
359
+
360
+
361
+ def _cfg(url='', **kwargs):
362
+ return {
363
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
364
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
365
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
366
+ 'first_conv': 'conv_stem', 'classifier': 'classifier',
367
+ **kwargs
368
+ }
369
+
370
+
371
+ default_cfgs = generate_default_cfgs({
372
+ 'ghostnet_050.untrained': _cfg(),
373
+ 'ghostnet_100.in1k': _cfg(
374
+ hf_hub_id='timm/',
375
+ # url='https://github.com/huawei-noah/CV-backbones/releases/download/ghostnet_pth/ghostnet_1x.pth'
376
+ ),
377
+ 'ghostnet_130.untrained': _cfg(),
378
+ 'ghostnetv2_100.in1k': _cfg(
379
+ hf_hub_id='timm/',
380
+ # url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/GhostNetV2/ck_ghostnetv2_10.pth.tar'
381
+ ),
382
+ 'ghostnetv2_130.in1k': _cfg(
383
+ hf_hub_id='timm/',
384
+ # url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/GhostNetV2/ck_ghostnetv2_13.pth.tar'
385
+ ),
386
+ 'ghostnetv2_160.in1k': _cfg(
387
+ hf_hub_id='timm/',
388
+ # url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/GhostNetV2/ck_ghostnetv2_16.pth.tar'
389
+ ),
390
+ })
391
+
392
+
393
+ @register_model
394
+ def ghostnet_050(pretrained=False, **kwargs) -> GhostNet:
395
+ """ GhostNet-0.5x """
396
+ model = _create_ghostnet('ghostnet_050', width=0.5, pretrained=pretrained, **kwargs)
397
+ return model
398
+
399
+
400
+ @register_model
401
+ def ghostnet_100(pretrained=False, **kwargs) -> GhostNet:
402
+ """ GhostNet-1.0x """
403
+ model = _create_ghostnet('ghostnet_100', width=1.0, pretrained=pretrained, **kwargs)
404
+ return model
405
+
406
+
407
+ @register_model
408
+ def ghostnet_130(pretrained=False, **kwargs) -> GhostNet:
409
+ """ GhostNet-1.3x """
410
+ model = _create_ghostnet('ghostnet_130', width=1.3, pretrained=pretrained, **kwargs)
411
+ return model
412
+
413
+
414
+ @register_model
415
+ def ghostnetv2_100(pretrained=False, **kwargs) -> GhostNet:
416
+ """ GhostNetV2-1.0x """
417
+ model = _create_ghostnet('ghostnetv2_100', width=1.0, pretrained=pretrained, version='v2', **kwargs)
418
+ return model
419
+
420
+
421
+ @register_model
422
+ def ghostnetv2_130(pretrained=False, **kwargs) -> GhostNet:
423
+ """ GhostNetV2-1.3x """
424
+ model = _create_ghostnet('ghostnetv2_130', width=1.3, pretrained=pretrained, version='v2', **kwargs)
425
+ return model
426
+
427
+
428
+ @register_model
429
+ def ghostnetv2_160(pretrained=False, **kwargs) -> GhostNet:
430
+ """ GhostNetV2-1.6x """
431
+ model = _create_ghostnet('ghostnetv2_160', width=1.6, pretrained=pretrained, version='v2', **kwargs)
432
+ return model
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/inception_v4.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Pytorch Inception-V4 implementation
2
+ Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is
3
+ based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License)
4
+ """
5
+ from functools import partial
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
11
+ from timm.layers import create_classifier, ConvNormAct
12
+ from ._builder import build_model_with_cfg
13
+ from ._registry import register_model, generate_default_cfgs
14
+
15
+ __all__ = ['InceptionV4']
16
+
17
+
18
+ class Mixed3a(nn.Module):
19
+ def __init__(self, conv_block=ConvNormAct):
20
+ super(Mixed3a, self).__init__()
21
+ self.maxpool = nn.MaxPool2d(3, stride=2)
22
+ self.conv = conv_block(64, 96, kernel_size=3, stride=2)
23
+
24
+ def forward(self, x):
25
+ x0 = self.maxpool(x)
26
+ x1 = self.conv(x)
27
+ out = torch.cat((x0, x1), 1)
28
+ return out
29
+
30
+
31
+ class Mixed4a(nn.Module):
32
+ def __init__(self, conv_block=ConvNormAct):
33
+ super(Mixed4a, self).__init__()
34
+
35
+ self.branch0 = nn.Sequential(
36
+ conv_block(160, 64, kernel_size=1, stride=1),
37
+ conv_block(64, 96, kernel_size=3, stride=1)
38
+ )
39
+
40
+ self.branch1 = nn.Sequential(
41
+ conv_block(160, 64, kernel_size=1, stride=1),
42
+ conv_block(64, 64, kernel_size=(1, 7), stride=1, padding=(0, 3)),
43
+ conv_block(64, 64, kernel_size=(7, 1), stride=1, padding=(3, 0)),
44
+ conv_block(64, 96, kernel_size=(3, 3), stride=1)
45
+ )
46
+
47
+ def forward(self, x):
48
+ x0 = self.branch0(x)
49
+ x1 = self.branch1(x)
50
+ out = torch.cat((x0, x1), 1)
51
+ return out
52
+
53
+
54
+ class Mixed5a(nn.Module):
55
+ def __init__(self, conv_block=ConvNormAct):
56
+ super(Mixed5a, self).__init__()
57
+ self.conv = conv_block(192, 192, kernel_size=3, stride=2)
58
+ self.maxpool = nn.MaxPool2d(3, stride=2)
59
+
60
+ def forward(self, x):
61
+ x0 = self.conv(x)
62
+ x1 = self.maxpool(x)
63
+ out = torch.cat((x0, x1), 1)
64
+ return out
65
+
66
+
67
+ class InceptionA(nn.Module):
68
+ def __init__(self, conv_block=ConvNormAct):
69
+ super(InceptionA, self).__init__()
70
+ self.branch0 = conv_block(384, 96, kernel_size=1, stride=1)
71
+
72
+ self.branch1 = nn.Sequential(
73
+ conv_block(384, 64, kernel_size=1, stride=1),
74
+ conv_block(64, 96, kernel_size=3, stride=1, padding=1)
75
+ )
76
+
77
+ self.branch2 = nn.Sequential(
78
+ conv_block(384, 64, kernel_size=1, stride=1),
79
+ conv_block(64, 96, kernel_size=3, stride=1, padding=1),
80
+ conv_block(96, 96, kernel_size=3, stride=1, padding=1)
81
+ )
82
+
83
+ self.branch3 = nn.Sequential(
84
+ nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
85
+ conv_block(384, 96, kernel_size=1, stride=1)
86
+ )
87
+
88
+ def forward(self, x):
89
+ x0 = self.branch0(x)
90
+ x1 = self.branch1(x)
91
+ x2 = self.branch2(x)
92
+ x3 = self.branch3(x)
93
+ out = torch.cat((x0, x1, x2, x3), 1)
94
+ return out
95
+
96
+
97
+ class ReductionA(nn.Module):
98
+ def __init__(self, conv_block=ConvNormAct):
99
+ super(ReductionA, self).__init__()
100
+ self.branch0 = conv_block(384, 384, kernel_size=3, stride=2)
101
+
102
+ self.branch1 = nn.Sequential(
103
+ conv_block(384, 192, kernel_size=1, stride=1),
104
+ conv_block(192, 224, kernel_size=3, stride=1, padding=1),
105
+ conv_block(224, 256, kernel_size=3, stride=2)
106
+ )
107
+
108
+ self.branch2 = nn.MaxPool2d(3, stride=2)
109
+
110
+ def forward(self, x):
111
+ x0 = self.branch0(x)
112
+ x1 = self.branch1(x)
113
+ x2 = self.branch2(x)
114
+ out = torch.cat((x0, x1, x2), 1)
115
+ return out
116
+
117
+
118
+ class InceptionB(nn.Module):
119
+ def __init__(self, conv_block=ConvNormAct):
120
+ super(InceptionB, self).__init__()
121
+ self.branch0 = conv_block(1024, 384, kernel_size=1, stride=1)
122
+
123
+ self.branch1 = nn.Sequential(
124
+ conv_block(1024, 192, kernel_size=1, stride=1),
125
+ conv_block(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)),
126
+ conv_block(224, 256, kernel_size=(7, 1), stride=1, padding=(3, 0))
127
+ )
128
+
129
+ self.branch2 = nn.Sequential(
130
+ conv_block(1024, 192, kernel_size=1, stride=1),
131
+ conv_block(192, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)),
132
+ conv_block(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)),
133
+ conv_block(224, 224, kernel_size=(7, 1), stride=1, padding=(3, 0)),
134
+ conv_block(224, 256, kernel_size=(1, 7), stride=1, padding=(0, 3))
135
+ )
136
+
137
+ self.branch3 = nn.Sequential(
138
+ nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
139
+ conv_block(1024, 128, kernel_size=1, stride=1)
140
+ )
141
+
142
+ def forward(self, x):
143
+ x0 = self.branch0(x)
144
+ x1 = self.branch1(x)
145
+ x2 = self.branch2(x)
146
+ x3 = self.branch3(x)
147
+ out = torch.cat((x0, x1, x2, x3), 1)
148
+ return out
149
+
150
+
151
+ class ReductionB(nn.Module):
152
+ def __init__(self, conv_block=ConvNormAct):
153
+ super(ReductionB, self).__init__()
154
+
155
+ self.branch0 = nn.Sequential(
156
+ conv_block(1024, 192, kernel_size=1, stride=1),
157
+ conv_block(192, 192, kernel_size=3, stride=2)
158
+ )
159
+
160
+ self.branch1 = nn.Sequential(
161
+ conv_block(1024, 256, kernel_size=1, stride=1),
162
+ conv_block(256, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)),
163
+ conv_block(256, 320, kernel_size=(7, 1), stride=1, padding=(3, 0)),
164
+ conv_block(320, 320, kernel_size=3, stride=2)
165
+ )
166
+
167
+ self.branch2 = nn.MaxPool2d(3, stride=2)
168
+
169
+ def forward(self, x):
170
+ x0 = self.branch0(x)
171
+ x1 = self.branch1(x)
172
+ x2 = self.branch2(x)
173
+ out = torch.cat((x0, x1, x2), 1)
174
+ return out
175
+
176
+
177
+ class InceptionC(nn.Module):
178
+ def __init__(self, conv_block=ConvNormAct):
179
+ super(InceptionC, self).__init__()
180
+
181
+ self.branch0 = conv_block(1536, 256, kernel_size=1, stride=1)
182
+
183
+ self.branch1_0 = conv_block(1536, 384, kernel_size=1, stride=1)
184
+ self.branch1_1a = conv_block(384, 256, kernel_size=(1, 3), stride=1, padding=(0, 1))
185
+ self.branch1_1b = conv_block(384, 256, kernel_size=(3, 1), stride=1, padding=(1, 0))
186
+
187
+ self.branch2_0 = conv_block(1536, 384, kernel_size=1, stride=1)
188
+ self.branch2_1 = conv_block(384, 448, kernel_size=(3, 1), stride=1, padding=(1, 0))
189
+ self.branch2_2 = conv_block(448, 512, kernel_size=(1, 3), stride=1, padding=(0, 1))
190
+ self.branch2_3a = conv_block(512, 256, kernel_size=(1, 3), stride=1, padding=(0, 1))
191
+ self.branch2_3b = conv_block(512, 256, kernel_size=(3, 1), stride=1, padding=(1, 0))
192
+
193
+ self.branch3 = nn.Sequential(
194
+ nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
195
+ conv_block(1536, 256, kernel_size=1, stride=1)
196
+ )
197
+
198
+ def forward(self, x):
199
+ x0 = self.branch0(x)
200
+
201
+ x1_0 = self.branch1_0(x)
202
+ x1_1a = self.branch1_1a(x1_0)
203
+ x1_1b = self.branch1_1b(x1_0)
204
+ x1 = torch.cat((x1_1a, x1_1b), 1)
205
+
206
+ x2_0 = self.branch2_0(x)
207
+ x2_1 = self.branch2_1(x2_0)
208
+ x2_2 = self.branch2_2(x2_1)
209
+ x2_3a = self.branch2_3a(x2_2)
210
+ x2_3b = self.branch2_3b(x2_2)
211
+ x2 = torch.cat((x2_3a, x2_3b), 1)
212
+
213
+ x3 = self.branch3(x)
214
+
215
+ out = torch.cat((x0, x1, x2, x3), 1)
216
+ return out
217
+
218
+
219
+ class InceptionV4(nn.Module):
220
+ def __init__(
221
+ self,
222
+ num_classes=1000,
223
+ in_chans=3,
224
+ output_stride=32,
225
+ drop_rate=0.,
226
+ global_pool='avg',
227
+ norm_layer='batchnorm2d',
228
+ norm_eps=1e-3,
229
+ act_layer='relu',
230
+ ):
231
+ super(InceptionV4, self).__init__()
232
+ assert output_stride == 32
233
+ self.num_classes = num_classes
234
+ self.num_features = 1536
235
+ conv_block = partial(
236
+ ConvNormAct,
237
+ padding=0,
238
+ norm_layer=norm_layer,
239
+ act_layer=act_layer,
240
+ norm_kwargs=dict(eps=norm_eps),
241
+ act_kwargs=dict(inplace=True),
242
+ )
243
+
244
+ features = [
245
+ conv_block(in_chans, 32, kernel_size=3, stride=2),
246
+ conv_block(32, 32, kernel_size=3, stride=1),
247
+ conv_block(32, 64, kernel_size=3, stride=1, padding=1),
248
+ Mixed3a(conv_block),
249
+ Mixed4a(conv_block),
250
+ Mixed5a(conv_block),
251
+ ]
252
+ features += [InceptionA(conv_block) for _ in range(4)]
253
+ features += [ReductionA(conv_block)] # Mixed6a
254
+ features += [InceptionB(conv_block) for _ in range(7)]
255
+ features += [ReductionB(conv_block)] # Mixed7a
256
+ features += [InceptionC(conv_block) for _ in range(3)]
257
+ self.features = nn.Sequential(*features)
258
+ self.feature_info = [
259
+ dict(num_chs=64, reduction=2, module='features.2'),
260
+ dict(num_chs=160, reduction=4, module='features.3'),
261
+ dict(num_chs=384, reduction=8, module='features.9'),
262
+ dict(num_chs=1024, reduction=16, module='features.17'),
263
+ dict(num_chs=1536, reduction=32, module='features.21'),
264
+ ]
265
+ self.global_pool, self.head_drop, self.last_linear = create_classifier(
266
+ self.num_features, self.num_classes, pool_type=global_pool, drop_rate=drop_rate)
267
+
268
+ @torch.jit.ignore
269
+ def group_matcher(self, coarse=False):
270
+ return dict(
271
+ stem=r'^features\.[012]\.',
272
+ blocks=r'^features\.(\d+)'
273
+ )
274
+
275
+ @torch.jit.ignore
276
+ def set_grad_checkpointing(self, enable=True):
277
+ assert not enable, 'gradient checkpointing not supported'
278
+
279
+ @torch.jit.ignore
280
+ def get_classifier(self):
281
+ return self.last_linear
282
+
283
+ def reset_classifier(self, num_classes, global_pool='avg'):
284
+ self.num_classes = num_classes
285
+ self.global_pool, self.last_linear = create_classifier(
286
+ self.num_features, self.num_classes, pool_type=global_pool)
287
+
288
+ def forward_features(self, x):
289
+ return self.features(x)
290
+
291
+ def forward_head(self, x, pre_logits: bool = False):
292
+ x = self.global_pool(x)
293
+ x = self.head_drop(x)
294
+ return x if pre_logits else self.last_linear(x)
295
+
296
+ def forward(self, x):
297
+ x = self.forward_features(x)
298
+ x = self.forward_head(x)
299
+ return x
300
+
301
+
302
+ def _create_inception_v4(variant, pretrained=False, **kwargs) -> InceptionV4:
303
+ return build_model_with_cfg(
304
+ InceptionV4,
305
+ variant,
306
+ pretrained,
307
+ feature_cfg=dict(flatten_sequential=True),
308
+ **kwargs,
309
+ )
310
+
311
+
312
+ default_cfgs = generate_default_cfgs({
313
+ 'inception_v4.tf_in1k': {
314
+ 'hf_hub_id': 'timm/',
315
+ 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
316
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
317
+ 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
318
+ 'first_conv': 'features.0.conv', 'classifier': 'last_linear',
319
+ }
320
+ })
321
+
322
+
323
+ @register_model
324
+ def inception_v4(pretrained=False, **kwargs):
325
+ return _create_inception_v4('inception_v4', pretrained, **kwargs)
my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/levit.py ADDED
@@ -0,0 +1,933 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ LeViT
2
+
3
+ Paper: `LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference`
4
+ - https://arxiv.org/abs/2104.01136
5
+
6
+ @article{graham2021levit,
7
+ title={LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference},
8
+ author={Benjamin Graham and Alaaeldin El-Nouby and Hugo Touvron and Pierre Stock and Armand Joulin and Herv\'e J\'egou and Matthijs Douze},
9
+ journal={arXiv preprint arXiv:22104.01136},
10
+ year={2021}
11
+ }
12
+
13
+ Adapted from official impl at https://github.com/facebookresearch/LeViT, original copyright bellow.
14
+
15
+ This version combines both conv/linear models and fixes torchscript compatibility.
16
+
17
+ Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
18
+ """
19
+
20
+ # Copyright (c) 2015-present, Facebook, Inc.
21
+ # All rights reserved.
22
+
23
+ # Modified from
24
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
25
+ # Copyright 2020 Ross Wightman, Apache-2.0 License
26
+ from collections import OrderedDict
27
+ from functools import partial
28
+ from typing import Dict
29
+
30
+ import torch
31
+ import torch.nn as nn
32
+
33
+ from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
34
+ from timm.layers import to_ntuple, to_2tuple, get_act_layer, DropPath, trunc_normal_, ndgrid
35
+ from ._builder import build_model_with_cfg
36
+ from ._manipulate import checkpoint_seq
37
+ from ._registry import generate_default_cfgs, register_model
38
+
39
+ __all__ = ['Levit']
40
+
41
+
42
+ class ConvNorm(nn.Module):
43
+ def __init__(
44
+ self, in_chs, out_chs, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bn_weight_init=1):
45
+ super().__init__()
46
+ self.linear = nn.Conv2d(in_chs, out_chs, kernel_size, stride, padding, dilation, groups, bias=False)
47
+ self.bn = nn.BatchNorm2d(out_chs)
48
+
49
+ nn.init.constant_(self.bn.weight, bn_weight_init)
50
+
51
+ @torch.no_grad()
52
+ def fuse(self):
53
+ c, bn = self.linear, self.bn
54
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
55
+ w = c.weight * w[:, None, None, None]
56
+ b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
57
+ m = nn.Conv2d(
58
+ w.size(1), w.size(0), w.shape[2:], stride=self.linear.stride,
59
+ padding=self.linear.padding, dilation=self.linear.dilation, groups=self.linear.groups)
60
+ m.weight.data.copy_(w)
61
+ m.bias.data.copy_(b)
62
+ return m
63
+
64
+ def forward(self, x):
65
+ return self.bn(self.linear(x))
66
+
67
+
68
+ class LinearNorm(nn.Module):
69
+ def __init__(self, in_features, out_features, bn_weight_init=1):
70
+ super().__init__()
71
+ self.linear = nn.Linear(in_features, out_features, bias=False)
72
+ self.bn = nn.BatchNorm1d(out_features)
73
+
74
+ nn.init.constant_(self.bn.weight, bn_weight_init)
75
+
76
+ @torch.no_grad()
77
+ def fuse(self):
78
+ l, bn = self.linear, self.bn
79
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
80
+ w = l.weight * w[:, None]
81
+ b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
82
+ m = nn.Linear(w.size(1), w.size(0))
83
+ m.weight.data.copy_(w)
84
+ m.bias.data.copy_(b)
85
+ return m
86
+
87
+ def forward(self, x):
88
+ x = self.linear(x)
89
+ return self.bn(x.flatten(0, 1)).reshape_as(x)
90
+
91
+
92
+ class NormLinear(nn.Module):
93
+ def __init__(self, in_features, out_features, bias=True, std=0.02, drop=0.):
94
+ super().__init__()
95
+ self.bn = nn.BatchNorm1d(in_features)
96
+ self.drop = nn.Dropout(drop)
97
+ self.linear = nn.Linear(in_features, out_features, bias=bias)
98
+
99
+ trunc_normal_(self.linear.weight, std=std)
100
+ if self.linear.bias is not None:
101
+ nn.init.constant_(self.linear.bias, 0)
102
+
103
+ @torch.no_grad()
104
+ def fuse(self):
105
+ bn, l = self.bn, self.linear
106
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
107
+ b = bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5
108
+ w = l.weight * w[None, :]
109
+ if l.bias is None:
110
+ b = b @ self.linear.weight.T
111
+ else:
112
+ b = (l.weight @ b[:, None]).view(-1) + self.linear.bias
113
+ m = nn.Linear(w.size(1), w.size(0))
114
+ m.weight.data.copy_(w)
115
+ m.bias.data.copy_(b)
116
+ return m
117
+
118
+ def forward(self, x):
119
+ return self.linear(self.drop(self.bn(x)))
120
+
121
+
122
+ class Stem8(nn.Sequential):
123
+ def __init__(self, in_chs, out_chs, act_layer):
124
+ super().__init__()
125
+ self.stride = 8
126
+
127
+ self.add_module('conv1', ConvNorm(in_chs, out_chs // 4, 3, stride=2, padding=1))
128
+ self.add_module('act1', act_layer())
129
+ self.add_module('conv2', ConvNorm(out_chs // 4, out_chs // 2, 3, stride=2, padding=1))
130
+ self.add_module('act2', act_layer())
131
+ self.add_module('conv3', ConvNorm(out_chs // 2, out_chs, 3, stride=2, padding=1))
132
+
133
+
134
+ class Stem16(nn.Sequential):
135
+ def __init__(self, in_chs, out_chs, act_layer):
136
+ super().__init__()
137
+ self.stride = 16
138
+
139
+ self.add_module('conv1', ConvNorm(in_chs, out_chs // 8, 3, stride=2, padding=1))
140
+ self.add_module('act1', act_layer())
141
+ self.add_module('conv2', ConvNorm(out_chs // 8, out_chs // 4, 3, stride=2, padding=1))
142
+ self.add_module('act2', act_layer())
143
+ self.add_module('conv3', ConvNorm(out_chs // 4, out_chs // 2, 3, stride=2, padding=1))
144
+ self.add_module('act3', act_layer())
145
+ self.add_module('conv4', ConvNorm(out_chs // 2, out_chs, 3, stride=2, padding=1))
146
+
147
+
148
+ class Downsample(nn.Module):
149
+ def __init__(self, stride, resolution, use_pool=False):
150
+ super().__init__()
151
+ self.stride = stride
152
+ self.resolution = to_2tuple(resolution)
153
+ self.pool = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) if use_pool else None
154
+
155
+ def forward(self, x):
156
+ B, N, C = x.shape
157
+ x = x.view(B, self.resolution[0], self.resolution[1], C)
158
+ if self.pool is not None:
159
+ x = self.pool(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
160
+ else:
161
+ x = x[:, ::self.stride, ::self.stride]
162
+ return x.reshape(B, -1, C)
163
+
164
+
165
+ class Attention(nn.Module):
166
+ attention_bias_cache: Dict[str, torch.Tensor]
167
+
168
+ def __init__(
169
+ self,
170
+ dim,
171
+ key_dim,
172
+ num_heads=8,
173
+ attn_ratio=4.,
174
+ resolution=14,
175
+ use_conv=False,
176
+ act_layer=nn.SiLU,
177
+ ):
178
+ super().__init__()
179
+ ln_layer = ConvNorm if use_conv else LinearNorm
180
+ resolution = to_2tuple(resolution)
181
+
182
+ self.use_conv = use_conv
183
+ self.num_heads = num_heads
184
+ self.scale = key_dim ** -0.5
185
+ self.key_dim = key_dim
186
+ self.key_attn_dim = key_dim * num_heads
187
+ self.val_dim = int(attn_ratio * key_dim)
188
+ self.val_attn_dim = int(attn_ratio * key_dim) * num_heads
189
+
190
+ self.qkv = ln_layer(dim, self.val_attn_dim + self.key_attn_dim * 2)
191
+ self.proj = nn.Sequential(OrderedDict([
192
+ ('act', act_layer()),
193
+ ('ln', ln_layer(self.val_attn_dim, dim, bn_weight_init=0))
194
+ ]))
195
+
196
+ self.attention_biases = nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1]))
197
+ pos = torch.stack(ndgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1)
198
+ rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
199
+ rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
200
+ self.register_buffer('attention_bias_idxs', rel_pos, persistent=False)
201
+ self.attention_bias_cache = {}
202
+
203
+ @torch.no_grad()
204
+ def train(self, mode=True):
205
+ super().train(mode)
206
+ if mode and self.attention_bias_cache:
207
+ self.attention_bias_cache = {} # clear ab cache
208
+
209
+ def get_attention_biases(self, device: torch.device) -> torch.Tensor:
210
+ if torch.jit.is_tracing() or self.training:
211
+ return self.attention_biases[:, self.attention_bias_idxs]
212
+ else:
213
+ device_key = str(device)
214
+ if device_key not in self.attention_bias_cache:
215
+ self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
216
+ return self.attention_bias_cache[device_key]
217
+
218
+ def forward(self, x): # x (B,C,H,W)
219
+ if self.use_conv:
220
+ B, C, H, W = x.shape
221
+ q, k, v = self.qkv(x).view(
222
+ B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.val_dim], dim=2)
223
+
224
+ attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
225
+ attn = attn.softmax(dim=-1)
226
+
227
+ x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
228
+ else:
229
+ B, N, C = x.shape
230
+ q, k, v = self.qkv(x).view(
231
+ B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.val_dim], dim=3)
232
+ q = q.permute(0, 2, 1, 3)
233
+ k = k.permute(0, 2, 3, 1)
234
+ v = v.permute(0, 2, 1, 3)
235
+
236
+ attn = q @ k * self.scale + self.get_attention_biases(x.device)
237
+ attn = attn.softmax(dim=-1)
238
+
239
+ x = (attn @ v).transpose(1, 2).reshape(B, N, self.val_attn_dim)
240
+ x = self.proj(x)
241
+ return x
242
+
243
+
244
+ class AttentionDownsample(nn.Module):
245
+ attention_bias_cache: Dict[str, torch.Tensor]
246
+
247
+ def __init__(
248
+ self,
249
+ in_dim,
250
+ out_dim,
251
+ key_dim,
252
+ num_heads=8,
253
+ attn_ratio=2.0,
254
+ stride=2,
255
+ resolution=14,
256
+ use_conv=False,
257
+ use_pool=False,
258
+ act_layer=nn.SiLU,
259
+ ):
260
+ super().__init__()
261
+ resolution = to_2tuple(resolution)
262
+
263
+ self.stride = stride
264
+ self.resolution = resolution
265
+ self.num_heads = num_heads
266
+ self.key_dim = key_dim
267
+ self.key_attn_dim = key_dim * num_heads
268
+ self.val_dim = int(attn_ratio * key_dim)
269
+ self.val_attn_dim = self.val_dim * self.num_heads
270
+ self.scale = key_dim ** -0.5
271
+ self.use_conv = use_conv
272
+
273
+ if self.use_conv:
274
+ ln_layer = ConvNorm
275
+ sub_layer = partial(
276
+ nn.AvgPool2d,
277
+ kernel_size=3 if use_pool else 1, padding=1 if use_pool else 0, count_include_pad=False)
278
+ else:
279
+ ln_layer = LinearNorm
280
+ sub_layer = partial(Downsample, resolution=resolution, use_pool=use_pool)
281
+
282
+ self.kv = ln_layer(in_dim, self.val_attn_dim + self.key_attn_dim)
283
+ self.q = nn.Sequential(OrderedDict([
284
+ ('down', sub_layer(stride=stride)),
285
+ ('ln', ln_layer(in_dim, self.key_attn_dim))
286
+ ]))
287
+ self.proj = nn.Sequential(OrderedDict([
288
+ ('act', act_layer()),
289
+ ('ln', ln_layer(self.val_attn_dim, out_dim))
290
+ ]))
291
+
292
+ self.attention_biases = nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1]))
293
+ k_pos = torch.stack(ndgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1)
294
+ q_pos = torch.stack(ndgrid(
295
+ torch.arange(0, resolution[0], step=stride),
296
+ torch.arange(0, resolution[1], step=stride)
297
+ )).flatten(1)
298
+ rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs()
299
+ rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
300
+ self.register_buffer('attention_bias_idxs', rel_pos, persistent=False)
301
+
302
+ self.attention_bias_cache = {} # per-device attention_biases cache
303
+
304
+ @torch.no_grad()
305
+ def train(self, mode=True):
306
+ super().train(mode)
307
+ if mode and self.attention_bias_cache:
308
+ self.attention_bias_cache = {} # clear ab cache
309
+
310
+ def get_attention_biases(self, device: torch.device) -> torch.Tensor:
311
+ if torch.jit.is_tracing() or self.training:
312
+ return self.attention_biases[:, self.attention_bias_idxs]
313
+ else:
314
+ device_key = str(device)
315
+ if device_key not in self.attention_bias_cache:
316
+ self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
317
+ return self.attention_bias_cache[device_key]
318
+
319
+ def forward(self, x):
320
+ if self.use_conv:
321
+ B, C, H, W = x.shape
322
+ HH, WW = (H - 1) // self.stride + 1, (W - 1) // self.stride + 1
323
+ k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.val_dim], dim=2)
324
+ q = self.q(x).view(B, self.num_heads, self.key_dim, -1)
325
+
326
+ attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
327
+ attn = attn.softmax(dim=-1)
328
+
329
+ x = (v @ attn.transpose(-2, -1)).reshape(B, self.val_attn_dim, HH, WW)
330
+ else:
331
+ B, N, C = x.shape
332
+ k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.val_dim], dim=3)
333
+ k = k.permute(0, 2, 3, 1) # BHCN
334
+ v = v.permute(0, 2, 1, 3) # BHNC
335
+ q = self.q(x).view(B, -1, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
336
+
337
+ attn = q @ k * self.scale + self.get_attention_biases(x.device)
338
+ attn = attn.softmax(dim=-1)
339
+
340
+ x = (attn @ v).transpose(1, 2).reshape(B, -1, self.val_attn_dim)
341
+ x = self.proj(x)
342
+ return x
343
+
344
+
345
+ class LevitMlp(nn.Module):
346
+ """ MLP for Levit w/ normalization + ability to switch btw conv and linear
347
+ """
348
+ def __init__(
349
+ self,
350
+ in_features,
351
+ hidden_features=None,
352
+ out_features=None,
353
+ use_conv=False,
354
+ act_layer=nn.SiLU,
355
+ drop=0.
356
+ ):
357
+ super().__init__()
358
+ out_features = out_features or in_features
359
+ hidden_features = hidden_features or in_features
360
+ ln_layer = ConvNorm if use_conv else LinearNorm
361
+
362
+ self.ln1 = ln_layer(in_features, hidden_features)
363
+ self.act = act_layer()
364
+ self.drop = nn.Dropout(drop)
365
+ self.ln2 = ln_layer(hidden_features, out_features, bn_weight_init=0)
366
+
367
+ def forward(self, x):
368
+ x = self.ln1(x)
369
+ x = self.act(x)
370
+ x = self.drop(x)
371
+ x = self.ln2(x)
372
+ return x
373
+
374
+
375
+ class LevitDownsample(nn.Module):
376
+ def __init__(
377
+ self,
378
+ in_dim,
379
+ out_dim,
380
+ key_dim,
381
+ num_heads=8,
382
+ attn_ratio=4.,
383
+ mlp_ratio=2.,
384
+ act_layer=nn.SiLU,
385
+ attn_act_layer=None,
386
+ resolution=14,
387
+ use_conv=False,
388
+ use_pool=False,
389
+ drop_path=0.,
390
+ ):
391
+ super().__init__()
392
+ attn_act_layer = attn_act_layer or act_layer
393
+
394
+ self.attn_downsample = AttentionDownsample(
395
+ in_dim=in_dim,
396
+ out_dim=out_dim,
397
+ key_dim=key_dim,
398
+ num_heads=num_heads,
399
+ attn_ratio=attn_ratio,
400
+ act_layer=attn_act_layer,
401
+ resolution=resolution,
402
+ use_conv=use_conv,
403
+ use_pool=use_pool,
404
+ )
405
+
406
+ self.mlp = LevitMlp(
407
+ out_dim,
408
+ int(out_dim * mlp_ratio),
409
+ use_conv=use_conv,
410
+ act_layer=act_layer
411
+ )
412
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
413
+
414
+ def forward(self, x):
415
+ x = self.attn_downsample(x)
416
+ x = x + self.drop_path(self.mlp(x))
417
+ return x
418
+
419
+
420
+ class LevitBlock(nn.Module):
421
+ def __init__(
422
+ self,
423
+ dim,
424
+ key_dim,
425
+ num_heads=8,
426
+ attn_ratio=4.,
427
+ mlp_ratio=2.,
428
+ resolution=14,
429
+ use_conv=False,
430
+ act_layer=nn.SiLU,
431
+ attn_act_layer=None,
432
+ drop_path=0.,
433
+ ):
434
+ super().__init__()
435
+ attn_act_layer = attn_act_layer or act_layer
436
+
437
+ self.attn = Attention(
438
+ dim=dim,
439
+ key_dim=key_dim,
440
+ num_heads=num_heads,
441
+ attn_ratio=attn_ratio,
442
+ resolution=resolution,
443
+ use_conv=use_conv,
444
+ act_layer=attn_act_layer,
445
+ )
446
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
447
+
448
+ self.mlp = LevitMlp(
449
+ dim,
450
+ int(dim * mlp_ratio),
451
+ use_conv=use_conv,
452
+ act_layer=act_layer
453
+ )
454
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
455
+
456
+ def forward(self, x):
457
+ x = x + self.drop_path1(self.attn(x))
458
+ x = x + self.drop_path2(self.mlp(x))
459
+ return x
460
+
461
+
462
+ class LevitStage(nn.Module):
463
+ def __init__(
464
+ self,
465
+ in_dim,
466
+ out_dim,
467
+ key_dim,
468
+ depth=4,
469
+ num_heads=8,
470
+ attn_ratio=4.0,
471
+ mlp_ratio=4.0,
472
+ act_layer=nn.SiLU,
473
+ attn_act_layer=None,
474
+ resolution=14,
475
+ downsample='',
476
+ use_conv=False,
477
+ drop_path=0.,
478
+ ):
479
+ super().__init__()
480
+ resolution = to_2tuple(resolution)
481
+
482
+ if downsample:
483
+ self.downsample = LevitDownsample(
484
+ in_dim,
485
+ out_dim,
486
+ key_dim=key_dim,
487
+ num_heads=in_dim // key_dim,
488
+ attn_ratio=4.,
489
+ mlp_ratio=2.,
490
+ act_layer=act_layer,
491
+ attn_act_layer=attn_act_layer,
492
+ resolution=resolution,
493
+ use_conv=use_conv,
494
+ drop_path=drop_path,
495
+ )
496
+ resolution = [(r - 1) // 2 + 1 for r in resolution]
497
+ else:
498
+ assert in_dim == out_dim
499
+ self.downsample = nn.Identity()
500
+
501
+ blocks = []
502
+ for _ in range(depth):
503
+ blocks += [LevitBlock(
504
+ out_dim,
505
+ key_dim,
506
+ num_heads=num_heads,
507
+ attn_ratio=attn_ratio,
508
+ mlp_ratio=mlp_ratio,
509
+ act_layer=act_layer,
510
+ attn_act_layer=attn_act_layer,
511
+ resolution=resolution,
512
+ use_conv=use_conv,
513
+ drop_path=drop_path,
514
+ )]
515
+ self.blocks = nn.Sequential(*blocks)
516
+
517
+ def forward(self, x):
518
+ x = self.downsample(x)
519
+ x = self.blocks(x)
520
+ return x
521
+
522
+
523
+ class Levit(nn.Module):
524
+ """ Vision Transformer with support for patch or hybrid CNN input stage
525
+
526
+ NOTE: distillation is defaulted to True since pretrained weights use it, will cause problems
527
+ w/ train scripts that don't take tuple outputs,
528
+ """
529
+
530
+ def __init__(
531
+ self,
532
+ img_size=224,
533
+ in_chans=3,
534
+ num_classes=1000,
535
+ embed_dim=(192,),
536
+ key_dim=64,
537
+ depth=(12,),
538
+ num_heads=(3,),
539
+ attn_ratio=2.,
540
+ mlp_ratio=2.,
541
+ stem_backbone=None,
542
+ stem_stride=None,
543
+ stem_type='s16',
544
+ down_op='subsample',
545
+ act_layer='hard_swish',
546
+ attn_act_layer=None,
547
+ use_conv=False,
548
+ global_pool='avg',
549
+ drop_rate=0.,
550
+ drop_path_rate=0.):
551
+ super().__init__()
552
+ act_layer = get_act_layer(act_layer)
553
+ attn_act_layer = get_act_layer(attn_act_layer or act_layer)
554
+ self.use_conv = use_conv
555
+ self.num_classes = num_classes
556
+ self.global_pool = global_pool
557
+ self.num_features = embed_dim[-1]
558
+ self.embed_dim = embed_dim
559
+ self.drop_rate = drop_rate
560
+ self.grad_checkpointing = False
561
+ self.feature_info = []
562
+
563
+ num_stages = len(embed_dim)
564
+ assert len(depth) == num_stages
565
+ num_heads = to_ntuple(num_stages)(num_heads)
566
+ attn_ratio = to_ntuple(num_stages)(attn_ratio)
567
+ mlp_ratio = to_ntuple(num_stages)(mlp_ratio)
568
+
569
+ if stem_backbone is not None:
570
+ assert stem_stride >= 2
571
+ self.stem = stem_backbone
572
+ stride = stem_stride
573
+ else:
574
+ assert stem_type in ('s16', 's8')
575
+ if stem_type == 's16':
576
+ self.stem = Stem16(in_chans, embed_dim[0], act_layer=act_layer)
577
+ else:
578
+ self.stem = Stem8(in_chans, embed_dim[0], act_layer=act_layer)
579
+ stride = self.stem.stride
580
+ resolution = tuple([i // p for i, p in zip(to_2tuple(img_size), to_2tuple(stride))])
581
+
582
+ in_dim = embed_dim[0]
583
+ stages = []
584
+ for i in range(num_stages):
585
+ stage_stride = 2 if i > 0 else 1
586
+ stages += [LevitStage(
587
+ in_dim,
588
+ embed_dim[i],
589
+ key_dim,
590
+ depth=depth[i],
591
+ num_heads=num_heads[i],
592
+ attn_ratio=attn_ratio[i],
593
+ mlp_ratio=mlp_ratio[i],
594
+ act_layer=act_layer,
595
+ attn_act_layer=attn_act_layer,
596
+ resolution=resolution,
597
+ use_conv=use_conv,
598
+ downsample=down_op if stage_stride == 2 else '',
599
+ drop_path=drop_path_rate
600
+ )]
601
+ stride *= stage_stride
602
+ resolution = tuple([(r - 1) // stage_stride + 1 for r in resolution])
603
+ self.feature_info += [dict(num_chs=embed_dim[i], reduction=stride, module=f'stages.{i}')]
604
+ in_dim = embed_dim[i]
605
+ self.stages = nn.Sequential(*stages)
606
+
607
+ # Classifier head
608
+ self.head = NormLinear(embed_dim[-1], num_classes, drop=drop_rate) if num_classes > 0 else nn.Identity()
609
+
610
+ @torch.jit.ignore
611
+ def no_weight_decay(self):
612
+ return {x for x in self.state_dict().keys() if 'attention_biases' in x}
613
+
614
+ @torch.jit.ignore
615
+ def group_matcher(self, coarse=False):
616
+ matcher = dict(
617
+ stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
618
+ blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
619
+ )
620
+ return matcher
621
+
622
+ @torch.jit.ignore
623
+ def set_grad_checkpointing(self, enable=True):
624
+ self.grad_checkpointing = enable
625
+
626
+ @torch.jit.ignore
627
+ def get_classifier(self):
628
+ return self.head
629
+
630
+ def reset_classifier(self, num_classes, global_pool=None, distillation=None):
631
+ self.num_classes = num_classes
632
+ if global_pool is not None:
633
+ self.global_pool = global_pool
634
+ self.head = NormLinear(
635
+ self.embed_dim[-1], num_classes, drop=self.drop_rate) if num_classes > 0 else nn.Identity()
636
+
637
+ def forward_features(self, x):
638
+ x = self.stem(x)
639
+ if not self.use_conv:
640
+ x = x.flatten(2).transpose(1, 2)
641
+ if self.grad_checkpointing and not torch.jit.is_scripting():
642
+ x = checkpoint_seq(self.stages, x)
643
+ else:
644
+ x = self.stages(x)
645
+ return x
646
+
647
+ def forward_head(self, x, pre_logits: bool = False):
648
+ if self.global_pool == 'avg':
649
+ x = x.mean(dim=(-2, -1)) if self.use_conv else x.mean(dim=1)
650
+ return x if pre_logits else self.head(x)
651
+
652
+ def forward(self, x):
653
+ x = self.forward_features(x)
654
+ x = self.forward_head(x)
655
+ return x
656
+
657
+
658
+ class LevitDistilled(Levit):
659
+ def __init__(self, *args, **kwargs):
660
+ super().__init__(*args, **kwargs)
661
+ self.head_dist = NormLinear(self.num_features, self.num_classes) if self.num_classes > 0 else nn.Identity()
662
+ self.distilled_training = False # must set this True to train w/ distillation token
663
+
664
+ @torch.jit.ignore
665
+ def get_classifier(self):
666
+ return self.head, self.head_dist
667
+
668
+ def reset_classifier(self, num_classes, global_pool=None, distillation=None):
669
+ self.num_classes = num_classes
670
+ if global_pool is not None:
671
+ self.global_pool = global_pool
672
+ self.head = NormLinear(
673
+ self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else nn.Identity()
674
+ self.head_dist = NormLinear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
675
+
676
+ @torch.jit.ignore
677
+ def set_distilled_training(self, enable=True):
678
+ self.distilled_training = enable
679
+
680
+ def forward_head(self, x, pre_logits: bool = False):
681
+ if self.global_pool == 'avg':
682
+ x = x.mean(dim=(-2, -1)) if self.use_conv else x.mean(dim=1)
683
+ if pre_logits:
684
+ return x
685
+ x, x_dist = self.head(x), self.head_dist(x)
686
+ if self.distilled_training and self.training and not torch.jit.is_scripting():
687
+ # only return separate classification predictions when training in distilled mode
688
+ return x, x_dist
689
+ else:
690
+ # during standard train/finetune, inference average the classifier predictions
691
+ return (x + x_dist) / 2
692
+
693
+
694
+ def checkpoint_filter_fn(state_dict, model):
695
+ if 'model' in state_dict:
696
+ state_dict = state_dict['model']
697
+
698
+ # filter out attn biases, should not have been persistent
699
+ state_dict = {k: v for k, v in state_dict.items() if 'attention_bias_idxs' not in k}
700
+
701
+ D = model.state_dict()
702
+ out_dict = {}
703
+ for ka, kb, va, vb in zip(D.keys(), state_dict.keys(), D.values(), state_dict.values()):
704
+ if va.ndim == 4 and vb.ndim == 2:
705
+ vb = vb[:, :, None, None]
706
+ if va.shape != vb.shape:
707
+ # head or first-conv shapes may change for fine-tune
708
+ assert 'head' in ka or 'stem.conv1.linear' in ka
709
+ out_dict[ka] = vb
710
+
711
+ return out_dict
712
+
713
+
714
+ model_cfgs = dict(
715
+ levit_128s=dict(
716
+ embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 6, 8), depth=(2, 3, 4)),
717
+ levit_128=dict(
718
+ embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 8, 12), depth=(4, 4, 4)),
719
+ levit_192=dict(
720
+ embed_dim=(192, 288, 384), key_dim=32, num_heads=(3, 5, 6), depth=(4, 4, 4)),
721
+ levit_256=dict(
722
+ embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 4, 4)),
723
+ levit_384=dict(
724
+ embed_dim=(384, 512, 768), key_dim=32, num_heads=(6, 9, 12), depth=(4, 4, 4)),
725
+
726
+ # stride-8 stem experiments
727
+ levit_384_s8=dict(
728
+ embed_dim=(384, 512, 768), key_dim=32, num_heads=(6, 9, 12), depth=(4, 4, 4),
729
+ act_layer='silu', stem_type='s8'),
730
+ levit_512_s8=dict(
731
+ embed_dim=(512, 640, 896), key_dim=64, num_heads=(8, 10, 14), depth=(4, 4, 4),
732
+ act_layer='silu', stem_type='s8'),
733
+
734
+ # wider experiments
735
+ levit_512=dict(
736
+ embed_dim=(512, 768, 1024), key_dim=64, num_heads=(8, 12, 16), depth=(4, 4, 4), act_layer='silu'),
737
+
738
+ # deeper experiments
739
+ levit_256d=dict(
740
+ embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 8, 6), act_layer='silu'),
741
+ levit_512d=dict(
742
+ embed_dim=(512, 640, 768), key_dim=64, num_heads=(8, 10, 12), depth=(4, 8, 6), act_layer='silu'),
743
+ )
744
+
745
+
746
+ def create_levit(variant, cfg_variant=None, pretrained=False, distilled=True, **kwargs):
747
+ is_conv = '_conv' in variant
748
+ out_indices = kwargs.pop('out_indices', (0, 1, 2))
749
+ if kwargs.get('features_only', None):
750
+ if not is_conv:
751
+ raise RuntimeError('features_only not implemented for LeVit in non-convolutional mode.')
752
+ if cfg_variant is None:
753
+ if variant in model_cfgs:
754
+ cfg_variant = variant
755
+ elif is_conv:
756
+ cfg_variant = variant.replace('_conv', '')
757
+
758
+ model_cfg = dict(model_cfgs[cfg_variant], **kwargs)
759
+ model = build_model_with_cfg(
760
+ LevitDistilled if distilled else Levit,
761
+ variant,
762
+ pretrained,
763
+ pretrained_filter_fn=checkpoint_filter_fn,
764
+ feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
765
+ **model_cfg,
766
+ )
767
+ return model
768
+
769
+
770
+ def _cfg(url='', **kwargs):
771
+ return {
772
+ 'url': url,
773
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
774
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
775
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
776
+ 'first_conv': 'stem.conv1.linear', 'classifier': ('head.linear', 'head_dist.linear'),
777
+ **kwargs
778
+ }
779
+
780
+
781
+ default_cfgs = generate_default_cfgs({
782
+ # weights in nn.Linear mode
783
+ 'levit_128s.fb_dist_in1k': _cfg(
784
+ hf_hub_id='timm/',
785
+ ),
786
+ 'levit_128.fb_dist_in1k': _cfg(
787
+ hf_hub_id='timm/',
788
+ ),
789
+ 'levit_192.fb_dist_in1k': _cfg(
790
+ hf_hub_id='timm/',
791
+ ),
792
+ 'levit_256.fb_dist_in1k': _cfg(
793
+ hf_hub_id='timm/',
794
+ ),
795
+ 'levit_384.fb_dist_in1k': _cfg(
796
+ hf_hub_id='timm/',
797
+ ),
798
+
799
+ # weights in nn.Conv2d mode
800
+ 'levit_conv_128s.fb_dist_in1k': _cfg(
801
+ hf_hub_id='timm/',
802
+ pool_size=(4, 4),
803
+ ),
804
+ 'levit_conv_128.fb_dist_in1k': _cfg(
805
+ hf_hub_id='timm/',
806
+ pool_size=(4, 4),
807
+ ),
808
+ 'levit_conv_192.fb_dist_in1k': _cfg(
809
+ hf_hub_id='timm/',
810
+ pool_size=(4, 4),
811
+ ),
812
+ 'levit_conv_256.fb_dist_in1k': _cfg(
813
+ hf_hub_id='timm/',
814
+ pool_size=(4, 4),
815
+ ),
816
+ 'levit_conv_384.fb_dist_in1k': _cfg(
817
+ hf_hub_id='timm/',
818
+ pool_size=(4, 4),
819
+ ),
820
+
821
+ 'levit_384_s8.untrained': _cfg(classifier='head.linear'),
822
+ 'levit_512_s8.untrained': _cfg(classifier='head.linear'),
823
+ 'levit_512.untrained': _cfg(classifier='head.linear'),
824
+ 'levit_256d.untrained': _cfg(classifier='head.linear'),
825
+ 'levit_512d.untrained': _cfg(classifier='head.linear'),
826
+
827
+ 'levit_conv_384_s8.untrained': _cfg(classifier='head.linear'),
828
+ 'levit_conv_512_s8.untrained': _cfg(classifier='head.linear'),
829
+ 'levit_conv_512.untrained': _cfg(classifier='head.linear'),
830
+ 'levit_conv_256d.untrained': _cfg(classifier='head.linear'),
831
+ 'levit_conv_512d.untrained': _cfg(classifier='head.linear'),
832
+ })
833
+
834
+
835
+ @register_model
836
+ def levit_128s(pretrained=False, **kwargs) -> Levit:
837
+ return create_levit('levit_128s', pretrained=pretrained, **kwargs)
838
+
839
+
840
+ @register_model
841
+ def levit_128(pretrained=False, **kwargs) -> Levit:
842
+ return create_levit('levit_128', pretrained=pretrained, **kwargs)
843
+
844
+
845
+ @register_model
846
+ def levit_192(pretrained=False, **kwargs) -> Levit:
847
+ return create_levit('levit_192', pretrained=pretrained, **kwargs)
848
+
849
+
850
+ @register_model
851
+ def levit_256(pretrained=False, **kwargs) -> Levit:
852
+ return create_levit('levit_256', pretrained=pretrained, **kwargs)
853
+
854
+
855
+ @register_model
856
+ def levit_384(pretrained=False, **kwargs) -> Levit:
857
+ return create_levit('levit_384', pretrained=pretrained, **kwargs)
858
+
859
+
860
+ @register_model
861
+ def levit_384_s8(pretrained=False, **kwargs) -> Levit:
862
+ return create_levit('levit_384_s8', pretrained=pretrained, **kwargs)
863
+
864
+
865
+ @register_model
866
+ def levit_512_s8(pretrained=False, **kwargs) -> Levit:
867
+ return create_levit('levit_512_s8', pretrained=pretrained, distilled=False, **kwargs)
868
+
869
+
870
+ @register_model
871
+ def levit_512(pretrained=False, **kwargs) -> Levit:
872
+ return create_levit('levit_512', pretrained=pretrained, distilled=False, **kwargs)
873
+
874
+
875
+ @register_model
876
+ def levit_256d(pretrained=False, **kwargs) -> Levit:
877
+ return create_levit('levit_256d', pretrained=pretrained, distilled=False, **kwargs)
878
+
879
+
880
+ @register_model
881
+ def levit_512d(pretrained=False, **kwargs) -> Levit:
882
+ return create_levit('levit_512d', pretrained=pretrained, distilled=False, **kwargs)
883
+
884
+
885
+ @register_model
886
+ def levit_conv_128s(pretrained=False, **kwargs) -> Levit:
887
+ return create_levit('levit_conv_128s', pretrained=pretrained, use_conv=True, **kwargs)
888
+
889
+
890
+ @register_model
891
+ def levit_conv_128(pretrained=False, **kwargs) -> Levit:
892
+ return create_levit('levit_conv_128', pretrained=pretrained, use_conv=True, **kwargs)
893
+
894
+
895
+ @register_model
896
+ def levit_conv_192(pretrained=False, **kwargs) -> Levit:
897
+ return create_levit('levit_conv_192', pretrained=pretrained, use_conv=True, **kwargs)
898
+
899
+
900
+ @register_model
901
+ def levit_conv_256(pretrained=False, **kwargs) -> Levit:
902
+ return create_levit('levit_conv_256', pretrained=pretrained, use_conv=True, **kwargs)
903
+
904
+
905
+ @register_model
906
+ def levit_conv_384(pretrained=False, **kwargs) -> Levit:
907
+ return create_levit('levit_conv_384', pretrained=pretrained, use_conv=True, **kwargs)
908
+
909
+
910
+ @register_model
911
+ def levit_conv_384_s8(pretrained=False, **kwargs) -> Levit:
912
+ return create_levit('levit_conv_384_s8', pretrained=pretrained, use_conv=True, **kwargs)
913
+
914
+
915
+ @register_model
916
+ def levit_conv_512_s8(pretrained=False, **kwargs) -> Levit:
917
+ return create_levit('levit_conv_512_s8', pretrained=pretrained, use_conv=True, distilled=False, **kwargs)
918
+
919
+
920
+ @register_model
921
+ def levit_conv_512(pretrained=False, **kwargs) -> Levit:
922
+ return create_levit('levit_conv_512', pretrained=pretrained, use_conv=True, distilled=False, **kwargs)
923
+
924
+
925
+ @register_model
926
+ def levit_conv_256d(pretrained=False, **kwargs) -> Levit:
927
+ return create_levit('levit_conv_256d', pretrained=pretrained, use_conv=True, distilled=False, **kwargs)
928
+
929
+
930
+ @register_model
931
+ def levit_conv_512d(pretrained=False, **kwargs) -> Levit:
932
+ return create_levit('levit_conv_512d', pretrained=pretrained, use_conv=True, distilled=False, **kwargs)
933
+