diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/decathlon_datalist.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/decathlon_datalist.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..93ea1efe8320b56da67cb7d288bc79d114522691
Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/decathlon_datalist.cpython-38.pyc differ
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/image_dataset.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/image_dataset.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..67980831c570796af6a509aa0e7411d6ad9b8ae3
Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/image_dataset.cpython-38.pyc differ
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/image_reader.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/image_reader.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f5ec94291852c8e1d365850ee4d60482610d3b46
Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/data/__pycache__/image_reader.cpython-38.pyc differ
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/__init__.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..14a37a54c252d7f6453d506baa6344c4c57627d4
Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/__init__.cpython-38.pyc differ
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/aspp.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/aspp.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2bcad641a6a6fe140eb715626f67149d5ce6a948
Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/aspp.cpython-38.pyc differ
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/backbone_fpn_utils.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/backbone_fpn_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..409d621e23e3cad41b2f0ca27134e7331c71ff8b
Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/backbone_fpn_utils.cpython-38.pyc differ
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/crf.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/crf.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b309561232f03310000f0228f2c4c970e42631ea
Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/crf.cpython-38.pyc differ
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/fcn.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/fcn.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..37ac8cde9299bdb38f2a4907d0cf996936b9bc4d
Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/fcn.cpython-38.pyc differ
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/feature_pyramid_network.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/feature_pyramid_network.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0ea3f4514b41fa550a50ab454148420f658b3ab6
Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/feature_pyramid_network.cpython-38.pyc differ
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/patchembedding.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/patchembedding.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ca7777948dc3355c2521e6d22ffbdfea14bbdee9
Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/patchembedding.cpython-38.pyc differ
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/segresnet_block.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/segresnet_block.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3c51ec93b4885b3efeac429433e1cd6ac7263418
Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/segresnet_block.cpython-38.pyc differ
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/unetr_block.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/unetr_block.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..26d2237dd0b3136ffe2733ede61a00d752017391
Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/unetr_block.cpython-38.pyc differ
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/upsample.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/upsample.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7a6ba92b4c8ef8a33a4524b6355caca7e3059827
Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/blocks/__pycache__/upsample.cpython-38.pyc differ
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/__init__.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0eb1cd5a0f4cf42af35c25484b313547e7af4d6f
Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/__init__.cpython-38.pyc differ
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/convutils.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/convutils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a5758cc6eb932b781864b671601a43058289a165
Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/convutils.cpython-38.pyc differ
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/factories.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/factories.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1488ca75e3a86a41e8dcd0ad494e9dbd1ae469e1
Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/factories.cpython-38.pyc differ
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/simplelayers.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/simplelayers.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3590ffc3685a23a7bb3ad027b2042711536ef16b
Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/simplelayers.cpython-38.pyc differ
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/weight_init.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/weight_init.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..7471f7f3091b7912144c95e88fa5d5ba151ba2bb
Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/networks/layers/__pycache__/weight_init.cpython-38.pyc differ
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/__init__.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..949a706c9ed793de3dd13e59a0c5a1b019a87337
Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/__init__.cpython-38.pyc differ
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/decorators.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/decorators.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4f1177dc42a1d54cf30c43be2a02cf50601a5831
Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/decorators.cpython-38.pyc differ
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/deprecate_utils.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/deprecate_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a0a598c1ef79d8b6139e279ab773df1ebe8475b3
Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/deprecate_utils.cpython-38.pyc differ
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/jupyter_utils.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/jupyter_utils.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0f10299c0d8ed7f4d1818f5673854dc6d040de26
Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/jupyter_utils.cpython-38.pyc differ
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/module.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/module.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9cd35e9c2780010ebab262d0fdcd939d8b1caddc
Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/module.cpython-38.pyc differ
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/profiling.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/profiling.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f7f950f2e223c4052cf48f1e5ac338fa05df4002
Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/profiling.cpython-38.pyc differ
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/type_conversion.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/type_conversion.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..378f2431cb39dda9c4c57b836ec85d89f054e22b
Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/monai/utils/__pycache__/type_conversion.cpython-38.pyc differ
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/__pycache__/__init__.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..073e041f4b54e50efb68d1b71ecc67d51eea9d16
Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/__pycache__/__init__.cpython-38.pyc differ
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/__pycache__/version.cpython-38.pyc b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/__pycache__/version.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..4c0e0a2fc47051d7cba0c6200d6bd070cb9ee429
Binary files /dev/null and b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/__pycache__/version.cpython-38.pyc differ
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/__init__.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d09efac9210283bfbd6509552659a7a99a7fbbe
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/__init__.py
@@ -0,0 +1,94 @@
+from .beit import *
+from .byoanet import *
+from .byobnet import *
+from .cait import *
+from .coat import *
+from .convit import *
+from .convmixer import *
+from .convnext import *
+from .crossvit import *
+from .cspnet import *
+from .davit import *
+from .deit import *
+from .densenet import *
+from .dla import *
+from .dpn import *
+from .edgenext import *
+from .efficientformer import *
+from .efficientformer_v2 import *
+from .efficientnet import *
+from .efficientvit_mit import *
+from .efficientvit_msra import *
+from .eva import *
+from .fastvit import *
+from .focalnet import *
+from .gcvit import *
+from .ghostnet import *
+from .hardcorenas import *
+from .hgnet import *
+from .hrnet import *
+from .inception_next import *
+from .inception_resnet_v2 import *
+from .inception_v3 import *
+from .inception_v4 import *
+from .levit import *
+from .maxxvit import *
+from .metaformer import *
+from .mlp_mixer import *
+from .mobilenetv3 import *
+from .mobilevit import *
+from .mvitv2 import *
+from .nasnet import *
+from .nest import *
+from .nextvit import *
+from .nfnet import *
+from .pit import *
+from .pnasnet import *
+from .pvt_v2 import *
+from .regnet import *
+from .repghost import *
+from .repvit import *
+from .res2net import *
+from .resnest import *
+from .resnet import *
+from .resnetv2 import *
+from .rexnet import *
+from .selecsls import *
+from .senet import *
+from .sequencer import *
+from .sknet import *
+from .swin_transformer import *
+from .swin_transformer_v2 import *
+from .swin_transformer_v2_cr import *
+from .tiny_vit import *
+from .tnt import *
+from .tresnet import *
+from .twins import *
+from .vgg import *
+from .visformer import *
+from .vision_transformer import *
+from .vision_transformer_hybrid import *
+from .vision_transformer_relpos import *
+from .vision_transformer_sam import *
+from .volo import *
+from .vovnet import *
+from .xception import *
+from .xception_aligned import *
+from .xcit import *
+
+from ._builder import build_model_with_cfg, load_pretrained, load_custom_pretrained, resolve_pretrained_cfg, \
+ set_pretrained_download_progress, set_pretrained_check_hash
+from ._factory import create_model, parse_model_name, safe_model_name
+from ._features import FeatureInfo, FeatureHooks, FeatureHookNet, FeatureListNet, FeatureDictNet
+from ._features_fx import FeatureGraphNet, GraphExtractNet, create_feature_extractor, \
+ register_notrace_module, is_notrace_module, get_notrace_modules, \
+ register_notrace_function, is_notrace_function, get_notrace_functions
+from ._helpers import clean_state_dict, load_state_dict, load_checkpoint, remap_state_dict, resume_checkpoint
+from ._hub import load_model_config_from_hf, load_state_dict_from_hf, push_to_hf_hub
+from ._manipulate import model_parameters, named_apply, named_modules, named_modules_with_params, \
+ group_modules, group_parameters, checkpoint_seq, adapt_input_conv
+from ._pretrained import PretrainedCfg, DefaultCfg, filter_pretrained_cfg
+from ._prune import adapt_model_from_string
+from ._registry import split_model_name_tag, get_arch_name, generate_default_cfgs, register_model, \
+ register_model_deprecations, model_entrypoint, list_models, list_pretrained, get_deprecated_models, \
+ is_model, list_modules, is_model_in_modules, is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_efficientnet_builder.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_efficientnet_builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e3161d6bec37c145ccd96550e97d1be3726aaff
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_efficientnet_builder.py
@@ -0,0 +1,484 @@
+""" EfficientNet, MobileNetV3, etc Builder
+
+Assembles EfficieNet and related network feature blocks from string definitions.
+Handles stride, dilation calculations, and selects feature extraction points.
+
+Hacked together by / Copyright 2019, Ross Wightman
+"""
+
+import logging
+import math
+import re
+from copy import deepcopy
+from functools import partial
+from typing import Any, Dict, List
+
+import torch.nn as nn
+
+from ._efficientnet_blocks import *
+from timm.layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible
+
+__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights",
+ 'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT']
+
+_logger = logging.getLogger(__name__)
+
+
+_DEBUG_BUILDER = False
+
+# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
+# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
+# NOTE: momentum varies btw .99 and .9997 depending on source
+# .99 in official TF TPU impl
+# .9997 (/w .999 in search space) for paper
+BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
+BN_EPS_TF_DEFAULT = 1e-3
+_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT)
+
+BlockArgs = List[List[Dict[str, Any]]]
+
+
+def get_bn_args_tf():
+ return _BN_ARGS_TF.copy()
+
+
+def resolve_bn_args(kwargs):
+ bn_args = {}
+ bn_momentum = kwargs.pop('bn_momentum', None)
+ if bn_momentum is not None:
+ bn_args['momentum'] = bn_momentum
+ bn_eps = kwargs.pop('bn_eps', None)
+ if bn_eps is not None:
+ bn_args['eps'] = bn_eps
+ return bn_args
+
+
+def resolve_act_layer(kwargs, default='relu'):
+ return get_act_layer(kwargs.pop('act_layer', default))
+
+
+def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None, round_limit=0.9):
+ """Round number of filters based on depth multiplier."""
+ if not multiplier:
+ return channels
+ return make_divisible(channels * multiplier, divisor, channel_min, round_limit=round_limit)
+
+
+def _log_info_if(msg, condition):
+ if condition:
+ _logger.info(msg)
+
+
+def _parse_ksize(ss):
+ if ss.isdigit():
+ return int(ss)
+ else:
+ return [int(k) for k in ss.split('.')]
+
+
+def _decode_block_str(block_str):
+ """ Decode block definition string
+
+ Gets a list of block arg (dicts) through a string notation of arguments.
+ E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip
+
+ All args can exist in any order with the exception of the leading string which
+ is assumed to indicate the block type.
+
+ leading string - block type (
+ ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
+ r - number of repeat blocks,
+ k - kernel size,
+ s - strides (1-9),
+ e - expansion ratio,
+ c - output channels,
+ se - squeeze/excitation ratio
+ n - activation fn ('re', 'r6', 'hs', or 'sw')
+ Args:
+ block_str: a string representation of block arguments.
+ Returns:
+ A list of block args (dicts)
+ Raises:
+ ValueError: if the string def not properly specified (TODO)
+ """
+ assert isinstance(block_str, str)
+ ops = block_str.split('_')
+ block_type = ops[0] # take the block type off the front
+ ops = ops[1:]
+ options = {}
+ skip = None
+ for op in ops:
+ # string options being checked on individual basis, combine if they grow
+ if op == 'noskip':
+ skip = False # force no skip connection
+ elif op == 'skip':
+ skip = True # force a skip connection
+ elif op.startswith('n'):
+ # activation fn
+ key = op[0]
+ v = op[1:]
+ if v == 're':
+ value = get_act_layer('relu')
+ elif v == 'r6':
+ value = get_act_layer('relu6')
+ elif v == 'hs':
+ value = get_act_layer('hard_swish')
+ elif v == 'sw':
+ value = get_act_layer('swish') # aka SiLU
+ elif v == 'mi':
+ value = get_act_layer('mish')
+ else:
+ continue
+ options[key] = value
+ else:
+ # all numeric options
+ splits = re.split(r'(\d.*)', op)
+ if len(splits) >= 2:
+ key, value = splits[:2]
+ options[key] = value
+
+ # if act_layer is None, the model default (passed to model init) will be used
+ act_layer = options['n'] if 'n' in options else None
+ exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
+ pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
+ force_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
+ num_repeat = int(options['r'])
+
+ # each type of block has different valid arguments, fill accordingly
+ block_args = dict(
+ block_type=block_type,
+ out_chs=int(options['c']),
+ stride=int(options['s']),
+ act_layer=act_layer,
+ )
+ if block_type == 'ir':
+ block_args.update(dict(
+ dw_kernel_size=_parse_ksize(options['k']),
+ exp_kernel_size=exp_kernel_size,
+ pw_kernel_size=pw_kernel_size,
+ exp_ratio=float(options['e']),
+ se_ratio=float(options['se']) if 'se' in options else 0.,
+ noskip=skip is False,
+ ))
+ if 'cc' in options:
+ block_args['num_experts'] = int(options['cc'])
+ elif block_type == 'ds' or block_type == 'dsa':
+ block_args.update(dict(
+ dw_kernel_size=_parse_ksize(options['k']),
+ pw_kernel_size=pw_kernel_size,
+ se_ratio=float(options['se']) if 'se' in options else 0.,
+ pw_act=block_type == 'dsa',
+ noskip=block_type == 'dsa' or skip is False,
+ ))
+ elif block_type == 'er':
+ block_args.update(dict(
+ exp_kernel_size=_parse_ksize(options['k']),
+ pw_kernel_size=pw_kernel_size,
+ exp_ratio=float(options['e']),
+ force_in_chs=force_in_chs,
+ se_ratio=float(options['se']) if 'se' in options else 0.,
+ noskip=skip is False,
+ ))
+ elif block_type == 'cn':
+ block_args.update(dict(
+ kernel_size=int(options['k']),
+ skip=skip is True,
+ ))
+ else:
+ assert False, 'Unknown block type (%s)' % block_type
+ if 'gs' in options:
+ block_args['group_size'] = options['gs']
+
+ return block_args, num_repeat
+
+
+def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'):
+ """ Per-stage depth scaling
+ Scales the block repeats in each stage. This depth scaling impl maintains
+ compatibility with the EfficientNet scaling method, while allowing sensible
+ scaling for other models that may have multiple block arg definitions in each stage.
+ """
+
+ # We scale the total repeat count for each stage, there may be multiple
+ # block arg defs per stage so we need to sum.
+ num_repeat = sum(repeats)
+ if depth_trunc == 'round':
+ # Truncating to int by rounding allows stages with few repeats to remain
+ # proportionally smaller for longer. This is a good choice when stage definitions
+ # include single repeat stages that we'd prefer to keep that way as long as possible
+ num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
+ else:
+ # The default for EfficientNet truncates repeats to int via 'ceil'.
+ # Any multiplier > 1.0 will result in an increased depth for every stage.
+ num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))
+
+ # Proportionally distribute repeat count scaling to each block definition in the stage.
+ # Allocation is done in reverse as it results in the first block being less likely to be scaled.
+ # The first block makes less sense to repeat in most of the arch definitions.
+ repeats_scaled = []
+ for r in repeats[::-1]:
+ rs = max(1, round((r / num_repeat * num_repeat_scaled)))
+ repeats_scaled.append(rs)
+ num_repeat -= r
+ num_repeat_scaled -= rs
+ repeats_scaled = repeats_scaled[::-1]
+
+ # Apply the calculated scaling to each block arg in the stage
+ sa_scaled = []
+ for ba, rep in zip(stack_args, repeats_scaled):
+ sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
+ return sa_scaled
+
+
+def decode_arch_def(
+ arch_def,
+ depth_multiplier=1.0,
+ depth_trunc='ceil',
+ experts_multiplier=1,
+ fix_first_last=False,
+ group_size=None,
+):
+ """ Decode block architecture definition strings -> block kwargs
+
+ Args:
+ arch_def: architecture definition strings, list of list of strings
+ depth_multiplier: network depth multiplier
+ depth_trunc: networ depth truncation mode when applying multiplier
+ experts_multiplier: CondConv experts multiplier
+ fix_first_last: fix first and last block depths when multiplier is applied
+ group_size: group size override for all blocks that weren't explicitly set in arch string
+
+ Returns:
+ list of list of block kwargs
+ """
+ arch_args = []
+ if isinstance(depth_multiplier, tuple):
+ assert len(depth_multiplier) == len(arch_def)
+ else:
+ depth_multiplier = (depth_multiplier,) * len(arch_def)
+ for stack_idx, (block_strings, multiplier) in enumerate(zip(arch_def, depth_multiplier)):
+ assert isinstance(block_strings, list)
+ stack_args = []
+ repeats = []
+ for block_str in block_strings:
+ assert isinstance(block_str, str)
+ ba, rep = _decode_block_str(block_str)
+ if ba.get('num_experts', 0) > 0 and experts_multiplier > 1:
+ ba['num_experts'] *= experts_multiplier
+ if group_size is not None:
+ ba.setdefault('group_size', group_size)
+ stack_args.append(ba)
+ repeats.append(rep)
+ if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1):
+ arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc))
+ else:
+ arch_args.append(_scale_stage_depth(stack_args, repeats, multiplier, depth_trunc))
+ return arch_args
+
+
+class EfficientNetBuilder:
+ """ Build Trunk Blocks
+
+ This ended up being somewhat of a cross between
+ https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py
+ and
+ https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py
+
+ """
+ def __init__(self, output_stride=32, pad_type='', round_chs_fn=round_channels, se_from_exp=False,
+ act_layer=None, norm_layer=None, se_layer=None, drop_path_rate=0., feature_location=''):
+ self.output_stride = output_stride
+ self.pad_type = pad_type
+ self.round_chs_fn = round_chs_fn
+ self.se_from_exp = se_from_exp # calculate se channel reduction from expanded (mid) chs
+ self.act_layer = act_layer
+ self.norm_layer = norm_layer
+ self.se_layer = get_attn(se_layer)
+ try:
+ self.se_layer(8, rd_ratio=1.0) # test if attn layer accepts rd_ratio arg
+ self.se_has_ratio = True
+ except TypeError:
+ self.se_has_ratio = False
+ self.drop_path_rate = drop_path_rate
+ if feature_location == 'depthwise':
+ # old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense
+ _logger.warning("feature_location=='depthwise' is deprecated, using 'expansion'")
+ feature_location = 'expansion'
+ self.feature_location = feature_location
+ assert feature_location in ('bottleneck', 'expansion', '')
+ self.verbose = _DEBUG_BUILDER
+
+ # state updated during build, consumed by model
+ self.in_chs = None
+ self.features = []
+
+ def _make_block(self, ba, block_idx, block_count):
+ drop_path_rate = self.drop_path_rate * block_idx / block_count
+ bt = ba.pop('block_type')
+ ba['in_chs'] = self.in_chs
+ ba['out_chs'] = self.round_chs_fn(ba['out_chs'])
+ if 'force_in_chs' in ba and ba['force_in_chs']:
+ # NOTE this is a hack to work around mismatch in TF EdgeEffNet impl
+ ba['force_in_chs'] = self.round_chs_fn(ba['force_in_chs'])
+ ba['pad_type'] = self.pad_type
+ # block act fn overrides the model default
+ ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
+ assert ba['act_layer'] is not None
+ ba['norm_layer'] = self.norm_layer
+ ba['drop_path_rate'] = drop_path_rate
+ if bt != 'cn':
+ se_ratio = ba.pop('se_ratio')
+ if se_ratio and self.se_layer is not None:
+ if not self.se_from_exp:
+ # adjust se_ratio by expansion ratio if calculating se channels from block input
+ se_ratio /= ba.get('exp_ratio', 1.0)
+ if self.se_has_ratio:
+ ba['se_layer'] = partial(self.se_layer, rd_ratio=se_ratio)
+ else:
+ ba['se_layer'] = self.se_layer
+
+ if bt == 'ir':
+ _log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
+ block = CondConvResidual(**ba) if ba.get('num_experts', 0) else InvertedResidual(**ba)
+ elif bt == 'ds' or bt == 'dsa':
+ _log_info_if(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
+ block = DepthwiseSeparableConv(**ba)
+ elif bt == 'er':
+ _log_info_if(' EdgeResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
+ block = EdgeResidual(**ba)
+ elif bt == 'cn':
+ _log_info_if(' ConvBnAct {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
+ block = ConvBnAct(**ba)
+ else:
+ assert False, 'Uknkown block type (%s) while building model.' % bt
+
+ self.in_chs = ba['out_chs'] # update in_chs for arg of next block
+ return block
+
+ def __call__(self, in_chs, model_block_args):
+ """ Build the blocks
+ Args:
+ in_chs: Number of input-channels passed to first block
+ model_block_args: A list of lists, outer list defines stages, inner
+ list contains strings defining block configuration(s)
+ Return:
+ List of block stacks (each stack wrapped in nn.Sequential)
+ """
+ _log_info_if('Building model trunk with %d stages...' % len(model_block_args), self.verbose)
+ self.in_chs = in_chs
+ total_block_count = sum([len(x) for x in model_block_args])
+ total_block_idx = 0
+ current_stride = 2
+ current_dilation = 1
+ stages = []
+ if model_block_args[0][0]['stride'] > 1:
+ # if the first block starts with a stride, we need to extract first level feat from stem
+ feature_info = dict(module='bn1', num_chs=in_chs, stage=0, reduction=current_stride)
+ self.features.append(feature_info)
+
+ # outer list of block_args defines the stacks
+ for stack_idx, stack_args in enumerate(model_block_args):
+ last_stack = stack_idx + 1 == len(model_block_args)
+ _log_info_if('Stack: {}'.format(stack_idx), self.verbose)
+ assert isinstance(stack_args, list)
+
+ blocks = []
+ # each stack (stage of blocks) contains a list of block arguments
+ for block_idx, block_args in enumerate(stack_args):
+ last_block = block_idx + 1 == len(stack_args)
+ _log_info_if(' Block: {}'.format(block_idx), self.verbose)
+
+ assert block_args['stride'] in (1, 2)
+ if block_idx >= 1: # only the first block in any stack can have a stride > 1
+ block_args['stride'] = 1
+
+ extract_features = False
+ if last_block:
+ next_stack_idx = stack_idx + 1
+ extract_features = next_stack_idx >= len(model_block_args) or \
+ model_block_args[next_stack_idx][0]['stride'] > 1
+
+ next_dilation = current_dilation
+ if block_args['stride'] > 1:
+ next_output_stride = current_stride * block_args['stride']
+ if next_output_stride > self.output_stride:
+ next_dilation = current_dilation * block_args['stride']
+ block_args['stride'] = 1
+ _log_info_if(' Converting stride to dilation to maintain output_stride=={}'.format(
+ self.output_stride), self.verbose)
+ else:
+ current_stride = next_output_stride
+ block_args['dilation'] = current_dilation
+ if next_dilation != current_dilation:
+ current_dilation = next_dilation
+
+ # create the block
+ block = self._make_block(block_args, total_block_idx, total_block_count)
+ blocks.append(block)
+
+ # stash feature module name and channel info for model feature extraction
+ if extract_features:
+ feature_info = dict(
+ stage=stack_idx + 1,
+ reduction=current_stride,
+ **block.feature_info(self.feature_location),
+ )
+ leaf_name = feature_info.get('module', '')
+ if leaf_name:
+ feature_info['module'] = '.'.join([f'blocks.{stack_idx}.{block_idx}', leaf_name])
+ else:
+ assert last_block
+ feature_info['module'] = f'blocks.{stack_idx}'
+ self.features.append(feature_info)
+
+ total_block_idx += 1 # incr global block idx (across all stacks)
+ stages.append(nn.Sequential(*blocks))
+ return stages
+
+
+def _init_weight_goog(m, n='', fix_group_fanout=True):
+ """ Weight initialization as per Tensorflow official implementations.
+
+ Args:
+ m (nn.Module): module to init
+ n (str): module name
+ fix_group_fanout (bool): enable correct (matching Tensorflow TPU impl) fanout calculation w/ group convs
+
+ Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc:
+ * https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
+ * https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
+ """
+ if isinstance(m, CondConv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ if fix_group_fanout:
+ fan_out //= m.groups
+ init_weight_fn = get_condconv_initializer(
+ lambda w: nn.init.normal_(w, 0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
+ init_weight_fn(m.weight)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ if fix_group_fanout:
+ fan_out //= m.groups
+ nn.init.normal_(m.weight, 0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Linear):
+ fan_out = m.weight.size(0) # fan-out
+ fan_in = 0
+ if 'routing_fn' in n:
+ fan_in = m.weight.size(1)
+ init_range = 1.0 / math.sqrt(fan_in + fan_out)
+ nn.init.uniform_(m.weight, -init_range, init_range)
+ nn.init.zeros_(m.bias)
+
+
+def efficientnet_init_weights(model: nn.Module, init_fn=None):
+ init_fn = init_fn or _init_weight_goog
+ for n, m in model.named_modules():
+ init_fn(m, n)
+
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_factory.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..bff15b9a6457ca0ed648f4f0c9b2308a76666c1d
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_factory.py
@@ -0,0 +1,127 @@
+import os
+from typing import Any, Dict, Optional, Union
+from urllib.parse import urlsplit
+
+from timm.layers import set_layer_config
+from ._helpers import load_checkpoint
+from ._hub import load_model_config_from_hf
+from ._pretrained import PretrainedCfg
+from ._registry import is_model, model_entrypoint, split_model_name_tag
+
+
+__all__ = ['parse_model_name', 'safe_model_name', 'create_model']
+
+
+def parse_model_name(model_name: str):
+ if model_name.startswith('hf_hub'):
+ # NOTE for backwards compat, deprecate hf_hub use
+ model_name = model_name.replace('hf_hub', 'hf-hub')
+ parsed = urlsplit(model_name)
+ assert parsed.scheme in ('', 'timm', 'hf-hub')
+ if parsed.scheme == 'hf-hub':
+ # FIXME may use fragment as revision, currently `@` in URI path
+ return parsed.scheme, parsed.path
+ else:
+ model_name = os.path.split(parsed.path)[-1]
+ return 'timm', model_name
+
+
+def safe_model_name(model_name: str, remove_source: bool = True):
+ # return a filename / path safe model name
+ def make_safe(name):
+ return ''.join(c if c.isalnum() else '_' for c in name).rstrip('_')
+ if remove_source:
+ model_name = parse_model_name(model_name)[-1]
+ return make_safe(model_name)
+
+
+def create_model(
+ model_name: str,
+ pretrained: bool = False,
+ pretrained_cfg: Optional[Union[str, Dict[str, Any], PretrainedCfg]] = None,
+ pretrained_cfg_overlay: Optional[Dict[str, Any]] = None,
+ checkpoint_path: str = '',
+ scriptable: Optional[bool] = None,
+ exportable: Optional[bool] = None,
+ no_jit: Optional[bool] = None,
+ **kwargs,
+):
+ """Create a model.
+
+ Lookup model's entrypoint function and pass relevant args to create a new model.
+
+
+ **kwargs will be passed through entrypoint fn to ``timm.models.build_model_with_cfg()``
+ and then the model class __init__(). kwargs values set to None are pruned before passing.
+
+
+ Args:
+ model_name: Name of model to instantiate.
+ pretrained: If set to `True`, load pretrained ImageNet-1k weights.
+ pretrained_cfg: Pass in an external pretrained_cfg for model.
+ pretrained_cfg_overlay: Replace key-values in base pretrained_cfg with these.
+ checkpoint_path: Path of checkpoint to load _after_ the model is initialized.
+ scriptable: Set layer config so that model is jit scriptable (not working for all models yet).
+ exportable: Set layer config so that model is traceable / ONNX exportable (not fully impl/obeyed yet).
+ no_jit: Set layer config so that model doesn't utilize jit scripted layers (so far activations only).
+
+ Keyword Args:
+ drop_rate (float): Classifier dropout rate for training.
+ drop_path_rate (float): Stochastic depth drop rate for training.
+ global_pool (str): Classifier global pooling type.
+
+ Example:
+
+ ```py
+ >>> from timm import create_model
+
+ >>> # Create a MobileNetV3-Large model with no pretrained weights.
+ >>> model = create_model('mobilenetv3_large_100')
+
+ >>> # Create a MobileNetV3-Large model with pretrained weights.
+ >>> model = create_model('mobilenetv3_large_100', pretrained=True)
+ >>> model.num_classes
+ 1000
+
+ >>> # Create a MobileNetV3-Large model with pretrained weights and a new head with 10 classes.
+ >>> model = create_model('mobilenetv3_large_100', pretrained=True, num_classes=10)
+ >>> model.num_classes
+ 10
+ ```
+ """
+ # Parameters that aren't supported by all models or are intended to only override model defaults if set
+ # should default to None in command line args/cfg. Remove them if they are present and not set so that
+ # non-supporting models don't break and default args remain in effect.
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
+
+ model_source, model_name = parse_model_name(model_name)
+ if model_source == 'hf-hub':
+ assert not pretrained_cfg, 'pretrained_cfg should not be set when sourcing model from Hugging Face Hub.'
+ # For model names specified in the form `hf-hub:path/architecture_name@revision`,
+ # load model weights + pretrained_cfg from Hugging Face hub.
+ pretrained_cfg, model_name, model_args = load_model_config_from_hf(model_name)
+ if model_args:
+ for k, v in model_args.items():
+ kwargs.setdefault(k, v)
+ else:
+ model_name, pretrained_tag = split_model_name_tag(model_name)
+ if pretrained_tag and not pretrained_cfg:
+ # a valid pretrained_cfg argument takes priority over tag in model name
+ pretrained_cfg = pretrained_tag
+
+ if not is_model(model_name):
+ raise RuntimeError('Unknown model (%s)' % model_name)
+
+ create_fn = model_entrypoint(model_name)
+ with set_layer_config(scriptable=scriptable, exportable=exportable, no_jit=no_jit):
+ model = create_fn(
+ pretrained=pretrained,
+ pretrained_cfg=pretrained_cfg,
+ pretrained_cfg_overlay=pretrained_cfg_overlay,
+ **kwargs,
+ )
+
+ if checkpoint_path:
+ load_checkpoint(model, checkpoint_path)
+
+ return model
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_features.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_features.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ef51809bcecc0bb764d108a1d093da3b35ed405
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_features.py
@@ -0,0 +1,368 @@
+""" PyTorch Feature Extraction Helpers
+
+A collection of classes, functions, modules to help extract features from models
+and provide a common interface for describing them.
+
+The return_layers, module re-writing idea inspired by torchvision IntermediateLayerGetter
+https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+from collections import OrderedDict, defaultdict
+from copy import deepcopy
+from functools import partial
+from typing import Dict, List, Sequence, Tuple, Union
+
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+
+from timm.layers import Format
+
+
+__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet']
+
+
+class FeatureInfo:
+
+ def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
+ prev_reduction = 1
+ for i, fi in enumerate(feature_info):
+ # sanity check the mandatory fields, there may be additional fields depending on the model
+ assert 'num_chs' in fi and fi['num_chs'] > 0
+ assert 'reduction' in fi and fi['reduction'] >= prev_reduction
+ prev_reduction = fi['reduction']
+ assert 'module' in fi
+ fi.setdefault('index', i)
+ self.out_indices = out_indices
+ self.info = feature_info
+
+ def from_other(self, out_indices: Tuple[int]):
+ return FeatureInfo(deepcopy(self.info), out_indices)
+
+ def get(self, key, idx=None):
+ """ Get value by key at specified index (indices)
+ if idx == None, returns value for key at each output index
+ if idx is an integer, return value for that feature module index (ignoring output indices)
+ if idx is a list/tupple, return value for each module index (ignoring output indices)
+ """
+ if idx is None:
+ return [self.info[i][key] for i in self.out_indices]
+ if isinstance(idx, (tuple, list)):
+ return [self.info[i][key] for i in idx]
+ else:
+ return self.info[idx][key]
+
+ def get_dicts(self, keys=None, idx=None):
+ """ return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)
+ """
+ if idx is None:
+ if keys is None:
+ return [self.info[i] for i in self.out_indices]
+ else:
+ return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
+ if isinstance(idx, (tuple, list)):
+ return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx]
+ else:
+ return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
+
+ def channels(self, idx=None):
+ """ feature channels accessor
+ """
+ return self.get('num_chs', idx)
+
+ def reduction(self, idx=None):
+ """ feature reduction (output stride) accessor
+ """
+ return self.get('reduction', idx)
+
+ def module_name(self, idx=None):
+ """ feature module name accessor
+ """
+ return self.get('module', idx)
+
+ def __getitem__(self, item):
+ return self.info[item]
+
+ def __len__(self):
+ return len(self.info)
+
+
+class FeatureHooks:
+ """ Feature Hook Helper
+
+ This module helps with the setup and extraction of hooks for extracting features from
+ internal nodes in a model by node name.
+
+ FIXME This works well in eager Python but needs redesign for torchscript.
+ """
+
+ def __init__(
+ self,
+ hooks: Sequence[str],
+ named_modules: dict,
+ out_map: Sequence[Union[int, str]] = None,
+ default_hook_type: str = 'forward',
+ ):
+ # setup feature hooks
+ self._feature_outputs = defaultdict(OrderedDict)
+ modules = {k: v for k, v in named_modules}
+ for i, h in enumerate(hooks):
+ hook_name = h['module']
+ m = modules[hook_name]
+ hook_id = out_map[i] if out_map else hook_name
+ hook_fn = partial(self._collect_output_hook, hook_id)
+ hook_type = h.get('hook_type', default_hook_type)
+ if hook_type == 'forward_pre':
+ m.register_forward_pre_hook(hook_fn)
+ elif hook_type == 'forward':
+ m.register_forward_hook(hook_fn)
+ else:
+ assert False, "Unsupported hook type"
+
+ def _collect_output_hook(self, hook_id, *args):
+ x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre
+ if isinstance(x, tuple):
+ x = x[0] # unwrap input tuple
+ self._feature_outputs[x.device][hook_id] = x
+
+ def get_output(self, device) -> Dict[str, torch.tensor]:
+ output = self._feature_outputs[device]
+ self._feature_outputs[device] = OrderedDict() # clear after reading
+ return output
+
+
+def _module_list(module, flatten_sequential=False):
+ # a yield/iter would be better for this but wouldn't be compatible with torchscript
+ ml = []
+ for name, module in module.named_children():
+ if flatten_sequential and isinstance(module, nn.Sequential):
+ # first level of Sequential containers is flattened into containing model
+ for child_name, child_module in module.named_children():
+ combined = [name, child_name]
+ ml.append(('_'.join(combined), '.'.join(combined), child_module))
+ else:
+ ml.append((name, name, module))
+ return ml
+
+
+def _get_feature_info(net, out_indices):
+ feature_info = getattr(net, 'feature_info')
+ if isinstance(feature_info, FeatureInfo):
+ return feature_info.from_other(out_indices)
+ elif isinstance(feature_info, (list, tuple)):
+ return FeatureInfo(net.feature_info, out_indices)
+ else:
+ assert False, "Provided feature_info is not valid"
+
+
+def _get_return_layers(feature_info, out_map):
+ module_names = feature_info.module_name()
+ return_layers = {}
+ for i, name in enumerate(module_names):
+ return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i]
+ return return_layers
+
+
+class FeatureDictNet(nn.ModuleDict):
+ """ Feature extractor with OrderedDict return
+
+ Wrap a model and extract features as specified by the out indices, the network is
+ partially re-built from contained modules.
+
+ There is a strong assumption that the modules have been registered into the model in the same
+ order as they are used. There should be no reuse of the same nn.Module more than once, including
+ trivial modules like `self.relu = nn.ReLU`.
+
+ Only submodules that are directly assigned to the model class (`model.feature1`) or at most
+ one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured.
+ All Sequential containers that are directly assigned to the original model will have their
+ modules assigned to this module with the name `model.features.1` being changed to `model.features_1`
+ """
+ def __init__(
+ self,
+ model: nn.Module,
+ out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
+ out_map: Sequence[Union[int, str]] = None,
+ output_fmt: str = 'NCHW',
+ feature_concat: bool = False,
+ flatten_sequential: bool = False,
+ ):
+ """
+ Args:
+ model: Model from which to extract features.
+ out_indices: Output indices of the model features to extract.
+ out_map: Return id mapping for each output index, otherwise str(index) is used.
+ feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
+ first element e.g. `x[0]`
+ flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
+ """
+ super(FeatureDictNet, self).__init__()
+ self.feature_info = _get_feature_info(model, out_indices)
+ self.output_fmt = Format(output_fmt)
+ self.concat = feature_concat
+ self.grad_checkpointing = False
+ self.return_layers = {}
+
+ return_layers = _get_return_layers(self.feature_info, out_map)
+ modules = _module_list(model, flatten_sequential=flatten_sequential)
+ remaining = set(return_layers.keys())
+ layers = OrderedDict()
+ for new_name, old_name, module in modules:
+ layers[new_name] = module
+ if old_name in remaining:
+ # return id has to be consistently str type for torchscript
+ self.return_layers[new_name] = str(return_layers[old_name])
+ remaining.remove(old_name)
+ if not remaining:
+ break
+ assert not remaining and len(self.return_layers) == len(return_layers), \
+ f'Return layers ({remaining}) are not present in model'
+ self.update(layers)
+
+ def set_grad_checkpointing(self, enable: bool = True):
+ self.grad_checkpointing = enable
+
+ def _collect(self, x) -> (Dict[str, torch.Tensor]):
+ out = OrderedDict()
+ for i, (name, module) in enumerate(self.items()):
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ # Skipping checkpoint of first module because need a gradient at input
+ # Skipping last because networks with in-place ops might fail w/ checkpointing enabled
+ # NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues
+ first_or_last_module = i == 0 or i == max(len(self) - 1, 0)
+ x = module(x) if first_or_last_module else checkpoint(module, x)
+ else:
+ x = module(x)
+
+ if name in self.return_layers:
+ out_id = self.return_layers[name]
+ if isinstance(x, (tuple, list)):
+ # If model tap is a tuple or list, concat or select first element
+ # FIXME this may need to be more generic / flexible for some nets
+ out[out_id] = torch.cat(x, 1) if self.concat else x[0]
+ else:
+ out[out_id] = x
+ return out
+
+ def forward(self, x) -> Dict[str, torch.Tensor]:
+ return self._collect(x)
+
+
+class FeatureListNet(FeatureDictNet):
+ """ Feature extractor with list return
+
+ A specialization of FeatureDictNet that always returns features as a list (values() of dict).
+ """
+ def __init__(
+ self,
+ model: nn.Module,
+ out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
+ output_fmt: str = 'NCHW',
+ feature_concat: bool = False,
+ flatten_sequential: bool = False,
+ ):
+ """
+ Args:
+ model: Model from which to extract features.
+ out_indices: Output indices of the model features to extract.
+ feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting
+ first element e.g. `x[0]`
+ flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules)
+ """
+ super().__init__(
+ model,
+ out_indices=out_indices,
+ output_fmt=output_fmt,
+ feature_concat=feature_concat,
+ flatten_sequential=flatten_sequential,
+ )
+
+ def forward(self, x) -> (List[torch.Tensor]):
+ return list(self._collect(x).values())
+
+
+class FeatureHookNet(nn.ModuleDict):
+ """ FeatureHookNet
+
+ Wrap a model and extract features specified by the out indices using forward/forward-pre hooks.
+
+ If `no_rewrite` is True, features are extracted via hooks without modifying the underlying
+ network in any way.
+
+ If `no_rewrite` is False, the model will be re-written as in the
+ FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one.
+
+ FIXME this does not currently work with Torchscript, see FeatureHooks class
+ """
+ def __init__(
+ self,
+ model: nn.Module,
+ out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4),
+ out_map: Sequence[Union[int, str]] = None,
+ return_dict: bool = False,
+ output_fmt: str = 'NCHW',
+ no_rewrite: bool = False,
+ flatten_sequential: bool = False,
+ default_hook_type: str = 'forward',
+ ):
+ """
+
+ Args:
+ model: Model from which to extract features.
+ out_indices: Output indices of the model features to extract.
+ out_map: Return id mapping for each output index, otherwise str(index) is used.
+ return_dict: Output features as a dict.
+ no_rewrite: Enforce that model is not re-written if True, ie no modules are removed / changed.
+ flatten_sequential arg must also be False if this is set True.
+ flatten_sequential: Re-write modules by flattening first two levels of nn.Sequential containers.
+ default_hook_type: The default hook type to use if not specified in model.feature_info.
+ """
+ super().__init__()
+ assert not torch.jit.is_scripting()
+ self.feature_info = _get_feature_info(model, out_indices)
+ self.return_dict = return_dict
+ self.output_fmt = Format(output_fmt)
+ self.grad_checkpointing = False
+
+ layers = OrderedDict()
+ hooks = []
+ if no_rewrite:
+ assert not flatten_sequential
+ if hasattr(model, 'reset_classifier'): # make sure classifier is removed?
+ model.reset_classifier(0)
+ layers['body'] = model
+ hooks.extend(self.feature_info.get_dicts())
+ else:
+ modules = _module_list(model, flatten_sequential=flatten_sequential)
+ remaining = {
+ f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
+ for f in self.feature_info.get_dicts()
+ }
+ for new_name, old_name, module in modules:
+ layers[new_name] = module
+ for fn, fm in module.named_modules(prefix=old_name):
+ if fn in remaining:
+ hooks.append(dict(module=fn, hook_type=remaining[fn]))
+ del remaining[fn]
+ if not remaining:
+ break
+ assert not remaining, f'Return layers ({remaining}) are not present in model'
+ self.update(layers)
+ self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map)
+
+ def set_grad_checkpointing(self, enable: bool = True):
+ self.grad_checkpointing = enable
+
+ def forward(self, x):
+ for i, (name, module) in enumerate(self.items()):
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ # Skipping checkpoint of first module because need a gradient at input
+ # Skipping last because networks with in-place ops might fail w/ checkpointing enabled
+ # NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues
+ first_or_last_module = i == 0 or i == max(len(self) - 1, 0)
+ x = module(x) if first_or_last_module else checkpoint(module, x)
+ else:
+ x = module(x)
+ out = self.hooks.get_output(x.device)
+ return out if self.return_dict else list(out.values())
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_features_fx.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_features_fx.py
new file mode 100644
index 0000000000000000000000000000000000000000..c48c13b7fcaa8094fc9f4b2170100bbcee50fc34
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_features_fx.py
@@ -0,0 +1,141 @@
+""" PyTorch FX Based Feature Extraction Helpers
+Using https://pytorch.org/vision/stable/feature_extraction.html
+"""
+from typing import Callable, List, Dict, Union, Type
+
+import torch
+from torch import nn
+
+from ._features import _get_feature_info, _get_return_layers
+
+try:
+ from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor
+ has_fx_feature_extraction = True
+except ImportError:
+ has_fx_feature_extraction = False
+
+# Layers we went to treat as leaf modules
+from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame
+from timm.layers.non_local_attn import BilinearAttnTransform
+from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
+from timm.layers.norm_act import (
+ BatchNormAct2d,
+ SyncBatchNormAct,
+ FrozenBatchNormAct2d,
+ GroupNormAct,
+ GroupNorm1Act,
+ LayerNormAct,
+ LayerNormAct2d
+)
+
+__all__ = ['register_notrace_module', 'is_notrace_module', 'get_notrace_modules',
+ 'register_notrace_function', 'is_notrace_function', 'get_notrace_functions',
+ 'create_feature_extractor', 'FeatureGraphNet', 'GraphExtractNet']
+
+
+# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here
+# BUT modules from timm.models should use the registration mechanism below
+_leaf_modules = {
+ BilinearAttnTransform, # reason: flow control t <= 1
+ # Reason: get_same_padding has a max which raises a control flow error
+ Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame,
+ CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0]),
+ BatchNormAct2d,
+ SyncBatchNormAct,
+ FrozenBatchNormAct2d,
+ GroupNormAct,
+ GroupNorm1Act,
+ LayerNormAct,
+ LayerNormAct2d,
+}
+
+try:
+ from timm.layers import InplaceAbn
+ _leaf_modules.add(InplaceAbn)
+except ImportError:
+ pass
+
+
+def register_notrace_module(module: Type[nn.Module]):
+ """
+ Any module not under timm.models.layers should get this decorator if we don't want to trace through it.
+ """
+ _leaf_modules.add(module)
+ return module
+
+
+def is_notrace_module(module: Type[nn.Module]):
+ return module in _leaf_modules
+
+
+def get_notrace_modules():
+ return list(_leaf_modules)
+
+
+# Functions we want to autowrap (treat them as leaves)
+_autowrap_functions = set()
+
+
+def register_notrace_function(func: Callable):
+ """
+ Decorator for functions which ought not to be traced through
+ """
+ _autowrap_functions.add(func)
+ return func
+
+
+def is_notrace_function(func: Callable):
+ return func in _autowrap_functions
+
+
+def get_notrace_functions():
+ return list(_autowrap_functions)
+
+
+def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]):
+ assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
+ return _create_feature_extractor(
+ model, return_nodes,
+ tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)}
+ )
+
+
+class FeatureGraphNet(nn.Module):
+ """ A FX Graph based feature extractor that works with the model feature_info metadata
+ """
+ def __init__(self, model, out_indices, out_map=None):
+ super().__init__()
+ assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
+ self.feature_info = _get_feature_info(model, out_indices)
+ if out_map is not None:
+ assert len(out_map) == len(out_indices)
+ return_nodes = _get_return_layers(self.feature_info, out_map)
+ self.graph_module = create_feature_extractor(model, return_nodes)
+
+ def forward(self, x):
+ return list(self.graph_module(x).values())
+
+
+class GraphExtractNet(nn.Module):
+ """ A standalone feature extraction wrapper that maps dict -> list or single tensor
+ NOTE:
+ * one can use feature_extractor directly if dictionary output is desired
+ * unlike FeatureGraphNet, this is intended to be used standalone and not with model feature_info
+ metadata for builtin feature extraction mode
+ * create_feature_extractor can be used directly if dictionary output is acceptable
+
+ Args:
+ model: model to extract features from
+ return_nodes: node names to return features from (dict or list)
+ squeeze_out: if only one output, and output in list format, flatten to single tensor
+ """
+ def __init__(self, model, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True):
+ super().__init__()
+ self.squeeze_out = squeeze_out
+ self.graph_module = create_feature_extractor(model, return_nodes)
+
+ def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]:
+ out = list(self.graph_module(x).values())
+ if self.squeeze_out and len(out) == 1:
+ return out[0]
+ return out
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_hub.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_hub.py
new file mode 100644
index 0000000000000000000000000000000000000000..55ab04bfa130aba01c513ef15bcb803de363c74c
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_hub.py
@@ -0,0 +1,402 @@
+import hashlib
+import json
+import logging
+import os
+from functools import partial
+from pathlib import Path
+from tempfile import TemporaryDirectory
+from typing import Iterable, Optional, Union
+
+import torch
+from torch.hub import HASH_REGEX, download_url_to_file, urlparse
+
+try:
+ from torch.hub import get_dir
+except ImportError:
+ from torch.hub import _get_torch_home as get_dir
+
+try:
+ import safetensors.torch
+ _has_safetensors = True
+except ImportError:
+ _has_safetensors = False
+
+try:
+ from typing import Literal
+except ImportError:
+ from typing_extensions import Literal
+
+from timm import __version__
+from timm.models._pretrained import filter_pretrained_cfg
+
+try:
+ from huggingface_hub import (
+ create_repo, get_hf_file_metadata,
+ hf_hub_download, hf_hub_url,
+ repo_type_and_id_from_hf_id, upload_folder)
+ from huggingface_hub.utils import EntryNotFoundError
+ hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
+ _has_hf_hub = True
+except ImportError:
+ hf_hub_download = None
+ _has_hf_hub = False
+
+_logger = logging.getLogger(__name__)
+
+__all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf',
+ 'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub']
+
+# Default name for a weights file hosted on the Huggingface Hub.
+HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
+HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version
+HF_OPEN_CLIP_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl
+HF_OPEN_CLIP_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version
+
+
+def get_cache_dir(child_dir=''):
+ """
+ Returns the location of the directory where models are cached (and creates it if necessary).
+ """
+ # Issue warning to move data if old env is set
+ if os.getenv('TORCH_MODEL_ZOO'):
+ _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead')
+
+ hub_dir = get_dir()
+ child_dir = () if not child_dir else (child_dir,)
+ model_dir = os.path.join(hub_dir, 'checkpoints', *child_dir)
+ os.makedirs(model_dir, exist_ok=True)
+ return model_dir
+
+
+def download_cached_file(url, check_hash=True, progress=False):
+ if isinstance(url, (list, tuple)):
+ url, filename = url
+ else:
+ parts = urlparse(url)
+ filename = os.path.basename(parts.path)
+ cached_file = os.path.join(get_cache_dir(), filename)
+ if not os.path.exists(cached_file):
+ _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file))
+ hash_prefix = None
+ if check_hash:
+ r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
+ hash_prefix = r.group(1) if r else None
+ download_url_to_file(url, cached_file, hash_prefix, progress=progress)
+ return cached_file
+
+
+def check_cached_file(url, check_hash=True):
+ if isinstance(url, (list, tuple)):
+ url, filename = url
+ else:
+ parts = urlparse(url)
+ filename = os.path.basename(parts.path)
+ cached_file = os.path.join(get_cache_dir(), filename)
+ if os.path.exists(cached_file):
+ if check_hash:
+ r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
+ hash_prefix = r.group(1) if r else None
+ if hash_prefix:
+ with open(cached_file, 'rb') as f:
+ hd = hashlib.sha256(f.read()).hexdigest()
+ if hd[:len(hash_prefix)] != hash_prefix:
+ return False
+ return True
+ return False
+
+
+def has_hf_hub(necessary=False):
+ if not _has_hf_hub and necessary:
+ # if no HF Hub module installed, and it is necessary to continue, raise error
+ raise RuntimeError(
+ 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
+ return _has_hf_hub
+
+
+def hf_split(hf_id: str):
+ # FIXME I may change @ -> # and be parsed as fragment in a URI model name scheme
+ rev_split = hf_id.split('@')
+ assert 0 < len(rev_split) <= 2, 'hf_hub id should only contain one @ character to identify revision.'
+ hf_model_id = rev_split[0]
+ hf_revision = rev_split[-1] if len(rev_split) > 1 else None
+ return hf_model_id, hf_revision
+
+
+def load_cfg_from_json(json_file: Union[str, os.PathLike]):
+ with open(json_file, "r", encoding="utf-8") as reader:
+ text = reader.read()
+ return json.loads(text)
+
+
+def download_from_hf(model_id: str, filename: str):
+ hf_model_id, hf_revision = hf_split(model_id)
+ return hf_hub_download(hf_model_id, filename, revision=hf_revision)
+
+
+def load_model_config_from_hf(model_id: str):
+ assert has_hf_hub(True)
+ cached_file = download_from_hf(model_id, 'config.json')
+
+ hf_config = load_cfg_from_json(cached_file)
+ if 'pretrained_cfg' not in hf_config:
+ # old form, pull pretrain_cfg out of the base dict
+ pretrained_cfg = hf_config
+ hf_config = {}
+ hf_config['architecture'] = pretrained_cfg.pop('architecture')
+ hf_config['num_features'] = pretrained_cfg.pop('num_features', None)
+ if 'labels' in pretrained_cfg: # deprecated name for 'label_names'
+ pretrained_cfg['label_names'] = pretrained_cfg.pop('labels')
+ hf_config['pretrained_cfg'] = pretrained_cfg
+
+ # NOTE currently discarding parent config as only arch name and pretrained_cfg used in timm right now
+ pretrained_cfg = hf_config['pretrained_cfg']
+ pretrained_cfg['hf_hub_id'] = model_id # insert hf_hub id for pretrained weight load during model creation
+ pretrained_cfg['source'] = 'hf-hub'
+
+ # model should be created with base config num_classes if its exist
+ if 'num_classes' in hf_config:
+ pretrained_cfg['num_classes'] = hf_config['num_classes']
+
+ # label meta-data in base config overrides saved pretrained_cfg on load
+ if 'label_names' in hf_config:
+ pretrained_cfg['label_names'] = hf_config.pop('label_names')
+ if 'label_descriptions' in hf_config:
+ pretrained_cfg['label_descriptions'] = hf_config.pop('label_descriptions')
+
+ model_args = hf_config.get('model_args', {})
+ model_name = hf_config['architecture']
+ return pretrained_cfg, model_name, model_args
+
+
+def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
+ assert has_hf_hub(True)
+ hf_model_id, hf_revision = hf_split(model_id)
+
+ # Look for .safetensors alternatives and load from it if it exists
+ if _has_safetensors:
+ for safe_filename in _get_safe_alternatives(filename):
+ try:
+ cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision)
+ _logger.info(
+ f"[{model_id}] Safe alternative available for '{filename}' "
+ f"(as '{safe_filename}'). Loading weights using safetensors.")
+ return safetensors.torch.load_file(cached_safe_file, device="cpu")
+ except EntryNotFoundError:
+ pass
+
+ # Otherwise, load using pytorch.load
+ cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
+ _logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
+ return torch.load(cached_file, map_location='cpu')
+
+
+def save_config_for_hf(
+ model,
+ config_path: str,
+ model_config: Optional[dict] = None,
+ model_args: Optional[dict] = None
+):
+ model_config = model_config or {}
+ hf_config = {}
+ pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
+ # set some values at root config level
+ hf_config['architecture'] = pretrained_cfg.pop('architecture')
+ hf_config['num_classes'] = model_config.pop('num_classes', model.num_classes)
+
+ # NOTE these attr saved for informational purposes, do not impact model build
+ hf_config['num_features'] = model_config.pop('num_features', model.num_features)
+ global_pool_type = model_config.pop('global_pool', getattr(model, 'global_pool', None))
+ if isinstance(global_pool_type, str) and global_pool_type:
+ hf_config['global_pool'] = global_pool_type
+
+ # Save class label info
+ if 'labels' in model_config:
+ _logger.warning(
+ "'labels' as a config field for is deprecated. Please use 'label_names' and 'label_descriptions'."
+ " Renaming provided 'labels' field to 'label_names'.")
+ model_config.setdefault('label_names', model_config.pop('labels'))
+
+ label_names = model_config.pop('label_names', None)
+ if label_names:
+ assert isinstance(label_names, (dict, list, tuple))
+ # map label id (classifier index) -> unique label name (ie synset for ImageNet, MID for OpenImages)
+ # can be a dict id: name if there are id gaps, or tuple/list if no gaps.
+ hf_config['label_names'] = label_names
+
+ label_descriptions = model_config.pop('label_descriptions', None)
+ if label_descriptions:
+ assert isinstance(label_descriptions, dict)
+ # maps label names -> descriptions
+ hf_config['label_descriptions'] = label_descriptions
+
+ if model_args:
+ hf_config['model_args'] = model_args
+
+ hf_config['pretrained_cfg'] = pretrained_cfg
+ hf_config.update(model_config)
+
+ with config_path.open('w') as f:
+ json.dump(hf_config, f, indent=2)
+
+
+def save_for_hf(
+ model,
+ save_directory: str,
+ model_config: Optional[dict] = None,
+ model_args: Optional[dict] = None,
+ safe_serialization: Union[bool, Literal["both"]] = False,
+):
+ assert has_hf_hub(True)
+ save_directory = Path(save_directory)
+ save_directory.mkdir(exist_ok=True, parents=True)
+
+ # Save model weights, either safely (using safetensors), or using legacy pytorch approach or both.
+ tensors = model.state_dict()
+ if safe_serialization is True or safe_serialization == "both":
+ assert _has_safetensors, "`pip install safetensors` to use .safetensors"
+ safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME)
+ if safe_serialization is False or safe_serialization == "both":
+ torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
+
+ config_path = save_directory / 'config.json'
+ save_config_for_hf(
+ model,
+ config_path,
+ model_config=model_config,
+ model_args=model_args,
+ )
+
+
+def push_to_hf_hub(
+ model: torch.nn.Module,
+ repo_id: str,
+ commit_message: str = 'Add model',
+ token: Optional[str] = None,
+ revision: Optional[str] = None,
+ private: bool = False,
+ create_pr: bool = False,
+ model_config: Optional[dict] = None,
+ model_card: Optional[dict] = None,
+ model_args: Optional[dict] = None,
+ safe_serialization: Union[bool, Literal["both"]] = False,
+):
+ """
+ Arguments:
+ (...)
+ safe_serialization (`bool` or `"both"`, *optional*, defaults to `False`):
+ Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
+ Can be set to `"both"` in order to push both safe and unsafe weights.
+ """
+ # Create repo if it doesn't exist yet
+ repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
+
+ # Infer complete repo_id from repo_url
+ # Can be different from the input `repo_id` if repo_owner was implicit
+ _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
+ repo_id = f"{repo_owner}/{repo_name}"
+
+ # Check if README file already exist in repo
+ try:
+ get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
+ has_readme = True
+ except EntryNotFoundError:
+ has_readme = False
+
+ # Dump model and push to Hub
+ with TemporaryDirectory() as tmpdir:
+ # Save model weights and config.
+ save_for_hf(
+ model,
+ tmpdir,
+ model_config=model_config,
+ model_args=model_args,
+ safe_serialization=safe_serialization,
+ )
+
+ # Add readme if it does not exist
+ if not has_readme:
+ model_card = model_card or {}
+ model_name = repo_id.split('/')[-1]
+ readme_path = Path(tmpdir) / "README.md"
+ readme_text = generate_readme(model_card, model_name)
+ readme_path.write_text(readme_text)
+
+ # Upload model and return
+ return upload_folder(
+ repo_id=repo_id,
+ folder_path=tmpdir,
+ revision=revision,
+ create_pr=create_pr,
+ commit_message=commit_message,
+ )
+
+
+def generate_readme(model_card: dict, model_name: str):
+ readme_text = "---\n"
+ readme_text += "tags:\n- image-classification\n- timm\n"
+ readme_text += "library_name: timm\n"
+ readme_text += f"license: {model_card.get('license', 'apache-2.0')}\n"
+ if 'details' in model_card and 'Dataset' in model_card['details']:
+ readme_text += 'datasets:\n'
+ if isinstance(model_card['details']['Dataset'], (tuple, list)):
+ for d in model_card['details']['Dataset']:
+ readme_text += f"- {d.lower()}\n"
+ else:
+ readme_text += f"- {model_card['details']['Dataset'].lower()}\n"
+ if 'Pretrain Dataset' in model_card['details']:
+ if isinstance(model_card['details']['Pretrain Dataset'], (tuple, list)):
+ for d in model_card['details']['Pretrain Dataset']:
+ readme_text += f"- {d.lower()}\n"
+ else:
+ readme_text += f"- {model_card['details']['Pretrain Dataset'].lower()}\n"
+ readme_text += "---\n"
+ readme_text += f"# Model card for {model_name}\n"
+ if 'description' in model_card:
+ readme_text += f"\n{model_card['description']}\n"
+ if 'details' in model_card:
+ readme_text += f"\n## Model Details\n"
+ for k, v in model_card['details'].items():
+ if isinstance(v, (list, tuple)):
+ readme_text += f"- **{k}:**\n"
+ for vi in v:
+ readme_text += f" - {vi}\n"
+ elif isinstance(v, dict):
+ readme_text += f"- **{k}:**\n"
+ for ki, vi in v.items():
+ readme_text += f" - {ki}: {vi}\n"
+ else:
+ readme_text += f"- **{k}:** {v}\n"
+ if 'usage' in model_card:
+ readme_text += f"\n## Model Usage\n"
+ readme_text += model_card['usage']
+ readme_text += '\n'
+
+ if 'comparison' in model_card:
+ readme_text += f"\n## Model Comparison\n"
+ readme_text += model_card['comparison']
+ readme_text += '\n'
+
+ if 'citation' in model_card:
+ readme_text += f"\n## Citation\n"
+ if not isinstance(model_card['citation'], (list, tuple)):
+ citations = [model_card['citation']]
+ else:
+ citations = model_card['citation']
+ for c in citations:
+ readme_text += f"```bibtex\n{c}\n```\n"
+ return readme_text
+
+
+def _get_safe_alternatives(filename: str) -> Iterable[str]:
+ """Returns potential safetensors alternatives for a given filename.
+
+ Use case:
+ When downloading a model from the Huggingface Hub, we first look if a .safetensors file exists and if yes, we use it.
+ Main use case is filename "pytorch_model.bin" => check for "model.safetensors" or "pytorch_model.safetensors".
+ """
+ if filename == HF_WEIGHTS_NAME:
+ yield HF_SAFE_WEIGHTS_NAME
+ if filename == HF_OPEN_CLIP_WEIGHTS_NAME:
+ yield HF_OPEN_CLIP_SAFE_WEIGHTS_NAME
+ if filename not in (HF_WEIGHTS_NAME, HF_OPEN_CLIP_WEIGHTS_NAME) and filename.endswith(".bin"):
+ yield filename[:-4] + ".safetensors"
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_prune.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_prune.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bbe71ecf30fb74c9b9292cc2a7bce7c3515fbd0
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/_prune.py
@@ -0,0 +1,113 @@
+import os
+import pkgutil
+from copy import deepcopy
+
+from torch import nn as nn
+
+from timm.layers import Conv2dSame, BatchNormAct2d, Linear
+
+__all__ = ['extract_layer', 'set_layer', 'adapt_model_from_string', 'adapt_model_from_file']
+
+
+def extract_layer(model, layer):
+ layer = layer.split('.')
+ module = model
+ if hasattr(model, 'module') and layer[0] != 'module':
+ module = model.module
+ if not hasattr(model, 'module') and layer[0] == 'module':
+ layer = layer[1:]
+ for l in layer:
+ if hasattr(module, l):
+ if not l.isdigit():
+ module = getattr(module, l)
+ else:
+ module = module[int(l)]
+ else:
+ return module
+ return module
+
+
+def set_layer(model, layer, val):
+ layer = layer.split('.')
+ module = model
+ if hasattr(model, 'module') and layer[0] != 'module':
+ module = model.module
+ lst_index = 0
+ module2 = module
+ for l in layer:
+ if hasattr(module2, l):
+ if not l.isdigit():
+ module2 = getattr(module2, l)
+ else:
+ module2 = module2[int(l)]
+ lst_index += 1
+ lst_index -= 1
+ for l in layer[:lst_index]:
+ if not l.isdigit():
+ module = getattr(module, l)
+ else:
+ module = module[int(l)]
+ l = layer[lst_index]
+ setattr(module, l, val)
+
+
+def adapt_model_from_string(parent_module, model_string):
+ separator = '***'
+ state_dict = {}
+ lst_shape = model_string.split(separator)
+ for k in lst_shape:
+ k = k.split(':')
+ key = k[0]
+ shape = k[1][1:-1].split(',')
+ if shape[0] != '':
+ state_dict[key] = [int(i) for i in shape]
+
+ new_module = deepcopy(parent_module)
+ for n, m in parent_module.named_modules():
+ old_module = extract_layer(parent_module, n)
+ if isinstance(old_module, nn.Conv2d) or isinstance(old_module, Conv2dSame):
+ if isinstance(old_module, Conv2dSame):
+ conv = Conv2dSame
+ else:
+ conv = nn.Conv2d
+ s = state_dict[n + '.weight']
+ in_channels = s[1]
+ out_channels = s[0]
+ g = 1
+ if old_module.groups > 1:
+ in_channels = out_channels
+ g = in_channels
+ new_conv = conv(
+ in_channels=in_channels, out_channels=out_channels, kernel_size=old_module.kernel_size,
+ bias=old_module.bias is not None, padding=old_module.padding, dilation=old_module.dilation,
+ groups=g, stride=old_module.stride)
+ set_layer(new_module, n, new_conv)
+ elif isinstance(old_module, BatchNormAct2d):
+ new_bn = BatchNormAct2d(
+ state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
+ affine=old_module.affine, track_running_stats=True)
+ new_bn.drop = old_module.drop
+ new_bn.act = old_module.act
+ set_layer(new_module, n, new_bn)
+ elif isinstance(old_module, nn.BatchNorm2d):
+ new_bn = nn.BatchNorm2d(
+ num_features=state_dict[n + '.weight'][0], eps=old_module.eps, momentum=old_module.momentum,
+ affine=old_module.affine, track_running_stats=True)
+ set_layer(new_module, n, new_bn)
+ elif isinstance(old_module, nn.Linear):
+ # FIXME extra checks to ensure this is actually the FC classifier layer and not a diff Linear layer?
+ num_features = state_dict[n + '.weight'][1]
+ new_fc = Linear(
+ in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
+ set_layer(new_module, n, new_fc)
+ if hasattr(new_module, 'num_features'):
+ new_module.num_features = num_features
+ new_module.eval()
+ parent_module.eval()
+
+ return new_module
+
+
+def adapt_model_from_file(parent_module, model_variant):
+ adapt_data = pkgutil.get_data(__name__, os.path.join('_pruned', model_variant + '.txt'))
+ return adapt_model_from_string(parent_module, adapt_data.decode('utf-8').strip())
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/beit.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/beit.py
new file mode 100644
index 0000000000000000000000000000000000000000..0167099ce7aff0ad9c6ed5aa7e2219d547a2db3d
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/beit.py
@@ -0,0 +1,621 @@
+""" BEiT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
+
+Model from official source: https://github.com/microsoft/unilm/tree/master/beit
+
+@inproceedings{beit,
+title={{BEiT}: {BERT} Pre-Training of Image Transformers},
+author={Hangbo Bao and Li Dong and Songhao Piao and Furu Wei},
+booktitle={International Conference on Learning Representations},
+year={2022},
+url={https://openreview.net/forum?id=p-BhZSz59o4}
+}
+
+BEiT-v2 from https://github.com/microsoft/unilm/tree/master/beit2
+
+@article{beitv2,
+title={{BEiT v2}: Masked Image Modeling with Vector-Quantized Visual Tokenizers},
+author={Zhiliang Peng and Li Dong and Hangbo Bao and Qixiang Ye and Furu Wei},
+year={2022},
+eprint={2208.06366},
+archivePrefix={arXiv},
+primaryClass={cs.CV}
+}
+
+At this point only the 1k fine-tuned classification weights and model configs have been added,
+see original source above for pre-training models and procedure.
+
+Modifications by / Copyright 2021 Ross Wightman, original copyrights below
+"""
+# --------------------------------------------------------
+# BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
+# Github source: https://github.com/microsoft/unilm/tree/master/beit
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# By Hangbo Bao
+# Based on timm and DeiT code bases
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm
+# https://github.com/facebookresearch/deit/
+# https://github.com/facebookresearch/dino
+# --------------------------------------------------------'
+
+import math
+from typing import Callable, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.checkpoint import checkpoint
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.layers import PatchEmbed, Mlp, SwiGLU, LayerNorm, DropPath, trunc_normal_, use_fused_attn
+from timm.layers import resample_patch_embed, resample_abs_pos_embed, resize_rel_pos_bias_table, ndgrid
+
+
+from ._builder import build_model_with_cfg
+from ._registry import generate_default_cfgs, register_model
+from .vision_transformer import checkpoint_filter_fn
+
+__all__ = ['Beit']
+
+
+def gen_relative_position_index(window_size: Tuple[int, int]) -> torch.Tensor:
+ num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+ # cls to token & token 2 cls & cls to cls
+ # get pair-wise relative position index for each token inside the window
+ window_area = window_size[0] * window_size[1]
+ coords = torch.stack(ndgrid(torch.arange(window_size[0]), torch.arange(window_size[1]))) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
+ relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype)
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ relative_position_index[0, 0:] = num_relative_distance - 3
+ relative_position_index[0:, 0] = num_relative_distance - 2
+ relative_position_index[0, 0] = num_relative_distance - 1
+ return relative_position_index
+
+
+class Attention(nn.Module):
+ fused_attn: torch.jit.Final[bool]
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ attn_drop: float = 0.,
+ proj_drop: float = 0.,
+ window_size: Optional[Tuple[int, int]] = None,
+ attn_head_dim: Optional[int] = None,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ if attn_head_dim is not None:
+ head_dim = attn_head_dim
+ all_head_dim = head_dim * self.num_heads
+ self.scale = head_dim ** -0.5
+ self.fused_attn = use_fused_attn()
+
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.register_buffer('k_bias', torch.zeros(all_head_dim), persistent=False)
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
+ else:
+ self.q_bias = None
+ self.k_bias = None
+ self.v_bias = None
+
+ if window_size:
+ self.window_size = window_size
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros(self.num_relative_distance, num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+ self.register_buffer("relative_position_index", gen_relative_position_index(window_size), persistent=False)
+ else:
+ self.window_size = None
+ self.relative_position_bias_table = None
+ self.relative_position_index = None
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(all_head_dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def _get_rel_pos_bias(self):
+ relative_position_bias = self.relative_position_bias_table[
+ self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1] + 1,
+ self.window_size[0] * self.window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ return relative_position_bias.unsqueeze(0)
+
+ def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None):
+ B, N, C = x.shape
+
+ qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
+
+ if self.fused_attn:
+ rel_pos_bias = None
+ if self.relative_position_bias_table is not None:
+ rel_pos_bias = self._get_rel_pos_bias()
+ if shared_rel_pos_bias is not None:
+ rel_pos_bias = rel_pos_bias + shared_rel_pos_bias
+ elif shared_rel_pos_bias is not None:
+ rel_pos_bias = shared_rel_pos_bias
+
+ x = F.scaled_dot_product_attention(
+ q, k, v,
+ attn_mask=rel_pos_bias,
+ dropout_p=self.attn_drop.p if self.training else 0.,
+ )
+ else:
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ if self.relative_position_bias_table is not None:
+ attn = attn + self._get_rel_pos_bias()
+ if shared_rel_pos_bias is not None:
+ attn = attn + shared_rel_pos_bias
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = attn @ v
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ qkv_bias: bool = False,
+ mlp_ratio: float = 4.,
+ scale_mlp: bool = False,
+ swiglu_mlp: bool = False,
+ proj_drop: float = 0.,
+ attn_drop: float = 0.,
+ drop_path: float = 0.,
+ init_values: Optional[float] = None,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = LayerNorm,
+ window_size: Optional[Tuple[int, int]] = None,
+ attn_head_dim: Optional[int] = None,
+ ):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ attn_drop=attn_drop,
+ proj_drop=proj_drop,
+ window_size=window_size,
+ attn_head_dim=attn_head_dim,
+ )
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ if swiglu_mlp:
+ self.mlp = SwiGLU(
+ in_features=dim,
+ hidden_features=int(dim * mlp_ratio),
+ norm_layer=norm_layer if scale_mlp else None,
+ drop=proj_drop,
+ )
+ else:
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=int(dim * mlp_ratio),
+ act_layer=act_layer,
+ norm_layer=norm_layer if scale_mlp else None,
+ drop=proj_drop,
+ )
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ if init_values:
+ self.gamma_1 = nn.Parameter(init_values * torch.ones(dim))
+ self.gamma_2 = nn.Parameter(init_values * torch.ones(dim))
+ else:
+ self.gamma_1, self.gamma_2 = None, None
+
+ def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None):
+ if self.gamma_1 is None:
+ x = x + self.drop_path1(self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias))
+ x = x + self.drop_path2(self.mlp(self.norm2(x)))
+ else:
+ x = x + self.drop_path1(self.gamma_1 * self.attn(self.norm1(x), shared_rel_pos_bias=shared_rel_pos_bias))
+ x = x + self.drop_path2(self.gamma_2 * self.mlp(self.norm2(x)))
+ return x
+
+
+class RelativePositionBias(nn.Module):
+
+ def __init__(self, window_size, num_heads):
+ super().__init__()
+ self.window_size = window_size
+ self.window_area = window_size[0] * window_size[1]
+ num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
+ self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads))
+ # trunc_normal_(self.relative_position_bias_table, std=.02)
+ self.register_buffer("relative_position_index", gen_relative_position_index(window_size))
+
+ def forward(self):
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_area + 1, self.window_area + 1, -1) # Wh*Ww,Wh*Ww,nH
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+
+
+class Beit(nn.Module):
+ """ Vision Transformer with support for patch or hybrid CNN input stage
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ num_classes: int = 1000,
+ global_pool: str = 'avg',
+ embed_dim: int = 768,
+ depth: int = 12,
+ num_heads: int = 12,
+ qkv_bias: bool = True,
+ mlp_ratio: float = 4.,
+ swiglu_mlp: bool = False,
+ scale_mlp: bool = False,
+ drop_rate: float = 0.,
+ pos_drop_rate: float = 0.,
+ proj_drop_rate: float = 0.,
+ attn_drop_rate: float = 0.,
+ drop_path_rate: float = 0.,
+ norm_layer: Callable = LayerNorm,
+ init_values: Optional[float] = None,
+ use_abs_pos_emb: bool = True,
+ use_rel_pos_bias: bool = False,
+ use_shared_rel_pos_bias: bool = False,
+ head_init_scale: float = 0.001,
+ ):
+ super().__init__()
+ self.num_classes = num_classes
+ self.global_pool = global_pool
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_prefix_tokens = 1
+ self.grad_checkpointing = False
+
+ self.patch_embed = PatchEmbed(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ )
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) if use_abs_pos_emb else None
+ self.pos_drop = nn.Dropout(p=pos_drop_rate)
+
+ if use_shared_rel_pos_bias:
+ self.rel_pos_bias = RelativePositionBias(
+ window_size=self.patch_embed.grid_size,
+ num_heads=num_heads,
+ )
+ else:
+ self.rel_pos_bias = None
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ self.blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ mlp_ratio=mlp_ratio,
+ scale_mlp=scale_mlp,
+ swiglu_mlp=swiglu_mlp,
+ proj_drop=proj_drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ init_values=init_values,
+ window_size=self.patch_embed.grid_size if use_rel_pos_bias else None,
+ )
+ for i in range(depth)])
+
+ use_fc_norm = self.global_pool == 'avg'
+ self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
+ self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
+ self.head_drop = nn.Dropout(drop_rate)
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ self.apply(self._init_weights)
+ if self.pos_embed is not None:
+ trunc_normal_(self.pos_embed, std=.02)
+ trunc_normal_(self.cls_token, std=.02)
+
+ self.fix_init_weight()
+ if isinstance(self.head, nn.Linear):
+ trunc_normal_(self.head.weight, std=.02)
+ self.head.weight.data.mul_(head_init_scale)
+ self.head.bias.data.mul_(head_init_scale)
+
+ def fix_init_weight(self):
+ def rescale(param, layer_id):
+ param.div_(math.sqrt(2.0 * layer_id))
+
+ for layer_id, layer in enumerate(self.blocks):
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ nwd = {'pos_embed', 'cls_token'}
+ for n, _ in self.named_parameters():
+ if 'relative_position_bias_table' in n:
+ nwd.add(n)
+ return nwd
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.grad_checkpointing = enable
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse=False):
+ matcher = dict(
+ stem=r'^cls_token|pos_embed|patch_embed|rel_pos_bias', # stem and embed
+ blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))],
+ )
+ return matcher
+
+ @torch.jit.ignore
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=None):
+ self.num_classes = num_classes
+ if global_pool is not None:
+ self.global_pool = global_pool
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ x = self.patch_embed(x)
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ if self.pos_embed is not None:
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
+ for blk in self.blocks:
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias)
+ else:
+ x = blk(x, shared_rel_pos_bias=rel_pos_bias)
+ x = self.norm(x)
+ return x
+
+ def forward_head(self, x, pre_logits: bool = False):
+ if self.global_pool:
+ x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
+ x = self.fc_norm(x)
+ x = self.head_drop(x)
+ return x if pre_logits else self.head(x)
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.forward_head(x)
+ return x
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
+ 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
+ **kwargs
+ }
+
+
+default_cfgs = generate_default_cfgs({
+ 'beit_base_patch16_224.in22k_ft_in22k_in1k': _cfg(
+ #url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth',
+ hf_hub_id='timm/'),
+ 'beit_base_patch16_384.in22k_ft_in22k_in1k': _cfg(
+ #url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth',
+ hf_hub_id='timm/',
+ input_size=(3, 384, 384), crop_pct=1.0,
+ ),
+ 'beit_base_patch16_224.in22k_ft_in22k': _cfg(
+ #url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth',
+ hf_hub_id='timm/',
+ num_classes=21841,
+ ),
+ 'beit_large_patch16_224.in22k_ft_in22k_in1k': _cfg(
+ #url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth',
+ hf_hub_id='timm/'),
+ 'beit_large_patch16_384.in22k_ft_in22k_in1k': _cfg(
+ #url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth',
+ hf_hub_id='timm/',
+ input_size=(3, 384, 384), crop_pct=1.0,
+ ),
+ 'beit_large_patch16_512.in22k_ft_in22k_in1k': _cfg(
+ #url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth',
+ hf_hub_id='timm/',
+ input_size=(3, 512, 512), crop_pct=1.0,
+ ),
+ 'beit_large_patch16_224.in22k_ft_in22k': _cfg(
+ #url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth',
+ hf_hub_id='timm/',
+ num_classes=21841,
+ ),
+
+ 'beitv2_base_patch16_224.in1k_ft_in22k_in1k': _cfg(
+ #url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth',
+ hf_hub_id='timm/',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
+ ),
+ 'beitv2_base_patch16_224.in1k_ft_in1k': _cfg(
+ #url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft1k.pth',
+ hf_hub_id='timm/',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
+ ),
+ 'beitv2_base_patch16_224.in1k_ft_in22k': _cfg(
+ #url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth',
+ hf_hub_id='timm/',
+ num_classes=21841, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
+ ),
+ 'beitv2_large_patch16_224.in1k_ft_in22k_in1k': _cfg(
+ #url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth',
+ hf_hub_id='timm/',
+ crop_pct=0.95, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
+ ),
+ 'beitv2_large_patch16_224.in1k_ft_in1k': _cfg(
+ #url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft1k.pth',
+ hf_hub_id='timm/',
+ crop_pct=0.95, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
+ ),
+ 'beitv2_large_patch16_224.in1k_ft_in22k': _cfg(
+ #url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth',
+ hf_hub_id='timm/',
+ num_classes=21841, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
+ ),
+})
+
+
+def _beit_checkpoint_filter_fn(state_dict, model, interpolation='bicubic', antialias=True):
+ state_dict = state_dict.get('model', state_dict)
+ state_dict = state_dict.get('module', state_dict)
+ # beit v2 didn't strip module
+
+ out_dict = {}
+ for k, v in state_dict.items():
+ if 'relative_position_index' in k:
+ continue
+ if 'patch_embed.proj.weight' in k:
+ O, I, H, W = model.patch_embed.proj.weight.shape
+ if v.shape[-1] != W or v.shape[-2] != H:
+ v = resample_patch_embed(
+ v,
+ (H, W),
+ interpolation=interpolation,
+ antialias=antialias,
+ verbose=True,
+ )
+ elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
+ # To resize pos embedding when using model at different size from pretrained weights
+ num_prefix_tokens = 1
+ v = resample_abs_pos_embed(
+ v,
+ new_size=model.patch_embed.grid_size,
+ num_prefix_tokens=num_prefix_tokens,
+ interpolation=interpolation,
+ antialias=antialias,
+ verbose=True,
+ )
+ elif k.endswith('relative_position_bias_table'):
+ m = model.get_submodule(k[:-29])
+ if v.shape != m.relative_position_bias_table.shape or m.window_size[0] != m.window_size[1]:
+ v = resize_rel_pos_bias_table(
+ v,
+ new_window_size=m.window_size,
+ new_bias_shape=m.relative_position_bias_table.shape,
+ )
+ out_dict[k] = v
+ return out_dict
+
+
+def _create_beit(variant, pretrained=False, **kwargs):
+ if kwargs.get('features_only', None):
+ raise RuntimeError('features_only not implemented for BEiT models.')
+
+ model = build_model_with_cfg(
+ Beit, variant, pretrained,
+ # FIXME an updated filter fn needed to interpolate rel pos emb if fine tuning to diff model sizes
+ pretrained_filter_fn=_beit_checkpoint_filter_fn,
+ **kwargs)
+ return model
+
+
+@register_model
+def beit_base_patch16_224(pretrained=False, **kwargs) -> Beit:
+ model_args = dict(
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
+ use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1)
+ model = _create_beit('beit_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def beit_base_patch16_384(pretrained=False, **kwargs) -> Beit:
+ model_args = dict(
+ img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12,
+ use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=0.1)
+ model = _create_beit('beit_base_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def beit_large_patch16_224(pretrained=False, **kwargs) -> Beit:
+ model_args = dict(
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16,
+ use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
+ model = _create_beit('beit_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def beit_large_patch16_384(pretrained=False, **kwargs) -> Beit:
+ model_args = dict(
+ img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16,
+ use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
+ model = _create_beit('beit_large_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def beit_large_patch16_512(pretrained=False, **kwargs) -> Beit:
+ model_args = dict(
+ img_size=512, patch_size=16, embed_dim=1024, depth=24, num_heads=16,
+ use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
+ model = _create_beit('beit_large_patch16_512', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def beitv2_base_patch16_224(pretrained=False, **kwargs) -> Beit:
+ model_args = dict(
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
+ use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
+ model = _create_beit('beitv2_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def beitv2_large_patch16_224(pretrained=False, **kwargs) -> Beit:
+ model_args = dict(
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16,
+ use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5)
+ model = _create_beit('beitv2_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/byoanet.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/byoanet.py
new file mode 100644
index 0000000000000000000000000000000000000000..683ed0ca0177fbb4afa4913856bb638b212ad6e0
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/byoanet.py
@@ -0,0 +1,455 @@
+""" Bring-Your-Own-Attention Network
+
+A flexible network w/ dataclass based config for stacking NN blocks including
+self-attention (or similar) layers.
+
+Currently used to implement experimental variants of:
+ * Bottleneck Transformers
+ * Lambda ResNets
+ * HaloNets
+
+Consider all of the models definitions here as experimental WIP and likely to change.
+
+Hacked together by / copyright Ross Wightman, 2021.
+"""
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from ._builder import build_model_with_cfg
+from ._registry import register_model, generate_default_cfgs
+from .byobnet import ByoBlockCfg, ByoModelCfg, ByobNet, interleave_blocks
+
+__all__ = []
+
+
+model_cfgs = dict(
+
+ botnet26t=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
+ ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='maxpool',
+ fixed_input_size=True,
+ self_attn_layer='bottleneck',
+ self_attn_kwargs=dict()
+ ),
+ sebotnet33ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), every=[2], d=3, c=512, s=2, gs=0, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), every=[2], d=3, c=1024, s=2, gs=0, br=0.25),
+ ByoBlockCfg('self_attn', d=2, c=1536, s=2, gs=0, br=0.333),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='',
+ act_layer='silu',
+ num_features=1280,
+ attn_layer='se',
+ self_attn_layer='bottleneck',
+ self_attn_kwargs=dict()
+ ),
+ botnet50ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), every=4, d=4, c=512, s=2, gs=0, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='maxpool',
+ act_layer='silu',
+ fixed_input_size=True,
+ self_attn_layer='bottleneck',
+ self_attn_kwargs=dict()
+ ),
+ eca_botnext26ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=16, br=0.25),
+ ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='maxpool',
+ fixed_input_size=True,
+ act_layer='silu',
+ attn_layer='eca',
+ self_attn_layer='bottleneck',
+ self_attn_kwargs=dict(dim_head=16)
+ ),
+
+ halonet_h1=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='self_attn', d=3, c=64, s=1, gs=0, br=1.0),
+ ByoBlockCfg(type='self_attn', d=3, c=128, s=2, gs=0, br=1.0),
+ ByoBlockCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0),
+ ByoBlockCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0),
+ ),
+ stem_chs=64,
+ stem_type='7x7',
+ stem_pool='maxpool',
+
+ self_attn_layer='halo',
+ self_attn_kwargs=dict(block_size=8, halo_size=3),
+ ),
+ halonet26t=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
+ ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='maxpool',
+ self_attn_layer='halo',
+ self_attn_kwargs=dict(block_size=8, halo_size=2)
+ ),
+ sehalonet33ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), every=[2], d=3, c=512, s=2, gs=0, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), every=[2], d=3, c=1024, s=2, gs=0, br=0.25),
+ ByoBlockCfg('self_attn', d=2, c=1536, s=2, gs=0, br=0.333),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='',
+ act_layer='silu',
+ num_features=1280,
+ attn_layer='se',
+ self_attn_layer='halo',
+ self_attn_kwargs=dict(block_size=8, halo_size=3)
+ ),
+ halonet50ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
+ interleave_blocks(
+ types=('bottle', 'self_attn'), every=4, d=4, c=512, s=2, gs=0, br=0.25,
+ self_attn_layer='halo', self_attn_kwargs=dict(block_size=8, halo_size=3, num_heads=4)),
+ interleave_blocks(types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='maxpool',
+ act_layer='silu',
+ self_attn_layer='halo',
+ self_attn_kwargs=dict(block_size=8, halo_size=3)
+ ),
+ eca_halonext26ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=16, br=0.25),
+ ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='maxpool',
+ act_layer='silu',
+ attn_layer='eca',
+ self_attn_layer='halo',
+ self_attn_kwargs=dict(block_size=8, halo_size=2, dim_head=16)
+ ),
+
+ lambda_resnet26t=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
+ ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='maxpool',
+ self_attn_layer='lambda',
+ self_attn_kwargs=dict(r=9)
+ ),
+ lambda_resnet50ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), every=4, d=4, c=512, s=2, gs=0, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='maxpool',
+ act_layer='silu',
+ self_attn_layer='lambda',
+ self_attn_kwargs=dict(r=9)
+ ),
+ lambda_resnet26rpt_256=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
+ interleave_blocks(types=('bottle', 'self_attn'), d=2, c=1024, s=2, gs=0, br=0.25),
+ ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='maxpool',
+ self_attn_layer='lambda',
+ self_attn_kwargs=dict(r=None)
+ ),
+
+ # experimental
+ haloregnetz_b=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=3),
+ ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=3),
+ interleave_blocks(types=('bottle', 'self_attn'), every=3, d=12, c=192, s=2, gs=16, br=3),
+ ByoBlockCfg('self_attn', d=2, c=288, s=2, gs=16, br=3),
+ ),
+ stem_chs=32,
+ stem_pool='',
+ downsample='',
+ num_features=1536,
+ act_layer='silu',
+ attn_layer='se',
+ attn_kwargs=dict(rd_ratio=0.25),
+ block_kwargs=dict(bottle_in=True, linear_out=True),
+ self_attn_layer='halo',
+ self_attn_kwargs=dict(block_size=7, halo_size=2, qk_ratio=0.33)
+ ),
+
+ # experimental
+ lamhalobotnet50ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
+ interleave_blocks(
+ types=('bottle', 'self_attn'), d=4, c=512, s=2, gs=0, br=0.25,
+ self_attn_layer='lambda', self_attn_kwargs=dict(r=13)),
+ interleave_blocks(
+ types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25,
+ self_attn_layer='halo', self_attn_kwargs=dict(halo_size=3)),
+ interleave_blocks(
+ types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25,
+ self_attn_layer='bottleneck', self_attn_kwargs=dict()),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='',
+ act_layer='silu',
+ ),
+ halo2botnet50ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
+ interleave_blocks(
+ types=('bottle', 'self_attn'), d=4, c=512, s=2, gs=0, br=0.25,
+ self_attn_layer='halo', self_attn_kwargs=dict(halo_size=3)),
+ interleave_blocks(
+ types=('bottle', 'self_attn'), d=6, c=1024, s=2, gs=0, br=0.25,
+ self_attn_layer='halo', self_attn_kwargs=dict(halo_size=3)),
+ interleave_blocks(
+ types=('bottle', 'self_attn'), d=3, c=2048, s=2, gs=0, br=0.25,
+ self_attn_layer='bottleneck', self_attn_kwargs=dict()),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='',
+ act_layer='silu',
+ ),
+)
+
+
+def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs):
+ return build_model_with_cfg(
+ ByobNet, variant, pretrained,
+ model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
+ feature_cfg=dict(flatten_sequential=True),
+ **kwargs,
+ )
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.95, 'interpolation': 'bicubic',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc',
+ 'fixed_input_size': False, 'min_input_size': (3, 224, 224),
+ **kwargs
+ }
+
+
+default_cfgs = generate_default_cfgs({
+ # GPU-Efficient (ResNet) weights
+ 'botnet26t_256.c1_in1k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/botnet26t_c1_256-167a0e9f.pth',
+ hf_hub_id='timm/',
+ fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
+ 'sebotnet33ts_256.a1h_in1k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sebotnet33ts_a1h2_256-957e3c3e.pth',
+ hf_hub_id='timm/',
+ fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
+ 'botnet50ts_256.untrained': _cfg(
+ fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
+ 'eca_botnext26ts_256.c1_in1k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_botnext26ts_c_256-95a898f6.pth',
+ hf_hub_id='timm/',
+ fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
+
+ 'halonet_h1.untrained': _cfg(input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
+ 'halonet26t.a1h_in1k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet26t_a1h_256-3083328c.pth',
+ hf_hub_id='timm/',
+ input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
+ 'sehalonet33ts.ra2_in1k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/sehalonet33ts_256-87e053f9.pth',
+ hf_hub_id='timm/',
+ input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
+ 'halonet50ts.a1h_in1k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halonet50ts_a1h2_256-f3a3daee.pth',
+ hf_hub_id='timm/',
+ input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
+ 'eca_halonext26ts.c1_in1k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_halonext26ts_c_256-06906299.pth',
+ hf_hub_id='timm/',
+ input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256), crop_pct=0.94),
+
+ 'lambda_resnet26t.c1_in1k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26t_c_256-e5a5c857.pth',
+ hf_hub_id='timm/',
+ min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
+ 'lambda_resnet50ts.a1h_in1k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet50ts_a1h_256-b87370f7.pth',
+ hf_hub_id='timm/',
+ min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)),
+ 'lambda_resnet26rpt_256.c1_in1k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lambda_resnet26rpt_c_256-ab00292d.pth',
+ hf_hub_id='timm/',
+ fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=0.94),
+
+ 'haloregnetz_b.ra3_in1k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/haloregnetz_c_raa_256-c8ad7616.pth',
+ hf_hub_id='timm/',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
+ first_conv='stem.conv', input_size=(3, 224, 224), pool_size=(7, 7), min_input_size=(3, 224, 224), crop_pct=0.94),
+
+ 'lamhalobotnet50ts_256.a1h_in1k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/lamhalobotnet50ts_a1h2_256-fe3d9445.pth',
+ hf_hub_id='timm/',
+ fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
+ 'halo2botnet50ts_256.a1h_in1k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/halo2botnet50ts_a1h2_256-fd9c11a3.pth',
+ hf_hub_id='timm/',
+ fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
+})
+
+
+@register_model
+def botnet26t_256(pretrained=False, **kwargs) -> ByobNet:
+ """ Bottleneck Transformer w/ ResNet26-T backbone.
+ """
+ kwargs.setdefault('img_size', 256)
+ return _create_byoanet('botnet26t_256', 'botnet26t', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def sebotnet33ts_256(pretrained=False, **kwargs) -> ByobNet:
+ """ Bottleneck Transformer w/ a ResNet33-t backbone, SE attn for non Halo blocks, SiLU,
+ """
+ return _create_byoanet('sebotnet33ts_256', 'sebotnet33ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def botnet50ts_256(pretrained=False, **kwargs) -> ByobNet:
+ """ Bottleneck Transformer w/ ResNet50-T backbone, silu act.
+ """
+ kwargs.setdefault('img_size', 256)
+ return _create_byoanet('botnet50ts_256', 'botnet50ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def eca_botnext26ts_256(pretrained=False, **kwargs) -> ByobNet:
+ """ Bottleneck Transformer w/ ResNet26-T backbone, silu act.
+ """
+ kwargs.setdefault('img_size', 256)
+ return _create_byoanet('eca_botnext26ts_256', 'eca_botnext26ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def halonet_h1(pretrained=False, **kwargs) -> ByobNet:
+ """ HaloNet-H1. Halo attention in all stages as per the paper.
+ NOTE: This runs very slowly!
+ """
+ return _create_byoanet('halonet_h1', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def halonet26t(pretrained=False, **kwargs) -> ByobNet:
+ """ HaloNet w/ a ResNet26-t backbone. Halo attention in final two stages
+ """
+ return _create_byoanet('halonet26t', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def sehalonet33ts(pretrained=False, **kwargs) -> ByobNet:
+ """ HaloNet w/ a ResNet33-t backbone, SE attn for non Halo blocks, SiLU, 1-2 Halo in stage 2,3,4.
+ """
+ return _create_byoanet('sehalonet33ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def halonet50ts(pretrained=False, **kwargs) -> ByobNet:
+ """ HaloNet w/ a ResNet50-t backbone, silu act. Halo attention in final two stages
+ """
+ return _create_byoanet('halonet50ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def eca_halonext26ts(pretrained=False, **kwargs) -> ByobNet:
+ """ HaloNet w/ a ResNet26-t backbone, silu act. Halo attention in final two stages
+ """
+ return _create_byoanet('eca_halonext26ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def lambda_resnet26t(pretrained=False, **kwargs) -> ByobNet:
+ """ Lambda-ResNet-26-T. Lambda layers w/ conv pos in last two stages.
+ """
+ return _create_byoanet('lambda_resnet26t', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def lambda_resnet50ts(pretrained=False, **kwargs) -> ByobNet:
+ """ Lambda-ResNet-50-TS. SiLU act. Lambda layers w/ conv pos in last two stages.
+ """
+ return _create_byoanet('lambda_resnet50ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def lambda_resnet26rpt_256(pretrained=False, **kwargs) -> ByobNet:
+ """ Lambda-ResNet-26-R-T. Lambda layers w/ rel pos embed in last two stages.
+ """
+ kwargs.setdefault('img_size', 256)
+ return _create_byoanet('lambda_resnet26rpt_256', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def haloregnetz_b(pretrained=False, **kwargs) -> ByobNet:
+ """ Halo + RegNetZ
+ """
+ return _create_byoanet('haloregnetz_b', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def lamhalobotnet50ts_256(pretrained=False, **kwargs) -> ByobNet:
+ """ Combo Attention (Lambda + Halo + Bot) Network
+ """
+ return _create_byoanet('lamhalobotnet50ts_256', 'lamhalobotnet50ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def halo2botnet50ts_256(pretrained=False, **kwargs) -> ByobNet:
+ """ Combo Attention (Halo + Halo + Bot) Network
+ """
+ return _create_byoanet('halo2botnet50ts_256', 'halo2botnet50ts', pretrained=pretrained, **kwargs)
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/byobnet.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/byobnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..a504b7262b1770509a3a80b7698d7ec3ebb2354e
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/byobnet.py
@@ -0,0 +1,2245 @@
+""" Bring-Your-Own-Blocks Network
+
+A flexible network w/ dataclass based config for stacking those NN blocks.
+
+This model is currently used to implement the following networks:
+
+GPU Efficient (ResNets) - gernet_l/m/s (original versions called genet, but this was already used (by SENet author)).
+Paper: `Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
+Code and weights: https://github.com/idstcv/GPU-Efficient-Networks, licensed Apache 2.0
+
+RepVGG - repvgg_*
+Paper: `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
+Code and weights: https://github.com/DingXiaoH/RepVGG, licensed MIT
+
+MobileOne - mobileone_*
+Paper: `MobileOne: An Improved One millisecond Mobile Backbone` - https://arxiv.org/abs/2206.04040
+Code and weights: https://github.com/apple/ml-mobileone, licensed MIT
+
+In all cases the models have been modified to fit within the design of ByobNet. I've remapped
+the original weights and verified accuracies.
+
+For GPU Efficient nets, I used the original names for the blocks since they were for the most part
+the same as original residual blocks in ResNe(X)t, DarkNet, and other existing models. Note also some
+changes introduced in RegNet were also present in the stem and bottleneck blocks for this model.
+
+A significant number of different network archs can be implemented here, including variants of the
+above nets that include attention.
+
+Hacked together by / copyright Ross Wightman, 2021.
+"""
+import math
+from dataclasses import dataclass, field, replace
+from functools import partial
+from typing import Tuple, List, Dict, Optional, Union, Any, Callable, Sequence
+
+import torch
+import torch.nn as nn
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
+ create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0a
+from ._builder import build_model_with_cfg
+from ._manipulate import named_apply, checkpoint_seq
+from ._registry import generate_default_cfgs, register_model
+
+__all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block']
+
+
+@dataclass
+class ByoBlockCfg:
+ type: Union[str, nn.Module]
+ d: int # block depth (number of block repeats in stage)
+ c: int # number of output channels for each block in stage
+ s: int = 2 # stride of stage (first block)
+ gs: Optional[Union[int, Callable]] = None # group-size of blocks in stage, conv is depthwise if gs == 1
+ br: float = 1. # bottleneck-ratio of blocks in stage
+
+ # NOTE: these config items override the model cfgs that are applied to all blocks by default
+ attn_layer: Optional[str] = None
+ attn_kwargs: Optional[Dict[str, Any]] = None
+ self_attn_layer: Optional[str] = None
+ self_attn_kwargs: Optional[Dict[str, Any]] = None
+ block_kwargs: Optional[Dict[str, Any]] = None
+
+
+@dataclass
+class ByoModelCfg:
+ blocks: Tuple[Union[ByoBlockCfg, Tuple[ByoBlockCfg, ...]], ...]
+ downsample: str = 'conv1x1'
+ stem_type: str = '3x3'
+ stem_pool: Optional[str] = 'maxpool'
+ stem_chs: int = 32
+ width_factor: float = 1.0
+ num_features: int = 0 # num out_channels for final conv, no final 1x1 conv if 0
+ zero_init_last: bool = True # zero init last weight (usually bn) in residual path
+ fixed_input_size: bool = False # model constrained to a fixed-input size / img_size must be provided on creation
+
+ act_layer: str = 'relu'
+ norm_layer: str = 'batchnorm'
+
+ # NOTE: these config items will be overridden by the block cfg (per-block) if they are set there
+ attn_layer: Optional[str] = None
+ attn_kwargs: dict = field(default_factory=lambda: dict())
+ self_attn_layer: Optional[str] = None
+ self_attn_kwargs: dict = field(default_factory=lambda: dict())
+ block_kwargs: Dict[str, Any] = field(default_factory=lambda: dict())
+
+
+def _rep_vgg_bcfg(d=(4, 6, 16, 1), wf=(1., 1., 1., 1.), groups=0):
+ c = (64, 128, 256, 512)
+ group_size = 0
+ if groups > 0:
+ group_size = lambda chs, idx: chs // groups if (idx + 1) % 2 == 0 else 0
+ bcfg = tuple([ByoBlockCfg(type='rep', d=d, c=c * wf, gs=group_size) for d, c, wf in zip(d, c, wf)])
+ return bcfg
+
+
+def _mobileone_bcfg(d=(2, 8, 10, 1), wf=(1., 1., 1., 1.), se_blocks=(), num_conv_branches=1):
+ c = (64, 128, 256, 512)
+ prev_c = min(64, c[0] * wf[0])
+ se_blocks = se_blocks or (0,) * len(d)
+ bcfg = []
+ for d, c, w, se in zip(d, c, wf, se_blocks):
+ scfg = []
+ for i in range(d):
+ out_c = c * w
+ bk = dict(num_conv_branches=num_conv_branches)
+ ak = {}
+ if i >= d - se:
+ ak['attn_layer'] = 'se'
+ scfg += [ByoBlockCfg(type='one', d=1, c=prev_c, gs=1, block_kwargs=bk, **ak)] # depthwise block
+ scfg += [ByoBlockCfg(
+ type='one', d=1, c=out_c, gs=0, block_kwargs=dict(kernel_size=1, **bk), **ak)] # pointwise block
+ prev_c = out_c
+ bcfg += [scfg]
+ return bcfg
+
+
+def interleave_blocks(
+ types: Tuple[str, str], d,
+ every: Union[int, List[int]] = 1,
+ first: bool = False,
+ **kwargs,
+) -> Tuple[ByoBlockCfg]:
+ """ interleave 2 block types in stack
+ """
+ assert len(types) == 2
+ if isinstance(every, int):
+ every = list(range(0 if first else every, d, every + 1))
+ if not every:
+ every = [d - 1]
+ set(every)
+ blocks = []
+ for i in range(d):
+ block_type = types[1] if i in every else types[0]
+ blocks += [ByoBlockCfg(type=block_type, d=1, **kwargs)]
+ return tuple(blocks)
+
+
+def expand_blocks_cfg(stage_blocks_cfg: Union[ByoBlockCfg, Sequence[ByoBlockCfg]]) -> List[ByoBlockCfg]:
+ if not isinstance(stage_blocks_cfg, Sequence):
+ stage_blocks_cfg = (stage_blocks_cfg,)
+ block_cfgs = []
+ for i, cfg in enumerate(stage_blocks_cfg):
+ block_cfgs += [replace(cfg, d=1) for _ in range(cfg.d)]
+ return block_cfgs
+
+
+def num_groups(group_size, channels):
+ if not group_size: # 0 or None
+ return 1 # normal conv with 1 group
+ else:
+ # NOTE group_size == 1 -> depthwise conv
+ assert channels % group_size == 0
+ return channels // group_size
+
+
+@dataclass
+class LayerFn:
+ conv_norm_act: Callable = ConvNormAct
+ norm_act: Callable = BatchNormAct2d
+ act: Callable = nn.ReLU
+ attn: Optional[Callable] = None
+ self_attn: Optional[Callable] = None
+
+
+class DownsampleAvg(nn.Module):
+ def __init__(
+ self,
+ in_chs: int,
+ out_chs: int,
+ stride: int = 1,
+ dilation: int = 1,
+ apply_act: bool = False,
+ layers: LayerFn = None,
+ ):
+ """ AvgPool Downsampling as in 'D' ResNet variants."""
+ super(DownsampleAvg, self).__init__()
+ layers = layers or LayerFn()
+ avg_stride = stride if dilation == 1 else 1
+ if stride > 1 or dilation > 1:
+ avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
+ self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
+ else:
+ self.pool = nn.Identity()
+ self.conv = layers.conv_norm_act(in_chs, out_chs, 1, apply_act=apply_act)
+
+ def forward(self, x):
+ return self.conv(self.pool(x))
+
+
+def create_shortcut(
+ downsample_type: str,
+ in_chs: int,
+ out_chs: int,
+ stride: int,
+ dilation: Tuple[int, int],
+ layers: LayerFn,
+ **kwargs,
+):
+ assert downsample_type in ('avg', 'conv1x1', '')
+ if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
+ if not downsample_type:
+ return None # no shortcut
+ elif downsample_type == 'avg':
+ return DownsampleAvg(in_chs, out_chs, stride=stride, dilation=dilation[0], **kwargs)
+ else:
+ return layers.conv_norm_act(in_chs, out_chs, kernel_size=1, stride=stride, dilation=dilation[0], **kwargs)
+ else:
+ return nn.Identity() # identity shortcut
+
+
+class BasicBlock(nn.Module):
+ """ ResNet Basic Block - kxk + kxk
+ """
+
+ def __init__(
+ self,
+ in_chs: int,
+ out_chs: int,
+ kernel_size: int = 3,
+ stride: int = 1,
+ dilation: Tuple[int, int] = (1, 1),
+ group_size: Optional[int] = None,
+ bottle_ratio: float = 1.0,
+ downsample: str = 'avg',
+ attn_last: bool = True,
+ linear_out: bool = False,
+ layers: LayerFn = None,
+ drop_block: Callable = None,
+ drop_path_rate: float = 0.,
+ ):
+ super(BasicBlock, self).__init__()
+ layers = layers or LayerFn()
+ mid_chs = make_divisible(out_chs * bottle_ratio)
+ groups = num_groups(group_size, mid_chs)
+
+ self.shortcut = create_shortcut(
+ downsample, in_chs, out_chs,
+ stride=stride, dilation=dilation, apply_act=False, layers=layers,
+ )
+
+ self.conv1_kxk = layers.conv_norm_act(in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0])
+ self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
+ self.conv2_kxk = layers.conv_norm_act(
+ mid_chs, out_chs, kernel_size,
+ dilation=dilation[1], groups=groups, drop_layer=drop_block, apply_act=False,
+ )
+ self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
+ self.act = nn.Identity() if linear_out else layers.act(inplace=True)
+
+ def init_weights(self, zero_init_last: bool = False):
+ if zero_init_last and self.shortcut is not None and getattr(self.conv2_kxk.bn, 'weight', None) is not None:
+ nn.init.zeros_(self.conv2_kxk.bn.weight)
+ for attn in (self.attn, self.attn_last):
+ if hasattr(attn, 'reset_parameters'):
+ attn.reset_parameters()
+
+ def forward(self, x):
+ shortcut = x
+ x = self.conv1_kxk(x)
+ x = self.conv2_kxk(x)
+ x = self.attn(x)
+ x = self.drop_path(x)
+ if self.shortcut is not None:
+ x = x + self.shortcut(shortcut)
+ return self.act(x)
+
+
+class BottleneckBlock(nn.Module):
+ """ ResNet-like Bottleneck Block - 1x1 - kxk - 1x1
+ """
+
+ def __init__(
+ self,
+ in_chs: int,
+ out_chs: int,
+ kernel_size: int = 3,
+ stride: int = 1,
+ dilation: Tuple[int, int] = (1, 1),
+ bottle_ratio: float = 1.,
+ group_size: Optional[int] = None,
+ downsample: str = 'avg',
+ attn_last: bool = False,
+ linear_out: bool = False,
+ extra_conv: bool = False,
+ bottle_in: bool = False,
+ layers: LayerFn = None,
+ drop_block: Callable = None,
+ drop_path_rate: float = 0.,
+ ):
+ super(BottleneckBlock, self).__init__()
+ layers = layers or LayerFn()
+ mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
+ groups = num_groups(group_size, mid_chs)
+
+ self.shortcut = create_shortcut(
+ downsample, in_chs, out_chs,
+ stride=stride, dilation=dilation, apply_act=False, layers=layers,
+ )
+
+ self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
+ self.conv2_kxk = layers.conv_norm_act(
+ mid_chs, mid_chs, kernel_size,
+ stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block,
+ )
+ if extra_conv:
+ self.conv2b_kxk = layers.conv_norm_act(
+ mid_chs, mid_chs, kernel_size, dilation=dilation[1], groups=groups)
+ else:
+ self.conv2b_kxk = nn.Identity()
+ self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
+ self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
+ self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
+ self.act = nn.Identity() if linear_out else layers.act(inplace=True)
+
+ def init_weights(self, zero_init_last: bool = False):
+ if zero_init_last and self.shortcut is not None and getattr(self.conv3_1x1.bn, 'weight', None) is not None:
+ nn.init.zeros_(self.conv3_1x1.bn.weight)
+ for attn in (self.attn, self.attn_last):
+ if hasattr(attn, 'reset_parameters'):
+ attn.reset_parameters()
+
+ def forward(self, x):
+ shortcut = x
+ x = self.conv1_1x1(x)
+ x = self.conv2_kxk(x)
+ x = self.conv2b_kxk(x)
+ x = self.attn(x)
+ x = self.conv3_1x1(x)
+ x = self.attn_last(x)
+ x = self.drop_path(x)
+ if self.shortcut is not None:
+ x = x + self.shortcut(shortcut)
+ return self.act(x)
+
+
+class DarkBlock(nn.Module):
+ """ DarkNet-like (1x1 + 3x3 w/ stride) block
+
+ The GE-Net impl included a 1x1 + 3x3 block in their search space. It was not used in the feature models.
+ This block is pretty much a DarkNet block (also DenseNet) hence the name. Neither DarkNet or DenseNet
+ uses strides within the block (external 3x3 or maxpool downsampling is done in front of the block repeats).
+
+ If one does want to use a lot of these blocks w/ stride, I'd recommend using the EdgeBlock (3x3 /w stride + 1x1)
+ for more optimal compute.
+ """
+
+ def __init__(
+ self,
+ in_chs: int,
+ out_chs: int,
+ kernel_size: int = 3,
+ stride: int = 1,
+ dilation: Tuple[int, int] = (1, 1),
+ bottle_ratio: float = 1.0,
+ group_size: Optional[int] = None,
+ downsample: str = 'avg',
+ attn_last: bool = True,
+ linear_out: bool = False,
+ layers: LayerFn = None,
+ drop_block: Callable = None,
+ drop_path_rate: float = 0.,
+ ):
+ super(DarkBlock, self).__init__()
+ layers = layers or LayerFn()
+ mid_chs = make_divisible(out_chs * bottle_ratio)
+ groups = num_groups(group_size, mid_chs)
+
+ self.shortcut = create_shortcut(
+ downsample, in_chs, out_chs,
+ stride=stride, dilation=dilation, apply_act=False, layers=layers,
+ )
+
+ self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
+ self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
+ self.conv2_kxk = layers.conv_norm_act(
+ mid_chs, out_chs, kernel_size,
+ stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block, apply_act=False,
+ )
+ self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
+ self.act = nn.Identity() if linear_out else layers.act(inplace=True)
+
+ def init_weights(self, zero_init_last: bool = False):
+ if zero_init_last and self.shortcut is not None and getattr(self.conv2_kxk.bn, 'weight', None) is not None:
+ nn.init.zeros_(self.conv2_kxk.bn.weight)
+ for attn in (self.attn, self.attn_last):
+ if hasattr(attn, 'reset_parameters'):
+ attn.reset_parameters()
+
+ def forward(self, x):
+ shortcut = x
+ x = self.conv1_1x1(x)
+ x = self.attn(x)
+ x = self.conv2_kxk(x)
+ x = self.attn_last(x)
+ x = self.drop_path(x)
+ if self.shortcut is not None:
+ x = x + self.shortcut(shortcut)
+ return self.act(x)
+
+
+class EdgeBlock(nn.Module):
+ """ EdgeResidual-like (3x3 + 1x1) block
+
+ A two layer block like DarkBlock, but with the order of the 3x3 and 1x1 convs reversed.
+ Very similar to the EfficientNet Edge-Residual block but this block it ends with activations, is
+ intended to be used with either expansion or bottleneck contraction, and can use DW/group/non-grouped convs.
+
+ FIXME is there a more common 3x3 + 1x1 conv block to name this after?
+ """
+
+ def __init__(
+ self,
+ in_chs: int,
+ out_chs: int,
+ kernel_size: int = 3,
+ stride: int = 1,
+ dilation: Tuple[int, int] = (1, 1),
+ bottle_ratio: float = 1.0,
+ group_size: Optional[int] = None,
+ downsample: str = 'avg',
+ attn_last: bool = False,
+ linear_out: bool = False,
+ layers: LayerFn = None,
+ drop_block: Callable = None,
+ drop_path_rate: float = 0.,
+ ):
+ super(EdgeBlock, self).__init__()
+ layers = layers or LayerFn()
+ mid_chs = make_divisible(out_chs * bottle_ratio)
+ groups = num_groups(group_size, mid_chs)
+
+ self.shortcut = create_shortcut(
+ downsample, in_chs, out_chs,
+ stride=stride, dilation=dilation, apply_act=False, layers=layers,
+ )
+
+ self.conv1_kxk = layers.conv_norm_act(
+ in_chs, mid_chs, kernel_size,
+ stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block,
+ )
+ self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
+ self.conv2_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
+ self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
+ self.act = nn.Identity() if linear_out else layers.act(inplace=True)
+
+ def init_weights(self, zero_init_last: bool = False):
+ if zero_init_last and self.shortcut is not None and getattr(self.conv2_1x1.bn, 'weight', None) is not None:
+ nn.init.zeros_(self.conv2_1x1.bn.weight)
+ for attn in (self.attn, self.attn_last):
+ if hasattr(attn, 'reset_parameters'):
+ attn.reset_parameters()
+
+ def forward(self, x):
+ shortcut = x
+ x = self.conv1_kxk(x)
+ x = self.attn(x)
+ x = self.conv2_1x1(x)
+ x = self.attn_last(x)
+ x = self.drop_path(x)
+ if self.shortcut is not None:
+ x = x + self.shortcut(shortcut)
+ return self.act(x)
+
+
+class RepVggBlock(nn.Module):
+ """ RepVGG Block.
+
+ Adapted from impl at https://github.com/DingXiaoH/RepVGG
+ """
+
+ def __init__(
+ self,
+ in_chs: int,
+ out_chs: int,
+ kernel_size: int = 3,
+ stride: int = 1,
+ dilation: Tuple[int, int] = (1, 1),
+ bottle_ratio: float = 1.0,
+ group_size: Optional[int] = None,
+ downsample: str = '',
+ layers: LayerFn = None,
+ drop_block: Callable = None,
+ drop_path_rate: float = 0.,
+ inference_mode: bool = False
+ ):
+ super(RepVggBlock, self).__init__()
+ self.groups = groups = num_groups(group_size, in_chs)
+ layers = layers or LayerFn()
+
+ if inference_mode:
+ self.reparam_conv = nn.Conv2d(
+ in_channels=in_chs,
+ out_channels=out_chs,
+ kernel_size=kernel_size,
+ stride=stride,
+ dilation=dilation,
+ groups=groups,
+ bias=True,
+ )
+ else:
+ self.reparam_conv = None
+ use_ident = in_chs == out_chs and stride == 1 and dilation[0] == dilation[1]
+ self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None
+ self.conv_kxk = layers.conv_norm_act(
+ in_chs, out_chs, kernel_size,
+ stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block, apply_act=False,
+ )
+ self.conv_1x1 = layers.conv_norm_act(in_chs, out_chs, 1, stride=stride, groups=groups, apply_act=False)
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity()
+
+ self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
+ self.act = layers.act(inplace=True)
+
+ def init_weights(self, zero_init_last: bool = False):
+ # NOTE this init overrides that base model init with specific changes for the block type
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ nn.init.normal_(m.weight, .1, .1)
+ nn.init.normal_(m.bias, 0, .1)
+ if hasattr(self.attn, 'reset_parameters'):
+ self.attn.reset_parameters()
+
+ def forward(self, x):
+ if self.reparam_conv is not None:
+ return self.act(self.attn(self.reparam_conv(x)))
+
+ if self.identity is None:
+ x = self.conv_1x1(x) + self.conv_kxk(x)
+ else:
+ identity = self.identity(x)
+ x = self.conv_1x1(x) + self.conv_kxk(x)
+ x = self.drop_path(x) # not in the paper / official impl, experimental
+ x += identity
+ x = self.attn(x) # no attn in the paper / official impl, experimental
+ return self.act(x)
+
+ def reparameterize(self):
+ """ Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
+ https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
+ architecture used at training time to obtain a plain CNN-like structure
+ for inference.
+ """
+ if self.reparam_conv is not None:
+ return
+
+ kernel, bias = self._get_kernel_bias()
+ self.reparam_conv = nn.Conv2d(
+ in_channels=self.conv_kxk.conv.in_channels,
+ out_channels=self.conv_kxk.conv.out_channels,
+ kernel_size=self.conv_kxk.conv.kernel_size,
+ stride=self.conv_kxk.conv.stride,
+ padding=self.conv_kxk.conv.padding,
+ dilation=self.conv_kxk.conv.dilation,
+ groups=self.conv_kxk.conv.groups,
+ bias=True,
+ )
+ self.reparam_conv.weight.data = kernel
+ self.reparam_conv.bias.data = bias
+
+ # Delete un-used branches
+ for name, para in self.named_parameters():
+ if 'reparam_conv' in name:
+ continue
+ para.detach_()
+ self.__delattr__('conv_kxk')
+ self.__delattr__('conv_1x1')
+ self.__delattr__('identity')
+ self.__delattr__('drop_path')
+
+ def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
+ """ Method to obtain re-parameterized kernel and bias.
+ Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
+ """
+ # get weights and bias of scale branch
+ kernel_1x1 = 0
+ bias_1x1 = 0
+ if self.conv_1x1 is not None:
+ kernel_1x1, bias_1x1 = self._fuse_bn_tensor(self.conv_1x1)
+ # Pad scale branch kernel to match conv branch kernel size.
+ pad = self.conv_kxk.conv.kernel_size[0] // 2
+ kernel_1x1 = torch.nn.functional.pad(kernel_1x1, [pad, pad, pad, pad])
+
+ # get weights and bias of skip branch
+ kernel_identity = 0
+ bias_identity = 0
+ if self.identity is not None:
+ kernel_identity, bias_identity = self._fuse_bn_tensor(self.identity)
+
+ # get weights and bias of conv branches
+ kernel_conv, bias_conv = self._fuse_bn_tensor(self.conv_kxk)
+
+ kernel_final = kernel_conv + kernel_1x1 + kernel_identity
+ bias_final = bias_conv + bias_1x1 + bias_identity
+ return kernel_final, bias_final
+
+ def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]:
+ """ Method to fuse batchnorm layer with preceeding conv layer.
+ Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
+ """
+ if isinstance(branch, ConvNormAct):
+ kernel = branch.conv.weight
+ running_mean = branch.bn.running_mean
+ running_var = branch.bn.running_var
+ gamma = branch.bn.weight
+ beta = branch.bn.bias
+ eps = branch.bn.eps
+ else:
+ assert isinstance(branch, nn.BatchNorm2d)
+ if not hasattr(self, 'id_tensor'):
+ in_chs = self.conv_kxk.conv.in_channels
+ input_dim = in_chs // self.groups
+ kernel_size = self.conv_kxk.conv.kernel_size
+ kernel_value = torch.zeros_like(self.conv_kxk.conv.weight)
+ for i in range(in_chs):
+ kernel_value[i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2] = 1
+ self.id_tensor = kernel_value
+ kernel = self.id_tensor
+ running_mean = branch.running_mean
+ running_var = branch.running_var
+ gamma = branch.weight
+ beta = branch.bias
+ eps = branch.eps
+ std = (running_var + eps).sqrt()
+ t = (gamma / std).reshape(-1, 1, 1, 1)
+ return kernel * t, beta - running_mean * gamma / std
+
+
+class MobileOneBlock(nn.Module):
+ """ MobileOne building block.
+
+ This block has a multi-branched architecture at train-time
+ and plain-CNN style architecture at inference time
+ For more details, please refer to our paper:
+ `An Improved One millisecond Mobile Backbone` -
+ https://arxiv.org/pdf/2206.04040.pdf
+ """
+
+ def __init__(
+ self,
+ in_chs: int,
+ out_chs: int,
+ kernel_size: int = 3,
+ stride: int = 1,
+ dilation: Tuple[int, int] = (1, 1),
+ bottle_ratio: float = 1.0, # unused
+ group_size: Optional[int] = None,
+ downsample: str = '', # unused
+ inference_mode: bool = False,
+ num_conv_branches: int = 1,
+ layers: LayerFn = None,
+ drop_block: Callable = None,
+ drop_path_rate: float = 0.,
+ ) -> None:
+ """ Construct a MobileOneBlock module.
+ """
+ super(MobileOneBlock, self).__init__()
+ self.num_conv_branches = num_conv_branches
+ self.groups = groups = num_groups(group_size, in_chs)
+ layers = layers or LayerFn()
+
+ if inference_mode:
+ self.reparam_conv = nn.Conv2d(
+ in_channels=in_chs,
+ out_channels=out_chs,
+ kernel_size=kernel_size,
+ stride=stride,
+ dilation=dilation,
+ groups=groups,
+ bias=True)
+ else:
+ self.reparam_conv = None
+
+ # Re-parameterizable skip connection
+ use_ident = in_chs == out_chs and stride == 1 and dilation[0] == dilation[1]
+ self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None
+
+ # Re-parameterizable conv branches
+ convs = []
+ for _ in range(self.num_conv_branches):
+ convs.append(layers.conv_norm_act(
+ in_chs, out_chs, kernel_size=kernel_size,
+ stride=stride, groups=groups, apply_act=False))
+ self.conv_kxk = nn.ModuleList(convs)
+
+ # Re-parameterizable scale branch
+ self.conv_scale = None
+ if kernel_size > 1:
+ self.conv_scale = layers.conv_norm_act(
+ in_chs, out_chs, kernel_size=1,
+ stride=stride, groups=groups, apply_act=False)
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity()
+
+ self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
+ self.act = layers.act(inplace=True)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """ Apply forward pass. """
+ # Inference mode forward pass.
+ if self.reparam_conv is not None:
+ return self.act(self.attn(self.reparam_conv(x)))
+
+ # Multi-branched train-time forward pass.
+ # Skip branch output
+ identity_out = 0
+ if self.identity is not None:
+ identity_out = self.identity(x)
+
+ # Scale branch output
+ scale_out = 0
+ if self.conv_scale is not None:
+ scale_out = self.conv_scale(x)
+
+ # Other branches
+ out = scale_out
+ for ck in self.conv_kxk:
+ out += ck(x)
+ out = self.drop_path(out)
+ out += identity_out
+
+ return self.act(self.attn(out))
+
+ def reparameterize(self):
+ """ Following works like `RepVGG: Making VGG-style ConvNets Great Again` -
+ https://arxiv.org/pdf/2101.03697.pdf. We re-parameterize multi-branched
+ architecture used at training time to obtain a plain CNN-like structure
+ for inference.
+ """
+ if self.reparam_conv is not None:
+ return
+
+ kernel, bias = self._get_kernel_bias()
+ self.reparam_conv = nn.Conv2d(
+ in_channels=self.conv_kxk[0].conv.in_channels,
+ out_channels=self.conv_kxk[0].conv.out_channels,
+ kernel_size=self.conv_kxk[0].conv.kernel_size,
+ stride=self.conv_kxk[0].conv.stride,
+ padding=self.conv_kxk[0].conv.padding,
+ dilation=self.conv_kxk[0].conv.dilation,
+ groups=self.conv_kxk[0].conv.groups,
+ bias=True)
+ self.reparam_conv.weight.data = kernel
+ self.reparam_conv.bias.data = bias
+
+ # Delete un-used branches
+ for name, para in self.named_parameters():
+ if 'reparam_conv' in name:
+ continue
+ para.detach_()
+ self.__delattr__('conv_kxk')
+ self.__delattr__('conv_scale')
+ self.__delattr__('identity')
+ self.__delattr__('drop_path')
+
+ def _get_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
+ """ Method to obtain re-parameterized kernel and bias.
+ Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L83
+ """
+ # get weights and bias of scale branch
+ kernel_scale = 0
+ bias_scale = 0
+ if self.conv_scale is not None:
+ kernel_scale, bias_scale = self._fuse_bn_tensor(self.conv_scale)
+ # Pad scale branch kernel to match conv branch kernel size.
+ pad = self.conv_kxk[0].conv.kernel_size[0] // 2
+ kernel_scale = torch.nn.functional.pad(kernel_scale, [pad, pad, pad, pad])
+
+ # get weights and bias of skip branch
+ kernel_identity = 0
+ bias_identity = 0
+ if self.identity is not None:
+ kernel_identity, bias_identity = self._fuse_bn_tensor(self.identity)
+
+ # get weights and bias of conv branches
+ kernel_conv = 0
+ bias_conv = 0
+ for ix in range(self.num_conv_branches):
+ _kernel, _bias = self._fuse_bn_tensor(self.conv_kxk[ix])
+ kernel_conv += _kernel
+ bias_conv += _bias
+
+ kernel_final = kernel_conv + kernel_scale + kernel_identity
+ bias_final = bias_conv + bias_scale + bias_identity
+ return kernel_final, bias_final
+
+ def _fuse_bn_tensor(self, branch) -> Tuple[torch.Tensor, torch.Tensor]:
+ """ Method to fuse batchnorm layer with preceeding conv layer.
+ Reference: https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py#L95
+ """
+ if isinstance(branch, ConvNormAct):
+ kernel = branch.conv.weight
+ running_mean = branch.bn.running_mean
+ running_var = branch.bn.running_var
+ gamma = branch.bn.weight
+ beta = branch.bn.bias
+ eps = branch.bn.eps
+ else:
+ assert isinstance(branch, nn.BatchNorm2d)
+ if not hasattr(self, 'id_tensor'):
+ in_chs = self.conv_kxk[0].conv.in_channels
+ input_dim = in_chs // self.groups
+ kernel_size = self.conv_kxk[0].conv.kernel_size
+ kernel_value = torch.zeros_like(self.conv_kxk[0].conv.weight)
+ for i in range(in_chs):
+ kernel_value[i, i % input_dim, kernel_size[0] // 2, kernel_size[1] // 2] = 1
+ self.id_tensor = kernel_value
+ kernel = self.id_tensor
+ running_mean = branch.running_mean
+ running_var = branch.running_var
+ gamma = branch.weight
+ beta = branch.bias
+ eps = branch.eps
+ std = (running_var + eps).sqrt()
+ t = (gamma / std).reshape(-1, 1, 1, 1)
+ return kernel * t, beta - running_mean * gamma / std
+
+
+class SelfAttnBlock(nn.Module):
+ """ ResNet-like Bottleneck Block - 1x1 - optional kxk - self attn - 1x1
+ """
+
+ def __init__(
+ self,
+ in_chs: int,
+ out_chs: int,
+ kernel_size: int = 3,
+ stride: int = 1,
+ dilation: Tuple[int, int] = (1, 1),
+ bottle_ratio: float = 1.,
+ group_size: Optional[int] = None,
+ downsample: str = 'avg',
+ extra_conv: bool = False,
+ linear_out: bool = False,
+ bottle_in: bool = False,
+ post_attn_na: bool = True,
+ feat_size: Optional[Tuple[int, int]] = None,
+ layers: LayerFn = None,
+ drop_block: Callable = None,
+ drop_path_rate: float = 0.,
+ ):
+ super(SelfAttnBlock, self).__init__()
+ assert layers is not None
+ mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
+ groups = num_groups(group_size, mid_chs)
+
+ self.shortcut = create_shortcut(
+ downsample, in_chs, out_chs,
+ stride=stride, dilation=dilation, apply_act=False, layers=layers,
+ )
+
+ self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
+ if extra_conv:
+ self.conv2_kxk = layers.conv_norm_act(
+ mid_chs, mid_chs, kernel_size,
+ stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block,
+ )
+ stride = 1 # striding done via conv if enabled
+ else:
+ self.conv2_kxk = nn.Identity()
+ opt_kwargs = {} if feat_size is None else dict(feat_size=feat_size)
+ # FIXME need to dilate self attn to have dilated network support, moop moop
+ self.self_attn = layers.self_attn(mid_chs, stride=stride, **opt_kwargs)
+ self.post_attn = layers.norm_act(mid_chs) if post_attn_na else nn.Identity()
+ self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
+ self.act = nn.Identity() if linear_out else layers.act(inplace=True)
+
+ def init_weights(self, zero_init_last: bool = False):
+ if zero_init_last and self.shortcut is not None and getattr(self.conv3_1x1.bn, 'weight', None) is not None:
+ nn.init.zeros_(self.conv3_1x1.bn.weight)
+ if hasattr(self.self_attn, 'reset_parameters'):
+ self.self_attn.reset_parameters()
+
+ def forward(self, x):
+ shortcut = x
+ x = self.conv1_1x1(x)
+ x = self.conv2_kxk(x)
+ x = self.self_attn(x)
+ x = self.post_attn(x)
+ x = self.conv3_1x1(x)
+ x = self.drop_path(x)
+ if self.shortcut is not None:
+ x = x + self.shortcut(shortcut)
+ return self.act(x)
+
+
+_block_registry = dict(
+ basic=BasicBlock,
+ bottle=BottleneckBlock,
+ dark=DarkBlock,
+ edge=EdgeBlock,
+ rep=RepVggBlock,
+ one=MobileOneBlock,
+ self_attn=SelfAttnBlock,
+)
+
+
+def register_block(block_type:str, block_fn: nn.Module):
+ _block_registry[block_type] = block_fn
+
+
+def create_block(block: Union[str, nn.Module], **kwargs):
+ if isinstance(block, (nn.Module, partial)):
+ return block(**kwargs)
+ assert block in _block_registry, f'Unknown block type ({block}'
+ return _block_registry[block](**kwargs)
+
+
+class Stem(nn.Sequential):
+
+ def __init__(
+ self,
+ in_chs: int,
+ out_chs: int,
+ kernel_size: int = 3,
+ stride: int = 4,
+ pool: str = 'maxpool',
+ num_rep: int = 3,
+ num_act: Optional[int] = None,
+ chs_decay: float = 0.5,
+ layers: LayerFn = None,
+ ):
+ super().__init__()
+ assert stride in (2, 4)
+ layers = layers or LayerFn()
+
+ if isinstance(out_chs, (list, tuple)):
+ num_rep = len(out_chs)
+ stem_chs = out_chs
+ else:
+ stem_chs = [round(out_chs * chs_decay ** i) for i in range(num_rep)][::-1]
+
+ self.stride = stride
+ self.feature_info = [] # track intermediate features
+ prev_feat = ''
+ stem_strides = [2] + [1] * (num_rep - 1)
+ if stride == 4 and not pool:
+ # set last conv in stack to be strided if stride == 4 and no pooling layer
+ stem_strides[-1] = 2
+
+ num_act = num_rep if num_act is None else num_act
+ # if num_act < num_rep, first convs in stack won't have bn + act
+ stem_norm_acts = [False] * (num_rep - num_act) + [True] * num_act
+ prev_chs = in_chs
+ curr_stride = 1
+ for i, (ch, s, na) in enumerate(zip(stem_chs, stem_strides, stem_norm_acts)):
+ layer_fn = layers.conv_norm_act if na else create_conv2d
+ conv_name = f'conv{i + 1}'
+ if i > 0 and s > 1:
+ self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
+ self.add_module(conv_name, layer_fn(prev_chs, ch, kernel_size=kernel_size, stride=s))
+ prev_chs = ch
+ curr_stride *= s
+ prev_feat = conv_name
+
+ if pool and 'max' in pool.lower():
+ self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
+ self.add_module('pool', nn.MaxPool2d(3, 2, 1))
+ curr_stride *= 2
+ prev_feat = 'pool'
+
+ self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
+ assert curr_stride == stride
+
+
+def create_byob_stem(
+ in_chs: int,
+ out_chs: int,
+ stem_type: str = '',
+ pool_type: str = '',
+ feat_prefix: str = 'stem',
+ layers: LayerFn = None,
+):
+ layers = layers or LayerFn()
+ assert stem_type in ('', 'quad', 'quad2', 'tiered', 'deep', 'rep', 'one', '7x7', '3x3')
+ if 'quad' in stem_type:
+ # based on NFNet stem, stack of 4 3x3 convs
+ num_act = 2 if 'quad2' in stem_type else None
+ stem = Stem(in_chs, out_chs, num_rep=4, num_act=num_act, pool=pool_type, layers=layers)
+ elif 'tiered' in stem_type:
+ # 3x3 stack of 3 convs as in my ResNet-T
+ stem = Stem(in_chs, (3 * out_chs // 8, out_chs // 2, out_chs), pool=pool_type, layers=layers)
+ elif 'deep' in stem_type:
+ # 3x3 stack of 3 convs as in ResNet-D
+ stem = Stem(in_chs, out_chs, num_rep=3, chs_decay=1.0, pool=pool_type, layers=layers)
+ elif 'rep' in stem_type:
+ stem = RepVggBlock(in_chs, out_chs, stride=2, layers=layers)
+ elif 'one' in stem_type:
+ stem = MobileOneBlock(in_chs, out_chs, kernel_size=3, stride=2, layers=layers)
+ elif '7x7' in stem_type:
+ # 7x7 stem conv as in ResNet
+ if pool_type:
+ stem = Stem(in_chs, out_chs, 7, num_rep=1, pool=pool_type, layers=layers)
+ else:
+ stem = layers.conv_norm_act(in_chs, out_chs, 7, stride=2)
+ else:
+ # 3x3 stem conv as in RegNet is the default
+ if pool_type:
+ stem = Stem(in_chs, out_chs, 3, num_rep=1, pool=pool_type, layers=layers)
+ else:
+ stem = layers.conv_norm_act(in_chs, out_chs, 3, stride=2)
+
+ if isinstance(stem, Stem):
+ feature_info = [dict(f, module='.'.join([feat_prefix, f['module']])) for f in stem.feature_info]
+ else:
+ feature_info = [dict(num_chs=out_chs, reduction=2, module=feat_prefix)]
+ return stem, feature_info
+
+
+def reduce_feat_size(feat_size, stride=2):
+ return None if feat_size is None else tuple([s // stride for s in feat_size])
+
+
+def override_kwargs(block_kwargs, model_kwargs):
+ """ Override model level attn/self-attn/block kwargs w/ block level
+
+ NOTE: kwargs are NOT merged across levels, block_kwargs will fully replace model_kwargs
+ for the block if set to anything that isn't None.
+
+ i.e. an empty block_kwargs dict will remove kwargs set at model level for that block
+ """
+ out_kwargs = block_kwargs if block_kwargs is not None else model_kwargs
+ return out_kwargs or {} # make sure None isn't returned
+
+
+def update_block_kwargs(block_kwargs: Dict[str, Any], block_cfg: ByoBlockCfg, model_cfg: ByoModelCfg, ):
+ layer_fns = block_kwargs['layers']
+
+ # override attn layer / args with block local config
+ attn_set = block_cfg.attn_layer is not None
+ if attn_set or block_cfg.attn_kwargs is not None:
+ # override attn layer config
+ if attn_set and not block_cfg.attn_layer:
+ # empty string for attn_layer type will disable attn for this block
+ attn_layer = None
+ else:
+ attn_kwargs = override_kwargs(block_cfg.attn_kwargs, model_cfg.attn_kwargs)
+ attn_layer = block_cfg.attn_layer or model_cfg.attn_layer
+ attn_layer = partial(get_attn(attn_layer), **attn_kwargs) if attn_layer is not None else None
+ layer_fns = replace(layer_fns, attn=attn_layer)
+
+ # override self-attn layer / args with block local cfg
+ self_attn_set = block_cfg.self_attn_layer is not None
+ if self_attn_set or block_cfg.self_attn_kwargs is not None:
+ # override attn layer config
+ if self_attn_set and not block_cfg.self_attn_layer: # attn_layer == ''
+ # empty string for self_attn_layer type will disable attn for this block
+ self_attn_layer = None
+ else:
+ self_attn_kwargs = override_kwargs(block_cfg.self_attn_kwargs, model_cfg.self_attn_kwargs)
+ self_attn_layer = block_cfg.self_attn_layer or model_cfg.self_attn_layer
+ self_attn_layer = partial(get_attn(self_attn_layer), **self_attn_kwargs) \
+ if self_attn_layer is not None else None
+ layer_fns = replace(layer_fns, self_attn=self_attn_layer)
+
+ block_kwargs['layers'] = layer_fns
+
+ # add additional block_kwargs specified in block_cfg or model_cfg, precedence to block if set
+ block_kwargs.update(override_kwargs(block_cfg.block_kwargs, model_cfg.block_kwargs))
+
+
+def create_byob_stages(
+ cfg: ByoModelCfg,
+ drop_path_rate: float,
+ output_stride: int,
+ stem_feat: Dict[str, Any],
+ feat_size: Optional[int] = None,
+ layers: Optional[LayerFn] = None,
+ block_kwargs_fn: Optional[Callable] = update_block_kwargs,
+):
+
+ layers = layers or LayerFn()
+ feature_info = []
+ block_cfgs = [expand_blocks_cfg(s) for s in cfg.blocks]
+ depths = [sum([bc.d for bc in stage_bcs]) for stage_bcs in block_cfgs]
+ dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
+ dilation = 1
+ net_stride = stem_feat['reduction']
+ prev_chs = stem_feat['num_chs']
+ prev_feat = stem_feat
+ stages = []
+ for stage_idx, stage_block_cfgs in enumerate(block_cfgs):
+ stride = stage_block_cfgs[0].s
+ if stride != 1 and prev_feat:
+ feature_info.append(prev_feat)
+ if net_stride >= output_stride and stride > 1:
+ dilation *= stride
+ stride = 1
+ net_stride *= stride
+ first_dilation = 1 if dilation in (1, 2) else 2
+
+ blocks = []
+ for block_idx, block_cfg in enumerate(stage_block_cfgs):
+ out_chs = make_divisible(block_cfg.c * cfg.width_factor)
+ group_size = block_cfg.gs
+ if isinstance(group_size, Callable):
+ group_size = group_size(out_chs, block_idx)
+ block_kwargs = dict( # Blocks used in this model must accept these arguments
+ in_chs=prev_chs,
+ out_chs=out_chs,
+ stride=stride if block_idx == 0 else 1,
+ dilation=(first_dilation, dilation),
+ group_size=group_size,
+ bottle_ratio=block_cfg.br,
+ downsample=cfg.downsample,
+ drop_path_rate=dpr[stage_idx][block_idx],
+ layers=layers,
+ )
+ if block_cfg.type in ('self_attn',):
+ # add feat_size arg for blocks that support/need it
+ block_kwargs['feat_size'] = feat_size
+ block_kwargs_fn(block_kwargs, block_cfg=block_cfg, model_cfg=cfg)
+ blocks += [create_block(block_cfg.type, **block_kwargs)]
+ first_dilation = dilation
+ prev_chs = out_chs
+ if stride > 1 and block_idx == 0:
+ feat_size = reduce_feat_size(feat_size, stride)
+
+ stages += [nn.Sequential(*blocks)]
+ prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}')
+
+ feature_info.append(prev_feat)
+ return nn.Sequential(*stages), feature_info
+
+
+def get_layer_fns(cfg: ByoModelCfg):
+ act = get_act_layer(cfg.act_layer)
+ norm_act = get_norm_act_layer(norm_layer=cfg.norm_layer, act_layer=act)
+ conv_norm_act = partial(ConvNormAct, norm_layer=cfg.norm_layer, act_layer=act)
+ attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
+ self_attn = partial(get_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None
+ layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn)
+ return layer_fn
+
+
+class ByobNet(nn.Module):
+ """ 'Bring-your-own-blocks' Net
+
+ A flexible network backbone that allows building model stem + blocks via
+ dataclass cfg definition w/ factory functions for module instantiation.
+
+ Current assumption is that both stem and blocks are in conv-bn-act order (w/ block ending in act).
+ """
+ def __init__(
+ self,
+ cfg: ByoModelCfg,
+ num_classes: int = 1000,
+ in_chans: int = 3,
+ global_pool: str = 'avg',
+ output_stride: int = 32,
+ img_size: Optional[Union[int, Tuple[int, int]]] = None,
+ drop_rate: float = 0.,
+ drop_path_rate: float =0.,
+ zero_init_last: bool = True,
+ **kwargs,
+ ):
+ """
+ Args:
+ cfg: Model architecture configuration.
+ num_classes: Number of classifier classes.
+ in_chans: Number of input channels.
+ global_pool: Global pooling type.
+ output_stride: Output stride of network, one of (8, 16, 32).
+ img_size: Image size for fixed image size models (i.e. self-attn).
+ drop_rate: Classifier dropout rate.
+ drop_path_rate: Stochastic depth drop-path rate.
+ zero_init_last: Zero-init last weight of residual path.
+ **kwargs: Extra kwargs overlayed onto cfg.
+ """
+ super().__init__()
+ self.num_classes = num_classes
+ self.drop_rate = drop_rate
+ self.grad_checkpointing = False
+
+ cfg = replace(cfg, **kwargs) # overlay kwargs onto cfg
+ layers = get_layer_fns(cfg)
+ if cfg.fixed_input_size:
+ assert img_size is not None, 'img_size argument is required for fixed input size model'
+ feat_size = to_2tuple(img_size) if img_size is not None else None
+
+ self.feature_info = []
+ stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor))
+ self.stem, stem_feat = create_byob_stem(in_chans, stem_chs, cfg.stem_type, cfg.stem_pool, layers=layers)
+ self.feature_info.extend(stem_feat[:-1])
+ feat_size = reduce_feat_size(feat_size, stride=stem_feat[-1]['reduction'])
+
+ self.stages, stage_feat = create_byob_stages(
+ cfg,
+ drop_path_rate,
+ output_stride,
+ stem_feat[-1],
+ layers=layers,
+ feat_size=feat_size,
+ )
+ self.feature_info.extend(stage_feat[:-1])
+
+ prev_chs = stage_feat[-1]['num_chs']
+ if cfg.num_features:
+ self.num_features = int(round(cfg.width_factor * cfg.num_features))
+ self.final_conv = layers.conv_norm_act(prev_chs, self.num_features, 1)
+ else:
+ self.num_features = prev_chs
+ self.final_conv = nn.Identity()
+ self.feature_info += [
+ dict(num_chs=self.num_features, reduction=stage_feat[-1]['reduction'], module='final_conv')]
+
+ self.head = ClassifierHead(
+ self.num_features,
+ num_classes,
+ pool_type=global_pool,
+ drop_rate=self.drop_rate,
+ )
+
+ # init weights
+ named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse=False):
+ matcher = dict(
+ stem=r'^stem',
+ blocks=[
+ (r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.(\d+)', None),
+ (r'^final_conv', (99999,))
+ ]
+ )
+ return matcher
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.grad_checkpointing = enable
+
+ @torch.jit.ignore
+ def get_classifier(self):
+ return self.head.fc
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.head.reset(num_classes, global_pool)
+
+ def forward_features(self, x):
+ x = self.stem(x)
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint_seq(self.stages, x)
+ else:
+ x = self.stages(x)
+ x = self.final_conv(x)
+ return x
+
+ def forward_head(self, x, pre_logits: bool = False):
+ return self.head(x, pre_logits=pre_logits)
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.forward_head(x)
+ return x
+
+
+def _init_weights(module, name='', zero_init_last=False):
+ if isinstance(module, nn.Conv2d):
+ fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
+ fan_out //= module.groups
+ module.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Linear):
+ nn.init.normal_(module.weight, mean=0.0, std=0.01)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, nn.BatchNorm2d):
+ nn.init.ones_(module.weight)
+ nn.init.zeros_(module.bias)
+ elif hasattr(module, 'init_weights'):
+ module.init_weights(zero_init_last=zero_init_last)
+
+
+model_cfgs = dict(
+ gernet_l=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.),
+ ByoBlockCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.),
+ ByoBlockCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4),
+ ByoBlockCfg(type='bottle', d=5, c=640, s=2, gs=1, br=3.),
+ ByoBlockCfg(type='bottle', d=4, c=640, s=1, gs=1, br=3.),
+ ),
+ stem_chs=32,
+ stem_pool=None,
+ num_features=2560,
+ ),
+ gernet_m=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.),
+ ByoBlockCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.),
+ ByoBlockCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4),
+ ByoBlockCfg(type='bottle', d=4, c=640, s=2, gs=1, br=3.),
+ ByoBlockCfg(type='bottle', d=1, c=640, s=1, gs=1, br=3.),
+ ),
+ stem_chs=32,
+ stem_pool=None,
+ num_features=2560,
+ ),
+ gernet_s=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='basic', d=1, c=48, s=2, gs=0, br=1.),
+ ByoBlockCfg(type='basic', d=3, c=48, s=2, gs=0, br=1.),
+ ByoBlockCfg(type='bottle', d=7, c=384, s=2, gs=0, br=1 / 4),
+ ByoBlockCfg(type='bottle', d=2, c=560, s=2, gs=1, br=3.),
+ ByoBlockCfg(type='bottle', d=1, c=256, s=1, gs=1, br=3.),
+ ),
+ stem_chs=13,
+ stem_pool=None,
+ num_features=1920,
+ ),
+
+ repvgg_a0=ByoModelCfg(
+ blocks=_rep_vgg_bcfg(d=(2, 4, 14, 1), wf=(0.75, 0.75, 0.75, 2.5)),
+ stem_type='rep',
+ stem_chs=48,
+ ),
+ repvgg_a1=ByoModelCfg(
+ blocks=_rep_vgg_bcfg(d=(2, 4, 14, 1), wf=(1, 1, 1, 2.5)),
+ stem_type='rep',
+ stem_chs=64,
+ ),
+ repvgg_a2=ByoModelCfg(
+ blocks=_rep_vgg_bcfg(d=(2, 4, 14, 1), wf=(1.5, 1.5, 1.5, 2.75)),
+ stem_type='rep',
+ stem_chs=64,
+ ),
+ repvgg_b0=ByoModelCfg(
+ blocks=_rep_vgg_bcfg(wf=(1., 1., 1., 2.5)),
+ stem_type='rep',
+ stem_chs=64,
+ ),
+ repvgg_b1=ByoModelCfg(
+ blocks=_rep_vgg_bcfg(wf=(2., 2., 2., 4.)),
+ stem_type='rep',
+ stem_chs=64,
+ ),
+ repvgg_b1g4=ByoModelCfg(
+ blocks=_rep_vgg_bcfg(wf=(2., 2., 2., 4.), groups=4),
+ stem_type='rep',
+ stem_chs=64,
+ ),
+ repvgg_b2=ByoModelCfg(
+ blocks=_rep_vgg_bcfg(wf=(2.5, 2.5, 2.5, 5.)),
+ stem_type='rep',
+ stem_chs=64,
+ ),
+ repvgg_b2g4=ByoModelCfg(
+ blocks=_rep_vgg_bcfg(wf=(2.5, 2.5, 2.5, 5.), groups=4),
+ stem_type='rep',
+ stem_chs=64,
+ ),
+ repvgg_b3=ByoModelCfg(
+ blocks=_rep_vgg_bcfg(wf=(3., 3., 3., 5.)),
+ stem_type='rep',
+ stem_chs=64,
+ ),
+ repvgg_b3g4=ByoModelCfg(
+ blocks=_rep_vgg_bcfg(wf=(3., 3., 3., 5.), groups=4),
+ stem_type='rep',
+ stem_chs=64,
+ ),
+ repvgg_d2se=ByoModelCfg(
+ blocks=_rep_vgg_bcfg(d=(8, 14, 24, 1), wf=(2.5, 2.5, 2.5, 5.)),
+ stem_type='rep',
+ stem_chs=64,
+ attn_layer='se',
+ attn_kwargs=dict(rd_ratio=0.0625, rd_divisor=1),
+ ),
+
+ # 4 x conv stem w/ 2 act, no maxpool, 2,4,6,4 repeats, group size 32 in first 3 blocks
+ # DW convs in last block, 2048 pre-FC, silu act
+ resnet51q=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=6, c=1536, s=2, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=4, c=1536, s=2, gs=1, br=1.0),
+ ),
+ stem_chs=128,
+ stem_type='quad2',
+ stem_pool=None,
+ num_features=2048,
+ act_layer='silu',
+ ),
+
+ # 4 x conv stem w/ 4 act, no maxpool, 1,4,6,4 repeats, edge block first, group size 32 in next 2 blocks
+ # DW convs in last block, 4 conv for each bottle block, 2048 pre-FC, silu act
+ resnet61q=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='edge', d=1, c=256, s=1, gs=0, br=1.0, block_kwargs=dict()),
+ ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=6, c=1536, s=2, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=4, c=1536, s=2, gs=1, br=1.0),
+ ),
+ stem_chs=128,
+ stem_type='quad',
+ stem_pool=None,
+ num_features=2048,
+ act_layer='silu',
+ block_kwargs=dict(extra_conv=True),
+ ),
+
+ # A series of ResNeXt-26 models w/ one of none, GC, SE, ECA, BAT attn, group size 32, SiLU act,
+ # and a tiered stem w/ maxpool
+ resnext26ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=1024, s=2, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=2048, s=2, gs=32, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='maxpool',
+ act_layer='silu',
+ ),
+ gcresnext26ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=1024, s=2, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=2048, s=2, gs=32, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='maxpool',
+ act_layer='silu',
+ attn_layer='gca',
+ ),
+ seresnext26ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=1024, s=2, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=2048, s=2, gs=32, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='maxpool',
+ act_layer='silu',
+ attn_layer='se',
+ ),
+ eca_resnext26ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=1024, s=2, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=2048, s=2, gs=32, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='maxpool',
+ act_layer='silu',
+ attn_layer='eca',
+ ),
+ bat_resnext26ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=1024, s=2, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=2048, s=2, gs=32, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='maxpool',
+ act_layer='silu',
+ attn_layer='bat',
+ attn_kwargs=dict(block_size=8)
+ ),
+
+ # ResNet-32 (2, 3, 3, 2) models w/ no attn, no groups, SiLU act, no pre-fc feat layer, tiered stem w/o maxpool
+ resnet32ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=3, c=1536, s=2, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=1536, s=2, gs=0, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='',
+ num_features=0,
+ act_layer='silu',
+ ),
+
+ # ResNet-33 (2, 3, 3, 2) models w/ no attn, no groups, SiLU act, 1280 pre-FC feat, tiered stem w/o maxpool
+ resnet33ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=3, c=1536, s=2, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=1536, s=2, gs=0, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='',
+ num_features=1280,
+ act_layer='silu',
+ ),
+
+ # A series of ResNet-33 (2, 3, 3, 2) models w/ one of GC, SE, ECA attn, no groups, SiLU act, 1280 pre-FC feat
+ # and a tiered stem w/ no maxpool
+ gcresnet33ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=3, c=1536, s=2, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=1536, s=2, gs=0, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='',
+ num_features=1280,
+ act_layer='silu',
+ attn_layer='gca',
+ ),
+ seresnet33ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=3, c=1536, s=2, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=1536, s=2, gs=0, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='',
+ num_features=1280,
+ act_layer='silu',
+ attn_layer='se',
+ ),
+ eca_resnet33ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=3, c=1536, s=2, gs=0, br=0.25),
+ ByoBlockCfg(type='bottle', d=2, c=1536, s=2, gs=0, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='',
+ num_features=1280,
+ act_layer='silu',
+ attn_layer='eca',
+ ),
+
+ gcresnet50t=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=3, c=256, s=1, br=0.25),
+ ByoBlockCfg(type='bottle', d=4, c=512, s=2, br=0.25),
+ ByoBlockCfg(type='bottle', d=6, c=1024, s=2, br=0.25),
+ ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='',
+ attn_layer='gca',
+ ),
+
+ gcresnext50ts=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=6, c=1024, s=2, gs=32, br=0.25),
+ ByoBlockCfg(type='bottle', d=3, c=2048, s=2, gs=32, br=0.25),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='maxpool',
+ act_layer='silu',
+ attn_layer='gca',
+ ),
+
+ # experimental models, closer to a RegNetZ than a ResNet. Similar to EfficientNets but w/ groups instead of DW
+ regnetz_b16=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=3),
+ ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=3),
+ ByoBlockCfg(type='bottle', d=12, c=192, s=2, gs=16, br=3),
+ ByoBlockCfg(type='bottle', d=2, c=288, s=2, gs=16, br=3),
+ ),
+ stem_chs=32,
+ stem_pool='',
+ downsample='',
+ num_features=1536,
+ act_layer='silu',
+ attn_layer='se',
+ attn_kwargs=dict(rd_ratio=0.25),
+ block_kwargs=dict(bottle_in=True, linear_out=True),
+ ),
+ regnetz_c16=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=4),
+ ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=4),
+ ByoBlockCfg(type='bottle', d=12, c=192, s=2, gs=16, br=4),
+ ByoBlockCfg(type='bottle', d=2, c=288, s=2, gs=16, br=4),
+ ),
+ stem_chs=32,
+ stem_pool='',
+ downsample='',
+ num_features=1536,
+ act_layer='silu',
+ attn_layer='se',
+ attn_kwargs=dict(rd_ratio=0.25),
+ block_kwargs=dict(bottle_in=True, linear_out=True),
+ ),
+ regnetz_d32=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=32, br=4),
+ ByoBlockCfg(type='bottle', d=6, c=128, s=2, gs=32, br=4),
+ ByoBlockCfg(type='bottle', d=12, c=256, s=2, gs=32, br=4),
+ ByoBlockCfg(type='bottle', d=3, c=384, s=2, gs=32, br=4),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='',
+ downsample='',
+ num_features=1792,
+ act_layer='silu',
+ attn_layer='se',
+ attn_kwargs=dict(rd_ratio=0.25),
+ block_kwargs=dict(bottle_in=True, linear_out=True),
+ ),
+ regnetz_d8=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=8, br=4),
+ ByoBlockCfg(type='bottle', d=6, c=128, s=2, gs=8, br=4),
+ ByoBlockCfg(type='bottle', d=12, c=256, s=2, gs=8, br=4),
+ ByoBlockCfg(type='bottle', d=3, c=384, s=2, gs=8, br=4),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='',
+ downsample='',
+ num_features=1792,
+ act_layer='silu',
+ attn_layer='se',
+ attn_kwargs=dict(rd_ratio=0.25),
+ block_kwargs=dict(bottle_in=True, linear_out=True),
+ ),
+ regnetz_e8=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=3, c=96, s=1, gs=8, br=4),
+ ByoBlockCfg(type='bottle', d=8, c=192, s=2, gs=8, br=4),
+ ByoBlockCfg(type='bottle', d=16, c=384, s=2, gs=8, br=4),
+ ByoBlockCfg(type='bottle', d=3, c=512, s=2, gs=8, br=4),
+ ),
+ stem_chs=64,
+ stem_type='tiered',
+ stem_pool='',
+ downsample='',
+ num_features=2048,
+ act_layer='silu',
+ attn_layer='se',
+ attn_kwargs=dict(rd_ratio=0.25),
+ block_kwargs=dict(bottle_in=True, linear_out=True),
+ ),
+
+ # experimental EvoNorm configs
+ regnetz_b16_evos=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=3),
+ ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=3),
+ ByoBlockCfg(type='bottle', d=12, c=192, s=2, gs=16, br=3),
+ ByoBlockCfg(type='bottle', d=2, c=288, s=2, gs=16, br=3),
+ ),
+ stem_chs=32,
+ stem_pool='',
+ downsample='',
+ num_features=1536,
+ act_layer='silu',
+ norm_layer=partial(EvoNorm2dS0a, group_size=16),
+ attn_layer='se',
+ attn_kwargs=dict(rd_ratio=0.25),
+ block_kwargs=dict(bottle_in=True, linear_out=True),
+ ),
+ regnetz_c16_evos=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=4),
+ ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=4),
+ ByoBlockCfg(type='bottle', d=12, c=192, s=2, gs=16, br=4),
+ ByoBlockCfg(type='bottle', d=2, c=288, s=2, gs=16, br=4),
+ ),
+ stem_chs=32,
+ stem_pool='',
+ downsample='',
+ num_features=1536,
+ act_layer='silu',
+ norm_layer=partial(EvoNorm2dS0a, group_size=16),
+ attn_layer='se',
+ attn_kwargs=dict(rd_ratio=0.25),
+ block_kwargs=dict(bottle_in=True, linear_out=True),
+ ),
+ regnetz_d8_evos=ByoModelCfg(
+ blocks=(
+ ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=8, br=4),
+ ByoBlockCfg(type='bottle', d=6, c=128, s=2, gs=8, br=4),
+ ByoBlockCfg(type='bottle', d=12, c=256, s=2, gs=8, br=4),
+ ByoBlockCfg(type='bottle', d=3, c=384, s=2, gs=8, br=4),
+ ),
+ stem_chs=64,
+ stem_type='deep',
+ stem_pool='',
+ downsample='',
+ num_features=1792,
+ act_layer='silu',
+ norm_layer=partial(EvoNorm2dS0a, group_size=16),
+ attn_layer='se',
+ attn_kwargs=dict(rd_ratio=0.25),
+ block_kwargs=dict(bottle_in=True, linear_out=True),
+ ),
+
+ mobileone_s0=ByoModelCfg(
+ blocks=_mobileone_bcfg(wf=(0.75, 1.0, 1.0, 2.), num_conv_branches=4),
+ stem_type='one',
+ stem_chs=48,
+ ),
+ mobileone_s1=ByoModelCfg(
+ blocks=_mobileone_bcfg(wf=(1.5, 1.5, 2.0, 2.5)),
+ stem_type='one',
+ stem_chs=64,
+ ),
+ mobileone_s2=ByoModelCfg(
+ blocks=_mobileone_bcfg(wf=(1.5, 2.0, 2.5, 4.0)),
+ stem_type='one',
+ stem_chs=64,
+ ),
+ mobileone_s3=ByoModelCfg(
+ blocks=_mobileone_bcfg(wf=(2.0, 2.5, 3.0, 4.0)),
+ stem_type='one',
+ stem_chs=64,
+ ),
+ mobileone_s4=ByoModelCfg(
+ blocks=_mobileone_bcfg(wf=(3.0, 3.5, 3.5, 4.0), se_blocks=(0, 0, 5, 1)),
+ stem_type='one',
+ stem_chs=64,
+ ),
+)
+
+
+def _create_byobnet(variant, pretrained=False, **kwargs):
+ return build_model_with_cfg(
+ ByobNet, variant, pretrained,
+ model_cfg=model_cfgs[variant],
+ feature_cfg=dict(flatten_sequential=True),
+ **kwargs)
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.875, 'interpolation': 'bilinear',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'stem.conv', 'classifier': 'head.fc',
+ **kwargs
+ }
+
+
+def _cfgr(url='', **kwargs):
+ return {
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
+ 'crop_pct': 0.9, 'interpolation': 'bicubic',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc',
+ **kwargs
+ }
+
+
+default_cfgs = generate_default_cfgs({
+ # GPU-Efficient (ResNet) weights
+ 'gernet_s.idstcv_in1k': _cfg(hf_hub_id='timm/'),
+ 'gernet_m.idstcv_in1k': _cfg(hf_hub_id='timm/'),
+ 'gernet_l.idstcv_in1k': _cfg(hf_hub_id='timm/', input_size=(3, 256, 256), pool_size=(8, 8)),
+
+ # RepVGG weights
+ 'repvgg_a0.rvgg_in1k': _cfg(
+ hf_hub_id='timm/',
+ first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
+ 'repvgg_a1.rvgg_in1k': _cfg(
+ hf_hub_id='timm/',
+ first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
+ 'repvgg_a2.rvgg_in1k': _cfg(
+ hf_hub_id='timm/',
+ first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
+ 'repvgg_b0.rvgg_in1k': _cfg(
+ hf_hub_id='timm/',
+ first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
+ 'repvgg_b1.rvgg_in1k': _cfg(
+ hf_hub_id='timm/',
+ first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
+ 'repvgg_b1g4.rvgg_in1k': _cfg(
+ hf_hub_id='timm/',
+ first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
+ 'repvgg_b2.rvgg_in1k': _cfg(
+ hf_hub_id='timm/',
+ first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
+ 'repvgg_b2g4.rvgg_in1k': _cfg(
+ hf_hub_id='timm/',
+ first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
+ 'repvgg_b3.rvgg_in1k': _cfg(
+ hf_hub_id='timm/',
+ first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
+ 'repvgg_b3g4.rvgg_in1k': _cfg(
+ hf_hub_id='timm/',
+ first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit'),
+ 'repvgg_d2se.rvgg_in1k': _cfg(
+ hf_hub_id='timm/',
+ first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv'), license='mit',
+ input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0,
+ ),
+
+ # experimental ResNet configs
+ 'resnet51q.ra2_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet51q_ra2-d47dcc76.pth',
+ first_conv='stem.conv1', input_size=(3, 256, 256), pool_size=(8, 8),
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
+ 'resnet61q.ra2_in1k': _cfgr(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet61q_ra2-6afc536c.pth',
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
+
+ # ResNeXt-26 models with different attention in Bottleneck blocks
+ 'resnext26ts.ra2_in1k': _cfgr(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnext26ts_256_ra2-8bbd9106.pth',
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
+ 'seresnext26ts.ch_in1k': _cfgr(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/seresnext26ts_256-6f0d74a3.pth',
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
+ 'gcresnext26ts.ch_in1k': _cfgr(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext26ts_256-e414378b.pth',
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
+ 'eca_resnext26ts.ch_in1k': _cfgr(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_resnext26ts_256-5a1d030f.pth',
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
+ 'bat_resnext26ts.ch_in1k': _cfgr(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/bat_resnext26ts_256-fa6fd595.pth',
+ min_input_size=(3, 256, 256)),
+
+ # ResNet-32 / 33 models with different attention in Bottleneck blocks
+ 'resnet32ts.ra2_in1k': _cfgr(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet32ts_256-aacf5250.pth',
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
+ 'resnet33ts.ra2_in1k': _cfgr(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/resnet33ts_256-e91b09a4.pth',
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
+ 'gcresnet33ts.ra2_in1k': _cfgr(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnet33ts_256-0e0cd345.pth',
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
+ 'seresnet33ts.ra2_in1k': _cfgr(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/seresnet33ts_256-f8ad44d9.pth',
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
+ 'eca_resnet33ts.ra2_in1k': _cfgr(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/eca_resnet33ts_256-8f98face.pth',
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
+
+ 'gcresnet50t.ra2_in1k': _cfgr(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnet50t_256-96374d1c.pth',
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
+
+ 'gcresnext50ts.ch_in1k': _cfgr(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/gcresnext50ts_256-3e0f515e.pth',
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
+
+ # custom `timm` specific RegNetZ inspired models w/ different sizing from paper
+ 'regnetz_b16.ra3_in1k': _cfgr(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_b_raa-677d9606.pth',
+ first_conv='stem.conv', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
+ input_size=(3, 224, 224), pool_size=(7, 7), crop_pct=0.94, test_input_size=(3, 288, 288), test_crop_pct=1.0),
+ 'regnetz_c16.ra3_in1k': _cfgr(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_c_rab2_256-a54bf36a.pth',
+ first_conv='stem.conv', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
+ crop_pct=0.94, test_input_size=(3, 320, 320), test_crop_pct=1.0),
+ 'regnetz_d32.ra3_in1k': _cfgr(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_d_rab_256-b8073a89.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.95, test_input_size=(3, 320, 320)),
+ 'regnetz_d8.ra3_in1k': _cfgr(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_d8_bh-afc03c55.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.94, test_input_size=(3, 320, 320), test_crop_pct=1.0),
+ 'regnetz_e8.ra3_in1k': _cfgr(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_e8_bh-aace8e6e.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.94, test_input_size=(3, 320, 320), test_crop_pct=1.0),
+
+ 'regnetz_b16_evos.untrained': _cfgr(
+ first_conv='stem.conv', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
+ input_size=(3, 224, 224), pool_size=(7, 7), crop_pct=0.95, test_input_size=(3, 288, 288)),
+ 'regnetz_c16_evos.ch_in1k': _cfgr(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_c16_evos_ch-d8311942.pth',
+ first_conv='stem.conv', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
+ crop_pct=0.95, test_input_size=(3, 320, 320)),
+ 'regnetz_d8_evos.ch_in1k': _cfgr(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/regnetz_d8_evos_ch-2bc12646.pth',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0),
+
+ 'mobileone_s0.apple_in1k': _cfg(
+ hf_hub_id='timm/',
+ crop_pct=0.875,
+ first_conv=('stem.conv_kxk.0.conv', 'stem.conv_scale.conv'),
+ ),
+ 'mobileone_s1.apple_in1k': _cfg(
+ hf_hub_id='timm/',
+ crop_pct=0.9,
+ first_conv=('stem.conv_kxk.0.conv', 'stem.conv_scale.conv'),
+ ),
+ 'mobileone_s2.apple_in1k': _cfg(
+ hf_hub_id='timm/',
+ crop_pct=0.9,
+ first_conv=('stem.conv_kxk.0.conv', 'stem.conv_scale.conv'),
+ ),
+ 'mobileone_s3.apple_in1k': _cfg(
+ hf_hub_id='timm/',
+ crop_pct=0.9,
+ first_conv=('stem.conv_kxk.0.conv', 'stem.conv_scale.conv'),
+ ),
+ 'mobileone_s4.apple_in1k': _cfg(
+ hf_hub_id='timm/',
+ crop_pct=0.9,
+ first_conv=('stem.conv_kxk.0.conv', 'stem.conv_scale.conv'),
+ ),
+})
+
+
+@register_model
+def gernet_l(pretrained=False, **kwargs) -> ByobNet:
+ """ GEResNet-Large (GENet-Large from official impl)
+ `Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
+ """
+ return _create_byobnet('gernet_l', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def gernet_m(pretrained=False, **kwargs) -> ByobNet:
+ """ GEResNet-Medium (GENet-Normal from official impl)
+ `Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
+ """
+ return _create_byobnet('gernet_m', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def gernet_s(pretrained=False, **kwargs) -> ByobNet:
+ """ EResNet-Small (GENet-Small from official impl)
+ `Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
+ """
+ return _create_byobnet('gernet_s', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def repvgg_a0(pretrained=False, **kwargs) -> ByobNet:
+ """ RepVGG-A0
+ `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
+ """
+ return _create_byobnet('repvgg_a0', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def repvgg_a1(pretrained=False, **kwargs) -> ByobNet:
+ """ RepVGG-A1
+ `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
+ """
+ return _create_byobnet('repvgg_a1', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def repvgg_a2(pretrained=False, **kwargs) -> ByobNet:
+ """ RepVGG-A2
+ `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
+ """
+ return _create_byobnet('repvgg_a2', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def repvgg_b0(pretrained=False, **kwargs) -> ByobNet:
+ """ RepVGG-B0
+ `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
+ """
+ return _create_byobnet('repvgg_b0', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def repvgg_b1(pretrained=False, **kwargs) -> ByobNet:
+ """ RepVGG-B1
+ `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
+ """
+ return _create_byobnet('repvgg_b1', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def repvgg_b1g4(pretrained=False, **kwargs) -> ByobNet:
+ """ RepVGG-B1g4
+ `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
+ """
+ return _create_byobnet('repvgg_b1g4', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def repvgg_b2(pretrained=False, **kwargs) -> ByobNet:
+ """ RepVGG-B2
+ `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
+ """
+ return _create_byobnet('repvgg_b2', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def repvgg_b2g4(pretrained=False, **kwargs) -> ByobNet:
+ """ RepVGG-B2g4
+ `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
+ """
+ return _create_byobnet('repvgg_b2g4', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def repvgg_b3(pretrained=False, **kwargs) -> ByobNet:
+ """ RepVGG-B3
+ `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
+ """
+ return _create_byobnet('repvgg_b3', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def repvgg_b3g4(pretrained=False, **kwargs) -> ByobNet:
+ """ RepVGG-B3g4
+ `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
+ """
+ return _create_byobnet('repvgg_b3g4', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def repvgg_d2se(pretrained=False, **kwargs) -> ByobNet:
+ """ RepVGG-D2se
+ `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
+ """
+ return _create_byobnet('repvgg_d2se', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def resnet51q(pretrained=False, **kwargs) -> ByobNet:
+ """
+ """
+ return _create_byobnet('resnet51q', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def resnet61q(pretrained=False, **kwargs) -> ByobNet:
+ """
+ """
+ return _create_byobnet('resnet61q', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def resnext26ts(pretrained=False, **kwargs) -> ByobNet:
+ """
+ """
+ return _create_byobnet('resnext26ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def gcresnext26ts(pretrained=False, **kwargs) -> ByobNet:
+ """
+ """
+ return _create_byobnet('gcresnext26ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def seresnext26ts(pretrained=False, **kwargs) -> ByobNet:
+ """
+ """
+ return _create_byobnet('seresnext26ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def eca_resnext26ts(pretrained=False, **kwargs) -> ByobNet:
+ """
+ """
+ return _create_byobnet('eca_resnext26ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def bat_resnext26ts(pretrained=False, **kwargs) -> ByobNet:
+ """
+ """
+ return _create_byobnet('bat_resnext26ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def resnet32ts(pretrained=False, **kwargs) -> ByobNet:
+ """
+ """
+ return _create_byobnet('resnet32ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def resnet33ts(pretrained=False, **kwargs) -> ByobNet:
+ """
+ """
+ return _create_byobnet('resnet33ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def gcresnet33ts(pretrained=False, **kwargs) -> ByobNet:
+ """
+ """
+ return _create_byobnet('gcresnet33ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def seresnet33ts(pretrained=False, **kwargs) -> ByobNet:
+ """
+ """
+ return _create_byobnet('seresnet33ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def eca_resnet33ts(pretrained=False, **kwargs) -> ByobNet:
+ """
+ """
+ return _create_byobnet('eca_resnet33ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def gcresnet50t(pretrained=False, **kwargs) -> ByobNet:
+ """
+ """
+ return _create_byobnet('gcresnet50t', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def gcresnext50ts(pretrained=False, **kwargs) -> ByobNet:
+ """
+ """
+ return _create_byobnet('gcresnext50ts', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def regnetz_b16(pretrained=False, **kwargs) -> ByobNet:
+ """
+ """
+ return _create_byobnet('regnetz_b16', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def regnetz_c16(pretrained=False, **kwargs) -> ByobNet:
+ """
+ """
+ return _create_byobnet('regnetz_c16', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def regnetz_d32(pretrained=False, **kwargs) -> ByobNet:
+ """
+ """
+ return _create_byobnet('regnetz_d32', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def regnetz_d8(pretrained=False, **kwargs) -> ByobNet:
+ """
+ """
+ return _create_byobnet('regnetz_d8', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def regnetz_e8(pretrained=False, **kwargs) -> ByobNet:
+ """
+ """
+ return _create_byobnet('regnetz_e8', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def regnetz_b16_evos(pretrained=False, **kwargs) -> ByobNet:
+ """
+ """
+ return _create_byobnet('regnetz_b16_evos', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def regnetz_c16_evos(pretrained=False, **kwargs) -> ByobNet:
+ """
+ """
+ return _create_byobnet('regnetz_c16_evos', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def regnetz_d8_evos(pretrained=False, **kwargs) -> ByobNet:
+ """
+ """
+ return _create_byobnet('regnetz_d8_evos', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def mobileone_s0(pretrained=False, **kwargs) -> ByobNet:
+ """
+ """
+ return _create_byobnet('mobileone_s0', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def mobileone_s1(pretrained=False, **kwargs) -> ByobNet:
+ """
+ """
+ return _create_byobnet('mobileone_s1', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def mobileone_s2(pretrained=False, **kwargs) -> ByobNet:
+ """
+ """
+ return _create_byobnet('mobileone_s2', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def mobileone_s3(pretrained=False, **kwargs) -> ByobNet:
+ """
+ """
+ return _create_byobnet('mobileone_s3', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def mobileone_s4(pretrained=False, **kwargs) -> ByobNet:
+ """
+ """
+ return _create_byobnet('mobileone_s4', pretrained=pretrained, **kwargs)
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/coat.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/coat.py
new file mode 100644
index 0000000000000000000000000000000000000000..68358b3d6bea4913e120d42d9871f7d494153ab4
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/coat.py
@@ -0,0 +1,804 @@
+"""
+CoaT architecture.
+
+Paper: Co-Scale Conv-Attentional Image Transformers - https://arxiv.org/abs/2104.06399
+
+Official CoaT code at: https://github.com/mlpc-ucsd/CoaT
+
+Modified from timm/models/vision_transformer.py
+"""
+from functools import partial
+from typing import Tuple, List, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert, LayerNorm
+from ._builder import build_model_with_cfg
+from ._registry import register_model, generate_default_cfgs
+
+__all__ = ['CoaT']
+
+
+class ConvRelPosEnc(nn.Module):
+ """ Convolutional relative position encoding. """
+ def __init__(self, head_chs, num_heads, window):
+ """
+ Initialization.
+ Ch: Channels per head.
+ h: Number of heads.
+ window: Window size(s) in convolutional relative positional encoding. It can have two forms:
+ 1. An integer of window size, which assigns all attention heads with the same window s
+ size in ConvRelPosEnc.
+ 2. A dict mapping window size to #attention head splits (
+ e.g. {window size 1: #attention head split 1, window size 2: #attention head split 2})
+ It will apply different window size to the attention head splits.
+ """
+ super().__init__()
+
+ if isinstance(window, int):
+ # Set the same window size for all attention heads.
+ window = {window: num_heads}
+ self.window = window
+ elif isinstance(window, dict):
+ self.window = window
+ else:
+ raise ValueError()
+
+ self.conv_list = nn.ModuleList()
+ self.head_splits = []
+ for cur_window, cur_head_split in window.items():
+ dilation = 1
+ # Determine padding size.
+ # Ref: https://discuss.pytorch.org/t/how-to-keep-the-shape-of-input-and-output-same-when-dilation-conv/14338
+ padding_size = (cur_window + (cur_window - 1) * (dilation - 1)) // 2
+ cur_conv = nn.Conv2d(
+ cur_head_split * head_chs,
+ cur_head_split * head_chs,
+ kernel_size=(cur_window, cur_window),
+ padding=(padding_size, padding_size),
+ dilation=(dilation, dilation),
+ groups=cur_head_split * head_chs,
+ )
+ self.conv_list.append(cur_conv)
+ self.head_splits.append(cur_head_split)
+ self.channel_splits = [x * head_chs for x in self.head_splits]
+
+ def forward(self, q, v, size: Tuple[int, int]):
+ B, num_heads, N, C = q.shape
+ H, W = size
+ _assert(N == 1 + H * W, '')
+
+ # Convolutional relative position encoding.
+ q_img = q[:, :, 1:, :] # [B, h, H*W, Ch]
+ v_img = v[:, :, 1:, :] # [B, h, H*W, Ch]
+
+ v_img = v_img.transpose(-1, -2).reshape(B, num_heads * C, H, W)
+ v_img_list = torch.split(v_img, self.channel_splits, dim=1) # Split according to channels
+ conv_v_img_list = []
+ for i, conv in enumerate(self.conv_list):
+ conv_v_img_list.append(conv(v_img_list[i]))
+ conv_v_img = torch.cat(conv_v_img_list, dim=1)
+ conv_v_img = conv_v_img.reshape(B, num_heads, C, H * W).transpose(-1, -2)
+
+ EV_hat = q_img * conv_v_img
+ EV_hat = F.pad(EV_hat, (0, 0, 1, 0, 0, 0)) # [B, h, N, Ch].
+ return EV_hat
+
+
+class FactorAttnConvRelPosEnc(nn.Module):
+ """ Factorized attention with convolutional relative position encoding class. """
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ attn_drop=0.,
+ proj_drop=0.,
+ shared_crpe=None,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop) # Note: attn_drop is actually not used.
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ # Shared convolutional relative position encoding.
+ self.crpe = shared_crpe
+
+ def forward(self, x, size: Tuple[int, int]):
+ B, N, C = x.shape
+
+ # Generate Q, K, V.
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0) # [B, h, N, Ch]
+
+ # Factorized attention.
+ k_softmax = k.softmax(dim=2)
+ factor_att = k_softmax.transpose(-1, -2) @ v
+ factor_att = q @ factor_att
+
+ # Convolutional relative position encoding.
+ crpe = self.crpe(q, v, size=size) # [B, h, N, Ch]
+
+ # Merge and reshape.
+ x = self.scale * factor_att + crpe
+ x = x.transpose(1, 2).reshape(B, N, C) # [B, h, N, Ch] -> [B, N, h, Ch] -> [B, N, C]
+
+ # Output projection.
+ x = self.proj(x)
+ x = self.proj_drop(x)
+
+ return x
+
+
+class ConvPosEnc(nn.Module):
+ """ Convolutional Position Encoding.
+ Note: This module is similar to the conditional position encoding in CPVT.
+ """
+ def __init__(self, dim, k=3):
+ super(ConvPosEnc, self).__init__()
+ self.proj = nn.Conv2d(dim, dim, k, 1, k//2, groups=dim)
+
+ def forward(self, x, size: Tuple[int, int]):
+ B, N, C = x.shape
+ H, W = size
+ _assert(N == 1 + H * W, '')
+
+ # Extract CLS token and image tokens.
+ cls_token, img_tokens = x[:, :1], x[:, 1:] # [B, 1, C], [B, H*W, C]
+
+ # Depthwise convolution.
+ feat = img_tokens.transpose(1, 2).view(B, C, H, W)
+ x = self.proj(feat) + feat
+ x = x.flatten(2).transpose(1, 2)
+
+ # Combine with CLS token.
+ x = torch.cat((cls_token, x), dim=1)
+
+ return x
+
+
+class SerialBlock(nn.Module):
+ """ Serial block class.
+ Note: In this implementation, each serial block only contains a conv-attention and a FFN (MLP) module. """
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.,
+ qkv_bias=False,
+ proj_drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ shared_cpe=None,
+ shared_crpe=None,
+ ):
+ super().__init__()
+
+ # Conv-Attention.
+ self.cpe = shared_cpe
+
+ self.norm1 = norm_layer(dim)
+ self.factoratt_crpe = FactorAttnConvRelPosEnc(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ attn_drop=attn_drop,
+ proj_drop=proj_drop,
+ shared_crpe=shared_crpe,
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ # MLP.
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=proj_drop,
+ )
+
+ def forward(self, x, size: Tuple[int, int]):
+ # Conv-Attention.
+ x = self.cpe(x, size)
+ cur = self.norm1(x)
+ cur = self.factoratt_crpe(cur, size)
+ x = x + self.drop_path(cur)
+
+ # MLP.
+ cur = self.norm2(x)
+ cur = self.mlp(cur)
+ x = x + self.drop_path(cur)
+
+ return x
+
+
+class ParallelBlock(nn.Module):
+ """ Parallel block class. """
+ def __init__(
+ self,
+ dims,
+ num_heads,
+ mlp_ratios=[],
+ qkv_bias=False,
+ proj_drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ shared_crpes=None,
+ ):
+ super().__init__()
+
+ # Conv-Attention.
+ self.norm12 = norm_layer(dims[1])
+ self.norm13 = norm_layer(dims[2])
+ self.norm14 = norm_layer(dims[3])
+ self.factoratt_crpe2 = FactorAttnConvRelPosEnc(
+ dims[1],
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ attn_drop=attn_drop,
+ proj_drop=proj_drop,
+ shared_crpe=shared_crpes[1],
+ )
+ self.factoratt_crpe3 = FactorAttnConvRelPosEnc(
+ dims[2],
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ attn_drop=attn_drop,
+ proj_drop=proj_drop,
+ shared_crpe=shared_crpes[2],
+ )
+ self.factoratt_crpe4 = FactorAttnConvRelPosEnc(
+ dims[3],
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ attn_drop=attn_drop,
+ proj_drop=proj_drop,
+ shared_crpe=shared_crpes[3],
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ # MLP.
+ self.norm22 = norm_layer(dims[1])
+ self.norm23 = norm_layer(dims[2])
+ self.norm24 = norm_layer(dims[3])
+ # In parallel block, we assume dimensions are the same and share the linear transformation.
+ assert dims[1] == dims[2] == dims[3]
+ assert mlp_ratios[1] == mlp_ratios[2] == mlp_ratios[3]
+ mlp_hidden_dim = int(dims[1] * mlp_ratios[1])
+ self.mlp2 = self.mlp3 = self.mlp4 = Mlp(
+ in_features=dims[1],
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=proj_drop,
+ )
+
+ def upsample(self, x, factor: float, size: Tuple[int, int]):
+ """ Feature map up-sampling. """
+ return self.interpolate(x, scale_factor=factor, size=size)
+
+ def downsample(self, x, factor: float, size: Tuple[int, int]):
+ """ Feature map down-sampling. """
+ return self.interpolate(x, scale_factor=1.0/factor, size=size)
+
+ def interpolate(self, x, scale_factor: float, size: Tuple[int, int]):
+ """ Feature map interpolation. """
+ B, N, C = x.shape
+ H, W = size
+ _assert(N == 1 + H * W, '')
+
+ cls_token = x[:, :1, :]
+ img_tokens = x[:, 1:, :]
+
+ img_tokens = img_tokens.transpose(1, 2).reshape(B, C, H, W)
+ img_tokens = F.interpolate(
+ img_tokens,
+ scale_factor=scale_factor,
+ recompute_scale_factor=False,
+ mode='bilinear',
+ align_corners=False,
+ )
+ img_tokens = img_tokens.reshape(B, C, -1).transpose(1, 2)
+
+ out = torch.cat((cls_token, img_tokens), dim=1)
+
+ return out
+
+ def forward(self, x1, x2, x3, x4, sizes: List[Tuple[int, int]]):
+ _, S2, S3, S4 = sizes
+ cur2 = self.norm12(x2)
+ cur3 = self.norm13(x3)
+ cur4 = self.norm14(x4)
+ cur2 = self.factoratt_crpe2(cur2, size=S2)
+ cur3 = self.factoratt_crpe3(cur3, size=S3)
+ cur4 = self.factoratt_crpe4(cur4, size=S4)
+ upsample3_2 = self.upsample(cur3, factor=2., size=S3)
+ upsample4_3 = self.upsample(cur4, factor=2., size=S4)
+ upsample4_2 = self.upsample(cur4, factor=4., size=S4)
+ downsample2_3 = self.downsample(cur2, factor=2., size=S2)
+ downsample3_4 = self.downsample(cur3, factor=2., size=S3)
+ downsample2_4 = self.downsample(cur2, factor=4., size=S2)
+ cur2 = cur2 + upsample3_2 + upsample4_2
+ cur3 = cur3 + upsample4_3 + downsample2_3
+ cur4 = cur4 + downsample3_4 + downsample2_4
+ x2 = x2 + self.drop_path(cur2)
+ x3 = x3 + self.drop_path(cur3)
+ x4 = x4 + self.drop_path(cur4)
+
+ # MLP.
+ cur2 = self.norm22(x2)
+ cur3 = self.norm23(x3)
+ cur4 = self.norm24(x4)
+ cur2 = self.mlp2(cur2)
+ cur3 = self.mlp3(cur3)
+ cur4 = self.mlp4(cur4)
+ x2 = x2 + self.drop_path(cur2)
+ x3 = x3 + self.drop_path(cur3)
+ x4 = x4 + self.drop_path(cur4)
+
+ return x1, x2, x3, x4
+
+
+class CoaT(nn.Module):
+ """ CoaT class. """
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ num_classes=1000,
+ embed_dims=(64, 128, 320, 512),
+ serial_depths=(3, 4, 6, 3),
+ parallel_depth=0,
+ num_heads=8,
+ mlp_ratios=(4, 4, 4, 4),
+ qkv_bias=True,
+ drop_rate=0.,
+ proj_drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ norm_layer=LayerNorm,
+ return_interm_layers=False,
+ out_features=None,
+ crpe_window=None,
+ global_pool='token',
+ ):
+ super().__init__()
+ assert global_pool in ('token', 'avg')
+ crpe_window = crpe_window or {3: 2, 5: 3, 7: 3}
+ self.return_interm_layers = return_interm_layers
+ self.out_features = out_features
+ self.embed_dims = embed_dims
+ self.num_features = embed_dims[-1]
+ self.num_classes = num_classes
+ self.global_pool = global_pool
+
+ # Patch embeddings.
+ img_size = to_2tuple(img_size)
+ self.patch_embed1 = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans,
+ embed_dim=embed_dims[0], norm_layer=nn.LayerNorm)
+ self.patch_embed2 = PatchEmbed(
+ img_size=[x // 4 for x in img_size], patch_size=2, in_chans=embed_dims[0],
+ embed_dim=embed_dims[1], norm_layer=nn.LayerNorm)
+ self.patch_embed3 = PatchEmbed(
+ img_size=[x // 8 for x in img_size], patch_size=2, in_chans=embed_dims[1],
+ embed_dim=embed_dims[2], norm_layer=nn.LayerNorm)
+ self.patch_embed4 = PatchEmbed(
+ img_size=[x // 16 for x in img_size], patch_size=2, in_chans=embed_dims[2],
+ embed_dim=embed_dims[3], norm_layer=nn.LayerNorm)
+
+ # Class tokens.
+ self.cls_token1 = nn.Parameter(torch.zeros(1, 1, embed_dims[0]))
+ self.cls_token2 = nn.Parameter(torch.zeros(1, 1, embed_dims[1]))
+ self.cls_token3 = nn.Parameter(torch.zeros(1, 1, embed_dims[2]))
+ self.cls_token4 = nn.Parameter(torch.zeros(1, 1, embed_dims[3]))
+
+ # Convolutional position encodings.
+ self.cpe1 = ConvPosEnc(dim=embed_dims[0], k=3)
+ self.cpe2 = ConvPosEnc(dim=embed_dims[1], k=3)
+ self.cpe3 = ConvPosEnc(dim=embed_dims[2], k=3)
+ self.cpe4 = ConvPosEnc(dim=embed_dims[3], k=3)
+
+ # Convolutional relative position encodings.
+ self.crpe1 = ConvRelPosEnc(head_chs=embed_dims[0] // num_heads, num_heads=num_heads, window=crpe_window)
+ self.crpe2 = ConvRelPosEnc(head_chs=embed_dims[1] // num_heads, num_heads=num_heads, window=crpe_window)
+ self.crpe3 = ConvRelPosEnc(head_chs=embed_dims[2] // num_heads, num_heads=num_heads, window=crpe_window)
+ self.crpe4 = ConvRelPosEnc(head_chs=embed_dims[3] // num_heads, num_heads=num_heads, window=crpe_window)
+
+ # Disable stochastic depth.
+ dpr = drop_path_rate
+ assert dpr == 0.0
+ skwargs = dict(
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_drop=proj_drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr,
+ norm_layer=norm_layer,
+ )
+
+ # Serial blocks 1.
+ self.serial_blocks1 = nn.ModuleList([
+ SerialBlock(
+ dim=embed_dims[0],
+ mlp_ratio=mlp_ratios[0],
+ shared_cpe=self.cpe1,
+ shared_crpe=self.crpe1,
+ **skwargs,
+ )
+ for _ in range(serial_depths[0])]
+ )
+
+ # Serial blocks 2.
+ self.serial_blocks2 = nn.ModuleList([
+ SerialBlock(
+ dim=embed_dims[1],
+ mlp_ratio=mlp_ratios[1],
+ shared_cpe=self.cpe2,
+ shared_crpe=self.crpe2,
+ **skwargs,
+ )
+ for _ in range(serial_depths[1])]
+ )
+
+ # Serial blocks 3.
+ self.serial_blocks3 = nn.ModuleList([
+ SerialBlock(
+ dim=embed_dims[2],
+ mlp_ratio=mlp_ratios[2],
+ shared_cpe=self.cpe3,
+ shared_crpe=self.crpe3,
+ **skwargs,
+ )
+ for _ in range(serial_depths[2])]
+ )
+
+ # Serial blocks 4.
+ self.serial_blocks4 = nn.ModuleList([
+ SerialBlock(
+ dim=embed_dims[3],
+ mlp_ratio=mlp_ratios[3],
+ shared_cpe=self.cpe4,
+ shared_crpe=self.crpe4,
+ **skwargs,
+ )
+ for _ in range(serial_depths[3])]
+ )
+
+ # Parallel blocks.
+ self.parallel_depth = parallel_depth
+ if self.parallel_depth > 0:
+ self.parallel_blocks = nn.ModuleList([
+ ParallelBlock(
+ dims=embed_dims,
+ mlp_ratios=mlp_ratios,
+ shared_crpes=(self.crpe1, self.crpe2, self.crpe3, self.crpe4),
+ **skwargs,
+ )
+ for _ in range(parallel_depth)]
+ )
+ else:
+ self.parallel_blocks = None
+
+ # Classification head(s).
+ if not self.return_interm_layers:
+ if self.parallel_blocks is not None:
+ self.norm2 = norm_layer(embed_dims[1])
+ self.norm3 = norm_layer(embed_dims[2])
+ else:
+ self.norm2 = self.norm3 = None
+ self.norm4 = norm_layer(embed_dims[3])
+
+ if self.parallel_depth > 0:
+ # CoaT series: Aggregate features of last three scales for classification.
+ assert embed_dims[1] == embed_dims[2] == embed_dims[3]
+ self.aggregate = torch.nn.Conv1d(in_channels=3, out_channels=1, kernel_size=1)
+ self.head_drop = nn.Dropout(drop_rate)
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+ else:
+ # CoaT-Lite series: Use feature of last scale for classification.
+ self.aggregate = None
+ self.head_drop = nn.Dropout(drop_rate)
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+ # Initialize weights.
+ trunc_normal_(self.cls_token1, std=.02)
+ trunc_normal_(self.cls_token2, std=.02)
+ trunc_normal_(self.cls_token3, std=.02)
+ trunc_normal_(self.cls_token4, std=.02)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'cls_token1', 'cls_token2', 'cls_token3', 'cls_token4'}
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ assert not enable, 'gradient checkpointing not supported'
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse=False):
+ matcher = dict(
+ stem1=r'^cls_token1|patch_embed1|crpe1|cpe1',
+ serial_blocks1=r'^serial_blocks1\.(\d+)',
+ stem2=r'^cls_token2|patch_embed2|crpe2|cpe2',
+ serial_blocks2=r'^serial_blocks2\.(\d+)',
+ stem3=r'^cls_token3|patch_embed3|crpe3|cpe3',
+ serial_blocks3=r'^serial_blocks3\.(\d+)',
+ stem4=r'^cls_token4|patch_embed4|crpe4|cpe4',
+ serial_blocks4=r'^serial_blocks4\.(\d+)',
+ parallel_blocks=[ # FIXME (partially?) overlap parallel w/ serial blocks??
+ (r'^parallel_blocks\.(\d+)', None),
+ (r'^norm|aggregate', (99999,)),
+ ]
+ )
+ return matcher
+
+ @torch.jit.ignore
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=None):
+ self.num_classes = num_classes
+ if global_pool is not None:
+ assert global_pool in ('token', 'avg')
+ self.global_pool = global_pool
+ self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x0):
+ B = x0.shape[0]
+
+ # Serial blocks 1.
+ x1 = self.patch_embed1(x0)
+ H1, W1 = self.patch_embed1.grid_size
+ x1 = insert_cls(x1, self.cls_token1)
+ for blk in self.serial_blocks1:
+ x1 = blk(x1, size=(H1, W1))
+ x1_nocls = remove_cls(x1).reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
+
+ # Serial blocks 2.
+ x2 = self.patch_embed2(x1_nocls)
+ H2, W2 = self.patch_embed2.grid_size
+ x2 = insert_cls(x2, self.cls_token2)
+ for blk in self.serial_blocks2:
+ x2 = blk(x2, size=(H2, W2))
+ x2_nocls = remove_cls(x2).reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous()
+
+ # Serial blocks 3.
+ x3 = self.patch_embed3(x2_nocls)
+ H3, W3 = self.patch_embed3.grid_size
+ x3 = insert_cls(x3, self.cls_token3)
+ for blk in self.serial_blocks3:
+ x3 = blk(x3, size=(H3, W3))
+ x3_nocls = remove_cls(x3).reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous()
+
+ # Serial blocks 4.
+ x4 = self.patch_embed4(x3_nocls)
+ H4, W4 = self.patch_embed4.grid_size
+ x4 = insert_cls(x4, self.cls_token4)
+ for blk in self.serial_blocks4:
+ x4 = blk(x4, size=(H4, W4))
+ x4_nocls = remove_cls(x4).reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous()
+
+ # Only serial blocks: Early return.
+ if self.parallel_blocks is None:
+ if not torch.jit.is_scripting() and self.return_interm_layers:
+ # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2).
+ feat_out = {}
+ if 'x1_nocls' in self.out_features:
+ feat_out['x1_nocls'] = x1_nocls
+ if 'x2_nocls' in self.out_features:
+ feat_out['x2_nocls'] = x2_nocls
+ if 'x3_nocls' in self.out_features:
+ feat_out['x3_nocls'] = x3_nocls
+ if 'x4_nocls' in self.out_features:
+ feat_out['x4_nocls'] = x4_nocls
+ return feat_out
+ else:
+ # Return features for classification.
+ x4 = self.norm4(x4)
+ return x4
+
+ # Parallel blocks.
+ for blk in self.parallel_blocks:
+ x2, x3, x4 = self.cpe2(x2, (H2, W2)), self.cpe3(x3, (H3, W3)), self.cpe4(x4, (H4, W4))
+ x1, x2, x3, x4 = blk(x1, x2, x3, x4, sizes=[(H1, W1), (H2, W2), (H3, W3), (H4, W4)])
+
+ if not torch.jit.is_scripting() and self.return_interm_layers:
+ # Return intermediate features for down-stream tasks (e.g. Deformable DETR and Detectron2).
+ feat_out = {}
+ if 'x1_nocls' in self.out_features:
+ x1_nocls = remove_cls(x1).reshape(B, H1, W1, -1).permute(0, 3, 1, 2).contiguous()
+ feat_out['x1_nocls'] = x1_nocls
+ if 'x2_nocls' in self.out_features:
+ x2_nocls = remove_cls(x2).reshape(B, H2, W2, -1).permute(0, 3, 1, 2).contiguous()
+ feat_out['x2_nocls'] = x2_nocls
+ if 'x3_nocls' in self.out_features:
+ x3_nocls = remove_cls(x3).reshape(B, H3, W3, -1).permute(0, 3, 1, 2).contiguous()
+ feat_out['x3_nocls'] = x3_nocls
+ if 'x4_nocls' in self.out_features:
+ x4_nocls = remove_cls(x4).reshape(B, H4, W4, -1).permute(0, 3, 1, 2).contiguous()
+ feat_out['x4_nocls'] = x4_nocls
+ return feat_out
+ else:
+ x2 = self.norm2(x2)
+ x3 = self.norm3(x3)
+ x4 = self.norm4(x4)
+ return [x2, x3, x4]
+
+ def forward_head(self, x_feat: Union[torch.Tensor, List[torch.Tensor]], pre_logits: bool = False):
+ if isinstance(x_feat, list):
+ assert self.aggregate is not None
+ if self.global_pool == 'avg':
+ x = torch.cat([xl[:, 1:].mean(dim=1, keepdim=True) for xl in x_feat], dim=1) # [B, 3, C]
+ else:
+ x = torch.stack([xl[:, 0] for xl in x_feat], dim=1) # [B, 3, C]
+ x = self.aggregate(x).squeeze(dim=1) # Shape: [B, C]
+ else:
+ x = x_feat[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x_feat[:, 0]
+ x = self.head_drop(x)
+ return x if pre_logits else self.head(x)
+
+ def forward(self, x) -> torch.Tensor:
+ if not torch.jit.is_scripting() and self.return_interm_layers:
+ # Return intermediate features (for down-stream tasks).
+ return self.forward_features(x)
+ else:
+ # Return features for classification.
+ x_feat = self.forward_features(x)
+ x = self.forward_head(x_feat)
+ return x
+
+
+def insert_cls(x, cls_token):
+ """ Insert CLS token. """
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+ return x
+
+
+def remove_cls(x):
+ """ Remove CLS token. """
+ return x[:, 1:, :]
+
+
+def checkpoint_filter_fn(state_dict, model):
+ out_dict = {}
+ state_dict = state_dict.get('model', state_dict)
+ for k, v in state_dict.items():
+ # original model had unused norm layers, removing them requires filtering pretrained checkpoints
+ if k.startswith('norm1') or \
+ (k.startswith('norm2') and getattr(model, 'norm2', None) is None) or \
+ (k.startswith('norm3') and getattr(model, 'norm3', None) is None) or \
+ (k.startswith('norm4') and getattr(model, 'norm4', None) is None) or \
+ (k.startswith('aggregate') and getattr(model, 'aggregate', None) is None) or \
+ (k.startswith('head') and getattr(model, 'head', None) is None):
+ continue
+ out_dict[k] = v
+ return out_dict
+
+
+def _create_coat(variant, pretrained=False, default_cfg=None, **kwargs):
+ if kwargs.get('features_only', None):
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
+
+ model = build_model_with_cfg(
+ CoaT,
+ variant,
+ pretrained,
+ pretrained_filter_fn=checkpoint_filter_fn,
+ **kwargs,
+ )
+ return model
+
+
+def _cfg_coat(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'patch_embed1.proj', 'classifier': 'head',
+ **kwargs
+ }
+
+
+default_cfgs = generate_default_cfgs({
+ 'coat_tiny.in1k': _cfg_coat(hf_hub_id='timm/'),
+ 'coat_mini.in1k': _cfg_coat(hf_hub_id='timm/'),
+ 'coat_small.in1k': _cfg_coat(hf_hub_id='timm/'),
+ 'coat_lite_tiny.in1k': _cfg_coat(hf_hub_id='timm/'),
+ 'coat_lite_mini.in1k': _cfg_coat(hf_hub_id='timm/'),
+ 'coat_lite_small.in1k': _cfg_coat(hf_hub_id='timm/'),
+ 'coat_lite_medium.in1k': _cfg_coat(hf_hub_id='timm/'),
+ 'coat_lite_medium_384.in1k': _cfg_coat(
+ hf_hub_id='timm/',
+ input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash',
+ ),
+})
+
+
+@register_model
+def coat_tiny(pretrained=False, **kwargs) -> CoaT:
+ model_cfg = dict(
+ patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6)
+ model = _create_coat('coat_tiny', pretrained=pretrained, **dict(model_cfg, **kwargs))
+ return model
+
+
+@register_model
+def coat_mini(pretrained=False, **kwargs) -> CoaT:
+ model_cfg = dict(
+ patch_size=4, embed_dims=[152, 216, 216, 216], serial_depths=[2, 2, 2, 2], parallel_depth=6)
+ model = _create_coat('coat_mini', pretrained=pretrained, **dict(model_cfg, **kwargs))
+ return model
+
+
+@register_model
+def coat_small(pretrained=False, **kwargs) -> CoaT:
+ model_cfg = dict(
+ patch_size=4, embed_dims=[152, 320, 320, 320], serial_depths=[2, 2, 2, 2], parallel_depth=6, **kwargs)
+ model = _create_coat('coat_small', pretrained=pretrained, **dict(model_cfg, **kwargs))
+ return model
+
+
+@register_model
+def coat_lite_tiny(pretrained=False, **kwargs) -> CoaT:
+ model_cfg = dict(
+ patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], mlp_ratios=[8, 8, 4, 4])
+ model = _create_coat('coat_lite_tiny', pretrained=pretrained, **dict(model_cfg, **kwargs))
+ return model
+
+
+@register_model
+def coat_lite_mini(pretrained=False, **kwargs) -> CoaT:
+ model_cfg = dict(
+ patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[2, 2, 2, 2], mlp_ratios=[8, 8, 4, 4])
+ model = _create_coat('coat_lite_mini', pretrained=pretrained, **dict(model_cfg, **kwargs))
+ return model
+
+
+@register_model
+def coat_lite_small(pretrained=False, **kwargs) -> CoaT:
+ model_cfg = dict(
+ patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[3, 4, 6, 3], mlp_ratios=[8, 8, 4, 4])
+ model = _create_coat('coat_lite_small', pretrained=pretrained, **dict(model_cfg, **kwargs))
+ return model
+
+
+@register_model
+def coat_lite_medium(pretrained=False, **kwargs) -> CoaT:
+ model_cfg = dict(
+ patch_size=4, embed_dims=[128, 256, 320, 512], serial_depths=[3, 6, 10, 8])
+ model = _create_coat('coat_lite_medium', pretrained=pretrained, **dict(model_cfg, **kwargs))
+ return model
+
+
+@register_model
+def coat_lite_medium_384(pretrained=False, **kwargs) -> CoaT:
+ model_cfg = dict(
+ img_size=384, patch_size=4, embed_dims=[128, 256, 320, 512], serial_depths=[3, 6, 10, 8])
+ model = _create_coat('coat_lite_medium_384', pretrained=pretrained, **dict(model_cfg, **kwargs))
+ return model
\ No newline at end of file
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/convit.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/convit.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cfcae2705d5eba5d299357c81742799dad5667e
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/convit.py
@@ -0,0 +1,430 @@
+""" ConViT Model
+
+@article{d2021convit,
+ title={ConViT: Improving Vision Transformers with Soft Convolutional Inductive Biases},
+ author={d'Ascoli, St{\'e}phane and Touvron, Hugo and Leavitt, Matthew and Morcos, Ari and Biroli, Giulio and Sagun, Levent},
+ journal={arXiv preprint arXiv:2103.10697},
+ year={2021}
+}
+
+Paper link: https://arxiv.org/abs/2103.10697
+Original code: https://github.com/facebookresearch/convit, original copyright below
+
+Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
+"""
+# Copyright (c) 2015-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the CC-by-NC license found in the
+# LICENSE file in the root directory of this source tree.
+#
+'''These modules are adapted from those of timm, see
+https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+'''
+
+from functools import partial
+
+import torch
+import torch.nn as nn
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.layers import DropPath, trunc_normal_, PatchEmbed, Mlp, LayerNorm
+from ._builder import build_model_with_cfg
+from ._features_fx import register_notrace_module
+from ._registry import register_model, generate_default_cfgs
+from .vision_transformer_hybrid import HybridEmbed
+
+
+__all__ = ['ConVit']
+
+
+@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
+class GPSA(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ attn_drop=0.,
+ proj_drop=0.,
+ locality_strength=1.,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ self.dim = dim
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+ self.locality_strength = locality_strength
+
+ self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias)
+ self.v = nn.Linear(dim, dim, bias=qkv_bias)
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.pos_proj = nn.Linear(3, num_heads)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.gating_param = nn.Parameter(torch.ones(self.num_heads))
+ self.rel_indices: torch.Tensor = torch.zeros(1, 1, 1, 3) # silly torchscript hack, won't work with None
+
+ def forward(self, x):
+ B, N, C = x.shape
+ if self.rel_indices is None or self.rel_indices.shape[1] != N:
+ self.rel_indices = self.get_rel_indices(N)
+ attn = self.get_attention(x)
+ v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def get_attention(self, x):
+ B, N, C = x.shape
+ qk = self.qk(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k = qk[0], qk[1]
+ pos_score = self.rel_indices.expand(B, -1, -1, -1)
+ pos_score = self.pos_proj(pos_score).permute(0, 3, 1, 2)
+ patch_score = (q @ k.transpose(-2, -1)) * self.scale
+ patch_score = patch_score.softmax(dim=-1)
+ pos_score = pos_score.softmax(dim=-1)
+
+ gating = self.gating_param.view(1, -1, 1, 1)
+ attn = (1. - torch.sigmoid(gating)) * patch_score + torch.sigmoid(gating) * pos_score
+ attn /= attn.sum(dim=-1).unsqueeze(-1)
+ attn = self.attn_drop(attn)
+ return attn
+
+ def get_attention_map(self, x, return_map=False):
+ attn_map = self.get_attention(x).mean(0) # average over batch
+ distances = self.rel_indices.squeeze()[:, :, -1] ** .5
+ dist = torch.einsum('nm,hnm->h', (distances, attn_map)) / distances.size(0)
+ if return_map:
+ return dist, attn_map
+ else:
+ return dist
+
+ def local_init(self):
+ self.v.weight.data.copy_(torch.eye(self.dim))
+ locality_distance = 1 # max(1,1/locality_strength**.5)
+
+ kernel_size = int(self.num_heads ** .5)
+ center = (kernel_size - 1) / 2 if kernel_size % 2 == 0 else kernel_size // 2
+ for h1 in range(kernel_size):
+ for h2 in range(kernel_size):
+ position = h1 + kernel_size * h2
+ self.pos_proj.weight.data[position, 2] = -1
+ self.pos_proj.weight.data[position, 1] = 2 * (h1 - center) * locality_distance
+ self.pos_proj.weight.data[position, 0] = 2 * (h2 - center) * locality_distance
+ self.pos_proj.weight.data *= self.locality_strength
+
+ def get_rel_indices(self, num_patches: int) -> torch.Tensor:
+ img_size = int(num_patches ** .5)
+ rel_indices = torch.zeros(1, num_patches, num_patches, 3)
+ ind = torch.arange(img_size).view(1, -1) - torch.arange(img_size).view(-1, 1)
+ indx = ind.repeat(img_size, img_size)
+ indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1)
+ indd = indx ** 2 + indy ** 2
+ rel_indices[:, :, :, 2] = indd.unsqueeze(0)
+ rel_indices[:, :, :, 1] = indy.unsqueeze(0)
+ rel_indices[:, :, :, 0] = indx.unsqueeze(0)
+ device = self.qk.weight.device
+ return rel_indices.to(device)
+
+
+class MHSA(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ attn_drop=0.,
+ proj_drop=0.,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim ** -0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def get_attention_map(self, x, return_map=False):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2]
+ attn_map = (q @ k.transpose(-2, -1)) * self.scale
+ attn_map = attn_map.softmax(dim=-1).mean(0)
+
+ img_size = int(N ** .5)
+ ind = torch.arange(img_size).view(1, -1) - torch.arange(img_size).view(-1, 1)
+ indx = ind.repeat(img_size, img_size)
+ indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1)
+ indd = indx ** 2 + indy ** 2
+ distances = indd ** .5
+ distances = distances.to(x.device)
+
+ dist = torch.einsum('nm,hnm->h', (distances, attn_map)) / N
+ if return_map:
+ return dist, attn_map
+ else:
+ return dist
+
+ def forward(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.,
+ qkv_bias=False,
+ proj_drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ act_layer=nn.GELU,
+ norm_layer=LayerNorm,
+ use_gpsa=True,
+ locality_strength=1.,
+ ):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.use_gpsa = use_gpsa
+ if self.use_gpsa:
+ self.attn = GPSA(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ attn_drop=attn_drop,
+ proj_drop=proj_drop,
+ locality_strength=locality_strength,
+ )
+ else:
+ self.attn = MHSA(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ attn_drop=attn_drop,
+ proj_drop=proj_drop,
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=proj_drop,
+ )
+
+ def forward(self, x):
+ x = x + self.drop_path(self.attn(self.norm1(x)))
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+ return x
+
+
+class ConVit(nn.Module):
+ """ Vision Transformer with support for patch or hybrid CNN input stage
+ """
+
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ num_classes=1000,
+ global_pool='token',
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.,
+ qkv_bias=False,
+ drop_rate=0.,
+ pos_drop_rate=0.,
+ proj_drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ hybrid_backbone=None,
+ norm_layer=LayerNorm,
+ local_up_to_layer=3,
+ locality_strength=1.,
+ use_pos_embed=True,
+ ):
+ super().__init__()
+ assert global_pool in ('', 'avg', 'token')
+ embed_dim *= num_heads
+ self.num_classes = num_classes
+ self.global_pool = global_pool
+ self.local_up_to_layer = local_up_to_layer
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.locality_strength = locality_strength
+ self.use_pos_embed = use_pos_embed
+
+ if hybrid_backbone is not None:
+ self.patch_embed = HybridEmbed(
+ hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
+ else:
+ self.patch_embed = PatchEmbed(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ )
+ num_patches = self.patch_embed.num_patches
+ self.num_patches = num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_drop = nn.Dropout(p=pos_drop_rate)
+
+ if self.use_pos_embed:
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+ trunc_normal_(self.pos_embed, std=.02)
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ self.blocks = nn.ModuleList([
+ Block(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_drop=proj_drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ use_gpsa=i < local_up_to_layer,
+ locality_strength=locality_strength,
+ ) for i in range(depth)])
+ self.norm = norm_layer(embed_dim)
+
+ # Classifier head
+ self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')]
+ self.head_drop = nn.Dropout(drop_rate)
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ trunc_normal_(self.cls_token, std=.02)
+ self.apply(self._init_weights)
+ for n, m in self.named_modules():
+ if hasattr(m, 'local_init'):
+ m.local_init()
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed', 'cls_token'}
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse=False):
+ return dict(
+ stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
+ blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
+ )
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ assert not enable, 'gradient checkpointing not supported'
+
+ @torch.jit.ignore
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=None):
+ self.num_classes = num_classes
+ if global_pool is not None:
+ assert global_pool in ('', 'token', 'avg')
+ self.global_pool = global_pool
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ x = self.patch_embed(x)
+ if self.use_pos_embed:
+ x = x + self.pos_embed
+ x = self.pos_drop(x)
+ cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
+ for u, blk in enumerate(self.blocks):
+ if u == self.local_up_to_layer:
+ x = torch.cat((cls_tokens, x), dim=1)
+ x = blk(x)
+ x = self.norm(x)
+ return x
+
+ def forward_head(self, x, pre_logits: bool = False):
+ if self.global_pool:
+ x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
+ x = self.head_drop(x)
+ return x if pre_logits else self.head(x)
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.forward_head(x)
+ return x
+
+
+def _create_convit(variant, pretrained=False, **kwargs):
+ if kwargs.get('features_only', None):
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
+
+ return build_model_with_cfg(ConVit, variant, pretrained, **kwargs)
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True,
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
+ **kwargs
+ }
+
+
+default_cfgs = generate_default_cfgs({
+ # ConViT
+ 'convit_tiny.fb_in1k': _cfg(hf_hub_id='timm/'),
+ 'convit_small.fb_in1k': _cfg(hf_hub_id='timm/'),
+ 'convit_base.fb_in1k': _cfg(hf_hub_id='timm/')
+})
+
+
+@register_model
+def convit_tiny(pretrained=False, **kwargs) -> ConVit:
+ model_args = dict(
+ local_up_to_layer=10, locality_strength=1.0, embed_dim=48, num_heads=4)
+ model = _create_convit(variant='convit_tiny', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def convit_small(pretrained=False, **kwargs) -> ConVit:
+ model_args = dict(
+ local_up_to_layer=10, locality_strength=1.0, embed_dim=48, num_heads=9)
+ model = _create_convit(variant='convit_small', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def convit_base(pretrained=False, **kwargs) -> ConVit:
+ model_args = dict(
+ local_up_to_layer=10, locality_strength=1.0, embed_dim=48, num_heads=16)
+ model = _create_convit(variant='convit_base', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/crossvit.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/crossvit.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c90aec97c5cc504d6b5588a1b36d9d65c0ddb26
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/crossvit.py
@@ -0,0 +1,627 @@
+""" CrossViT Model
+
+@inproceedings{
+ chen2021crossvit,
+ title={{CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification}},
+ author={Chun-Fu (Richard) Chen and Quanfu Fan and Rameswar Panda},
+ booktitle={International Conference on Computer Vision (ICCV)},
+ year={2021}
+}
+
+Paper link: https://arxiv.org/abs/2103.14899
+Original code: https://github.com/IBM/CrossViT/blob/main/models/crossvit.py
+
+NOTE: model names have been renamed from originals to represent actual input res all *_224 -> *_240 and *_384 -> *_408
+
+Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
+"""
+
+# Copyright IBM All Rights Reserved.
+# SPDX-License-Identifier: Apache-2.0
+
+
+"""
+Modifed from Timm. https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+
+"""
+from functools import partial
+from typing import List
+from typing import Tuple
+
+import torch
+import torch.hub
+import torch.nn as nn
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.layers import DropPath, to_2tuple, trunc_normal_, _assert
+from ._builder import build_model_with_cfg
+from ._features_fx import register_notrace_function
+from ._registry import register_model, generate_default_cfgs
+from .vision_transformer import Block
+
+__all__ = ['CrossVit'] # model_registry will add each entrypoint fn to this
+
+
+class PatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+ """
+
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, multi_conv=False):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.num_patches = num_patches
+ if multi_conv:
+ if patch_size[0] == 12:
+ self.proj = nn.Sequential(
+ nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=3, padding=0),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=1, padding=1),
+ )
+ elif patch_size[0] == 16:
+ self.proj = nn.Sequential(
+ nn.Conv2d(in_chans, embed_dim // 4, kernel_size=7, stride=4, padding=3),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=3, stride=2, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=3, stride=2, padding=1),
+ )
+ else:
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ # FIXME look at relaxing size constraints
+ _assert(H == self.img_size[0],
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
+ _assert(W == self.img_size[1],
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
+ x = self.proj(x).flatten(2).transpose(1, 2)
+ return x
+
+
+class CrossAttention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ attn_drop=0.,
+ proj_drop=0.,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
+ self.scale = head_dim ** -0.5
+
+ self.wq = nn.Linear(dim, dim, bias=qkv_bias)
+ self.wk = nn.Linear(dim, dim, bias=qkv_bias)
+ self.wv = nn.Linear(dim, dim, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ # B1C -> B1H(C/H) -> BH1(C/H)
+ q = self.wq(x[:, 0:1, ...]).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
+ # BNC -> BNH(C/H) -> BHN(C/H)
+ k = self.wk(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
+ # BNC -> BNH(C/H) -> BHN(C/H)
+ v = self.wv(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale # BH1(C/H) @ BH(C/H)N -> BH1N
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, 1, C) # (BH1N @ BHN(C/H)) -> BH1(C/H) -> B1H(C/H) -> B1C
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class CrossAttentionBlock(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.,
+ qkv_bias=False,
+ proj_drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ ):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = CrossAttention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ attn_drop=attn_drop,
+ proj_drop=proj_drop,
+ )
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def forward(self, x):
+ x = x[:, 0:1, ...] + self.drop_path(self.attn(self.norm1(x)))
+ return x
+
+
+class MultiScaleBlock(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ patches,
+ depth,
+ num_heads,
+ mlp_ratio,
+ qkv_bias=False,
+ proj_drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ ):
+ super().__init__()
+
+ num_branches = len(dim)
+ self.num_branches = num_branches
+ # different branch could have different embedding size, the first one is the base
+ self.blocks = nn.ModuleList()
+ for d in range(num_branches):
+ tmp = []
+ for i in range(depth[d]):
+ tmp.append(Block(
+ dim=dim[d],
+ num_heads=num_heads[d],
+ mlp_ratio=mlp_ratio[d],
+ qkv_bias=qkv_bias,
+ proj_drop=proj_drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path[i],
+ norm_layer=norm_layer,
+ ))
+ if len(tmp) != 0:
+ self.blocks.append(nn.Sequential(*tmp))
+
+ if len(self.blocks) == 0:
+ self.blocks = None
+
+ self.projs = nn.ModuleList()
+ for d in range(num_branches):
+ if dim[d] == dim[(d + 1) % num_branches] and False:
+ tmp = [nn.Identity()]
+ else:
+ tmp = [norm_layer(dim[d]), act_layer(), nn.Linear(dim[d], dim[(d + 1) % num_branches])]
+ self.projs.append(nn.Sequential(*tmp))
+
+ self.fusion = nn.ModuleList()
+ for d in range(num_branches):
+ d_ = (d + 1) % num_branches
+ nh = num_heads[d_]
+ if depth[-1] == 0: # backward capability:
+ self.fusion.append(
+ CrossAttentionBlock(
+ dim=dim[d_],
+ num_heads=nh,
+ mlp_ratio=mlp_ratio[d],
+ qkv_bias=qkv_bias,
+ proj_drop=proj_drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path[-1],
+ norm_layer=norm_layer,
+ ))
+ else:
+ tmp = []
+ for _ in range(depth[-1]):
+ tmp.append(CrossAttentionBlock(
+ dim=dim[d_],
+ num_heads=nh,
+ mlp_ratio=mlp_ratio[d],
+ qkv_bias=qkv_bias,
+ proj_drop=proj_drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path[-1],
+ norm_layer=norm_layer,
+ ))
+ self.fusion.append(nn.Sequential(*tmp))
+
+ self.revert_projs = nn.ModuleList()
+ for d in range(num_branches):
+ if dim[(d + 1) % num_branches] == dim[d] and False:
+ tmp = [nn.Identity()]
+ else:
+ tmp = [norm_layer(dim[(d + 1) % num_branches]), act_layer(),
+ nn.Linear(dim[(d + 1) % num_branches], dim[d])]
+ self.revert_projs.append(nn.Sequential(*tmp))
+
+ def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
+
+ outs_b = []
+ for i, block in enumerate(self.blocks):
+ outs_b.append(block(x[i]))
+
+ # only take the cls token out
+ proj_cls_token = torch.jit.annotate(List[torch.Tensor], [])
+ for i, proj in enumerate(self.projs):
+ proj_cls_token.append(proj(outs_b[i][:, 0:1, ...]))
+
+ # cross attention
+ outs = []
+ for i, (fusion, revert_proj) in enumerate(zip(self.fusion, self.revert_projs)):
+ tmp = torch.cat((proj_cls_token[i], outs_b[(i + 1) % self.num_branches][:, 1:, ...]), dim=1)
+ tmp = fusion(tmp)
+ reverted_proj_cls_token = revert_proj(tmp[:, 0:1, ...])
+ tmp = torch.cat((reverted_proj_cls_token, outs_b[i][:, 1:, ...]), dim=1)
+ outs.append(tmp)
+ return outs
+
+
+def _compute_num_patches(img_size, patches):
+ return [i[0] // p * i[1] // p for i, p in zip(img_size, patches)]
+
+
+@register_notrace_function
+def scale_image(x, ss: Tuple[int, int], crop_scale: bool = False): # annotations for torchscript
+ """
+ Pulled out of CrossViT.forward_features to bury conditional logic in a leaf node for FX tracing.
+ Args:
+ x (Tensor): input image
+ ss (tuple[int, int]): height and width to scale to
+ crop_scale (bool): whether to crop instead of interpolate to achieve the desired scale. Defaults to False
+ Returns:
+ Tensor: the "scaled" image batch tensor
+ """
+ H, W = x.shape[-2:]
+ if H != ss[0] or W != ss[1]:
+ if crop_scale and ss[0] <= H and ss[1] <= W:
+ cu, cl = int(round((H - ss[0]) / 2.)), int(round((W - ss[1]) / 2.))
+ x = x[:, :, cu:cu + ss[0], cl:cl + ss[1]]
+ else:
+ x = torch.nn.functional.interpolate(x, size=ss, mode='bicubic', align_corners=False)
+ return x
+
+
+class CrossVit(nn.Module):
+ """ Vision Transformer with support for patch or hybrid CNN input stage
+ """
+
+ def __init__(
+ self,
+ img_size=224,
+ img_scale=(1.0, 1.0),
+ patch_size=(8, 16),
+ in_chans=3,
+ num_classes=1000,
+ embed_dim=(192, 384),
+ depth=((1, 3, 1), (1, 3, 1), (1, 3, 1)),
+ num_heads=(6, 12),
+ mlp_ratio=(2., 2., 4.),
+ multi_conv=False,
+ crop_scale=False,
+ qkv_bias=True,
+ drop_rate=0.,
+ pos_drop_rate=0.,
+ proj_drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ global_pool='token',
+ ):
+ super().__init__()
+ assert global_pool in ('token', 'avg')
+
+ self.num_classes = num_classes
+ self.global_pool = global_pool
+ self.img_size = to_2tuple(img_size)
+ img_scale = to_2tuple(img_scale)
+ self.img_size_scaled = [tuple([int(sj * si) for sj in self.img_size]) for si in img_scale]
+ self.crop_scale = crop_scale # crop instead of interpolate for scale
+ num_patches = _compute_num_patches(self.img_size_scaled, patch_size)
+ self.num_branches = len(patch_size)
+ self.embed_dim = embed_dim
+ self.num_features = sum(embed_dim)
+ self.patch_embed = nn.ModuleList()
+
+ # hard-coded for torch jit script
+ for i in range(self.num_branches):
+ setattr(self, f'pos_embed_{i}', nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i])))
+ setattr(self, f'cls_token_{i}', nn.Parameter(torch.zeros(1, 1, embed_dim[i])))
+
+ for im_s, p, d in zip(self.img_size_scaled, patch_size, embed_dim):
+ self.patch_embed.append(
+ PatchEmbed(
+ img_size=im_s,
+ patch_size=p,
+ in_chans=in_chans,
+ embed_dim=d,
+ multi_conv=multi_conv,
+ ))
+
+ self.pos_drop = nn.Dropout(p=pos_drop_rate)
+
+ total_depth = sum([sum(x[-2:]) for x in depth])
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, total_depth)] # stochastic depth decay rule
+ dpr_ptr = 0
+ self.blocks = nn.ModuleList()
+ for idx, block_cfg in enumerate(depth):
+ curr_depth = max(block_cfg[:-1]) + block_cfg[-1]
+ dpr_ = dpr[dpr_ptr:dpr_ptr + curr_depth]
+ blk = MultiScaleBlock(
+ embed_dim,
+ num_patches,
+ block_cfg,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_drop=proj_drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr_,
+ norm_layer=norm_layer,
+ )
+ dpr_ptr += curr_depth
+ self.blocks.append(blk)
+
+ self.norm = nn.ModuleList([norm_layer(embed_dim[i]) for i in range(self.num_branches)])
+ self.head_drop = nn.Dropout(drop_rate)
+ self.head = nn.ModuleList([
+ nn.Linear(embed_dim[i], num_classes) if num_classes > 0 else nn.Identity()
+ for i in range(self.num_branches)])
+
+ for i in range(self.num_branches):
+ trunc_normal_(getattr(self, f'pos_embed_{i}'), std=.02)
+ trunc_normal_(getattr(self, f'cls_token_{i}'), std=.02)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ out = set()
+ for i in range(self.num_branches):
+ out.add(f'cls_token_{i}')
+ pe = getattr(self, f'pos_embed_{i}', None)
+ if pe is not None and pe.requires_grad:
+ out.add(f'pos_embed_{i}')
+ return out
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse=False):
+ return dict(
+ stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
+ blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
+ )
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ assert not enable, 'gradient checkpointing not supported'
+
+ @torch.jit.ignore
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=None):
+ self.num_classes = num_classes
+ if global_pool is not None:
+ assert global_pool in ('token', 'avg')
+ self.global_pool = global_pool
+ self.head = nn.ModuleList(
+ [nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in
+ range(self.num_branches)])
+
+ def forward_features(self, x) -> List[torch.Tensor]:
+ B = x.shape[0]
+ xs = []
+ for i, patch_embed in enumerate(self.patch_embed):
+ x_ = x
+ ss = self.img_size_scaled[i]
+ x_ = scale_image(x_, ss, self.crop_scale)
+ x_ = patch_embed(x_)
+ cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script
+ cls_tokens = cls_tokens.expand(B, -1, -1)
+ x_ = torch.cat((cls_tokens, x_), dim=1)
+ pos_embed = self.pos_embed_0 if i == 0 else self.pos_embed_1 # hard-coded for torch jit script
+ x_ = x_ + pos_embed
+ x_ = self.pos_drop(x_)
+ xs.append(x_)
+
+ for i, blk in enumerate(self.blocks):
+ xs = blk(xs)
+
+ # NOTE: was before branch token section, move to here to assure all branch token are before layer norm
+ xs = [norm(xs[i]) for i, norm in enumerate(self.norm)]
+ return xs
+
+ def forward_head(self, xs: List[torch.Tensor], pre_logits: bool = False) -> torch.Tensor:
+ xs = [x[:, 1:].mean(dim=1) for x in xs] if self.global_pool == 'avg' else [x[:, 0] for x in xs]
+ xs = [self.head_drop(x) for x in xs]
+ if pre_logits or isinstance(self.head[0], nn.Identity):
+ return torch.cat([x for x in xs], dim=1)
+ return torch.mean(torch.stack([head(xs[i]) for i, head in enumerate(self.head)], dim=0), dim=0)
+
+ def forward(self, x):
+ xs = self.forward_features(x)
+ x = self.forward_head(xs)
+ return x
+
+
+def _create_crossvit(variant, pretrained=False, **kwargs):
+ if kwargs.get('features_only', None):
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
+
+ def pretrained_filter_fn(state_dict):
+ new_state_dict = {}
+ for key in state_dict.keys():
+ if 'pos_embed' in key or 'cls_token' in key:
+ new_key = key.replace(".", "_")
+ else:
+ new_key = key
+ new_state_dict[new_key] = state_dict[key]
+ return new_state_dict
+
+ return build_model_with_cfg(
+ CrossVit,
+ variant,
+ pretrained,
+ pretrained_filter_fn=pretrained_filter_fn,
+ **kwargs,
+ )
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 240, 240), 'pool_size': None, 'crop_pct': 0.875,
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True,
+ 'first_conv': ('patch_embed.0.proj', 'patch_embed.1.proj'),
+ 'classifier': ('head.0', 'head.1'),
+ **kwargs
+ }
+
+
+default_cfgs = generate_default_cfgs({
+ 'crossvit_15_240.in1k': _cfg(hf_hub_id='timm/'),
+ 'crossvit_15_dagger_240.in1k': _cfg(
+ hf_hub_id='timm/',
+ first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
+ ),
+ 'crossvit_15_dagger_408.in1k': _cfg(
+ hf_hub_id='timm/',
+ input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), crop_pct=1.0,
+ ),
+ 'crossvit_18_240.in1k': _cfg(hf_hub_id='timm/'),
+ 'crossvit_18_dagger_240.in1k': _cfg(
+ hf_hub_id='timm/',
+ first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
+ ),
+ 'crossvit_18_dagger_408.in1k': _cfg(
+ hf_hub_id='timm/',
+ input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'), crop_pct=1.0,
+ ),
+ 'crossvit_9_240.in1k': _cfg(hf_hub_id='timm/'),
+ 'crossvit_9_dagger_240.in1k': _cfg(
+ hf_hub_id='timm/',
+ first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
+ ),
+ 'crossvit_base_240.in1k': _cfg(hf_hub_id='timm/'),
+ 'crossvit_small_240.in1k': _cfg(hf_hub_id='timm/'),
+ 'crossvit_tiny_240.in1k': _cfg(hf_hub_id='timm/'),
+})
+
+
+@register_model
+def crossvit_tiny_240(pretrained=False, **kwargs) -> CrossVit:
+ model_args = dict(
+ img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[96, 192], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
+ num_heads=[3, 3], mlp_ratio=[4, 4, 1])
+ model = _create_crossvit(variant='crossvit_tiny_240', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def crossvit_small_240(pretrained=False, **kwargs) -> CrossVit:
+ model_args = dict(
+ img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
+ num_heads=[6, 6], mlp_ratio=[4, 4, 1])
+ model = _create_crossvit(variant='crossvit_small_240', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def crossvit_base_240(pretrained=False, **kwargs) -> CrossVit:
+ model_args = dict(
+ img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[384, 768], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
+ num_heads=[12, 12], mlp_ratio=[4, 4, 1])
+ model = _create_crossvit(variant='crossvit_base_240', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def crossvit_9_240(pretrained=False, **kwargs) -> CrossVit:
+ model_args = dict(
+ img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]],
+ num_heads=[4, 4], mlp_ratio=[3, 3, 1])
+ model = _create_crossvit(variant='crossvit_9_240', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def crossvit_15_240(pretrained=False, **kwargs) -> CrossVit:
+ model_args = dict(
+ img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
+ num_heads=[6, 6], mlp_ratio=[3, 3, 1])
+ model = _create_crossvit(variant='crossvit_15_240', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def crossvit_18_240(pretrained=False, **kwargs) -> CrossVit:
+ model_args = dict(
+ img_scale=(1.0, 224 / 240), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
+ num_heads=[7, 7], mlp_ratio=[3, 3, 1], **kwargs)
+ model = _create_crossvit(variant='crossvit_18_240', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def crossvit_9_dagger_240(pretrained=False, **kwargs) -> CrossVit:
+ model_args = dict(
+ img_scale=(1.0, 224 / 240), patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]],
+ num_heads=[4, 4], mlp_ratio=[3, 3, 1], multi_conv=True)
+ model = _create_crossvit(variant='crossvit_9_dagger_240', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def crossvit_15_dagger_240(pretrained=False, **kwargs) -> CrossVit:
+ model_args = dict(
+ img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
+ num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True)
+ model = _create_crossvit(variant='crossvit_15_dagger_240', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def crossvit_15_dagger_408(pretrained=False, **kwargs) -> CrossVit:
+ model_args = dict(
+ img_scale=(1.0, 384/408), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
+ num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True)
+ model = _create_crossvit(variant='crossvit_15_dagger_408', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def crossvit_18_dagger_240(pretrained=False, **kwargs) -> CrossVit:
+ model_args = dict(
+ img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
+ num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True)
+ model = _create_crossvit(variant='crossvit_18_dagger_240', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def crossvit_18_dagger_408(pretrained=False, **kwargs) -> CrossVit:
+ model_args = dict(
+ img_scale=(1.0, 384/408), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
+ num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True)
+ model = _create_crossvit(variant='crossvit_18_dagger_408', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/cspnet.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/cspnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..21b2cd344ade099df5815b905a7f8c3a4270b742
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/cspnet.py
@@ -0,0 +1,1106 @@
+"""PyTorch CspNet
+
+A PyTorch implementation of Cross Stage Partial Networks including:
+* CSPResNet50
+* CSPResNeXt50
+* CSPDarkNet53
+* and DarkNet53 for good measure
+
+Based on paper `CSPNet: A New Backbone that can Enhance Learning Capability of CNN` - https://arxiv.org/abs/1911.11929
+
+Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStagePartialNetworks
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+from dataclasses import dataclass, asdict, replace
+from functools import partial
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, get_attn, create_act_layer, make_divisible
+from ._builder import build_model_with_cfg
+from ._manipulate import named_apply, MATCH_PREV_GROUP
+from ._registry import register_model, generate_default_cfgs
+
+__all__ = ['CspNet'] # model_registry will add each entrypoint fn to this
+
+
+@dataclass
+class CspStemCfg:
+ out_chs: Union[int, Tuple[int, ...]] = 32
+ stride: Union[int, Tuple[int, ...]] = 2
+ kernel_size: int = 3
+ padding: Union[int, str] = ''
+ pool: Optional[str] = ''
+
+
+def _pad_arg(x, n):
+ # pads an argument tuple to specified n by padding with last value
+ if not isinstance(x, (tuple, list)):
+ x = (x,)
+ curr_n = len(x)
+ pad_n = n - curr_n
+ if pad_n <= 0:
+ return x[:n]
+ return tuple(x + (x[-1],) * pad_n)
+
+
+@dataclass
+class CspStagesCfg:
+ depth: Tuple[int, ...] = (3, 3, 5, 2) # block depth (number of block repeats in stages)
+ out_chs: Tuple[int, ...] = (128, 256, 512, 1024) # number of output channels for blocks in stage
+ stride: Union[int, Tuple[int, ...]] = 2 # stride of stage
+ groups: Union[int, Tuple[int, ...]] = 1 # num kxk conv groups
+ block_ratio: Union[float, Tuple[float, ...]] = 1.0
+ bottle_ratio: Union[float, Tuple[float, ...]] = 1. # bottleneck-ratio of blocks in stage
+ avg_down: Union[bool, Tuple[bool, ...]] = False
+ attn_layer: Optional[Union[str, Tuple[str, ...]]] = None
+ attn_kwargs: Optional[Union[Dict, Tuple[Dict]]] = None
+ stage_type: Union[str, Tuple[str]] = 'csp' # stage type ('csp', 'cs2', 'dark')
+ block_type: Union[str, Tuple[str]] = 'bottle' # blocks type for stages ('bottle', 'dark')
+
+ # cross-stage only
+ expand_ratio: Union[float, Tuple[float, ...]] = 1.0
+ cross_linear: Union[bool, Tuple[bool, ...]] = False
+ down_growth: Union[bool, Tuple[bool, ...]] = False
+
+ def __post_init__(self):
+ n = len(self.depth)
+ assert len(self.out_chs) == n
+ self.stride = _pad_arg(self.stride, n)
+ self.groups = _pad_arg(self.groups, n)
+ self.block_ratio = _pad_arg(self.block_ratio, n)
+ self.bottle_ratio = _pad_arg(self.bottle_ratio, n)
+ self.avg_down = _pad_arg(self.avg_down, n)
+ self.attn_layer = _pad_arg(self.attn_layer, n)
+ self.attn_kwargs = _pad_arg(self.attn_kwargs, n)
+ self.stage_type = _pad_arg(self.stage_type, n)
+ self.block_type = _pad_arg(self.block_type, n)
+
+ self.expand_ratio = _pad_arg(self.expand_ratio, n)
+ self.cross_linear = _pad_arg(self.cross_linear, n)
+ self.down_growth = _pad_arg(self.down_growth, n)
+
+
+@dataclass
+class CspModelCfg:
+ stem: CspStemCfg
+ stages: CspStagesCfg
+ zero_init_last: bool = True # zero init last weight (usually bn) in residual path
+ act_layer: str = 'leaky_relu'
+ norm_layer: str = 'batchnorm'
+ aa_layer: Optional[str] = None # FIXME support string factory for this
+
+
+def _cs3_cfg(
+ width_multiplier=1.0,
+ depth_multiplier=1.0,
+ avg_down=False,
+ act_layer='silu',
+ focus=False,
+ attn_layer=None,
+ attn_kwargs=None,
+ bottle_ratio=1.0,
+ block_type='dark',
+):
+ if focus:
+ stem_cfg = CspStemCfg(
+ out_chs=make_divisible(64 * width_multiplier),
+ kernel_size=6, stride=2, padding=2, pool='')
+ else:
+ stem_cfg = CspStemCfg(
+ out_chs=tuple([make_divisible(c * width_multiplier) for c in (32, 64)]),
+ kernel_size=3, stride=2, pool='')
+ return CspModelCfg(
+ stem=stem_cfg,
+ stages=CspStagesCfg(
+ out_chs=tuple([make_divisible(c * width_multiplier) for c in (128, 256, 512, 1024)]),
+ depth=tuple([int(d * depth_multiplier) for d in (3, 6, 9, 3)]),
+ stride=2,
+ bottle_ratio=bottle_ratio,
+ block_ratio=0.5,
+ avg_down=avg_down,
+ attn_layer=attn_layer,
+ attn_kwargs=attn_kwargs,
+ stage_type='cs3',
+ block_type=block_type,
+ ),
+ act_layer=act_layer,
+ )
+
+
+class BottleneckBlock(nn.Module):
+ """ ResNe(X)t Bottleneck Block
+ """
+
+ def __init__(
+ self,
+ in_chs,
+ out_chs,
+ dilation=1,
+ bottle_ratio=0.25,
+ groups=1,
+ act_layer=nn.ReLU,
+ norm_layer=nn.BatchNorm2d,
+ attn_last=False,
+ attn_layer=None,
+ drop_block=None,
+ drop_path=0.
+ ):
+ super(BottleneckBlock, self).__init__()
+ mid_chs = int(round(out_chs * bottle_ratio))
+ ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
+ attn_last = attn_layer is not None and attn_last
+ attn_first = attn_layer is not None and not attn_last
+
+ self.conv1 = ConvNormAct(in_chs, mid_chs, kernel_size=1, **ckwargs)
+ self.conv2 = ConvNormAct(
+ mid_chs, mid_chs, kernel_size=3, dilation=dilation, groups=groups,
+ drop_layer=drop_block, **ckwargs)
+ self.attn2 = attn_layer(mid_chs, act_layer=act_layer) if attn_first else nn.Identity()
+ self.conv3 = ConvNormAct(mid_chs, out_chs, kernel_size=1, apply_act=False, **ckwargs)
+ self.attn3 = attn_layer(out_chs, act_layer=act_layer) if attn_last else nn.Identity()
+ self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
+ self.act3 = create_act_layer(act_layer)
+
+ def zero_init_last(self):
+ nn.init.zeros_(self.conv3.bn.weight)
+
+ def forward(self, x):
+ shortcut = x
+ x = self.conv1(x)
+ x = self.conv2(x)
+ x = self.attn2(x)
+ x = self.conv3(x)
+ x = self.attn3(x)
+ x = self.drop_path(x) + shortcut
+ # FIXME partial shortcut needed if first block handled as per original, not used for my current impl
+ #x[:, :shortcut.size(1)] += shortcut
+ x = self.act3(x)
+ return x
+
+
+class DarkBlock(nn.Module):
+ """ DarkNet Block
+ """
+
+ def __init__(
+ self,
+ in_chs,
+ out_chs,
+ dilation=1,
+ bottle_ratio=0.5,
+ groups=1,
+ act_layer=nn.ReLU,
+ norm_layer=nn.BatchNorm2d,
+ attn_layer=None,
+ drop_block=None,
+ drop_path=0.
+ ):
+ super(DarkBlock, self).__init__()
+ mid_chs = int(round(out_chs * bottle_ratio))
+ ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
+
+ self.conv1 = ConvNormAct(in_chs, mid_chs, kernel_size=1, **ckwargs)
+ self.attn = attn_layer(mid_chs, act_layer=act_layer) if attn_layer is not None else nn.Identity()
+ self.conv2 = ConvNormAct(
+ mid_chs, out_chs, kernel_size=3, dilation=dilation, groups=groups,
+ drop_layer=drop_block, **ckwargs)
+ self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
+
+ def zero_init_last(self):
+ nn.init.zeros_(self.conv2.bn.weight)
+
+ def forward(self, x):
+ shortcut = x
+ x = self.conv1(x)
+ x = self.attn(x)
+ x = self.conv2(x)
+ x = self.drop_path(x) + shortcut
+ return x
+
+
+class EdgeBlock(nn.Module):
+ """ EdgeResidual / Fused-MBConv / MobileNetV1-like 3x3 + 1x1 block (w/ activated output)
+ """
+
+ def __init__(
+ self,
+ in_chs,
+ out_chs,
+ dilation=1,
+ bottle_ratio=0.5,
+ groups=1,
+ act_layer=nn.ReLU,
+ norm_layer=nn.BatchNorm2d,
+ attn_layer=None,
+ drop_block=None,
+ drop_path=0.
+ ):
+ super(EdgeBlock, self).__init__()
+ mid_chs = int(round(out_chs * bottle_ratio))
+ ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
+
+ self.conv1 = ConvNormAct(
+ in_chs, mid_chs, kernel_size=3, dilation=dilation, groups=groups,
+ drop_layer=drop_block, **ckwargs)
+ self.attn = attn_layer(mid_chs, act_layer=act_layer) if attn_layer is not None else nn.Identity()
+ self.conv2 = ConvNormAct(mid_chs, out_chs, kernel_size=1, **ckwargs)
+ self.drop_path = DropPath(drop_path) if drop_path else nn.Identity()
+
+ def zero_init_last(self):
+ nn.init.zeros_(self.conv2.bn.weight)
+
+ def forward(self, x):
+ shortcut = x
+ x = self.conv1(x)
+ x = self.attn(x)
+ x = self.conv2(x)
+ x = self.drop_path(x) + shortcut
+ return x
+
+
+class CrossStage(nn.Module):
+ """Cross Stage."""
+ def __init__(
+ self,
+ in_chs,
+ out_chs,
+ stride,
+ dilation,
+ depth,
+ block_ratio=1.,
+ bottle_ratio=1.,
+ expand_ratio=1.,
+ groups=1,
+ first_dilation=None,
+ avg_down=False,
+ down_growth=False,
+ cross_linear=False,
+ block_dpr=None,
+ block_fn=BottleneckBlock,
+ **block_kwargs,
+ ):
+ super(CrossStage, self).__init__()
+ first_dilation = first_dilation or dilation
+ down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels
+ self.expand_chs = exp_chs = int(round(out_chs * expand_ratio))
+ block_out_chs = int(round(out_chs * block_ratio))
+ conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
+ aa_layer = block_kwargs.pop('aa_layer', None)
+
+ if stride != 1 or first_dilation != dilation:
+ if avg_down:
+ self.conv_down = nn.Sequential(
+ nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
+ ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
+ )
+ else:
+ self.conv_down = ConvNormActAa(
+ in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
+ aa_layer=aa_layer, **conv_kwargs)
+ prev_chs = down_chs
+ else:
+ self.conv_down = nn.Identity()
+ prev_chs = in_chs
+
+ # FIXME this 1x1 expansion is pushed down into the cross and block paths in the darknet cfgs. Also,
+ # there is also special case for the first stage for some of the model that results in uneven split
+ # across the two paths. I did it this way for simplicity for now.
+ self.conv_exp = ConvNormAct(prev_chs, exp_chs, kernel_size=1, apply_act=not cross_linear, **conv_kwargs)
+ prev_chs = exp_chs // 2 # output of conv_exp is always split in two
+
+ self.blocks = nn.Sequential()
+ for i in range(depth):
+ self.blocks.add_module(str(i), block_fn(
+ in_chs=prev_chs,
+ out_chs=block_out_chs,
+ dilation=dilation,
+ bottle_ratio=bottle_ratio,
+ groups=groups,
+ drop_path=block_dpr[i] if block_dpr is not None else 0.,
+ **block_kwargs,
+ ))
+ prev_chs = block_out_chs
+
+ # transition convs
+ self.conv_transition_b = ConvNormAct(prev_chs, exp_chs // 2, kernel_size=1, **conv_kwargs)
+ self.conv_transition = ConvNormAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs)
+
+ def forward(self, x):
+ x = self.conv_down(x)
+ x = self.conv_exp(x)
+ xs, xb = x.split(self.expand_chs // 2, dim=1)
+ xb = self.blocks(xb)
+ xb = self.conv_transition_b(xb).contiguous()
+ out = self.conv_transition(torch.cat([xs, xb], dim=1))
+ return out
+
+
+class CrossStage3(nn.Module):
+ """Cross Stage 3.
+ Similar to CrossStage, but with only one transition conv for the output.
+ """
+ def __init__(
+ self,
+ in_chs,
+ out_chs,
+ stride,
+ dilation,
+ depth,
+ block_ratio=1.,
+ bottle_ratio=1.,
+ expand_ratio=1.,
+ groups=1,
+ first_dilation=None,
+ avg_down=False,
+ down_growth=False,
+ cross_linear=False,
+ block_dpr=None,
+ block_fn=BottleneckBlock,
+ **block_kwargs,
+ ):
+ super(CrossStage3, self).__init__()
+ first_dilation = first_dilation or dilation
+ down_chs = out_chs if down_growth else in_chs # grow downsample channels to output channels
+ self.expand_chs = exp_chs = int(round(out_chs * expand_ratio))
+ block_out_chs = int(round(out_chs * block_ratio))
+ conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
+ aa_layer = block_kwargs.pop('aa_layer', None)
+
+ if stride != 1 or first_dilation != dilation:
+ if avg_down:
+ self.conv_down = nn.Sequential(
+ nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
+ ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
+ )
+ else:
+ self.conv_down = ConvNormActAa(
+ in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
+ aa_layer=aa_layer, **conv_kwargs)
+ prev_chs = down_chs
+ else:
+ self.conv_down = None
+ prev_chs = in_chs
+
+ # expansion conv
+ self.conv_exp = ConvNormAct(prev_chs, exp_chs, kernel_size=1, apply_act=not cross_linear, **conv_kwargs)
+ prev_chs = exp_chs // 2 # expanded output is split in 2 for blocks and cross stage
+
+ self.blocks = nn.Sequential()
+ for i in range(depth):
+ self.blocks.add_module(str(i), block_fn(
+ in_chs=prev_chs,
+ out_chs=block_out_chs,
+ dilation=dilation,
+ bottle_ratio=bottle_ratio,
+ groups=groups,
+ drop_path=block_dpr[i] if block_dpr is not None else 0.,
+ **block_kwargs,
+ ))
+ prev_chs = block_out_chs
+
+ # transition convs
+ self.conv_transition = ConvNormAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs)
+
+ def forward(self, x):
+ x = self.conv_down(x)
+ x = self.conv_exp(x)
+ x1, x2 = x.split(self.expand_chs // 2, dim=1)
+ x1 = self.blocks(x1)
+ out = self.conv_transition(torch.cat([x1, x2], dim=1))
+ return out
+
+
+class DarkStage(nn.Module):
+ """DarkNet stage."""
+
+ def __init__(
+ self,
+ in_chs,
+ out_chs,
+ stride,
+ dilation,
+ depth,
+ block_ratio=1.,
+ bottle_ratio=1.,
+ groups=1,
+ first_dilation=None,
+ avg_down=False,
+ block_fn=BottleneckBlock,
+ block_dpr=None,
+ **block_kwargs,
+ ):
+ super(DarkStage, self).__init__()
+ first_dilation = first_dilation or dilation
+ conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
+ aa_layer = block_kwargs.pop('aa_layer', None)
+
+ if avg_down:
+ self.conv_down = nn.Sequential(
+ nn.AvgPool2d(2) if stride == 2 else nn.Identity(), # FIXME dilation handling
+ ConvNormActAa(in_chs, out_chs, kernel_size=1, stride=1, groups=groups, **conv_kwargs)
+ )
+ else:
+ self.conv_down = ConvNormActAa(
+ in_chs, out_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
+ aa_layer=aa_layer, **conv_kwargs)
+
+ prev_chs = out_chs
+ block_out_chs = int(round(out_chs * block_ratio))
+ self.blocks = nn.Sequential()
+ for i in range(depth):
+ self.blocks.add_module(str(i), block_fn(
+ in_chs=prev_chs,
+ out_chs=block_out_chs,
+ dilation=dilation,
+ bottle_ratio=bottle_ratio,
+ groups=groups,
+ drop_path=block_dpr[i] if block_dpr is not None else 0.,
+ **block_kwargs
+ ))
+ prev_chs = block_out_chs
+
+ def forward(self, x):
+ x = self.conv_down(x)
+ x = self.blocks(x)
+ return x
+
+
+def create_csp_stem(
+ in_chans=3,
+ out_chs=32,
+ kernel_size=3,
+ stride=2,
+ pool='',
+ padding='',
+ act_layer=nn.ReLU,
+ norm_layer=nn.BatchNorm2d,
+ aa_layer=None,
+):
+ stem = nn.Sequential()
+ feature_info = []
+ if not isinstance(out_chs, (tuple, list)):
+ out_chs = [out_chs]
+ stem_depth = len(out_chs)
+ assert stem_depth
+ assert stride in (1, 2, 4)
+ prev_feat = None
+ prev_chs = in_chans
+ last_idx = stem_depth - 1
+ stem_stride = 1
+ for i, chs in enumerate(out_chs):
+ conv_name = f'conv{i + 1}'
+ conv_stride = 2 if (i == 0 and stride > 1) or (i == last_idx and stride > 2 and not pool) else 1
+ if conv_stride > 1 and prev_feat is not None:
+ feature_info.append(prev_feat)
+ stem.add_module(conv_name, ConvNormAct(
+ prev_chs, chs, kernel_size,
+ stride=conv_stride,
+ padding=padding if i == 0 else '',
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ ))
+ stem_stride *= conv_stride
+ prev_chs = chs
+ prev_feat = dict(num_chs=prev_chs, reduction=stem_stride, module='.'.join(['stem', conv_name]))
+ if pool:
+ assert stride > 2
+ if prev_feat is not None:
+ feature_info.append(prev_feat)
+ if aa_layer is not None:
+ stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=1, padding=1))
+ stem.add_module('aa', aa_layer(channels=prev_chs, stride=2))
+ pool_name = 'aa'
+ else:
+ stem.add_module('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
+ pool_name = 'pool'
+ stem_stride *= 2
+ prev_feat = dict(num_chs=prev_chs, reduction=stem_stride, module='.'.join(['stem', pool_name]))
+ feature_info.append(prev_feat)
+ return stem, feature_info
+
+
+def _get_stage_fn(stage_args):
+ stage_type = stage_args.pop('stage_type')
+ assert stage_type in ('dark', 'csp', 'cs3')
+ if stage_type == 'dark':
+ stage_args.pop('expand_ratio', None)
+ stage_args.pop('cross_linear', None)
+ stage_args.pop('down_growth', None)
+ stage_fn = DarkStage
+ elif stage_type == 'csp':
+ stage_fn = CrossStage
+ else:
+ stage_fn = CrossStage3
+ return stage_fn, stage_args
+
+
+def _get_block_fn(stage_args):
+ block_type = stage_args.pop('block_type')
+ assert block_type in ('dark', 'edge', 'bottle')
+ if block_type == 'dark':
+ return DarkBlock, stage_args
+ elif block_type == 'edge':
+ return EdgeBlock, stage_args
+ else:
+ return BottleneckBlock, stage_args
+
+
+def _get_attn_fn(stage_args):
+ attn_layer = stage_args.pop('attn_layer')
+ attn_kwargs = stage_args.pop('attn_kwargs', None) or {}
+ if attn_layer is not None:
+ attn_layer = get_attn(attn_layer)
+ if attn_kwargs:
+ attn_layer = partial(attn_layer, **attn_kwargs)
+ return attn_layer, stage_args
+
+
+def create_csp_stages(
+ cfg: CspModelCfg,
+ drop_path_rate: float,
+ output_stride: int,
+ stem_feat: Dict[str, Any],
+):
+ cfg_dict = asdict(cfg.stages)
+ num_stages = len(cfg.stages.depth)
+ cfg_dict['block_dpr'] = [None] * num_stages if not drop_path_rate else \
+ [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.stages.depth)).split(cfg.stages.depth)]
+ stage_args = [dict(zip(cfg_dict.keys(), values)) for values in zip(*cfg_dict.values())]
+ block_kwargs = dict(
+ act_layer=cfg.act_layer,
+ norm_layer=cfg.norm_layer,
+ )
+
+ dilation = 1
+ net_stride = stem_feat['reduction']
+ prev_chs = stem_feat['num_chs']
+ prev_feat = stem_feat
+ feature_info = []
+ stages = []
+ for stage_idx, stage_args in enumerate(stage_args):
+ stage_fn, stage_args = _get_stage_fn(stage_args)
+ block_fn, stage_args = _get_block_fn(stage_args)
+ attn_fn, stage_args = _get_attn_fn(stage_args)
+ stride = stage_args.pop('stride')
+ if stride != 1 and prev_feat:
+ feature_info.append(prev_feat)
+ if net_stride >= output_stride and stride > 1:
+ dilation *= stride
+ stride = 1
+ net_stride *= stride
+ first_dilation = 1 if dilation in (1, 2) else 2
+
+ stages += [stage_fn(
+ prev_chs,
+ **stage_args,
+ stride=stride,
+ first_dilation=first_dilation,
+ dilation=dilation,
+ block_fn=block_fn,
+ aa_layer=cfg.aa_layer,
+ attn_layer=attn_fn, # will be passed through stage as block_kwargs
+ **block_kwargs,
+ )]
+ prev_chs = stage_args['out_chs']
+ prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}')
+
+ feature_info.append(prev_feat)
+ return nn.Sequential(*stages), feature_info
+
+
+class CspNet(nn.Module):
+ """Cross Stage Partial base model.
+
+ Paper: `CSPNet: A New Backbone that can Enhance Learning Capability of CNN` - https://arxiv.org/abs/1911.11929
+ Ref Impl: https://github.com/WongKinYiu/CrossStagePartialNetworks
+
+ NOTE: There are differences in the way I handle the 1x1 'expansion' conv in this impl vs the
+ darknet impl. I did it this way for simplicity and less special cases.
+ """
+
+ def __init__(
+ self,
+ cfg: CspModelCfg,
+ in_chans=3,
+ num_classes=1000,
+ output_stride=32,
+ global_pool='avg',
+ drop_rate=0.,
+ drop_path_rate=0.,
+ zero_init_last=True,
+ **kwargs,
+ ):
+ """
+ Args:
+ cfg (CspModelCfg): Model architecture configuration
+ in_chans (int): Number of input channels (default: 3)
+ num_classes (int): Number of classifier classes (default: 1000)
+ output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
+ global_pool (str): Global pooling type (default: 'avg')
+ drop_rate (float): Dropout rate (default: 0.)
+ drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
+ zero_init_last (bool): Zero-init last weight of residual path
+ kwargs (dict): Extra kwargs overlayed onto cfg
+ """
+ super().__init__()
+ self.num_classes = num_classes
+ self.drop_rate = drop_rate
+ assert output_stride in (8, 16, 32)
+
+ cfg = replace(cfg, **kwargs) # overlay kwargs onto cfg
+ layer_args = dict(
+ act_layer=cfg.act_layer,
+ norm_layer=cfg.norm_layer,
+ aa_layer=cfg.aa_layer
+ )
+ self.feature_info = []
+
+ # Construct the stem
+ self.stem, stem_feat_info = create_csp_stem(in_chans, **asdict(cfg.stem), **layer_args)
+ self.feature_info.extend(stem_feat_info[:-1])
+
+ # Construct the stages
+ self.stages, stage_feat_info = create_csp_stages(
+ cfg,
+ drop_path_rate=drop_path_rate,
+ output_stride=output_stride,
+ stem_feat=stem_feat_info[-1],
+ )
+ prev_chs = stage_feat_info[-1]['num_chs']
+ self.feature_info.extend(stage_feat_info)
+
+ # Construct the head
+ self.num_features = prev_chs
+ self.head = ClassifierHead(
+ in_features=prev_chs, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
+
+ named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse=False):
+ matcher = dict(
+ stem=r'^stem',
+ blocks=r'^stages\.(\d+)' if coarse else [
+ (r'^stages\.(\d+)\.blocks\.(\d+)', None),
+ (r'^stages\.(\d+)\..*transition', MATCH_PREV_GROUP), # map to last block in stage
+ (r'^stages\.(\d+)', (0,)),
+ ]
+ )
+ return matcher
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ assert not enable, 'gradient checkpointing not supported'
+
+ @torch.jit.ignore
+ def get_classifier(self):
+ return self.head.fc
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
+
+ def forward_features(self, x):
+ x = self.stem(x)
+ x = self.stages(x)
+ return x
+
+ def forward_head(self, x, pre_logits: bool = False):
+ return self.head(x, pre_logits=pre_logits)
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.forward_head(x)
+ return x
+
+
+def _init_weights(module, name, zero_init_last=False):
+ if isinstance(module, nn.Conv2d):
+ nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, nn.Linear):
+ nn.init.normal_(module.weight, mean=0.0, std=0.01)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif zero_init_last and hasattr(module, 'zero_init_last'):
+ module.zero_init_last()
+
+
+model_cfgs = dict(
+ cspresnet50=CspModelCfg(
+ stem=CspStemCfg(out_chs=64, kernel_size=7, stride=4, pool='max'),
+ stages=CspStagesCfg(
+ depth=(3, 3, 5, 2),
+ out_chs=(128, 256, 512, 1024),
+ stride=(1, 2),
+ expand_ratio=2.,
+ bottle_ratio=0.5,
+ cross_linear=True,
+ ),
+ ),
+ cspresnet50d=CspModelCfg(
+ stem=CspStemCfg(out_chs=(32, 32, 64), kernel_size=3, stride=4, pool='max'),
+ stages=CspStagesCfg(
+ depth=(3, 3, 5, 2),
+ out_chs=(128, 256, 512, 1024),
+ stride=(1,) + (2,),
+ expand_ratio=2.,
+ bottle_ratio=0.5,
+ block_ratio=1.,
+ cross_linear=True,
+ ),
+ ),
+ cspresnet50w=CspModelCfg(
+ stem=CspStemCfg(out_chs=(32, 32, 64), kernel_size=3, stride=4, pool='max'),
+ stages=CspStagesCfg(
+ depth=(3, 3, 5, 2),
+ out_chs=(256, 512, 1024, 2048),
+ stride=(1,) + (2,),
+ expand_ratio=1.,
+ bottle_ratio=0.25,
+ block_ratio=0.5,
+ cross_linear=True,
+ ),
+ ),
+ cspresnext50=CspModelCfg(
+ stem=CspStemCfg(out_chs=64, kernel_size=7, stride=4, pool='max'),
+ stages=CspStagesCfg(
+ depth=(3, 3, 5, 2),
+ out_chs=(256, 512, 1024, 2048),
+ stride=(1,) + (2,),
+ groups=32,
+ expand_ratio=1.,
+ bottle_ratio=1.,
+ block_ratio=0.5,
+ cross_linear=True,
+ ),
+ ),
+ cspdarknet53=CspModelCfg(
+ stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
+ stages=CspStagesCfg(
+ depth=(1, 2, 8, 8, 4),
+ out_chs=(64, 128, 256, 512, 1024),
+ stride=2,
+ expand_ratio=(2.,) + (1.,),
+ bottle_ratio=(0.5,) + (1.,),
+ block_ratio=(1.,) + (0.5,),
+ down_growth=True,
+ block_type='dark',
+ ),
+ ),
+ darknet17=CspModelCfg(
+ stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
+ stages=CspStagesCfg(
+ depth=(1,) * 5,
+ out_chs=(64, 128, 256, 512, 1024),
+ stride=(2,),
+ bottle_ratio=(0.5,),
+ block_ratio=(1.,),
+ stage_type='dark',
+ block_type='dark',
+ ),
+ ),
+ darknet21=CspModelCfg(
+ stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
+ stages=CspStagesCfg(
+ depth=(1, 1, 1, 2, 2),
+ out_chs=(64, 128, 256, 512, 1024),
+ stride=(2,),
+ bottle_ratio=(0.5,),
+ block_ratio=(1.,),
+ stage_type='dark',
+ block_type='dark',
+
+ ),
+ ),
+ sedarknet21=CspModelCfg(
+ stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
+ stages=CspStagesCfg(
+ depth=(1, 1, 1, 2, 2),
+ out_chs=(64, 128, 256, 512, 1024),
+ stride=2,
+ bottle_ratio=0.5,
+ block_ratio=1.,
+ attn_layer='se',
+ stage_type='dark',
+ block_type='dark',
+
+ ),
+ ),
+ darknet53=CspModelCfg(
+ stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
+ stages=CspStagesCfg(
+ depth=(1, 2, 8, 8, 4),
+ out_chs=(64, 128, 256, 512, 1024),
+ stride=2,
+ bottle_ratio=0.5,
+ block_ratio=1.,
+ stage_type='dark',
+ block_type='dark',
+ ),
+ ),
+ darknetaa53=CspModelCfg(
+ stem=CspStemCfg(out_chs=32, kernel_size=3, stride=1, pool=''),
+ stages=CspStagesCfg(
+ depth=(1, 2, 8, 8, 4),
+ out_chs=(64, 128, 256, 512, 1024),
+ stride=2,
+ bottle_ratio=0.5,
+ block_ratio=1.,
+ avg_down=True,
+ stage_type='dark',
+ block_type='dark',
+ ),
+ ),
+
+ cs3darknet_s=_cs3_cfg(width_multiplier=0.5, depth_multiplier=0.5),
+ cs3darknet_m=_cs3_cfg(width_multiplier=0.75, depth_multiplier=0.67),
+ cs3darknet_l=_cs3_cfg(),
+ cs3darknet_x=_cs3_cfg(width_multiplier=1.25, depth_multiplier=1.33),
+
+ cs3darknet_focus_s=_cs3_cfg(width_multiplier=0.5, depth_multiplier=0.5, focus=True),
+ cs3darknet_focus_m=_cs3_cfg(width_multiplier=0.75, depth_multiplier=0.67, focus=True),
+ cs3darknet_focus_l=_cs3_cfg(focus=True),
+ cs3darknet_focus_x=_cs3_cfg(width_multiplier=1.25, depth_multiplier=1.33, focus=True),
+
+ cs3sedarknet_l=_cs3_cfg(attn_layer='se', attn_kwargs=dict(rd_ratio=.25)),
+ cs3sedarknet_x=_cs3_cfg(attn_layer='se', width_multiplier=1.25, depth_multiplier=1.33),
+
+ cs3sedarknet_xdw=CspModelCfg(
+ stem=CspStemCfg(out_chs=(32, 64), kernel_size=3, stride=2, pool=''),
+ stages=CspStagesCfg(
+ depth=(3, 6, 12, 4),
+ out_chs=(256, 512, 1024, 2048),
+ stride=2,
+ groups=(1, 1, 256, 512),
+ bottle_ratio=0.5,
+ block_ratio=0.5,
+ attn_layer='se',
+ ),
+ act_layer='silu',
+ ),
+
+ cs3edgenet_x=_cs3_cfg(width_multiplier=1.25, depth_multiplier=1.33, bottle_ratio=1.5, block_type='edge'),
+ cs3se_edgenet_x=_cs3_cfg(
+ width_multiplier=1.25, depth_multiplier=1.33, bottle_ratio=1.5, block_type='edge',
+ attn_layer='se', attn_kwargs=dict(rd_ratio=.25)),
+)
+
+
+def _create_cspnet(variant, pretrained=False, **kwargs):
+ if variant.startswith('darknet') or variant.startswith('cspdarknet'):
+ # NOTE: DarkNet is one of few models with stride==1 features w/ 6 out_indices [0..5]
+ default_out_indices = (0, 1, 2, 3, 4, 5)
+ else:
+ default_out_indices = (0, 1, 2, 3, 4)
+ out_indices = kwargs.pop('out_indices', default_out_indices)
+ return build_model_with_cfg(
+ CspNet, variant, pretrained,
+ model_cfg=model_cfgs[variant],
+ feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
+ **kwargs)
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
+ 'crop_pct': 0.887, 'interpolation': 'bilinear',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'stem.conv1.conv', 'classifier': 'head.fc',
+ **kwargs
+ }
+
+
+default_cfgs = generate_default_cfgs({
+ 'cspresnet50.ra_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspresnet50_ra-d3e8d487.pth'),
+ 'cspresnet50d.untrained': _cfg(),
+ 'cspresnet50w.untrained': _cfg(),
+ 'cspresnext50.ra_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspresnext50_ra_224-648b4713.pth',
+ ),
+ 'cspdarknet53.ra_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/cspdarknet53_ra_256-d05c7c21.pth'),
+
+ 'darknet17.untrained': _cfg(),
+ 'darknet21.untrained': _cfg(),
+ 'sedarknet21.untrained': _cfg(),
+ 'darknet53.c2ns_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/darknet53_256_c2ns-3aeff817.pth',
+ interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=1.0),
+ 'darknetaa53.c2ns_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/darknetaa53_c2ns-5c28ec8a.pth',
+ test_input_size=(3, 288, 288), test_crop_pct=1.0),
+
+ 'cs3darknet_s.untrained': _cfg(interpolation='bicubic'),
+ 'cs3darknet_m.c2ns_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_m_c2ns-43f06604.pth',
+ interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95,
+ ),
+ 'cs3darknet_l.c2ns_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_l_c2ns-16220c5d.pth',
+ interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
+ 'cs3darknet_x.c2ns_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_x_c2ns-4e4490aa.pth',
+ interpolation='bicubic', crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
+
+ 'cs3darknet_focus_s.untrained': _cfg(interpolation='bicubic'),
+ 'cs3darknet_focus_m.c2ns_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_focus_m_c2ns-e23bed41.pth',
+ interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
+ 'cs3darknet_focus_l.c2ns_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3darknet_focus_l_c2ns-65ef8888.pth',
+ interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
+ 'cs3darknet_focus_x.untrained': _cfg(interpolation='bicubic'),
+
+ 'cs3sedarknet_l.c2ns_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3sedarknet_l_c2ns-e8d1dc13.pth',
+ interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=0.95),
+ 'cs3sedarknet_x.c2ns_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3sedarknet_x_c2ns-b4d0abc0.pth',
+ interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=1.0),
+
+ 'cs3sedarknet_xdw.untrained': _cfg(interpolation='bicubic'),
+
+ 'cs3edgenet_x.c2_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3edgenet_x_c2-2e1610a9.pth',
+ interpolation='bicubic', test_input_size=(3, 288, 288), test_crop_pct=1.0),
+ 'cs3se_edgenet_x.c2ns_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/cs3se_edgenet_x_c2ns-76f8e3ac.pth',
+ interpolation='bicubic', crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0),
+})
+
+
+@register_model
+def cspresnet50(pretrained=False, **kwargs) -> CspNet:
+ return _create_cspnet('cspresnet50', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def cspresnet50d(pretrained=False, **kwargs) -> CspNet:
+ return _create_cspnet('cspresnet50d', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def cspresnet50w(pretrained=False, **kwargs) -> CspNet:
+ return _create_cspnet('cspresnet50w', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def cspresnext50(pretrained=False, **kwargs) -> CspNet:
+ return _create_cspnet('cspresnext50', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def cspdarknet53(pretrained=False, **kwargs) -> CspNet:
+ return _create_cspnet('cspdarknet53', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def darknet17(pretrained=False, **kwargs) -> CspNet:
+ return _create_cspnet('darknet17', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def darknet21(pretrained=False, **kwargs) -> CspNet:
+ return _create_cspnet('darknet21', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def sedarknet21(pretrained=False, **kwargs) -> CspNet:
+ return _create_cspnet('sedarknet21', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def darknet53(pretrained=False, **kwargs) -> CspNet:
+ return _create_cspnet('darknet53', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def darknetaa53(pretrained=False, **kwargs) -> CspNet:
+ return _create_cspnet('darknetaa53', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def cs3darknet_s(pretrained=False, **kwargs) -> CspNet:
+ return _create_cspnet('cs3darknet_s', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def cs3darknet_m(pretrained=False, **kwargs) -> CspNet:
+ return _create_cspnet('cs3darknet_m', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def cs3darknet_l(pretrained=False, **kwargs) -> CspNet:
+ return _create_cspnet('cs3darknet_l', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def cs3darknet_x(pretrained=False, **kwargs) -> CspNet:
+ return _create_cspnet('cs3darknet_x', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def cs3darknet_focus_s(pretrained=False, **kwargs) -> CspNet:
+ return _create_cspnet('cs3darknet_focus_s', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def cs3darknet_focus_m(pretrained=False, **kwargs) -> CspNet:
+ return _create_cspnet('cs3darknet_focus_m', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def cs3darknet_focus_l(pretrained=False, **kwargs) -> CspNet:
+ return _create_cspnet('cs3darknet_focus_l', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def cs3darknet_focus_x(pretrained=False, **kwargs) -> CspNet:
+ return _create_cspnet('cs3darknet_focus_x', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def cs3sedarknet_l(pretrained=False, **kwargs) -> CspNet:
+ return _create_cspnet('cs3sedarknet_l', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def cs3sedarknet_x(pretrained=False, **kwargs) -> CspNet:
+ return _create_cspnet('cs3sedarknet_x', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def cs3sedarknet_xdw(pretrained=False, **kwargs) -> CspNet:
+ return _create_cspnet('cs3sedarknet_xdw', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def cs3edgenet_x(pretrained=False, **kwargs) -> CspNet:
+ return _create_cspnet('cs3edgenet_x', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def cs3se_edgenet_x(pretrained=False, **kwargs) -> CspNet:
+ return _create_cspnet('cs3se_edgenet_x', pretrained=pretrained, **kwargs)
\ No newline at end of file
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/deit.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/deit.py
new file mode 100644
index 0000000000000000000000000000000000000000..f80087e80df1f677c4fc90eaf1ec65b10db1f2c9
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/deit.py
@@ -0,0 +1,416 @@
+""" DeiT - Data-efficient Image Transformers
+
+DeiT model defs and weights from https://github.com/facebookresearch/deit, original copyright below
+
+paper: `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
+
+paper: `DeiT III: Revenge of the ViT` - https://arxiv.org/abs/2204.07118
+
+Modifications copyright 2021, Ross Wightman
+"""
+# Copyright (c) 2015-present, Facebook, Inc.
+# All rights reserved.
+from functools import partial
+from typing import Sequence, Union
+
+import torch
+from torch import nn as nn
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.layers import resample_abs_pos_embed
+from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn
+from ._builder import build_model_with_cfg
+from ._manipulate import checkpoint_seq
+from ._registry import generate_default_cfgs, register_model, register_model_deprecations
+
+__all__ = ['VisionTransformerDistilled'] # model_registry will add each entrypoint fn to this
+
+
+class VisionTransformerDistilled(VisionTransformer):
+ """ Vision Transformer w/ Distillation Token and Head
+
+ Distillation token & head support for `DeiT: Data-efficient Image Transformers`
+ - https://arxiv.org/abs/2012.12877
+ """
+
+ def __init__(self, *args, **kwargs):
+ weight_init = kwargs.pop('weight_init', '')
+ super().__init__(*args, **kwargs, weight_init='skip')
+ assert self.global_pool in ('token',)
+
+ self.num_prefix_tokens = 2
+ self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
+ self.pos_embed = nn.Parameter(
+ torch.zeros(1, self.patch_embed.num_patches + self.num_prefix_tokens, self.embed_dim))
+ self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
+ self.distilled_training = False # must set this True to train w/ distillation token
+
+ self.init_weights(weight_init)
+
+ def init_weights(self, mode=''):
+ trunc_normal_(self.dist_token, std=.02)
+ super().init_weights(mode=mode)
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse=False):
+ return dict(
+ stem=r'^cls_token|pos_embed|patch_embed|dist_token',
+ blocks=[
+ (r'^blocks\.(\d+)', None),
+ (r'^norm', (99999,))] # final norm w/ last block
+ )
+
+ @torch.jit.ignore
+ def get_classifier(self):
+ return self.head, self.head_dist
+
+ def reset_classifier(self, num_classes, global_pool=None):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+ self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
+
+ @torch.jit.ignore
+ def set_distilled_training(self, enable=True):
+ self.distilled_training = enable
+
+ def _pos_embed(self, x):
+ if self.dynamic_img_size:
+ B, H, W, C = x.shape
+ pos_embed = resample_abs_pos_embed(
+ self.pos_embed,
+ (H, W),
+ num_prefix_tokens=0 if self.no_embed_class else self.num_prefix_tokens,
+ )
+ x = x.view(B, -1, C)
+ else:
+ pos_embed = self.pos_embed
+ if self.no_embed_class:
+ # deit-3, updated JAX (big vision)
+ # position embedding does not overlap with class token, add then concat
+ x = x + pos_embed
+ x = torch.cat((
+ self.cls_token.expand(x.shape[0], -1, -1),
+ self.dist_token.expand(x.shape[0], -1, -1),
+ x),
+ dim=1)
+ else:
+ # original timm, JAX, and deit vit impl
+ # pos_embed has entry for class token, concat then add
+ x = torch.cat((
+ self.cls_token.expand(x.shape[0], -1, -1),
+ self.dist_token.expand(x.shape[0], -1, -1),
+ x),
+ dim=1)
+ x = x + pos_embed
+ return self.pos_drop(x)
+
+ def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
+ x, x_dist = x[:, 0], x[:, 1]
+ if pre_logits:
+ return (x + x_dist) / 2
+ x = self.head(x)
+ x_dist = self.head_dist(x_dist)
+ if self.distilled_training and self.training and not torch.jit.is_scripting():
+ # only return separate classification predictions when training in distilled mode
+ return x, x_dist
+ else:
+ # during standard train / finetune, inference average the classifier predictions
+ return (x + x_dist) / 2
+
+
+def _create_deit(variant, pretrained=False, distilled=False, **kwargs):
+ if kwargs.get('features_only', None):
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
+ model_cls = VisionTransformerDistilled if distilled else VisionTransformer
+ model = build_model_with_cfg(
+ model_cls,
+ variant,
+ pretrained,
+ pretrained_filter_fn=partial(checkpoint_filter_fn, adapt_layer_scale=True),
+ **kwargs,
+ )
+ return model
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
+ **kwargs
+ }
+
+
+default_cfgs = generate_default_cfgs({
+ # deit models (FB weights)
+ 'deit_tiny_patch16_224.fb_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
+ 'deit_small_patch16_224.fb_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
+ 'deit_base_patch16_224.fb_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth'),
+ 'deit_base_patch16_384.fb_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
+ input_size=(3, 384, 384), crop_pct=1.0),
+
+ 'deit_tiny_distilled_patch16_224.fb_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
+ classifier=('head', 'head_dist')),
+ 'deit_small_distilled_patch16_224.fb_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
+ classifier=('head', 'head_dist')),
+ 'deit_base_distilled_patch16_224.fb_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
+ classifier=('head', 'head_dist')),
+ 'deit_base_distilled_patch16_384.fb_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
+ input_size=(3, 384, 384), crop_pct=1.0,
+ classifier=('head', 'head_dist')),
+
+ 'deit3_small_patch16_224.fb_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_1k.pth'),
+ 'deit3_small_patch16_384.fb_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_1k.pth',
+ input_size=(3, 384, 384), crop_pct=1.0),
+ 'deit3_medium_patch16_224.fb_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/deit_3_medium_224_1k.pth'),
+ 'deit3_base_patch16_224.fb_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_1k.pth'),
+ 'deit3_base_patch16_384.fb_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_1k.pth',
+ input_size=(3, 384, 384), crop_pct=1.0),
+ 'deit3_large_patch16_224.fb_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_1k.pth'),
+ 'deit3_large_patch16_384.fb_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_1k.pth',
+ input_size=(3, 384, 384), crop_pct=1.0),
+ 'deit3_huge_patch14_224.fb_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_1k.pth'),
+
+ 'deit3_small_patch16_224.fb_in22k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_21k.pth',
+ crop_pct=1.0),
+ 'deit3_small_patch16_384.fb_in22k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_21k.pth',
+ input_size=(3, 384, 384), crop_pct=1.0),
+ 'deit3_medium_patch16_224.fb_in22k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/deit_3_medium_224_21k.pth',
+ crop_pct=1.0),
+ 'deit3_base_patch16_224.fb_in22k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_21k.pth',
+ crop_pct=1.0),
+ 'deit3_base_patch16_384.fb_in22k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_21k.pth',
+ input_size=(3, 384, 384), crop_pct=1.0),
+ 'deit3_large_patch16_224.fb_in22k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_21k.pth',
+ crop_pct=1.0),
+ 'deit3_large_patch16_384.fb_in22k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_21k.pth',
+ input_size=(3, 384, 384), crop_pct=1.0),
+ 'deit3_huge_patch14_224.fb_in22k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_21k_v1.pth',
+ crop_pct=1.0),
+})
+
+
+@register_model
+def deit_tiny_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
+ """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3)
+ model = _create_deit('deit_tiny_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def deit_small_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
+ """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6)
+ model = _create_deit('deit_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def deit_base_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
+ """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
+ model = _create_deit('deit_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def deit_base_patch16_384(pretrained=False, **kwargs) -> VisionTransformer:
+ """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
+ model = _create_deit('deit_base_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs) -> VisionTransformerDistilled:
+ """ DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_args = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3)
+ model = _create_deit(
+ 'deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def deit_small_distilled_patch16_224(pretrained=False, **kwargs) -> VisionTransformerDistilled:
+ """ DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6)
+ model = _create_deit(
+ 'deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def deit_base_distilled_patch16_224(pretrained=False, **kwargs) -> VisionTransformerDistilled:
+ """ DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
+ model = _create_deit(
+ 'deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def deit_base_distilled_patch16_384(pretrained=False, **kwargs) -> VisionTransformerDistilled:
+ """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
+ model = _create_deit(
+ 'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def deit3_small_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
+ """ DeiT-3 small model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6)
+ model = _create_deit('deit3_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def deit3_small_patch16_384(pretrained=False, **kwargs) -> VisionTransformer:
+ """ DeiT-3 small model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6)
+ model = _create_deit('deit3_small_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def deit3_medium_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
+ """ DeiT-3 medium model @ 224x224 (https://arxiv.org/abs/2012.12877).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_args = dict(patch_size=16, embed_dim=512, depth=12, num_heads=8, no_embed_class=True, init_values=1e-6)
+ model = _create_deit('deit3_medium_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def deit3_base_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
+ """ DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6)
+ model = _create_deit('deit3_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def deit3_base_patch16_384(pretrained=False, **kwargs) -> VisionTransformer:
+ """ DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_args = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6)
+ model = _create_deit('deit3_base_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def deit3_large_patch16_224(pretrained=False, **kwargs) -> VisionTransformer:
+ """ DeiT-3 large model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6)
+ model = _create_deit('deit3_large_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def deit3_large_patch16_384(pretrained=False, **kwargs) -> VisionTransformer:
+ """ DeiT-3 large model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_args = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6)
+ model = _create_deit('deit3_large_patch16_384', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def deit3_huge_patch14_224(pretrained=False, **kwargs) -> VisionTransformer:
+ """ DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
+ ImageNet-1k weights from https://github.com/facebookresearch/deit.
+ """
+ model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, no_embed_class=True, init_values=1e-6)
+ model = _create_deit('deit3_huge_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+register_model_deprecations(__name__, {
+ 'deit3_small_patch16_224_in21ft1k': 'deit3_small_patch16_224.fb_in22k_ft_in1k',
+ 'deit3_small_patch16_384_in21ft1k': 'deit3_small_patch16_384.fb_in22k_ft_in1k',
+ 'deit3_medium_patch16_224_in21ft1k': 'deit3_medium_patch16_224.fb_in22k_ft_in1k',
+ 'deit3_base_patch16_224_in21ft1k': 'deit3_base_patch16_224.fb_in22k_ft_in1k',
+ 'deit3_base_patch16_384_in21ft1k': 'deit3_base_patch16_384.fb_in22k_ft_in1k',
+ 'deit3_large_patch16_224_in21ft1k': 'deit3_large_patch16_224.fb_in22k_ft_in1k',
+ 'deit3_large_patch16_384_in21ft1k': 'deit3_large_patch16_384.fb_in22k_ft_in1k',
+ 'deit3_huge_patch14_224_in21ft1k': 'deit3_huge_patch14_224.fb_in22k_ft_in1k'
+})
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/dla.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/dla.py
new file mode 100644
index 0000000000000000000000000000000000000000..3052819db7a3a1c422be1b8f58cd87686f8f13e5
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/dla.py
@@ -0,0 +1,515 @@
+""" Deep Layer Aggregation and DLA w/ Res2Net
+DLA original adapted from Official Pytorch impl at: https://github.com/ucbdrive/dla
+DLA Paper: `Deep Layer Aggregation` - https://arxiv.org/abs/1707.06484
+
+Res2Net additions from: https://github.com/gasvn/Res2Net/
+Res2Net Paper: `Res2Net: A New Multi-scale Backbone Architecture` - https://arxiv.org/abs/1904.01169
+"""
+import math
+from typing import List, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.layers import create_classifier
+from ._builder import build_model_with_cfg
+from ._registry import register_model, generate_default_cfgs
+
+__all__ = ['DLA']
+
+
+class DlaBasic(nn.Module):
+ """DLA Basic"""
+
+ def __init__(self, inplanes, planes, stride=1, dilation=1, **_):
+ super(DlaBasic, self).__init__()
+ self.conv1 = nn.Conv2d(
+ inplanes, planes, kernel_size=3,
+ stride=stride, padding=dilation, bias=False, dilation=dilation)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = nn.Conv2d(
+ planes, planes, kernel_size=3,
+ stride=1, padding=dilation, bias=False, dilation=dilation)
+ self.bn2 = nn.BatchNorm2d(planes)
+ self.stride = stride
+
+ def forward(self, x, shortcut: Optional[torch.Tensor] = None, children: Optional[List[torch.Tensor]] = None):
+ if shortcut is None:
+ shortcut = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ out += shortcut
+ out = self.relu(out)
+
+ return out
+
+
+class DlaBottleneck(nn.Module):
+ """DLA/DLA-X Bottleneck"""
+ expansion = 2
+
+ def __init__(self, inplanes, outplanes, stride=1, dilation=1, cardinality=1, base_width=64):
+ super(DlaBottleneck, self).__init__()
+ self.stride = stride
+ mid_planes = int(math.floor(outplanes * (base_width / 64)) * cardinality)
+ mid_planes = mid_planes // self.expansion
+
+ self.conv1 = nn.Conv2d(inplanes, mid_planes, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(mid_planes)
+ self.conv2 = nn.Conv2d(
+ mid_planes, mid_planes, kernel_size=3,
+ stride=stride, padding=dilation, bias=False, dilation=dilation, groups=cardinality)
+ self.bn2 = nn.BatchNorm2d(mid_planes)
+ self.conv3 = nn.Conv2d(mid_planes, outplanes, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(outplanes)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x, shortcut: Optional[torch.Tensor] = None, children: Optional[List[torch.Tensor]] = None):
+ if shortcut is None:
+ shortcut = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ out += shortcut
+ out = self.relu(out)
+
+ return out
+
+
+class DlaBottle2neck(nn.Module):
+ """ Res2Net/Res2NeXT DLA Bottleneck
+ Adapted from https://github.com/gasvn/Res2Net/blob/master/dla.py
+ """
+ expansion = 2
+
+ def __init__(self, inplanes, outplanes, stride=1, dilation=1, scale=4, cardinality=8, base_width=4):
+ super(DlaBottle2neck, self).__init__()
+ self.is_first = stride > 1
+ self.scale = scale
+ mid_planes = int(math.floor(outplanes * (base_width / 64)) * cardinality)
+ mid_planes = mid_planes // self.expansion
+ self.width = mid_planes
+
+ self.conv1 = nn.Conv2d(inplanes, mid_planes * scale, kernel_size=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(mid_planes * scale)
+
+ num_scale_convs = max(1, scale - 1)
+ convs = []
+ bns = []
+ for _ in range(num_scale_convs):
+ convs.append(nn.Conv2d(
+ mid_planes, mid_planes, kernel_size=3,
+ stride=stride, padding=dilation, dilation=dilation, groups=cardinality, bias=False))
+ bns.append(nn.BatchNorm2d(mid_planes))
+ self.convs = nn.ModuleList(convs)
+ self.bns = nn.ModuleList(bns)
+ self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1) if self.is_first else None
+
+ self.conv3 = nn.Conv2d(mid_planes * scale, outplanes, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(outplanes)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x, shortcut: Optional[torch.Tensor] = None, children: Optional[List[torch.Tensor]] = None):
+ if shortcut is None:
+ shortcut = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ spx = torch.split(out, self.width, 1)
+ spo = []
+ sp = spx[0] # redundant, for torchscript
+ for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
+ if i == 0 or self.is_first:
+ sp = spx[i]
+ else:
+ sp = sp + spx[i]
+ sp = conv(sp)
+ sp = bn(sp)
+ sp = self.relu(sp)
+ spo.append(sp)
+ if self.scale > 1:
+ if self.pool is not None: # self.is_first == True, None check for torchscript
+ spo.append(self.pool(spx[-1]))
+ else:
+ spo.append(spx[-1])
+ out = torch.cat(spo, 1)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ out += shortcut
+ out = self.relu(out)
+
+ return out
+
+
+class DlaRoot(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, shortcut):
+ super(DlaRoot, self).__init__()
+ self.conv = nn.Conv2d(
+ in_channels, out_channels, 1, stride=1, bias=False, padding=(kernel_size - 1) // 2)
+ self.bn = nn.BatchNorm2d(out_channels)
+ self.relu = nn.ReLU(inplace=True)
+ self.shortcut = shortcut
+
+ def forward(self, x_children: List[torch.Tensor]):
+ x = self.conv(torch.cat(x_children, 1))
+ x = self.bn(x)
+ if self.shortcut:
+ x += x_children[0]
+ x = self.relu(x)
+
+ return x
+
+
+class DlaTree(nn.Module):
+ def __init__(
+ self,
+ levels,
+ block,
+ in_channels,
+ out_channels,
+ stride=1,
+ dilation=1,
+ cardinality=1,
+ base_width=64,
+ level_root=False,
+ root_dim=0,
+ root_kernel_size=1,
+ root_shortcut=False,
+ ):
+ super(DlaTree, self).__init__()
+ if root_dim == 0:
+ root_dim = 2 * out_channels
+ if level_root:
+ root_dim += in_channels
+ self.downsample = nn.MaxPool2d(stride, stride=stride) if stride > 1 else nn.Identity()
+ self.project = nn.Identity()
+ cargs = dict(dilation=dilation, cardinality=cardinality, base_width=base_width)
+ if levels == 1:
+ self.tree1 = block(in_channels, out_channels, stride, **cargs)
+ self.tree2 = block(out_channels, out_channels, 1, **cargs)
+ if in_channels != out_channels:
+ # NOTE the official impl/weights have project layers in levels > 1 case that are never
+ # used, I've moved the project layer here to avoid wasted params but old checkpoints will
+ # need strict=False while loading.
+ self.project = nn.Sequential(
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
+ nn.BatchNorm2d(out_channels))
+ self.root = DlaRoot(root_dim, out_channels, root_kernel_size, root_shortcut)
+ else:
+ cargs.update(dict(root_kernel_size=root_kernel_size, root_shortcut=root_shortcut))
+ self.tree1 = DlaTree(
+ levels - 1,
+ block,
+ in_channels,
+ out_channels,
+ stride,
+ root_dim=0,
+ **cargs,
+ )
+ self.tree2 = DlaTree(
+ levels - 1,
+ block,
+ out_channels,
+ out_channels,
+ root_dim=root_dim + out_channels,
+ **cargs,
+ )
+ self.root = None
+ self.level_root = level_root
+ self.root_dim = root_dim
+ self.levels = levels
+
+ def forward(self, x, shortcut: Optional[torch.Tensor] = None, children: Optional[List[torch.Tensor]] = None):
+ if children is None:
+ children = []
+ bottom = self.downsample(x)
+ shortcut = self.project(bottom)
+ if self.level_root:
+ children.append(bottom)
+ x1 = self.tree1(x, shortcut)
+ if self.root is not None: # levels == 1
+ x2 = self.tree2(x1)
+ x = self.root([x2, x1] + children)
+ else:
+ children.append(x1)
+ x = self.tree2(x1, None, children)
+ return x
+
+
+class DLA(nn.Module):
+ def __init__(
+ self,
+ levels,
+ channels,
+ output_stride=32,
+ num_classes=1000,
+ in_chans=3,
+ global_pool='avg',
+ cardinality=1,
+ base_width=64,
+ block=DlaBottle2neck,
+ shortcut_root=False,
+ drop_rate=0.0,
+ ):
+ super(DLA, self).__init__()
+ self.channels = channels
+ self.num_classes = num_classes
+ self.cardinality = cardinality
+ self.base_width = base_width
+ assert output_stride == 32 # FIXME support dilation
+
+ self.base_layer = nn.Sequential(
+ nn.Conv2d(in_chans, channels[0], kernel_size=7, stride=1, padding=3, bias=False),
+ nn.BatchNorm2d(channels[0]),
+ nn.ReLU(inplace=True),
+ )
+ self.level0 = self._make_conv_level(channels[0], channels[0], levels[0])
+ self.level1 = self._make_conv_level(channels[0], channels[1], levels[1], stride=2)
+ cargs = dict(cardinality=cardinality, base_width=base_width, root_shortcut=shortcut_root)
+ self.level2 = DlaTree(levels[2], block, channels[1], channels[2], 2, level_root=False, **cargs)
+ self.level3 = DlaTree(levels[3], block, channels[2], channels[3], 2, level_root=True, **cargs)
+ self.level4 = DlaTree(levels[4], block, channels[3], channels[4], 2, level_root=True, **cargs)
+ self.level5 = DlaTree(levels[5], block, channels[4], channels[5], 2, level_root=True, **cargs)
+ self.feature_info = [
+ dict(num_chs=channels[0], reduction=1, module='level0'), # rare to have a meaningful stride 1 level
+ dict(num_chs=channels[1], reduction=2, module='level1'),
+ dict(num_chs=channels[2], reduction=4, module='level2'),
+ dict(num_chs=channels[3], reduction=8, module='level3'),
+ dict(num_chs=channels[4], reduction=16, module='level4'),
+ dict(num_chs=channels[5], reduction=32, module='level5'),
+ ]
+
+ self.num_features = channels[-1]
+ self.global_pool, self.head_drop, self.fc = create_classifier(
+ self.num_features,
+ self.num_classes,
+ pool_type=global_pool,
+ use_conv=True,
+ drop_rate=drop_rate,
+ )
+ self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ m.weight.data.normal_(0, math.sqrt(2. / n))
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+ def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1):
+ modules = []
+ for i in range(convs):
+ modules.extend([
+ nn.Conv2d(
+ inplanes, planes, kernel_size=3,
+ stride=stride if i == 0 else 1,
+ padding=dilation, bias=False, dilation=dilation),
+ nn.BatchNorm2d(planes),
+ nn.ReLU(inplace=True)])
+ inplanes = planes
+ return nn.Sequential(*modules)
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse=False):
+ matcher = dict(
+ stem=r'^base_layer',
+ blocks=r'^level(\d+)' if coarse else [
+ # an unusual arch, this achieves somewhat more granularity without getting super messy
+ (r'^level(\d+)\.tree(\d+)', None),
+ (r'^level(\d+)\.root', (2,)),
+ (r'^level(\d+)', (1,))
+ ]
+ )
+ return matcher
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ assert not enable, 'gradient checkpointing not supported'
+
+ @torch.jit.ignore
+ def get_classifier(self):
+ return self.fc
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.num_classes = num_classes
+ self.global_pool, self.fc = create_classifier(
+ self.num_features, self.num_classes, pool_type=global_pool, use_conv=True)
+ self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
+
+ def forward_features(self, x):
+ x = self.base_layer(x)
+ x = self.level0(x)
+ x = self.level1(x)
+ x = self.level2(x)
+ x = self.level3(x)
+ x = self.level4(x)
+ x = self.level5(x)
+ return x
+
+ def forward_head(self, x, pre_logits: bool = False):
+ x = self.global_pool(x)
+ x = self.head_drop(x)
+ if pre_logits:
+ return self.flatten(x)
+ x = self.fc(x)
+ return self.flatten(x)
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.forward_head(x)
+ return x
+
+
+def _create_dla(variant, pretrained=False, **kwargs):
+ return build_model_with_cfg(
+ DLA,
+ variant,
+ pretrained,
+ pretrained_strict=False,
+ feature_cfg=dict(out_indices=(1, 2, 3, 4, 5)),
+ **kwargs,
+ )
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.875, 'interpolation': 'bilinear',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'base_layer.0', 'classifier': 'fc',
+ **kwargs
+ }
+
+
+default_cfgs = generate_default_cfgs({
+ 'dla34.in1k': _cfg(hf_hub_id='timm/'),
+ 'dla46_c.in1k': _cfg(hf_hub_id='timm/'),
+ 'dla46x_c.in1k': _cfg(hf_hub_id='timm/'),
+ 'dla60x_c.in1k': _cfg(hf_hub_id='timm/'),
+ 'dla60.in1k': _cfg(hf_hub_id='timm/'),
+ 'dla60x.in1k': _cfg(hf_hub_id='timm/'),
+ 'dla102.in1k': _cfg(hf_hub_id='timm/'),
+ 'dla102x.in1k': _cfg(hf_hub_id='timm/'),
+ 'dla102x2.in1k': _cfg(hf_hub_id='timm/'),
+ 'dla169.in1k': _cfg(hf_hub_id='timm/'),
+ 'dla60_res2net.in1k': _cfg(hf_hub_id='timm/'),
+ 'dla60_res2next.in1k': _cfg(hf_hub_id='timm/'),
+})
+
+
+@register_model
+def dla60_res2net(pretrained=False, **kwargs) -> DLA:
+ model_args = dict(
+ levels=(1, 1, 1, 2, 3, 1), channels=(16, 32, 128, 256, 512, 1024),
+ block=DlaBottle2neck, cardinality=1, base_width=28)
+ return _create_dla('dla60_res2net', pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def dla60_res2next(pretrained=False,**kwargs):
+ model_args = dict(
+ levels=(1, 1, 1, 2, 3, 1), channels=(16, 32, 128, 256, 512, 1024),
+ block=DlaBottle2neck, cardinality=8, base_width=4)
+ return _create_dla('dla60_res2next', pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def dla34(pretrained=False, **kwargs) -> DLA: # DLA-34
+ model_args = dict(
+ levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 128, 256, 512], block=DlaBasic)
+ return _create_dla('dla34', pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def dla46_c(pretrained=False, **kwargs) -> DLA: # DLA-46-C
+ model_args = dict(
+ levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 64, 128, 256], block=DlaBottleneck)
+ return _create_dla('dla46_c', pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def dla46x_c(pretrained=False, **kwargs) -> DLA: # DLA-X-46-C
+ model_args = dict(
+ levels=[1, 1, 1, 2, 2, 1], channels=[16, 32, 64, 64, 128, 256],
+ block=DlaBottleneck, cardinality=32, base_width=4)
+ return _create_dla('dla46x_c', pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def dla60x_c(pretrained=False, **kwargs) -> DLA: # DLA-X-60-C
+ model_args = dict(
+ levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 64, 64, 128, 256],
+ block=DlaBottleneck, cardinality=32, base_width=4)
+ return _create_dla('dla60x_c', pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def dla60(pretrained=False, **kwargs) -> DLA: # DLA-60
+ model_args = dict(
+ levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 128, 256, 512, 1024],
+ block=DlaBottleneck)
+ return _create_dla('dla60', pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def dla60x(pretrained=False, **kwargs) -> DLA: # DLA-X-60
+ model_args = dict(
+ levels=[1, 1, 1, 2, 3, 1], channels=[16, 32, 128, 256, 512, 1024],
+ block=DlaBottleneck, cardinality=32, base_width=4)
+ return _create_dla('dla60x', pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def dla102(pretrained=False, **kwargs) -> DLA: # DLA-102
+ model_args = dict(
+ levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
+ block=DlaBottleneck, shortcut_root=True)
+ return _create_dla('dla102', pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def dla102x(pretrained=False, **kwargs) -> DLA: # DLA-X-102
+ model_args = dict(
+ levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
+ block=DlaBottleneck, cardinality=32, base_width=4, shortcut_root=True)
+ return _create_dla('dla102x', pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def dla102x2(pretrained=False, **kwargs) -> DLA: # DLA-X-102 64
+ model_args = dict(
+ levels=[1, 1, 1, 3, 4, 1], channels=[16, 32, 128, 256, 512, 1024],
+ block=DlaBottleneck, cardinality=64, base_width=4, shortcut_root=True)
+ return _create_dla('dla102x2', pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def dla169(pretrained=False, **kwargs) -> DLA: # DLA-169
+ model_args = dict(
+ levels=[1, 1, 2, 3, 5, 1], channels=[16, 32, 128, 256, 512, 1024],
+ block=DlaBottleneck, shortcut_root=True)
+ return _create_dla('dla169', pretrained, **dict(model_args, **kwargs))
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/eva.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/eva.py
new file mode 100644
index 0000000000000000000000000000000000000000..82fff28acffa5549e7e49c36539e4360cf35ff7d
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/eva.py
@@ -0,0 +1,1109 @@
+""" EVA
+
+EVA from https://github.com/baaivision/EVA , paper: https://arxiv.org/abs/2211.07636
+
+@article{EVA,
+ title={EVA: Exploring the Limits of Masked Visual Representation Learning at Scale},
+ author={Fang, Yuxin and Wang, Wen and Xie, Binhui and Sun, Quan and Wu, Ledell and Wang, Xinggang and Huang,
+ Tiejun and Wang, Xinlong and Cao, Yue},
+ journal={arXiv preprint arXiv:2211.07636},
+ year={2022}
+}
+
+EVA-02: A Visual Representation for Neon Genesis - https://arxiv.org/abs/2303.11331
+@article{EVA02,
+ title={EVA-02: A Visual Representation for Neon Genesis},
+ author={Fang, Yuxin and Sun, Quan and Wang, Xinggang and Huang, Tiejun and Wang, Xinlong and Cao, Yue},
+ journal={arXiv preprint arXiv:2303.11331},
+ year={2023}
+}
+
+This file contains EVA & EVA02 model implementations evolved from BEiT, additional models in vision_transformer.py.
+
+Modifications by / Copyright 2023 Ross Wightman, original copyrights below
+"""
+# EVA models Copyright (c) 2022 BAAI-Vision
+# EVA02 models Copyright (c) 2023 BAAI-Vision
+
+import math
+from typing import Callable, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.checkpoint import checkpoint
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
+from timm.layers import PatchEmbed, Mlp, GluMlp, SwiGLU, LayerNorm, DropPath, PatchDropout, RotaryEmbeddingCat, \
+ apply_rot_embed_cat, apply_keep_indices_nlc, trunc_normal_, resample_patch_embed, resample_abs_pos_embed, \
+ to_2tuple, use_fused_attn
+
+from ._builder import build_model_with_cfg
+from ._registry import generate_default_cfgs, register_model
+
+__all__ = ['Eva']
+
+
+class EvaAttention(nn.Module):
+ fused_attn: torch.jit.Final[bool]
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = True,
+ qkv_fused: bool = True,
+ attn_drop: float = 0.,
+ proj_drop: float = 0.,
+ attn_head_dim: Optional[int] = None,
+ norm_layer: Optional[Callable] = None,
+ ):
+ """
+
+ Args:
+ dim:
+ num_heads:
+ qkv_bias:
+ qkv_fused:
+ attn_drop:
+ proj_drop:
+ attn_head_dim:
+ norm_layer:
+ """
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ if attn_head_dim is not None:
+ head_dim = attn_head_dim
+ all_head_dim = head_dim * self.num_heads
+ self.scale = head_dim ** -0.5
+ self.fused_attn = use_fused_attn()
+
+ if qkv_fused:
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
+ self.q_proj = self.k_proj = self.v_proj = None
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
+ self.register_buffer('k_bias', torch.zeros(all_head_dim), persistent=False)
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
+ else:
+ self.q_bias = self.k_bias = self.v_bias = None
+ else:
+ self.q_proj = nn.Linear(dim, all_head_dim, bias=qkv_bias)
+ self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
+ self.v_proj = nn.Linear(dim, all_head_dim, bias=qkv_bias)
+ self.qkv = None
+ self.q_bias = self.k_bias = self.v_bias = None
+
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.norm = norm_layer(all_head_dim) if norm_layer is not None else nn.Identity()
+ self.proj = nn.Linear(all_head_dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(
+ self,
+ x,
+ rope: Optional[torch.Tensor] = None,
+ attn_mask: Optional[torch.Tensor] = None,
+ ):
+ B, N, C = x.shape
+
+ if self.qkv is not None:
+ qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
+ else:
+ q = self.q_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2) # B, num_heads, N, C
+ k = self.k_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
+ v = self.v_proj(x).reshape(B, N, self.num_heads, -1).transpose(1, 2)
+
+ if rope is not None:
+ q = torch.cat([q[:, :, :1, :], apply_rot_embed_cat(q[:, :, 1:, :], rope)], 2).type_as(v)
+ k = torch.cat([k[:, :, :1, :], apply_rot_embed_cat(k[:, :, 1:, :], rope)], 2).type_as(v)
+
+ if self.fused_attn:
+ x = F.scaled_dot_product_attention(
+ q, k, v,
+ attn_mask=attn_mask,
+ dropout_p=self.attn_drop.p if self.training else 0.,
+ )
+ else:
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+ attn = attn.softmax(dim=-1)
+ if attn_mask is not None:
+ attn_mask = attn_mask.to(torch.bool)
+ attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
+ attn = self.attn_drop(attn)
+ x = attn @ v
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.norm(x)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class EvaBlock(nn.Module):
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ qkv_bias: bool = True,
+ qkv_fused: bool = True,
+ mlp_ratio: float = 4.,
+ swiglu_mlp: bool = False,
+ scale_mlp: bool = False,
+ scale_attn_inner: bool = False,
+ proj_drop: float = 0.,
+ attn_drop: float = 0.,
+ drop_path: float = 0.,
+ init_values: Optional[float] = None,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = LayerNorm,
+ attn_head_dim: Optional[int] = None,
+ ):
+ """
+
+ Args:
+ dim:
+ num_heads:
+ qkv_bias:
+ qkv_fused:
+ mlp_ratio:
+ swiglu_mlp:
+ scale_mlp:
+ scale_attn_inner:
+ proj_drop:
+ attn_drop:
+ drop_path:
+ init_values:
+ act_layer:
+ norm_layer:
+ attn_head_dim:
+ """
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = EvaAttention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qkv_fused=qkv_fused,
+ attn_drop=attn_drop,
+ proj_drop=proj_drop,
+ attn_head_dim=attn_head_dim,
+ norm_layer=norm_layer if scale_attn_inner else None,
+ )
+ self.gamma_1 = nn.Parameter(init_values * torch.ones(dim)) if init_values is not None else None
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ hidden_features = int(dim * mlp_ratio)
+ if swiglu_mlp:
+ if scale_mlp:
+ # when norm in SwiGLU used, an impl with separate fc for gate & x is used
+ self.mlp = SwiGLU(
+ in_features=dim,
+ hidden_features=hidden_features,
+ norm_layer=norm_layer if scale_mlp else None,
+ drop=proj_drop,
+ )
+ else:
+ # w/o any extra norm, an impl with packed weights is used, matches existing GluMLP
+ self.mlp = GluMlp(
+ in_features=dim,
+ hidden_features=hidden_features * 2,
+ norm_layer=norm_layer if scale_mlp else None,
+ act_layer=nn.SiLU,
+ gate_last=False,
+ drop=proj_drop,
+ )
+ else:
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=hidden_features,
+ act_layer=act_layer,
+ norm_layer=norm_layer if scale_mlp else None,
+ drop=proj_drop,
+ )
+ self.gamma_2 = nn.Parameter(init_values * torch.ones(dim)) if init_values is not None else None
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def forward(self, x, rope: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None):
+ if self.gamma_1 is None:
+ x = x + self.drop_path1(self.attn(self.norm1(x), rope=rope, attn_mask=attn_mask))
+ x = x + self.drop_path2(self.mlp(self.norm2(x)))
+ else:
+ x = x + self.drop_path1(self.gamma_1 * self.attn(self.norm1(x), rope=rope, attn_mask=attn_mask))
+ x = x + self.drop_path2(self.gamma_2 * self.mlp(self.norm2(x)))
+ return x
+
+
+class EvaBlockPostNorm(nn.Module):
+ """ EVA block w/ post-norm and support for swiglu, MLP norm scale, ROPE. """
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ qkv_bias: bool = True,
+ qkv_fused: bool = True,
+ mlp_ratio: float = 4.,
+ swiglu_mlp: bool = False,
+ scale_mlp: bool = False,
+ scale_attn_inner: bool = False,
+ proj_drop: float = 0.,
+ attn_drop: float = 0.,
+ drop_path: float = 0.,
+ init_values: Optional[float] = None, # ignore for post-norm
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = nn.LayerNorm,
+ attn_head_dim: Optional[int] = None,
+ ):
+ """
+
+ Args:
+ dim:
+ num_heads:
+ qkv_bias:
+ qkv_fused:
+ mlp_ratio:
+ swiglu_mlp:
+ scale_mlp:
+ scale_attn_inner:
+ proj_drop:
+ attn_drop:
+ drop_path:
+ init_values:
+ act_layer:
+ norm_layer:
+ attn_head_dim:
+ """
+ super().__init__()
+ self.attn = EvaAttention(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qkv_fused=qkv_fused,
+ attn_drop=attn_drop,
+ proj_drop=proj_drop,
+ attn_head_dim=attn_head_dim,
+ norm_layer=norm_layer if scale_attn_inner else None,
+ )
+ self.norm1 = norm_layer(dim)
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ hidden_features = int(dim * mlp_ratio)
+ if swiglu_mlp:
+ if scale_mlp:
+ # when norm in SwiGLU used, an impl with separate fc for gate & x is used
+ self.mlp = SwiGLU(
+ in_features=dim,
+ hidden_features=hidden_features,
+ norm_layer=norm_layer if scale_mlp else None,
+ drop=proj_drop,
+ )
+ else:
+ # w/o any extra norm, an impl with packed fc1 weights is used, matches existing GluMLP
+ self.mlp = GluMlp(
+ in_features=dim,
+ hidden_features=hidden_features * 2,
+ norm_layer=norm_layer if scale_mlp else None,
+ act_layer=nn.SiLU,
+ gate_last=False,
+ drop=proj_drop,
+ )
+ else:
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=hidden_features,
+ act_layer=act_layer,
+ norm_layer=norm_layer if scale_mlp else None,
+ drop=proj_drop,
+ )
+ self.norm2 = norm_layer(dim)
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def forward(self, x, rope: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None):
+ x = x + self.drop_path1(self.norm1(self.attn(x, rope=rope, attn_mask=attn_mask)))
+ x = x + self.drop_path2(self.norm2(self.mlp(x)))
+ return x
+
+
+class Eva(nn.Module):
+ """ Eva Vision Transformer w/ Abs & Rotary Pos Embed
+
+ This class implements the EVA and EVA02 models that were based on the BEiT ViT variant
+ * EVA - abs pos embed, global avg pool
+ * EVA02 - abs + rope pos embed, global avg pool, SwiGLU, scale Norm in MLP (ala normformer)
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ num_classes: int = 1000,
+ global_pool: str = 'avg',
+ embed_dim: int = 768,
+ depth: int = 12,
+ num_heads: int = 12,
+ qkv_bias: bool = True,
+ qkv_fused: bool = True,
+ mlp_ratio: float = 4.,
+ swiglu_mlp: bool = False,
+ scale_mlp: bool = False,
+ scale_attn_inner: bool = False,
+ drop_rate: float = 0.,
+ pos_drop_rate: float = 0.,
+ patch_drop_rate: float = 0.,
+ proj_drop_rate: float = 0.,
+ attn_drop_rate: float = 0.,
+ drop_path_rate: float = 0.,
+ norm_layer: Callable = LayerNorm,
+ init_values: Optional[float] = None,
+ class_token: bool = True,
+ use_abs_pos_emb: bool = True,
+ use_rot_pos_emb: bool = False,
+ use_post_norm: bool = False,
+ dynamic_img_size: bool = False,
+ dynamic_img_pad: bool = False,
+ ref_feat_shape: Optional[Union[Tuple[int, int], int]] = None,
+ head_init_scale: float = 0.001,
+ ):
+ """
+
+ Args:
+ img_size:
+ patch_size:
+ in_chans:
+ num_classes:
+ global_pool:
+ embed_dim:
+ depth:
+ num_heads:
+ qkv_bias:
+ qkv_fused:
+ mlp_ratio:
+ swiglu_mlp:
+ scale_mlp:
+ scale_attn_inner:
+ drop_rate:
+ pos_drop_rate:
+ proj_drop_rate:
+ attn_drop_rate:
+ drop_path_rate:
+ norm_layer:
+ init_values:
+ class_token:
+ use_abs_pos_emb:
+ use_rot_pos_emb:
+ use_post_norm:
+ ref_feat_shape:
+ head_init_scale:
+ """
+ super().__init__()
+ self.num_classes = num_classes
+ self.global_pool = global_pool
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_prefix_tokens = 1 if class_token else 0
+ self.dynamic_img_size = dynamic_img_size
+ self.grad_checkpointing = False
+
+ embed_args = {}
+ if dynamic_img_size:
+ # flatten deferred until after pos embed
+ embed_args.update(dict(strict_img_size=False, output_fmt='NHWC'))
+ self.patch_embed = PatchEmbed(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ dynamic_img_pad=dynamic_img_pad,
+ **embed_args,
+ )
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
+
+ self.pos_embed = nn.Parameter(
+ torch.zeros(1, num_patches + self.num_prefix_tokens, embed_dim)) if use_abs_pos_emb else None
+ self.pos_drop = nn.Dropout(p=pos_drop_rate)
+ if patch_drop_rate > 0:
+ self.patch_drop = PatchDropout(
+ patch_drop_rate,
+ num_prefix_tokens=self.num_prefix_tokens,
+ return_indices=True,
+ )
+ else:
+ self.patch_drop = None
+
+ if use_rot_pos_emb:
+ ref_feat_shape = to_2tuple(ref_feat_shape) if ref_feat_shape is not None else None
+ self.rope = RotaryEmbeddingCat(
+ embed_dim // num_heads,
+ in_pixels=False,
+ feat_shape=None if dynamic_img_size else self.patch_embed.grid_size,
+ ref_feat_shape=ref_feat_shape,
+ )
+ else:
+ self.rope = None
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ block_fn = EvaBlockPostNorm if use_post_norm else EvaBlock
+ self.blocks = nn.ModuleList([
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qkv_fused=qkv_fused,
+ mlp_ratio=mlp_ratio,
+ swiglu_mlp=swiglu_mlp,
+ scale_mlp=scale_mlp,
+ scale_attn_inner=scale_attn_inner,
+ proj_drop=proj_drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ init_values=init_values,
+ )
+ for i in range(depth)])
+
+ use_fc_norm = self.global_pool == 'avg'
+ self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim)
+ self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
+ self.head_drop = nn.Dropout(drop_rate)
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ self.apply(self._init_weights)
+ if self.pos_embed is not None:
+ trunc_normal_(self.pos_embed, std=.02)
+ if self.cls_token is not None:
+ trunc_normal_(self.cls_token, std=.02)
+
+ self.fix_init_weight()
+ if isinstance(self.head, nn.Linear):
+ trunc_normal_(self.head.weight, std=.02)
+ self.head.weight.data.mul_(head_init_scale)
+ self.head.bias.data.mul_(head_init_scale)
+
+ def fix_init_weight(self):
+ def rescale(param, layer_id):
+ param.div_(math.sqrt(2.0 * layer_id))
+
+ for layer_id, layer in enumerate(self.blocks):
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ nwd = {'pos_embed', 'cls_token'}
+ return nwd
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.grad_checkpointing = enable
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse=False):
+ matcher = dict(
+ stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
+ blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))],
+ )
+ return matcher
+
+ @torch.jit.ignore
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=None):
+ self.num_classes = num_classes
+ if global_pool is not None:
+ self.global_pool = global_pool
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ if self.dynamic_img_size:
+ B, H, W, C = x.shape
+ if self.pos_embed is not None:
+ pos_embed = resample_abs_pos_embed(
+ self.pos_embed,
+ (H, W),
+ num_prefix_tokens=self.num_prefix_tokens,
+ )
+ else:
+ pos_embed = None
+ x = x.view(B, -1, C)
+ rot_pos_embed = self.rope.get_embed(shape=(H, W)) if self.rope is not None else None
+ else:
+ pos_embed = self.pos_embed
+ rot_pos_embed = self.rope.get_embed() if self.rope is not None else None
+
+ if self.cls_token is not None:
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ if pos_embed is not None:
+ x = x + pos_embed
+ x = self.pos_drop(x)
+
+ # obtain shared rotary position embedding and apply patch dropout
+ if self.patch_drop is not None:
+ x, keep_indices = self.patch_drop(x)
+ if rot_pos_embed is not None and keep_indices is not None:
+ rot_pos_embed = apply_keep_indices_nlc(x, rot_pos_embed, keep_indices)
+ return x, rot_pos_embed
+
+ def forward_features(self, x):
+ x = self.patch_embed(x)
+ x, rot_pos_embed = self._pos_embed(x)
+ for blk in self.blocks:
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint(blk, x, rope=rot_pos_embed)
+ else:
+ x = blk(x, rope=rot_pos_embed)
+ x = self.norm(x)
+ return x
+
+ def forward_head(self, x, pre_logits: bool = False):
+ if self.global_pool:
+ x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
+ x = self.fc_norm(x)
+ x = self.head_drop(x)
+ return x if pre_logits else self.head(x)
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.forward_head(x)
+ return x
+
+
+def checkpoint_filter_fn(
+ state_dict,
+ model,
+ interpolation='bicubic',
+ antialias=True,
+):
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
+ out_dict = {}
+ state_dict = state_dict.get('model_ema', state_dict)
+ state_dict = state_dict.get('model', state_dict)
+ state_dict = state_dict.get('module', state_dict)
+ state_dict = state_dict.get('state_dict', state_dict)
+ # prefix for loading OpenCLIP compatible weights
+ if 'visual.trunk.pos_embed' in state_dict:
+ prefix = 'visual.trunk.'
+ elif 'visual.pos_embed' in state_dict:
+ prefix = 'visual.'
+ else:
+ prefix = ''
+ mim_weights = prefix + 'mask_token' in state_dict
+ no_qkv = prefix + 'blocks.0.attn.q_proj.weight' in state_dict
+
+ len_prefix = len(prefix)
+ for k, v in state_dict.items():
+ if prefix:
+ if k.startswith(prefix):
+ k = k[len_prefix:]
+ else:
+ continue
+
+ if 'rope' in k:
+ # fixed embedding no need to load buffer from checkpoint
+ continue
+
+ if 'patch_embed.proj.weight' in k:
+ _, _, H, W = model.patch_embed.proj.weight.shape
+ if v.shape[-1] != W or v.shape[-2] != H:
+ v = resample_patch_embed(
+ v,
+ (H, W),
+ interpolation=interpolation,
+ antialias=antialias,
+ verbose=True,
+ )
+ elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
+ # To resize pos embedding when using model at different size from pretrained weights
+ num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) else getattr(model, 'num_prefix_tokens', 1)
+ v = resample_abs_pos_embed(
+ v,
+ new_size=model.patch_embed.grid_size,
+ num_prefix_tokens=num_prefix_tokens,
+ interpolation=interpolation,
+ antialias=antialias,
+ verbose=True,
+ )
+
+ k = k.replace('mlp.ffn_ln', 'mlp.norm')
+ k = k.replace('attn.inner_attn_ln', 'attn.norm')
+ k = k.replace('mlp.w12', 'mlp.fc1')
+ k = k.replace('mlp.w1', 'mlp.fc1_g')
+ k = k.replace('mlp.w2', 'mlp.fc1_x')
+ k = k.replace('mlp.w3', 'mlp.fc2')
+ if no_qkv:
+ k = k.replace('q_bias', 'q_proj.bias')
+ k = k.replace('v_bias', 'v_proj.bias')
+
+ if mim_weights and k in ('mask_token', 'lm_head.weight', 'lm_head.bias', 'norm.weight', 'norm.bias'):
+ if k == 'norm.weight' or k == 'norm.bias':
+ # try moving norm -> fc norm on fine-tune, probably a better starting point than new init
+ k = k.replace('norm', 'fc_norm')
+ else:
+ # skip pretrain mask token & head weights
+ continue
+
+ out_dict[k] = v
+
+ return out_dict
+
+
+def _create_eva(variant, pretrained=False, **kwargs):
+ if kwargs.get('features_only', None):
+ raise RuntimeError('features_only not implemented for Eva models.')
+
+ model = build_model_with_cfg(
+ Eva, variant, pretrained,
+ pretrained_filter_fn=checkpoint_filter_fn,
+ **kwargs)
+ return model
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
+ 'mean': OPENAI_CLIP_MEAN, 'std': OPENAI_CLIP_STD,
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
+ 'license': 'mit', **kwargs
+ }
+
+
+default_cfgs = generate_default_cfgs({
+
+ # EVA 01 CLIP fine-tuned on imagenet-1k
+ 'eva_giant_patch14_224.clip_ft_in1k': _cfg(
+ # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt',
+ hf_hub_id='timm/',
+ ),
+ 'eva_giant_patch14_336.clip_ft_in1k': _cfg(
+ # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt',
+ hf_hub_id='timm/',
+ input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
+
+ # MIM EVA 01 pretrain, ft on in22k -> in1k
+ 'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg(
+ # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt',
+ hf_hub_id='timm/',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
+ input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
+ 'eva_giant_patch14_560.m30m_ft_in22k_in1k': _cfg(
+ # hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt',
+ hf_hub_id='timm/',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
+ input_size=(3, 560, 560), crop_pct=1.0, crop_mode='squash'),
+
+ # in22k or m38m MIM pretrain w/ intermediate in22k fine-tune and final in1k fine-tune
+ 'eva02_base_patch14_448.mim_in22k_ft_in22k_in1k': _cfg(
+ # hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_B_pt_in21k_medft_in21k_ft_in1k_p14.pt',
+ hf_hub_id='timm/',
+ input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash',
+ ),
+ 'eva02_large_patch14_448.mim_in22k_ft_in22k_in1k': _cfg(
+ # hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_L_pt_in21k_medft_in21k_ft_in1k_p14.pt',
+ hf_hub_id='timm/',
+ input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash',
+ ),
+ 'eva02_large_patch14_448.mim_m38m_ft_in22k_in1k': _cfg(
+ hf_hub_id='timm/',
+ #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k_to_in1k/eva02_L_pt_m38m_medft_in21k_ft_in1k_p14.pt',
+ input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash',
+ ),
+
+ # in22k or m3m MIM pretrain w/ in1k fine-tune
+ 'eva02_tiny_patch14_336.mim_in22k_ft_in1k': _cfg(
+ #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in1k/eva02_Ti_pt_in21k_ft_in1k_p14.pt',
+ hf_hub_id='timm/',
+ input_size=(3, 336, 336), crop_pct=1.0,
+ ),
+ 'eva02_small_patch14_336.mim_in22k_ft_in1k': _cfg(
+ #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in1k/eva02_S_pt_in21k_ft_in1k_p14.pt',
+ hf_hub_id='timm/',
+ input_size=(3, 336, 336), crop_pct=1.0,
+ ),
+ 'eva02_base_patch14_448.mim_in22k_ft_in1k': _cfg(
+ #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in1k/eva02_B_pt_in21k_ft_in1k_p14.pt',
+ hf_hub_id='timm/',
+ input_size=(3, 448, 448), crop_pct=1.0,
+ ),
+ 'eva02_large_patch14_448.mim_in22k_ft_in1k': _cfg(
+ #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in1k/eva02_L_pt_in21k_ft_in1k_p14.pt',
+ hf_hub_id='timm/',
+ input_size=(3, 448, 448), crop_pct=1.0,
+ ),
+ 'eva02_large_patch14_448.mim_m38m_ft_in1k': _cfg(
+ #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in1k/eva02_L_pt_m38m_ft_in1k_p14.pt',
+ hf_hub_id='timm/',
+ input_size=(3, 448, 448), crop_pct=1.0,
+ ),
+
+ # in22k or m3m MIM pretrain w/ in22k fine-tune
+ 'eva02_base_patch14_448.mim_in22k_ft_in22k': _cfg(
+ #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_B_pt_in21k_medft_in21k_p14.pt',
+ hf_hub_id='timm/',
+ input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', num_classes=21841,
+ ),
+ 'eva02_large_patch14_448.mim_in22k_ft_in22k': _cfg(
+ #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_L_pt_in21k_medft_in21k_p14.pt',
+ hf_hub_id='timm/',
+ input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', num_classes=21841,
+ ),
+ 'eva02_large_patch14_448.mim_m38m_ft_in22k': _cfg(
+ #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/cls/in21k/eva02_L_pt_m38m_medft_in21k_p14.pt',
+ hf_hub_id='timm/',
+ input_size=(3, 448, 448), crop_pct=1.0, crop_mode='squash', num_classes=21841,
+ ),
+
+ # in22k or m38m MIM pretrain
+ 'eva02_tiny_patch14_224.mim_in22k': _cfg(
+ # hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_Ti_pt_in21k_p14.pt',
+ hf_hub_id='timm/',
+ num_classes=0,
+ ),
+ 'eva02_small_patch14_224.mim_in22k': _cfg(
+ #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_S_pt_in21k_p14.pt',
+ hf_hub_id='timm/',
+ num_classes=0,
+ ),
+ 'eva02_base_patch14_224.mim_in22k': _cfg(
+ #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_B_pt_in21k_p14.pt',
+ hf_hub_id='timm/',
+ num_classes=0,
+ ),
+ 'eva02_large_patch14_224.mim_in22k': _cfg(
+ #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_L_pt_in21k_p14.pt',
+ hf_hub_id='timm/',
+ num_classes=0,
+ ),
+ 'eva02_large_patch14_224.mim_m38m': _cfg(
+ #hf_hub_id='Yuxin-CV/EVA-02', hf_hub_filename='eva02/pt/eva02_L_pt_m38m_p14.pt',
+ hf_hub_id='timm/',
+ num_classes=0,
+ ),
+
+ # EVA01 and EVA02 CLIP image towers
+ 'eva_giant_patch14_clip_224.laion400m': _cfg(
+ # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA01_CLIP_g_14_plus_psz14_s11B.pt',
+ hf_hub_id='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k', # float16 weights
+ hf_hub_filename='open_clip_pytorch_model.bin',
+ num_classes=1024,
+ ),
+ 'eva_giant_patch14_clip_224.merged2b': _cfg(
+ # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA01_CLIP_g_14_plus_psz14_s11B.pt',
+ hf_hub_id='timm/eva_giant_patch14_plus_clip_224.merged2b_s11b_b114k', # float16 weights
+ hf_hub_filename='open_clip_pytorch_model.bin',
+ num_classes=1024,
+ ),
+ 'eva02_base_patch16_clip_224.merged2b': _cfg(
+ # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_L_psz14_s4B.pt',
+ hf_hub_id='timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k', # float16 weights
+ hf_hub_filename='open_clip_pytorch_model.bin',
+ num_classes=512,
+ ),
+ 'eva02_large_patch14_clip_224.merged2b': _cfg(
+ # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_L_psz14_s4B.pt',
+ hf_hub_id='timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k', # float16 weights
+ hf_hub_filename='open_clip_pytorch_model.bin',
+ num_classes=768,
+ ),
+ 'eva02_large_patch14_clip_336.merged2b': _cfg(
+ # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_L_psz14_s4B.pt',
+ hf_hub_id='timm/eva02_large_patch14_clip_336.merged2b_s6b_b61k', # float16 weights
+ hf_hub_filename='open_clip_pytorch_model.bin',
+ input_size=(3, 336, 336), crop_pct=1.0,
+ num_classes=768,
+ ),
+ 'eva02_enormous_patch14_clip_224.laion2b': _cfg(
+ # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_E_psz14_plus_s9B.pt',
+ hf_hub_id='timm/eva02_enormous_patch14_clip_224.laion2b_s4b_b115k', # float16 weights
+ hf_hub_filename='open_clip_pytorch_model.bin',
+ num_classes=1024,
+ ),
+ 'eva02_enormous_patch14_clip_224.laion2b_plus': _cfg(
+ # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_E_psz14_plus_s9B.pt',
+ hf_hub_id='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k', # bfloat16 weights
+ hf_hub_filename='open_clip_pytorch_model.bin',
+ num_classes=1024,
+ ),
+ 'eva02_enormous_patch14_clip_224.pretrain': _cfg(
+ # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_E_psz14.pt',
+ num_classes=0,
+ ),
+
+})
+
+
+@register_model
+def eva_giant_patch14_224(pretrained=False, **kwargs) -> Eva:
+ """ EVA-g model https://arxiv.org/abs/2211.07636 """
+ model_args = dict(patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408)
+ model = _create_eva('eva_giant_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def eva_giant_patch14_336(pretrained=False, **kwargs) -> Eva:
+ """ EVA-g model https://arxiv.org/abs/2211.07636 """
+ model_args = dict(patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408)
+ model = _create_eva('eva_giant_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def eva_giant_patch14_560(pretrained=False, **kwargs) -> Eva:
+ """ EVA-g model https://arxiv.org/abs/2211.07636 """
+ model_args = dict(patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408)
+ model = _create_eva('eva_giant_patch14_560', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def eva02_tiny_patch14_224(pretrained=False, **kwargs) -> Eva:
+ model_args = dict(
+ img_size=224,
+ patch_size=14,
+ embed_dim=192,
+ depth=12,
+ num_heads=3,
+ mlp_ratio=4 * 2 / 3,
+ swiglu_mlp=True,
+ use_rot_pos_emb=True,
+ ref_feat_shape=(16, 16), # 224/14
+ )
+ model = _create_eva('eva02_tiny_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def eva02_small_patch14_224(pretrained=False, **kwargs) -> Eva:
+ model_args = dict(
+ img_size=224,
+ patch_size=14,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4 * 2 / 3,
+ swiglu_mlp=True,
+ use_rot_pos_emb=True,
+ ref_feat_shape=(16, 16), # 224/14
+ )
+ model = _create_eva('eva02_small_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def eva02_base_patch14_224(pretrained=False, **kwargs) -> Eva:
+ model_args = dict(
+ img_size=224,
+ patch_size=14,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ qkv_fused=False,
+ mlp_ratio=4 * 2 / 3,
+ swiglu_mlp=True,
+ scale_mlp=True,
+ use_rot_pos_emb=True,
+ ref_feat_shape=(16, 16), # 224/14
+ )
+ model = _create_eva('eva02_base_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def eva02_large_patch14_224(pretrained=False, **kwargs) -> Eva:
+ model_args = dict(
+ img_size=224,
+ patch_size=14,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4 * 2 / 3,
+ qkv_fused=False,
+ swiglu_mlp=True,
+ scale_mlp=True,
+ use_rot_pos_emb=True,
+ ref_feat_shape=(16, 16), # 224/14
+ )
+ model = _create_eva('eva02_large_patch14_224', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def eva02_tiny_patch14_336(pretrained=False, **kwargs) -> Eva:
+ model_args = dict(
+ img_size=336,
+ patch_size=14,
+ embed_dim=192,
+ depth=12,
+ num_heads=3,
+ mlp_ratio=4 * 2 / 3,
+ swiglu_mlp=True,
+ use_rot_pos_emb=True,
+ ref_feat_shape=(16, 16), # 224/14
+ )
+ model = _create_eva('eva02_tiny_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def eva02_small_patch14_336(pretrained=False, **kwargs) -> Eva:
+ model_args = dict(
+ img_size=336,
+ patch_size=14,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4 * 2 / 3,
+ swiglu_mlp=True,
+ use_rot_pos_emb=True,
+ ref_feat_shape=(16, 16), # 224/14
+ )
+ model = _create_eva('eva02_small_patch14_336', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def eva02_base_patch14_448(pretrained=False, **kwargs) -> Eva:
+ model_args = dict(
+ img_size=448,
+ patch_size=14,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ qkv_fused=False,
+ mlp_ratio=4 * 2 / 3,
+ swiglu_mlp=True,
+ scale_mlp=True,
+ use_rot_pos_emb=True,
+ ref_feat_shape=(16, 16), # 224/14
+ )
+ model = _create_eva('eva02_base_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def eva02_large_patch14_448(pretrained=False, **kwargs) -> Eva:
+ model_args = dict(
+ img_size=448,
+ patch_size=14,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4 * 2 / 3,
+ qkv_fused=False,
+ swiglu_mlp=True,
+ scale_mlp=True,
+ use_rot_pos_emb=True,
+ ref_feat_shape=(16, 16), # 224/14
+ )
+ model = _create_eva('eva02_large_patch14_448', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def eva_giant_patch14_clip_224(pretrained=False, **kwargs) -> Eva:
+ """ EVA-g CLIP model (only difference from non-CLIP is the pooling) """
+ model_args = dict(
+ patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408,
+ global_pool=kwargs.pop('global_pool', 'token'))
+ model = _create_eva('eva_giant_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def eva02_base_patch16_clip_224(pretrained=False, **kwargs) -> Eva:
+ """ A EVA-CLIP specific variant that adds additional attn scale layernorm to eva02_base """
+ model_args = dict(
+ img_size=224,
+ patch_size=16,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ qkv_fused=False,
+ mlp_ratio=4 * 2 / 3,
+ swiglu_mlp=True,
+ scale_mlp=True,
+ scale_attn_inner=True,
+ use_rot_pos_emb=True,
+ ref_feat_shape=(16, 16), # 224/14
+ global_pool=kwargs.pop('global_pool', 'token'),
+ )
+ model = _create_eva('eva02_base_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def eva02_large_patch14_clip_224(pretrained=False, **kwargs) -> Eva:
+ """ A EVA-CLIP specific variant that adds additional attn scale layernorm to eva02_large """
+ model_args = dict(
+ img_size=224,
+ patch_size=14,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4 * 2 / 3,
+ qkv_fused=False,
+ swiglu_mlp=True,
+ scale_mlp=True,
+ scale_attn_inner=True,
+ use_rot_pos_emb=True,
+ ref_feat_shape=(16, 16), # 224/14
+ global_pool=kwargs.pop('global_pool', 'token'),
+ )
+ model = _create_eva('eva02_large_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def eva02_large_patch14_clip_336(pretrained=False, **kwargs) -> Eva:
+ """ A EVA-CLIP specific variant that adds additional attn scale layernorm to eva02_large """
+ model_args = dict(
+ img_size=336,
+ patch_size=14,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4 * 2 / 3,
+ qkv_fused=False,
+ swiglu_mlp=True,
+ scale_mlp=True,
+ scale_attn_inner=True,
+ use_rot_pos_emb=True,
+ ref_feat_shape=(16, 16), # 224/14
+ global_pool=kwargs.pop('global_pool', 'token'),
+ )
+ model = _create_eva('eva02_large_patch14_clip_336', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def eva02_enormous_patch14_clip_224(pretrained=False, **kwargs) -> Eva:
+ """ A EVA-CLIP specific variant that uses residual post-norm in blocks """
+ model_args = dict(
+ img_size=224,
+ patch_size=14,
+ embed_dim=1792,
+ depth=64,
+ num_heads=16,
+ mlp_ratio=15360 / 1792,
+ use_post_norm=True,
+ global_pool=kwargs.pop('global_pool', 'token'),
+ )
+ model = _create_eva('eva02_enormous_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/factory.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ae83dc08e51931866feeac00f2e99646aa2667c
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/factory.py
@@ -0,0 +1,4 @@
+from ._factory import *
+
+import warnings
+warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/features.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/features.py
new file mode 100644
index 0000000000000000000000000000000000000000..25605d99daa908ab16b42f4cc8c3a5585e305df4
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/features.py
@@ -0,0 +1,4 @@
+from ._features import *
+
+import warnings
+warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/gcvit.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/gcvit.py
new file mode 100644
index 0000000000000000000000000000000000000000..29536a7dd2481c24827d107add12e3d3593ec03f
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/gcvit.py
@@ -0,0 +1,592 @@
+""" Global Context ViT
+
+From scratch implementation of GCViT in the style of timm swin_transformer_v2_cr.py
+
+Global Context Vision Transformers -https://arxiv.org/abs/2206.09959
+
+@article{hatamizadeh2022global,
+ title={Global Context Vision Transformers},
+ author={Hatamizadeh, Ali and Yin, Hongxu and Kautz, Jan and Molchanov, Pavlo},
+ journal={arXiv preprint arXiv:2206.09959},
+ year={2022}
+}
+
+Free of any code related to NVIDIA GCVit impl at https://github.com/NVlabs/GCVit.
+The license for this code release is Apache 2.0 with no commercial restrictions.
+
+However, weight files adapted from NVIDIA GCVit impl ARE under a non-commercial share-alike license
+(https://creativecommons.org/licenses/by-nc-sa/4.0/) until I have a chance to train new ones...
+
+Hacked together by / Copyright 2022, Ross Wightman
+"""
+import math
+from functools import partial
+from typing import Callable, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as checkpoint
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d, \
+ get_attn, get_act_layer, get_norm_layer, RelPosBias, _assert
+from ._builder import build_model_with_cfg
+from ._features_fx import register_notrace_function
+from ._manipulate import named_apply
+from ._registry import register_model, generate_default_cfgs
+
+__all__ = ['GlobalContextVit']
+
+
+class MbConvBlock(nn.Module):
+ """ A depthwise separable / fused mbconv style residual block with SE, `no norm.
+ """
+ def __init__(
+ self,
+ in_chs,
+ out_chs=None,
+ expand_ratio=1.0,
+ attn_layer='se',
+ bias=False,
+ act_layer=nn.GELU,
+ ):
+ super().__init__()
+ attn_kwargs = dict(act_layer=act_layer)
+ if isinstance(attn_layer, str) and attn_layer == 'se' or attn_layer == 'eca':
+ attn_kwargs['rd_ratio'] = 0.25
+ attn_kwargs['bias'] = False
+ attn_layer = get_attn(attn_layer)
+ out_chs = out_chs or in_chs
+ mid_chs = int(expand_ratio * in_chs)
+
+ self.conv_dw = nn.Conv2d(in_chs, mid_chs, 3, 1, 1, groups=in_chs, bias=bias)
+ self.act = act_layer()
+ self.se = attn_layer(mid_chs, **attn_kwargs)
+ self.conv_pw = nn.Conv2d(mid_chs, out_chs, 1, 1, 0, bias=bias)
+
+ def forward(self, x):
+ shortcut = x
+ x = self.conv_dw(x)
+ x = self.act(x)
+ x = self.se(x)
+ x = self.conv_pw(x)
+ x = x + shortcut
+ return x
+
+
+class Downsample2d(nn.Module):
+ def __init__(
+ self,
+ dim,
+ dim_out=None,
+ reduction='conv',
+ act_layer=nn.GELU,
+ norm_layer=LayerNorm2d, # NOTE in NCHW
+ ):
+ super().__init__()
+ dim_out = dim_out or dim
+
+ self.norm1 = norm_layer(dim) if norm_layer is not None else nn.Identity()
+ self.conv_block = MbConvBlock(dim, act_layer=act_layer)
+ assert reduction in ('conv', 'max', 'avg')
+ if reduction == 'conv':
+ self.reduction = nn.Conv2d(dim, dim_out, 3, 2, 1, bias=False)
+ elif reduction == 'max':
+ assert dim == dim_out
+ self.reduction = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+ else:
+ assert dim == dim_out
+ self.reduction = nn.AvgPool2d(kernel_size=2)
+ self.norm2 = norm_layer(dim_out) if norm_layer is not None else nn.Identity()
+
+ def forward(self, x):
+ x = self.norm1(x)
+ x = self.conv_block(x)
+ x = self.reduction(x)
+ x = self.norm2(x)
+ return x
+
+
+class FeatureBlock(nn.Module):
+ def __init__(
+ self,
+ dim,
+ levels=0,
+ reduction='max',
+ act_layer=nn.GELU,
+ ):
+ super().__init__()
+ reductions = levels
+ levels = max(1, levels)
+ if reduction == 'avg':
+ pool_fn = partial(nn.AvgPool2d, kernel_size=2)
+ else:
+ pool_fn = partial(nn.MaxPool2d, kernel_size=3, stride=2, padding=1)
+ self.blocks = nn.Sequential()
+ for i in range(levels):
+ self.blocks.add_module(f'conv{i+1}', MbConvBlock(dim, act_layer=act_layer))
+ if reductions:
+ self.blocks.add_module(f'pool{i+1}', pool_fn())
+ reductions -= 1
+
+ def forward(self, x):
+ return self.blocks(x)
+
+
+class Stem(nn.Module):
+ def __init__(
+ self,
+ in_chs: int = 3,
+ out_chs: int = 96,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = LayerNorm2d, # NOTE stem in NCHW
+ ):
+ super().__init__()
+ self.conv1 = nn.Conv2d(in_chs, out_chs, kernel_size=3, stride=2, padding=1)
+ self.down = Downsample2d(out_chs, act_layer=act_layer, norm_layer=norm_layer)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.down(x)
+ return x
+
+
+class WindowAttentionGlobal(nn.Module):
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ window_size: Tuple[int, int],
+ use_global: bool = True,
+ qkv_bias: bool = True,
+ attn_drop: float = 0.,
+ proj_drop: float = 0.,
+ ):
+ super().__init__()
+ window_size = to_2tuple(window_size)
+ self.window_size = window_size
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim ** -0.5
+ self.use_global = use_global
+
+ self.rel_pos = RelPosBias(window_size=window_size, num_heads=num_heads)
+ if self.use_global:
+ self.qkv = nn.Linear(dim, dim * 2, bias=qkv_bias)
+ else:
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, q_global: Optional[torch.Tensor] = None):
+ B, N, C = x.shape
+ if self.use_global and q_global is not None:
+ _assert(x.shape[-1] == q_global.shape[-1], 'x and q_global seq lengths should be equal')
+
+ kv = self.qkv(x)
+ kv = kv.reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
+ k, v = kv.unbind(0)
+
+ q = q_global.repeat(B // q_global.shape[0], 1, 1, 1)
+ q = q.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
+ else:
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0)
+ q = q * self.scale
+
+ attn = q @ k.transpose(-2, -1).contiguous() # NOTE contiguous() fixes an odd jit bug in PyTorch 2.0
+ attn = self.rel_pos(attn)
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+def window_partition(x, window_size: Tuple[int, int]):
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
+ return windows
+
+
+@register_notrace_function # reason: int argument is a Proxy
+def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]):
+ H, W = img_size
+ C = windows.shape[-1]
+ x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
+ return x
+
+
+class LayerScale(nn.Module):
+ def __init__(self, dim, init_values=1e-5, inplace=False):
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x):
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
+
+
+class GlobalContextVitBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ feat_size: Tuple[int, int],
+ num_heads: int,
+ window_size: int = 7,
+ mlp_ratio: float = 4.,
+ use_global: bool = True,
+ qkv_bias: bool = True,
+ layer_scale: Optional[float] = None,
+ proj_drop: float = 0.,
+ attn_drop: float = 0.,
+ drop_path: float = 0.,
+ attn_layer: Callable = WindowAttentionGlobal,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = nn.LayerNorm,
+ ):
+ super().__init__()
+ feat_size = to_2tuple(feat_size)
+ window_size = to_2tuple(window_size)
+ self.window_size = window_size
+ self.num_windows = int((feat_size[0] // window_size[0]) * (feat_size[1] // window_size[1]))
+
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_layer(
+ dim,
+ num_heads=num_heads,
+ window_size=window_size,
+ use_global=use_global,
+ qkv_bias=qkv_bias,
+ attn_drop=attn_drop,
+ proj_drop=proj_drop,
+ )
+ self.ls1 = LayerScale(dim, layer_scale) if layer_scale is not None else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop)
+ self.ls2 = LayerScale(dim, layer_scale) if layer_scale is not None else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def _window_attn(self, x, q_global: Optional[torch.Tensor] = None):
+ B, H, W, C = x.shape
+ x_win = window_partition(x, self.window_size)
+ x_win = x_win.view(-1, self.window_size[0] * self.window_size[1], C)
+ attn_win = self.attn(x_win, q_global)
+ x = window_reverse(attn_win, self.window_size, (H, W))
+ return x
+
+ def forward(self, x, q_global: Optional[torch.Tensor] = None):
+ x = x + self.drop_path1(self.ls1(self._window_attn(self.norm1(x), q_global)))
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
+ return x
+
+
+class GlobalContextVitStage(nn.Module):
+ def __init__(
+ self,
+ dim,
+ depth: int,
+ num_heads: int,
+ feat_size: Tuple[int, int],
+ window_size: Tuple[int, int],
+ downsample: bool = True,
+ global_norm: bool = False,
+ stage_norm: bool = False,
+ mlp_ratio: float = 4.,
+ qkv_bias: bool = True,
+ layer_scale: Optional[float] = None,
+ proj_drop: float = 0.,
+ attn_drop: float = 0.,
+ drop_path: Union[List[float], float] = 0.0,
+ act_layer: Callable = nn.GELU,
+ norm_layer: Callable = nn.LayerNorm,
+ norm_layer_cl: Callable = LayerNorm2d,
+ ):
+ super().__init__()
+ if downsample:
+ self.downsample = Downsample2d(
+ dim=dim,
+ dim_out=dim * 2,
+ norm_layer=norm_layer,
+ )
+ dim = dim * 2
+ feat_size = (feat_size[0] // 2, feat_size[1] // 2)
+ else:
+ self.downsample = nn.Identity()
+ self.feat_size = feat_size
+ window_size = to_2tuple(window_size)
+
+ feat_levels = int(math.log2(min(feat_size) / min(window_size)))
+ self.global_block = FeatureBlock(dim, feat_levels)
+ self.global_norm = norm_layer_cl(dim) if global_norm else nn.Identity()
+
+ self.blocks = nn.ModuleList([
+ GlobalContextVitBlock(
+ dim=dim,
+ num_heads=num_heads,
+ feat_size=feat_size,
+ window_size=window_size,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ use_global=(i % 2 != 0),
+ layer_scale=layer_scale,
+ proj_drop=proj_drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ act_layer=act_layer,
+ norm_layer=norm_layer_cl,
+ )
+ for i in range(depth)
+ ])
+ self.norm = norm_layer_cl(dim) if stage_norm else nn.Identity()
+ self.dim = dim
+ self.feat_size = feat_size
+ self.grad_checkpointing = False
+
+ def forward(self, x):
+ # input NCHW, downsample & global block are 2d conv + pooling
+ x = self.downsample(x)
+ global_query = self.global_block(x)
+
+ # reshape NCHW --> NHWC for transformer blocks
+ x = x.permute(0, 2, 3, 1)
+ global_query = self.global_norm(global_query.permute(0, 2, 3, 1))
+ for blk in self.blocks:
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x, global_query)
+ x = self.norm(x)
+ x = x.permute(0, 3, 1, 2).contiguous() # back to NCHW
+ return x
+
+
+class GlobalContextVit(nn.Module):
+ def __init__(
+ self,
+ in_chans: int = 3,
+ num_classes: int = 1000,
+ global_pool: str = 'avg',
+ img_size: Tuple[int, int] = 224,
+ window_ratio: Tuple[int, ...] = (32, 32, 16, 32),
+ window_size: Tuple[int, ...] = None,
+ embed_dim: int = 64,
+ depths: Tuple[int, ...] = (3, 4, 19, 5),
+ num_heads: Tuple[int, ...] = (2, 4, 8, 16),
+ mlp_ratio: float = 3.0,
+ qkv_bias: bool = True,
+ layer_scale: Optional[float] = None,
+ drop_rate: float = 0.,
+ proj_drop_rate: float = 0.,
+ attn_drop_rate: float = 0.,
+ drop_path_rate: float = 0.,
+ weight_init='',
+ act_layer: str = 'gelu',
+ norm_layer: str = 'layernorm2d',
+ norm_layer_cl: str = 'layernorm',
+ norm_eps: float = 1e-5,
+ ):
+ super().__init__()
+ act_layer = get_act_layer(act_layer)
+ norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
+ norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps)
+
+ img_size = to_2tuple(img_size)
+ feat_size = tuple(d // 4 for d in img_size) # stem reduction by 4
+ self.global_pool = global_pool
+ self.num_classes = num_classes
+ self.drop_rate = drop_rate
+ num_stages = len(depths)
+ self.num_features = int(embed_dim * 2 ** (num_stages - 1))
+ if window_size is not None:
+ window_size = to_ntuple(num_stages)(window_size)
+ else:
+ assert window_ratio is not None
+ window_size = tuple([(img_size[0] // r, img_size[1] // r) for r in to_ntuple(num_stages)(window_ratio)])
+
+ self.stem = Stem(
+ in_chs=in_chans,
+ out_chs=embed_dim,
+ act_layer=act_layer,
+ norm_layer=norm_layer
+ )
+
+ dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
+ stages = []
+ for i in range(num_stages):
+ last_stage = i == num_stages - 1
+ stage_scale = 2 ** max(i - 1, 0)
+ stages.append(GlobalContextVitStage(
+ dim=embed_dim * stage_scale,
+ depth=depths[i],
+ num_heads=num_heads[i],
+ feat_size=(feat_size[0] // stage_scale, feat_size[1] // stage_scale),
+ window_size=window_size[i],
+ downsample=i != 0,
+ stage_norm=last_stage,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ layer_scale=layer_scale,
+ proj_drop=proj_drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[i],
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ norm_layer_cl=norm_layer_cl,
+ ))
+ self.stages = nn.Sequential(*stages)
+
+ # Classifier head
+ self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
+
+ if weight_init:
+ named_apply(partial(self._init_weights, scheme=weight_init), self)
+
+ def _init_weights(self, module, name, scheme='vit'):
+ # note Conv2d left as default init
+ if scheme == 'vit':
+ if isinstance(module, nn.Linear):
+ nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ if 'mlp' in name:
+ nn.init.normal_(module.bias, std=1e-6)
+ else:
+ nn.init.zeros_(module.bias)
+ else:
+ if isinstance(module, nn.Linear):
+ nn.init.normal_(module.weight, std=.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {
+ k for k, _ in self.named_parameters()
+ if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])}
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse=False):
+ matcher = dict(
+ stem=r'^stem', # stem and embed
+ blocks=r'^stages\.(\d+)'
+ )
+ return matcher
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ for s in self.stages:
+ s.grad_checkpointing = enable
+
+ @torch.jit.ignore
+ def get_classifier(self):
+ return self.head.fc
+
+ def reset_classifier(self, num_classes, global_pool=None):
+ self.num_classes = num_classes
+ if global_pool is None:
+ global_pool = self.head.global_pool.pool_type
+ self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
+
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.stem(x)
+ x = self.stages(x)
+ return x
+
+ def forward_head(self, x, pre_logits: bool = False):
+ return self.head(x, pre_logits=pre_logits)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.forward_features(x)
+ x = self.forward_head(x)
+ return x
+
+
+def _create_gcvit(variant, pretrained=False, **kwargs):
+ if kwargs.get('features_only', None):
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
+ model = build_model_with_cfg(GlobalContextVit, variant, pretrained, **kwargs)
+ return model
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'stem.conv1', 'classifier': 'head.fc',
+ 'fixed_input_size': True,
+ **kwargs
+ }
+
+
+default_cfgs = generate_default_cfgs({
+ 'gcvit_xxtiny.in1k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_xxtiny_224_nvidia-d1d86009.pth'),
+ 'gcvit_xtiny.in1k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_xtiny_224_nvidia-274b92b7.pth'),
+ 'gcvit_tiny.in1k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_tiny_224_nvidia-ac783954.pth'),
+ 'gcvit_small.in1k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_small_224_nvidia-4e98afa2.pth'),
+ 'gcvit_base.in1k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-morevit/gcvit_base_224_nvidia-f009139b.pth'),
+})
+
+
+@register_model
+def gcvit_xxtiny(pretrained=False, **kwargs) -> GlobalContextVit:
+ model_kwargs = dict(
+ depths=(2, 2, 6, 2),
+ num_heads=(2, 4, 8, 16),
+ **kwargs)
+ return _create_gcvit('gcvit_xxtiny', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def gcvit_xtiny(pretrained=False, **kwargs) -> GlobalContextVit:
+ model_kwargs = dict(
+ depths=(3, 4, 6, 5),
+ num_heads=(2, 4, 8, 16),
+ **kwargs)
+ return _create_gcvit('gcvit_xtiny', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def gcvit_tiny(pretrained=False, **kwargs) -> GlobalContextVit:
+ model_kwargs = dict(
+ depths=(3, 4, 19, 5),
+ num_heads=(2, 4, 8, 16),
+ **kwargs)
+ return _create_gcvit('gcvit_tiny', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def gcvit_small(pretrained=False, **kwargs) -> GlobalContextVit:
+ model_kwargs = dict(
+ depths=(3, 4, 19, 5),
+ num_heads=(3, 6, 12, 24),
+ embed_dim=96,
+ mlp_ratio=2,
+ layer_scale=1e-5,
+ **kwargs)
+ return _create_gcvit('gcvit_small', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def gcvit_base(pretrained=False, **kwargs) -> GlobalContextVit:
+ model_kwargs = dict(
+ depths=(3, 4, 19, 5),
+ num_heads=(4, 8, 16, 32),
+ embed_dim=128,
+ mlp_ratio=2,
+ layer_scale=1e-5,
+ **kwargs)
+ return _create_gcvit('gcvit_base', pretrained=pretrained, **model_kwargs)
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/ghostnet.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/ghostnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..d34b54852159fd91218a067e85e5ea511d76a30e
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/ghostnet.py
@@ -0,0 +1,432 @@
+"""
+An implementation of GhostNet & GhostNetV2 Models as defined in:
+GhostNet: More Features from Cheap Operations. https://arxiv.org/abs/1911.11907
+GhostNetV2: Enhance Cheap Operation with Long-Range Attention. https://proceedings.neurips.cc/paper_files/paper/2022/file/40b60852a4abdaa696b5a1a78da34635-Paper-Conference.pdf
+
+The train script & code of models at:
+Original model: https://github.com/huawei-noah/CV-backbones/tree/master/ghostnet_pytorch
+Original model: https://github.com/huawei-noah/Efficient-AI-Backbones/blob/master/ghostnetv2_pytorch/model/ghostnetv2_torch.py
+"""
+import math
+from functools import partial
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.layers import SelectAdaptivePool2d, Linear, make_divisible
+from ._builder import build_model_with_cfg
+from ._efficientnet_blocks import SqueezeExcite, ConvBnAct
+from ._manipulate import checkpoint_seq
+from ._registry import register_model, generate_default_cfgs
+
+__all__ = ['GhostNet']
+
+
+_SE_LAYER = partial(SqueezeExcite, gate_layer='hard_sigmoid', rd_round_fn=partial(make_divisible, divisor=4))
+
+
+class GhostModule(nn.Module):
+ def __init__(
+ self,
+ in_chs,
+ out_chs,
+ kernel_size=1,
+ ratio=2,
+ dw_size=3,
+ stride=1,
+ use_act=True,
+ act_layer=nn.ReLU,
+ ):
+ super(GhostModule, self).__init__()
+ self.out_chs = out_chs
+ init_chs = math.ceil(out_chs / ratio)
+ new_chs = init_chs * (ratio - 1)
+
+ self.primary_conv = nn.Sequential(
+ nn.Conv2d(in_chs, init_chs, kernel_size, stride, kernel_size // 2, bias=False),
+ nn.BatchNorm2d(init_chs),
+ act_layer(inplace=True) if use_act else nn.Identity(),
+ )
+
+ self.cheap_operation = nn.Sequential(
+ nn.Conv2d(init_chs, new_chs, dw_size, 1, dw_size//2, groups=init_chs, bias=False),
+ nn.BatchNorm2d(new_chs),
+ act_layer(inplace=True) if use_act else nn.Identity(),
+ )
+
+ def forward(self, x):
+ x1 = self.primary_conv(x)
+ x2 = self.cheap_operation(x1)
+ out = torch.cat([x1, x2], dim=1)
+ return out[:, :self.out_chs, :, :]
+
+
+class GhostModuleV2(nn.Module):
+ def __init__(
+ self,
+ in_chs,
+ out_chs,
+ kernel_size=1,
+ ratio=2,
+ dw_size=3,
+ stride=1,
+ use_act=True,
+ act_layer=nn.ReLU,
+ ):
+ super().__init__()
+ self.gate_fn = nn.Sigmoid()
+ self.out_chs = out_chs
+ init_chs = math.ceil(out_chs / ratio)
+ new_chs = init_chs * (ratio - 1)
+ self.primary_conv = nn.Sequential(
+ nn.Conv2d(in_chs, init_chs, kernel_size, stride, kernel_size // 2, bias=False),
+ nn.BatchNorm2d(init_chs),
+ act_layer(inplace=True) if use_act else nn.Identity(),
+ )
+ self.cheap_operation = nn.Sequential(
+ nn.Conv2d(init_chs, new_chs, dw_size, 1, dw_size // 2, groups=init_chs, bias=False),
+ nn.BatchNorm2d(new_chs),
+ act_layer(inplace=True) if use_act else nn.Identity(),
+ )
+ self.short_conv = nn.Sequential(
+ nn.Conv2d(in_chs, out_chs, kernel_size, stride, kernel_size // 2, bias=False),
+ nn.BatchNorm2d(out_chs),
+ nn.Conv2d(out_chs, out_chs, kernel_size=(1, 5), stride=1, padding=(0, 2), groups=out_chs, bias=False),
+ nn.BatchNorm2d(out_chs),
+ nn.Conv2d(out_chs, out_chs, kernel_size=(5, 1), stride=1, padding=(2, 0), groups=out_chs, bias=False),
+ nn.BatchNorm2d(out_chs),
+ )
+
+ def forward(self, x):
+ res = self.short_conv(F.avg_pool2d(x, kernel_size=2, stride=2))
+ x1 = self.primary_conv(x)
+ x2 = self.cheap_operation(x1)
+ out = torch.cat([x1, x2], dim=1)
+ return out[:, :self.out_chs, :, :] * F.interpolate(
+ self.gate_fn(res), size=(out.shape[-2], out.shape[-1]), mode='nearest')
+
+
+class GhostBottleneck(nn.Module):
+ """ Ghost bottleneck w/ optional SE"""
+
+ def __init__(
+ self,
+ in_chs,
+ mid_chs,
+ out_chs,
+ dw_kernel_size=3,
+ stride=1,
+ act_layer=nn.ReLU,
+ se_ratio=0.,
+ mode='original',
+ ):
+ super(GhostBottleneck, self).__init__()
+ has_se = se_ratio is not None and se_ratio > 0.
+ self.stride = stride
+
+ # Point-wise expansion
+ if mode == 'original':
+ self.ghost1 = GhostModule(in_chs, mid_chs, use_act=True, act_layer=act_layer)
+ else:
+ self.ghost1 = GhostModuleV2(in_chs, mid_chs, use_act=True, act_layer=act_layer)
+
+ # Depth-wise convolution
+ if self.stride > 1:
+ self.conv_dw = nn.Conv2d(
+ mid_chs, mid_chs, dw_kernel_size, stride=stride,
+ padding=(dw_kernel_size-1)//2, groups=mid_chs, bias=False)
+ self.bn_dw = nn.BatchNorm2d(mid_chs)
+ else:
+ self.conv_dw = None
+ self.bn_dw = None
+
+ # Squeeze-and-excitation
+ self.se = _SE_LAYER(mid_chs, rd_ratio=se_ratio) if has_se else None
+
+ # Point-wise linear projection
+ self.ghost2 = GhostModule(mid_chs, out_chs, use_act=False)
+
+ # shortcut
+ if in_chs == out_chs and self.stride == 1:
+ self.shortcut = nn.Sequential()
+ else:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(
+ in_chs, in_chs, dw_kernel_size, stride=stride,
+ padding=(dw_kernel_size-1)//2, groups=in_chs, bias=False),
+ nn.BatchNorm2d(in_chs),
+ nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False),
+ nn.BatchNorm2d(out_chs),
+ )
+
+ def forward(self, x):
+ shortcut = x
+
+ # 1st ghost bottleneck
+ x = self.ghost1(x)
+
+ # Depth-wise convolution
+ if self.conv_dw is not None:
+ x = self.conv_dw(x)
+ x = self.bn_dw(x)
+
+ # Squeeze-and-excitation
+ if self.se is not None:
+ x = self.se(x)
+
+ # 2nd ghost bottleneck
+ x = self.ghost2(x)
+
+ x += self.shortcut(shortcut)
+ return x
+
+
+class GhostNet(nn.Module):
+ def __init__(
+ self,
+ cfgs,
+ num_classes=1000,
+ width=1.0,
+ in_chans=3,
+ output_stride=32,
+ global_pool='avg',
+ drop_rate=0.2,
+ version='v1',
+ ):
+ super(GhostNet, self).__init__()
+ # setting of inverted residual blocks
+ assert output_stride == 32, 'only output_stride==32 is valid, dilation not supported'
+ self.cfgs = cfgs
+ self.num_classes = num_classes
+ self.drop_rate = drop_rate
+ self.grad_checkpointing = False
+ self.feature_info = []
+
+ # building first layer
+ stem_chs = make_divisible(16 * width, 4)
+ self.conv_stem = nn.Conv2d(in_chans, stem_chs, 3, 2, 1, bias=False)
+ self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=f'conv_stem'))
+ self.bn1 = nn.BatchNorm2d(stem_chs)
+ self.act1 = nn.ReLU(inplace=True)
+ prev_chs = stem_chs
+
+ # building inverted residual blocks
+ stages = nn.ModuleList([])
+ stage_idx = 0
+ layer_idx = 0
+ net_stride = 2
+ for cfg in self.cfgs:
+ layers = []
+ s = 1
+ for k, exp_size, c, se_ratio, s in cfg:
+ out_chs = make_divisible(c * width, 4)
+ mid_chs = make_divisible(exp_size * width, 4)
+ layer_kwargs = {}
+ if version == 'v2' and layer_idx > 1:
+ layer_kwargs['mode'] = 'attn'
+ layers.append(GhostBottleneck(prev_chs, mid_chs, out_chs, k, s, se_ratio=se_ratio, **layer_kwargs))
+ prev_chs = out_chs
+ layer_idx += 1
+ if s > 1:
+ net_stride *= 2
+ self.feature_info.append(dict(
+ num_chs=prev_chs, reduction=net_stride, module=f'blocks.{stage_idx}'))
+ stages.append(nn.Sequential(*layers))
+ stage_idx += 1
+
+ out_chs = make_divisible(exp_size * width, 4)
+ stages.append(nn.Sequential(ConvBnAct(prev_chs, out_chs, 1)))
+ self.pool_dim = prev_chs = out_chs
+
+ self.blocks = nn.Sequential(*stages)
+
+ # building last several layers
+ self.num_features = out_chs = 1280
+ self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
+ self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True)
+ self.act2 = nn.ReLU(inplace=True)
+ self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
+ self.classifier = Linear(out_chs, num_classes) if num_classes > 0 else nn.Identity()
+
+ # FIXME init
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse=False):
+ matcher = dict(
+ stem=r'^conv_stem|bn1',
+ blocks=[
+ (r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)', None),
+ (r'conv_head', (99999,))
+ ]
+ )
+ return matcher
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.grad_checkpointing = enable
+
+ @torch.jit.ignore
+ def get_classifier(self):
+ return self.classifier
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.num_classes = num_classes
+ # cannot meaningfully change pooling of efficient head after creation
+ self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
+ self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
+ self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ x = self.conv_stem(x)
+ x = self.bn1(x)
+ x = self.act1(x)
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint_seq(self.blocks, x, flatten=True)
+ else:
+ x = self.blocks(x)
+ return x
+
+ def forward_head(self, x):
+ x = self.global_pool(x)
+ x = self.conv_head(x)
+ x = self.act2(x)
+ x = self.flatten(x)
+ if self.drop_rate > 0.:
+ x = F.dropout(x, p=self.drop_rate, training=self.training)
+ x = self.classifier(x)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.forward_head(x)
+ return x
+
+
+def checkpoint_filter_fn(state_dict, model: nn.Module):
+ out_dict = {}
+ for k, v in state_dict.items():
+ if 'total' in k:
+ continue
+ out_dict[k] = v
+ return out_dict
+
+
+def _create_ghostnet(variant, width=1.0, pretrained=False, **kwargs):
+ """
+ Constructs a GhostNet model
+ """
+ cfgs = [
+ # k, t, c, SE, s
+ # stage1
+ [[3, 16, 16, 0, 1]],
+ # stage2
+ [[3, 48, 24, 0, 2]],
+ [[3, 72, 24, 0, 1]],
+ # stage3
+ [[5, 72, 40, 0.25, 2]],
+ [[5, 120, 40, 0.25, 1]],
+ # stage4
+ [[3, 240, 80, 0, 2]],
+ [[3, 200, 80, 0, 1],
+ [3, 184, 80, 0, 1],
+ [3, 184, 80, 0, 1],
+ [3, 480, 112, 0.25, 1],
+ [3, 672, 112, 0.25, 1]
+ ],
+ # stage5
+ [[5, 672, 160, 0.25, 2]],
+ [[5, 960, 160, 0, 1],
+ [5, 960, 160, 0.25, 1],
+ [5, 960, 160, 0, 1],
+ [5, 960, 160, 0.25, 1]
+ ]
+ ]
+ model_kwargs = dict(
+ cfgs=cfgs,
+ width=width,
+ **kwargs,
+ )
+ return build_model_with_cfg(
+ GhostNet,
+ variant,
+ pretrained,
+ pretrained_filter_fn=checkpoint_filter_fn,
+ feature_cfg=dict(flatten_sequential=True),
+ **model_kwargs,
+ )
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'conv_stem', 'classifier': 'classifier',
+ **kwargs
+ }
+
+
+default_cfgs = generate_default_cfgs({
+ 'ghostnet_050.untrained': _cfg(),
+ 'ghostnet_100.in1k': _cfg(
+ hf_hub_id='timm/',
+ # url='https://github.com/huawei-noah/CV-backbones/releases/download/ghostnet_pth/ghostnet_1x.pth'
+ ),
+ 'ghostnet_130.untrained': _cfg(),
+ 'ghostnetv2_100.in1k': _cfg(
+ hf_hub_id='timm/',
+ # url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/GhostNetV2/ck_ghostnetv2_10.pth.tar'
+ ),
+ 'ghostnetv2_130.in1k': _cfg(
+ hf_hub_id='timm/',
+ # url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/GhostNetV2/ck_ghostnetv2_13.pth.tar'
+ ),
+ 'ghostnetv2_160.in1k': _cfg(
+ hf_hub_id='timm/',
+ # url='https://github.com/huawei-noah/Efficient-AI-Backbones/releases/download/GhostNetV2/ck_ghostnetv2_16.pth.tar'
+ ),
+})
+
+
+@register_model
+def ghostnet_050(pretrained=False, **kwargs) -> GhostNet:
+ """ GhostNet-0.5x """
+ model = _create_ghostnet('ghostnet_050', width=0.5, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def ghostnet_100(pretrained=False, **kwargs) -> GhostNet:
+ """ GhostNet-1.0x """
+ model = _create_ghostnet('ghostnet_100', width=1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def ghostnet_130(pretrained=False, **kwargs) -> GhostNet:
+ """ GhostNet-1.3x """
+ model = _create_ghostnet('ghostnet_130', width=1.3, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def ghostnetv2_100(pretrained=False, **kwargs) -> GhostNet:
+ """ GhostNetV2-1.0x """
+ model = _create_ghostnet('ghostnetv2_100', width=1.0, pretrained=pretrained, version='v2', **kwargs)
+ return model
+
+
+@register_model
+def ghostnetv2_130(pretrained=False, **kwargs) -> GhostNet:
+ """ GhostNetV2-1.3x """
+ model = _create_ghostnet('ghostnetv2_130', width=1.3, pretrained=pretrained, version='v2', **kwargs)
+ return model
+
+
+@register_model
+def ghostnetv2_160(pretrained=False, **kwargs) -> GhostNet:
+ """ GhostNetV2-1.6x """
+ model = _create_ghostnet('ghostnetv2_160', width=1.6, pretrained=pretrained, version='v2', **kwargs)
+ return model
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/inception_v4.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/inception_v4.py
new file mode 100644
index 0000000000000000000000000000000000000000..a43290a3db7d3e846e658572f58a1213f5d486f0
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/inception_v4.py
@@ -0,0 +1,325 @@
+""" Pytorch Inception-V4 implementation
+Sourced from https://github.com/Cadene/tensorflow-model-zoo.torch (MIT License) which is
+based upon Google's Tensorflow implementation and pretrained weights (Apache 2.0 License)
+"""
+from functools import partial
+
+import torch
+import torch.nn as nn
+
+from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
+from timm.layers import create_classifier, ConvNormAct
+from ._builder import build_model_with_cfg
+from ._registry import register_model, generate_default_cfgs
+
+__all__ = ['InceptionV4']
+
+
+class Mixed3a(nn.Module):
+ def __init__(self, conv_block=ConvNormAct):
+ super(Mixed3a, self).__init__()
+ self.maxpool = nn.MaxPool2d(3, stride=2)
+ self.conv = conv_block(64, 96, kernel_size=3, stride=2)
+
+ def forward(self, x):
+ x0 = self.maxpool(x)
+ x1 = self.conv(x)
+ out = torch.cat((x0, x1), 1)
+ return out
+
+
+class Mixed4a(nn.Module):
+ def __init__(self, conv_block=ConvNormAct):
+ super(Mixed4a, self).__init__()
+
+ self.branch0 = nn.Sequential(
+ conv_block(160, 64, kernel_size=1, stride=1),
+ conv_block(64, 96, kernel_size=3, stride=1)
+ )
+
+ self.branch1 = nn.Sequential(
+ conv_block(160, 64, kernel_size=1, stride=1),
+ conv_block(64, 64, kernel_size=(1, 7), stride=1, padding=(0, 3)),
+ conv_block(64, 64, kernel_size=(7, 1), stride=1, padding=(3, 0)),
+ conv_block(64, 96, kernel_size=(3, 3), stride=1)
+ )
+
+ def forward(self, x):
+ x0 = self.branch0(x)
+ x1 = self.branch1(x)
+ out = torch.cat((x0, x1), 1)
+ return out
+
+
+class Mixed5a(nn.Module):
+ def __init__(self, conv_block=ConvNormAct):
+ super(Mixed5a, self).__init__()
+ self.conv = conv_block(192, 192, kernel_size=3, stride=2)
+ self.maxpool = nn.MaxPool2d(3, stride=2)
+
+ def forward(self, x):
+ x0 = self.conv(x)
+ x1 = self.maxpool(x)
+ out = torch.cat((x0, x1), 1)
+ return out
+
+
+class InceptionA(nn.Module):
+ def __init__(self, conv_block=ConvNormAct):
+ super(InceptionA, self).__init__()
+ self.branch0 = conv_block(384, 96, kernel_size=1, stride=1)
+
+ self.branch1 = nn.Sequential(
+ conv_block(384, 64, kernel_size=1, stride=1),
+ conv_block(64, 96, kernel_size=3, stride=1, padding=1)
+ )
+
+ self.branch2 = nn.Sequential(
+ conv_block(384, 64, kernel_size=1, stride=1),
+ conv_block(64, 96, kernel_size=3, stride=1, padding=1),
+ conv_block(96, 96, kernel_size=3, stride=1, padding=1)
+ )
+
+ self.branch3 = nn.Sequential(
+ nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
+ conv_block(384, 96, kernel_size=1, stride=1)
+ )
+
+ def forward(self, x):
+ x0 = self.branch0(x)
+ x1 = self.branch1(x)
+ x2 = self.branch2(x)
+ x3 = self.branch3(x)
+ out = torch.cat((x0, x1, x2, x3), 1)
+ return out
+
+
+class ReductionA(nn.Module):
+ def __init__(self, conv_block=ConvNormAct):
+ super(ReductionA, self).__init__()
+ self.branch0 = conv_block(384, 384, kernel_size=3, stride=2)
+
+ self.branch1 = nn.Sequential(
+ conv_block(384, 192, kernel_size=1, stride=1),
+ conv_block(192, 224, kernel_size=3, stride=1, padding=1),
+ conv_block(224, 256, kernel_size=3, stride=2)
+ )
+
+ self.branch2 = nn.MaxPool2d(3, stride=2)
+
+ def forward(self, x):
+ x0 = self.branch0(x)
+ x1 = self.branch1(x)
+ x2 = self.branch2(x)
+ out = torch.cat((x0, x1, x2), 1)
+ return out
+
+
+class InceptionB(nn.Module):
+ def __init__(self, conv_block=ConvNormAct):
+ super(InceptionB, self).__init__()
+ self.branch0 = conv_block(1024, 384, kernel_size=1, stride=1)
+
+ self.branch1 = nn.Sequential(
+ conv_block(1024, 192, kernel_size=1, stride=1),
+ conv_block(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)),
+ conv_block(224, 256, kernel_size=(7, 1), stride=1, padding=(3, 0))
+ )
+
+ self.branch2 = nn.Sequential(
+ conv_block(1024, 192, kernel_size=1, stride=1),
+ conv_block(192, 192, kernel_size=(7, 1), stride=1, padding=(3, 0)),
+ conv_block(192, 224, kernel_size=(1, 7), stride=1, padding=(0, 3)),
+ conv_block(224, 224, kernel_size=(7, 1), stride=1, padding=(3, 0)),
+ conv_block(224, 256, kernel_size=(1, 7), stride=1, padding=(0, 3))
+ )
+
+ self.branch3 = nn.Sequential(
+ nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
+ conv_block(1024, 128, kernel_size=1, stride=1)
+ )
+
+ def forward(self, x):
+ x0 = self.branch0(x)
+ x1 = self.branch1(x)
+ x2 = self.branch2(x)
+ x3 = self.branch3(x)
+ out = torch.cat((x0, x1, x2, x3), 1)
+ return out
+
+
+class ReductionB(nn.Module):
+ def __init__(self, conv_block=ConvNormAct):
+ super(ReductionB, self).__init__()
+
+ self.branch0 = nn.Sequential(
+ conv_block(1024, 192, kernel_size=1, stride=1),
+ conv_block(192, 192, kernel_size=3, stride=2)
+ )
+
+ self.branch1 = nn.Sequential(
+ conv_block(1024, 256, kernel_size=1, stride=1),
+ conv_block(256, 256, kernel_size=(1, 7), stride=1, padding=(0, 3)),
+ conv_block(256, 320, kernel_size=(7, 1), stride=1, padding=(3, 0)),
+ conv_block(320, 320, kernel_size=3, stride=2)
+ )
+
+ self.branch2 = nn.MaxPool2d(3, stride=2)
+
+ def forward(self, x):
+ x0 = self.branch0(x)
+ x1 = self.branch1(x)
+ x2 = self.branch2(x)
+ out = torch.cat((x0, x1, x2), 1)
+ return out
+
+
+class InceptionC(nn.Module):
+ def __init__(self, conv_block=ConvNormAct):
+ super(InceptionC, self).__init__()
+
+ self.branch0 = conv_block(1536, 256, kernel_size=1, stride=1)
+
+ self.branch1_0 = conv_block(1536, 384, kernel_size=1, stride=1)
+ self.branch1_1a = conv_block(384, 256, kernel_size=(1, 3), stride=1, padding=(0, 1))
+ self.branch1_1b = conv_block(384, 256, kernel_size=(3, 1), stride=1, padding=(1, 0))
+
+ self.branch2_0 = conv_block(1536, 384, kernel_size=1, stride=1)
+ self.branch2_1 = conv_block(384, 448, kernel_size=(3, 1), stride=1, padding=(1, 0))
+ self.branch2_2 = conv_block(448, 512, kernel_size=(1, 3), stride=1, padding=(0, 1))
+ self.branch2_3a = conv_block(512, 256, kernel_size=(1, 3), stride=1, padding=(0, 1))
+ self.branch2_3b = conv_block(512, 256, kernel_size=(3, 1), stride=1, padding=(1, 0))
+
+ self.branch3 = nn.Sequential(
+ nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
+ conv_block(1536, 256, kernel_size=1, stride=1)
+ )
+
+ def forward(self, x):
+ x0 = self.branch0(x)
+
+ x1_0 = self.branch1_0(x)
+ x1_1a = self.branch1_1a(x1_0)
+ x1_1b = self.branch1_1b(x1_0)
+ x1 = torch.cat((x1_1a, x1_1b), 1)
+
+ x2_0 = self.branch2_0(x)
+ x2_1 = self.branch2_1(x2_0)
+ x2_2 = self.branch2_2(x2_1)
+ x2_3a = self.branch2_3a(x2_2)
+ x2_3b = self.branch2_3b(x2_2)
+ x2 = torch.cat((x2_3a, x2_3b), 1)
+
+ x3 = self.branch3(x)
+
+ out = torch.cat((x0, x1, x2, x3), 1)
+ return out
+
+
+class InceptionV4(nn.Module):
+ def __init__(
+ self,
+ num_classes=1000,
+ in_chans=3,
+ output_stride=32,
+ drop_rate=0.,
+ global_pool='avg',
+ norm_layer='batchnorm2d',
+ norm_eps=1e-3,
+ act_layer='relu',
+ ):
+ super(InceptionV4, self).__init__()
+ assert output_stride == 32
+ self.num_classes = num_classes
+ self.num_features = 1536
+ conv_block = partial(
+ ConvNormAct,
+ padding=0,
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ norm_kwargs=dict(eps=norm_eps),
+ act_kwargs=dict(inplace=True),
+ )
+
+ features = [
+ conv_block(in_chans, 32, kernel_size=3, stride=2),
+ conv_block(32, 32, kernel_size=3, stride=1),
+ conv_block(32, 64, kernel_size=3, stride=1, padding=1),
+ Mixed3a(conv_block),
+ Mixed4a(conv_block),
+ Mixed5a(conv_block),
+ ]
+ features += [InceptionA(conv_block) for _ in range(4)]
+ features += [ReductionA(conv_block)] # Mixed6a
+ features += [InceptionB(conv_block) for _ in range(7)]
+ features += [ReductionB(conv_block)] # Mixed7a
+ features += [InceptionC(conv_block) for _ in range(3)]
+ self.features = nn.Sequential(*features)
+ self.feature_info = [
+ dict(num_chs=64, reduction=2, module='features.2'),
+ dict(num_chs=160, reduction=4, module='features.3'),
+ dict(num_chs=384, reduction=8, module='features.9'),
+ dict(num_chs=1024, reduction=16, module='features.17'),
+ dict(num_chs=1536, reduction=32, module='features.21'),
+ ]
+ self.global_pool, self.head_drop, self.last_linear = create_classifier(
+ self.num_features, self.num_classes, pool_type=global_pool, drop_rate=drop_rate)
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse=False):
+ return dict(
+ stem=r'^features\.[012]\.',
+ blocks=r'^features\.(\d+)'
+ )
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ assert not enable, 'gradient checkpointing not supported'
+
+ @torch.jit.ignore
+ def get_classifier(self):
+ return self.last_linear
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.num_classes = num_classes
+ self.global_pool, self.last_linear = create_classifier(
+ self.num_features, self.num_classes, pool_type=global_pool)
+
+ def forward_features(self, x):
+ return self.features(x)
+
+ def forward_head(self, x, pre_logits: bool = False):
+ x = self.global_pool(x)
+ x = self.head_drop(x)
+ return x if pre_logits else self.last_linear(x)
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.forward_head(x)
+ return x
+
+
+def _create_inception_v4(variant, pretrained=False, **kwargs) -> InceptionV4:
+ return build_model_with_cfg(
+ InceptionV4,
+ variant,
+ pretrained,
+ feature_cfg=dict(flatten_sequential=True),
+ **kwargs,
+ )
+
+
+default_cfgs = generate_default_cfgs({
+ 'inception_v4.tf_in1k': {
+ 'hf_hub_id': 'timm/',
+ 'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
+ 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
+ 'first_conv': 'features.0.conv', 'classifier': 'last_linear',
+ }
+})
+
+
+@register_model
+def inception_v4(pretrained=False, **kwargs):
+ return _create_inception_v4('inception_v4', pretrained, **kwargs)
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/levit.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/levit.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca0708bd5944efe747382e5120b6980e13784072
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/levit.py
@@ -0,0 +1,933 @@
+""" LeViT
+
+Paper: `LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference`
+ - https://arxiv.org/abs/2104.01136
+
+@article{graham2021levit,
+ title={LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference},
+ author={Benjamin Graham and Alaaeldin El-Nouby and Hugo Touvron and Pierre Stock and Armand Joulin and Herv\'e J\'egou and Matthijs Douze},
+ journal={arXiv preprint arXiv:22104.01136},
+ year={2021}
+}
+
+Adapted from official impl at https://github.com/facebookresearch/LeViT, original copyright bellow.
+
+This version combines both conv/linear models and fixes torchscript compatibility.
+
+Modifications and additions for timm hacked together by / Copyright 2021, Ross Wightman
+"""
+
+# Copyright (c) 2015-present, Facebook, Inc.
+# All rights reserved.
+
+# Modified from
+# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
+# Copyright 2020 Ross Wightman, Apache-2.0 License
+from collections import OrderedDict
+from functools import partial
+from typing import Dict
+
+import torch
+import torch.nn as nn
+
+from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
+from timm.layers import to_ntuple, to_2tuple, get_act_layer, DropPath, trunc_normal_, ndgrid
+from ._builder import build_model_with_cfg
+from ._manipulate import checkpoint_seq
+from ._registry import generate_default_cfgs, register_model
+
+__all__ = ['Levit']
+
+
+class ConvNorm(nn.Module):
+ def __init__(
+ self, in_chs, out_chs, kernel_size=1, stride=1, padding=0, dilation=1, groups=1, bn_weight_init=1):
+ super().__init__()
+ self.linear = nn.Conv2d(in_chs, out_chs, kernel_size, stride, padding, dilation, groups, bias=False)
+ self.bn = nn.BatchNorm2d(out_chs)
+
+ nn.init.constant_(self.bn.weight, bn_weight_init)
+
+ @torch.no_grad()
+ def fuse(self):
+ c, bn = self.linear, self.bn
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
+ w = c.weight * w[:, None, None, None]
+ b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
+ m = nn.Conv2d(
+ w.size(1), w.size(0), w.shape[2:], stride=self.linear.stride,
+ padding=self.linear.padding, dilation=self.linear.dilation, groups=self.linear.groups)
+ m.weight.data.copy_(w)
+ m.bias.data.copy_(b)
+ return m
+
+ def forward(self, x):
+ return self.bn(self.linear(x))
+
+
+class LinearNorm(nn.Module):
+ def __init__(self, in_features, out_features, bn_weight_init=1):
+ super().__init__()
+ self.linear = nn.Linear(in_features, out_features, bias=False)
+ self.bn = nn.BatchNorm1d(out_features)
+
+ nn.init.constant_(self.bn.weight, bn_weight_init)
+
+ @torch.no_grad()
+ def fuse(self):
+ l, bn = self.linear, self.bn
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
+ w = l.weight * w[:, None]
+ b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
+ m = nn.Linear(w.size(1), w.size(0))
+ m.weight.data.copy_(w)
+ m.bias.data.copy_(b)
+ return m
+
+ def forward(self, x):
+ x = self.linear(x)
+ return self.bn(x.flatten(0, 1)).reshape_as(x)
+
+
+class NormLinear(nn.Module):
+ def __init__(self, in_features, out_features, bias=True, std=0.02, drop=0.):
+ super().__init__()
+ self.bn = nn.BatchNorm1d(in_features)
+ self.drop = nn.Dropout(drop)
+ self.linear = nn.Linear(in_features, out_features, bias=bias)
+
+ trunc_normal_(self.linear.weight, std=std)
+ if self.linear.bias is not None:
+ nn.init.constant_(self.linear.bias, 0)
+
+ @torch.no_grad()
+ def fuse(self):
+ bn, l = self.bn, self.linear
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
+ b = bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5
+ w = l.weight * w[None, :]
+ if l.bias is None:
+ b = b @ self.linear.weight.T
+ else:
+ b = (l.weight @ b[:, None]).view(-1) + self.linear.bias
+ m = nn.Linear(w.size(1), w.size(0))
+ m.weight.data.copy_(w)
+ m.bias.data.copy_(b)
+ return m
+
+ def forward(self, x):
+ return self.linear(self.drop(self.bn(x)))
+
+
+class Stem8(nn.Sequential):
+ def __init__(self, in_chs, out_chs, act_layer):
+ super().__init__()
+ self.stride = 8
+
+ self.add_module('conv1', ConvNorm(in_chs, out_chs // 4, 3, stride=2, padding=1))
+ self.add_module('act1', act_layer())
+ self.add_module('conv2', ConvNorm(out_chs // 4, out_chs // 2, 3, stride=2, padding=1))
+ self.add_module('act2', act_layer())
+ self.add_module('conv3', ConvNorm(out_chs // 2, out_chs, 3, stride=2, padding=1))
+
+
+class Stem16(nn.Sequential):
+ def __init__(self, in_chs, out_chs, act_layer):
+ super().__init__()
+ self.stride = 16
+
+ self.add_module('conv1', ConvNorm(in_chs, out_chs // 8, 3, stride=2, padding=1))
+ self.add_module('act1', act_layer())
+ self.add_module('conv2', ConvNorm(out_chs // 8, out_chs // 4, 3, stride=2, padding=1))
+ self.add_module('act2', act_layer())
+ self.add_module('conv3', ConvNorm(out_chs // 4, out_chs // 2, 3, stride=2, padding=1))
+ self.add_module('act3', act_layer())
+ self.add_module('conv4', ConvNorm(out_chs // 2, out_chs, 3, stride=2, padding=1))
+
+
+class Downsample(nn.Module):
+ def __init__(self, stride, resolution, use_pool=False):
+ super().__init__()
+ self.stride = stride
+ self.resolution = to_2tuple(resolution)
+ self.pool = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) if use_pool else None
+
+ def forward(self, x):
+ B, N, C = x.shape
+ x = x.view(B, self.resolution[0], self.resolution[1], C)
+ if self.pool is not None:
+ x = self.pool(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1)
+ else:
+ x = x[:, ::self.stride, ::self.stride]
+ return x.reshape(B, -1, C)
+
+
+class Attention(nn.Module):
+ attention_bias_cache: Dict[str, torch.Tensor]
+
+ def __init__(
+ self,
+ dim,
+ key_dim,
+ num_heads=8,
+ attn_ratio=4.,
+ resolution=14,
+ use_conv=False,
+ act_layer=nn.SiLU,
+ ):
+ super().__init__()
+ ln_layer = ConvNorm if use_conv else LinearNorm
+ resolution = to_2tuple(resolution)
+
+ self.use_conv = use_conv
+ self.num_heads = num_heads
+ self.scale = key_dim ** -0.5
+ self.key_dim = key_dim
+ self.key_attn_dim = key_dim * num_heads
+ self.val_dim = int(attn_ratio * key_dim)
+ self.val_attn_dim = int(attn_ratio * key_dim) * num_heads
+
+ self.qkv = ln_layer(dim, self.val_attn_dim + self.key_attn_dim * 2)
+ self.proj = nn.Sequential(OrderedDict([
+ ('act', act_layer()),
+ ('ln', ln_layer(self.val_attn_dim, dim, bn_weight_init=0))
+ ]))
+
+ self.attention_biases = nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1]))
+ pos = torch.stack(ndgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1)
+ rel_pos = (pos[..., :, None] - pos[..., None, :]).abs()
+ rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
+ self.register_buffer('attention_bias_idxs', rel_pos, persistent=False)
+ self.attention_bias_cache = {}
+
+ @torch.no_grad()
+ def train(self, mode=True):
+ super().train(mode)
+ if mode and self.attention_bias_cache:
+ self.attention_bias_cache = {} # clear ab cache
+
+ def get_attention_biases(self, device: torch.device) -> torch.Tensor:
+ if torch.jit.is_tracing() or self.training:
+ return self.attention_biases[:, self.attention_bias_idxs]
+ else:
+ device_key = str(device)
+ if device_key not in self.attention_bias_cache:
+ self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
+ return self.attention_bias_cache[device_key]
+
+ def forward(self, x): # x (B,C,H,W)
+ if self.use_conv:
+ B, C, H, W = x.shape
+ q, k, v = self.qkv(x).view(
+ B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.val_dim], dim=2)
+
+ attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
+ attn = attn.softmax(dim=-1)
+
+ x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
+ else:
+ B, N, C = x.shape
+ q, k, v = self.qkv(x).view(
+ B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.val_dim], dim=3)
+ q = q.permute(0, 2, 1, 3)
+ k = k.permute(0, 2, 3, 1)
+ v = v.permute(0, 2, 1, 3)
+
+ attn = q @ k * self.scale + self.get_attention_biases(x.device)
+ attn = attn.softmax(dim=-1)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, self.val_attn_dim)
+ x = self.proj(x)
+ return x
+
+
+class AttentionDownsample(nn.Module):
+ attention_bias_cache: Dict[str, torch.Tensor]
+
+ def __init__(
+ self,
+ in_dim,
+ out_dim,
+ key_dim,
+ num_heads=8,
+ attn_ratio=2.0,
+ stride=2,
+ resolution=14,
+ use_conv=False,
+ use_pool=False,
+ act_layer=nn.SiLU,
+ ):
+ super().__init__()
+ resolution = to_2tuple(resolution)
+
+ self.stride = stride
+ self.resolution = resolution
+ self.num_heads = num_heads
+ self.key_dim = key_dim
+ self.key_attn_dim = key_dim * num_heads
+ self.val_dim = int(attn_ratio * key_dim)
+ self.val_attn_dim = self.val_dim * self.num_heads
+ self.scale = key_dim ** -0.5
+ self.use_conv = use_conv
+
+ if self.use_conv:
+ ln_layer = ConvNorm
+ sub_layer = partial(
+ nn.AvgPool2d,
+ kernel_size=3 if use_pool else 1, padding=1 if use_pool else 0, count_include_pad=False)
+ else:
+ ln_layer = LinearNorm
+ sub_layer = partial(Downsample, resolution=resolution, use_pool=use_pool)
+
+ self.kv = ln_layer(in_dim, self.val_attn_dim + self.key_attn_dim)
+ self.q = nn.Sequential(OrderedDict([
+ ('down', sub_layer(stride=stride)),
+ ('ln', ln_layer(in_dim, self.key_attn_dim))
+ ]))
+ self.proj = nn.Sequential(OrderedDict([
+ ('act', act_layer()),
+ ('ln', ln_layer(self.val_attn_dim, out_dim))
+ ]))
+
+ self.attention_biases = nn.Parameter(torch.zeros(num_heads, resolution[0] * resolution[1]))
+ k_pos = torch.stack(ndgrid(torch.arange(resolution[0]), torch.arange(resolution[1]))).flatten(1)
+ q_pos = torch.stack(ndgrid(
+ torch.arange(0, resolution[0], step=stride),
+ torch.arange(0, resolution[1], step=stride)
+ )).flatten(1)
+ rel_pos = (q_pos[..., :, None] - k_pos[..., None, :]).abs()
+ rel_pos = (rel_pos[0] * resolution[1]) + rel_pos[1]
+ self.register_buffer('attention_bias_idxs', rel_pos, persistent=False)
+
+ self.attention_bias_cache = {} # per-device attention_biases cache
+
+ @torch.no_grad()
+ def train(self, mode=True):
+ super().train(mode)
+ if mode and self.attention_bias_cache:
+ self.attention_bias_cache = {} # clear ab cache
+
+ def get_attention_biases(self, device: torch.device) -> torch.Tensor:
+ if torch.jit.is_tracing() or self.training:
+ return self.attention_biases[:, self.attention_bias_idxs]
+ else:
+ device_key = str(device)
+ if device_key not in self.attention_bias_cache:
+ self.attention_bias_cache[device_key] = self.attention_biases[:, self.attention_bias_idxs]
+ return self.attention_bias_cache[device_key]
+
+ def forward(self, x):
+ if self.use_conv:
+ B, C, H, W = x.shape
+ HH, WW = (H - 1) // self.stride + 1, (W - 1) // self.stride + 1
+ k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.val_dim], dim=2)
+ q = self.q(x).view(B, self.num_heads, self.key_dim, -1)
+
+ attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
+ attn = attn.softmax(dim=-1)
+
+ x = (v @ attn.transpose(-2, -1)).reshape(B, self.val_attn_dim, HH, WW)
+ else:
+ B, N, C = x.shape
+ k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.val_dim], dim=3)
+ k = k.permute(0, 2, 3, 1) # BHCN
+ v = v.permute(0, 2, 1, 3) # BHNC
+ q = self.q(x).view(B, -1, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
+
+ attn = q @ k * self.scale + self.get_attention_biases(x.device)
+ attn = attn.softmax(dim=-1)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, -1, self.val_attn_dim)
+ x = self.proj(x)
+ return x
+
+
+class LevitMlp(nn.Module):
+ """ MLP for Levit w/ normalization + ability to switch btw conv and linear
+ """
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ use_conv=False,
+ act_layer=nn.SiLU,
+ drop=0.
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ ln_layer = ConvNorm if use_conv else LinearNorm
+
+ self.ln1 = ln_layer(in_features, hidden_features)
+ self.act = act_layer()
+ self.drop = nn.Dropout(drop)
+ self.ln2 = ln_layer(hidden_features, out_features, bn_weight_init=0)
+
+ def forward(self, x):
+ x = self.ln1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.ln2(x)
+ return x
+
+
+class LevitDownsample(nn.Module):
+ def __init__(
+ self,
+ in_dim,
+ out_dim,
+ key_dim,
+ num_heads=8,
+ attn_ratio=4.,
+ mlp_ratio=2.,
+ act_layer=nn.SiLU,
+ attn_act_layer=None,
+ resolution=14,
+ use_conv=False,
+ use_pool=False,
+ drop_path=0.,
+ ):
+ super().__init__()
+ attn_act_layer = attn_act_layer or act_layer
+
+ self.attn_downsample = AttentionDownsample(
+ in_dim=in_dim,
+ out_dim=out_dim,
+ key_dim=key_dim,
+ num_heads=num_heads,
+ attn_ratio=attn_ratio,
+ act_layer=attn_act_layer,
+ resolution=resolution,
+ use_conv=use_conv,
+ use_pool=use_pool,
+ )
+
+ self.mlp = LevitMlp(
+ out_dim,
+ int(out_dim * mlp_ratio),
+ use_conv=use_conv,
+ act_layer=act_layer
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def forward(self, x):
+ x = self.attn_downsample(x)
+ x = x + self.drop_path(self.mlp(x))
+ return x
+
+
+class LevitBlock(nn.Module):
+ def __init__(
+ self,
+ dim,
+ key_dim,
+ num_heads=8,
+ attn_ratio=4.,
+ mlp_ratio=2.,
+ resolution=14,
+ use_conv=False,
+ act_layer=nn.SiLU,
+ attn_act_layer=None,
+ drop_path=0.,
+ ):
+ super().__init__()
+ attn_act_layer = attn_act_layer or act_layer
+
+ self.attn = Attention(
+ dim=dim,
+ key_dim=key_dim,
+ num_heads=num_heads,
+ attn_ratio=attn_ratio,
+ resolution=resolution,
+ use_conv=use_conv,
+ act_layer=attn_act_layer,
+ )
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.mlp = LevitMlp(
+ dim,
+ int(dim * mlp_ratio),
+ use_conv=use_conv,
+ act_layer=act_layer
+ )
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def forward(self, x):
+ x = x + self.drop_path1(self.attn(x))
+ x = x + self.drop_path2(self.mlp(x))
+ return x
+
+
+class LevitStage(nn.Module):
+ def __init__(
+ self,
+ in_dim,
+ out_dim,
+ key_dim,
+ depth=4,
+ num_heads=8,
+ attn_ratio=4.0,
+ mlp_ratio=4.0,
+ act_layer=nn.SiLU,
+ attn_act_layer=None,
+ resolution=14,
+ downsample='',
+ use_conv=False,
+ drop_path=0.,
+ ):
+ super().__init__()
+ resolution = to_2tuple(resolution)
+
+ if downsample:
+ self.downsample = LevitDownsample(
+ in_dim,
+ out_dim,
+ key_dim=key_dim,
+ num_heads=in_dim // key_dim,
+ attn_ratio=4.,
+ mlp_ratio=2.,
+ act_layer=act_layer,
+ attn_act_layer=attn_act_layer,
+ resolution=resolution,
+ use_conv=use_conv,
+ drop_path=drop_path,
+ )
+ resolution = [(r - 1) // 2 + 1 for r in resolution]
+ else:
+ assert in_dim == out_dim
+ self.downsample = nn.Identity()
+
+ blocks = []
+ for _ in range(depth):
+ blocks += [LevitBlock(
+ out_dim,
+ key_dim,
+ num_heads=num_heads,
+ attn_ratio=attn_ratio,
+ mlp_ratio=mlp_ratio,
+ act_layer=act_layer,
+ attn_act_layer=attn_act_layer,
+ resolution=resolution,
+ use_conv=use_conv,
+ drop_path=drop_path,
+ )]
+ self.blocks = nn.Sequential(*blocks)
+
+ def forward(self, x):
+ x = self.downsample(x)
+ x = self.blocks(x)
+ return x
+
+
+class Levit(nn.Module):
+ """ Vision Transformer with support for patch or hybrid CNN input stage
+
+ NOTE: distillation is defaulted to True since pretrained weights use it, will cause problems
+ w/ train scripts that don't take tuple outputs,
+ """
+
+ def __init__(
+ self,
+ img_size=224,
+ in_chans=3,
+ num_classes=1000,
+ embed_dim=(192,),
+ key_dim=64,
+ depth=(12,),
+ num_heads=(3,),
+ attn_ratio=2.,
+ mlp_ratio=2.,
+ stem_backbone=None,
+ stem_stride=None,
+ stem_type='s16',
+ down_op='subsample',
+ act_layer='hard_swish',
+ attn_act_layer=None,
+ use_conv=False,
+ global_pool='avg',
+ drop_rate=0.,
+ drop_path_rate=0.):
+ super().__init__()
+ act_layer = get_act_layer(act_layer)
+ attn_act_layer = get_act_layer(attn_act_layer or act_layer)
+ self.use_conv = use_conv
+ self.num_classes = num_classes
+ self.global_pool = global_pool
+ self.num_features = embed_dim[-1]
+ self.embed_dim = embed_dim
+ self.drop_rate = drop_rate
+ self.grad_checkpointing = False
+ self.feature_info = []
+
+ num_stages = len(embed_dim)
+ assert len(depth) == num_stages
+ num_heads = to_ntuple(num_stages)(num_heads)
+ attn_ratio = to_ntuple(num_stages)(attn_ratio)
+ mlp_ratio = to_ntuple(num_stages)(mlp_ratio)
+
+ if stem_backbone is not None:
+ assert stem_stride >= 2
+ self.stem = stem_backbone
+ stride = stem_stride
+ else:
+ assert stem_type in ('s16', 's8')
+ if stem_type == 's16':
+ self.stem = Stem16(in_chans, embed_dim[0], act_layer=act_layer)
+ else:
+ self.stem = Stem8(in_chans, embed_dim[0], act_layer=act_layer)
+ stride = self.stem.stride
+ resolution = tuple([i // p for i, p in zip(to_2tuple(img_size), to_2tuple(stride))])
+
+ in_dim = embed_dim[0]
+ stages = []
+ for i in range(num_stages):
+ stage_stride = 2 if i > 0 else 1
+ stages += [LevitStage(
+ in_dim,
+ embed_dim[i],
+ key_dim,
+ depth=depth[i],
+ num_heads=num_heads[i],
+ attn_ratio=attn_ratio[i],
+ mlp_ratio=mlp_ratio[i],
+ act_layer=act_layer,
+ attn_act_layer=attn_act_layer,
+ resolution=resolution,
+ use_conv=use_conv,
+ downsample=down_op if stage_stride == 2 else '',
+ drop_path=drop_path_rate
+ )]
+ stride *= stage_stride
+ resolution = tuple([(r - 1) // stage_stride + 1 for r in resolution])
+ self.feature_info += [dict(num_chs=embed_dim[i], reduction=stride, module=f'stages.{i}')]
+ in_dim = embed_dim[i]
+ self.stages = nn.Sequential(*stages)
+
+ # Classifier head
+ self.head = NormLinear(embed_dim[-1], num_classes, drop=drop_rate) if num_classes > 0 else nn.Identity()
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {x for x in self.state_dict().keys() if 'attention_biases' in x}
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse=False):
+ matcher = dict(
+ stem=r'^cls_token|pos_embed|patch_embed', # stem and embed
+ blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
+ )
+ return matcher
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.grad_checkpointing = enable
+
+ @torch.jit.ignore
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=None, distillation=None):
+ self.num_classes = num_classes
+ if global_pool is not None:
+ self.global_pool = global_pool
+ self.head = NormLinear(
+ self.embed_dim[-1], num_classes, drop=self.drop_rate) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ x = self.stem(x)
+ if not self.use_conv:
+ x = x.flatten(2).transpose(1, 2)
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint_seq(self.stages, x)
+ else:
+ x = self.stages(x)
+ return x
+
+ def forward_head(self, x, pre_logits: bool = False):
+ if self.global_pool == 'avg':
+ x = x.mean(dim=(-2, -1)) if self.use_conv else x.mean(dim=1)
+ return x if pre_logits else self.head(x)
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.forward_head(x)
+ return x
+
+
+class LevitDistilled(Levit):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.head_dist = NormLinear(self.num_features, self.num_classes) if self.num_classes > 0 else nn.Identity()
+ self.distilled_training = False # must set this True to train w/ distillation token
+
+ @torch.jit.ignore
+ def get_classifier(self):
+ return self.head, self.head_dist
+
+ def reset_classifier(self, num_classes, global_pool=None, distillation=None):
+ self.num_classes = num_classes
+ if global_pool is not None:
+ self.global_pool = global_pool
+ self.head = NormLinear(
+ self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else nn.Identity()
+ self.head_dist = NormLinear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+ @torch.jit.ignore
+ def set_distilled_training(self, enable=True):
+ self.distilled_training = enable
+
+ def forward_head(self, x, pre_logits: bool = False):
+ if self.global_pool == 'avg':
+ x = x.mean(dim=(-2, -1)) if self.use_conv else x.mean(dim=1)
+ if pre_logits:
+ return x
+ x, x_dist = self.head(x), self.head_dist(x)
+ if self.distilled_training and self.training and not torch.jit.is_scripting():
+ # only return separate classification predictions when training in distilled mode
+ return x, x_dist
+ else:
+ # during standard train/finetune, inference average the classifier predictions
+ return (x + x_dist) / 2
+
+
+def checkpoint_filter_fn(state_dict, model):
+ if 'model' in state_dict:
+ state_dict = state_dict['model']
+
+ # filter out attn biases, should not have been persistent
+ state_dict = {k: v for k, v in state_dict.items() if 'attention_bias_idxs' not in k}
+
+ D = model.state_dict()
+ out_dict = {}
+ for ka, kb, va, vb in zip(D.keys(), state_dict.keys(), D.values(), state_dict.values()):
+ if va.ndim == 4 and vb.ndim == 2:
+ vb = vb[:, :, None, None]
+ if va.shape != vb.shape:
+ # head or first-conv shapes may change for fine-tune
+ assert 'head' in ka or 'stem.conv1.linear' in ka
+ out_dict[ka] = vb
+
+ return out_dict
+
+
+model_cfgs = dict(
+ levit_128s=dict(
+ embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 6, 8), depth=(2, 3, 4)),
+ levit_128=dict(
+ embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 8, 12), depth=(4, 4, 4)),
+ levit_192=dict(
+ embed_dim=(192, 288, 384), key_dim=32, num_heads=(3, 5, 6), depth=(4, 4, 4)),
+ levit_256=dict(
+ embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 4, 4)),
+ levit_384=dict(
+ embed_dim=(384, 512, 768), key_dim=32, num_heads=(6, 9, 12), depth=(4, 4, 4)),
+
+ # stride-8 stem experiments
+ levit_384_s8=dict(
+ embed_dim=(384, 512, 768), key_dim=32, num_heads=(6, 9, 12), depth=(4, 4, 4),
+ act_layer='silu', stem_type='s8'),
+ levit_512_s8=dict(
+ embed_dim=(512, 640, 896), key_dim=64, num_heads=(8, 10, 14), depth=(4, 4, 4),
+ act_layer='silu', stem_type='s8'),
+
+ # wider experiments
+ levit_512=dict(
+ embed_dim=(512, 768, 1024), key_dim=64, num_heads=(8, 12, 16), depth=(4, 4, 4), act_layer='silu'),
+
+ # deeper experiments
+ levit_256d=dict(
+ embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 8, 6), act_layer='silu'),
+ levit_512d=dict(
+ embed_dim=(512, 640, 768), key_dim=64, num_heads=(8, 10, 12), depth=(4, 8, 6), act_layer='silu'),
+)
+
+
+def create_levit(variant, cfg_variant=None, pretrained=False, distilled=True, **kwargs):
+ is_conv = '_conv' in variant
+ out_indices = kwargs.pop('out_indices', (0, 1, 2))
+ if kwargs.get('features_only', None):
+ if not is_conv:
+ raise RuntimeError('features_only not implemented for LeVit in non-convolutional mode.')
+ if cfg_variant is None:
+ if variant in model_cfgs:
+ cfg_variant = variant
+ elif is_conv:
+ cfg_variant = variant.replace('_conv', '')
+
+ model_cfg = dict(model_cfgs[cfg_variant], **kwargs)
+ model = build_model_with_cfg(
+ LevitDistilled if distilled else Levit,
+ variant,
+ pretrained,
+ pretrained_filter_fn=checkpoint_filter_fn,
+ feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
+ **model_cfg,
+ )
+ return model
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'stem.conv1.linear', 'classifier': ('head.linear', 'head_dist.linear'),
+ **kwargs
+ }
+
+
+default_cfgs = generate_default_cfgs({
+ # weights in nn.Linear mode
+ 'levit_128s.fb_dist_in1k': _cfg(
+ hf_hub_id='timm/',
+ ),
+ 'levit_128.fb_dist_in1k': _cfg(
+ hf_hub_id='timm/',
+ ),
+ 'levit_192.fb_dist_in1k': _cfg(
+ hf_hub_id='timm/',
+ ),
+ 'levit_256.fb_dist_in1k': _cfg(
+ hf_hub_id='timm/',
+ ),
+ 'levit_384.fb_dist_in1k': _cfg(
+ hf_hub_id='timm/',
+ ),
+
+ # weights in nn.Conv2d mode
+ 'levit_conv_128s.fb_dist_in1k': _cfg(
+ hf_hub_id='timm/',
+ pool_size=(4, 4),
+ ),
+ 'levit_conv_128.fb_dist_in1k': _cfg(
+ hf_hub_id='timm/',
+ pool_size=(4, 4),
+ ),
+ 'levit_conv_192.fb_dist_in1k': _cfg(
+ hf_hub_id='timm/',
+ pool_size=(4, 4),
+ ),
+ 'levit_conv_256.fb_dist_in1k': _cfg(
+ hf_hub_id='timm/',
+ pool_size=(4, 4),
+ ),
+ 'levit_conv_384.fb_dist_in1k': _cfg(
+ hf_hub_id='timm/',
+ pool_size=(4, 4),
+ ),
+
+ 'levit_384_s8.untrained': _cfg(classifier='head.linear'),
+ 'levit_512_s8.untrained': _cfg(classifier='head.linear'),
+ 'levit_512.untrained': _cfg(classifier='head.linear'),
+ 'levit_256d.untrained': _cfg(classifier='head.linear'),
+ 'levit_512d.untrained': _cfg(classifier='head.linear'),
+
+ 'levit_conv_384_s8.untrained': _cfg(classifier='head.linear'),
+ 'levit_conv_512_s8.untrained': _cfg(classifier='head.linear'),
+ 'levit_conv_512.untrained': _cfg(classifier='head.linear'),
+ 'levit_conv_256d.untrained': _cfg(classifier='head.linear'),
+ 'levit_conv_512d.untrained': _cfg(classifier='head.linear'),
+})
+
+
+@register_model
+def levit_128s(pretrained=False, **kwargs) -> Levit:
+ return create_levit('levit_128s', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def levit_128(pretrained=False, **kwargs) -> Levit:
+ return create_levit('levit_128', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def levit_192(pretrained=False, **kwargs) -> Levit:
+ return create_levit('levit_192', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def levit_256(pretrained=False, **kwargs) -> Levit:
+ return create_levit('levit_256', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def levit_384(pretrained=False, **kwargs) -> Levit:
+ return create_levit('levit_384', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def levit_384_s8(pretrained=False, **kwargs) -> Levit:
+ return create_levit('levit_384_s8', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def levit_512_s8(pretrained=False, **kwargs) -> Levit:
+ return create_levit('levit_512_s8', pretrained=pretrained, distilled=False, **kwargs)
+
+
+@register_model
+def levit_512(pretrained=False, **kwargs) -> Levit:
+ return create_levit('levit_512', pretrained=pretrained, distilled=False, **kwargs)
+
+
+@register_model
+def levit_256d(pretrained=False, **kwargs) -> Levit:
+ return create_levit('levit_256d', pretrained=pretrained, distilled=False, **kwargs)
+
+
+@register_model
+def levit_512d(pretrained=False, **kwargs) -> Levit:
+ return create_levit('levit_512d', pretrained=pretrained, distilled=False, **kwargs)
+
+
+@register_model
+def levit_conv_128s(pretrained=False, **kwargs) -> Levit:
+ return create_levit('levit_conv_128s', pretrained=pretrained, use_conv=True, **kwargs)
+
+
+@register_model
+def levit_conv_128(pretrained=False, **kwargs) -> Levit:
+ return create_levit('levit_conv_128', pretrained=pretrained, use_conv=True, **kwargs)
+
+
+@register_model
+def levit_conv_192(pretrained=False, **kwargs) -> Levit:
+ return create_levit('levit_conv_192', pretrained=pretrained, use_conv=True, **kwargs)
+
+
+@register_model
+def levit_conv_256(pretrained=False, **kwargs) -> Levit:
+ return create_levit('levit_conv_256', pretrained=pretrained, use_conv=True, **kwargs)
+
+
+@register_model
+def levit_conv_384(pretrained=False, **kwargs) -> Levit:
+ return create_levit('levit_conv_384', pretrained=pretrained, use_conv=True, **kwargs)
+
+
+@register_model
+def levit_conv_384_s8(pretrained=False, **kwargs) -> Levit:
+ return create_levit('levit_conv_384_s8', pretrained=pretrained, use_conv=True, **kwargs)
+
+
+@register_model
+def levit_conv_512_s8(pretrained=False, **kwargs) -> Levit:
+ return create_levit('levit_conv_512_s8', pretrained=pretrained, use_conv=True, distilled=False, **kwargs)
+
+
+@register_model
+def levit_conv_512(pretrained=False, **kwargs) -> Levit:
+ return create_levit('levit_conv_512', pretrained=pretrained, use_conv=True, distilled=False, **kwargs)
+
+
+@register_model
+def levit_conv_256d(pretrained=False, **kwargs) -> Levit:
+ return create_levit('levit_conv_256d', pretrained=pretrained, use_conv=True, distilled=False, **kwargs)
+
+
+@register_model
+def levit_conv_512d(pretrained=False, **kwargs) -> Levit:
+ return create_levit('levit_conv_512d', pretrained=pretrained, use_conv=True, distilled=False, **kwargs)
+
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/maxxvit.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/maxxvit.py
new file mode 100644
index 0000000000000000000000000000000000000000..6283443ce5466344e9966d9e81325d7cc9e9c8bc
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/maxxvit.py
@@ -0,0 +1,2331 @@
+""" MaxVit and CoAtNet Vision Transformer - CNN Hybrids in PyTorch
+
+This is a from-scratch implementation of both CoAtNet and MaxVit in PyTorch.
+
+99% of the implementation was done from papers, however last minute some adjustments were made
+based on the (as yet unfinished?) public code release https://github.com/google-research/maxvit
+
+There are multiple sets of models defined for both architectures. Typically, names with a
+ `_rw` suffix are my own original configs prior to referencing https://github.com/google-research/maxvit.
+These configs work well and appear to be a bit faster / lower resource than the paper.
+
+The models without extra prefix / suffix' (coatnet_0_224, maxvit_tiny_224, etc), are intended to
+match paper, BUT, without any official pretrained weights it's difficult to confirm a 100% match.
+
+Papers:
+
+MaxViT: Multi-Axis Vision Transformer - https://arxiv.org/abs/2204.01697
+@article{tu2022maxvit,
+ title={MaxViT: Multi-Axis Vision Transformer},
+ author={Tu, Zhengzhong and Talebi, Hossein and Zhang, Han and Yang, Feng and Milanfar, Peyman and Bovik, Alan and Li, Yinxiao},
+ journal={ECCV},
+ year={2022},
+}
+
+CoAtNet: Marrying Convolution and Attention for All Data Sizes - https://arxiv.org/abs/2106.04803
+@article{DBLP:journals/corr/abs-2106-04803,
+ author = {Zihang Dai and Hanxiao Liu and Quoc V. Le and Mingxing Tan},
+ title = {CoAtNet: Marrying Convolution and Attention for All Data Sizes},
+ journal = {CoRR},
+ volume = {abs/2106.04803},
+ year = {2021}
+}
+
+Hacked together by / Copyright 2022, Ross Wightman
+"""
+
+import math
+from collections import OrderedDict
+from dataclasses import dataclass, replace, field
+from functools import partial
+from typing import Callable, Optional, Union, Tuple, List
+
+import torch
+from torch import nn
+from torch.jit import Final
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.layers import Mlp, ConvMlp, DropPath, LayerNorm, ClassifierHead, NormMlpClassifierHead
+from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, create_pool2d
+from timm.layers import trunc_normal_tf_, to_2tuple, extend_tuple, make_divisible, _assert
+from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf, use_fused_attn, resize_rel_pos_bias_table
+from ._builder import build_model_with_cfg
+from ._features_fx import register_notrace_function
+from ._manipulate import named_apply, checkpoint_seq
+from ._registry import generate_default_cfgs, register_model
+
+__all__ = ['MaxxVitCfg', 'MaxxVitConvCfg', 'MaxxVitTransformerCfg', 'MaxxVit']
+
+
+@dataclass
+class MaxxVitTransformerCfg:
+ dim_head: int = 32
+ head_first: bool = True # head ordering in qkv channel dim
+ expand_ratio: float = 4.0
+ expand_first: bool = True
+ shortcut_bias: bool = True
+ attn_bias: bool = True
+ attn_drop: float = 0.
+ proj_drop: float = 0.
+ pool_type: str = 'avg2'
+ rel_pos_type: str = 'bias'
+ rel_pos_dim: int = 512 # for relative position types w/ MLP
+ partition_ratio: int = 32
+ window_size: Optional[Tuple[int, int]] = None
+ grid_size: Optional[Tuple[int, int]] = None
+ no_block_attn: bool = False # disable window block attention for maxvit (ie only grid)
+ use_nchw_attn: bool = False # for MaxViT variants (not used for CoAt), keep tensors in NCHW order
+ init_values: Optional[float] = None
+ act_layer: str = 'gelu'
+ norm_layer: str = 'layernorm2d'
+ norm_layer_cl: str = 'layernorm'
+ norm_eps: float = 1e-6
+
+ def __post_init__(self):
+ if self.grid_size is not None:
+ self.grid_size = to_2tuple(self.grid_size)
+ if self.window_size is not None:
+ self.window_size = to_2tuple(self.window_size)
+ if self.grid_size is None:
+ self.grid_size = self.window_size
+
+
+@dataclass
+class MaxxVitConvCfg:
+ block_type: str = 'mbconv'
+ expand_ratio: float = 4.0
+ expand_output: bool = True # calculate expansion channels from output (vs input chs)
+ kernel_size: int = 3
+ group_size: int = 1 # 1 == depthwise
+ pre_norm_act: bool = False # activation after pre-norm
+ output_bias: bool = True # bias for shortcut + final 1x1 projection conv
+ stride_mode: str = 'dw' # stride done via one of 'pool', '1x1', 'dw'
+ pool_type: str = 'avg2'
+ downsample_pool_type: str = 'avg2'
+ padding: str = ''
+ attn_early: bool = False # apply attn between conv2 and norm2, instead of after norm2
+ attn_layer: str = 'se'
+ attn_act_layer: str = 'silu'
+ attn_ratio: float = 0.25
+ init_values: Optional[float] = 1e-6 # for ConvNeXt block, ignored by MBConv
+ act_layer: str = 'gelu'
+ norm_layer: str = ''
+ norm_layer_cl: str = ''
+ norm_eps: Optional[float] = None
+
+ def __post_init__(self):
+ # mbconv vs convnext blocks have different defaults, set in post_init to avoid explicit config args
+ assert self.block_type in ('mbconv', 'convnext')
+ use_mbconv = self.block_type == 'mbconv'
+ if not self.norm_layer:
+ self.norm_layer = 'batchnorm2d' if use_mbconv else 'layernorm2d'
+ if not self.norm_layer_cl and not use_mbconv:
+ self.norm_layer_cl = 'layernorm'
+ if self.norm_eps is None:
+ self.norm_eps = 1e-5 if use_mbconv else 1e-6
+ self.downsample_pool_type = self.downsample_pool_type or self.pool_type
+
+
+@dataclass
+class MaxxVitCfg:
+ embed_dim: Tuple[int, ...] = (96, 192, 384, 768)
+ depths: Tuple[int, ...] = (2, 3, 5, 2)
+ block_type: Tuple[Union[str, Tuple[str, ...]], ...] = ('C', 'C', 'T', 'T')
+ stem_width: Union[int, Tuple[int, int]] = 64
+ stem_bias: bool = False
+ conv_cfg: MaxxVitConvCfg = field(default_factory=MaxxVitConvCfg)
+ transformer_cfg: MaxxVitTransformerCfg = field(default_factory=MaxxVitTransformerCfg)
+ head_hidden_size: int = None
+ weight_init: str = 'vit_eff'
+
+
+class Attention2d(nn.Module):
+ fused_attn: Final[bool]
+
+ """ multi-head attention for 2D NCHW tensors"""
+ def __init__(
+ self,
+ dim: int,
+ dim_out: Optional[int] = None,
+ dim_head: int = 32,
+ bias: bool = True,
+ expand_first: bool = True,
+ head_first: bool = True,
+ rel_pos_cls: Callable = None,
+ attn_drop: float = 0.,
+ proj_drop: float = 0.
+ ):
+ super().__init__()
+ dim_out = dim_out or dim
+ dim_attn = dim_out if expand_first else dim
+ self.num_heads = dim_attn // dim_head
+ self.dim_head = dim_head
+ self.head_first = head_first
+ self.scale = dim_head ** -0.5
+ self.fused_attn = use_fused_attn()
+
+ self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias)
+ self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
+ B, C, H, W = x.shape
+
+ if self.head_first:
+ q, k, v = self.qkv(x).view(B, self.num_heads, self.dim_head * 3, -1).chunk(3, dim=2)
+ else:
+ q, k, v = self.qkv(x).reshape(B, 3, self.num_heads, self.dim_head, -1).unbind(1)
+
+ if self.fused_attn:
+ attn_bias = None
+ if self.rel_pos is not None:
+ attn_bias = self.rel_pos.get_bias()
+ elif shared_rel_pos is not None:
+ attn_bias = shared_rel_pos
+
+ x = torch.nn.functional.scaled_dot_product_attention(
+ q.transpose(-1, -2).contiguous(),
+ k.transpose(-1, -2).contiguous(),
+ v.transpose(-1, -2).contiguous(),
+ attn_mask=attn_bias,
+ dropout_p=self.attn_drop.p if self.training else 0.,
+ ).transpose(-1, -2).reshape(B, -1, H, W)
+ else:
+ q = q * self.scale
+ attn = q.transpose(-2, -1) @ k
+ if self.rel_pos is not None:
+ attn = self.rel_pos(attn)
+ elif shared_rel_pos is not None:
+ attn = attn + shared_rel_pos
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class AttentionCl(nn.Module):
+ """ Channels-last multi-head attention (B, ..., C) """
+ fused_attn: Final[bool]
+
+ def __init__(
+ self,
+ dim: int,
+ dim_out: Optional[int] = None,
+ dim_head: int = 32,
+ bias: bool = True,
+ expand_first: bool = True,
+ head_first: bool = True,
+ rel_pos_cls: Callable = None,
+ attn_drop: float = 0.,
+ proj_drop: float = 0.
+ ):
+ super().__init__()
+ dim_out = dim_out or dim
+ dim_attn = dim_out if expand_first and dim_out > dim else dim
+ assert dim_attn % dim_head == 0, 'attn dim should be divisible by head_dim'
+ self.num_heads = dim_attn // dim_head
+ self.dim_head = dim_head
+ self.head_first = head_first
+ self.scale = dim_head ** -0.5
+ self.fused_attn = use_fused_attn()
+
+ self.qkv = nn.Linear(dim, dim_attn * 3, bias=bias)
+ self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim_attn, dim_out, bias=bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
+ B = x.shape[0]
+ restore_shape = x.shape[:-1]
+
+ if self.head_first:
+ q, k, v = self.qkv(x).view(B, -1, self.num_heads, self.dim_head * 3).transpose(1, 2).chunk(3, dim=3)
+ else:
+ q, k, v = self.qkv(x).reshape(B, -1, 3, self.num_heads, self.dim_head).transpose(1, 3).unbind(2)
+
+ if self.fused_attn:
+ attn_bias = None
+ if self.rel_pos is not None:
+ attn_bias = self.rel_pos.get_bias()
+ elif shared_rel_pos is not None:
+ attn_bias = shared_rel_pos
+
+ x = torch.nn.functional.scaled_dot_product_attention(
+ q, k, v,
+ attn_mask=attn_bias,
+ dropout_p=self.attn_drop.p if self.training else 0.,
+ )
+ else:
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+ if self.rel_pos is not None:
+ attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos)
+ elif shared_rel_pos is not None:
+ attn = attn + shared_rel_pos
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = attn @ v
+
+ x = x.transpose(1, 2).reshape(restore_shape + (-1,))
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class LayerScale(nn.Module):
+ def __init__(self, dim, init_values=1e-5, inplace=False):
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x):
+ gamma = self.gamma
+ return x.mul_(gamma) if self.inplace else x * gamma
+
+
+class LayerScale2d(nn.Module):
+ def __init__(self, dim, init_values=1e-5, inplace=False):
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x):
+ gamma = self.gamma.view(1, -1, 1, 1)
+ return x.mul_(gamma) if self.inplace else x * gamma
+
+
+class Downsample2d(nn.Module):
+ """ A downsample pooling module supporting several maxpool and avgpool modes
+ * 'max' - MaxPool2d w/ kernel_size 3, stride 2, padding 1
+ * 'max2' - MaxPool2d w/ kernel_size = stride = 2
+ * 'avg' - AvgPool2d w/ kernel_size 3, stride 2, padding 1
+ * 'avg2' - AvgPool2d w/ kernel_size = stride = 2
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ dim_out: int,
+ pool_type: str = 'avg2',
+ padding: str = '',
+ bias: bool = True,
+ ):
+ super().__init__()
+ assert pool_type in ('max', 'max2', 'avg', 'avg2')
+ if pool_type == 'max':
+ self.pool = create_pool2d('max', kernel_size=3, stride=2, padding=padding or 1)
+ elif pool_type == 'max2':
+ self.pool = create_pool2d('max', 2, padding=padding or 0) # kernel_size == stride == 2
+ elif pool_type == 'avg':
+ self.pool = create_pool2d(
+ 'avg', kernel_size=3, stride=2, count_include_pad=False, padding=padding or 1)
+ else:
+ self.pool = create_pool2d('avg', 2, padding=padding or 0)
+
+ if dim != dim_out:
+ self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias)
+ else:
+ self.expand = nn.Identity()
+
+ def forward(self, x):
+ x = self.pool(x) # spatial downsample
+ x = self.expand(x) # expand chs
+ return x
+
+
+def _init_transformer(module, name, scheme=''):
+ if isinstance(module, (nn.Conv2d, nn.Linear)):
+ if scheme == 'normal':
+ nn.init.normal_(module.weight, std=.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif scheme == 'trunc_normal':
+ trunc_normal_tf_(module.weight, std=.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif scheme == 'xavier_normal':
+ nn.init.xavier_normal_(module.weight)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ else:
+ # vit like
+ nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ if 'mlp' in name:
+ nn.init.normal_(module.bias, std=1e-6)
+ else:
+ nn.init.zeros_(module.bias)
+
+
+class TransformerBlock2d(nn.Module):
+ """ Transformer block with 2D downsampling
+ '2D' NCHW tensor layout
+
+ Some gains can be seen on GPU using a 1D / CL block, BUT w/ the need to switch back/forth to NCHW
+ for spatial pooling, the benefit is minimal so ended up using just this variant for CoAt configs.
+
+ This impl was faster on TPU w/ PT XLA than the 1D experiment.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ dim_out: int,
+ stride: int = 1,
+ rel_pos_cls: Callable = None,
+ cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
+ drop_path: float = 0.,
+ ):
+ super().__init__()
+ norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps)
+ act_layer = get_act_layer(cfg.act_layer)
+
+ if stride == 2:
+ self.shortcut = Downsample2d(dim, dim_out, pool_type=cfg.pool_type, bias=cfg.shortcut_bias)
+ self.norm1 = nn.Sequential(OrderedDict([
+ ('norm', norm_layer(dim)),
+ ('down', Downsample2d(dim, dim, pool_type=cfg.pool_type)),
+ ]))
+ else:
+ assert dim == dim_out
+ self.shortcut = nn.Identity()
+ self.norm1 = norm_layer(dim)
+
+ self.attn = Attention2d(
+ dim,
+ dim_out,
+ dim_head=cfg.dim_head,
+ expand_first=cfg.expand_first,
+ bias=cfg.attn_bias,
+ rel_pos_cls=rel_pos_cls,
+ attn_drop=cfg.attn_drop,
+ proj_drop=cfg.proj_drop
+ )
+ self.ls1 = LayerScale2d(dim_out, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.norm2 = norm_layer(dim_out)
+ self.mlp = ConvMlp(
+ in_features=dim_out,
+ hidden_features=int(dim_out * cfg.expand_ratio),
+ act_layer=act_layer,
+ drop=cfg.proj_drop)
+ self.ls2 = LayerScale2d(dim_out, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def init_weights(self, scheme=''):
+ named_apply(partial(_init_transformer, scheme=scheme), self)
+
+ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
+ x = self.shortcut(x) + self.drop_path1(self.ls1(self.attn(self.norm1(x), shared_rel_pos=shared_rel_pos)))
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
+ return x
+
+
+def _init_conv(module, name, scheme=''):
+ if isinstance(module, nn.Conv2d):
+ if scheme == 'normal':
+ nn.init.normal_(module.weight, std=.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif scheme == 'trunc_normal':
+ trunc_normal_tf_(module.weight, std=.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif scheme == 'xavier_normal':
+ nn.init.xavier_normal_(module.weight)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ else:
+ # efficientnet like
+ fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
+ fan_out //= module.groups
+ nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out))
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+
+def num_groups(group_size, channels):
+ if not group_size: # 0 or None
+ return 1 # normal conv with 1 group
+ else:
+ # NOTE group_size == 1 -> depthwise conv
+ assert channels % group_size == 0
+ return channels // group_size
+
+
+class MbConvBlock(nn.Module):
+ """ Pre-Norm Conv Block - 1x1 - kxk - 1x1, w/ inverted bottleneck (expand)
+ """
+ def __init__(
+ self,
+ in_chs: int,
+ out_chs: int,
+ stride: int = 1,
+ dilation: Tuple[int, int] = (1, 1),
+ cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
+ drop_path: float = 0.
+ ):
+ super(MbConvBlock, self).__init__()
+ norm_act_layer = partial(get_norm_act_layer(cfg.norm_layer, cfg.act_layer), eps=cfg.norm_eps)
+ mid_chs = make_divisible((out_chs if cfg.expand_output else in_chs) * cfg.expand_ratio)
+ groups = num_groups(cfg.group_size, mid_chs)
+
+ if stride == 2:
+ self.shortcut = Downsample2d(
+ in_chs, out_chs, pool_type=cfg.pool_type, bias=cfg.output_bias, padding=cfg.padding)
+ else:
+ self.shortcut = nn.Identity()
+
+ assert cfg.stride_mode in ('pool', '1x1', 'dw')
+ stride_pool, stride_1, stride_2 = 1, 1, 1
+ if cfg.stride_mode == 'pool':
+ # NOTE this is not described in paper, experiment to find faster option that doesn't stride in 1x1
+ stride_pool, dilation_2 = stride, dilation[1]
+ # FIXME handle dilation of avg pool
+ elif cfg.stride_mode == '1x1':
+ # NOTE I don't like this option described in paper, 1x1 w/ stride throws info away
+ stride_1, dilation_2 = stride, dilation[1]
+ else:
+ stride_2, dilation_2 = stride, dilation[0]
+
+ self.pre_norm = norm_act_layer(in_chs, apply_act=cfg.pre_norm_act)
+ if stride_pool > 1:
+ self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type, padding=cfg.padding)
+ else:
+ self.down = nn.Identity()
+ self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=stride_1)
+ self.norm1 = norm_act_layer(mid_chs)
+
+ self.conv2_kxk = create_conv2d(
+ mid_chs, mid_chs, cfg.kernel_size,
+ stride=stride_2, dilation=dilation_2, groups=groups, padding=cfg.padding)
+
+ attn_kwargs = {}
+ if isinstance(cfg.attn_layer, str):
+ if cfg.attn_layer == 'se' or cfg.attn_layer == 'eca':
+ attn_kwargs['act_layer'] = cfg.attn_act_layer
+ attn_kwargs['rd_channels'] = int(cfg.attn_ratio * (out_chs if cfg.expand_output else mid_chs))
+
+ # two different orderings for SE and norm2 (due to some weights and trials using SE before norm2)
+ if cfg.attn_early:
+ self.se_early = create_attn(cfg.attn_layer, mid_chs, **attn_kwargs)
+ self.norm2 = norm_act_layer(mid_chs)
+ self.se = None
+ else:
+ self.se_early = None
+ self.norm2 = norm_act_layer(mid_chs)
+ self.se = create_attn(cfg.attn_layer, mid_chs, **attn_kwargs)
+
+ self.conv3_1x1 = create_conv2d(mid_chs, out_chs, 1, bias=cfg.output_bias)
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def init_weights(self, scheme=''):
+ named_apply(partial(_init_conv, scheme=scheme), self)
+
+ def forward(self, x):
+ shortcut = self.shortcut(x)
+ x = self.pre_norm(x)
+ x = self.down(x)
+
+ # 1x1 expansion conv & norm-act
+ x = self.conv1_1x1(x)
+ x = self.norm1(x)
+
+ # depthwise / grouped 3x3 conv w/ SE (or other) channel attention & norm-act
+ x = self.conv2_kxk(x)
+ if self.se_early is not None:
+ x = self.se_early(x)
+ x = self.norm2(x)
+ if self.se is not None:
+ x = self.se(x)
+
+ # 1x1 linear projection to output width
+ x = self.conv3_1x1(x)
+ x = self.drop_path(x) + shortcut
+ return x
+
+
+class ConvNeXtBlock(nn.Module):
+ """ ConvNeXt Block
+ """
+
+ def __init__(
+ self,
+ in_chs: int,
+ out_chs: Optional[int] = None,
+ kernel_size: int = 7,
+ stride: int = 1,
+ dilation: Tuple[int, int] = (1, 1),
+ cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
+ conv_mlp: bool = True,
+ drop_path: float = 0.
+ ):
+ super().__init__()
+ out_chs = out_chs or in_chs
+ act_layer = get_act_layer(cfg.act_layer)
+ if conv_mlp:
+ norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps)
+ mlp_layer = ConvMlp
+ else:
+ assert 'layernorm' in cfg.norm_layer
+ norm_layer = LayerNorm
+ mlp_layer = Mlp
+ self.use_conv_mlp = conv_mlp
+
+ if stride == 2:
+ self.shortcut = Downsample2d(in_chs, out_chs)
+ elif in_chs != out_chs:
+ self.shortcut = nn.Conv2d(in_chs, out_chs, kernel_size=1, bias=cfg.output_bias)
+ else:
+ self.shortcut = nn.Identity()
+
+ assert cfg.stride_mode in ('pool', 'dw')
+ stride_pool, stride_dw = 1, 1
+ # FIXME handle dilation?
+ if cfg.stride_mode == 'pool':
+ stride_pool = stride
+ else:
+ stride_dw = stride
+
+ if stride_pool == 2:
+ self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type)
+ else:
+ self.down = nn.Identity()
+
+ self.conv_dw = create_conv2d(
+ in_chs, out_chs, kernel_size=kernel_size, stride=stride_dw, dilation=dilation[1],
+ depthwise=True, bias=cfg.output_bias)
+ self.norm = norm_layer(out_chs)
+ self.mlp = mlp_layer(out_chs, int(cfg.expand_ratio * out_chs), bias=cfg.output_bias, act_layer=act_layer)
+ if conv_mlp:
+ self.ls = LayerScale2d(out_chs, cfg.init_values) if cfg.init_values else nn.Identity()
+ else:
+ self.ls = LayerScale(out_chs, cfg.init_values) if cfg.init_values else nn.Identity()
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def forward(self, x):
+ shortcut = self.shortcut(x)
+ x = self.down(x)
+ x = self.conv_dw(x)
+ if self.use_conv_mlp:
+ x = self.norm(x)
+ x = self.mlp(x)
+ x = self.ls(x)
+ else:
+ x = x.permute(0, 2, 3, 1)
+ x = self.norm(x)
+ x = self.mlp(x)
+ x = self.ls(x)
+ x = x.permute(0, 3, 1, 2)
+
+ x = self.drop_path(x) + shortcut
+ return x
+
+
+def window_partition(x, window_size: List[int]):
+ B, H, W, C = x.shape
+ _assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})')
+ _assert(W % window_size[1] == 0, '')
+ x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
+ return windows
+
+
+@register_notrace_function # reason: int argument is a Proxy
+def window_reverse(windows, window_size: List[int], img_size: List[int]):
+ H, W = img_size
+ C = windows.shape[-1]
+ x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
+ return x
+
+
+def grid_partition(x, grid_size: List[int]):
+ B, H, W, C = x.shape
+ _assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}')
+ _assert(W % grid_size[1] == 0, '')
+ x = x.view(B, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1], C)
+ windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, grid_size[0], grid_size[1], C)
+ return windows
+
+
+@register_notrace_function # reason: int argument is a Proxy
+def grid_reverse(windows, grid_size: List[int], img_size: List[int]):
+ H, W = img_size
+ C = windows.shape[-1]
+ x = windows.view(-1, H // grid_size[0], W // grid_size[1], grid_size[0], grid_size[1], C)
+ x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, H, W, C)
+ return x
+
+
+def get_rel_pos_cls(cfg: MaxxVitTransformerCfg, window_size):
+ rel_pos_cls = None
+ if cfg.rel_pos_type == 'mlp':
+ rel_pos_cls = partial(RelPosMlp, window_size=window_size, hidden_dim=cfg.rel_pos_dim)
+ elif cfg.rel_pos_type == 'bias':
+ rel_pos_cls = partial(RelPosBias, window_size=window_size)
+ elif cfg.rel_pos_type == 'bias_tf':
+ rel_pos_cls = partial(RelPosBiasTf, window_size=window_size)
+ return rel_pos_cls
+
+
+class PartitionAttentionCl(nn.Module):
+ """ Grid or Block partition + Attn + FFN.
+ NxC 'channels last' tensor layout.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ partition_type: str = 'block',
+ cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
+ drop_path: float = 0.,
+ ):
+ super().__init__()
+ norm_layer = partial(get_norm_layer(cfg.norm_layer_cl), eps=cfg.norm_eps) # NOTE this block is channels-last
+ act_layer = get_act_layer(cfg.act_layer)
+
+ self.partition_block = partition_type == 'block'
+ self.partition_size = to_2tuple(cfg.window_size if self.partition_block else cfg.grid_size)
+ rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size)
+
+ self.norm1 = norm_layer(dim)
+ self.attn = AttentionCl(
+ dim,
+ dim,
+ dim_head=cfg.dim_head,
+ bias=cfg.attn_bias,
+ head_first=cfg.head_first,
+ rel_pos_cls=rel_pos_cls,
+ attn_drop=cfg.attn_drop,
+ proj_drop=cfg.proj_drop,
+ )
+ self.ls1 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=int(dim * cfg.expand_ratio),
+ act_layer=act_layer,
+ drop=cfg.proj_drop)
+ self.ls2 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def _partition_attn(self, x):
+ img_size = x.shape[1:3]
+ if self.partition_block:
+ partitioned = window_partition(x, self.partition_size)
+ else:
+ partitioned = grid_partition(x, self.partition_size)
+
+ partitioned = self.attn(partitioned)
+
+ if self.partition_block:
+ x = window_reverse(partitioned, self.partition_size, img_size)
+ else:
+ x = grid_reverse(partitioned, self.partition_size, img_size)
+ return x
+
+ def forward(self, x):
+ x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x))))
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
+ return x
+
+
+class ParallelPartitionAttention(nn.Module):
+ """ Experimental. Grid and Block partition + single FFN
+ NxC tensor layout.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
+ drop_path: float = 0.,
+ ):
+ super().__init__()
+ assert dim % 2 == 0
+ norm_layer = partial(get_norm_layer(cfg.norm_layer_cl), eps=cfg.norm_eps) # NOTE this block is channels-last
+ act_layer = get_act_layer(cfg.act_layer)
+
+ assert cfg.window_size == cfg.grid_size
+ self.partition_size = to_2tuple(cfg.window_size)
+ rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size)
+
+ self.norm1 = norm_layer(dim)
+ self.attn_block = AttentionCl(
+ dim,
+ dim // 2,
+ dim_head=cfg.dim_head,
+ bias=cfg.attn_bias,
+ head_first=cfg.head_first,
+ rel_pos_cls=rel_pos_cls,
+ attn_drop=cfg.attn_drop,
+ proj_drop=cfg.proj_drop,
+ )
+ self.attn_grid = AttentionCl(
+ dim,
+ dim // 2,
+ dim_head=cfg.dim_head,
+ bias=cfg.attn_bias,
+ head_first=cfg.head_first,
+ rel_pos_cls=rel_pos_cls,
+ attn_drop=cfg.attn_drop,
+ proj_drop=cfg.proj_drop,
+ )
+ self.ls1 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=int(dim * cfg.expand_ratio),
+ out_features=dim,
+ act_layer=act_layer,
+ drop=cfg.proj_drop)
+ self.ls2 = LayerScale(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def _partition_attn(self, x):
+ img_size = x.shape[1:3]
+
+ partitioned_block = window_partition(x, self.partition_size)
+ partitioned_block = self.attn_block(partitioned_block)
+ x_window = window_reverse(partitioned_block, self.partition_size, img_size)
+
+ partitioned_grid = grid_partition(x, self.partition_size)
+ partitioned_grid = self.attn_grid(partitioned_grid)
+ x_grid = grid_reverse(partitioned_grid, self.partition_size, img_size)
+
+ return torch.cat([x_window, x_grid], dim=-1)
+
+ def forward(self, x):
+ x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x))))
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
+ return x
+
+
+def window_partition_nchw(x, window_size: List[int]):
+ B, C, H, W = x.shape
+ _assert(H % window_size[0] == 0, f'height ({H}) must be divisible by window ({window_size[0]})')
+ _assert(W % window_size[1] == 0, '')
+ x = x.view(B, C, H // window_size[0], window_size[0], W // window_size[1], window_size[1])
+ windows = x.permute(0, 2, 4, 1, 3, 5).contiguous().view(-1, C, window_size[0], window_size[1])
+ return windows
+
+
+@register_notrace_function # reason: int argument is a Proxy
+def window_reverse_nchw(windows, window_size: List[int], img_size: List[int]):
+ H, W = img_size
+ C = windows.shape[1]
+ x = windows.view(-1, H // window_size[0], W // window_size[1], C, window_size[0], window_size[1])
+ x = x.permute(0, 3, 1, 4, 2, 5).contiguous().view(-1, C, H, W)
+ return x
+
+
+def grid_partition_nchw(x, grid_size: List[int]):
+ B, C, H, W = x.shape
+ _assert(H % grid_size[0] == 0, f'height {H} must be divisible by grid {grid_size[0]}')
+ _assert(W % grid_size[1] == 0, '')
+ x = x.view(B, C, grid_size[0], H // grid_size[0], grid_size[1], W // grid_size[1])
+ windows = x.permute(0, 3, 5, 1, 2, 4).contiguous().view(-1, C, grid_size[0], grid_size[1])
+ return windows
+
+
+@register_notrace_function # reason: int argument is a Proxy
+def grid_reverse_nchw(windows, grid_size: List[int], img_size: List[int]):
+ H, W = img_size
+ C = windows.shape[1]
+ x = windows.view(-1, H // grid_size[0], W // grid_size[1], C, grid_size[0], grid_size[1])
+ x = x.permute(0, 3, 4, 1, 5, 2).contiguous().view(-1, C, H, W)
+ return x
+
+
+class PartitionAttention2d(nn.Module):
+ """ Grid or Block partition + Attn + FFN
+
+ '2D' NCHW tensor layout.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ partition_type: str = 'block',
+ cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
+ drop_path: float = 0.,
+ ):
+ super().__init__()
+ norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps) # NOTE this block is channels-last
+ act_layer = get_act_layer(cfg.act_layer)
+
+ self.partition_block = partition_type == 'block'
+ self.partition_size = to_2tuple(cfg.window_size if self.partition_block else cfg.grid_size)
+ rel_pos_cls = get_rel_pos_cls(cfg, self.partition_size)
+
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention2d(
+ dim,
+ dim,
+ dim_head=cfg.dim_head,
+ bias=cfg.attn_bias,
+ head_first=cfg.head_first,
+ rel_pos_cls=rel_pos_cls,
+ attn_drop=cfg.attn_drop,
+ proj_drop=cfg.proj_drop,
+ )
+ self.ls1 = LayerScale2d(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ self.mlp = ConvMlp(
+ in_features=dim,
+ hidden_features=int(dim * cfg.expand_ratio),
+ act_layer=act_layer,
+ drop=cfg.proj_drop)
+ self.ls2 = LayerScale2d(dim, init_values=cfg.init_values) if cfg.init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def _partition_attn(self, x):
+ img_size = x.shape[-2:]
+ if self.partition_block:
+ partitioned = window_partition_nchw(x, self.partition_size)
+ else:
+ partitioned = grid_partition_nchw(x, self.partition_size)
+
+ partitioned = self.attn(partitioned)
+
+ if self.partition_block:
+ x = window_reverse_nchw(partitioned, self.partition_size, img_size)
+ else:
+ x = grid_reverse_nchw(partitioned, self.partition_size, img_size)
+ return x
+
+ def forward(self, x):
+ x = x + self.drop_path1(self.ls1(self._partition_attn(self.norm1(x))))
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
+ return x
+
+
+class MaxxVitBlock(nn.Module):
+ """ MaxVit conv, window partition + FFN , grid partition + FFN
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ dim_out: int,
+ stride: int = 1,
+ conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
+ transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
+ drop_path: float = 0.,
+ ):
+ super().__init__()
+ self.nchw_attn = transformer_cfg.use_nchw_attn
+
+ conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock
+ self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)
+
+ attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path)
+ partition_layer = PartitionAttention2d if self.nchw_attn else PartitionAttentionCl
+ self.attn_block = None if transformer_cfg.no_block_attn else partition_layer(**attn_kwargs)
+ self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs)
+
+ def init_weights(self, scheme=''):
+ if self.attn_block is not None:
+ named_apply(partial(_init_transformer, scheme=scheme), self.attn_block)
+ named_apply(partial(_init_transformer, scheme=scheme), self.attn_grid)
+ named_apply(partial(_init_conv, scheme=scheme), self.conv)
+
+ def forward(self, x):
+ # NCHW format
+ x = self.conv(x)
+
+ if not self.nchw_attn:
+ x = x.permute(0, 2, 3, 1) # to NHWC (channels-last)
+ if self.attn_block is not None:
+ x = self.attn_block(x)
+ x = self.attn_grid(x)
+ if not self.nchw_attn:
+ x = x.permute(0, 3, 1, 2) # back to NCHW
+ return x
+
+
+class ParallelMaxxVitBlock(nn.Module):
+ """ MaxVit block with parallel cat(window + grid), one FF
+ Experimental timm block.
+ """
+
+ def __init__(
+ self,
+ dim,
+ dim_out,
+ stride=1,
+ num_conv=2,
+ conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
+ transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
+ drop_path=0.,
+ ):
+ super().__init__()
+
+ conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock
+ if num_conv > 1:
+ convs = [conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)]
+ convs += [conv_cls(dim_out, dim_out, cfg=conv_cfg, drop_path=drop_path)] * (num_conv - 1)
+ self.conv = nn.Sequential(*convs)
+ else:
+ self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)
+ self.attn = ParallelPartitionAttention(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path)
+
+ def init_weights(self, scheme=''):
+ named_apply(partial(_init_transformer, scheme=scheme), self.attn)
+ named_apply(partial(_init_conv, scheme=scheme), self.conv)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = x.permute(0, 2, 3, 1)
+ x = self.attn(x)
+ x = x.permute(0, 3, 1, 2)
+ return x
+
+
+class MaxxVitStage(nn.Module):
+ def __init__(
+ self,
+ in_chs: int,
+ out_chs: int,
+ stride: int = 2,
+ depth: int = 4,
+ feat_size: Tuple[int, int] = (14, 14),
+ block_types: Union[str, Tuple[str]] = 'C',
+ transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
+ conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
+ drop_path: Union[float, List[float]] = 0.,
+ ):
+ super().__init__()
+ self.grad_checkpointing = False
+
+ block_types = extend_tuple(block_types, depth)
+ blocks = []
+ for i, t in enumerate(block_types):
+ block_stride = stride if i == 0 else 1
+ assert t in ('C', 'T', 'M', 'PM')
+ if t == 'C':
+ conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock
+ blocks += [conv_cls(
+ in_chs,
+ out_chs,
+ stride=block_stride,
+ cfg=conv_cfg,
+ drop_path=drop_path[i],
+ )]
+ elif t == 'T':
+ rel_pos_cls = get_rel_pos_cls(transformer_cfg, feat_size)
+ blocks += [TransformerBlock2d(
+ in_chs,
+ out_chs,
+ stride=block_stride,
+ rel_pos_cls=rel_pos_cls,
+ cfg=transformer_cfg,
+ drop_path=drop_path[i],
+ )]
+ elif t == 'M':
+ blocks += [MaxxVitBlock(
+ in_chs,
+ out_chs,
+ stride=block_stride,
+ conv_cfg=conv_cfg,
+ transformer_cfg=transformer_cfg,
+ drop_path=drop_path[i],
+ )]
+ elif t == 'PM':
+ blocks += [ParallelMaxxVitBlock(
+ in_chs,
+ out_chs,
+ stride=block_stride,
+ conv_cfg=conv_cfg,
+ transformer_cfg=transformer_cfg,
+ drop_path=drop_path[i],
+ )]
+ in_chs = out_chs
+ self.blocks = nn.Sequential(*blocks)
+
+ def forward(self, x):
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint_seq(self.blocks, x)
+ else:
+ x = self.blocks(x)
+ return x
+
+
+class Stem(nn.Module):
+
+ def __init__(
+ self,
+ in_chs: int,
+ out_chs: int,
+ kernel_size: int = 3,
+ padding: str = '',
+ bias: bool = False,
+ act_layer: str = 'gelu',
+ norm_layer: str = 'batchnorm2d',
+ norm_eps: float = 1e-5,
+ ):
+ super().__init__()
+ if not isinstance(out_chs, (list, tuple)):
+ out_chs = to_2tuple(out_chs)
+
+ norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps)
+ self.out_chs = out_chs[-1]
+ self.stride = 2
+
+ self.conv1 = create_conv2d(in_chs, out_chs[0], kernel_size, stride=2, padding=padding, bias=bias)
+ self.norm1 = norm_act_layer(out_chs[0])
+ self.conv2 = create_conv2d(out_chs[0], out_chs[1], kernel_size, stride=1, padding=padding, bias=bias)
+
+ def init_weights(self, scheme=''):
+ named_apply(partial(_init_conv, scheme=scheme), self)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.conv2(x)
+ return x
+
+
+def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]):
+ if cfg.window_size is not None:
+ assert cfg.grid_size
+ return cfg
+ partition_size = img_size[0] // cfg.partition_ratio, img_size[1] // cfg.partition_ratio
+ cfg = replace(cfg, window_size=partition_size, grid_size=partition_size)
+ return cfg
+
+
+def _overlay_kwargs(cfg: MaxxVitCfg, **kwargs):
+ transformer_kwargs = {}
+ conv_kwargs = {}
+ base_kwargs = {}
+ for k, v in kwargs.items():
+ if k.startswith('transformer_'):
+ transformer_kwargs[k.replace('transformer_', '')] = v
+ elif k.startswith('conv_'):
+ conv_kwargs[k.replace('conv_', '')] = v
+ else:
+ base_kwargs[k] = v
+ cfg = replace(
+ cfg,
+ transformer_cfg=replace(cfg.transformer_cfg, **transformer_kwargs),
+ conv_cfg=replace(cfg.conv_cfg, **conv_kwargs),
+ **base_kwargs
+ )
+ return cfg
+
+
+class MaxxVit(nn.Module):
+ """ CoaTNet + MaxVit base model.
+
+ Highly configurable for different block compositions, tensor layouts, pooling types.
+ """
+
+ def __init__(
+ self,
+ cfg: MaxxVitCfg,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ in_chans: int = 3,
+ num_classes: int = 1000,
+ global_pool: str = 'avg',
+ drop_rate: float = 0.,
+ drop_path_rate: float = 0.,
+ **kwargs,
+ ):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ if kwargs:
+ cfg = _overlay_kwargs(cfg, **kwargs)
+ transformer_cfg = cfg_window_size(cfg.transformer_cfg, img_size)
+ self.num_classes = num_classes
+ self.global_pool = global_pool
+ self.num_features = self.embed_dim = cfg.embed_dim[-1]
+ self.drop_rate = drop_rate
+ self.grad_checkpointing = False
+ self.feature_info = []
+
+ self.stem = Stem(
+ in_chs=in_chans,
+ out_chs=cfg.stem_width,
+ padding=cfg.conv_cfg.padding,
+ bias=cfg.stem_bias,
+ act_layer=cfg.conv_cfg.act_layer,
+ norm_layer=cfg.conv_cfg.norm_layer,
+ norm_eps=cfg.conv_cfg.norm_eps,
+ )
+ stride = self.stem.stride
+ self.feature_info += [dict(num_chs=self.stem.out_chs, reduction=2, module='stem')]
+ feat_size = tuple([i // s for i, s in zip(img_size, to_2tuple(stride))])
+
+ num_stages = len(cfg.embed_dim)
+ assert len(cfg.depths) == num_stages
+ dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
+ in_chs = self.stem.out_chs
+ stages = []
+ for i in range(num_stages):
+ stage_stride = 2
+ out_chs = cfg.embed_dim[i]
+ feat_size = tuple([(r - 1) // stage_stride + 1 for r in feat_size])
+ stages += [MaxxVitStage(
+ in_chs,
+ out_chs,
+ depth=cfg.depths[i],
+ block_types=cfg.block_type[i],
+ conv_cfg=cfg.conv_cfg,
+ transformer_cfg=transformer_cfg,
+ feat_size=feat_size,
+ drop_path=dpr[i],
+ )]
+ stride *= stage_stride
+ in_chs = out_chs
+ self.feature_info += [dict(num_chs=out_chs, reduction=stride, module=f'stages.{i}')]
+ self.stages = nn.Sequential(*stages)
+
+ final_norm_layer = partial(get_norm_layer(cfg.transformer_cfg.norm_layer), eps=cfg.transformer_cfg.norm_eps)
+ self.head_hidden_size = cfg.head_hidden_size
+ if self.head_hidden_size:
+ self.norm = nn.Identity()
+ self.head = NormMlpClassifierHead(
+ self.num_features,
+ num_classes,
+ hidden_size=self.head_hidden_size,
+ pool_type=global_pool,
+ drop_rate=drop_rate,
+ norm_layer=final_norm_layer,
+ )
+ else:
+ # standard classifier head w/ norm, pooling, fc classifier
+ self.norm = final_norm_layer(self.num_features)
+ self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
+
+ # Weight init (default PyTorch init works well for AdamW if scheme not set)
+ assert cfg.weight_init in ('', 'normal', 'trunc_normal', 'xavier_normal', 'vit_eff')
+ if cfg.weight_init:
+ named_apply(partial(self._init_weights, scheme=cfg.weight_init), self)
+
+ def _init_weights(self, module, name, scheme=''):
+ if hasattr(module, 'init_weights'):
+ try:
+ module.init_weights(scheme=scheme)
+ except TypeError:
+ module.init_weights()
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {
+ k for k, _ in self.named_parameters()
+ if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])}
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse=False):
+ matcher = dict(
+ stem=r'^stem', # stem and embed
+ blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
+ )
+ return matcher
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ for s in self.stages:
+ s.grad_checkpointing = enable
+
+ @torch.jit.ignore
+ def get_classifier(self):
+ return self.head.fc
+
+ def reset_classifier(self, num_classes, global_pool=None):
+ self.num_classes = num_classes
+ self.head.reset(num_classes, global_pool)
+
+ def forward_features(self, x):
+ x = self.stem(x)
+ x = self.stages(x)
+ x = self.norm(x)
+ return x
+
+ def forward_head(self, x, pre_logits: bool = False):
+ return self.head(x, pre_logits=pre_logits)
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.forward_head(x)
+ return x
+
+
+def _rw_coat_cfg(
+ stride_mode='pool',
+ pool_type='avg2',
+ conv_output_bias=False,
+ conv_attn_early=False,
+ conv_attn_act_layer='relu',
+ conv_norm_layer='',
+ transformer_shortcut_bias=True,
+ transformer_norm_layer='layernorm2d',
+ transformer_norm_layer_cl='layernorm',
+ init_values=None,
+ rel_pos_type='bias',
+ rel_pos_dim=512,
+):
+ # 'RW' timm variant models were created and trained before seeing https://github.com/google-research/maxvit
+ # Common differences for initial timm models:
+ # - pre-norm layer in MZBConv included an activation after norm
+ # - mbconv expansion calculated from input instead of output chs
+ # - mbconv shortcut and final 1x1 conv did not have a bias
+ # - SE act layer was relu, not silu
+ # - mbconv uses silu in timm, not gelu
+ # - expansion in attention block done via output proj, not input proj
+ # Variable differences (evolved over training initial models):
+ # - avg pool with kernel_size=2 favoured downsampling (instead of maxpool for coat)
+ # - SE attention was between conv2 and norm/act
+ # - default to avg pool for mbconv downsample instead of 1x1 or dw conv
+ # - transformer block shortcut has no bias
+ return dict(
+ conv_cfg=MaxxVitConvCfg(
+ stride_mode=stride_mode,
+ pool_type=pool_type,
+ pre_norm_act=True,
+ expand_output=False,
+ output_bias=conv_output_bias,
+ attn_early=conv_attn_early,
+ attn_act_layer=conv_attn_act_layer,
+ act_layer='silu',
+ norm_layer=conv_norm_layer,
+ ),
+ transformer_cfg=MaxxVitTransformerCfg(
+ expand_first=False,
+ shortcut_bias=transformer_shortcut_bias,
+ pool_type=pool_type,
+ init_values=init_values,
+ norm_layer=transformer_norm_layer,
+ norm_layer_cl=transformer_norm_layer_cl,
+ rel_pos_type=rel_pos_type,
+ rel_pos_dim=rel_pos_dim,
+ ),
+ )
+
+
+def _rw_max_cfg(
+ stride_mode='dw',
+ pool_type='avg2',
+ conv_output_bias=False,
+ conv_attn_ratio=1 / 16,
+ conv_norm_layer='',
+ transformer_norm_layer='layernorm2d',
+ transformer_norm_layer_cl='layernorm',
+ window_size=None,
+ dim_head=32,
+ init_values=None,
+ rel_pos_type='bias',
+ rel_pos_dim=512,
+):
+ # 'RW' timm variant models were created and trained before seeing https://github.com/google-research/maxvit
+ # Differences of initial timm models:
+ # - mbconv expansion calculated from input instead of output chs
+ # - mbconv shortcut and final 1x1 conv did not have a bias
+ # - mbconv uses silu in timm, not gelu
+ # - expansion in attention block done via output proj, not input proj
+ return dict(
+ conv_cfg=MaxxVitConvCfg(
+ stride_mode=stride_mode,
+ pool_type=pool_type,
+ expand_output=False,
+ output_bias=conv_output_bias,
+ attn_ratio=conv_attn_ratio,
+ act_layer='silu',
+ norm_layer=conv_norm_layer,
+ ),
+ transformer_cfg=MaxxVitTransformerCfg(
+ expand_first=False,
+ pool_type=pool_type,
+ dim_head=dim_head,
+ window_size=window_size,
+ init_values=init_values,
+ norm_layer=transformer_norm_layer,
+ norm_layer_cl=transformer_norm_layer_cl,
+ rel_pos_type=rel_pos_type,
+ rel_pos_dim=rel_pos_dim,
+ ),
+ )
+
+
+def _next_cfg(
+ stride_mode='dw',
+ pool_type='avg2',
+ conv_norm_layer='layernorm2d',
+ conv_norm_layer_cl='layernorm',
+ transformer_norm_layer='layernorm2d',
+ transformer_norm_layer_cl='layernorm',
+ window_size=None,
+ no_block_attn=False,
+ init_values=1e-6,
+ rel_pos_type='mlp', # MLP by default for maxxvit
+ rel_pos_dim=512,
+):
+ # For experimental models with convnext instead of mbconv
+ init_values = to_2tuple(init_values)
+ return dict(
+ conv_cfg=MaxxVitConvCfg(
+ block_type='convnext',
+ stride_mode=stride_mode,
+ pool_type=pool_type,
+ expand_output=False,
+ init_values=init_values[0],
+ norm_layer=conv_norm_layer,
+ norm_layer_cl=conv_norm_layer_cl,
+ ),
+ transformer_cfg=MaxxVitTransformerCfg(
+ expand_first=False,
+ pool_type=pool_type,
+ window_size=window_size,
+ no_block_attn=no_block_attn, # enabled for MaxxViT-V2
+ init_values=init_values[1],
+ norm_layer=transformer_norm_layer,
+ norm_layer_cl=transformer_norm_layer_cl,
+ rel_pos_type=rel_pos_type,
+ rel_pos_dim=rel_pos_dim,
+ ),
+ )
+
+
+def _tf_cfg():
+ return dict(
+ conv_cfg=MaxxVitConvCfg(
+ norm_eps=1e-3,
+ act_layer='gelu_tanh',
+ padding='same',
+ ),
+ transformer_cfg=MaxxVitTransformerCfg(
+ norm_eps=1e-5,
+ act_layer='gelu_tanh',
+ head_first=False, # heads are interleaved (q_nh, q_hdim, k_nh, q_hdim, ....)
+ rel_pos_type='bias_tf',
+ ),
+ )
+
+
+model_cfgs = dict(
+ # timm specific CoAtNet configs
+ coatnet_pico_rw=MaxxVitCfg(
+ embed_dim=(64, 128, 256, 512),
+ depths=(2, 3, 5, 2),
+ stem_width=(32, 64),
+ **_rw_max_cfg( # using newer max defaults here
+ conv_output_bias=True,
+ conv_attn_ratio=0.25,
+ ),
+ ),
+ coatnet_nano_rw=MaxxVitCfg(
+ embed_dim=(64, 128, 256, 512),
+ depths=(3, 4, 6, 3),
+ stem_width=(32, 64),
+ **_rw_max_cfg( # using newer max defaults here
+ stride_mode='pool',
+ conv_output_bias=True,
+ conv_attn_ratio=0.25,
+ ),
+ ),
+ coatnet_0_rw=MaxxVitCfg(
+ embed_dim=(96, 192, 384, 768),
+ depths=(2, 3, 7, 2), # deeper than paper '0' model
+ stem_width=(32, 64),
+ **_rw_coat_cfg(
+ conv_attn_early=True,
+ transformer_shortcut_bias=False,
+ ),
+ ),
+ coatnet_1_rw=MaxxVitCfg(
+ embed_dim=(96, 192, 384, 768),
+ depths=(2, 6, 14, 2),
+ stem_width=(32, 64),
+ **_rw_coat_cfg(
+ stride_mode='dw',
+ conv_attn_early=True,
+ transformer_shortcut_bias=False,
+ )
+ ),
+ coatnet_2_rw=MaxxVitCfg(
+ embed_dim=(128, 256, 512, 1024),
+ depths=(2, 6, 14, 2),
+ stem_width=(64, 128),
+ **_rw_coat_cfg(
+ stride_mode='dw',
+ conv_attn_act_layer='silu',
+ #init_values=1e-6,
+ ),
+ ),
+ coatnet_3_rw=MaxxVitCfg(
+ embed_dim=(192, 384, 768, 1536),
+ depths=(2, 6, 14, 2),
+ stem_width=(96, 192),
+ **_rw_coat_cfg(
+ stride_mode='dw',
+ conv_attn_act_layer='silu',
+ init_values=1e-6,
+ ),
+ ),
+
+ # Experimental CoAtNet configs w/ ImageNet-1k train (different norm layers, MLP rel-pos)
+ coatnet_bn_0_rw=MaxxVitCfg(
+ embed_dim=(96, 192, 384, 768),
+ depths=(2, 3, 7, 2), # deeper than paper '0' model
+ stem_width=(32, 64),
+ **_rw_coat_cfg(
+ stride_mode='dw',
+ conv_attn_early=True,
+ transformer_shortcut_bias=False,
+ transformer_norm_layer='batchnorm2d',
+ )
+ ),
+ coatnet_rmlp_nano_rw=MaxxVitCfg(
+ embed_dim=(64, 128, 256, 512),
+ depths=(3, 4, 6, 3),
+ stem_width=(32, 64),
+ **_rw_max_cfg(
+ conv_output_bias=True,
+ conv_attn_ratio=0.25,
+ rel_pos_type='mlp',
+ rel_pos_dim=384,
+ ),
+ ),
+ coatnet_rmlp_0_rw=MaxxVitCfg(
+ embed_dim=(96, 192, 384, 768),
+ depths=(2, 3, 7, 2), # deeper than paper '0' model
+ stem_width=(32, 64),
+ **_rw_coat_cfg(
+ stride_mode='dw',
+ rel_pos_type='mlp',
+ ),
+ ),
+ coatnet_rmlp_1_rw=MaxxVitCfg(
+ embed_dim=(96, 192, 384, 768),
+ depths=(2, 6, 14, 2),
+ stem_width=(32, 64),
+ **_rw_coat_cfg(
+ pool_type='max',
+ conv_attn_early=True,
+ transformer_shortcut_bias=False,
+ rel_pos_type='mlp',
+ rel_pos_dim=384, # was supposed to be 512, woops
+ ),
+ ),
+ coatnet_rmlp_1_rw2=MaxxVitCfg(
+ embed_dim=(96, 192, 384, 768),
+ depths=(2, 6, 14, 2),
+ stem_width=(32, 64),
+ **_rw_coat_cfg(
+ stride_mode='dw',
+ rel_pos_type='mlp',
+ rel_pos_dim=512, # was supposed to be 512, woops
+ ),
+ ),
+ coatnet_rmlp_2_rw=MaxxVitCfg(
+ embed_dim=(128, 256, 512, 1024),
+ depths=(2, 6, 14, 2),
+ stem_width=(64, 128),
+ **_rw_coat_cfg(
+ stride_mode='dw',
+ conv_attn_act_layer='silu',
+ init_values=1e-6,
+ rel_pos_type='mlp'
+ ),
+ ),
+ coatnet_rmlp_3_rw=MaxxVitCfg(
+ embed_dim=(192, 384, 768, 1536),
+ depths=(2, 6, 14, 2),
+ stem_width=(96, 192),
+ **_rw_coat_cfg(
+ stride_mode='dw',
+ conv_attn_act_layer='silu',
+ init_values=1e-6,
+ rel_pos_type='mlp'
+ ),
+ ),
+
+ coatnet_nano_cc=MaxxVitCfg(
+ embed_dim=(64, 128, 256, 512),
+ depths=(3, 4, 6, 3),
+ stem_width=(32, 64),
+ block_type=('C', 'C', ('C', 'T'), ('C', 'T')),
+ **_rw_coat_cfg(),
+ ),
+ coatnext_nano_rw=MaxxVitCfg(
+ embed_dim=(64, 128, 256, 512),
+ depths=(3, 4, 6, 3),
+ stem_width=(32, 64),
+ weight_init='normal',
+ **_next_cfg(
+ rel_pos_type='bias',
+ init_values=(1e-5, None)
+ ),
+ ),
+
+ # Trying to be like the CoAtNet paper configs
+ coatnet_0=MaxxVitCfg(
+ embed_dim=(96, 192, 384, 768),
+ depths=(2, 3, 5, 2),
+ stem_width=64,
+ head_hidden_size=768,
+ ),
+ coatnet_1=MaxxVitCfg(
+ embed_dim=(96, 192, 384, 768),
+ depths=(2, 6, 14, 2),
+ stem_width=64,
+ head_hidden_size=768,
+ ),
+ coatnet_2=MaxxVitCfg(
+ embed_dim=(128, 256, 512, 1024),
+ depths=(2, 6, 14, 2),
+ stem_width=128,
+ head_hidden_size=1024,
+ ),
+ coatnet_3=MaxxVitCfg(
+ embed_dim=(192, 384, 768, 1536),
+ depths=(2, 6, 14, 2),
+ stem_width=192,
+ head_hidden_size=1536,
+ ),
+ coatnet_4=MaxxVitCfg(
+ embed_dim=(192, 384, 768, 1536),
+ depths=(2, 12, 28, 2),
+ stem_width=192,
+ head_hidden_size=1536,
+ ),
+ coatnet_5=MaxxVitCfg(
+ embed_dim=(256, 512, 1280, 2048),
+ depths=(2, 12, 28, 2),
+ stem_width=192,
+ head_hidden_size=2048,
+ ),
+
+ # Experimental MaxVit configs
+ maxvit_pico_rw=MaxxVitCfg(
+ embed_dim=(32, 64, 128, 256),
+ depths=(2, 2, 5, 2),
+ block_type=('M',) * 4,
+ stem_width=(24, 32),
+ **_rw_max_cfg(),
+ ),
+ maxvit_nano_rw=MaxxVitCfg(
+ embed_dim=(64, 128, 256, 512),
+ depths=(1, 2, 3, 1),
+ block_type=('M',) * 4,
+ stem_width=(32, 64),
+ **_rw_max_cfg(),
+ ),
+ maxvit_tiny_rw=MaxxVitCfg(
+ embed_dim=(64, 128, 256, 512),
+ depths=(2, 2, 5, 2),
+ block_type=('M',) * 4,
+ stem_width=(32, 64),
+ **_rw_max_cfg(),
+ ),
+ maxvit_tiny_pm=MaxxVitCfg(
+ embed_dim=(64, 128, 256, 512),
+ depths=(2, 2, 5, 2),
+ block_type=('PM',) * 4,
+ stem_width=(32, 64),
+ **_rw_max_cfg(),
+ ),
+
+ maxvit_rmlp_pico_rw=MaxxVitCfg(
+ embed_dim=(32, 64, 128, 256),
+ depths=(2, 2, 5, 2),
+ block_type=('M',) * 4,
+ stem_width=(24, 32),
+ **_rw_max_cfg(rel_pos_type='mlp'),
+ ),
+ maxvit_rmlp_nano_rw=MaxxVitCfg(
+ embed_dim=(64, 128, 256, 512),
+ depths=(1, 2, 3, 1),
+ block_type=('M',) * 4,
+ stem_width=(32, 64),
+ **_rw_max_cfg(rel_pos_type='mlp'),
+ ),
+ maxvit_rmlp_tiny_rw=MaxxVitCfg(
+ embed_dim=(64, 128, 256, 512),
+ depths=(2, 2, 5, 2),
+ block_type=('M',) * 4,
+ stem_width=(32, 64),
+ **_rw_max_cfg(rel_pos_type='mlp'),
+ ),
+ maxvit_rmlp_small_rw=MaxxVitCfg(
+ embed_dim=(96, 192, 384, 768),
+ depths=(2, 2, 5, 2),
+ block_type=('M',) * 4,
+ stem_width=(32, 64),
+ **_rw_max_cfg(
+ rel_pos_type='mlp',
+ init_values=1e-6,
+ ),
+ ),
+ maxvit_rmlp_base_rw=MaxxVitCfg(
+ embed_dim=(96, 192, 384, 768),
+ depths=(2, 6, 14, 2),
+ block_type=('M',) * 4,
+ stem_width=(32, 64),
+ head_hidden_size=768,
+ **_rw_max_cfg(
+ rel_pos_type='mlp',
+ ),
+ ),
+
+ maxxvit_rmlp_nano_rw=MaxxVitCfg(
+ embed_dim=(64, 128, 256, 512),
+ depths=(1, 2, 3, 1),
+ block_type=('M',) * 4,
+ stem_width=(32, 64),
+ weight_init='normal',
+ **_next_cfg(),
+ ),
+ maxxvit_rmlp_tiny_rw=MaxxVitCfg(
+ embed_dim=(64, 128, 256, 512),
+ depths=(2, 2, 5, 2),
+ block_type=('M',) * 4,
+ stem_width=(32, 64),
+ **_next_cfg(),
+ ),
+ maxxvit_rmlp_small_rw=MaxxVitCfg(
+ embed_dim=(96, 192, 384, 768),
+ depths=(2, 2, 5, 2),
+ block_type=('M',) * 4,
+ stem_width=(48, 96),
+ **_next_cfg(),
+ ),
+
+ maxxvitv2_nano_rw=MaxxVitCfg(
+ embed_dim=(96, 192, 384, 768),
+ depths=(1, 2, 3, 1),
+ block_type=('M',) * 4,
+ stem_width=(48, 96),
+ weight_init='normal',
+ **_next_cfg(
+ no_block_attn=True,
+ rel_pos_type='bias',
+ ),
+ ),
+ maxxvitv2_rmlp_base_rw=MaxxVitCfg(
+ embed_dim=(128, 256, 512, 1024),
+ depths=(2, 6, 12, 2),
+ block_type=('M',) * 4,
+ stem_width=(64, 128),
+ **_next_cfg(
+ no_block_attn=True,
+ ),
+ ),
+ maxxvitv2_rmlp_large_rw=MaxxVitCfg(
+ embed_dim=(160, 320, 640, 1280),
+ depths=(2, 6, 16, 2),
+ block_type=('M',) * 4,
+ stem_width=(80, 160),
+ head_hidden_size=1280,
+ **_next_cfg(
+ no_block_attn=True,
+ ),
+ ),
+
+ # Trying to be like the MaxViT paper configs
+ maxvit_tiny_tf=MaxxVitCfg(
+ embed_dim=(64, 128, 256, 512),
+ depths=(2, 2, 5, 2),
+ block_type=('M',) * 4,
+ stem_width=64,
+ stem_bias=True,
+ head_hidden_size=512,
+ **_tf_cfg(),
+ ),
+ maxvit_small_tf=MaxxVitCfg(
+ embed_dim=(96, 192, 384, 768),
+ depths=(2, 2, 5, 2),
+ block_type=('M',) * 4,
+ stem_width=64,
+ stem_bias=True,
+ head_hidden_size=768,
+ **_tf_cfg(),
+ ),
+ maxvit_base_tf=MaxxVitCfg(
+ embed_dim=(96, 192, 384, 768),
+ depths=(2, 6, 14, 2),
+ block_type=('M',) * 4,
+ stem_width=64,
+ stem_bias=True,
+ head_hidden_size=768,
+ **_tf_cfg(),
+ ),
+ maxvit_large_tf=MaxxVitCfg(
+ embed_dim=(128, 256, 512, 1024),
+ depths=(2, 6, 14, 2),
+ block_type=('M',) * 4,
+ stem_width=128,
+ stem_bias=True,
+ head_hidden_size=1024,
+ **_tf_cfg(),
+ ),
+ maxvit_xlarge_tf=MaxxVitCfg(
+ embed_dim=(192, 384, 768, 1536),
+ depths=(2, 6, 14, 2),
+ block_type=('M',) * 4,
+ stem_width=192,
+ stem_bias=True,
+ head_hidden_size=1536,
+ **_tf_cfg(),
+ ),
+)
+
+
+def checkpoint_filter_fn(state_dict, model: nn.Module):
+ model_state_dict = model.state_dict()
+ out_dict = {}
+ for k, v in state_dict.items():
+ if k.endswith('relative_position_bias_table'):
+ m = model.get_submodule(k[:-29])
+ if v.shape != m.relative_position_bias_table.shape or m.window_size[0] != m.window_size[1]:
+ v = resize_rel_pos_bias_table(
+ v,
+ new_window_size=m.window_size,
+ new_bias_shape=m.relative_position_bias_table.shape,
+ )
+
+ if k in model_state_dict and v.ndim != model_state_dict[k].ndim and v.numel() == model_state_dict[k].numel():
+ # adapt between conv2d / linear layers
+ assert v.ndim in (2, 4)
+ v = v.reshape(model_state_dict[k].shape)
+ out_dict[k] = v
+ return out_dict
+
+
+def _create_maxxvit(variant, cfg_variant=None, pretrained=False, **kwargs):
+ if cfg_variant is None:
+ if variant in model_cfgs:
+ cfg_variant = variant
+ else:
+ cfg_variant = '_'.join(variant.split('_')[:-1])
+ return build_model_with_cfg(
+ MaxxVit, variant, pretrained,
+ model_cfg=model_cfgs[cfg_variant],
+ feature_cfg=dict(flatten_sequential=True),
+ pretrained_filter_fn=checkpoint_filter_fn,
+ **kwargs)
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.95, 'interpolation': 'bicubic',
+ 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
+ 'first_conv': 'stem.conv1', 'classifier': 'head.fc',
+ 'fixed_input_size': True,
+ **kwargs
+ }
+
+
+default_cfgs = generate_default_cfgs({
+ # timm specific CoAtNet configs, ImageNet-1k pretrain, fixed rel-pos
+ 'coatnet_pico_rw_224.untrained': _cfg(url=''),
+ 'coatnet_nano_rw_224.sw_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_nano_rw_224_sw-f53093b4.pth',
+ crop_pct=0.9),
+ 'coatnet_0_rw_224.sw_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_0_rw_224_sw-a6439706.pth'),
+ 'coatnet_1_rw_224.sw_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth'
+ ),
+
+ # timm specific CoAtNet configs, ImageNet-12k pretrain w/ 1k fine-tune, fixed rel-pos
+ 'coatnet_2_rw_224.sw_in12k_ft_in1k': _cfg(
+ hf_hub_id='timm/'),
+ #'coatnet_3_rw_224.untrained': _cfg(url=''),
+
+ # Experimental CoAtNet configs w/ ImageNet-12k pretrain -> 1k fine-tune (different norm layers, MLP rel-pos)
+ 'coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k': _cfg(
+ hf_hub_id='timm/'),
+ 'coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k': _cfg(
+ hf_hub_id='timm/'),
+ 'coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
+
+ # Experimental CoAtNet configs w/ ImageNet-1k train (different norm layers, MLP rel-pos)
+ 'coatnet_bn_0_rw_224.sw_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_bn_0_rw_224_sw-c228e218.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
+ crop_pct=0.95),
+ 'coatnet_rmlp_nano_rw_224.sw_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_nano_rw_224_sw-bd1d51b3.pth',
+ crop_pct=0.9),
+ 'coatnet_rmlp_0_rw_224.untrained': _cfg(url=''),
+ 'coatnet_rmlp_1_rw_224.sw_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'),
+ 'coatnet_rmlp_2_rw_224.sw_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_2_rw_224_sw-5ccfac55.pth'),
+ 'coatnet_rmlp_3_rw_224.untrained': _cfg(url=''),
+ 'coatnet_nano_cc_224.untrained': _cfg(url=''),
+ 'coatnext_nano_rw_224.sw_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnext_nano_rw_224_ad-22cb71c2.pth',
+ crop_pct=0.9),
+
+ # ImagenNet-12k pretrain CoAtNet
+ 'coatnet_2_rw_224.sw_in12k': _cfg(
+ hf_hub_id='timm/',
+ num_classes=11821),
+ 'coatnet_3_rw_224.sw_in12k': _cfg(
+ hf_hub_id='timm/',
+ num_classes=11821),
+ 'coatnet_rmlp_1_rw2_224.sw_in12k': _cfg(
+ hf_hub_id='timm/',
+ num_classes=11821),
+ 'coatnet_rmlp_2_rw_224.sw_in12k': _cfg(
+ hf_hub_id='timm/',
+ num_classes=11821),
+
+ # Trying to be like the CoAtNet paper configs (will adapt if 'tf' weights are ever released)
+ 'coatnet_0_224.untrained': _cfg(url=''),
+ 'coatnet_1_224.untrained': _cfg(url=''),
+ 'coatnet_2_224.untrained': _cfg(url=''),
+ 'coatnet_3_224.untrained': _cfg(url=''),
+ 'coatnet_4_224.untrained': _cfg(url=''),
+ 'coatnet_5_224.untrained': _cfg(url=''),
+
+ # timm specific MaxVit configs, ImageNet-1k pretrain or untrained
+ 'maxvit_pico_rw_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
+ 'maxvit_nano_rw_256.sw_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth',
+ input_size=(3, 256, 256), pool_size=(8, 8)),
+ 'maxvit_tiny_rw_224.sw_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_tiny_rw_224_sw-7d0dffeb.pth'),
+ 'maxvit_tiny_rw_256.untrained': _cfg(
+ url='',
+ input_size=(3, 256, 256), pool_size=(8, 8)),
+ 'maxvit_tiny_pm_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
+
+ # timm specific MaxVit w/ MLP rel-pos, ImageNet-1k pretrain
+ 'maxvit_rmlp_pico_rw_256.sw_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_pico_rw_256_sw-8d82f2c6.pth',
+ input_size=(3, 256, 256), pool_size=(8, 8)),
+ 'maxvit_rmlp_nano_rw_256.sw_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth',
+ input_size=(3, 256, 256), pool_size=(8, 8)),
+ 'maxvit_rmlp_tiny_rw_256.sw_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_tiny_rw_256_sw-bbef0ff5.pth',
+ input_size=(3, 256, 256), pool_size=(8, 8)),
+ 'maxvit_rmlp_small_rw_224.sw_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_small_rw_224_sw-6ef0ae4f.pth',
+ crop_pct=0.9,
+ ),
+ 'maxvit_rmlp_small_rw_256.untrained': _cfg(
+ url='',
+ input_size=(3, 256, 256), pool_size=(8, 8)),
+
+ # timm specific MaxVit w/ ImageNet-12k pretrain and 1k fine-tune
+ 'maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ ),
+ 'maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
+
+ # timm specific MaxVit w/ ImageNet-12k pretrain
+ 'maxvit_rmlp_base_rw_224.sw_in12k': _cfg(
+ hf_hub_id='timm/',
+ num_classes=11821,
+ ),
+
+ # timm MaxxViT configs (ConvNeXt conv blocks mixed with MaxVit transformer blocks)
+ 'maxxvit_rmlp_nano_rw_256.sw_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_nano_rw_256_sw-0325d459.pth',
+ input_size=(3, 256, 256), pool_size=(8, 8)),
+ 'maxxvit_rmlp_tiny_rw_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
+ 'maxxvit_rmlp_small_rw_256.sw_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_small_rw_256_sw-37e217ff.pth',
+ input_size=(3, 256, 256), pool_size=(8, 8)),
+
+ # timm MaxxViT-V2 configs (ConvNeXt conv blocks mixed with MaxVit transformer blocks, more width, no block attn)
+ 'maxxvitv2_nano_rw_256.sw_in1k': _cfg(
+ hf_hub_id='timm/',
+ input_size=(3, 256, 256), pool_size=(8, 8)),
+ 'maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k': _cfg(
+ hf_hub_id='timm/'),
+ 'maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
+ 'maxxvitv2_rmlp_large_rw_224.untrained': _cfg(url=''),
+
+ 'maxxvitv2_rmlp_base_rw_224.sw_in12k': _cfg(
+ hf_hub_id='timm/',
+ num_classes=11821),
+
+ # MaxViT models ported from official Tensorflow impl
+ 'maxvit_tiny_tf_224.in1k': _cfg(
+ hf_hub_id='timm/',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+ 'maxvit_tiny_tf_384.in1k': _cfg(
+ hf_hub_id='timm/',
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
+ 'maxvit_tiny_tf_512.in1k': _cfg(
+ hf_hub_id='timm/',
+ input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
+ 'maxvit_small_tf_224.in1k': _cfg(
+ hf_hub_id='timm/',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+ 'maxvit_small_tf_384.in1k': _cfg(
+ hf_hub_id='timm/',
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
+ 'maxvit_small_tf_512.in1k': _cfg(
+ hf_hub_id='timm/',
+ input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
+ 'maxvit_base_tf_224.in1k': _cfg(
+ hf_hub_id='timm/',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+ 'maxvit_base_tf_384.in1k': _cfg(
+ hf_hub_id='timm/',
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
+ 'maxvit_base_tf_512.in1k': _cfg(
+ hf_hub_id='timm/',
+ input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
+ 'maxvit_large_tf_224.in1k': _cfg(
+ hf_hub_id='timm/',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+ 'maxvit_large_tf_384.in1k': _cfg(
+ hf_hub_id='timm/',
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
+ 'maxvit_large_tf_512.in1k': _cfg(
+ hf_hub_id='timm/',
+ input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
+
+ 'maxvit_base_tf_224.in21k': _cfg(
+ hf_hub_id='timm/',
+ num_classes=21843),
+ 'maxvit_base_tf_384.in21k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
+ 'maxvit_base_tf_512.in21k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
+ 'maxvit_large_tf_224.in21k': _cfg(
+ hf_hub_id='timm/',
+ num_classes=21843),
+ 'maxvit_large_tf_384.in21k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
+ 'maxvit_large_tf_512.in21k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
+ 'maxvit_xlarge_tf_224.in21k': _cfg(
+ hf_hub_id='timm/',
+ num_classes=21843),
+ 'maxvit_xlarge_tf_384.in21k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
+ 'maxvit_xlarge_tf_512.in21k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
+})
+
+
+@register_model
+def coatnet_pico_rw_224(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('coatnet_pico_rw_224', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def coatnet_nano_rw_224(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('coatnet_nano_rw_224', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def coatnet_0_rw_224(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('coatnet_0_rw_224', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def coatnet_1_rw_224(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('coatnet_1_rw_224', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def coatnet_2_rw_224(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('coatnet_2_rw_224', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def coatnet_3_rw_224(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('coatnet_3_rw_224', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def coatnet_bn_0_rw_224(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('coatnet_bn_0_rw_224', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def coatnet_rmlp_nano_rw_224(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('coatnet_rmlp_nano_rw_224', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def coatnet_rmlp_0_rw_224(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('coatnet_rmlp_0_rw_224', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def coatnet_rmlp_1_rw_224(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('coatnet_rmlp_1_rw_224', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def coatnet_rmlp_1_rw2_224(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('coatnet_rmlp_1_rw2_224', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def coatnet_rmlp_2_rw_224(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('coatnet_rmlp_2_rw_224', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def coatnet_rmlp_2_rw_384(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('coatnet_rmlp_2_rw_384', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def coatnet_rmlp_3_rw_224(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('coatnet_rmlp_3_rw_224', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def coatnet_nano_cc_224(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('coatnet_nano_cc_224', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def coatnext_nano_rw_224(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('coatnext_nano_rw_224', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def coatnet_0_224(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('coatnet_0_224', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def coatnet_1_224(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('coatnet_1_224', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def coatnet_2_224(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('coatnet_2_224', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def coatnet_3_224(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('coatnet_3_224', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def coatnet_4_224(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('coatnet_4_224', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def coatnet_5_224(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('coatnet_5_224', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxvit_pico_rw_256(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxvit_pico_rw_256', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxvit_nano_rw_256(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxvit_nano_rw_256', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxvit_tiny_rw_224(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxvit_tiny_rw_224', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxvit_tiny_rw_256(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxvit_tiny_rw_256', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxvit_rmlp_pico_rw_256(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxvit_rmlp_pico_rw_256', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxvit_rmlp_nano_rw_256(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxvit_rmlp_tiny_rw_256(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxvit_rmlp_tiny_rw_256', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxvit_rmlp_small_rw_224(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxvit_rmlp_small_rw_224', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxvit_rmlp_small_rw_256(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxvit_rmlp_small_rw_256', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxvit_rmlp_base_rw_224(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxvit_rmlp_base_rw_224', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxvit_rmlp_base_rw_384(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxvit_rmlp_base_rw_384', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxvit_tiny_pm_256(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxxvit_rmlp_nano_rw_256(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxxvit_rmlp_tiny_rw_256(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxxvit_rmlp_tiny_rw_256', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxxvit_rmlp_small_rw_256(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxxvit_rmlp_small_rw_256', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxxvitv2_nano_rw_256(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxxvitv2_nano_rw_256', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxxvitv2_rmlp_base_rw_224(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxxvitv2_rmlp_base_rw_224', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxxvitv2_rmlp_base_rw_384(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxxvitv2_rmlp_base_rw_384', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxxvitv2_rmlp_large_rw_224(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxxvitv2_rmlp_large_rw_224', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxvit_tiny_tf_224(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxvit_tiny_tf_224', 'maxvit_tiny_tf', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxvit_tiny_tf_384(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxvit_tiny_tf_384', 'maxvit_tiny_tf', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxvit_tiny_tf_512(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxvit_tiny_tf_512', 'maxvit_tiny_tf', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxvit_small_tf_224(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxvit_small_tf_224', 'maxvit_small_tf', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxvit_small_tf_384(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxvit_small_tf_384', 'maxvit_small_tf', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxvit_small_tf_512(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxvit_small_tf_512', 'maxvit_small_tf', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxvit_base_tf_224(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxvit_base_tf_224', 'maxvit_base_tf', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxvit_base_tf_384(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxvit_base_tf_384', 'maxvit_base_tf', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxvit_base_tf_512(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxvit_base_tf_512', 'maxvit_base_tf', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxvit_large_tf_224(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxvit_large_tf_224', 'maxvit_large_tf', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxvit_large_tf_384(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxvit_large_tf_384', 'maxvit_large_tf', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxvit_large_tf_512(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxvit_large_tf_512', 'maxvit_large_tf', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxvit_xlarge_tf_224(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxvit_xlarge_tf_224', 'maxvit_xlarge_tf', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxvit_xlarge_tf_384(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxvit_xlarge_tf_384', 'maxvit_xlarge_tf', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def maxvit_xlarge_tf_512(pretrained=False, **kwargs) -> MaxxVit:
+ return _create_maxxvit('maxvit_xlarge_tf_512', 'maxvit_xlarge_tf', pretrained=pretrained, **kwargs)
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/metaformer.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/metaformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b026a2e438ccfcbcde15e9c7839c19ecbde8aec
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/metaformer.py
@@ -0,0 +1,1056 @@
+"""
+Poolformer from MetaFormer is Actually What You Need for Vision https://arxiv.org/abs/2111.11418
+
+IdentityFormer, RandFormer, PoolFormerV2, ConvFormer, and CAFormer
+from MetaFormer Baselines for Vision https://arxiv.org/abs/2210.13452
+
+All implemented models support feature extraction and variable input resolution.
+
+Original implementation by Weihao Yu et al.,
+adapted for timm by Fredo Guan and Ross Wightman.
+
+Adapted from https://github.com/sail-sg/metaformer, original copyright below
+"""
+
+# Copyright 2022 Garena Online Private Limited
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from collections import OrderedDict
+from functools import partial
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+from torch.jit import Final
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.layers import trunc_normal_, DropPath, SelectAdaptivePool2d, GroupNorm1, LayerNorm, LayerNorm2d, Mlp, \
+ use_fused_attn
+from ._builder import build_model_with_cfg
+from ._manipulate import checkpoint_seq
+from ._registry import generate_default_cfgs, register_model
+
+__all__ = ['MetaFormer']
+
+
+class Stem(nn.Module):
+ """
+ Stem implemented by a layer of convolution.
+ Conv2d params constant across all models.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ norm_layer=None,
+ ):
+ super().__init__()
+ self.conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=7,
+ stride=4,
+ padding=2
+ )
+ self.norm = norm_layer(out_channels) if norm_layer else nn.Identity()
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.norm(x)
+ return x
+
+
+class Downsampling(nn.Module):
+ """
+ Downsampling implemented by a layer of convolution.
+ """
+
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ norm_layer=None,
+ ):
+ super().__init__()
+ self.norm = norm_layer(in_channels) if norm_layer else nn.Identity()
+ self.conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding
+ )
+
+ def forward(self, x):
+ x = self.norm(x)
+ x = self.conv(x)
+ return x
+
+
+class Scale(nn.Module):
+ """
+ Scale vector by element multiplications.
+ """
+
+ def __init__(self, dim, init_value=1.0, trainable=True, use_nchw=True):
+ super().__init__()
+ self.shape = (dim, 1, 1) if use_nchw else (dim,)
+ self.scale = nn.Parameter(init_value * torch.ones(dim), requires_grad=trainable)
+
+ def forward(self, x):
+ return x * self.scale.view(self.shape)
+
+
+class SquaredReLU(nn.Module):
+ """
+ Squared ReLU: https://arxiv.org/abs/2109.08668
+ """
+
+ def __init__(self, inplace=False):
+ super().__init__()
+ self.relu = nn.ReLU(inplace=inplace)
+
+ def forward(self, x):
+ return torch.square(self.relu(x))
+
+
+class StarReLU(nn.Module):
+ """
+ StarReLU: s * relu(x) ** 2 + b
+ """
+
+ def __init__(
+ self,
+ scale_value=1.0,
+ bias_value=0.0,
+ scale_learnable=True,
+ bias_learnable=True,
+ mode=None,
+ inplace=False
+ ):
+ super().__init__()
+ self.inplace = inplace
+ self.relu = nn.ReLU(inplace=inplace)
+ self.scale = nn.Parameter(scale_value * torch.ones(1), requires_grad=scale_learnable)
+ self.bias = nn.Parameter(bias_value * torch.ones(1), requires_grad=bias_learnable)
+
+ def forward(self, x):
+ return self.scale * self.relu(x) ** 2 + self.bias
+
+
+class Attention(nn.Module):
+ """
+ Vanilla self-attention from Transformer: https://arxiv.org/abs/1706.03762.
+ Modified from timm.
+ """
+ fused_attn: Final[bool]
+
+ def __init__(
+ self,
+ dim,
+ head_dim=32,
+ num_heads=None,
+ qkv_bias=False,
+ attn_drop=0.,
+ proj_drop=0.,
+ proj_bias=False,
+ **kwargs
+ ):
+ super().__init__()
+
+ self.head_dim = head_dim
+ self.scale = head_dim ** -0.5
+ self.fused_attn = use_fused_attn()
+
+ self.num_heads = num_heads if num_heads else dim // head_dim
+ if self.num_heads == 0:
+ self.num_heads = 1
+
+ self.attention_dim = self.num_heads * self.head_dim
+
+ self.qkv = nn.Linear(dim, self.attention_dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(self.attention_dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0)
+
+ if self.fused_attn:
+ x = F.scaled_dot_product_attention(
+ q, k, v,
+ dropout_p=self.attn_drop.p if self.training else 0.,
+ )
+ else:
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = attn @ v
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+# custom norm modules that disable the bias term, since the original models defs
+# used a custom norm with a weight term but no bias term.
+
+class GroupNorm1NoBias(GroupNorm1):
+ def __init__(self, num_channels, **kwargs):
+ super().__init__(num_channels, **kwargs)
+ self.eps = kwargs.get('eps', 1e-6)
+ self.bias = None
+
+
+class LayerNorm2dNoBias(LayerNorm2d):
+ def __init__(self, num_channels, **kwargs):
+ super().__init__(num_channels, **kwargs)
+ self.eps = kwargs.get('eps', 1e-6)
+ self.bias = None
+
+
+class LayerNormNoBias(nn.LayerNorm):
+ def __init__(self, num_channels, **kwargs):
+ super().__init__(num_channels, **kwargs)
+ self.eps = kwargs.get('eps', 1e-6)
+ self.bias = None
+
+
+class SepConv(nn.Module):
+ r"""
+ Inverted separable convolution from MobileNetV2: https://arxiv.org/abs/1801.04381.
+ """
+
+ def __init__(
+ self,
+ dim,
+ expansion_ratio=2,
+ act1_layer=StarReLU,
+ act2_layer=nn.Identity,
+ bias=False,
+ kernel_size=7,
+ padding=3,
+ **kwargs
+ ):
+ super().__init__()
+ mid_channels = int(expansion_ratio * dim)
+ self.pwconv1 = nn.Conv2d(dim, mid_channels, kernel_size=1, bias=bias)
+ self.act1 = act1_layer()
+ self.dwconv = nn.Conv2d(
+ mid_channels, mid_channels, kernel_size=kernel_size,
+ padding=padding, groups=mid_channels, bias=bias) # depthwise conv
+ self.act2 = act2_layer()
+ self.pwconv2 = nn.Conv2d(mid_channels, dim, kernel_size=1, bias=bias)
+
+ def forward(self, x):
+ x = self.pwconv1(x)
+ x = self.act1(x)
+ x = self.dwconv(x)
+ x = self.act2(x)
+ x = self.pwconv2(x)
+ return x
+
+
+class Pooling(nn.Module):
+ """
+ Implementation of pooling for PoolFormer: https://arxiv.org/abs/2111.11418
+ """
+
+ def __init__(self, pool_size=3, **kwargs):
+ super().__init__()
+ self.pool = nn.AvgPool2d(
+ pool_size, stride=1, padding=pool_size // 2, count_include_pad=False)
+
+ def forward(self, x):
+ y = self.pool(x)
+ return y - x
+
+
+class MlpHead(nn.Module):
+ """ MLP classification head
+ """
+
+ def __init__(
+ self,
+ dim,
+ num_classes=1000,
+ mlp_ratio=4,
+ act_layer=SquaredReLU,
+ norm_layer=LayerNorm,
+ drop_rate=0.,
+ bias=True
+ ):
+ super().__init__()
+ hidden_features = int(mlp_ratio * dim)
+ self.fc1 = nn.Linear(dim, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.norm = norm_layer(hidden_features)
+ self.fc2 = nn.Linear(hidden_features, num_classes, bias=bias)
+ self.head_drop = nn.Dropout(drop_rate)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.norm(x)
+ x = self.head_drop(x)
+ x = self.fc2(x)
+ return x
+
+
+class MetaFormerBlock(nn.Module):
+ """
+ Implementation of one MetaFormer block.
+ """
+
+ def __init__(
+ self,
+ dim,
+ token_mixer=Pooling,
+ mlp_act=StarReLU,
+ mlp_bias=False,
+ norm_layer=LayerNorm2d,
+ proj_drop=0.,
+ drop_path=0.,
+ use_nchw=True,
+ layer_scale_init_value=None,
+ res_scale_init_value=None,
+ **kwargs
+ ):
+ super().__init__()
+ ls_layer = partial(Scale, dim=dim, init_value=layer_scale_init_value, use_nchw=use_nchw)
+ rs_layer = partial(Scale, dim=dim, init_value=res_scale_init_value, use_nchw=use_nchw)
+
+ self.norm1 = norm_layer(dim)
+ self.token_mixer = token_mixer(dim=dim, proj_drop=proj_drop, **kwargs)
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.layer_scale1 = ls_layer() if layer_scale_init_value is not None else nn.Identity()
+ self.res_scale1 = rs_layer() if res_scale_init_value is not None else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ self.mlp = Mlp(
+ dim,
+ int(4 * dim),
+ act_layer=mlp_act,
+ bias=mlp_bias,
+ drop=proj_drop,
+ use_conv=use_nchw,
+ )
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.layer_scale2 = ls_layer() if layer_scale_init_value is not None else nn.Identity()
+ self.res_scale2 = rs_layer() if res_scale_init_value is not None else nn.Identity()
+
+ def forward(self, x):
+ x = self.res_scale1(x) + \
+ self.layer_scale1(
+ self.drop_path1(
+ self.token_mixer(self.norm1(x))
+ )
+ )
+ x = self.res_scale2(x) + \
+ self.layer_scale2(
+ self.drop_path2(
+ self.mlp(self.norm2(x))
+ )
+ )
+ return x
+
+
+class MetaFormerStage(nn.Module):
+
+ def __init__(
+ self,
+ in_chs,
+ out_chs,
+ depth=2,
+ token_mixer=nn.Identity,
+ mlp_act=StarReLU,
+ mlp_bias=False,
+ downsample_norm=LayerNorm2d,
+ norm_layer=LayerNorm2d,
+ proj_drop=0.,
+ dp_rates=[0.] * 2,
+ layer_scale_init_value=None,
+ res_scale_init_value=None,
+ **kwargs,
+ ):
+ super().__init__()
+
+ self.grad_checkpointing = False
+ self.use_nchw = not issubclass(token_mixer, Attention)
+
+ # don't downsample if in_chs and out_chs are the same
+ self.downsample = nn.Identity() if in_chs == out_chs else Downsampling(
+ in_chs,
+ out_chs,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ norm_layer=downsample_norm,
+ )
+
+ self.blocks = nn.Sequential(*[MetaFormerBlock(
+ dim=out_chs,
+ token_mixer=token_mixer,
+ mlp_act=mlp_act,
+ mlp_bias=mlp_bias,
+ norm_layer=norm_layer,
+ proj_drop=proj_drop,
+ drop_path=dp_rates[i],
+ layer_scale_init_value=layer_scale_init_value,
+ res_scale_init_value=res_scale_init_value,
+ use_nchw=self.use_nchw,
+ **kwargs,
+ ) for i in range(depth)])
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.grad_checkpointing = enable
+
+ def forward(self, x: Tensor):
+ x = self.downsample(x)
+ B, C, H, W = x.shape
+
+ if not self.use_nchw:
+ x = x.reshape(B, C, -1).transpose(1, 2)
+
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint_seq(self.blocks, x)
+ else:
+ x = self.blocks(x)
+
+ if not self.use_nchw:
+ x = x.transpose(1, 2).reshape(B, C, H, W)
+
+ return x
+
+
+class MetaFormer(nn.Module):
+ r""" MetaFormer
+ A PyTorch impl of : `MetaFormer Baselines for Vision` -
+ https://arxiv.org/abs/2210.13452
+
+ Args:
+ in_chans (int): Number of input image channels.
+ num_classes (int): Number of classes for classification head.
+ global_pool: Pooling for classifier head.
+ depths (list or tuple): Number of blocks at each stage.
+ dims (list or tuple): Feature dimension at each stage.
+ token_mixers (list, tuple or token_fcn): Token mixer for each stage.
+ mlp_act: Activation layer for MLP.
+ mlp_bias (boolean): Enable or disable mlp bias term.
+ drop_path_rate (float): Stochastic depth rate.
+ drop_rate (float): Dropout rate.
+ layer_scale_init_values (list, tuple, float or None): Init value for Layer Scale.
+ None means not use the layer scale. Form: https://arxiv.org/abs/2103.17239.
+ res_scale_init_values (list, tuple, float or None): Init value for res Scale on residual connections.
+ None means not use the res scale. From: https://arxiv.org/abs/2110.09456.
+ downsample_norm (nn.Module): Norm layer used in stem and downsampling layers.
+ norm_layers (list, tuple or norm_fcn): Norm layers for each stage.
+ output_norm: Norm layer before classifier head.
+ use_mlp_head: Use MLP classification head.
+ """
+
+ def __init__(
+ self,
+ in_chans=3,
+ num_classes=1000,
+ global_pool='avg',
+ depths=(2, 2, 6, 2),
+ dims=(64, 128, 320, 512),
+ token_mixers=Pooling,
+ mlp_act=StarReLU,
+ mlp_bias=False,
+ drop_path_rate=0.,
+ proj_drop_rate=0.,
+ drop_rate=0.0,
+ layer_scale_init_values=None,
+ res_scale_init_values=(None, None, 1.0, 1.0),
+ downsample_norm=LayerNorm2dNoBias,
+ norm_layers=LayerNorm2dNoBias,
+ output_norm=LayerNorm2d,
+ use_mlp_head=True,
+ **kwargs,
+ ):
+ super().__init__()
+ self.num_classes = num_classes
+ self.num_features = dims[-1]
+ self.drop_rate = drop_rate
+ self.use_mlp_head = use_mlp_head
+ self.num_stages = len(depths)
+
+ # convert everything to lists if they aren't indexable
+ if not isinstance(depths, (list, tuple)):
+ depths = [depths] # it means the model has only one stage
+ if not isinstance(dims, (list, tuple)):
+ dims = [dims]
+ if not isinstance(token_mixers, (list, tuple)):
+ token_mixers = [token_mixers] * self.num_stages
+ if not isinstance(norm_layers, (list, tuple)):
+ norm_layers = [norm_layers] * self.num_stages
+ if not isinstance(layer_scale_init_values, (list, tuple)):
+ layer_scale_init_values = [layer_scale_init_values] * self.num_stages
+ if not isinstance(res_scale_init_values, (list, tuple)):
+ res_scale_init_values = [res_scale_init_values] * self.num_stages
+
+ self.grad_checkpointing = False
+ self.feature_info = []
+
+ self.stem = Stem(
+ in_chans,
+ dims[0],
+ norm_layer=downsample_norm
+ )
+
+ stages = []
+ prev_dim = dims[0]
+ dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
+ for i in range(self.num_stages):
+ stages += [MetaFormerStage(
+ prev_dim,
+ dims[i],
+ depth=depths[i],
+ token_mixer=token_mixers[i],
+ mlp_act=mlp_act,
+ mlp_bias=mlp_bias,
+ proj_drop=proj_drop_rate,
+ dp_rates=dp_rates[i],
+ layer_scale_init_value=layer_scale_init_values[i],
+ res_scale_init_value=res_scale_init_values[i],
+ downsample_norm=downsample_norm,
+ norm_layer=norm_layers[i],
+ **kwargs,
+ )]
+ prev_dim = dims[i]
+ self.feature_info += [dict(num_chs=dims[i], reduction=2, module=f'stages.{i}')]
+
+ self.stages = nn.Sequential(*stages)
+
+ # if using MlpHead, dropout is handled by MlpHead
+ if num_classes > 0:
+ if self.use_mlp_head:
+ final = MlpHead(self.num_features, num_classes, drop_rate=self.drop_rate)
+ else:
+ final = nn.Linear(self.num_features, num_classes)
+ else:
+ final = nn.Identity()
+
+ self.head = nn.Sequential(OrderedDict([
+ ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
+ ('norm', output_norm(self.num_features)),
+ ('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
+ ('drop', nn.Dropout(drop_rate) if self.use_mlp_head else nn.Identity()),
+ ('fc', final)
+ ]))
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
+ trunc_normal_(m.weight, std=.02)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.grad_checkpointing = enable
+ for stage in self.stages:
+ stage.set_grad_checkpointing(enable=enable)
+
+ @torch.jit.ignore
+ def get_classifier(self):
+ return self.head.fc
+
+ def reset_classifier(self, num_classes=0, global_pool=None):
+ if global_pool is not None:
+ self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
+ self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
+ if num_classes > 0:
+ if self.use_mlp_head:
+ final = MlpHead(self.num_features, num_classes, drop_rate=self.drop_rate)
+ else:
+ final = nn.Linear(self.num_features, num_classes)
+ else:
+ final = nn.Identity()
+ self.head.fc = final
+
+ def forward_head(self, x: Tensor, pre_logits: bool = False):
+ # NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
+ x = self.head.global_pool(x)
+ x = self.head.norm(x)
+ x = self.head.flatten(x)
+ x = self.head.drop(x)
+ return x if pre_logits else self.head.fc(x)
+
+ def forward_features(self, x: Tensor):
+ x = self.stem(x)
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint_seq(self.stages, x)
+ else:
+ x = self.stages(x)
+ return x
+
+ def forward(self, x: Tensor):
+ x = self.forward_features(x)
+ x = self.forward_head(x)
+ return x
+
+
+# this works but it's long and breaks backwards compatability with weights from the poolformer-only impl
+def checkpoint_filter_fn(state_dict, model):
+ if 'stem.conv.weight' in state_dict:
+ return state_dict
+
+ import re
+ out_dict = {}
+ is_poolformerv1 = 'network.0.0.mlp.fc1.weight' in state_dict
+ model_state_dict = model.state_dict()
+ for k, v in state_dict.items():
+ if is_poolformerv1:
+ k = re.sub(r'layer_scale_([0-9]+)', r'layer_scale\1.scale', k)
+ k = k.replace('network.1', 'downsample_layers.1')
+ k = k.replace('network.3', 'downsample_layers.2')
+ k = k.replace('network.5', 'downsample_layers.3')
+ k = k.replace('network.2', 'network.1')
+ k = k.replace('network.4', 'network.2')
+ k = k.replace('network.6', 'network.3')
+ k = k.replace('network', 'stages')
+
+ k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k)
+ k = k.replace('downsample.proj', 'downsample.conv')
+ k = k.replace('patch_embed.proj', 'patch_embed.conv')
+ k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k)
+ k = k.replace('stages.0.downsample', 'patch_embed')
+ k = k.replace('patch_embed', 'stem')
+ k = k.replace('post_norm', 'norm')
+ k = k.replace('pre_norm', 'norm')
+ k = re.sub(r'^head', 'head.fc', k)
+ k = re.sub(r'^norm', 'head.norm', k)
+
+ if v.shape != model_state_dict[k] and v.numel() == model_state_dict[k].numel():
+ v = v.reshape(model_state_dict[k].shape)
+
+ out_dict[k] = v
+ return out_dict
+
+
+def _create_metaformer(variant, pretrained=False, **kwargs):
+ default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (2, 2, 6, 2))))
+ out_indices = kwargs.pop('out_indices', default_out_indices)
+
+ model = build_model_with_cfg(
+ MetaFormer,
+ variant,
+ pretrained,
+ pretrained_filter_fn=checkpoint_filter_fn,
+ feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
+ **kwargs,
+ )
+
+ return model
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 1.0, 'interpolation': 'bicubic',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'classifier': 'head.fc', 'first_conv': 'stem.conv',
+ **kwargs
+ }
+
+
+default_cfgs = generate_default_cfgs({
+ 'poolformer_s12.sail_in1k': _cfg(
+ hf_hub_id='timm/',
+ crop_pct=0.9),
+ 'poolformer_s24.sail_in1k': _cfg(
+ hf_hub_id='timm/',
+ crop_pct=0.9),
+ 'poolformer_s36.sail_in1k': _cfg(
+ hf_hub_id='timm/',
+ crop_pct=0.9),
+ 'poolformer_m36.sail_in1k': _cfg(
+ hf_hub_id='timm/',
+ crop_pct=0.95),
+ 'poolformer_m48.sail_in1k': _cfg(
+ hf_hub_id='timm/',
+ crop_pct=0.95),
+
+ 'poolformerv2_s12.sail_in1k': _cfg(hf_hub_id='timm/'),
+ 'poolformerv2_s24.sail_in1k': _cfg(hf_hub_id='timm/'),
+ 'poolformerv2_s36.sail_in1k': _cfg(hf_hub_id='timm/'),
+ 'poolformerv2_m36.sail_in1k': _cfg(hf_hub_id='timm/'),
+ 'poolformerv2_m48.sail_in1k': _cfg(hf_hub_id='timm/'),
+
+ 'convformer_s18.sail_in1k': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2'),
+ 'convformer_s18.sail_in1k_384': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
+ 'convformer_s18.sail_in22k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2'),
+ 'convformer_s18.sail_in22k_ft_in1k_384': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
+ 'convformer_s18.sail_in22k': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2', num_classes=21841),
+
+ 'convformer_s36.sail_in1k': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2'),
+ 'convformer_s36.sail_in1k_384': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
+ 'convformer_s36.sail_in22k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2'),
+ 'convformer_s36.sail_in22k_ft_in1k_384': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
+ 'convformer_s36.sail_in22k': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2', num_classes=21841),
+
+ 'convformer_m36.sail_in1k': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2'),
+ 'convformer_m36.sail_in1k_384': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
+ 'convformer_m36.sail_in22k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2'),
+ 'convformer_m36.sail_in22k_ft_in1k_384': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
+ 'convformer_m36.sail_in22k': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2', num_classes=21841),
+
+ 'convformer_b36.sail_in1k': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2'),
+ 'convformer_b36.sail_in1k_384': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
+ 'convformer_b36.sail_in22k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2'),
+ 'convformer_b36.sail_in22k_ft_in1k_384': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
+ 'convformer_b36.sail_in22k': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2', num_classes=21841),
+
+ 'caformer_s18.sail_in1k': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2'),
+ 'caformer_s18.sail_in1k_384': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
+ 'caformer_s18.sail_in22k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2'),
+ 'caformer_s18.sail_in22k_ft_in1k_384': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
+ 'caformer_s18.sail_in22k': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2', num_classes=21841),
+
+ 'caformer_s36.sail_in1k': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2'),
+ 'caformer_s36.sail_in1k_384': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
+ 'caformer_s36.sail_in22k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2'),
+ 'caformer_s36.sail_in22k_ft_in1k_384': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
+ 'caformer_s36.sail_in22k': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2', num_classes=21841),
+
+ 'caformer_m36.sail_in1k': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2'),
+ 'caformer_m36.sail_in1k_384': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
+ 'caformer_m36.sail_in22k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2'),
+ 'caformer_m36.sail_in22k_ft_in1k_384': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
+ 'caformer_m36.sail_in22k': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2', num_classes=21841),
+
+ 'caformer_b36.sail_in1k': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2'),
+ 'caformer_b36.sail_in1k_384': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
+ 'caformer_b36.sail_in22k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2'),
+ 'caformer_b36.sail_in22k_ft_in1k_384': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2', input_size=(3, 384, 384), pool_size=(12, 12)),
+ 'caformer_b36.sail_in22k': _cfg(
+ hf_hub_id='timm/',
+ classifier='head.fc.fc2', num_classes=21841),
+})
+
+
+@register_model
+def poolformer_s12(pretrained=False, **kwargs) -> MetaFormer:
+ model_kwargs = dict(
+ depths=[2, 2, 6, 2],
+ dims=[64, 128, 320, 512],
+ downsample_norm=None,
+ mlp_act=nn.GELU,
+ mlp_bias=True,
+ norm_layers=GroupNorm1,
+ layer_scale_init_values=1e-5,
+ res_scale_init_values=None,
+ use_mlp_head=False,
+ **kwargs)
+ return _create_metaformer('poolformer_s12', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def poolformer_s24(pretrained=False, **kwargs) -> MetaFormer:
+ model_kwargs = dict(
+ depths=[4, 4, 12, 4],
+ dims=[64, 128, 320, 512],
+ downsample_norm=None,
+ mlp_act=nn.GELU,
+ mlp_bias=True,
+ norm_layers=GroupNorm1,
+ layer_scale_init_values=1e-5,
+ res_scale_init_values=None,
+ use_mlp_head=False,
+ **kwargs)
+ return _create_metaformer('poolformer_s24', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def poolformer_s36(pretrained=False, **kwargs) -> MetaFormer:
+ model_kwargs = dict(
+ depths=[6, 6, 18, 6],
+ dims=[64, 128, 320, 512],
+ downsample_norm=None,
+ mlp_act=nn.GELU,
+ mlp_bias=True,
+ norm_layers=GroupNorm1,
+ layer_scale_init_values=1e-6,
+ res_scale_init_values=None,
+ use_mlp_head=False,
+ **kwargs)
+ return _create_metaformer('poolformer_s36', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def poolformer_m36(pretrained=False, **kwargs) -> MetaFormer:
+ model_kwargs = dict(
+ depths=[6, 6, 18, 6],
+ dims=[96, 192, 384, 768],
+ downsample_norm=None,
+ mlp_act=nn.GELU,
+ mlp_bias=True,
+ norm_layers=GroupNorm1,
+ layer_scale_init_values=1e-6,
+ res_scale_init_values=None,
+ use_mlp_head=False,
+ **kwargs)
+ return _create_metaformer('poolformer_m36', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def poolformer_m48(pretrained=False, **kwargs) -> MetaFormer:
+ model_kwargs = dict(
+ depths=[8, 8, 24, 8],
+ dims=[96, 192, 384, 768],
+ downsample_norm=None,
+ mlp_act=nn.GELU,
+ mlp_bias=True,
+ norm_layers=GroupNorm1,
+ layer_scale_init_values=1e-6,
+ res_scale_init_values=None,
+ use_mlp_head=False,
+ **kwargs)
+ return _create_metaformer('poolformer_m48', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def poolformerv2_s12(pretrained=False, **kwargs) -> MetaFormer:
+ model_kwargs = dict(
+ depths=[2, 2, 6, 2],
+ dims=[64, 128, 320, 512],
+ norm_layers=GroupNorm1NoBias,
+ use_mlp_head=False,
+ **kwargs)
+ return _create_metaformer('poolformerv2_s12', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def poolformerv2_s24(pretrained=False, **kwargs) -> MetaFormer:
+ model_kwargs = dict(
+ depths=[4, 4, 12, 4],
+ dims=[64, 128, 320, 512],
+ norm_layers=GroupNorm1NoBias,
+ use_mlp_head=False,
+ **kwargs)
+ return _create_metaformer('poolformerv2_s24', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def poolformerv2_s36(pretrained=False, **kwargs) -> MetaFormer:
+ model_kwargs = dict(
+ depths=[6, 6, 18, 6],
+ dims=[64, 128, 320, 512],
+ norm_layers=GroupNorm1NoBias,
+ use_mlp_head=False,
+ **kwargs)
+ return _create_metaformer('poolformerv2_s36', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def poolformerv2_m36(pretrained=False, **kwargs) -> MetaFormer:
+ model_kwargs = dict(
+ depths=[6, 6, 18, 6],
+ dims=[96, 192, 384, 768],
+ norm_layers=GroupNorm1NoBias,
+ use_mlp_head=False,
+ **kwargs)
+ return _create_metaformer('poolformerv2_m36', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def poolformerv2_m48(pretrained=False, **kwargs) -> MetaFormer:
+ model_kwargs = dict(
+ depths=[8, 8, 24, 8],
+ dims=[96, 192, 384, 768],
+ norm_layers=GroupNorm1NoBias,
+ use_mlp_head=False,
+ **kwargs)
+ return _create_metaformer('poolformerv2_m48', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def convformer_s18(pretrained=False, **kwargs) -> MetaFormer:
+ model_kwargs = dict(
+ depths=[3, 3, 9, 3],
+ dims=[64, 128, 320, 512],
+ token_mixers=SepConv,
+ norm_layers=LayerNorm2dNoBias,
+ **kwargs)
+ return _create_metaformer('convformer_s18', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def convformer_s36(pretrained=False, **kwargs) -> MetaFormer:
+ model_kwargs = dict(
+ depths=[3, 12, 18, 3],
+ dims=[64, 128, 320, 512],
+ token_mixers=SepConv,
+ norm_layers=LayerNorm2dNoBias,
+ **kwargs)
+ return _create_metaformer('convformer_s36', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def convformer_m36(pretrained=False, **kwargs) -> MetaFormer:
+ model_kwargs = dict(
+ depths=[3, 12, 18, 3],
+ dims=[96, 192, 384, 576],
+ token_mixers=SepConv,
+ norm_layers=LayerNorm2dNoBias,
+ **kwargs)
+ return _create_metaformer('convformer_m36', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def convformer_b36(pretrained=False, **kwargs) -> MetaFormer:
+ model_kwargs = dict(
+ depths=[3, 12, 18, 3],
+ dims=[128, 256, 512, 768],
+ token_mixers=SepConv,
+ norm_layers=LayerNorm2dNoBias,
+ **kwargs)
+ return _create_metaformer('convformer_b36', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def caformer_s18(pretrained=False, **kwargs) -> MetaFormer:
+ model_kwargs = dict(
+ depths=[3, 3, 9, 3],
+ dims=[64, 128, 320, 512],
+ token_mixers=[SepConv, SepConv, Attention, Attention],
+ norm_layers=[LayerNorm2dNoBias] * 2 + [LayerNormNoBias] * 2,
+ **kwargs)
+ return _create_metaformer('caformer_s18', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def caformer_s36(pretrained=False, **kwargs) -> MetaFormer:
+ model_kwargs = dict(
+ depths=[3, 12, 18, 3],
+ dims=[64, 128, 320, 512],
+ token_mixers=[SepConv, SepConv, Attention, Attention],
+ norm_layers=[LayerNorm2dNoBias] * 2 + [LayerNormNoBias] * 2,
+ **kwargs)
+ return _create_metaformer('caformer_s36', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def caformer_m36(pretrained=False, **kwargs) -> MetaFormer:
+ model_kwargs = dict(
+ depths=[3, 12, 18, 3],
+ dims=[96, 192, 384, 576],
+ token_mixers=[SepConv, SepConv, Attention, Attention],
+ norm_layers=[LayerNorm2dNoBias] * 2 + [LayerNormNoBias] * 2,
+ **kwargs)
+ return _create_metaformer('caformer_m36', pretrained=pretrained, **model_kwargs)
+
+
+@register_model
+def caformer_b36(pretrained=False, **kwargs) -> MetaFormer:
+ model_kwargs = dict(
+ depths=[3, 12, 18, 3],
+ dims=[128, 256, 512, 768],
+ token_mixers=[SepConv, SepConv, Attention, Attention],
+ norm_layers=[LayerNorm2dNoBias] * 2 + [LayerNormNoBias] * 2,
+ **kwargs)
+ return _create_metaformer('caformer_b36', pretrained=pretrained, **model_kwargs)
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/mlp_mixer.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/mlp_mixer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7d64349d4dc3a3f546d6c69afbebb783b79849f
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/mlp_mixer.py
@@ -0,0 +1,636 @@
+""" MLP-Mixer, ResMLP, and gMLP in PyTorch
+
+This impl originally based on MLP-Mixer paper.
+
+Official JAX impl: https://github.com/google-research/vision_transformer/blob/linen/vit_jax/models_mixer.py
+
+Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
+
+@article{tolstikhin2021,
+ title={MLP-Mixer: An all-MLP Architecture for Vision},
+ author={Tolstikhin, Ilya and Houlsby, Neil and Kolesnikov, Alexander and Beyer, Lucas and Zhai, Xiaohua and Unterthiner,
+ Thomas and Yung, Jessica and Keysers, Daniel and Uszkoreit, Jakob and Lucic, Mario and Dosovitskiy, Alexey},
+ journal={arXiv preprint arXiv:2105.01601},
+ year={2021}
+}
+
+Also supporting ResMlp, and a preliminary (not verified) implementations of gMLP
+
+Code: https://github.com/facebookresearch/deit
+Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
+@misc{touvron2021resmlp,
+ title={ResMLP: Feedforward networks for image classification with data-efficient training},
+ author={Hugo Touvron and Piotr Bojanowski and Mathilde Caron and Matthieu Cord and Alaaeldin El-Nouby and
+ Edouard Grave and Armand Joulin and Gabriel Synnaeve and Jakob Verbeek and Hervé Jégou},
+ year={2021},
+ eprint={2105.03404},
+}
+
+Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
+@misc{liu2021pay,
+ title={Pay Attention to MLPs},
+ author={Hanxiao Liu and Zihang Dai and David R. So and Quoc V. Le},
+ year={2021},
+ eprint={2105.08050},
+}
+
+A thank you to paper authors for releasing code and weights.
+
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+import math
+from functools import partial
+
+import torch
+import torch.nn as nn
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple
+from ._builder import build_model_with_cfg
+from ._manipulate import named_apply, checkpoint_seq
+from ._registry import generate_default_cfgs, register_model, register_model_deprecations
+
+__all__ = ['MixerBlock', 'MlpMixer'] # model_registry will add each entrypoint fn to this
+
+
+class MixerBlock(nn.Module):
+ """ Residual Block w/ token mixing and channel MLPs
+ Based on: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
+ """
+ def __init__(
+ self,
+ dim,
+ seq_len,
+ mlp_ratio=(0.5, 4.0),
+ mlp_layer=Mlp,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ act_layer=nn.GELU,
+ drop=0.,
+ drop_path=0.,
+ ):
+ super().__init__()
+ tokens_dim, channels_dim = [int(x * dim) for x in to_2tuple(mlp_ratio)]
+ self.norm1 = norm_layer(dim)
+ self.mlp_tokens = mlp_layer(seq_len, tokens_dim, act_layer=act_layer, drop=drop)
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, x):
+ x = x + self.drop_path(self.mlp_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2))
+ x = x + self.drop_path(self.mlp_channels(self.norm2(x)))
+ return x
+
+
+class Affine(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.alpha = nn.Parameter(torch.ones((1, 1, dim)))
+ self.beta = nn.Parameter(torch.zeros((1, 1, dim)))
+
+ def forward(self, x):
+ return torch.addcmul(self.beta, self.alpha, x)
+
+
+class ResBlock(nn.Module):
+ """ Residual MLP block w/ LayerScale and Affine 'norm'
+
+ Based on: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
+ """
+ def __init__(
+ self,
+ dim,
+ seq_len,
+ mlp_ratio=4,
+ mlp_layer=Mlp,
+ norm_layer=Affine,
+ act_layer=nn.GELU,
+ init_values=1e-4,
+ drop=0.,
+ drop_path=0.,
+ ):
+ super().__init__()
+ channel_dim = int(dim * mlp_ratio)
+ self.norm1 = norm_layer(dim)
+ self.linear_tokens = nn.Linear(seq_len, seq_len)
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ self.mlp_channels = mlp_layer(dim, channel_dim, act_layer=act_layer, drop=drop)
+ self.ls1 = nn.Parameter(init_values * torch.ones(dim))
+ self.ls2 = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x):
+ x = x + self.drop_path(self.ls1 * self.linear_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2))
+ x = x + self.drop_path(self.ls2 * self.mlp_channels(self.norm2(x)))
+ return x
+
+
+class SpatialGatingUnit(nn.Module):
+ """ Spatial Gating Unit
+
+ Based on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
+ """
+ def __init__(self, dim, seq_len, norm_layer=nn.LayerNorm):
+ super().__init__()
+ gate_dim = dim // 2
+ self.norm = norm_layer(gate_dim)
+ self.proj = nn.Linear(seq_len, seq_len)
+
+ def init_weights(self):
+ # special init for the projection gate, called as override by base model init
+ nn.init.normal_(self.proj.weight, std=1e-6)
+ nn.init.ones_(self.proj.bias)
+
+ def forward(self, x):
+ u, v = x.chunk(2, dim=-1)
+ v = self.norm(v)
+ v = self.proj(v.transpose(-1, -2))
+ return u * v.transpose(-1, -2)
+
+
+class SpatialGatingBlock(nn.Module):
+ """ Residual Block w/ Spatial Gating
+
+ Based on: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
+ """
+ def __init__(
+ self,
+ dim,
+ seq_len,
+ mlp_ratio=4,
+ mlp_layer=GatedMlp,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ act_layer=nn.GELU,
+ drop=0.,
+ drop_path=0.,
+ ):
+ super().__init__()
+ channel_dim = int(dim * mlp_ratio)
+ self.norm = norm_layer(dim)
+ sgu = partial(SpatialGatingUnit, seq_len=seq_len)
+ self.mlp_channels = mlp_layer(dim, channel_dim, act_layer=act_layer, gate_layer=sgu, drop=drop)
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def forward(self, x):
+ x = x + self.drop_path(self.mlp_channels(self.norm(x)))
+ return x
+
+
+class MlpMixer(nn.Module):
+
+ def __init__(
+ self,
+ num_classes=1000,
+ img_size=224,
+ in_chans=3,
+ patch_size=16,
+ num_blocks=8,
+ embed_dim=512,
+ mlp_ratio=(0.5, 4.0),
+ block_layer=MixerBlock,
+ mlp_layer=Mlp,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ act_layer=nn.GELU,
+ drop_rate=0.,
+ proj_drop_rate=0.,
+ drop_path_rate=0.,
+ nlhb=False,
+ stem_norm=False,
+ global_pool='avg',
+ ):
+ super().__init__()
+ self.num_classes = num_classes
+ self.global_pool = global_pool
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.grad_checkpointing = False
+
+ self.stem = PatchEmbed(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ norm_layer=norm_layer if stem_norm else None,
+ )
+ # FIXME drop_path (stochastic depth scaling rule or all the same?)
+ self.blocks = nn.Sequential(*[
+ block_layer(
+ embed_dim,
+ self.stem.num_patches,
+ mlp_ratio,
+ mlp_layer=mlp_layer,
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ drop=proj_drop_rate,
+ drop_path=drop_path_rate,
+ )
+ for _ in range(num_blocks)])
+ self.norm = norm_layer(embed_dim)
+ self.head_drop = nn.Dropout(drop_rate)
+ self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
+
+ self.init_weights(nlhb=nlhb)
+
+ @torch.jit.ignore
+ def init_weights(self, nlhb=False):
+ head_bias = -math.log(self.num_classes) if nlhb else 0.
+ named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse=False):
+ return dict(
+ stem=r'^stem', # stem and embed
+ blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
+ )
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.grad_checkpointing = enable
+
+ @torch.jit.ignore
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=None):
+ self.num_classes = num_classes
+ if global_pool is not None:
+ assert global_pool in ('', 'avg')
+ self.global_pool = global_pool
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ x = self.stem(x)
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint_seq(self.blocks, x)
+ else:
+ x = self.blocks(x)
+ x = self.norm(x)
+ return x
+
+ def forward_head(self, x, pre_logits: bool = False):
+ if self.global_pool == 'avg':
+ x = x.mean(dim=1)
+ x = self.head_drop(x)
+ return x if pre_logits else self.head(x)
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.forward_head(x)
+ return x
+
+
+def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=False):
+ """ Mixer weight initialization (trying to match Flax defaults)
+ """
+ if isinstance(module, nn.Linear):
+ if name.startswith('head'):
+ nn.init.zeros_(module.weight)
+ nn.init.constant_(module.bias, head_bias)
+ else:
+ if flax:
+ # Flax defaults
+ lecun_normal_(module.weight)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ else:
+ # like MLP init in vit (my original init)
+ nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ if 'mlp' in name:
+ nn.init.normal_(module.bias, std=1e-6)
+ else:
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, nn.Conv2d):
+ lecun_normal_(module.weight)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.ones_(module.weight)
+ nn.init.zeros_(module.bias)
+ elif hasattr(module, 'init_weights'):
+ # NOTE if a parent module contains init_weights method, it can override the init of the
+ # child modules as this will be called in depth-first order.
+ module.init_weights()
+
+
+def checkpoint_filter_fn(state_dict, model):
+ """ Remap checkpoints if needed """
+ if 'patch_embed.proj.weight' in state_dict:
+ # Remap FB ResMlp models -> timm
+ out_dict = {}
+ for k, v in state_dict.items():
+ k = k.replace('patch_embed.', 'stem.')
+ k = k.replace('attn.', 'linear_tokens.')
+ k = k.replace('mlp.', 'mlp_channels.')
+ k = k.replace('gamma_', 'ls')
+ if k.endswith('.alpha') or k.endswith('.beta'):
+ v = v.reshape(1, 1, -1)
+ out_dict[k] = v
+ return out_dict
+ return state_dict
+
+
+def _create_mixer(variant, pretrained=False, **kwargs):
+ if kwargs.get('features_only', None):
+ raise RuntimeError('features_only not implemented for MLP-Mixer models.')
+
+ model = build_model_with_cfg(
+ MlpMixer,
+ variant,
+ pretrained,
+ pretrained_filter_fn=checkpoint_filter_fn,
+ **kwargs,
+ )
+ return model
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': 0.875, 'interpolation': 'bicubic', 'fixed_input_size': True,
+ 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
+ 'first_conv': 'stem.proj', 'classifier': 'head',
+ **kwargs
+ }
+
+
+default_cfgs = generate_default_cfgs({
+ 'mixer_s32_224.untrained': _cfg(),
+ 'mixer_s16_224.untrained': _cfg(),
+ 'mixer_b32_224.untrained': _cfg(),
+ 'mixer_b16_224.goog_in21k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224-76587d61.pth',
+ ),
+ 'mixer_b16_224.goog_in21k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_b16_224_in21k-617b3de2.pth',
+ num_classes=21843
+ ),
+ 'mixer_l32_224.untrained': _cfg(),
+ 'mixer_l16_224.goog_in21k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224-92f9adc4.pth',
+ ),
+ 'mixer_l16_224.goog_in21k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224_in21k-846aa33c.pth',
+ num_classes=21843
+ ),
+
+ # Mixer ImageNet-21K-P pretraining
+ 'mixer_b16_224.miil_in21k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mixer_b16_224_miil_in21k-2a558a71.pth',
+ mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear', num_classes=11221,
+ ),
+ 'mixer_b16_224.miil_in21k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mixer_b16_224_miil-9229a591.pth',
+ mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear',
+ ),
+
+ 'gmixer_12_224.untrained': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+ 'gmixer_24_224.ra3_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gmixer_24_224_raa-7daf7ae6.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+
+ 'resmlp_12_224.fb_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/resmlp_12_no_dist.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+ 'resmlp_24_224.fb_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/resmlp_24_no_dist.pth',
+ #url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resmlp_24_224_raa-a8256759.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+ 'resmlp_36_224.fb_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/resmlp_36_no_dist.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+ 'resmlp_big_24_224.fb_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_no_dist.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+
+ 'resmlp_12_224.fb_distilled_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/resmlp_12_dist.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+ 'resmlp_24_224.fb_distilled_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/resmlp_24_dist.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+ 'resmlp_36_224.fb_distilled_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/resmlp_36_dist.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+ 'resmlp_big_24_224.fb_distilled_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_dist.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+
+ 'resmlp_big_24_224.fb_in22k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/resmlpB_24_22k.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+
+ 'resmlp_12_224.fb_dino': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/resmlp_12_dino.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+ 'resmlp_24_224.fb_dino': _cfg(
+ hf_hub_id='timm/',
+ url='https://dl.fbaipublicfiles.com/deit/resmlp_24_dino.pth',
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
+
+ 'gmlp_ti16_224.untrained': _cfg(),
+ 'gmlp_s16_224.ra3_in1k': _cfg(
+ hf_hub_id='timm/',
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gmlp_s16_224_raa-10536d42.pth',
+ ),
+ 'gmlp_b16_224.untrained': _cfg(),
+})
+
+
+@register_model
+def mixer_s32_224(pretrained=False, **kwargs) -> MlpMixer:
+ """ Mixer-S/32 224x224
+ Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
+ """
+ model_args = dict(patch_size=32, num_blocks=8, embed_dim=512, **kwargs)
+ model = _create_mixer('mixer_s32_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def mixer_s16_224(pretrained=False, **kwargs) -> MlpMixer:
+ """ Mixer-S/16 224x224
+ Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
+ """
+ model_args = dict(patch_size=16, num_blocks=8, embed_dim=512, **kwargs)
+ model = _create_mixer('mixer_s16_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def mixer_b32_224(pretrained=False, **kwargs) -> MlpMixer:
+ """ Mixer-B/32 224x224
+ Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
+ """
+ model_args = dict(patch_size=32, num_blocks=12, embed_dim=768, **kwargs)
+ model = _create_mixer('mixer_b32_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def mixer_b16_224(pretrained=False, **kwargs) -> MlpMixer:
+ """ Mixer-B/16 224x224. ImageNet-1k pretrained weights.
+ Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
+ """
+ model_args = dict(patch_size=16, num_blocks=12, embed_dim=768, **kwargs)
+ model = _create_mixer('mixer_b16_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def mixer_l32_224(pretrained=False, **kwargs) -> MlpMixer:
+ """ Mixer-L/32 224x224.
+ Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
+ """
+ model_args = dict(patch_size=32, num_blocks=24, embed_dim=1024, **kwargs)
+ model = _create_mixer('mixer_l32_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def mixer_l16_224(pretrained=False, **kwargs) -> MlpMixer:
+ """ Mixer-L/16 224x224. ImageNet-1k pretrained weights.
+ Paper: 'MLP-Mixer: An all-MLP Architecture for Vision' - https://arxiv.org/abs/2105.01601
+ """
+ model_args = dict(patch_size=16, num_blocks=24, embed_dim=1024, **kwargs)
+ model = _create_mixer('mixer_l16_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def gmixer_12_224(pretrained=False, **kwargs) -> MlpMixer:
+ """ Glu-Mixer-12 224x224
+ Experiment by Ross Wightman, adding SwiGLU to MLP-Mixer
+ """
+ model_args = dict(
+ patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=(1.0, 4.0),
+ mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs)
+ model = _create_mixer('gmixer_12_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def gmixer_24_224(pretrained=False, **kwargs) -> MlpMixer:
+ """ Glu-Mixer-24 224x224
+ Experiment by Ross Wightman, adding SwiGLU to MLP-Mixer
+ """
+ model_args = dict(
+ patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=(1.0, 4.0),
+ mlp_layer=GluMlp, act_layer=nn.SiLU, **kwargs)
+ model = _create_mixer('gmixer_24_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def resmlp_12_224(pretrained=False, **kwargs) -> MlpMixer:
+ """ ResMLP-12
+ Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
+ """
+ model_args = dict(
+ patch_size=16, num_blocks=12, embed_dim=384, mlp_ratio=4, block_layer=ResBlock, norm_layer=Affine, **kwargs)
+ model = _create_mixer('resmlp_12_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def resmlp_24_224(pretrained=False, **kwargs) -> MlpMixer:
+ """ ResMLP-24
+ Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
+ """
+ model_args = dict(
+ patch_size=16, num_blocks=24, embed_dim=384, mlp_ratio=4,
+ block_layer=partial(ResBlock, init_values=1e-5), norm_layer=Affine, **kwargs)
+ model = _create_mixer('resmlp_24_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def resmlp_36_224(pretrained=False, **kwargs) -> MlpMixer:
+ """ ResMLP-36
+ Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
+ """
+ model_args = dict(
+ patch_size=16, num_blocks=36, embed_dim=384, mlp_ratio=4,
+ block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
+ model = _create_mixer('resmlp_36_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def resmlp_big_24_224(pretrained=False, **kwargs) -> MlpMixer:
+ """ ResMLP-B-24
+ Paper: `ResMLP: Feedforward networks for image classification...` - https://arxiv.org/abs/2105.03404
+ """
+ model_args = dict(
+ patch_size=8, num_blocks=24, embed_dim=768, mlp_ratio=4,
+ block_layer=partial(ResBlock, init_values=1e-6), norm_layer=Affine, **kwargs)
+ model = _create_mixer('resmlp_big_24_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def gmlp_ti16_224(pretrained=False, **kwargs) -> MlpMixer:
+ """ gMLP-Tiny
+ Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
+ """
+ model_args = dict(
+ patch_size=16, num_blocks=30, embed_dim=128, mlp_ratio=6, block_layer=SpatialGatingBlock,
+ mlp_layer=GatedMlp, **kwargs)
+ model = _create_mixer('gmlp_ti16_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def gmlp_s16_224(pretrained=False, **kwargs) -> MlpMixer:
+ """ gMLP-Small
+ Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
+ """
+ model_args = dict(
+ patch_size=16, num_blocks=30, embed_dim=256, mlp_ratio=6, block_layer=SpatialGatingBlock,
+ mlp_layer=GatedMlp, **kwargs)
+ model = _create_mixer('gmlp_s16_224', pretrained=pretrained, **model_args)
+ return model
+
+
+@register_model
+def gmlp_b16_224(pretrained=False, **kwargs) -> MlpMixer:
+ """ gMLP-Base
+ Paper: `Pay Attention to MLPs` - https://arxiv.org/abs/2105.08050
+ """
+ model_args = dict(
+ patch_size=16, num_blocks=30, embed_dim=512, mlp_ratio=6, block_layer=SpatialGatingBlock,
+ mlp_layer=GatedMlp, **kwargs)
+ model = _create_mixer('gmlp_b16_224', pretrained=pretrained, **model_args)
+ return model
+
+
+register_model_deprecations(__name__, {
+ 'mixer_b16_224_in21k': 'mixer_b16_224.goog_in21k_ft_in1k',
+ 'mixer_l16_224_in21k': 'mixer_l16_224.goog_in21k_ft_in1k',
+ 'mixer_b16_224_miil': 'mixer_b16_224.miil_in21k_ft_in1k',
+ 'mixer_b16_224_miil_in21k': 'mixer_b16_224.miil_in21k',
+ 'resmlp_12_distilled_224': 'resmlp_12_224.fb_distilled_in1k',
+ 'resmlp_24_distilled_224': 'resmlp_24_224.fb_distilled_in1k',
+ 'resmlp_36_distilled_224': 'resmlp_36_224.fb_distilled_in1k',
+ 'resmlp_big_24_distilled_224': 'resmlp_big_24_224.fb_distilled_in1k',
+ 'resmlp_big_24_224_in22ft1k': 'resmlp_big_24_224.fb_in22k_ft_in1k',
+ 'resmlp_12_224_dino': 'resmlp_12_224',
+ 'resmlp_24_224_dino': 'resmlp_24_224',
+})
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/mobilevit.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/mobilevit.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c84871e6daeca044865574437f9e0294a6f3e7d
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/mobilevit.py
@@ -0,0 +1,681 @@
+""" MobileViT
+
+Paper:
+V1: `MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer` - https://arxiv.org/abs/2110.02178
+V2: `Separable Self-attention for Mobile Vision Transformers` - https://arxiv.org/abs/2206.02680
+
+MobileVitBlock and checkpoints adapted from https://github.com/apple/ml-cvnets (original copyright below)
+License: https://github.com/apple/ml-cvnets/blob/main/LICENSE (Apple open source)
+
+Rest of code, ByobNet, and Transformer block hacked together by / Copyright 2022, Ross Wightman
+"""
+#
+# For licensing see accompanying LICENSE file.
+# Copyright (C) 2020 Apple Inc. All Rights Reserved.
+#
+import math
+from typing import Callable, Tuple, Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from timm.layers import to_2tuple, make_divisible, GroupNorm1, ConvMlp, DropPath, is_exportable
+from ._builder import build_model_with_cfg
+from ._features_fx import register_notrace_module
+from ._registry import register_model, generate_default_cfgs, register_model_deprecations
+from .byobnet import register_block, ByoBlockCfg, ByoModelCfg, ByobNet, LayerFn, num_groups
+from .vision_transformer import Block as TransformerBlock
+
+__all__ = []
+
+
+def _inverted_residual_block(d, c, s, br=4.0):
+ # inverted residual is a bottleneck block with bottle_ratio > 1 applied to in_chs, linear output, gs=1 (depthwise)
+ return ByoBlockCfg(
+ type='bottle', d=d, c=c, s=s, gs=1, br=br,
+ block_kwargs=dict(bottle_in=True, linear_out=True))
+
+
+def _mobilevit_block(d, c, s, transformer_dim, transformer_depth, patch_size=4, br=4.0):
+ # inverted residual + mobilevit blocks as per MobileViT network
+ return (
+ _inverted_residual_block(d=d, c=c, s=s, br=br),
+ ByoBlockCfg(
+ type='mobilevit', d=1, c=c, s=1,
+ block_kwargs=dict(
+ transformer_dim=transformer_dim,
+ transformer_depth=transformer_depth,
+ patch_size=patch_size)
+ )
+ )
+
+
+def _mobilevitv2_block(d, c, s, transformer_depth, patch_size=2, br=2.0, transformer_br=0.5):
+ # inverted residual + mobilevit blocks as per MobileViT network
+ return (
+ _inverted_residual_block(d=d, c=c, s=s, br=br),
+ ByoBlockCfg(
+ type='mobilevit2', d=1, c=c, s=1, br=transformer_br, gs=1,
+ block_kwargs=dict(
+ transformer_depth=transformer_depth,
+ patch_size=patch_size)
+ )
+ )
+
+
+def _mobilevitv2_cfg(multiplier=1.0):
+ chs = (64, 128, 256, 384, 512)
+ if multiplier != 1.0:
+ chs = tuple([int(c * multiplier) for c in chs])
+ cfg = ByoModelCfg(
+ blocks=(
+ _inverted_residual_block(d=1, c=chs[0], s=1, br=2.0),
+ _inverted_residual_block(d=2, c=chs[1], s=2, br=2.0),
+ _mobilevitv2_block(d=1, c=chs[2], s=2, transformer_depth=2),
+ _mobilevitv2_block(d=1, c=chs[3], s=2, transformer_depth=4),
+ _mobilevitv2_block(d=1, c=chs[4], s=2, transformer_depth=3),
+ ),
+ stem_chs=int(32 * multiplier),
+ stem_type='3x3',
+ stem_pool='',
+ downsample='',
+ act_layer='silu',
+ )
+ return cfg
+
+
+model_cfgs = dict(
+ mobilevit_xxs=ByoModelCfg(
+ blocks=(
+ _inverted_residual_block(d=1, c=16, s=1, br=2.0),
+ _inverted_residual_block(d=3, c=24, s=2, br=2.0),
+ _mobilevit_block(d=1, c=48, s=2, transformer_dim=64, transformer_depth=2, patch_size=2, br=2.0),
+ _mobilevit_block(d=1, c=64, s=2, transformer_dim=80, transformer_depth=4, patch_size=2, br=2.0),
+ _mobilevit_block(d=1, c=80, s=2, transformer_dim=96, transformer_depth=3, patch_size=2, br=2.0),
+ ),
+ stem_chs=16,
+ stem_type='3x3',
+ stem_pool='',
+ downsample='',
+ act_layer='silu',
+ num_features=320,
+ ),
+
+ mobilevit_xs=ByoModelCfg(
+ blocks=(
+ _inverted_residual_block(d=1, c=32, s=1),
+ _inverted_residual_block(d=3, c=48, s=2),
+ _mobilevit_block(d=1, c=64, s=2, transformer_dim=96, transformer_depth=2, patch_size=2),
+ _mobilevit_block(d=1, c=80, s=2, transformer_dim=120, transformer_depth=4, patch_size=2),
+ _mobilevit_block(d=1, c=96, s=2, transformer_dim=144, transformer_depth=3, patch_size=2),
+ ),
+ stem_chs=16,
+ stem_type='3x3',
+ stem_pool='',
+ downsample='',
+ act_layer='silu',
+ num_features=384,
+ ),
+
+ mobilevit_s=ByoModelCfg(
+ blocks=(
+ _inverted_residual_block(d=1, c=32, s=1),
+ _inverted_residual_block(d=3, c=64, s=2),
+ _mobilevit_block(d=1, c=96, s=2, transformer_dim=144, transformer_depth=2, patch_size=2),
+ _mobilevit_block(d=1, c=128, s=2, transformer_dim=192, transformer_depth=4, patch_size=2),
+ _mobilevit_block(d=1, c=160, s=2, transformer_dim=240, transformer_depth=3, patch_size=2),
+ ),
+ stem_chs=16,
+ stem_type='3x3',
+ stem_pool='',
+ downsample='',
+ act_layer='silu',
+ num_features=640,
+ ),
+
+ semobilevit_s=ByoModelCfg(
+ blocks=(
+ _inverted_residual_block(d=1, c=32, s=1),
+ _inverted_residual_block(d=3, c=64, s=2),
+ _mobilevit_block(d=1, c=96, s=2, transformer_dim=144, transformer_depth=2, patch_size=2),
+ _mobilevit_block(d=1, c=128, s=2, transformer_dim=192, transformer_depth=4, patch_size=2),
+ _mobilevit_block(d=1, c=160, s=2, transformer_dim=240, transformer_depth=3, patch_size=2),
+ ),
+ stem_chs=16,
+ stem_type='3x3',
+ stem_pool='',
+ downsample='',
+ attn_layer='se',
+ attn_kwargs=dict(rd_ratio=1/8),
+ num_features=640,
+ ),
+
+ mobilevitv2_050=_mobilevitv2_cfg(.50),
+ mobilevitv2_075=_mobilevitv2_cfg(.75),
+ mobilevitv2_125=_mobilevitv2_cfg(1.25),
+ mobilevitv2_100=_mobilevitv2_cfg(1.0),
+ mobilevitv2_150=_mobilevitv2_cfg(1.5),
+ mobilevitv2_175=_mobilevitv2_cfg(1.75),
+ mobilevitv2_200=_mobilevitv2_cfg(2.0),
+)
+
+
+@register_notrace_module
+class MobileVitBlock(nn.Module):
+ """ MobileViT block
+ Paper: https://arxiv.org/abs/2110.02178?context=cs.LG
+ """
+ def __init__(
+ self,
+ in_chs: int,
+ out_chs: Optional[int] = None,
+ kernel_size: int = 3,
+ stride: int = 1,
+ bottle_ratio: float = 1.0,
+ group_size: Optional[int] = None,
+ dilation: Tuple[int, int] = (1, 1),
+ mlp_ratio: float = 2.0,
+ transformer_dim: Optional[int] = None,
+ transformer_depth: int = 2,
+ patch_size: int = 8,
+ num_heads: int = 4,
+ attn_drop: float = 0.,
+ drop: int = 0.,
+ no_fusion: bool = False,
+ drop_path_rate: float = 0.,
+ layers: LayerFn = None,
+ transformer_norm_layer: Callable = nn.LayerNorm,
+ **kwargs, # eat unused args
+ ):
+ super(MobileVitBlock, self).__init__()
+
+ layers = layers or LayerFn()
+ groups = num_groups(group_size, in_chs)
+ out_chs = out_chs or in_chs
+ transformer_dim = transformer_dim or make_divisible(bottle_ratio * in_chs)
+
+ self.conv_kxk = layers.conv_norm_act(
+ in_chs, in_chs, kernel_size=kernel_size,
+ stride=stride, groups=groups, dilation=dilation[0])
+ self.conv_1x1 = nn.Conv2d(in_chs, transformer_dim, kernel_size=1, bias=False)
+
+ self.transformer = nn.Sequential(*[
+ TransformerBlock(
+ transformer_dim,
+ mlp_ratio=mlp_ratio,
+ num_heads=num_heads,
+ qkv_bias=True,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ drop_path=drop_path_rate,
+ act_layer=layers.act,
+ norm_layer=transformer_norm_layer,
+ )
+ for _ in range(transformer_depth)
+ ])
+ self.norm = transformer_norm_layer(transformer_dim)
+
+ self.conv_proj = layers.conv_norm_act(transformer_dim, out_chs, kernel_size=1, stride=1)
+
+ if no_fusion:
+ self.conv_fusion = None
+ else:
+ self.conv_fusion = layers.conv_norm_act(in_chs + out_chs, out_chs, kernel_size=kernel_size, stride=1)
+
+ self.patch_size = to_2tuple(patch_size)
+ self.patch_area = self.patch_size[0] * self.patch_size[1]
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ shortcut = x
+
+ # Local representation
+ x = self.conv_kxk(x)
+ x = self.conv_1x1(x)
+
+ # Unfold (feature map -> patches)
+ patch_h, patch_w = self.patch_size
+ B, C, H, W = x.shape
+ new_h, new_w = math.ceil(H / patch_h) * patch_h, math.ceil(W / patch_w) * patch_w
+ num_patch_h, num_patch_w = new_h // patch_h, new_w // patch_w # n_h, n_w
+ num_patches = num_patch_h * num_patch_w # N
+ interpolate = False
+ if new_h != H or new_w != W:
+ # Note: Padding can be done, but then it needs to be handled in attention function.
+ x = F.interpolate(x, size=(new_h, new_w), mode="bilinear", align_corners=False)
+ interpolate = True
+
+ # [B, C, H, W] --> [B * C * n_h, n_w, p_h, p_w]
+ x = x.reshape(B * C * num_patch_h, patch_h, num_patch_w, patch_w).transpose(1, 2)
+ # [B * C * n_h, n_w, p_h, p_w] --> [BP, N, C] where P = p_h * p_w and N = n_h * n_w
+ x = x.reshape(B, C, num_patches, self.patch_area).transpose(1, 3).reshape(B * self.patch_area, num_patches, -1)
+
+ # Global representations
+ x = self.transformer(x)
+ x = self.norm(x)
+
+ # Fold (patch -> feature map)
+ # [B, P, N, C] --> [B*C*n_h, n_w, p_h, p_w]
+ x = x.contiguous().view(B, self.patch_area, num_patches, -1)
+ x = x.transpose(1, 3).reshape(B * C * num_patch_h, num_patch_w, patch_h, patch_w)
+ # [B*C*n_h, n_w, p_h, p_w] --> [B*C*n_h, p_h, n_w, p_w] --> [B, C, H, W]
+ x = x.transpose(1, 2).reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w)
+ if interpolate:
+ x = F.interpolate(x, size=(H, W), mode="bilinear", align_corners=False)
+
+ x = self.conv_proj(x)
+ if self.conv_fusion is not None:
+ x = self.conv_fusion(torch.cat((shortcut, x), dim=1))
+ return x
+
+
+class LinearSelfAttention(nn.Module):
+ """
+ This layer applies a self-attention with linear complexity, as described in `https://arxiv.org/abs/2206.02680`
+ This layer can be used for self- as well as cross-attention.
+ Args:
+ embed_dim (int): :math:`C` from an expected input of size :math:`(N, C, H, W)`
+ attn_drop (float): Dropout value for context scores. Default: 0.0
+ bias (bool): Use bias in learnable layers. Default: True
+ Shape:
+ - Input: :math:`(N, C, P, N)` where :math:`N` is the batch size, :math:`C` is the input channels,
+ :math:`P` is the number of pixels in the patch, and :math:`N` is the number of patches
+ - Output: same as the input
+ .. note::
+ For MobileViTv2, we unfold the feature map [B, C, H, W] into [B, C, P, N] where P is the number of pixels
+ in a patch and N is the number of patches. Because channel is the first dimension in this unfolded tensor,
+ we use point-wise convolution (instead of a linear layer). This avoids a transpose operation (which may be
+ expensive on resource-constrained devices) that may be required to convert the unfolded tensor from
+ channel-first to channel-last format in case of a linear layer.
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ self.embed_dim = embed_dim
+
+ self.qkv_proj = nn.Conv2d(
+ in_channels=embed_dim,
+ out_channels=1 + (2 * embed_dim),
+ bias=bias,
+ kernel_size=1,
+ )
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.out_proj = nn.Conv2d(
+ in_channels=embed_dim,
+ out_channels=embed_dim,
+ bias=bias,
+ kernel_size=1,
+ )
+ self.out_drop = nn.Dropout(proj_drop)
+
+ def _forward_self_attn(self, x: torch.Tensor) -> torch.Tensor:
+ # [B, C, P, N] --> [B, h + 2d, P, N]
+ qkv = self.qkv_proj(x)
+
+ # Project x into query, key and value
+ # Query --> [B, 1, P, N]
+ # value, key --> [B, d, P, N]
+ query, key, value = qkv.split([1, self.embed_dim, self.embed_dim], dim=1)
+
+ # apply softmax along N dimension
+ context_scores = F.softmax(query, dim=-1)
+ context_scores = self.attn_drop(context_scores)
+
+ # Compute context vector
+ # [B, d, P, N] x [B, 1, P, N] -> [B, d, P, N] --> [B, d, P, 1]
+ context_vector = (key * context_scores).sum(dim=-1, keepdim=True)
+
+ # combine context vector with values
+ # [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N]
+ out = F.relu(value) * context_vector.expand_as(value)
+ out = self.out_proj(out)
+ out = self.out_drop(out)
+ return out
+
+ @torch.jit.ignore()
+ def _forward_cross_attn(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor:
+ # x --> [B, C, P, N]
+ # x_prev = [B, C, P, M]
+ batch_size, in_dim, kv_patch_area, kv_num_patches = x.shape
+ q_patch_area, q_num_patches = x.shape[-2:]
+
+ assert (
+ kv_patch_area == q_patch_area
+ ), "The number of pixels in a patch for query and key_value should be the same"
+
+ # compute query, key, and value
+ # [B, C, P, M] --> [B, 1 + d, P, M]
+ qk = F.conv2d(
+ x_prev,
+ weight=self.qkv_proj.weight[:self.embed_dim + 1],
+ bias=self.qkv_proj.bias[:self.embed_dim + 1],
+ )
+
+ # [B, 1 + d, P, M] --> [B, 1, P, M], [B, d, P, M]
+ query, key = qk.split([1, self.embed_dim], dim=1)
+ # [B, C, P, N] --> [B, d, P, N]
+ value = F.conv2d(
+ x,
+ weight=self.qkv_proj.weight[self.embed_dim + 1],
+ bias=self.qkv_proj.bias[self.embed_dim + 1] if self.qkv_proj.bias is not None else None,
+ )
+
+ # apply softmax along M dimension
+ context_scores = F.softmax(query, dim=-1)
+ context_scores = self.attn_drop(context_scores)
+
+ # compute context vector
+ # [B, d, P, M] * [B, 1, P, M] -> [B, d, P, M] --> [B, d, P, 1]
+ context_vector = (key * context_scores).sum(dim=-1, keepdim=True)
+
+ # combine context vector with values
+ # [B, d, P, N] * [B, d, P, 1] --> [B, d, P, N]
+ out = F.relu(value) * context_vector.expand_as(value)
+ out = self.out_proj(out)
+ out = self.out_drop(out)
+ return out
+
+ def forward(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor:
+ if x_prev is None:
+ return self._forward_self_attn(x)
+ else:
+ return self._forward_cross_attn(x, x_prev=x_prev)
+
+
+class LinearTransformerBlock(nn.Module):
+ """
+ This class defines the pre-norm transformer encoder with linear self-attention in `MobileViTv2 paper <>`_
+ Args:
+ embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(B, C_{in}, P, N)`
+ mlp_ratio (float): Inner dimension ratio of the FFN relative to embed_dim
+ drop (float): Dropout rate. Default: 0.0
+ attn_drop (float): Dropout rate for attention in multi-head attention. Default: 0.0
+ drop_path (float): Stochastic depth rate Default: 0.0
+ norm_layer (Callable): Normalization layer. Default: layer_norm_2d
+ Shape:
+ - Input: :math:`(B, C_{in}, P, N)` where :math:`B` is batch size, :math:`C_{in}` is input embedding dim,
+ :math:`P` is number of pixels in a patch, and :math:`N` is number of patches,
+ - Output: same shape as the input
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ mlp_ratio: float = 2.0,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ drop_path: float = 0.0,
+ act_layer=None,
+ norm_layer=None,
+ ) -> None:
+ super().__init__()
+ act_layer = act_layer or nn.SiLU
+ norm_layer = norm_layer or GroupNorm1
+
+ self.norm1 = norm_layer(embed_dim)
+ self.attn = LinearSelfAttention(embed_dim=embed_dim, attn_drop=attn_drop, proj_drop=drop)
+ self.drop_path1 = DropPath(drop_path)
+
+ self.norm2 = norm_layer(embed_dim)
+ self.mlp = ConvMlp(
+ in_features=embed_dim,
+ hidden_features=int(embed_dim * mlp_ratio),
+ act_layer=act_layer,
+ drop=drop)
+ self.drop_path2 = DropPath(drop_path)
+
+ def forward(self, x: torch.Tensor, x_prev: Optional[torch.Tensor] = None) -> torch.Tensor:
+ if x_prev is None:
+ # self-attention
+ x = x + self.drop_path1(self.attn(self.norm1(x)))
+ else:
+ # cross-attention
+ res = x
+ x = self.norm1(x) # norm
+ x = self.attn(x, x_prev) # attn
+ x = self.drop_path1(x) + res # residual
+
+ # Feed forward network
+ x = x + self.drop_path2(self.mlp(self.norm2(x)))
+ return x
+
+
+@register_notrace_module
+class MobileVitV2Block(nn.Module):
+ """
+ This class defines the `MobileViTv2 block <>`_
+ """
+
+ def __init__(
+ self,
+ in_chs: int,
+ out_chs: Optional[int] = None,
+ kernel_size: int = 3,
+ bottle_ratio: float = 1.0,
+ group_size: Optional[int] = 1,
+ dilation: Tuple[int, int] = (1, 1),
+ mlp_ratio: float = 2.0,
+ transformer_dim: Optional[int] = None,
+ transformer_depth: int = 2,
+ patch_size: int = 8,
+ attn_drop: float = 0.,
+ drop: int = 0.,
+ drop_path_rate: float = 0.,
+ layers: LayerFn = None,
+ transformer_norm_layer: Callable = GroupNorm1,
+ **kwargs, # eat unused args
+ ):
+ super(MobileVitV2Block, self).__init__()
+ layers = layers or LayerFn()
+ groups = num_groups(group_size, in_chs)
+ out_chs = out_chs or in_chs
+ transformer_dim = transformer_dim or make_divisible(bottle_ratio * in_chs)
+
+ self.conv_kxk = layers.conv_norm_act(
+ in_chs, in_chs, kernel_size=kernel_size,
+ stride=1, groups=groups, dilation=dilation[0])
+ self.conv_1x1 = nn.Conv2d(in_chs, transformer_dim, kernel_size=1, bias=False)
+
+ self.transformer = nn.Sequential(*[
+ LinearTransformerBlock(
+ transformer_dim,
+ mlp_ratio=mlp_ratio,
+ attn_drop=attn_drop,
+ drop=drop,
+ drop_path=drop_path_rate,
+ act_layer=layers.act,
+ norm_layer=transformer_norm_layer
+ )
+ for _ in range(transformer_depth)
+ ])
+ self.norm = transformer_norm_layer(transformer_dim)
+
+ self.conv_proj = layers.conv_norm_act(transformer_dim, out_chs, kernel_size=1, stride=1, apply_act=False)
+
+ self.patch_size = to_2tuple(patch_size)
+ self.patch_area = self.patch_size[0] * self.patch_size[1]
+ self.coreml_exportable = is_exportable()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ B, C, H, W = x.shape
+ patch_h, patch_w = self.patch_size
+ new_h, new_w = math.ceil(H / patch_h) * patch_h, math.ceil(W / patch_w) * patch_w
+ num_patch_h, num_patch_w = new_h // patch_h, new_w // patch_w # n_h, n_w
+ num_patches = num_patch_h * num_patch_w # N
+ if new_h != H or new_w != W:
+ x = F.interpolate(x, size=(new_h, new_w), mode="bilinear", align_corners=True)
+
+ # Local representation
+ x = self.conv_kxk(x)
+ x = self.conv_1x1(x)
+
+ # Unfold (feature map -> patches), [B, C, H, W] -> [B, C, P, N]
+ C = x.shape[1]
+ if self.coreml_exportable:
+ x = F.unfold(x, kernel_size=(patch_h, patch_w), stride=(patch_h, patch_w))
+ else:
+ x = x.reshape(B, C, num_patch_h, patch_h, num_patch_w, patch_w).permute(0, 1, 3, 5, 2, 4)
+ x = x.reshape(B, C, -1, num_patches)
+
+ # Global representations
+ x = self.transformer(x)
+ x = self.norm(x)
+
+ # Fold (patches -> feature map), [B, C, P, N] --> [B, C, H, W]
+ if self.coreml_exportable:
+ # adopted from https://github.com/apple/ml-cvnets/blob/main/cvnets/modules/mobilevit_block.py#L609-L624
+ x = x.reshape(B, C * patch_h * patch_w, num_patch_h, num_patch_w)
+ x = F.pixel_shuffle(x, upscale_factor=patch_h)
+ else:
+ x = x.reshape(B, C, patch_h, patch_w, num_patch_h, num_patch_w).permute(0, 1, 4, 2, 5, 3)
+ x = x.reshape(B, C, num_patch_h * patch_h, num_patch_w * patch_w)
+
+ x = self.conv_proj(x)
+ return x
+
+
+register_block('mobilevit', MobileVitBlock)
+register_block('mobilevit2', MobileVitV2Block)
+
+
+def _create_mobilevit(variant, cfg_variant=None, pretrained=False, **kwargs):
+ return build_model_with_cfg(
+ ByobNet, variant, pretrained,
+ model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
+ feature_cfg=dict(flatten_sequential=True),
+ **kwargs)
+
+
+def _create_mobilevit2(variant, cfg_variant=None, pretrained=False, **kwargs):
+ return build_model_with_cfg(
+ ByobNet, variant, pretrained,
+ model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
+ feature_cfg=dict(flatten_sequential=True),
+ **kwargs)
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
+ 'crop_pct': 0.9, 'interpolation': 'bicubic',
+ 'mean': (0., 0., 0.), 'std': (1., 1., 1.),
+ 'first_conv': 'stem.conv', 'classifier': 'head.fc',
+ 'fixed_input_size': False,
+ **kwargs
+ }
+
+
+default_cfgs = generate_default_cfgs({
+ 'mobilevit_xxs.cvnets_in1k': _cfg(hf_hub_id='timm/'),
+ 'mobilevit_xs.cvnets_in1k': _cfg(hf_hub_id='timm/'),
+ 'mobilevit_s.cvnets_in1k': _cfg(hf_hub_id='timm/'),
+
+ 'mobilevitv2_050.cvnets_in1k': _cfg(
+ hf_hub_id='timm/',
+ crop_pct=0.888),
+ 'mobilevitv2_075.cvnets_in1k': _cfg(
+ hf_hub_id='timm/',
+ crop_pct=0.888),
+ 'mobilevitv2_100.cvnets_in1k': _cfg(
+ hf_hub_id='timm/',
+ crop_pct=0.888),
+ 'mobilevitv2_125.cvnets_in1k': _cfg(
+ hf_hub_id='timm/',
+ crop_pct=0.888),
+ 'mobilevitv2_150.cvnets_in1k': _cfg(
+ hf_hub_id='timm/',
+ crop_pct=0.888),
+ 'mobilevitv2_175.cvnets_in1k': _cfg(
+ hf_hub_id='timm/',
+ crop_pct=0.888),
+ 'mobilevitv2_200.cvnets_in1k': _cfg(
+ hf_hub_id='timm/',
+ crop_pct=0.888),
+
+ 'mobilevitv2_150.cvnets_in22k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ crop_pct=0.888),
+ 'mobilevitv2_175.cvnets_in22k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ crop_pct=0.888),
+ 'mobilevitv2_200.cvnets_in22k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ crop_pct=0.888),
+
+ 'mobilevitv2_150.cvnets_in22k_ft_in1k_384': _cfg(
+ hf_hub_id='timm/',
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
+ 'mobilevitv2_175.cvnets_in22k_ft_in1k_384': _cfg(
+ hf_hub_id='timm/',
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
+ 'mobilevitv2_200.cvnets_in22k_ft_in1k_384': _cfg(
+ hf_hub_id='timm/',
+ input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
+})
+
+
+@register_model
+def mobilevit_xxs(pretrained=False, **kwargs) -> ByobNet:
+ return _create_mobilevit('mobilevit_xxs', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def mobilevit_xs(pretrained=False, **kwargs) -> ByobNet:
+ return _create_mobilevit('mobilevit_xs', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def mobilevit_s(pretrained=False, **kwargs) -> ByobNet:
+ return _create_mobilevit('mobilevit_s', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def mobilevitv2_050(pretrained=False, **kwargs) -> ByobNet:
+ return _create_mobilevit('mobilevitv2_050', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def mobilevitv2_075(pretrained=False, **kwargs) -> ByobNet:
+ return _create_mobilevit('mobilevitv2_075', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def mobilevitv2_100(pretrained=False, **kwargs) -> ByobNet:
+ return _create_mobilevit('mobilevitv2_100', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def mobilevitv2_125(pretrained=False, **kwargs) -> ByobNet:
+ return _create_mobilevit('mobilevitv2_125', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def mobilevitv2_150(pretrained=False, **kwargs) -> ByobNet:
+ return _create_mobilevit('mobilevitv2_150', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def mobilevitv2_175(pretrained=False, **kwargs) -> ByobNet:
+ return _create_mobilevit('mobilevitv2_175', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def mobilevitv2_200(pretrained=False, **kwargs) -> ByobNet:
+ return _create_mobilevit('mobilevitv2_200', pretrained=pretrained, **kwargs)
+
+
+register_model_deprecations(__name__, {
+ 'mobilevitv2_150_in22ft1k': 'mobilevitv2_150.cvnets_in22k_ft_in1k',
+ 'mobilevitv2_175_in22ft1k': 'mobilevitv2_175.cvnets_in22k_ft_in1k',
+ 'mobilevitv2_200_in22ft1k': 'mobilevitv2_200.cvnets_in22k_ft_in1k',
+
+ 'mobilevitv2_150_384_in22ft1k': 'mobilevitv2_150.cvnets_in22k_ft_in1k_384',
+ 'mobilevitv2_175_384_in22ft1k': 'mobilevitv2_175.cvnets_in22k_ft_in1k_384',
+ 'mobilevitv2_200_384_in22ft1k': 'mobilevitv2_200.cvnets_in22k_ft_in1k_384',
+})
\ No newline at end of file
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/mvitv2.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/mvitv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d035fd65aa6f92d68e64a5fbf41acbdab841ffb
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/mvitv2.py
@@ -0,0 +1,1041 @@
+""" Multi-Scale Vision Transformer v2
+
+@inproceedings{li2021improved,
+ title={MViTv2: Improved multiscale vision transformers for classification and detection},
+ author={Li, Yanghao and Wu, Chao-Yuan and Fan, Haoqi and Mangalam, Karttikeya and Xiong, Bo and Malik, Jitendra and Feichtenhofer, Christoph},
+ booktitle={CVPR},
+ year={2022}
+}
+
+Code adapted from original Apache 2.0 licensed impl at https://github.com/facebookresearch/mvit
+Original copyright below.
+
+Modifications and timm support by / Copyright 2022, Ross Wightman
+"""
+# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved. All Rights Reserved.
+import operator
+from collections import OrderedDict
+from dataclasses import dataclass
+from functools import partial, reduce
+from typing import Union, List, Tuple, Optional
+
+import torch
+import torch.utils.checkpoint as checkpoint
+from torch import nn
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.layers import Mlp, DropPath, trunc_normal_tf_, get_norm_layer, to_2tuple
+from ._builder import build_model_with_cfg
+from ._features_fx import register_notrace_function
+from ._registry import register_model, register_model_deprecations, generate_default_cfgs
+
+__all__ = ['MultiScaleVit', 'MultiScaleVitCfg'] # model_registry will add each entrypoint fn to this
+
+
+@dataclass
+class MultiScaleVitCfg:
+ depths: Tuple[int, ...] = (2, 3, 16, 3)
+ embed_dim: Union[int, Tuple[int, ...]] = 96
+ num_heads: Union[int, Tuple[int, ...]] = 1
+ mlp_ratio: float = 4.
+ pool_first: bool = False
+ expand_attn: bool = True
+ qkv_bias: bool = True
+ use_cls_token: bool = False
+ use_abs_pos: bool = False
+ residual_pooling: bool = True
+ mode: str = 'conv'
+ kernel_qkv: Tuple[int, int] = (3, 3)
+ stride_q: Optional[Tuple[Tuple[int, int]]] = ((1, 1), (2, 2), (2, 2), (2, 2))
+ stride_kv: Optional[Tuple[Tuple[int, int]]] = None
+ stride_kv_adaptive: Optional[Tuple[int, int]] = (4, 4)
+ patch_kernel: Tuple[int, int] = (7, 7)
+ patch_stride: Tuple[int, int] = (4, 4)
+ patch_padding: Tuple[int, int] = (3, 3)
+ pool_type: str = 'max'
+ rel_pos_type: str = 'spatial'
+ act_layer: Union[str, Tuple[str, str]] = 'gelu'
+ norm_layer: Union[str, Tuple[str, str]] = 'layernorm'
+ norm_eps: float = 1e-6
+
+ def __post_init__(self):
+ num_stages = len(self.depths)
+ if not isinstance(self.embed_dim, (tuple, list)):
+ self.embed_dim = tuple(self.embed_dim * 2 ** i for i in range(num_stages))
+ assert len(self.embed_dim) == num_stages
+
+ if not isinstance(self.num_heads, (tuple, list)):
+ self.num_heads = tuple(self.num_heads * 2 ** i for i in range(num_stages))
+ assert len(self.num_heads) == num_stages
+
+ if self.stride_kv_adaptive is not None and self.stride_kv is None:
+ _stride_kv = self.stride_kv_adaptive
+ pool_kv_stride = []
+ for i in range(num_stages):
+ if min(self.stride_q[i]) > 1:
+ _stride_kv = [
+ max(_stride_kv[d] // self.stride_q[i][d], 1)
+ for d in range(len(_stride_kv))
+ ]
+ pool_kv_stride.append(tuple(_stride_kv))
+ self.stride_kv = tuple(pool_kv_stride)
+
+
+def prod(iterable):
+ return reduce(operator.mul, iterable, 1)
+
+
+class PatchEmbed(nn.Module):
+ """
+ PatchEmbed.
+ """
+
+ def __init__(
+ self,
+ dim_in=3,
+ dim_out=768,
+ kernel=(7, 7),
+ stride=(4, 4),
+ padding=(3, 3),
+ ):
+ super().__init__()
+
+ self.proj = nn.Conv2d(
+ dim_in,
+ dim_out,
+ kernel_size=kernel,
+ stride=stride,
+ padding=padding,
+ )
+
+ def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
+ x = self.proj(x)
+ # B C H W -> B HW C
+ return x.flatten(2).transpose(1, 2), x.shape[-2:]
+
+
+@register_notrace_function
+def reshape_pre_pool(
+ x,
+ feat_size: List[int],
+ has_cls_token: bool = True
+) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ H, W = feat_size
+ if has_cls_token:
+ cls_tok, x = x[:, :, :1, :], x[:, :, 1:, :]
+ else:
+ cls_tok = None
+ x = x.reshape(-1, H, W, x.shape[-1]).permute(0, 3, 1, 2).contiguous()
+ return x, cls_tok
+
+
+@register_notrace_function
+def reshape_post_pool(
+ x,
+ num_heads: int,
+ cls_tok: Optional[torch.Tensor] = None
+) -> Tuple[torch.Tensor, List[int]]:
+ feat_size = [x.shape[2], x.shape[3]]
+ L_pooled = x.shape[2] * x.shape[3]
+ x = x.reshape(-1, num_heads, x.shape[1], L_pooled).transpose(2, 3)
+ if cls_tok is not None:
+ x = torch.cat((cls_tok, x), dim=2)
+ return x, feat_size
+
+
+@register_notrace_function
+def cal_rel_pos_type(
+ attn: torch.Tensor,
+ q: torch.Tensor,
+ has_cls_token: bool,
+ q_size: List[int],
+ k_size: List[int],
+ rel_pos_h: torch.Tensor,
+ rel_pos_w: torch.Tensor,
+):
+ """
+ Spatial Relative Positional Embeddings.
+ """
+ sp_idx = 1 if has_cls_token else 0
+ q_h, q_w = q_size
+ k_h, k_w = k_size
+
+ # Scale up rel pos if shapes for q and k are different.
+ q_h_ratio = max(k_h / q_h, 1.0)
+ k_h_ratio = max(q_h / k_h, 1.0)
+ dist_h = (
+ torch.arange(q_h, device=q.device).unsqueeze(-1) * q_h_ratio -
+ torch.arange(k_h, device=q.device).unsqueeze(0) * k_h_ratio
+ )
+ dist_h += (k_h - 1) * k_h_ratio
+ q_w_ratio = max(k_w / q_w, 1.0)
+ k_w_ratio = max(q_w / k_w, 1.0)
+ dist_w = (
+ torch.arange(q_w, device=q.device).unsqueeze(-1) * q_w_ratio -
+ torch.arange(k_w, device=q.device).unsqueeze(0) * k_w_ratio
+ )
+ dist_w += (k_w - 1) * k_w_ratio
+
+ rel_h = rel_pos_h[dist_h.long()]
+ rel_w = rel_pos_w[dist_w.long()]
+
+ B, n_head, q_N, dim = q.shape
+
+ r_q = q[:, :, sp_idx:].reshape(B, n_head, q_h, q_w, dim)
+ rel_h = torch.einsum("byhwc,hkc->byhwk", r_q, rel_h)
+ rel_w = torch.einsum("byhwc,wkc->byhwk", r_q, rel_w)
+
+ attn[:, :, sp_idx:, sp_idx:] = (
+ attn[:, :, sp_idx:, sp_idx:].view(B, -1, q_h, q_w, k_h, k_w)
+ + rel_h.unsqueeze(-1)
+ + rel_w.unsqueeze(-2)
+ ).view(B, -1, q_h * q_w, k_h * k_w)
+
+ return attn
+
+
+class MultiScaleAttentionPoolFirst(nn.Module):
+ def __init__(
+ self,
+ dim,
+ dim_out,
+ feat_size,
+ num_heads=8,
+ qkv_bias=True,
+ mode="conv",
+ kernel_q=(1, 1),
+ kernel_kv=(1, 1),
+ stride_q=(1, 1),
+ stride_kv=(1, 1),
+ has_cls_token=True,
+ rel_pos_type='spatial',
+ residual_pooling=True,
+ norm_layer=nn.LayerNorm,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ self.dim_out = dim_out
+ self.head_dim = dim_out // num_heads
+ self.scale = self.head_dim ** -0.5
+ self.has_cls_token = has_cls_token
+ padding_q = tuple([int(q // 2) for q in kernel_q])
+ padding_kv = tuple([int(kv // 2) for kv in kernel_kv])
+
+ self.q = nn.Linear(dim, dim_out, bias=qkv_bias)
+ self.k = nn.Linear(dim, dim_out, bias=qkv_bias)
+ self.v = nn.Linear(dim, dim_out, bias=qkv_bias)
+ self.proj = nn.Linear(dim_out, dim_out)
+
+ # Skip pooling with kernel and stride size of (1, 1, 1).
+ if prod(kernel_q) == 1 and prod(stride_q) == 1:
+ kernel_q = None
+ if prod(kernel_kv) == 1 and prod(stride_kv) == 1:
+ kernel_kv = None
+ self.mode = mode
+ self.unshared = mode == 'conv_unshared'
+ self.pool_q, self.pool_k, self.pool_v = None, None, None
+ self.norm_q, self.norm_k, self.norm_v = None, None, None
+ if mode in ("avg", "max"):
+ pool_op = nn.MaxPool2d if mode == "max" else nn.AvgPool2d
+ if kernel_q:
+ self.pool_q = pool_op(kernel_q, stride_q, padding_q)
+ if kernel_kv:
+ self.pool_k = pool_op(kernel_kv, stride_kv, padding_kv)
+ self.pool_v = pool_op(kernel_kv, stride_kv, padding_kv)
+ elif mode == "conv" or mode == "conv_unshared":
+ dim_conv = dim // num_heads if mode == "conv" else dim
+ if kernel_q:
+ self.pool_q = nn.Conv2d(
+ dim_conv,
+ dim_conv,
+ kernel_q,
+ stride=stride_q,
+ padding=padding_q,
+ groups=dim_conv,
+ bias=False,
+ )
+ self.norm_q = norm_layer(dim_conv)
+ if kernel_kv:
+ self.pool_k = nn.Conv2d(
+ dim_conv,
+ dim_conv,
+ kernel_kv,
+ stride=stride_kv,
+ padding=padding_kv,
+ groups=dim_conv,
+ bias=False,
+ )
+ self.norm_k = norm_layer(dim_conv)
+ self.pool_v = nn.Conv2d(
+ dim_conv,
+ dim_conv,
+ kernel_kv,
+ stride=stride_kv,
+ padding=padding_kv,
+ groups=dim_conv,
+ bias=False,
+ )
+ self.norm_v = norm_layer(dim_conv)
+ else:
+ raise NotImplementedError(f"Unsupported model {mode}")
+
+ # relative pos embedding
+ self.rel_pos_type = rel_pos_type
+ if self.rel_pos_type == 'spatial':
+ assert feat_size[0] == feat_size[1]
+ size = feat_size[0]
+ q_size = size // stride_q[1] if len(stride_q) > 0 else size
+ kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size
+ rel_sp_dim = 2 * max(q_size, kv_size) - 1
+
+ self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim))
+ self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim))
+ trunc_normal_tf_(self.rel_pos_h, std=0.02)
+ trunc_normal_tf_(self.rel_pos_w, std=0.02)
+
+ self.residual_pooling = residual_pooling
+
+ def forward(self, x, feat_size: List[int]):
+ B, N, _ = x.shape
+
+ fold_dim = 1 if self.unshared else self.num_heads
+ x = x.reshape(B, N, fold_dim, -1).permute(0, 2, 1, 3)
+ q = k = v = x
+
+ if self.pool_q is not None:
+ q, q_tok = reshape_pre_pool(q, feat_size, self.has_cls_token)
+ q = self.pool_q(q)
+ q, q_size = reshape_post_pool(q, self.num_heads, q_tok)
+ else:
+ q_size = feat_size
+ if self.norm_q is not None:
+ q = self.norm_q(q)
+
+ if self.pool_k is not None:
+ k, k_tok = reshape_pre_pool(k, feat_size, self.has_cls_token)
+ k = self.pool_k(k)
+ k, k_size = reshape_post_pool(k, self.num_heads, k_tok)
+ else:
+ k_size = feat_size
+ if self.norm_k is not None:
+ k = self.norm_k(k)
+
+ if self.pool_v is not None:
+ v, v_tok = reshape_pre_pool(v, feat_size, self.has_cls_token)
+ v = self.pool_v(v)
+ v, v_size = reshape_post_pool(v, self.num_heads, v_tok)
+ else:
+ v_size = feat_size
+ if self.norm_v is not None:
+ v = self.norm_v(v)
+
+ q_N = q_size[0] * q_size[1] + int(self.has_cls_token)
+ q = q.transpose(1, 2).reshape(B, q_N, -1)
+ q = self.q(q).reshape(B, q_N, self.num_heads, -1).transpose(1, 2)
+
+ k_N = k_size[0] * k_size[1] + int(self.has_cls_token)
+ k = k.transpose(1, 2).reshape(B, k_N, -1)
+ k = self.k(k).reshape(B, k_N, self.num_heads, -1)
+
+ v_N = v_size[0] * v_size[1] + int(self.has_cls_token)
+ v = v.transpose(1, 2).reshape(B, v_N, -1)
+ v = self.v(v).reshape(B, v_N, self.num_heads, -1).transpose(1, 2)
+
+ attn = (q * self.scale) @ k
+ if self.rel_pos_type == 'spatial':
+ attn = cal_rel_pos_type(
+ attn,
+ q,
+ self.has_cls_token,
+ q_size,
+ k_size,
+ self.rel_pos_h,
+ self.rel_pos_w,
+ )
+ attn = attn.softmax(dim=-1)
+ x = attn @ v
+
+ if self.residual_pooling:
+ x = x + q
+
+ x = x.transpose(1, 2).reshape(B, -1, self.dim_out)
+ x = self.proj(x)
+
+ return x, q_size
+
+
+class MultiScaleAttention(nn.Module):
+ def __init__(
+ self,
+ dim,
+ dim_out,
+ feat_size,
+ num_heads=8,
+ qkv_bias=True,
+ mode="conv",
+ kernel_q=(1, 1),
+ kernel_kv=(1, 1),
+ stride_q=(1, 1),
+ stride_kv=(1, 1),
+ has_cls_token=True,
+ rel_pos_type='spatial',
+ residual_pooling=True,
+ norm_layer=nn.LayerNorm,
+ ):
+ super().__init__()
+ self.num_heads = num_heads
+ self.dim_out = dim_out
+ self.head_dim = dim_out // num_heads
+ self.scale = self.head_dim ** -0.5
+ self.has_cls_token = has_cls_token
+ padding_q = tuple([int(q // 2) for q in kernel_q])
+ padding_kv = tuple([int(kv // 2) for kv in kernel_kv])
+
+ self.qkv = nn.Linear(dim, dim_out * 3, bias=qkv_bias)
+ self.proj = nn.Linear(dim_out, dim_out)
+
+ # Skip pooling with kernel and stride size of (1, 1, 1).
+ if prod(kernel_q) == 1 and prod(stride_q) == 1:
+ kernel_q = None
+ if prod(kernel_kv) == 1 and prod(stride_kv) == 1:
+ kernel_kv = None
+ self.mode = mode
+ self.unshared = mode == 'conv_unshared'
+ self.norm_q, self.norm_k, self.norm_v = None, None, None
+ self.pool_q, self.pool_k, self.pool_v = None, None, None
+ if mode in ("avg", "max"):
+ pool_op = nn.MaxPool2d if mode == "max" else nn.AvgPool2d
+ if kernel_q:
+ self.pool_q = pool_op(kernel_q, stride_q, padding_q)
+ if kernel_kv:
+ self.pool_k = pool_op(kernel_kv, stride_kv, padding_kv)
+ self.pool_v = pool_op(kernel_kv, stride_kv, padding_kv)
+ elif mode == "conv" or mode == "conv_unshared":
+ dim_conv = dim_out // num_heads if mode == "conv" else dim_out
+ if kernel_q:
+ self.pool_q = nn.Conv2d(
+ dim_conv,
+ dim_conv,
+ kernel_q,
+ stride=stride_q,
+ padding=padding_q,
+ groups=dim_conv,
+ bias=False,
+ )
+ self.norm_q = norm_layer(dim_conv)
+ if kernel_kv:
+ self.pool_k = nn.Conv2d(
+ dim_conv,
+ dim_conv,
+ kernel_kv,
+ stride=stride_kv,
+ padding=padding_kv,
+ groups=dim_conv,
+ bias=False,
+ )
+ self.norm_k = norm_layer(dim_conv)
+ self.pool_v = nn.Conv2d(
+ dim_conv,
+ dim_conv,
+ kernel_kv,
+ stride=stride_kv,
+ padding=padding_kv,
+ groups=dim_conv,
+ bias=False,
+ )
+ self.norm_v = norm_layer(dim_conv)
+ else:
+ raise NotImplementedError(f"Unsupported model {mode}")
+
+ # relative pos embedding
+ self.rel_pos_type = rel_pos_type
+ if self.rel_pos_type == 'spatial':
+ assert feat_size[0] == feat_size[1]
+ size = feat_size[0]
+ q_size = size // stride_q[1] if len(stride_q) > 0 else size
+ kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size
+ rel_sp_dim = 2 * max(q_size, kv_size) - 1
+
+ self.rel_pos_h = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim))
+ self.rel_pos_w = nn.Parameter(torch.zeros(rel_sp_dim, self.head_dim))
+ trunc_normal_tf_(self.rel_pos_h, std=0.02)
+ trunc_normal_tf_(self.rel_pos_w, std=0.02)
+
+ self.residual_pooling = residual_pooling
+
+ def forward(self, x, feat_size: List[int]):
+ B, N, _ = x.shape
+
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(dim=0)
+
+ if self.pool_q is not None:
+ q, q_tok = reshape_pre_pool(q, feat_size, self.has_cls_token)
+ q = self.pool_q(q)
+ q, q_size = reshape_post_pool(q, self.num_heads, q_tok)
+ else:
+ q_size = feat_size
+ if self.norm_q is not None:
+ q = self.norm_q(q)
+
+ if self.pool_k is not None:
+ k, k_tok = reshape_pre_pool(k, feat_size, self.has_cls_token)
+ k = self.pool_k(k)
+ k, k_size = reshape_post_pool(k, self.num_heads, k_tok)
+ else:
+ k_size = feat_size
+ if self.norm_k is not None:
+ k = self.norm_k(k)
+
+ if self.pool_v is not None:
+ v, v_tok = reshape_pre_pool(v, feat_size, self.has_cls_token)
+ v = self.pool_v(v)
+ v, _ = reshape_post_pool(v, self.num_heads, v_tok)
+ if self.norm_v is not None:
+ v = self.norm_v(v)
+
+ attn = (q * self.scale) @ k.transpose(-2, -1)
+ if self.rel_pos_type == 'spatial':
+ attn = cal_rel_pos_type(
+ attn,
+ q,
+ self.has_cls_token,
+ q_size,
+ k_size,
+ self.rel_pos_h,
+ self.rel_pos_w,
+ )
+ attn = attn.softmax(dim=-1)
+ x = attn @ v
+
+ if self.residual_pooling:
+ x = x + q
+
+ x = x.transpose(1, 2).reshape(B, -1, self.dim_out)
+ x = self.proj(x)
+
+ return x, q_size
+
+
+class MultiScaleBlock(nn.Module):
+ def __init__(
+ self,
+ dim,
+ dim_out,
+ num_heads,
+ feat_size,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ drop_path=0.0,
+ norm_layer=nn.LayerNorm,
+ kernel_q=(1, 1),
+ kernel_kv=(1, 1),
+ stride_q=(1, 1),
+ stride_kv=(1, 1),
+ mode="conv",
+ has_cls_token=True,
+ expand_attn=False,
+ pool_first=False,
+ rel_pos_type='spatial',
+ residual_pooling=True,
+ ):
+ super().__init__()
+ proj_needed = dim != dim_out
+ self.dim = dim
+ self.dim_out = dim_out
+ self.has_cls_token = has_cls_token
+
+ self.norm1 = norm_layer(dim)
+
+ self.shortcut_proj_attn = nn.Linear(dim, dim_out) if proj_needed and expand_attn else None
+ if stride_q and prod(stride_q) > 1:
+ kernel_skip = [s + 1 if s > 1 else s for s in stride_q]
+ stride_skip = stride_q
+ padding_skip = [int(skip // 2) for skip in kernel_skip]
+ self.shortcut_pool_attn = nn.MaxPool2d(kernel_skip, stride_skip, padding_skip)
+ else:
+ self.shortcut_pool_attn = None
+
+ att_dim = dim_out if expand_attn else dim
+ attn_layer = MultiScaleAttentionPoolFirst if pool_first else MultiScaleAttention
+ self.attn = attn_layer(
+ dim,
+ att_dim,
+ num_heads=num_heads,
+ feat_size=feat_size,
+ qkv_bias=qkv_bias,
+ kernel_q=kernel_q,
+ kernel_kv=kernel_kv,
+ stride_q=stride_q,
+ stride_kv=stride_kv,
+ norm_layer=norm_layer,
+ has_cls_token=has_cls_token,
+ mode=mode,
+ rel_pos_type=rel_pos_type,
+ residual_pooling=residual_pooling,
+ )
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(att_dim)
+ mlp_dim_out = dim_out
+ self.shortcut_proj_mlp = nn.Linear(dim, dim_out) if proj_needed and not expand_attn else None
+ self.mlp = Mlp(
+ in_features=att_dim,
+ hidden_features=int(att_dim * mlp_ratio),
+ out_features=mlp_dim_out,
+ )
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ def _shortcut_pool(self, x, feat_size: List[int]):
+ if self.shortcut_pool_attn is None:
+ return x
+ if self.has_cls_token:
+ cls_tok, x = x[:, :1, :], x[:, 1:, :]
+ else:
+ cls_tok = None
+ B, L, C = x.shape
+ H, W = feat_size
+ x = x.reshape(B, H, W, C).permute(0, 3, 1, 2).contiguous()
+ x = self.shortcut_pool_attn(x)
+ x = x.reshape(B, C, -1).transpose(1, 2)
+ if cls_tok is not None:
+ x = torch.cat((cls_tok, x), dim=1)
+ return x
+
+ def forward(self, x, feat_size: List[int]):
+ x_norm = self.norm1(x)
+ # NOTE as per the original impl, this seems odd, but shortcut uses un-normalized input if no proj
+ x_shortcut = x if self.shortcut_proj_attn is None else self.shortcut_proj_attn(x_norm)
+ x_shortcut = self._shortcut_pool(x_shortcut, feat_size)
+ x, feat_size_new = self.attn(x_norm, feat_size)
+ x = x_shortcut + self.drop_path1(x)
+
+ x_norm = self.norm2(x)
+ x_shortcut = x if self.shortcut_proj_mlp is None else self.shortcut_proj_mlp(x_norm)
+ x = x_shortcut + self.drop_path2(self.mlp(x_norm))
+ return x, feat_size_new
+
+
+class MultiScaleVitStage(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ dim_out,
+ depth,
+ num_heads,
+ feat_size,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ mode="conv",
+ kernel_q=(1, 1),
+ kernel_kv=(1, 1),
+ stride_q=(1, 1),
+ stride_kv=(1, 1),
+ has_cls_token=True,
+ expand_attn=False,
+ pool_first=False,
+ rel_pos_type='spatial',
+ residual_pooling=True,
+ norm_layer=nn.LayerNorm,
+ drop_path=0.0,
+ ):
+ super().__init__()
+ self.grad_checkpointing = False
+
+ self.blocks = nn.ModuleList()
+ if expand_attn:
+ out_dims = (dim_out,) * depth
+ else:
+ out_dims = (dim,) * (depth - 1) + (dim_out,)
+
+ for i in range(depth):
+ attention_block = MultiScaleBlock(
+ dim=dim,
+ dim_out=out_dims[i],
+ num_heads=num_heads,
+ feat_size=feat_size,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ kernel_q=kernel_q,
+ kernel_kv=kernel_kv,
+ stride_q=stride_q if i == 0 else (1, 1),
+ stride_kv=stride_kv,
+ mode=mode,
+ has_cls_token=has_cls_token,
+ pool_first=pool_first,
+ rel_pos_type=rel_pos_type,
+ residual_pooling=residual_pooling,
+ expand_attn=expand_attn,
+ norm_layer=norm_layer,
+ drop_path=drop_path[i] if isinstance(drop_path, (list, tuple)) else drop_path,
+ )
+ dim = out_dims[i]
+ self.blocks.append(attention_block)
+ if i == 0:
+ feat_size = tuple([size // stride for size, stride in zip(feat_size, stride_q)])
+
+ self.feat_size = feat_size
+
+ def forward(self, x, feat_size: List[int]):
+ for blk in self.blocks:
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ x, feat_size = checkpoint.checkpoint(blk, x, feat_size)
+ else:
+ x, feat_size = blk(x, feat_size)
+ return x, feat_size
+
+
+class MultiScaleVit(nn.Module):
+ """
+ Improved Multiscale Vision Transformers for Classification and Detection
+ Yanghao Li*, Chao-Yuan Wu*, Haoqi Fan, Karttikeya Mangalam, Bo Xiong, Jitendra Malik,
+ Christoph Feichtenhofer*
+ https://arxiv.org/abs/2112.01526
+
+ Multiscale Vision Transformers
+ Haoqi Fan*, Bo Xiong*, Karttikeya Mangalam*, Yanghao Li*, Zhicheng Yan, Jitendra Malik,
+ Christoph Feichtenhofer*
+ https://arxiv.org/abs/2104.11227
+ """
+
+ def __init__(
+ self,
+ cfg: MultiScaleVitCfg,
+ img_size: Tuple[int, int] = (224, 224),
+ in_chans: int = 3,
+ global_pool: Optional[str] = None,
+ num_classes: int = 1000,
+ drop_path_rate: float = 0.,
+ drop_rate: float = 0.,
+ ):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ norm_layer = partial(get_norm_layer(cfg.norm_layer), eps=cfg.norm_eps)
+ self.num_classes = num_classes
+ self.drop_rate = drop_rate
+ if global_pool is None:
+ global_pool = 'token' if cfg.use_cls_token else 'avg'
+ self.global_pool = global_pool
+ self.depths = tuple(cfg.depths)
+ self.expand_attn = cfg.expand_attn
+
+ embed_dim = cfg.embed_dim[0]
+ self.patch_embed = PatchEmbed(
+ dim_in=in_chans,
+ dim_out=embed_dim,
+ kernel=cfg.patch_kernel,
+ stride=cfg.patch_stride,
+ padding=cfg.patch_padding,
+ )
+ patch_dims = (img_size[0] // cfg.patch_stride[0], img_size[1] // cfg.patch_stride[1])
+ num_patches = prod(patch_dims)
+
+ if cfg.use_cls_token:
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.num_prefix_tokens = 1
+ pos_embed_dim = num_patches + 1
+ else:
+ self.num_prefix_tokens = 0
+ self.cls_token = None
+ pos_embed_dim = num_patches
+
+ if cfg.use_abs_pos:
+ self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_dim, embed_dim))
+ else:
+ self.pos_embed = None
+
+ num_stages = len(cfg.embed_dim)
+ feat_size = patch_dims
+ dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
+ self.stages = nn.ModuleList()
+ for i in range(num_stages):
+ if cfg.expand_attn:
+ dim_out = cfg.embed_dim[i]
+ else:
+ dim_out = cfg.embed_dim[min(i + 1, num_stages - 1)]
+ stage = MultiScaleVitStage(
+ dim=embed_dim,
+ dim_out=dim_out,
+ depth=cfg.depths[i],
+ num_heads=cfg.num_heads[i],
+ feat_size=feat_size,
+ mlp_ratio=cfg.mlp_ratio,
+ qkv_bias=cfg.qkv_bias,
+ mode=cfg.mode,
+ pool_first=cfg.pool_first,
+ expand_attn=cfg.expand_attn,
+ kernel_q=cfg.kernel_qkv,
+ kernel_kv=cfg.kernel_qkv,
+ stride_q=cfg.stride_q[i],
+ stride_kv=cfg.stride_kv[i],
+ has_cls_token=cfg.use_cls_token,
+ rel_pos_type=cfg.rel_pos_type,
+ residual_pooling=cfg.residual_pooling,
+ norm_layer=norm_layer,
+ drop_path=dpr[i],
+ )
+ embed_dim = dim_out
+ feat_size = stage.feat_size
+ self.stages.append(stage)
+
+ self.num_features = embed_dim
+ self.norm = norm_layer(embed_dim)
+ self.head = nn.Sequential(OrderedDict([
+ ('drop', nn.Dropout(self.drop_rate)),
+ ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())
+ ]))
+
+ if self.pos_embed is not None:
+ trunc_normal_tf_(self.pos_embed, std=0.02)
+ if self.cls_token is not None:
+ trunc_normal_tf_(self.cls_token, std=0.02)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_tf_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {k for k, _ in self.named_parameters()
+ if any(n in k for n in ["pos_embed", "rel_pos_h", "rel_pos_w", "cls_token"])}
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse=False):
+ matcher = dict(
+ stem=r'^patch_embed', # stem and embed
+ blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))]
+ )
+ return matcher
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ for s in self.stages:
+ s.grad_checkpointing = enable
+
+ @torch.jit.ignore
+ def get_classifier(self):
+ return self.head.fc
+
+ def reset_classifier(self, num_classes, global_pool=None):
+ self.num_classes = num_classes
+ if global_pool is not None:
+ self.global_pool = global_pool
+ self.head = nn.Sequential(OrderedDict([
+ ('drop', nn.Dropout(self.drop_rate)),
+ ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())
+ ]))
+
+ def forward_features(self, x):
+ x, feat_size = self.patch_embed(x)
+ B, N, C = x.shape
+
+ if self.cls_token is not None:
+ cls_tokens = self.cls_token.expand(B, -1, -1)
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ if self.pos_embed is not None:
+ x = x + self.pos_embed
+
+ for stage in self.stages:
+ x, feat_size = stage(x, feat_size)
+
+ x = self.norm(x)
+ return x
+
+ def forward_head(self, x, pre_logits: bool = False):
+ if self.global_pool:
+ if self.global_pool == 'avg':
+ x = x[:, self.num_prefix_tokens:].mean(1)
+ else:
+ x = x[:, 0]
+ return x if pre_logits else self.head(x)
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.forward_head(x)
+ return x
+
+
+def checkpoint_filter_fn(state_dict, model):
+ if 'stages.0.blocks.0.norm1.weight' in state_dict:
+ return state_dict
+
+ import re
+ if 'model_state' in state_dict:
+ state_dict = state_dict['model_state']
+
+ depths = getattr(model, 'depths', None)
+ expand_attn = getattr(model, 'expand_attn', True)
+ assert depths is not None, 'model requires depth attribute to remap checkpoints'
+ depth_map = {}
+ block_idx = 0
+ for stage_idx, d in enumerate(depths):
+ depth_map.update({i: (stage_idx, i - block_idx) for i in range(block_idx, block_idx + d)})
+ block_idx += d
+
+ out_dict = {}
+ for k, v in state_dict.items():
+ k = re.sub(
+ r'blocks\.(\d+)',
+ lambda x: f'stages.{depth_map[int(x.group(1))][0]}.blocks.{depth_map[int(x.group(1))][1]}',
+ k)
+
+ if expand_attn:
+ k = re.sub(r'stages\.(\d+).blocks\.(\d+).proj', f'stages.\\1.blocks.\\2.shortcut_proj_attn', k)
+ else:
+ k = re.sub(r'stages\.(\d+).blocks\.(\d+).proj', f'stages.\\1.blocks.\\2.shortcut_proj_mlp', k)
+ if 'head' in k:
+ k = k.replace('head.projection', 'head.fc')
+ out_dict[k] = v
+
+ # for k, v in state_dict.items():
+ # if model.pos_embed is not None and k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
+ # # To resize pos embedding when using model at different size from pretrained weights
+ # v = resize_pos_embed(
+ # v,
+ # model.pos_embed,
+ # 0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1),
+ # model.patch_embed.grid_size
+ # )
+
+ return out_dict
+
+
+model_cfgs = dict(
+ mvitv2_tiny=MultiScaleVitCfg(
+ depths=(1, 2, 5, 2),
+ ),
+ mvitv2_small=MultiScaleVitCfg(
+ depths=(1, 2, 11, 2),
+ ),
+ mvitv2_base=MultiScaleVitCfg(
+ depths=(2, 3, 16, 3),
+ ),
+ mvitv2_large=MultiScaleVitCfg(
+ depths=(2, 6, 36, 4),
+ embed_dim=144,
+ num_heads=2,
+ expand_attn=False,
+ ),
+
+ mvitv2_small_cls=MultiScaleVitCfg(
+ depths=(1, 2, 11, 2),
+ use_cls_token=True,
+ ),
+ mvitv2_base_cls=MultiScaleVitCfg(
+ depths=(2, 3, 16, 3),
+ use_cls_token=True,
+ ),
+ mvitv2_large_cls=MultiScaleVitCfg(
+ depths=(2, 6, 36, 4),
+ embed_dim=144,
+ num_heads=2,
+ use_cls_token=True,
+ expand_attn=True,
+ ),
+ mvitv2_huge_cls=MultiScaleVitCfg(
+ depths=(4, 8, 60, 8),
+ embed_dim=192,
+ num_heads=3,
+ use_cls_token=True,
+ expand_attn=True,
+ ),
+)
+
+
+def _create_mvitv2(variant, cfg_variant=None, pretrained=False, **kwargs):
+ if kwargs.get('features_only', None):
+ raise RuntimeError('features_only not implemented for Multiscale Vision Transformer models.')
+
+ return build_model_with_cfg(
+ MultiScaleVit,
+ variant,
+ pretrained,
+ model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
+ pretrained_filter_fn=checkpoint_filter_fn,
+ feature_cfg=dict(flatten_sequential=True),
+ **kwargs,
+ )
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': .9, 'interpolation': 'bicubic',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
+ 'fixed_input_size': True,
+ **kwargs
+ }
+
+
+default_cfgs = generate_default_cfgs({
+ 'mvitv2_tiny.fb_in1k': _cfg(
+ url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_T_in1k.pyth',
+ hf_hub_id='timm/'),
+ 'mvitv2_small.fb_in1k': _cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_S_in1k.pyth',
+ hf_hub_id='timm/'),
+ 'mvitv2_base.fb_in1k': _cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_B_in1k.pyth',
+ hf_hub_id='timm/'),
+ 'mvitv2_large.fb_in1k': _cfg(url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_L_in1k.pyth',
+ hf_hub_id='timm/'),
+
+ 'mvitv2_small_cls': _cfg(url=''),
+ 'mvitv2_base_cls.fb_inw21k': _cfg(
+ url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_B_in21k.pyth',
+ hf_hub_id='timm/',
+ num_classes=19168),
+ 'mvitv2_large_cls.fb_inw21k': _cfg(
+ url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_L_in21k.pyth',
+ hf_hub_id='timm/',
+ num_classes=19168),
+ 'mvitv2_huge_cls.fb_inw21k': _cfg(
+ url='https://dl.fbaipublicfiles.com/mvit/mvitv2_models/MViTv2_H_in21k.pyth',
+ hf_hub_id='timm/',
+ num_classes=19168),
+})
+
+
+@register_model
+def mvitv2_tiny(pretrained=False, **kwargs) -> MultiScaleVit:
+ return _create_mvitv2('mvitv2_tiny', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def mvitv2_small(pretrained=False, **kwargs) -> MultiScaleVit:
+ return _create_mvitv2('mvitv2_small', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def mvitv2_base(pretrained=False, **kwargs) -> MultiScaleVit:
+ return _create_mvitv2('mvitv2_base', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def mvitv2_large(pretrained=False, **kwargs) -> MultiScaleVit:
+ return _create_mvitv2('mvitv2_large', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def mvitv2_small_cls(pretrained=False, **kwargs) -> MultiScaleVit:
+ return _create_mvitv2('mvitv2_small_cls', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def mvitv2_base_cls(pretrained=False, **kwargs) -> MultiScaleVit:
+ return _create_mvitv2('mvitv2_base_cls', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def mvitv2_large_cls(pretrained=False, **kwargs) -> MultiScaleVit:
+ return _create_mvitv2('mvitv2_large_cls', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def mvitv2_huge_cls(pretrained=False, **kwargs) -> MultiScaleVit:
+ return _create_mvitv2('mvitv2_huge_cls', pretrained=pretrained, **kwargs)
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/nasnet.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/nasnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..954ee176b0685008f75f7bc750f08a9e2c8bd2bc
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/nasnet.py
@@ -0,0 +1,600 @@
+""" NasNet-A (Large)
+ nasnetalarge implementation grabbed from Cadene's pretrained models
+ https://github.com/Cadene/pretrained-models.pytorch
+"""
+from functools import partial
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from timm.layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier
+from ._builder import build_model_with_cfg
+from ._registry import register_model, generate_default_cfgs
+
+__all__ = ['NASNetALarge']
+
+
+
+class ActConvBn(nn.Module):
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=''):
+ super(ActConvBn, self).__init__()
+ self.act = nn.ReLU()
+ self.conv = create_conv2d(
+ in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
+ self.bn = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1)
+
+ def forward(self, x):
+ x = self.act(x)
+ x = self.conv(x)
+ x = self.bn(x)
+ return x
+
+
+class SeparableConv2d(nn.Module):
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding=''):
+ super(SeparableConv2d, self).__init__()
+ self.depthwise_conv2d = create_conv2d(
+ in_channels, in_channels, kernel_size=kernel_size,
+ stride=stride, padding=padding, groups=in_channels)
+ self.pointwise_conv2d = create_conv2d(
+ in_channels, out_channels, kernel_size=1, padding=0)
+
+ def forward(self, x):
+ x = self.depthwise_conv2d(x)
+ x = self.pointwise_conv2d(x)
+ return x
+
+
+class BranchSeparables(nn.Module):
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, pad_type='', stem_cell=False):
+ super(BranchSeparables, self).__init__()
+ middle_channels = out_channels if stem_cell else in_channels
+ self.act_1 = nn.ReLU()
+ self.separable_1 = SeparableConv2d(
+ in_channels, middle_channels, kernel_size, stride=stride, padding=pad_type)
+ self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001, momentum=0.1)
+ self.act_2 = nn.ReLU(inplace=True)
+ self.separable_2 = SeparableConv2d(
+ middle_channels, out_channels, kernel_size, stride=1, padding=pad_type)
+ self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1)
+
+ def forward(self, x):
+ x = self.act_1(x)
+ x = self.separable_1(x)
+ x = self.bn_sep_1(x)
+ x = self.act_2(x)
+ x = self.separable_2(x)
+ x = self.bn_sep_2(x)
+ return x
+
+
+class CellStem0(nn.Module):
+ def __init__(self, stem_size, num_channels=42, pad_type=''):
+ super(CellStem0, self).__init__()
+ self.num_channels = num_channels
+ self.stem_size = stem_size
+ self.conv_1x1 = ActConvBn(self.stem_size, self.num_channels, 1, stride=1)
+
+ self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type)
+ self.comb_iter_0_right = BranchSeparables(self.stem_size, self.num_channels, 7, 2, pad_type, stem_cell=True)
+
+ self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type)
+ self.comb_iter_1_right = BranchSeparables(self.stem_size, self.num_channels, 7, 2, pad_type, stem_cell=True)
+
+ self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type)
+ self.comb_iter_2_right = BranchSeparables(self.stem_size, self.num_channels, 5, 2, pad_type, stem_cell=True)
+
+ self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
+
+ self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, pad_type)
+ self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type)
+
+ def forward(self, x):
+ x1 = self.conv_1x1(x)
+
+ x_comb_iter_0_left = self.comb_iter_0_left(x1)
+ x_comb_iter_0_right = self.comb_iter_0_right(x)
+ x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
+
+ x_comb_iter_1_left = self.comb_iter_1_left(x1)
+ x_comb_iter_1_right = self.comb_iter_1_right(x)
+ x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
+
+ x_comb_iter_2_left = self.comb_iter_2_left(x1)
+ x_comb_iter_2_right = self.comb_iter_2_right(x)
+ x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
+
+ x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
+ x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
+
+ x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
+ x_comb_iter_4_right = self.comb_iter_4_right(x1)
+ x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
+
+ x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
+ return x_out
+
+
+class CellStem1(nn.Module):
+
+ def __init__(self, stem_size, num_channels, pad_type=''):
+ super(CellStem1, self).__init__()
+ self.num_channels = num_channels
+ self.stem_size = stem_size
+ self.conv_1x1 = ActConvBn(2 * self.num_channels, self.num_channels, 1, stride=1)
+
+ self.act = nn.ReLU()
+ self.path_1 = nn.Sequential()
+ self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
+ self.path_1.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False))
+
+ self.path_2 = nn.Sequential()
+ self.path_2.add_module('pad', nn.ZeroPad2d((-1, 1, -1, 1)))
+ self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
+ self.path_2.add_module('conv', nn.Conv2d(self.stem_size, self.num_channels // 2, 1, stride=1, bias=False))
+
+ self.final_path_bn = nn.BatchNorm2d(self.num_channels, eps=0.001, momentum=0.1)
+
+ self.comb_iter_0_left = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type)
+ self.comb_iter_0_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, pad_type)
+
+ self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type)
+ self.comb_iter_1_right = BranchSeparables(self.num_channels, self.num_channels, 7, 2, pad_type)
+
+ self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type)
+ self.comb_iter_2_right = BranchSeparables(self.num_channels, self.num_channels, 5, 2, pad_type)
+
+ self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
+
+ self.comb_iter_4_left = BranchSeparables(self.num_channels, self.num_channels, 3, 1, pad_type)
+ self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type)
+
+ def forward(self, x_conv0, x_stem_0):
+ x_left = self.conv_1x1(x_stem_0)
+
+ x_relu = self.act(x_conv0)
+ # path 1
+ x_path1 = self.path_1(x_relu)
+ # path 2
+ x_path2 = self.path_2(x_relu)
+ # final path
+ x_right = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
+
+ x_comb_iter_0_left = self.comb_iter_0_left(x_left)
+ x_comb_iter_0_right = self.comb_iter_0_right(x_right)
+ x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
+
+ x_comb_iter_1_left = self.comb_iter_1_left(x_left)
+ x_comb_iter_1_right = self.comb_iter_1_right(x_right)
+ x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
+
+ x_comb_iter_2_left = self.comb_iter_2_left(x_left)
+ x_comb_iter_2_right = self.comb_iter_2_right(x_right)
+ x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
+
+ x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
+ x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
+
+ x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
+ x_comb_iter_4_right = self.comb_iter_4_right(x_left)
+ x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
+
+ x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
+ return x_out
+
+
+class FirstCell(nn.Module):
+
+ def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''):
+ super(FirstCell, self).__init__()
+ self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1)
+
+ self.act = nn.ReLU()
+ self.path_1 = nn.Sequential()
+ self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
+ self.path_1.add_module('conv', nn.Conv2d(in_chs_left, out_chs_left, 1, stride=1, bias=False))
+
+ self.path_2 = nn.Sequential()
+ self.path_2.add_module('pad', nn.ZeroPad2d((-1, 1, -1, 1)))
+ self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
+ self.path_2.add_module('conv', nn.Conv2d(in_chs_left, out_chs_left, 1, stride=1, bias=False))
+
+ self.final_path_bn = nn.BatchNorm2d(out_chs_left * 2, eps=0.001, momentum=0.1)
+
+ self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type)
+ self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
+
+ self.comb_iter_1_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type)
+ self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
+
+ self.comb_iter_2_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
+
+ self.comb_iter_3_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
+ self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
+
+ self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
+
+ def forward(self, x, x_prev):
+ x_relu = self.act(x_prev)
+ x_path1 = self.path_1(x_relu)
+ x_path2 = self.path_2(x_relu)
+ x_left = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
+ x_right = self.conv_1x1(x)
+
+ x_comb_iter_0_left = self.comb_iter_0_left(x_right)
+ x_comb_iter_0_right = self.comb_iter_0_right(x_left)
+ x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
+
+ x_comb_iter_1_left = self.comb_iter_1_left(x_left)
+ x_comb_iter_1_right = self.comb_iter_1_right(x_left)
+ x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
+
+ x_comb_iter_2_left = self.comb_iter_2_left(x_right)
+ x_comb_iter_2 = x_comb_iter_2_left + x_left
+
+ x_comb_iter_3_left = self.comb_iter_3_left(x_left)
+ x_comb_iter_3_right = self.comb_iter_3_right(x_left)
+ x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right
+
+ x_comb_iter_4_left = self.comb_iter_4_left(x_right)
+ x_comb_iter_4 = x_comb_iter_4_left + x_right
+
+ x_out = torch.cat([x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
+ return x_out
+
+
+class NormalCell(nn.Module):
+
+ def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''):
+ super(NormalCell, self).__init__()
+ self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type)
+ self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type)
+
+ self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 1, pad_type)
+ self.comb_iter_0_right = BranchSeparables(out_chs_left, out_chs_left, 3, 1, pad_type)
+
+ self.comb_iter_1_left = BranchSeparables(out_chs_left, out_chs_left, 5, 1, pad_type)
+ self.comb_iter_1_right = BranchSeparables(out_chs_left, out_chs_left, 3, 1, pad_type)
+
+ self.comb_iter_2_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
+
+ self.comb_iter_3_left = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
+ self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
+
+ self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
+
+ def forward(self, x, x_prev):
+ x_left = self.conv_prev_1x1(x_prev)
+ x_right = self.conv_1x1(x)
+
+ x_comb_iter_0_left = self.comb_iter_0_left(x_right)
+ x_comb_iter_0_right = self.comb_iter_0_right(x_left)
+ x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
+
+ x_comb_iter_1_left = self.comb_iter_1_left(x_left)
+ x_comb_iter_1_right = self.comb_iter_1_right(x_left)
+ x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
+
+ x_comb_iter_2_left = self.comb_iter_2_left(x_right)
+ x_comb_iter_2 = x_comb_iter_2_left + x_left
+
+ x_comb_iter_3_left = self.comb_iter_3_left(x_left)
+ x_comb_iter_3_right = self.comb_iter_3_right(x_left)
+ x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right
+
+ x_comb_iter_4_left = self.comb_iter_4_left(x_right)
+ x_comb_iter_4 = x_comb_iter_4_left + x_right
+
+ x_out = torch.cat([x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
+ return x_out
+
+
+class ReductionCell0(nn.Module):
+
+ def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''):
+ super(ReductionCell0, self).__init__()
+ self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type)
+ self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type)
+
+ self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type)
+ self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type)
+
+ self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type)
+ self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type)
+
+ self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type)
+ self.comb_iter_2_right = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type)
+
+ self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
+
+ self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
+ self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type)
+
+ def forward(self, x, x_prev):
+ x_left = self.conv_prev_1x1(x_prev)
+ x_right = self.conv_1x1(x)
+
+ x_comb_iter_0_left = self.comb_iter_0_left(x_right)
+ x_comb_iter_0_right = self.comb_iter_0_right(x_left)
+ x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
+
+ x_comb_iter_1_left = self.comb_iter_1_left(x_right)
+ x_comb_iter_1_right = self.comb_iter_1_right(x_left)
+ x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
+
+ x_comb_iter_2_left = self.comb_iter_2_left(x_right)
+ x_comb_iter_2_right = self.comb_iter_2_right(x_left)
+ x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
+
+ x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
+ x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
+
+ x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
+ x_comb_iter_4_right = self.comb_iter_4_right(x_right)
+ x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
+
+ x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
+ return x_out
+
+
+class ReductionCell1(nn.Module):
+
+ def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''):
+ super(ReductionCell1, self).__init__()
+ self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, 1, stride=1, padding=pad_type)
+ self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, 1, stride=1, padding=pad_type)
+
+ self.comb_iter_0_left = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type)
+ self.comb_iter_0_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type)
+
+ self.comb_iter_1_left = create_pool2d('max', 3, 2, padding=pad_type)
+ self.comb_iter_1_right = BranchSeparables(out_chs_right, out_chs_right, 7, 2, pad_type)
+
+ self.comb_iter_2_left = create_pool2d('avg', 3, 2, count_include_pad=False, padding=pad_type)
+ self.comb_iter_2_right = BranchSeparables(out_chs_right, out_chs_right, 5, 2, pad_type)
+
+ self.comb_iter_3_right = create_pool2d('avg', 3, 1, count_include_pad=False, padding=pad_type)
+
+ self.comb_iter_4_left = BranchSeparables(out_chs_right, out_chs_right, 3, 1, pad_type)
+ self.comb_iter_4_right = create_pool2d('max', 3, 2, padding=pad_type)
+
+ def forward(self, x, x_prev):
+ x_left = self.conv_prev_1x1(x_prev)
+ x_right = self.conv_1x1(x)
+
+ x_comb_iter_0_left = self.comb_iter_0_left(x_right)
+ x_comb_iter_0_right = self.comb_iter_0_right(x_left)
+ x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
+
+ x_comb_iter_1_left = self.comb_iter_1_left(x_right)
+ x_comb_iter_1_right = self.comb_iter_1_right(x_left)
+ x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
+
+ x_comb_iter_2_left = self.comb_iter_2_left(x_right)
+ x_comb_iter_2_right = self.comb_iter_2_right(x_left)
+ x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
+
+ x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
+ x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
+
+ x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
+ x_comb_iter_4_right = self.comb_iter_4_right(x_right)
+ x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
+
+ x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
+ return x_out
+
+
+class NASNetALarge(nn.Module):
+ """NASNetALarge (6 @ 4032) """
+
+ def __init__(
+ self,
+ num_classes=1000,
+ in_chans=3,
+ stem_size=96,
+ channel_multiplier=2,
+ num_features=4032,
+ output_stride=32,
+ drop_rate=0.,
+ global_pool='avg',
+ pad_type='same',
+ ):
+ super(NASNetALarge, self).__init__()
+ self.num_classes = num_classes
+ self.stem_size = stem_size
+ self.num_features = num_features
+ self.channel_multiplier = channel_multiplier
+ assert output_stride == 32
+
+ channels = self.num_features // 24
+ # 24 is default value for the architecture
+
+ self.conv0 = ConvNormAct(
+ in_channels=in_chans, out_channels=self.stem_size, kernel_size=3, padding=0, stride=2,
+ norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.1), apply_act=False)
+
+ self.cell_stem_0 = CellStem0(
+ self.stem_size, num_channels=channels // (channel_multiplier ** 2), pad_type=pad_type)
+ self.cell_stem_1 = CellStem1(
+ self.stem_size, num_channels=channels // channel_multiplier, pad_type=pad_type)
+
+ self.cell_0 = FirstCell(
+ in_chs_left=channels, out_chs_left=channels // 2,
+ in_chs_right=2 * channels, out_chs_right=channels, pad_type=pad_type)
+ self.cell_1 = NormalCell(
+ in_chs_left=2 * channels, out_chs_left=channels,
+ in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type)
+ self.cell_2 = NormalCell(
+ in_chs_left=6 * channels, out_chs_left=channels,
+ in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type)
+ self.cell_3 = NormalCell(
+ in_chs_left=6 * channels, out_chs_left=channels,
+ in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type)
+ self.cell_4 = NormalCell(
+ in_chs_left=6 * channels, out_chs_left=channels,
+ in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type)
+ self.cell_5 = NormalCell(
+ in_chs_left=6 * channels, out_chs_left=channels,
+ in_chs_right=6 * channels, out_chs_right=channels, pad_type=pad_type)
+
+ self.reduction_cell_0 = ReductionCell0(
+ in_chs_left=6 * channels, out_chs_left=2 * channels,
+ in_chs_right=6 * channels, out_chs_right=2 * channels, pad_type=pad_type)
+ self.cell_6 = FirstCell(
+ in_chs_left=6 * channels, out_chs_left=channels,
+ in_chs_right=8 * channels, out_chs_right=2 * channels, pad_type=pad_type)
+ self.cell_7 = NormalCell(
+ in_chs_left=8 * channels, out_chs_left=2 * channels,
+ in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type)
+ self.cell_8 = NormalCell(
+ in_chs_left=12 * channels, out_chs_left=2 * channels,
+ in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type)
+ self.cell_9 = NormalCell(
+ in_chs_left=12 * channels, out_chs_left=2 * channels,
+ in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type)
+ self.cell_10 = NormalCell(
+ in_chs_left=12 * channels, out_chs_left=2 * channels,
+ in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type)
+ self.cell_11 = NormalCell(
+ in_chs_left=12 * channels, out_chs_left=2 * channels,
+ in_chs_right=12 * channels, out_chs_right=2 * channels, pad_type=pad_type)
+
+ self.reduction_cell_1 = ReductionCell1(
+ in_chs_left=12 * channels, out_chs_left=4 * channels,
+ in_chs_right=12 * channels, out_chs_right=4 * channels, pad_type=pad_type)
+ self.cell_12 = FirstCell(
+ in_chs_left=12 * channels, out_chs_left=2 * channels,
+ in_chs_right=16 * channels, out_chs_right=4 * channels, pad_type=pad_type)
+ self.cell_13 = NormalCell(
+ in_chs_left=16 * channels, out_chs_left=4 * channels,
+ in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type)
+ self.cell_14 = NormalCell(
+ in_chs_left=24 * channels, out_chs_left=4 * channels,
+ in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type)
+ self.cell_15 = NormalCell(
+ in_chs_left=24 * channels, out_chs_left=4 * channels,
+ in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type)
+ self.cell_16 = NormalCell(
+ in_chs_left=24 * channels, out_chs_left=4 * channels,
+ in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type)
+ self.cell_17 = NormalCell(
+ in_chs_left=24 * channels, out_chs_left=4 * channels,
+ in_chs_right=24 * channels, out_chs_right=4 * channels, pad_type=pad_type)
+ self.act = nn.ReLU(inplace=True)
+ self.feature_info = [
+ dict(num_chs=96, reduction=2, module='conv0'),
+ dict(num_chs=168, reduction=4, module='cell_stem_1.conv_1x1.act'),
+ dict(num_chs=1008, reduction=8, module='reduction_cell_0.conv_1x1.act'),
+ dict(num_chs=2016, reduction=16, module='reduction_cell_1.conv_1x1.act'),
+ dict(num_chs=4032, reduction=32, module='act'),
+ ]
+
+ self.global_pool, self.head_drop, self.last_linear = create_classifier(
+ self.num_features, self.num_classes, pool_type=global_pool, drop_rate=drop_rate)
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse=False):
+ matcher = dict(
+ stem=r'^conv0|cell_stem_[01]',
+ blocks=[
+ (r'^cell_(\d+)', None),
+ (r'^reduction_cell_0', (6,)),
+ (r'^reduction_cell_1', (12,)),
+ ]
+ )
+ return matcher
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ assert not enable, 'gradient checkpointing not supported'
+
+ @torch.jit.ignore
+ def get_classifier(self):
+ return self.last_linear
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.num_classes = num_classes
+ self.global_pool, self.last_linear = create_classifier(
+ self.num_features, self.num_classes, pool_type=global_pool)
+
+ def forward_features(self, x):
+ x_conv0 = self.conv0(x)
+
+ x_stem_0 = self.cell_stem_0(x_conv0)
+ x_stem_1 = self.cell_stem_1(x_conv0, x_stem_0)
+
+ x_cell_0 = self.cell_0(x_stem_1, x_stem_0)
+ x_cell_1 = self.cell_1(x_cell_0, x_stem_1)
+ x_cell_2 = self.cell_2(x_cell_1, x_cell_0)
+ x_cell_3 = self.cell_3(x_cell_2, x_cell_1)
+ x_cell_4 = self.cell_4(x_cell_3, x_cell_2)
+ x_cell_5 = self.cell_5(x_cell_4, x_cell_3)
+
+ x_reduction_cell_0 = self.reduction_cell_0(x_cell_5, x_cell_4)
+ x_cell_6 = self.cell_6(x_reduction_cell_0, x_cell_4)
+ x_cell_7 = self.cell_7(x_cell_6, x_reduction_cell_0)
+ x_cell_8 = self.cell_8(x_cell_7, x_cell_6)
+ x_cell_9 = self.cell_9(x_cell_8, x_cell_7)
+ x_cell_10 = self.cell_10(x_cell_9, x_cell_8)
+ x_cell_11 = self.cell_11(x_cell_10, x_cell_9)
+
+ x_reduction_cell_1 = self.reduction_cell_1(x_cell_11, x_cell_10)
+ x_cell_12 = self.cell_12(x_reduction_cell_1, x_cell_10)
+ x_cell_13 = self.cell_13(x_cell_12, x_reduction_cell_1)
+ x_cell_14 = self.cell_14(x_cell_13, x_cell_12)
+ x_cell_15 = self.cell_15(x_cell_14, x_cell_13)
+ x_cell_16 = self.cell_16(x_cell_15, x_cell_14)
+ x_cell_17 = self.cell_17(x_cell_16, x_cell_15)
+ x = self.act(x_cell_17)
+ return x
+
+ def forward_head(self, x):
+ x = self.global_pool(x)
+ x = self.head_drop(x)
+ x = self.last_linear(x)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.forward_head(x)
+ return x
+
+
+def _create_nasnet(variant, pretrained=False, **kwargs):
+ return build_model_with_cfg(
+ NASNetALarge,
+ variant,
+ pretrained,
+ feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model
+ **kwargs,
+ )
+
+
+default_cfgs = generate_default_cfgs({
+ 'nasnetalarge.tf_in1k': {
+ 'hf_hub_id': 'timm/',
+ 'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/nasnetalarge-dc4a7b8b.pth',
+ 'input_size': (3, 331, 331),
+ 'pool_size': (11, 11),
+ 'crop_pct': 0.911,
+ 'interpolation': 'bicubic',
+ 'mean': (0.5, 0.5, 0.5),
+ 'std': (0.5, 0.5, 0.5),
+ 'num_classes': 1000,
+ 'first_conv': 'conv0.conv',
+ 'classifier': 'last_linear',
+ },
+})
+
+
+@register_model
+def nasnetalarge(pretrained=False, **kwargs) -> NASNetALarge:
+ """NASNet-A large model architecture.
+ """
+ model_kwargs = dict(pad_type='same', **kwargs)
+ return _create_nasnet('nasnetalarge', pretrained, **model_kwargs)
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/pit.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/pit.py
new file mode 100644
index 0000000000000000000000000000000000000000..993606d518c6abad89550c81632b5faf86f2a520
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/pit.py
@@ -0,0 +1,458 @@
+""" Pooling-based Vision Transformer (PiT) in PyTorch
+
+A PyTorch implement of Pooling-based Vision Transformers as described in
+'Rethinking Spatial Dimensions of Vision Transformers' - https://arxiv.org/abs/2103.16302
+
+This code was adapted from the original version at https://github.com/naver-ai/pit, original copyright below.
+
+Modifications for timm by / Copyright 2020 Ross Wightman
+"""
+# PiT
+# Copyright 2021-present NAVER Corp.
+# Apache License v2.0
+
+import math
+import re
+from functools import partial
+from typing import Sequence, Tuple
+
+import torch
+from torch import nn
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.layers import trunc_normal_, to_2tuple, LayerNorm
+from ._builder import build_model_with_cfg
+from ._registry import register_model, generate_default_cfgs
+from .vision_transformer import Block
+
+
+__all__ = ['PoolingVisionTransformer'] # model_registry will add each entrypoint fn to this
+
+
+class SequentialTuple(nn.Sequential):
+ """ This module exists to work around torchscript typing issues list -> list"""
+ def __init__(self, *args):
+ super(SequentialTuple, self).__init__(*args)
+
+ def forward(self, x: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
+ for module in self:
+ x = module(x)
+ return x
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self,
+ base_dim,
+ depth,
+ heads,
+ mlp_ratio,
+ pool=None,
+ proj_drop=.0,
+ attn_drop=.0,
+ drop_path_prob=None,
+ norm_layer=None,
+ ):
+ super(Transformer, self).__init__()
+ embed_dim = base_dim * heads
+
+ self.pool = pool
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+ self.blocks = nn.Sequential(*[
+ Block(
+ dim=embed_dim,
+ num_heads=heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=True,
+ proj_drop=proj_drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path_prob[i],
+ norm_layer=partial(nn.LayerNorm, eps=1e-6)
+ )
+ for i in range(depth)])
+
+ def forward(self, x: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
+ x, cls_tokens = x
+ token_length = cls_tokens.shape[1]
+ if self.pool is not None:
+ x, cls_tokens = self.pool(x, cls_tokens)
+
+ B, C, H, W = x.shape
+ x = x.flatten(2).transpose(1, 2)
+ x = torch.cat((cls_tokens, x), dim=1)
+
+ x = self.norm(x)
+ x = self.blocks(x)
+
+ cls_tokens = x[:, :token_length]
+ x = x[:, token_length:]
+ x = x.transpose(1, 2).reshape(B, C, H, W)
+
+ return x, cls_tokens
+
+
+class Pooling(nn.Module):
+ def __init__(self, in_feature, out_feature, stride, padding_mode='zeros'):
+ super(Pooling, self).__init__()
+
+ self.conv = nn.Conv2d(
+ in_feature,
+ out_feature,
+ kernel_size=stride + 1,
+ padding=stride // 2,
+ stride=stride,
+ padding_mode=padding_mode,
+ groups=in_feature,
+ )
+ self.fc = nn.Linear(in_feature, out_feature)
+
+ def forward(self, x, cls_token) -> Tuple[torch.Tensor, torch.Tensor]:
+ x = self.conv(x)
+ cls_token = self.fc(cls_token)
+ return x, cls_token
+
+
+class ConvEmbedding(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ img_size: int = 224,
+ patch_size: int = 16,
+ stride: int = 8,
+ padding: int = 0,
+ ):
+ super(ConvEmbedding, self).__init__()
+ padding = padding
+ self.img_size = to_2tuple(img_size)
+ self.patch_size = to_2tuple(patch_size)
+ self.height = math.floor((self.img_size[0] + 2 * padding - self.patch_size[0]) / stride + 1)
+ self.width = math.floor((self.img_size[1] + 2 * padding - self.patch_size[1]) / stride + 1)
+ self.grid_size = (self.height, self.width)
+
+ self.conv = nn.Conv2d(
+ in_channels, out_channels, kernel_size=patch_size,
+ stride=stride, padding=padding, bias=True)
+
+ def forward(self, x):
+ x = self.conv(x)
+ return x
+
+
+class PoolingVisionTransformer(nn.Module):
+ """ Pooling-based Vision Transformer
+
+ A PyTorch implement of 'Rethinking Spatial Dimensions of Vision Transformers'
+ - https://arxiv.org/abs/2103.16302
+ """
+ def __init__(
+ self,
+ img_size: int = 224,
+ patch_size: int = 16,
+ stride: int = 8,
+ stem_type: str = 'overlap',
+ base_dims: Sequence[int] = (48, 48, 48),
+ depth: Sequence[int] = (2, 6, 4),
+ heads: Sequence[int] = (2, 4, 8),
+ mlp_ratio: float = 4,
+ num_classes=1000,
+ in_chans=3,
+ global_pool='token',
+ distilled=False,
+ drop_rate=0.,
+ pos_drop_drate=0.,
+ proj_drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ ):
+ super(PoolingVisionTransformer, self).__init__()
+ assert global_pool in ('token',)
+
+ self.base_dims = base_dims
+ self.heads = heads
+ embed_dim = base_dims[0] * heads[0]
+ self.num_classes = num_classes
+ self.global_pool = global_pool
+ self.num_tokens = 2 if distilled else 1
+ self.feature_info = []
+
+ self.patch_embed = ConvEmbedding(in_chans, embed_dim, img_size, patch_size, stride)
+ self.pos_embed = nn.Parameter(torch.randn(1, embed_dim, self.patch_embed.height, self.patch_embed.width))
+ self.cls_token = nn.Parameter(torch.randn(1, self.num_tokens, embed_dim))
+ self.pos_drop = nn.Dropout(p=pos_drop_drate)
+
+ transformers = []
+ # stochastic depth decay rule
+ dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depth)).split(depth)]
+ prev_dim = embed_dim
+ for i in range(len(depth)):
+ pool = None
+ embed_dim = base_dims[i] * heads[i]
+ if i > 0:
+ pool = Pooling(
+ prev_dim,
+ embed_dim,
+ stride=2,
+ )
+ transformers += [Transformer(
+ base_dims[i],
+ depth[i],
+ heads[i],
+ mlp_ratio,
+ pool=pool,
+ proj_drop=proj_drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path_prob=dpr[i],
+ )]
+ prev_dim = embed_dim
+ self.feature_info += [dict(num_chs=prev_dim, reduction=(stride - 1) * 2**i, module=f'transformers.{i}')]
+
+ self.transformers = SequentialTuple(*transformers)
+ self.norm = nn.LayerNorm(base_dims[-1] * heads[-1], eps=1e-6)
+ self.num_features = self.embed_dim = embed_dim
+
+ # Classifier head
+ self.head_drop = nn.Dropout(drop_rate)
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+ self.head_dist = None
+ if distilled:
+ self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
+ self.distilled_training = False # must set this True to train w/ distillation token
+
+ trunc_normal_(self.pos_embed, std=.02)
+ trunc_normal_(self.cls_token, std=.02)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'pos_embed', 'cls_token'}
+
+ @torch.jit.ignore
+ def set_distilled_training(self, enable=True):
+ self.distilled_training = enable
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ assert not enable, 'gradient checkpointing not supported'
+
+ def get_classifier(self):
+ if self.head_dist is not None:
+ return self.head, self.head_dist
+ else:
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=None):
+ self.num_classes = num_classes
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+ if self.head_dist is not None:
+ self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ x = self.patch_embed(x)
+ x = self.pos_drop(x + self.pos_embed)
+ cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
+ x, cls_tokens = self.transformers((x, cls_tokens))
+ cls_tokens = self.norm(cls_tokens)
+ return cls_tokens
+
+ def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
+ if self.head_dist is not None:
+ assert self.global_pool == 'token'
+ x, x_dist = x[:, 0], x[:, 1]
+ x = self.head_drop(x)
+ x_dist = self.head_drop(x)
+ if not pre_logits:
+ x = self.head(x)
+ x_dist = self.head_dist(x_dist)
+ if self.distilled_training and self.training and not torch.jit.is_scripting():
+ # only return separate classification predictions when training in distilled mode
+ return x, x_dist
+ else:
+ # during standard train / finetune, inference average the classifier predictions
+ return (x + x_dist) / 2
+ else:
+ if self.global_pool == 'token':
+ x = x[:, 0]
+ x = self.head_drop(x)
+ if not pre_logits:
+ x = self.head(x)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.forward_head(x)
+ return x
+
+
+def checkpoint_filter_fn(state_dict, model):
+ """ preprocess checkpoints """
+ out_dict = {}
+ p_blocks = re.compile(r'pools\.(\d)\.')
+ for k, v in state_dict.items():
+ # FIXME need to update resize for PiT impl
+ # if k == 'pos_embed' and v.shape != model.pos_embed.shape:
+ # # To resize pos embedding when using model at different size from pretrained weights
+ # v = resize_pos_embed(v, model.pos_embed)
+ k = p_blocks.sub(lambda exp: f'transformers.{int(exp.group(1)) + 1}.pool.', k)
+ out_dict[k] = v
+ return out_dict
+
+
+def _create_pit(variant, pretrained=False, **kwargs):
+ default_out_indices = tuple(range(3))
+ out_indices = kwargs.pop('out_indices', default_out_indices)
+
+ model = build_model_with_cfg(
+ PoolingVisionTransformer,
+ variant,
+ pretrained,
+ pretrained_filter_fn=checkpoint_filter_fn,
+ feature_cfg=dict(feature_cls='hook', no_rewrite=True, out_indices=out_indices),
+ **kwargs,
+ )
+ return model
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'patch_embed.conv', 'classifier': 'head',
+ **kwargs
+ }
+
+
+default_cfgs = generate_default_cfgs({
+ # deit models (FB weights)
+ 'pit_ti_224.in1k': _cfg(hf_hub_id='timm/'),
+ 'pit_xs_224.in1k': _cfg(hf_hub_id='timm/'),
+ 'pit_s_224.in1k': _cfg(hf_hub_id='timm/'),
+ 'pit_b_224.in1k': _cfg(hf_hub_id='timm/'),
+ 'pit_ti_distilled_224.in1k': _cfg(
+ hf_hub_id='timm/',
+ classifier=('head', 'head_dist')),
+ 'pit_xs_distilled_224.in1k': _cfg(
+ hf_hub_id='timm/',
+ classifier=('head', 'head_dist')),
+ 'pit_s_distilled_224.in1k': _cfg(
+ hf_hub_id='timm/',
+ classifier=('head', 'head_dist')),
+ 'pit_b_distilled_224.in1k': _cfg(
+ hf_hub_id='timm/',
+ classifier=('head', 'head_dist')),
+})
+
+
+@register_model
+def pit_b_224(pretrained=False, **kwargs) -> PoolingVisionTransformer:
+ model_args = dict(
+ patch_size=14,
+ stride=7,
+ base_dims=[64, 64, 64],
+ depth=[3, 6, 4],
+ heads=[4, 8, 16],
+ mlp_ratio=4,
+ )
+ return _create_pit('pit_b_224', pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def pit_s_224(pretrained=False, **kwargs) -> PoolingVisionTransformer:
+ model_args = dict(
+ patch_size=16,
+ stride=8,
+ base_dims=[48, 48, 48],
+ depth=[2, 6, 4],
+ heads=[3, 6, 12],
+ mlp_ratio=4,
+ )
+ return _create_pit('pit_s_224', pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def pit_xs_224(pretrained=False, **kwargs) -> PoolingVisionTransformer:
+ model_args = dict(
+ patch_size=16,
+ stride=8,
+ base_dims=[48, 48, 48],
+ depth=[2, 6, 4],
+ heads=[2, 4, 8],
+ mlp_ratio=4,
+ )
+ return _create_pit('pit_xs_224', pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def pit_ti_224(pretrained=False, **kwargs) -> PoolingVisionTransformer:
+ model_args = dict(
+ patch_size=16,
+ stride=8,
+ base_dims=[32, 32, 32],
+ depth=[2, 6, 4],
+ heads=[2, 4, 8],
+ mlp_ratio=4,
+ )
+ return _create_pit('pit_ti_224', pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def pit_b_distilled_224(pretrained=False, **kwargs) -> PoolingVisionTransformer:
+ model_args = dict(
+ patch_size=14,
+ stride=7,
+ base_dims=[64, 64, 64],
+ depth=[3, 6, 4],
+ heads=[4, 8, 16],
+ mlp_ratio=4,
+ distilled=True,
+ )
+ return _create_pit('pit_b_distilled_224', pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def pit_s_distilled_224(pretrained=False, **kwargs) -> PoolingVisionTransformer:
+ model_args = dict(
+ patch_size=16,
+ stride=8,
+ base_dims=[48, 48, 48],
+ depth=[2, 6, 4],
+ heads=[3, 6, 12],
+ mlp_ratio=4,
+ distilled=True,
+ )
+ return _create_pit('pit_s_distilled_224', pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def pit_xs_distilled_224(pretrained=False, **kwargs) -> PoolingVisionTransformer:
+ model_args = dict(
+ patch_size=16,
+ stride=8,
+ base_dims=[48, 48, 48],
+ depth=[2, 6, 4],
+ heads=[2, 4, 8],
+ mlp_ratio=4,
+ distilled=True,
+ )
+ return _create_pit('pit_xs_distilled_224', pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def pit_ti_distilled_224(pretrained=False, **kwargs) -> PoolingVisionTransformer:
+ model_args = dict(
+ patch_size=16,
+ stride=8,
+ base_dims=[32, 32, 32],
+ depth=[2, 6, 4],
+ heads=[2, 4, 8],
+ mlp_ratio=4,
+ distilled=True,
+ )
+ return _create_pit('pit_ti_distilled_224', pretrained, **dict(model_args, **kwargs))
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/pnasnet.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/pnasnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..bee18604c3d2ffc1e58763775a4b7b86c1857b7c
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/pnasnet.py
@@ -0,0 +1,378 @@
+"""
+ pnasnet5large implementation grabbed from Cadene's pretrained models
+ Additional credit to https://github.com/creafz
+
+ https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/pnasnet.py
+
+"""
+from collections import OrderedDict
+from functools import partial
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from timm.layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier
+from ._builder import build_model_with_cfg
+from ._registry import register_model, generate_default_cfgs
+
+__all__ = ['PNASNet5Large']
+
+
+class SeparableConv2d(nn.Module):
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding=''):
+ super(SeparableConv2d, self).__init__()
+ self.depthwise_conv2d = create_conv2d(
+ in_channels, in_channels, kernel_size=kernel_size,
+ stride=stride, padding=padding, groups=in_channels)
+ self.pointwise_conv2d = create_conv2d(
+ in_channels, out_channels, kernel_size=1, padding=padding)
+
+ def forward(self, x):
+ x = self.depthwise_conv2d(x)
+ x = self.pointwise_conv2d(x)
+ return x
+
+
+class BranchSeparables(nn.Module):
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, stem_cell=False, padding=''):
+ super(BranchSeparables, self).__init__()
+ middle_channels = out_channels if stem_cell else in_channels
+ self.act_1 = nn.ReLU()
+ self.separable_1 = SeparableConv2d(
+ in_channels, middle_channels, kernel_size, stride=stride, padding=padding)
+ self.bn_sep_1 = nn.BatchNorm2d(middle_channels, eps=0.001)
+ self.act_2 = nn.ReLU()
+ self.separable_2 = SeparableConv2d(
+ middle_channels, out_channels, kernel_size, stride=1, padding=padding)
+ self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001)
+
+ def forward(self, x):
+ x = self.act_1(x)
+ x = self.separable_1(x)
+ x = self.bn_sep_1(x)
+ x = self.act_2(x)
+ x = self.separable_2(x)
+ x = self.bn_sep_2(x)
+ return x
+
+
+class ActConvBn(nn.Module):
+
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=''):
+ super(ActConvBn, self).__init__()
+ self.act = nn.ReLU()
+ self.conv = create_conv2d(
+ in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
+ self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
+
+ def forward(self, x):
+ x = self.act(x)
+ x = self.conv(x)
+ x = self.bn(x)
+ return x
+
+
+class FactorizedReduction(nn.Module):
+
+ def __init__(self, in_channels, out_channels, padding=''):
+ super(FactorizedReduction, self).__init__()
+ self.act = nn.ReLU()
+ self.path_1 = nn.Sequential(OrderedDict([
+ ('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)),
+ ('conv', create_conv2d(in_channels, out_channels // 2, kernel_size=1, padding=padding)),
+ ]))
+ self.path_2 = nn.Sequential(OrderedDict([
+ ('pad', nn.ZeroPad2d((-1, 1, -1, 1))), # shift
+ ('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False)),
+ ('conv', create_conv2d(in_channels, out_channels // 2, kernel_size=1, padding=padding)),
+ ]))
+ self.final_path_bn = nn.BatchNorm2d(out_channels, eps=0.001)
+
+ def forward(self, x):
+ x = self.act(x)
+ x_path1 = self.path_1(x)
+ x_path2 = self.path_2(x)
+ out = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
+ return out
+
+
+class CellBase(nn.Module):
+
+ def cell_forward(self, x_left, x_right):
+ x_comb_iter_0_left = self.comb_iter_0_left(x_left)
+ x_comb_iter_0_right = self.comb_iter_0_right(x_left)
+ x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
+
+ x_comb_iter_1_left = self.comb_iter_1_left(x_right)
+ x_comb_iter_1_right = self.comb_iter_1_right(x_right)
+ x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
+
+ x_comb_iter_2_left = self.comb_iter_2_left(x_right)
+ x_comb_iter_2_right = self.comb_iter_2_right(x_right)
+ x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
+
+ x_comb_iter_3_left = self.comb_iter_3_left(x_comb_iter_2)
+ x_comb_iter_3_right = self.comb_iter_3_right(x_right)
+ x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right
+
+ x_comb_iter_4_left = self.comb_iter_4_left(x_left)
+ if self.comb_iter_4_right is not None:
+ x_comb_iter_4_right = self.comb_iter_4_right(x_right)
+ else:
+ x_comb_iter_4_right = x_right
+ x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
+
+ x_out = torch.cat([x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
+ return x_out
+
+
+class CellStem0(CellBase):
+
+ def __init__(self, in_chs_left, out_chs_left, in_chs_right, out_chs_right, pad_type=''):
+ super(CellStem0, self).__init__()
+ self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=pad_type)
+
+ self.comb_iter_0_left = BranchSeparables(
+ in_chs_left, out_chs_left, kernel_size=5, stride=2, stem_cell=True, padding=pad_type)
+ self.comb_iter_0_right = nn.Sequential(OrderedDict([
+ ('max_pool', create_pool2d('max', 3, stride=2, padding=pad_type)),
+ ('conv', create_conv2d(in_chs_left, out_chs_left, kernel_size=1, padding=pad_type)),
+ ('bn', nn.BatchNorm2d(out_chs_left, eps=0.001)),
+ ]))
+
+ self.comb_iter_1_left = BranchSeparables(
+ out_chs_right, out_chs_right, kernel_size=7, stride=2, padding=pad_type)
+ self.comb_iter_1_right = create_pool2d('max', 3, stride=2, padding=pad_type)
+
+ self.comb_iter_2_left = BranchSeparables(
+ out_chs_right, out_chs_right, kernel_size=5, stride=2, padding=pad_type)
+ self.comb_iter_2_right = BranchSeparables(
+ out_chs_right, out_chs_right, kernel_size=3, stride=2, padding=pad_type)
+
+ self.comb_iter_3_left = BranchSeparables(
+ out_chs_right, out_chs_right, kernel_size=3, padding=pad_type)
+ self.comb_iter_3_right = create_pool2d('max', 3, stride=2, padding=pad_type)
+
+ self.comb_iter_4_left = BranchSeparables(
+ in_chs_right, out_chs_right, kernel_size=3, stride=2, stem_cell=True, padding=pad_type)
+ self.comb_iter_4_right = ActConvBn(
+ out_chs_right, out_chs_right, kernel_size=1, stride=2, padding=pad_type)
+
+ def forward(self, x_left):
+ x_right = self.conv_1x1(x_left)
+ x_out = self.cell_forward(x_left, x_right)
+ return x_out
+
+
+class Cell(CellBase):
+
+ def __init__(
+ self,
+ in_chs_left,
+ out_chs_left,
+ in_chs_right,
+ out_chs_right,
+ pad_type='',
+ is_reduction=False,
+ match_prev_layer_dims=False,
+ ):
+ super(Cell, self).__init__()
+
+ # If `is_reduction` is set to `True` stride 2 is used for
+ # convolution and pooling layers to reduce the spatial size of
+ # the output of a cell approximately by a factor of 2.
+ stride = 2 if is_reduction else 1
+
+ # If `match_prev_layer_dimensions` is set to `True`
+ # `FactorizedReduction` is used to reduce the spatial size
+ # of the left input of a cell approximately by a factor of 2.
+ self.match_prev_layer_dimensions = match_prev_layer_dims
+ if match_prev_layer_dims:
+ self.conv_prev_1x1 = FactorizedReduction(in_chs_left, out_chs_left, padding=pad_type)
+ else:
+ self.conv_prev_1x1 = ActConvBn(in_chs_left, out_chs_left, kernel_size=1, padding=pad_type)
+ self.conv_1x1 = ActConvBn(in_chs_right, out_chs_right, kernel_size=1, padding=pad_type)
+
+ self.comb_iter_0_left = BranchSeparables(
+ out_chs_left, out_chs_left, kernel_size=5, stride=stride, padding=pad_type)
+ self.comb_iter_0_right = create_pool2d('max', 3, stride=stride, padding=pad_type)
+
+ self.comb_iter_1_left = BranchSeparables(
+ out_chs_right, out_chs_right, kernel_size=7, stride=stride, padding=pad_type)
+ self.comb_iter_1_right = create_pool2d('max', 3, stride=stride, padding=pad_type)
+
+ self.comb_iter_2_left = BranchSeparables(
+ out_chs_right, out_chs_right, kernel_size=5, stride=stride, padding=pad_type)
+ self.comb_iter_2_right = BranchSeparables(
+ out_chs_right, out_chs_right, kernel_size=3, stride=stride, padding=pad_type)
+
+ self.comb_iter_3_left = BranchSeparables(out_chs_right, out_chs_right, kernel_size=3)
+ self.comb_iter_3_right = create_pool2d('max', 3, stride=stride, padding=pad_type)
+
+ self.comb_iter_4_left = BranchSeparables(
+ out_chs_left, out_chs_left, kernel_size=3, stride=stride, padding=pad_type)
+ if is_reduction:
+ self.comb_iter_4_right = ActConvBn(
+ out_chs_right, out_chs_right, kernel_size=1, stride=stride, padding=pad_type)
+ else:
+ self.comb_iter_4_right = None
+
+ def forward(self, x_left, x_right):
+ x_left = self.conv_prev_1x1(x_left)
+ x_right = self.conv_1x1(x_right)
+ x_out = self.cell_forward(x_left, x_right)
+ return x_out
+
+
+class PNASNet5Large(nn.Module):
+ def __init__(
+ self,
+ num_classes=1000,
+ in_chans=3,
+ output_stride=32,
+ drop_rate=0.,
+ global_pool='avg',
+ pad_type='',
+ ):
+ super(PNASNet5Large, self).__init__()
+ self.num_classes = num_classes
+ self.num_features = 4320
+ assert output_stride == 32
+
+ self.conv_0 = ConvNormAct(
+ in_chans, 96, kernel_size=3, stride=2, padding=0,
+ norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.1), apply_act=False)
+
+ self.cell_stem_0 = CellStem0(
+ in_chs_left=96, out_chs_left=54, in_chs_right=96, out_chs_right=54, pad_type=pad_type)
+
+ self.cell_stem_1 = Cell(
+ in_chs_left=96, out_chs_left=108, in_chs_right=270, out_chs_right=108, pad_type=pad_type,
+ match_prev_layer_dims=True, is_reduction=True)
+ self.cell_0 = Cell(
+ in_chs_left=270, out_chs_left=216, in_chs_right=540, out_chs_right=216, pad_type=pad_type,
+ match_prev_layer_dims=True)
+ self.cell_1 = Cell(
+ in_chs_left=540, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type)
+ self.cell_2 = Cell(
+ in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type)
+ self.cell_3 = Cell(
+ in_chs_left=1080, out_chs_left=216, in_chs_right=1080, out_chs_right=216, pad_type=pad_type)
+
+ self.cell_4 = Cell(
+ in_chs_left=1080, out_chs_left=432, in_chs_right=1080, out_chs_right=432, pad_type=pad_type,
+ is_reduction=True)
+ self.cell_5 = Cell(
+ in_chs_left=1080, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type,
+ match_prev_layer_dims=True)
+ self.cell_6 = Cell(
+ in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type)
+ self.cell_7 = Cell(
+ in_chs_left=2160, out_chs_left=432, in_chs_right=2160, out_chs_right=432, pad_type=pad_type)
+
+ self.cell_8 = Cell(
+ in_chs_left=2160, out_chs_left=864, in_chs_right=2160, out_chs_right=864, pad_type=pad_type,
+ is_reduction=True)
+ self.cell_9 = Cell(
+ in_chs_left=2160, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type,
+ match_prev_layer_dims=True)
+ self.cell_10 = Cell(
+ in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type)
+ self.cell_11 = Cell(
+ in_chs_left=4320, out_chs_left=864, in_chs_right=4320, out_chs_right=864, pad_type=pad_type)
+ self.act = nn.ReLU()
+ self.feature_info = [
+ dict(num_chs=96, reduction=2, module='conv_0'),
+ dict(num_chs=270, reduction=4, module='cell_stem_1.conv_1x1.act'),
+ dict(num_chs=1080, reduction=8, module='cell_4.conv_1x1.act'),
+ dict(num_chs=2160, reduction=16, module='cell_8.conv_1x1.act'),
+ dict(num_chs=4320, reduction=32, module='act'),
+ ]
+
+ self.global_pool, self.head_drop, self.last_linear = create_classifier(
+ self.num_features, self.num_classes, pool_type=global_pool, drop_rate=drop_rate)
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse=False):
+ return dict(stem=r'^conv_0|cell_stem_[01]', blocks=r'^cell_(\d+)')
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ assert not enable, 'gradient checkpointing not supported'
+
+ @torch.jit.ignore
+ def get_classifier(self):
+ return self.last_linear
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.num_classes = num_classes
+ self.global_pool, self.last_linear = create_classifier(
+ self.num_features, self.num_classes, pool_type=global_pool)
+
+ def forward_features(self, x):
+ x_conv_0 = self.conv_0(x)
+ x_stem_0 = self.cell_stem_0(x_conv_0)
+ x_stem_1 = self.cell_stem_1(x_conv_0, x_stem_0)
+ x_cell_0 = self.cell_0(x_stem_0, x_stem_1)
+ x_cell_1 = self.cell_1(x_stem_1, x_cell_0)
+ x_cell_2 = self.cell_2(x_cell_0, x_cell_1)
+ x_cell_3 = self.cell_3(x_cell_1, x_cell_2)
+ x_cell_4 = self.cell_4(x_cell_2, x_cell_3)
+ x_cell_5 = self.cell_5(x_cell_3, x_cell_4)
+ x_cell_6 = self.cell_6(x_cell_4, x_cell_5)
+ x_cell_7 = self.cell_7(x_cell_5, x_cell_6)
+ x_cell_8 = self.cell_8(x_cell_6, x_cell_7)
+ x_cell_9 = self.cell_9(x_cell_7, x_cell_8)
+ x_cell_10 = self.cell_10(x_cell_8, x_cell_9)
+ x_cell_11 = self.cell_11(x_cell_9, x_cell_10)
+ x = self.act(x_cell_11)
+ return x
+
+ def forward_head(self, x, pre_logits: bool = False):
+ x = self.global_pool(x)
+ x = self.head_drop(x)
+ return x if pre_logits else self.last_linear(x)
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.forward_head(x)
+ return x
+
+
+def _create_pnasnet(variant, pretrained=False, **kwargs):
+ return build_model_with_cfg(
+ PNASNet5Large,
+ variant,
+ pretrained,
+ feature_cfg=dict(feature_cls='hook', no_rewrite=True), # not possible to re-write this model
+ **kwargs,
+ )
+
+
+default_cfgs = generate_default_cfgs({
+ 'pnasnet5large.tf_in1k': {
+ 'hf_hub_id': 'timm/',
+ 'input_size': (3, 331, 331),
+ 'pool_size': (11, 11),
+ 'crop_pct': 0.911,
+ 'interpolation': 'bicubic',
+ 'mean': (0.5, 0.5, 0.5),
+ 'std': (0.5, 0.5, 0.5),
+ 'num_classes': 1000,
+ 'first_conv': 'conv_0.conv',
+ 'classifier': 'last_linear',
+ },
+})
+
+
+@register_model
+def pnasnet5large(pretrained=False, **kwargs) -> PNASNet5Large:
+ r"""PNASNet-5 model architecture from the
+ `"Progressive Neural Architecture Search"
+ `_ paper.
+ """
+ model_kwargs = dict(pad_type='same', **kwargs)
+ return _create_pnasnet('pnasnet5large', pretrained, **model_kwargs)
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/pvt_v2.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/pvt_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..16302002eb2703e4f7b2081036f228c8c8df75a4
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/pvt_v2.py
@@ -0,0 +1,503 @@
+""" Pyramid Vision Transformer v2
+
+@misc{wang2021pvtv2,
+ title={PVTv2: Improved Baselines with Pyramid Vision Transformer},
+ author={Wenhai Wang and Enze Xie and Xiang Li and Deng-Ping Fan and Kaitao Song and Ding Liang and
+ Tong Lu and Ping Luo and Ling Shao},
+ year={2021},
+ eprint={2106.13797},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV}
+}
+
+Based on Apache 2.0 licensed code at https://github.com/whai362/PVT
+
+Modifications and timm support by / Copyright 2022, Ross Wightman
+"""
+
+import math
+from typing import Tuple, List, Callable, Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.layers import DropPath, to_2tuple, to_ntuple, trunc_normal_, LayerNorm, use_fused_attn
+from ._builder import build_model_with_cfg
+from ._registry import register_model, generate_default_cfgs
+
+__all__ = ['PyramidVisionTransformerV2']
+
+
+class MlpWithDepthwiseConv(nn.Module):
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ drop=0.,
+ extra_relu=False,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.relu = nn.ReLU() if extra_relu else nn.Identity()
+ self.dwconv = nn.Conv2d(hidden_features, hidden_features, 3, 1, 1, bias=True, groups=hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x, feat_size: List[int]):
+ x = self.fc1(x)
+ B, N, C = x.shape
+ x = x.transpose(1, 2).view(B, C, feat_size[0], feat_size[1])
+ x = self.relu(x)
+ x = self.dwconv(x)
+ x = x.flatten(2).transpose(1, 2)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+class Attention(nn.Module):
+ fused_attn: torch.jit.Final[bool]
+
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ sr_ratio=1,
+ linear_attn=False,
+ qkv_bias=True,
+ attn_drop=0.,
+ proj_drop=0.
+ ):
+ super().__init__()
+ assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
+
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim ** -0.5
+ self.fused_attn = use_fused_attn()
+
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
+ self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ if not linear_attn:
+ self.pool = None
+ if sr_ratio > 1:
+ self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
+ self.norm = nn.LayerNorm(dim)
+ else:
+ self.sr = None
+ self.norm = None
+ self.act = None
+ else:
+ self.pool = nn.AdaptiveAvgPool2d(7)
+ self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1)
+ self.norm = nn.LayerNorm(dim)
+ self.act = nn.GELU()
+
+ def forward(self, x, feat_size: List[int]):
+ B, N, C = x.shape
+ H, W = feat_size
+ q = self.q(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
+
+ if self.pool is not None:
+ x = x.permute(0, 2, 1).reshape(B, C, H, W)
+ x = self.sr(self.pool(x)).reshape(B, C, -1).permute(0, 2, 1)
+ x = self.norm(x)
+ x = self.act(x)
+ kv = self.kv(x).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
+ else:
+ if self.sr is not None:
+ x = x.permute(0, 2, 1).reshape(B, C, H, W)
+ x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1)
+ x = self.norm(x)
+ kv = self.kv(x).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
+ else:
+ kv = self.kv(x).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
+ k, v = kv.unbind(0)
+
+ if self.fused_attn:
+ x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.)
+ else:
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = attn @ v
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.,
+ sr_ratio=1,
+ linear_attn=False,
+ qkv_bias=False,
+ proj_drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ act_layer=nn.GELU,
+ norm_layer=LayerNorm,
+ ):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = Attention(
+ dim,
+ num_heads=num_heads,
+ sr_ratio=sr_ratio,
+ linear_attn=linear_attn,
+ qkv_bias=qkv_bias,
+ attn_drop=attn_drop,
+ proj_drop=proj_drop,
+ )
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ self.mlp = MlpWithDepthwiseConv(
+ in_features=dim,
+ hidden_features=int(dim * mlp_ratio),
+ act_layer=act_layer,
+ drop=proj_drop,
+ extra_relu=linear_attn,
+ )
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def forward(self, x, feat_size: List[int]):
+ x = x + self.drop_path1(self.attn(self.norm1(x), feat_size))
+ x = x + self.drop_path2(self.mlp(self.norm2(x), feat_size))
+
+ return x
+
+
+class OverlapPatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+ """
+ def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768):
+ super().__init__()
+ patch_size = to_2tuple(patch_size)
+ assert max(patch_size) > stride, "Set larger patch_size than stride"
+ self.patch_size = patch_size
+ self.proj = nn.Conv2d(
+ in_chans, embed_dim, patch_size,
+ stride=stride, padding=(patch_size[0] // 2, patch_size[1] // 2))
+ self.norm = nn.LayerNorm(embed_dim)
+
+ def forward(self, x):
+ x = self.proj(x)
+ x = x.permute(0, 2, 3, 1)
+ x = self.norm(x)
+ return x
+
+
+class PyramidVisionTransformerStage(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ dim_out: int,
+ depth: int,
+ downsample: bool = True,
+ num_heads: int = 8,
+ sr_ratio: int = 1,
+ linear_attn: bool = False,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = True,
+ proj_drop: float = 0.,
+ attn_drop: float = 0.,
+ drop_path: Union[List[float], float] = 0.0,
+ norm_layer: Callable = LayerNorm,
+ ):
+ super().__init__()
+ self.grad_checkpointing = False
+
+ if downsample:
+ self.downsample = OverlapPatchEmbed(
+ patch_size=3,
+ stride=2,
+ in_chans=dim,
+ embed_dim=dim_out,
+ )
+ else:
+ assert dim == dim_out
+ self.downsample = None
+
+ self.blocks = nn.ModuleList([Block(
+ dim=dim_out,
+ num_heads=num_heads,
+ sr_ratio=sr_ratio,
+ linear_attn=linear_attn,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_drop=proj_drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer,
+ ) for i in range(depth)])
+
+ self.norm = norm_layer(dim_out)
+
+ def forward(self, x):
+ # x is either B, C, H, W (if downsample) or B, H, W, C if not
+ if self.downsample is not None:
+ # input to downsample is B, C, H, W
+ x = self.downsample(x) # output B, H, W, C
+ B, H, W, C = x.shape
+ feat_size = (H, W)
+ x = x.reshape(B, -1, C)
+ for blk in self.blocks:
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint.checkpoint(blk, x, feat_size)
+ else:
+ x = blk(x, feat_size)
+ x = self.norm(x)
+ x = x.reshape(B, feat_size[0], feat_size[1], -1).permute(0, 3, 1, 2).contiguous()
+ return x
+
+
+class PyramidVisionTransformerV2(nn.Module):
+ def __init__(
+ self,
+ in_chans=3,
+ num_classes=1000,
+ global_pool='avg',
+ depths=(3, 4, 6, 3),
+ embed_dims=(64, 128, 256, 512),
+ num_heads=(1, 2, 4, 8),
+ sr_ratios=(8, 4, 2, 1),
+ mlp_ratios=(8., 8., 4., 4.),
+ qkv_bias=True,
+ linear=False,
+ drop_rate=0.,
+ proj_drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ norm_layer=LayerNorm,
+ ):
+ super().__init__()
+ self.num_classes = num_classes
+ assert global_pool in ('avg', '')
+ self.global_pool = global_pool
+ self.depths = depths
+ num_stages = len(depths)
+ mlp_ratios = to_ntuple(num_stages)(mlp_ratios)
+ num_heads = to_ntuple(num_stages)(num_heads)
+ sr_ratios = to_ntuple(num_stages)(sr_ratios)
+ assert(len(embed_dims)) == num_stages
+ self.feature_info = []
+
+ self.patch_embed = OverlapPatchEmbed(
+ patch_size=7,
+ stride=4,
+ in_chans=in_chans,
+ embed_dim=embed_dims[0],
+ )
+
+ dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
+ cur = 0
+ prev_dim = embed_dims[0]
+ stages = []
+ for i in range(num_stages):
+ stages += [PyramidVisionTransformerStage(
+ dim=prev_dim,
+ dim_out=embed_dims[i],
+ depth=depths[i],
+ downsample=i > 0,
+ num_heads=num_heads[i],
+ sr_ratio=sr_ratios[i],
+ mlp_ratio=mlp_ratios[i],
+ linear_attn=linear,
+ qkv_bias=qkv_bias,
+ proj_drop=proj_drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ )]
+ prev_dim = embed_dims[i]
+ cur += depths[i]
+ self.feature_info += [dict(num_chs=prev_dim, reduction=4 * 2**i, module=f'stages.{i}')]
+ self.stages = nn.Sequential(*stages)
+
+ # classification head
+ self.num_features = embed_dims[-1]
+ self.head_drop = nn.Dropout(drop_rate)
+ self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Conv2d):
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ fan_out //= m.groups
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
+ if m.bias is not None:
+ m.bias.data.zero_()
+
+ def freeze_patch_emb(self):
+ self.patch_embed.requires_grad = False
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {}
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse=False):
+ matcher = dict(
+ stem=r'^patch_embed', # stem and embed
+ blocks=r'^stages\.(\d+)'
+ )
+ return matcher
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ for s in self.stages:
+ s.grad_checkpointing = enable
+
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=None):
+ self.num_classes = num_classes
+ if global_pool is not None:
+ assert global_pool in ('avg', '')
+ self.global_pool = global_pool
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ x = self.patch_embed(x)
+ x = self.stages(x)
+ return x
+
+ def forward_head(self, x, pre_logits: bool = False):
+ if self.global_pool:
+ x = x.mean(dim=(-1, -2))
+ x = self.head_drop(x)
+ return x if pre_logits else self.head(x)
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.forward_head(x)
+ return x
+
+
+def _checkpoint_filter_fn(state_dict, model):
+ """ Remap original checkpoints -> timm """
+ if 'patch_embed.proj.weight' in state_dict:
+ return state_dict # non-original checkpoint, no remapping needed
+
+ out_dict = {}
+ import re
+ for k, v in state_dict.items():
+ if k.startswith('patch_embed'):
+ k = k.replace('patch_embed1', 'patch_embed')
+ k = k.replace('patch_embed2', 'stages.1.downsample')
+ k = k.replace('patch_embed3', 'stages.2.downsample')
+ k = k.replace('patch_embed4', 'stages.3.downsample')
+ k = k.replace('dwconv.dwconv', 'dwconv')
+ k = re.sub(r'block(\d+).(\d+)', lambda x: f'stages.{int(x.group(1)) - 1}.blocks.{x.group(2)}', k)
+ k = re.sub(r'^norm(\d+)', lambda x: f'stages.{int(x.group(1)) - 1}.norm', k)
+ out_dict[k] = v
+ return out_dict
+
+
+def _create_pvt2(variant, pretrained=False, **kwargs):
+ default_out_indices = tuple(range(4))
+ out_indices = kwargs.pop('out_indices', default_out_indices)
+ model = build_model_with_cfg(
+ PyramidVisionTransformerV2,
+ variant,
+ pretrained,
+ pretrained_filter_fn=_checkpoint_filter_fn,
+ feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
+ **kwargs,
+ )
+ return model
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.9, 'interpolation': 'bicubic',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head', 'fixed_input_size': False,
+ **kwargs
+ }
+
+
+default_cfgs = generate_default_cfgs({
+ 'pvt_v2_b0.in1k': _cfg(hf_hub_id='timm/'),
+ 'pvt_v2_b1.in1k': _cfg(hf_hub_id='timm/'),
+ 'pvt_v2_b2.in1k': _cfg(hf_hub_id='timm/'),
+ 'pvt_v2_b3.in1k': _cfg(hf_hub_id='timm/'),
+ 'pvt_v2_b4.in1k': _cfg(hf_hub_id='timm/'),
+ 'pvt_v2_b5.in1k': _cfg(hf_hub_id='timm/'),
+ 'pvt_v2_b2_li.in1k': _cfg(hf_hub_id='timm/'),
+})
+
+
+@register_model
+def pvt_v2_b0(pretrained=False, **kwargs) -> PyramidVisionTransformerV2:
+ model_args = dict(depths=(2, 2, 2, 2), embed_dims=(32, 64, 160, 256), num_heads=(1, 2, 5, 8))
+ return _create_pvt2('pvt_v2_b0', pretrained=pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def pvt_v2_b1(pretrained=False, **kwargs) -> PyramidVisionTransformerV2:
+ model_args = dict(depths=(2, 2, 2, 2), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8))
+ return _create_pvt2('pvt_v2_b1', pretrained=pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def pvt_v2_b2(pretrained=False, **kwargs) -> PyramidVisionTransformerV2:
+ model_args = dict(depths=(3, 4, 6, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8))
+ return _create_pvt2('pvt_v2_b2', pretrained=pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def pvt_v2_b3(pretrained=False, **kwargs) -> PyramidVisionTransformerV2:
+ model_args = dict(depths=(3, 4, 18, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8))
+ return _create_pvt2('pvt_v2_b3', pretrained=pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def pvt_v2_b4(pretrained=False, **kwargs) -> PyramidVisionTransformerV2:
+ model_args = dict(depths=(3, 8, 27, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8))
+ return _create_pvt2('pvt_v2_b4', pretrained=pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def pvt_v2_b5(pretrained=False, **kwargs) -> PyramidVisionTransformerV2:
+ model_args = dict(
+ depths=(3, 6, 40, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8), mlp_ratios=(4, 4, 4, 4))
+ return _create_pvt2('pvt_v2_b5', pretrained=pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def pvt_v2_b2_li(pretrained=False, **kwargs) -> PyramidVisionTransformerV2:
+ model_args = dict(
+ depths=(3, 4, 6, 3), embed_dims=(64, 128, 320, 512), num_heads=(1, 2, 5, 8), linear=True)
+ return _create_pvt2('pvt_v2_b2_li', pretrained=pretrained, **dict(model_args, **kwargs))
+
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/registry.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..58e2e1f41add9182c630fd9e7ad5ad0877ea1155
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/registry.py
@@ -0,0 +1,4 @@
+from ._registry import *
+
+import warnings
+warnings.warn(f"Importing from {__name__} is deprecated, please import via timm.models", DeprecationWarning)
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/repghost.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/repghost.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2528805034152ff89ca90dd0faa2c81373e8ac8
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/repghost.py
@@ -0,0 +1,479 @@
+"""
+An implementation of RepGhostNet Model as defined in:
+RepGhost: A Hardware-Efficient Ghost Module via Re-parameterization. https://arxiv.org/abs/2211.06088
+
+Original implementation: https://github.com/ChengpengChen/RepGhost
+"""
+import copy
+from functools import partial
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.layers import SelectAdaptivePool2d, Linear, make_divisible
+from ._builder import build_model_with_cfg
+from ._efficientnet_blocks import SqueezeExcite, ConvBnAct
+from ._manipulate import checkpoint_seq
+from ._registry import register_model, generate_default_cfgs
+
+__all__ = ['RepGhostNet']
+
+
+_SE_LAYER = partial(SqueezeExcite, gate_layer='hard_sigmoid', rd_round_fn=partial(make_divisible, divisor=4))
+
+
+class RepGhostModule(nn.Module):
+ def __init__(
+ self,
+ in_chs,
+ out_chs,
+ kernel_size=1,
+ dw_size=3,
+ stride=1,
+ relu=True,
+ reparam=True,
+ ):
+ super(RepGhostModule, self).__init__()
+ self.out_chs = out_chs
+ init_chs = out_chs
+ new_chs = out_chs
+
+ self.primary_conv = nn.Sequential(
+ nn.Conv2d(in_chs, init_chs, kernel_size, stride, kernel_size // 2, bias=False),
+ nn.BatchNorm2d(init_chs),
+ nn.ReLU(inplace=True) if relu else nn.Identity(),
+ )
+
+ fusion_conv = []
+ fusion_bn = []
+ if reparam:
+ fusion_conv.append(nn.Identity())
+ fusion_bn.append(nn.BatchNorm2d(init_chs))
+
+ self.fusion_conv = nn.Sequential(*fusion_conv)
+ self.fusion_bn = nn.Sequential(*fusion_bn)
+
+ self.cheap_operation = nn.Sequential(
+ nn.Conv2d(init_chs, new_chs, dw_size, 1, dw_size//2, groups=init_chs, bias=False),
+ nn.BatchNorm2d(new_chs),
+ # nn.ReLU(inplace=True) if relu else nn.Identity(),
+ )
+ self.relu = nn.ReLU(inplace=False) if relu else nn.Identity()
+
+ def forward(self, x):
+ x1 = self.primary_conv(x)
+ x2 = self.cheap_operation(x1)
+ for conv, bn in zip(self.fusion_conv, self.fusion_bn):
+ x2 = x2 + bn(conv(x1))
+ return self.relu(x2)
+
+ def get_equivalent_kernel_bias(self):
+ kernel3x3, bias3x3 = self._fuse_bn_tensor(self.cheap_operation[0], self.cheap_operation[1])
+ for conv, bn in zip(self.fusion_conv, self.fusion_bn):
+ kernel, bias = self._fuse_bn_tensor(conv, bn, kernel3x3.shape[0], kernel3x3.device)
+ kernel3x3 += self._pad_1x1_to_3x3_tensor(kernel)
+ bias3x3 += bias
+ return kernel3x3, bias3x3
+
+ @staticmethod
+ def _pad_1x1_to_3x3_tensor(kernel1x1):
+ if kernel1x1 is None:
+ return 0
+ else:
+ return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
+
+ @staticmethod
+ def _fuse_bn_tensor(conv, bn, in_channels=None, device=None):
+ in_channels = in_channels if in_channels else bn.running_mean.shape[0]
+ device = device if device else bn.weight.device
+ if isinstance(conv, nn.Conv2d):
+ kernel = conv.weight
+ assert conv.bias is None
+ else:
+ assert isinstance(conv, nn.Identity)
+ kernel = torch.ones(in_channels, 1, 1, 1, device=device)
+
+ if isinstance(bn, nn.BatchNorm2d):
+ running_mean = bn.running_mean
+ running_var = bn.running_var
+ gamma = bn.weight
+ beta = bn.bias
+ eps = bn.eps
+ std = (running_var + eps).sqrt()
+ t = (gamma / std).reshape(-1, 1, 1, 1)
+ return kernel * t, beta - running_mean * gamma / std
+ assert isinstance(bn, nn.Identity)
+ return kernel, torch.zeros(in_channels).to(kernel.device)
+
+ def switch_to_deploy(self):
+ if len(self.fusion_conv) == 0 and len(self.fusion_bn) == 0:
+ return
+ kernel, bias = self.get_equivalent_kernel_bias()
+ self.cheap_operation = nn.Conv2d(
+ in_channels=self.cheap_operation[0].in_channels,
+ out_channels=self.cheap_operation[0].out_channels,
+ kernel_size=self.cheap_operation[0].kernel_size,
+ padding=self.cheap_operation[0].padding,
+ dilation=self.cheap_operation[0].dilation,
+ groups=self.cheap_operation[0].groups,
+ bias=True)
+ self.cheap_operation.weight.data = kernel
+ self.cheap_operation.bias.data = bias
+ self.__delattr__('fusion_conv')
+ self.__delattr__('fusion_bn')
+ self.fusion_conv = []
+ self.fusion_bn = []
+
+ def reparameterize(self):
+ self.switch_to_deploy()
+
+
+class RepGhostBottleneck(nn.Module):
+ """ RepGhost bottleneck w/ optional SE"""
+
+ def __init__(
+ self,
+ in_chs,
+ mid_chs,
+ out_chs,
+ dw_kernel_size=3,
+ stride=1,
+ act_layer=nn.ReLU,
+ se_ratio=0.,
+ reparam=True,
+ ):
+ super(RepGhostBottleneck, self).__init__()
+ has_se = se_ratio is not None and se_ratio > 0.
+ self.stride = stride
+
+ # Point-wise expansion
+ self.ghost1 = RepGhostModule(in_chs, mid_chs, relu=True, reparam=reparam)
+
+ # Depth-wise convolution
+ if self.stride > 1:
+ self.conv_dw = nn.Conv2d(
+ mid_chs, mid_chs, dw_kernel_size, stride=stride,
+ padding=(dw_kernel_size-1)//2, groups=mid_chs, bias=False)
+ self.bn_dw = nn.BatchNorm2d(mid_chs)
+ else:
+ self.conv_dw = None
+ self.bn_dw = None
+
+ # Squeeze-and-excitation
+ self.se = _SE_LAYER(mid_chs, rd_ratio=se_ratio) if has_se else None
+
+ # Point-wise linear projection
+ self.ghost2 = RepGhostModule(mid_chs, out_chs, relu=False, reparam=reparam)
+
+ # shortcut
+ if in_chs == out_chs and self.stride == 1:
+ self.shortcut = nn.Sequential()
+ else:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(
+ in_chs, in_chs, dw_kernel_size, stride=stride,
+ padding=(dw_kernel_size-1)//2, groups=in_chs, bias=False),
+ nn.BatchNorm2d(in_chs),
+ nn.Conv2d(in_chs, out_chs, 1, stride=1, padding=0, bias=False),
+ nn.BatchNorm2d(out_chs),
+ )
+
+ def forward(self, x):
+ shortcut = x
+
+ # 1st ghost bottleneck
+ x = self.ghost1(x)
+
+ # Depth-wise convolution
+ if self.conv_dw is not None:
+ x = self.conv_dw(x)
+ x = self.bn_dw(x)
+
+ # Squeeze-and-excitation
+ if self.se is not None:
+ x = self.se(x)
+
+ # 2nd ghost bottleneck
+ x = self.ghost2(x)
+
+ x += self.shortcut(shortcut)
+ return x
+
+
+class RepGhostNet(nn.Module):
+ def __init__(
+ self,
+ cfgs,
+ num_classes=1000,
+ width=1.0,
+ in_chans=3,
+ output_stride=32,
+ global_pool='avg',
+ drop_rate=0.2,
+ reparam=True,
+ ):
+ super(RepGhostNet, self).__init__()
+ # setting of inverted residual blocks
+ assert output_stride == 32, 'only output_stride==32 is valid, dilation not supported'
+ self.cfgs = cfgs
+ self.num_classes = num_classes
+ self.drop_rate = drop_rate
+ self.grad_checkpointing = False
+ self.feature_info = []
+
+ # building first layer
+ stem_chs = make_divisible(16 * width, 4)
+ self.conv_stem = nn.Conv2d(in_chans, stem_chs, 3, 2, 1, bias=False)
+ self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=f'conv_stem'))
+ self.bn1 = nn.BatchNorm2d(stem_chs)
+ self.act1 = nn.ReLU(inplace=True)
+ prev_chs = stem_chs
+
+ # building inverted residual blocks
+ stages = nn.ModuleList([])
+ block = RepGhostBottleneck
+ stage_idx = 0
+ net_stride = 2
+ for cfg in self.cfgs:
+ layers = []
+ s = 1
+ for k, exp_size, c, se_ratio, s in cfg:
+ out_chs = make_divisible(c * width, 4)
+ mid_chs = make_divisible(exp_size * width, 4)
+ layers.append(block(prev_chs, mid_chs, out_chs, k, s, se_ratio=se_ratio, reparam=reparam))
+ prev_chs = out_chs
+ if s > 1:
+ net_stride *= 2
+ self.feature_info.append(dict(
+ num_chs=prev_chs, reduction=net_stride, module=f'blocks.{stage_idx}'))
+ stages.append(nn.Sequential(*layers))
+ stage_idx += 1
+
+ out_chs = make_divisible(exp_size * width * 2, 4)
+ stages.append(nn.Sequential(ConvBnAct(prev_chs, out_chs, 1)))
+ self.pool_dim = prev_chs = out_chs
+
+ self.blocks = nn.Sequential(*stages)
+
+ # building last several layers
+ self.num_features = out_chs = 1280
+ self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
+ self.conv_head = nn.Conv2d(prev_chs, out_chs, 1, 1, 0, bias=True)
+ self.act2 = nn.ReLU(inplace=True)
+ self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
+ self.classifier = Linear(out_chs, num_classes) if num_classes > 0 else nn.Identity()
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse=False):
+ matcher = dict(
+ stem=r'^conv_stem|bn1',
+ blocks=[
+ (r'^blocks\.(\d+)' if coarse else r'^blocks\.(\d+)\.(\d+)', None),
+ (r'conv_head', (99999,))
+ ]
+ )
+ return matcher
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.grad_checkpointing = enable
+
+ @torch.jit.ignore
+ def get_classifier(self):
+ return self.classifier
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.num_classes = num_classes
+ # cannot meaningfully change pooling of efficient head after creation
+ self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
+ self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
+ self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ x = self.conv_stem(x)
+ x = self.bn1(x)
+ x = self.act1(x)
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint_seq(self.blocks, x, flatten=True)
+ else:
+ x = self.blocks(x)
+ return x
+
+ def forward_head(self, x):
+ x = self.global_pool(x)
+ x = self.conv_head(x)
+ x = self.act2(x)
+ x = self.flatten(x)
+ if self.drop_rate > 0.:
+ x = F.dropout(x, p=self.drop_rate, training=self.training)
+ x = self.classifier(x)
+ return x
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.forward_head(x)
+ return x
+
+ def convert_to_deploy(self):
+ repghost_model_convert(self, do_copy=False)
+
+
+def repghost_model_convert(model: torch.nn.Module, save_path=None, do_copy=True):
+ """
+ taken from from https://github.com/DingXiaoH/RepVGG/blob/main/repvgg.py
+ """
+ if do_copy:
+ model = copy.deepcopy(model)
+ for module in model.modules():
+ if hasattr(module, 'switch_to_deploy'):
+ module.switch_to_deploy()
+ if save_path is not None:
+ torch.save(model.state_dict(), save_path)
+ return model
+
+
+def _create_repghostnet(variant, width=1.0, pretrained=False, **kwargs):
+ """
+ Constructs a RepGhostNet model
+ """
+ cfgs = [
+ # k, t, c, SE, s
+ # stage1
+ [[3, 8, 16, 0, 1]],
+ # stage2
+ [[3, 24, 24, 0, 2]],
+ [[3, 36, 24, 0, 1]],
+ # stage3
+ [[5, 36, 40, 0.25, 2]],
+ [[5, 60, 40, 0.25, 1]],
+ # stage4
+ [[3, 120, 80, 0, 2]],
+ [[3, 100, 80, 0, 1],
+ [3, 120, 80, 0, 1],
+ [3, 120, 80, 0, 1],
+ [3, 240, 112, 0.25, 1],
+ [3, 336, 112, 0.25, 1]
+ ],
+ # stage5
+ [[5, 336, 160, 0.25, 2]],
+ [[5, 480, 160, 0, 1],
+ [5, 480, 160, 0.25, 1],
+ [5, 480, 160, 0, 1],
+ [5, 480, 160, 0.25, 1]
+ ]
+ ]
+ model_kwargs = dict(
+ cfgs=cfgs,
+ width=width,
+ **kwargs,
+ )
+ return build_model_with_cfg(
+ RepGhostNet,
+ variant,
+ pretrained,
+ feature_cfg=dict(flatten_sequential=True),
+ **model_kwargs,
+ )
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'conv_stem', 'classifier': 'classifier',
+ **kwargs
+ }
+
+
+default_cfgs = generate_default_cfgs({
+ 'repghostnet_050.in1k': _cfg(
+ hf_hub_id='timm/',
+ # url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_0_5x_43M_66.95.pth.tar'
+ ),
+ 'repghostnet_058.in1k': _cfg(
+ hf_hub_id='timm/',
+ # url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_0_58x_60M_68.94.pth.tar'
+ ),
+ 'repghostnet_080.in1k': _cfg(
+ hf_hub_id='timm/',
+ # url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_0_8x_96M_72.24.pth.tar'
+ ),
+ 'repghostnet_100.in1k': _cfg(
+ hf_hub_id='timm/',
+ # url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_1_0x_142M_74.22.pth.tar'
+ ),
+ 'repghostnet_111.in1k': _cfg(
+ hf_hub_id='timm/',
+ # url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_1_11x_170M_75.07.pth.tar'
+ ),
+ 'repghostnet_130.in1k': _cfg(
+ hf_hub_id='timm/',
+ # url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_1_3x_231M_76.37.pth.tar'
+ ),
+ 'repghostnet_150.in1k': _cfg(
+ hf_hub_id='timm/',
+ # url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_1_5x_301M_77.45.pth.tar'
+ ),
+ 'repghostnet_200.in1k': _cfg(
+ hf_hub_id='timm/',
+ # url='https://github.com/ChengpengChen/RepGhost/releases/download/RepGhost/repghostnet_2_0x_516M_78.81.pth.tar'
+ ),
+})
+
+
+@register_model
+def repghostnet_050(pretrained=False, **kwargs) -> RepGhostNet:
+ """ RepGhostNet-0.5x """
+ model = _create_repghostnet('repghostnet_050', width=0.5, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def repghostnet_058(pretrained=False, **kwargs) -> RepGhostNet:
+ """ RepGhostNet-0.58x """
+ model = _create_repghostnet('repghostnet_058', width=0.58, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def repghostnet_080(pretrained=False, **kwargs) -> RepGhostNet:
+ """ RepGhostNet-0.8x """
+ model = _create_repghostnet('repghostnet_080', width=0.8, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def repghostnet_100(pretrained=False, **kwargs) -> RepGhostNet:
+ """ RepGhostNet-1.0x """
+ model = _create_repghostnet('repghostnet_100', width=1.0, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def repghostnet_111(pretrained=False, **kwargs) -> RepGhostNet:
+ """ RepGhostNet-1.11x """
+ model = _create_repghostnet('repghostnet_111', width=1.11, pretrained=pretrained, **kwargs)
+ return model
+
+@register_model
+def repghostnet_130(pretrained=False, **kwargs) -> RepGhostNet:
+ """ RepGhostNet-1.3x """
+ model = _create_repghostnet('repghostnet_130', width=1.3, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def repghostnet_150(pretrained=False, **kwargs) -> RepGhostNet:
+ """ RepGhostNet-1.5x """
+ model = _create_repghostnet('repghostnet_150', width=1.5, pretrained=pretrained, **kwargs)
+ return model
+
+
+@register_model
+def repghostnet_200(pretrained=False, **kwargs) -> RepGhostNet:
+ """ RepGhostNet-2.0x """
+ model = _create_repghostnet('repghostnet_200', width=2.0, pretrained=pretrained, **kwargs)
+ return model
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/repvit.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/repvit.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4b2f46dc51448ca96b7df61edc44dca5568fe8c
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/repvit.py
@@ -0,0 +1,512 @@
+""" RepViT
+
+Paper: `RepViT: Revisiting Mobile CNN From ViT Perspective`
+ - https://arxiv.org/abs/2307.09283
+
+@misc{wang2023repvit,
+ title={RepViT: Revisiting Mobile CNN From ViT Perspective},
+ author={Ao Wang and Hui Chen and Zijia Lin and Hengjun Pu and Guiguang Ding},
+ year={2023},
+ eprint={2307.09283},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV}
+}
+
+Adapted from official impl at https://github.com/jameslahm/RepViT
+"""
+
+__all__ = ['RepVit']
+
+import torch.nn as nn
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from ._registry import register_model, generate_default_cfgs
+from ._builder import build_model_with_cfg
+from timm.layers import SqueezeExcite, trunc_normal_, to_ntuple, to_2tuple
+from ._manipulate import checkpoint_seq
+
+import torch
+
+
+class ConvNorm(nn.Sequential):
+ def __init__(self, in_dim, out_dim, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1):
+ super().__init__()
+ self.add_module('c', nn.Conv2d(in_dim, out_dim, ks, stride, pad, dilation, groups, bias=False))
+ self.add_module('bn', nn.BatchNorm2d(out_dim))
+ nn.init.constant_(self.bn.weight, bn_weight_init)
+ nn.init.constant_(self.bn.bias, 0)
+
+ @torch.no_grad()
+ def fuse(self):
+ c, bn = self._modules.values()
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
+ w = c.weight * w[:, None, None, None]
+ b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
+ m = nn.Conv2d(
+ w.size(1) * self.c.groups,
+ w.size(0),
+ w.shape[2:],
+ stride=self.c.stride,
+ padding=self.c.padding,
+ dilation=self.c.dilation,
+ groups=self.c.groups,
+ device=c.weight.device,
+ )
+ m.weight.data.copy_(w)
+ m.bias.data.copy_(b)
+ return m
+
+
+class NormLinear(nn.Sequential):
+ def __init__(self, in_dim, out_dim, bias=True, std=0.02):
+ super().__init__()
+ self.add_module('bn', nn.BatchNorm1d(in_dim))
+ self.add_module('l', nn.Linear(in_dim, out_dim, bias=bias))
+ trunc_normal_(self.l.weight, std=std)
+ if bias:
+ nn.init.constant_(self.l.bias, 0)
+
+ @torch.no_grad()
+ def fuse(self):
+ bn, l = self._modules.values()
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
+ b = bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5
+ w = l.weight * w[None, :]
+ if l.bias is None:
+ b = b @ self.l.weight.T
+ else:
+ b = (l.weight @ b[:, None]).view(-1) + self.l.bias
+ m = nn.Linear(w.size(1), w.size(0), device=l.weight.device)
+ m.weight.data.copy_(w)
+ m.bias.data.copy_(b)
+ return m
+
+
+class RepVggDw(nn.Module):
+ def __init__(self, ed, kernel_size, legacy=False):
+ super().__init__()
+ self.conv = ConvNorm(ed, ed, kernel_size, 1, (kernel_size - 1) // 2, groups=ed)
+ if legacy:
+ self.conv1 = ConvNorm(ed, ed, 1, 1, 0, groups=ed)
+ # Make torchscript happy.
+ self.bn = nn.Identity()
+ else:
+ self.conv1 = nn.Conv2d(ed, ed, 1, 1, 0, groups=ed)
+ self.bn = nn.BatchNorm2d(ed)
+ self.dim = ed
+ self.legacy = legacy
+
+ def forward(self, x):
+ return self.bn(self.conv(x) + self.conv1(x) + x)
+
+ @torch.no_grad()
+ def fuse(self):
+ conv = self.conv.fuse()
+
+ if self.legacy:
+ conv1 = self.conv1.fuse()
+ else:
+ conv1 = self.conv1
+
+ conv_w = conv.weight
+ conv_b = conv.bias
+ conv1_w = conv1.weight
+ conv1_b = conv1.bias
+
+ conv1_w = nn.functional.pad(conv1_w, [1, 1, 1, 1])
+
+ identity = nn.functional.pad(
+ torch.ones(conv1_w.shape[0], conv1_w.shape[1], 1, 1, device=conv1_w.device), [1, 1, 1, 1]
+ )
+
+ final_conv_w = conv_w + conv1_w + identity
+ final_conv_b = conv_b + conv1_b
+
+ conv.weight.data.copy_(final_conv_w)
+ conv.bias.data.copy_(final_conv_b)
+
+ if not self.legacy:
+ bn = self.bn
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
+ w = conv.weight * w[:, None, None, None]
+ b = bn.bias + (conv.bias - bn.running_mean) * bn.weight / (bn.running_var + bn.eps) ** 0.5
+ conv.weight.data.copy_(w)
+ conv.bias.data.copy_(b)
+ return conv
+
+
+class RepVitMlp(nn.Module):
+ def __init__(self, in_dim, hidden_dim, act_layer):
+ super().__init__()
+ self.conv1 = ConvNorm(in_dim, hidden_dim, 1, 1, 0)
+ self.act = act_layer()
+ self.conv2 = ConvNorm(hidden_dim, in_dim, 1, 1, 0, bn_weight_init=0)
+
+ def forward(self, x):
+ return self.conv2(self.act(self.conv1(x)))
+
+
+class RepViTBlock(nn.Module):
+ def __init__(self, in_dim, mlp_ratio, kernel_size, use_se, act_layer, legacy=False):
+ super(RepViTBlock, self).__init__()
+
+ self.token_mixer = RepVggDw(in_dim, kernel_size, legacy)
+ self.se = SqueezeExcite(in_dim, 0.25) if use_se else nn.Identity()
+ self.channel_mixer = RepVitMlp(in_dim, in_dim * mlp_ratio, act_layer)
+
+ def forward(self, x):
+ x = self.token_mixer(x)
+ x = self.se(x)
+ identity = x
+ x = self.channel_mixer(x)
+ return identity + x
+
+
+class RepVitStem(nn.Module):
+ def __init__(self, in_chs, out_chs, act_layer):
+ super().__init__()
+ self.conv1 = ConvNorm(in_chs, out_chs // 2, 3, 2, 1)
+ self.act1 = act_layer()
+ self.conv2 = ConvNorm(out_chs // 2, out_chs, 3, 2, 1)
+ self.stride = 4
+
+ def forward(self, x):
+ return self.conv2(self.act1(self.conv1(x)))
+
+
+class RepVitDownsample(nn.Module):
+ def __init__(self, in_dim, mlp_ratio, out_dim, kernel_size, act_layer, legacy=False):
+ super().__init__()
+ self.pre_block = RepViTBlock(in_dim, mlp_ratio, kernel_size, use_se=False, act_layer=act_layer, legacy=legacy)
+ self.spatial_downsample = ConvNorm(in_dim, in_dim, kernel_size, 2, (kernel_size - 1) // 2, groups=in_dim)
+ self.channel_downsample = ConvNorm(in_dim, out_dim, 1, 1)
+ self.ffn = RepVitMlp(out_dim, out_dim * mlp_ratio, act_layer)
+
+ def forward(self, x):
+ x = self.pre_block(x)
+ x = self.spatial_downsample(x)
+ x = self.channel_downsample(x)
+ identity = x
+ x = self.ffn(x)
+ return x + identity
+
+
+class RepVitClassifier(nn.Module):
+ def __init__(self, dim, num_classes, distillation=False, drop=0.0):
+ super().__init__()
+ self.head_drop = nn.Dropout(drop)
+ self.head = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()
+ self.distillation = distillation
+ self.distilled_training = False
+ self.num_classes = num_classes
+ if distillation:
+ self.head_dist = NormLinear(dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward(self, x):
+ x = self.head_drop(x)
+ if self.distillation:
+ x1, x2 = self.head(x), self.head_dist(x)
+ if self.training and self.distilled_training and not torch.jit.is_scripting():
+ return x1, x2
+ else:
+ return (x1 + x2) / 2
+ else:
+ x = self.head(x)
+ return x
+
+ @torch.no_grad()
+ def fuse(self):
+ if not self.num_classes > 0:
+ return nn.Identity()
+ head = self.head.fuse()
+ if self.distillation:
+ head_dist = self.head_dist.fuse()
+ head.weight += head_dist.weight
+ head.bias += head_dist.bias
+ head.weight /= 2
+ head.bias /= 2
+ return head
+ else:
+ return head
+
+
+class RepVitStage(nn.Module):
+ def __init__(self, in_dim, out_dim, depth, mlp_ratio, act_layer, kernel_size=3, downsample=True, legacy=False):
+ super().__init__()
+ if downsample:
+ self.downsample = RepVitDownsample(in_dim, mlp_ratio, out_dim, kernel_size, act_layer, legacy)
+ else:
+ assert in_dim == out_dim
+ self.downsample = nn.Identity()
+
+ blocks = []
+ use_se = True
+ for _ in range(depth):
+ blocks.append(RepViTBlock(out_dim, mlp_ratio, kernel_size, use_se, act_layer, legacy))
+ use_se = not use_se
+
+ self.blocks = nn.Sequential(*blocks)
+
+ def forward(self, x):
+ x = self.downsample(x)
+ x = self.blocks(x)
+ return x
+
+
+class RepVit(nn.Module):
+ def __init__(
+ self,
+ in_chans=3,
+ img_size=224,
+ embed_dim=(48,),
+ depth=(2,),
+ mlp_ratio=2,
+ global_pool='avg',
+ kernel_size=3,
+ num_classes=1000,
+ act_layer=nn.GELU,
+ distillation=True,
+ drop_rate=0.0,
+ legacy=False,
+ ):
+ super(RepVit, self).__init__()
+ self.grad_checkpointing = False
+ self.global_pool = global_pool
+ self.embed_dim = embed_dim
+ self.num_classes = num_classes
+
+ in_dim = embed_dim[0]
+ self.stem = RepVitStem(in_chans, in_dim, act_layer)
+ stride = self.stem.stride
+ resolution = tuple([i // p for i, p in zip(to_2tuple(img_size), to_2tuple(stride))])
+
+ num_stages = len(embed_dim)
+ mlp_ratios = to_ntuple(num_stages)(mlp_ratio)
+
+ self.feature_info = []
+ stages = []
+ for i in range(num_stages):
+ downsample = True if i != 0 else False
+ stages.append(
+ RepVitStage(
+ in_dim,
+ embed_dim[i],
+ depth[i],
+ mlp_ratio=mlp_ratios[i],
+ act_layer=act_layer,
+ kernel_size=kernel_size,
+ downsample=downsample,
+ legacy=legacy,
+ )
+ )
+ stage_stride = 2 if downsample else 1
+ stride *= stage_stride
+ resolution = tuple([(r - 1) // stage_stride + 1 for r in resolution])
+ self.feature_info += [dict(num_chs=embed_dim[i], reduction=stride, module=f'stages.{i}')]
+ in_dim = embed_dim[i]
+ self.stages = nn.Sequential(*stages)
+
+ self.num_features = embed_dim[-1]
+ self.head_drop = nn.Dropout(drop_rate)
+ self.head = RepVitClassifier(embed_dim[-1], num_classes, distillation)
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse=False):
+ matcher = dict(stem=r'^stem', blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]) # stem and embed
+ return matcher
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.grad_checkpointing = enable
+
+ @torch.jit.ignore
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=None, distillation=False):
+ self.num_classes = num_classes
+ if global_pool is not None:
+ self.global_pool = global_pool
+ self.head = (
+ RepVitClassifier(self.embed_dim[-1], num_classes, distillation) if num_classes > 0 else nn.Identity()
+ )
+
+ @torch.jit.ignore
+ def set_distilled_training(self, enable=True):
+ self.head.distilled_training = enable
+
+ def forward_features(self, x):
+ x = self.stem(x)
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint_seq(self.stages, x)
+ else:
+ x = self.stages(x)
+ return x
+
+ def forward_head(self, x, pre_logits: bool = False):
+ if self.global_pool == 'avg':
+ x = x.mean((2, 3), keepdim=False)
+ x = self.head_drop(x)
+ return self.head(x)
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.forward_head(x)
+ return x
+
+ @torch.no_grad()
+ def fuse(self):
+ def fuse_children(net):
+ for child_name, child in net.named_children():
+ if hasattr(child, 'fuse'):
+ fused = child.fuse()
+ setattr(net, child_name, fused)
+ fuse_children(fused)
+ else:
+ fuse_children(child)
+
+ fuse_children(self)
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000,
+ 'input_size': (3, 224, 224),
+ 'pool_size': (7, 7),
+ 'crop_pct': 0.95,
+ 'interpolation': 'bicubic',
+ 'mean': IMAGENET_DEFAULT_MEAN,
+ 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'stem.conv1.c',
+ 'classifier': ('head.head.l', 'head.head_dist.l'),
+ **kwargs,
+ }
+
+
+default_cfgs = generate_default_cfgs(
+ {
+ 'repvit_m1.dist_in1k': _cfg(
+ hf_hub_id='timm/',
+ ),
+ 'repvit_m2.dist_in1k': _cfg(
+ hf_hub_id='timm/',
+ ),
+ 'repvit_m3.dist_in1k': _cfg(
+ hf_hub_id='timm/',
+ ),
+ 'repvit_m0_9.dist_300e_in1k': _cfg(
+ hf_hub_id='timm/',
+ ),
+ 'repvit_m0_9.dist_450e_in1k': _cfg(
+ hf_hub_id='timm/',
+ ),
+ 'repvit_m1_0.dist_300e_in1k': _cfg(
+ hf_hub_id='timm/',
+ ),
+ 'repvit_m1_0.dist_450e_in1k': _cfg(
+ hf_hub_id='timm/',
+ ),
+ 'repvit_m1_1.dist_300e_in1k': _cfg(
+ hf_hub_id='timm/',
+ ),
+ 'repvit_m1_1.dist_450e_in1k': _cfg(
+ hf_hub_id='timm/',
+ ),
+ 'repvit_m1_5.dist_300e_in1k': _cfg(
+ hf_hub_id='timm/',
+ ),
+ 'repvit_m1_5.dist_450e_in1k': _cfg(
+ hf_hub_id='timm/',
+ ),
+ 'repvit_m2_3.dist_300e_in1k': _cfg(
+ hf_hub_id='timm/',
+ ),
+ 'repvit_m2_3.dist_450e_in1k': _cfg(
+ hf_hub_id='timm/',
+ ),
+ }
+)
+
+
+def _create_repvit(variant, pretrained=False, **kwargs):
+ out_indices = kwargs.pop('out_indices', (0, 1, 2, 3))
+ model = build_model_with_cfg(
+ RepVit,
+ variant,
+ pretrained,
+ feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
+ **kwargs,
+ )
+ return model
+
+
+@register_model
+def repvit_m1(pretrained=False, **kwargs):
+ """
+ Constructs a RepViT-M1 model
+ """
+ model_args = dict(embed_dim=(48, 96, 192, 384), depth=(2, 2, 14, 2), legacy=True)
+ return _create_repvit('repvit_m1', pretrained=pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def repvit_m2(pretrained=False, **kwargs):
+ """
+ Constructs a RepViT-M2 model
+ """
+ model_args = dict(embed_dim=(64, 128, 256, 512), depth=(2, 2, 12, 2), legacy=True)
+ return _create_repvit('repvit_m2', pretrained=pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def repvit_m3(pretrained=False, **kwargs):
+ """
+ Constructs a RepViT-M3 model
+ """
+ model_args = dict(embed_dim=(64, 128, 256, 512), depth=(4, 4, 18, 2), legacy=True)
+ return _create_repvit('repvit_m3', pretrained=pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def repvit_m0_9(pretrained=False, **kwargs):
+ """
+ Constructs a RepViT-M0.9 model
+ """
+ model_args = dict(embed_dim=(48, 96, 192, 384), depth=(2, 2, 14, 2))
+ return _create_repvit('repvit_m0_9', pretrained=pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def repvit_m1_0(pretrained=False, **kwargs):
+ """
+ Constructs a RepViT-M1.0 model
+ """
+ model_args = dict(embed_dim=(56, 112, 224, 448), depth=(2, 2, 14, 2))
+ return _create_repvit('repvit_m1_0', pretrained=pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def repvit_m1_1(pretrained=False, **kwargs):
+ """
+ Constructs a RepViT-M1.1 model
+ """
+ model_args = dict(embed_dim=(64, 128, 256, 512), depth=(2, 2, 12, 2))
+ return _create_repvit('repvit_m1_1', pretrained=pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def repvit_m1_5(pretrained=False, **kwargs):
+ """
+ Constructs a RepViT-M1.5 model
+ """
+ model_args = dict(embed_dim=(64, 128, 256, 512), depth=(4, 4, 24, 4))
+ return _create_repvit('repvit_m1_5', pretrained=pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def repvit_m2_3(pretrained=False, **kwargs):
+ """
+ Constructs a RepViT-M2.3 model
+ """
+ model_args = dict(embed_dim=(80, 160, 320, 640), depth=(6, 6, 34, 2))
+ return _create_repvit('repvit_m2_3', pretrained=pretrained, **dict(model_args, **kwargs))
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/res2net.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/res2net.py
new file mode 100644
index 0000000000000000000000000000000000000000..691f929b91db626f09f30e28a48b1cc5c8ab200e
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/res2net.py
@@ -0,0 +1,227 @@
+""" Res2Net and Res2NeXt
+Adapted from Official Pytorch impl at: https://github.com/gasvn/Res2Net/
+Paper: `Res2Net: A New Multi-scale Backbone Architecture` - https://arxiv.org/abs/1904.01169
+"""
+import math
+
+import torch
+import torch.nn as nn
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from ._builder import build_model_with_cfg
+from ._registry import register_model, generate_default_cfgs
+from .resnet import ResNet
+
+__all__ = []
+
+
+class Bottle2neck(nn.Module):
+ """ Res2Net/Res2NeXT Bottleneck
+ Adapted from https://github.com/gasvn/Res2Net/blob/master/res2net.py
+ """
+ expansion = 4
+
+ def __init__(
+ self,
+ inplanes,
+ planes,
+ stride=1,
+ downsample=None,
+ cardinality=1,
+ base_width=26,
+ scale=4,
+ dilation=1,
+ first_dilation=None,
+ act_layer=nn.ReLU,
+ norm_layer=None,
+ attn_layer=None,
+ **_,
+ ):
+ super(Bottle2neck, self).__init__()
+ self.scale = scale
+ self.is_first = stride > 1 or downsample is not None
+ self.num_scales = max(1, scale - 1)
+ width = int(math.floor(planes * (base_width / 64.0))) * cardinality
+ self.width = width
+ outplanes = planes * self.expansion
+ first_dilation = first_dilation or dilation
+
+ self.conv1 = nn.Conv2d(inplanes, width * scale, kernel_size=1, bias=False)
+ self.bn1 = norm_layer(width * scale)
+
+ convs = []
+ bns = []
+ for i in range(self.num_scales):
+ convs.append(nn.Conv2d(
+ width, width, kernel_size=3, stride=stride, padding=first_dilation,
+ dilation=first_dilation, groups=cardinality, bias=False))
+ bns.append(norm_layer(width))
+ self.convs = nn.ModuleList(convs)
+ self.bns = nn.ModuleList(bns)
+ if self.is_first:
+ # FIXME this should probably have count_include_pad=False, but hurts original weights
+ self.pool = nn.AvgPool2d(kernel_size=3, stride=stride, padding=1)
+ else:
+ self.pool = None
+
+ self.conv3 = nn.Conv2d(width * scale, outplanes, kernel_size=1, bias=False)
+ self.bn3 = norm_layer(outplanes)
+ self.se = attn_layer(outplanes) if attn_layer is not None else None
+
+ self.relu = act_layer(inplace=True)
+ self.downsample = downsample
+
+ def zero_init_last(self):
+ if getattr(self.bn3, 'weight', None) is not None:
+ nn.init.zeros_(self.bn3.weight)
+
+ def forward(self, x):
+ shortcut = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ spx = torch.split(out, self.width, 1)
+ spo = []
+ sp = spx[0] # redundant, for torchscript
+ for i, (conv, bn) in enumerate(zip(self.convs, self.bns)):
+ if i == 0 or self.is_first:
+ sp = spx[i]
+ else:
+ sp = sp + spx[i]
+ sp = conv(sp)
+ sp = bn(sp)
+ sp = self.relu(sp)
+ spo.append(sp)
+ if self.scale > 1:
+ if self.pool is not None: # self.is_first == True, None check for torchscript
+ spo.append(self.pool(spx[-1]))
+ else:
+ spo.append(spx[-1])
+ out = torch.cat(spo, 1)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.se is not None:
+ out = self.se(out)
+
+ if self.downsample is not None:
+ shortcut = self.downsample(x)
+
+ out += shortcut
+ out = self.relu(out)
+
+ return out
+
+
+def _create_res2net(variant, pretrained=False, **kwargs):
+ return build_model_with_cfg(ResNet, variant, pretrained, **kwargs)
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.875, 'interpolation': 'bilinear',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'conv1', 'classifier': 'fc',
+ **kwargs
+ }
+
+
+default_cfgs = generate_default_cfgs({
+ 'res2net50_26w_4s.in1k': _cfg(hf_hub_id='timm/'),
+ 'res2net50_48w_2s.in1k': _cfg(hf_hub_id='timm/'),
+ 'res2net50_14w_8s.in1k': _cfg(hf_hub_id='timm/'),
+ 'res2net50_26w_6s.in1k': _cfg(hf_hub_id='timm/'),
+ 'res2net50_26w_8s.in1k': _cfg(hf_hub_id='timm/'),
+ 'res2net101_26w_4s.in1k': _cfg(hf_hub_id='timm/'),
+ 'res2next50.in1k': _cfg(hf_hub_id='timm/'),
+ 'res2net50d.in1k': _cfg(hf_hub_id='timm/', first_conv='conv1.0'),
+ 'res2net101d.in1k': _cfg(hf_hub_id='timm/', first_conv='conv1.0'),
+})
+
+
+@register_model
+def res2net50_26w_4s(pretrained=False, **kwargs) -> ResNet:
+ """Constructs a Res2Net-50 26w4s model.
+ """
+ model_args = dict(
+ block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=4))
+ return _create_res2net('res2net50_26w_4s', pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def res2net101_26w_4s(pretrained=False, **kwargs) -> ResNet:
+ """Constructs a Res2Net-101 26w4s model.
+ """
+ model_args = dict(
+ block=Bottle2neck, layers=[3, 4, 23, 3], base_width=26, block_args=dict(scale=4))
+ return _create_res2net('res2net101_26w_4s', pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def res2net50_26w_6s(pretrained=False, **kwargs) -> ResNet:
+ """Constructs a Res2Net-50 26w6s model.
+ """
+ model_args = dict(
+ block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=6))
+ return _create_res2net('res2net50_26w_6s', pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def res2net50_26w_8s(pretrained=False, **kwargs) -> ResNet:
+ """Constructs a Res2Net-50 26w8s model.
+ """
+ model_args = dict(
+ block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=8))
+ return _create_res2net('res2net50_26w_8s', pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def res2net50_48w_2s(pretrained=False, **kwargs) -> ResNet:
+ """Constructs a Res2Net-50 48w2s model.
+ """
+ model_args = dict(
+ block=Bottle2neck, layers=[3, 4, 6, 3], base_width=48, block_args=dict(scale=2))
+ return _create_res2net('res2net50_48w_2s', pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def res2net50_14w_8s(pretrained=False, **kwargs) -> ResNet:
+ """Constructs a Res2Net-50 14w8s model.
+ """
+ model_args = dict(
+ block=Bottle2neck, layers=[3, 4, 6, 3], base_width=14, block_args=dict(scale=8))
+ return _create_res2net('res2net50_14w_8s', pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def res2next50(pretrained=False, **kwargs) -> ResNet:
+ """Construct Res2NeXt-50 4s
+ """
+ model_args = dict(
+ block=Bottle2neck, layers=[3, 4, 6, 3], base_width=4, cardinality=8, block_args=dict(scale=4))
+ return _create_res2net('res2next50', pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def res2net50d(pretrained=False, **kwargs) -> ResNet:
+ """Construct Res2Net-50
+ """
+ model_args = dict(
+ block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, stem_type='deep',
+ avg_down=True, stem_width=32, block_args=dict(scale=4))
+ return _create_res2net('res2net50d', pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def res2net101d(pretrained=False, **kwargs) -> ResNet:
+ """Construct Res2Net-50
+ """
+ model_args = dict(
+ block=Bottle2neck, layers=[3, 4, 23, 3], base_width=26, stem_type='deep',
+ avg_down=True, stem_width=32, block_args=dict(scale=4))
+ return _create_res2net('res2net101d', pretrained, **dict(model_args, **kwargs))
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/rexnet.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/rexnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..d34933f782ddc47a822c25e4f4c5703d7143601a
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/rexnet.py
@@ -0,0 +1,356 @@
+""" ReXNet
+
+A PyTorch impl of `ReXNet: Diminishing Representational Bottleneck on Convolutional Neural Network` -
+https://arxiv.org/abs/2007.00992
+
+Adapted from original impl at https://github.com/clovaai/rexnet
+Copyright (c) 2020-present NAVER Corp. MIT license
+
+Changes for timm, feature extraction, and rounded channel variant hacked together by Ross Wightman
+Copyright 2020 Ross Wightman
+"""
+
+from functools import partial
+from math import ceil
+
+import torch
+import torch.nn as nn
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.layers import ClassifierHead, create_act_layer, ConvNormAct, DropPath, make_divisible, SEModule
+from ._builder import build_model_with_cfg
+from ._efficientnet_builder import efficientnet_init_weights
+from ._manipulate import checkpoint_seq
+from ._registry import generate_default_cfgs, register_model
+
+__all__ = ['RexNet'] # model_registry will add each entrypoint fn to this
+
+
+SEWithNorm = partial(SEModule, norm_layer=nn.BatchNorm2d)
+
+
+class LinearBottleneck(nn.Module):
+ def __init__(
+ self,
+ in_chs,
+ out_chs,
+ stride,
+ dilation=(1, 1),
+ exp_ratio=1.0,
+ se_ratio=0.,
+ ch_div=1,
+ act_layer='swish',
+ dw_act_layer='relu6',
+ drop_path=None,
+ ):
+ super(LinearBottleneck, self).__init__()
+ self.use_shortcut = stride == 1 and dilation[0] == dilation[1] and in_chs <= out_chs
+ self.in_channels = in_chs
+ self.out_channels = out_chs
+
+ if exp_ratio != 1.:
+ dw_chs = make_divisible(round(in_chs * exp_ratio), divisor=ch_div)
+ self.conv_exp = ConvNormAct(in_chs, dw_chs, act_layer=act_layer)
+ else:
+ dw_chs = in_chs
+ self.conv_exp = None
+
+ self.conv_dw = ConvNormAct(
+ dw_chs,
+ dw_chs,
+ kernel_size=3,
+ stride=stride,
+ dilation=dilation[0],
+ groups=dw_chs,
+ apply_act=False,
+ )
+ if se_ratio > 0:
+ self.se = SEWithNorm(dw_chs, rd_channels=make_divisible(int(dw_chs * se_ratio), ch_div))
+ else:
+ self.se = None
+ self.act_dw = create_act_layer(dw_act_layer)
+
+ self.conv_pwl = ConvNormAct(dw_chs, out_chs, 1, apply_act=False)
+ self.drop_path = drop_path
+
+ def feat_channels(self, exp=False):
+ return self.conv_dw.out_channels if exp else self.out_channels
+
+ def forward(self, x):
+ shortcut = x
+ if self.conv_exp is not None:
+ x = self.conv_exp(x)
+ x = self.conv_dw(x)
+ if self.se is not None:
+ x = self.se(x)
+ x = self.act_dw(x)
+ x = self.conv_pwl(x)
+ if self.use_shortcut:
+ if self.drop_path is not None:
+ x = self.drop_path(x)
+ x = torch.cat([x[:, 0:self.in_channels] + shortcut, x[:, self.in_channels:]], dim=1)
+ return x
+
+
+def _block_cfg(
+ width_mult=1.0,
+ depth_mult=1.0,
+ initial_chs=16,
+ final_chs=180,
+ se_ratio=0.,
+ ch_div=1,
+):
+ layers = [1, 2, 2, 3, 3, 5]
+ strides = [1, 2, 2, 2, 1, 2]
+ layers = [ceil(element * depth_mult) for element in layers]
+ strides = sum([[element] + [1] * (layers[idx] - 1) for idx, element in enumerate(strides)], [])
+ exp_ratios = [1] * layers[0] + [6] * sum(layers[1:])
+ depth = sum(layers[:]) * 3
+ base_chs = initial_chs / width_mult if width_mult < 1.0 else initial_chs
+
+ # The following channel configuration is a simple instance to make each layer become an expand layer.
+ out_chs_list = []
+ for i in range(depth // 3):
+ out_chs_list.append(make_divisible(round(base_chs * width_mult), divisor=ch_div))
+ base_chs += final_chs / (depth // 3 * 1.0)
+
+ se_ratios = [0.] * (layers[0] + layers[1]) + [se_ratio] * sum(layers[2:])
+
+ return list(zip(out_chs_list, exp_ratios, strides, se_ratios))
+
+
+def _build_blocks(
+ block_cfg,
+ prev_chs,
+ width_mult,
+ ch_div=1,
+ output_stride=32,
+ act_layer='swish',
+ dw_act_layer='relu6',
+ drop_path_rate=0.,
+):
+ feat_chs = [prev_chs]
+ feature_info = []
+ curr_stride = 2
+ dilation = 1
+ features = []
+ num_blocks = len(block_cfg)
+ for block_idx, (chs, exp_ratio, stride, se_ratio) in enumerate(block_cfg):
+ next_dilation = dilation
+ if stride > 1:
+ fname = 'stem' if block_idx == 0 else f'features.{block_idx - 1}'
+ feature_info += [dict(num_chs=feat_chs[-1], reduction=curr_stride, module=fname)]
+ if curr_stride >= output_stride:
+ next_dilation = dilation * stride
+ stride = 1
+ block_dpr = drop_path_rate * block_idx / (num_blocks - 1) # stochastic depth linear decay rule
+ drop_path = DropPath(block_dpr) if block_dpr > 0. else None
+ features.append(LinearBottleneck(
+ in_chs=prev_chs,
+ out_chs=chs,
+ exp_ratio=exp_ratio,
+ stride=stride,
+ dilation=(dilation, next_dilation),
+ se_ratio=se_ratio,
+ ch_div=ch_div,
+ act_layer=act_layer,
+ dw_act_layer=dw_act_layer,
+ drop_path=drop_path,
+ ))
+ curr_stride *= stride
+ dilation = next_dilation
+ prev_chs = chs
+ feat_chs += [features[-1].feat_channels()]
+ pen_chs = make_divisible(1280 * width_mult, divisor=ch_div)
+ feature_info += [dict(num_chs=feat_chs[-1], reduction=curr_stride, module=f'features.{len(features) - 1}')]
+ features.append(ConvNormAct(prev_chs, pen_chs, act_layer=act_layer))
+ return features, feature_info
+
+
+class RexNet(nn.Module):
+ def __init__(
+ self,
+ in_chans=3,
+ num_classes=1000,
+ global_pool='avg',
+ output_stride=32,
+ initial_chs=16,
+ final_chs=180,
+ width_mult=1.0,
+ depth_mult=1.0,
+ se_ratio=1/12.,
+ ch_div=1,
+ act_layer='swish',
+ dw_act_layer='relu6',
+ drop_rate=0.2,
+ drop_path_rate=0.,
+ ):
+ super(RexNet, self).__init__()
+ self.num_classes = num_classes
+ self.drop_rate = drop_rate
+ self.grad_checkpointing = False
+
+ assert output_stride in (32, 16, 8)
+ stem_base_chs = 32 / width_mult if width_mult < 1.0 else 32
+ stem_chs = make_divisible(round(stem_base_chs * width_mult), divisor=ch_div)
+ self.stem = ConvNormAct(in_chans, stem_chs, 3, stride=2, act_layer=act_layer)
+
+ block_cfg = _block_cfg(width_mult, depth_mult, initial_chs, final_chs, se_ratio, ch_div)
+ features, self.feature_info = _build_blocks(
+ block_cfg,
+ stem_chs,
+ width_mult,
+ ch_div,
+ output_stride,
+ act_layer,
+ dw_act_layer,
+ drop_path_rate,
+ )
+ self.num_features = features[-1].out_channels
+ self.features = nn.Sequential(*features)
+
+ self.head = ClassifierHead(self.num_features, num_classes, global_pool, drop_rate)
+
+ efficientnet_init_weights(self)
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse=False):
+ matcher = dict(
+ stem=r'^stem',
+ blocks=r'^features\.(\d+)',
+ )
+ return matcher
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.grad_checkpointing = enable
+
+ @torch.jit.ignore
+ def get_classifier(self):
+ return self.head.fc
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
+
+ def forward_features(self, x):
+ x = self.stem(x)
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint_seq(self.features, x, flatten=True)
+ else:
+ x = self.features(x)
+ return x
+
+ def forward_head(self, x, pre_logits: bool = False):
+ return self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.forward_head(x)
+ return x
+
+
+def _create_rexnet(variant, pretrained, **kwargs):
+ feature_cfg = dict(flatten_sequential=True)
+ return build_model_with_cfg(
+ RexNet,
+ variant,
+ pretrained,
+ feature_cfg=feature_cfg,
+ **kwargs,
+ )
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'stem.conv', 'classifier': 'head.fc',
+ 'license': 'mit', **kwargs
+ }
+
+
+default_cfgs = generate_default_cfgs({
+ 'rexnet_100.nav_in1k': _cfg(hf_hub_id='timm/'),
+ 'rexnet_130.nav_in1k': _cfg(hf_hub_id='timm/'),
+ 'rexnet_150.nav_in1k': _cfg(hf_hub_id='timm/'),
+ 'rexnet_200.nav_in1k': _cfg(hf_hub_id='timm/'),
+ 'rexnet_300.nav_in1k': _cfg(hf_hub_id='timm/'),
+ 'rexnetr_100.untrained': _cfg(),
+ 'rexnetr_130.untrained': _cfg(),
+ 'rexnetr_150.untrained': _cfg(),
+ 'rexnetr_200.sw_in12k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ crop_pct=0.95, test_crop_pct=1.0, test_input_size=(3, 288, 288), license='apache-2.0'),
+ 'rexnetr_300.sw_in12k_ft_in1k': _cfg(
+ hf_hub_id='timm/',
+ crop_pct=0.95, test_crop_pct=1.0, test_input_size=(3, 288, 288), license='apache-2.0'),
+ 'rexnetr_200.sw_in12k': _cfg(
+ hf_hub_id='timm/',
+ num_classes=11821,
+ crop_pct=0.95, test_crop_pct=1.0, test_input_size=(3, 288, 288), license='apache-2.0'),
+ 'rexnetr_300.sw_in12k': _cfg(
+ hf_hub_id='timm/',
+ num_classes=11821,
+ crop_pct=0.95, test_crop_pct=1.0, test_input_size=(3, 288, 288), license='apache-2.0'),
+})
+
+
+@register_model
+def rexnet_100(pretrained=False, **kwargs) -> RexNet:
+ """ReXNet V1 1.0x"""
+ return _create_rexnet('rexnet_100', pretrained, **kwargs)
+
+
+@register_model
+def rexnet_130(pretrained=False, **kwargs) -> RexNet:
+ """ReXNet V1 1.3x"""
+ return _create_rexnet('rexnet_130', pretrained, width_mult=1.3, **kwargs)
+
+
+@register_model
+def rexnet_150(pretrained=False, **kwargs) -> RexNet:
+ """ReXNet V1 1.5x"""
+ return _create_rexnet('rexnet_150', pretrained, width_mult=1.5, **kwargs)
+
+
+@register_model
+def rexnet_200(pretrained=False, **kwargs) -> RexNet:
+ """ReXNet V1 2.0x"""
+ return _create_rexnet('rexnet_200', pretrained, width_mult=2.0, **kwargs)
+
+
+@register_model
+def rexnet_300(pretrained=False, **kwargs) -> RexNet:
+ """ReXNet V1 3.0x"""
+ return _create_rexnet('rexnet_300', pretrained, width_mult=3.0, **kwargs)
+
+
+@register_model
+def rexnetr_100(pretrained=False, **kwargs) -> RexNet:
+ """ReXNet V1 1.0x w/ rounded (mod 8) channels"""
+ return _create_rexnet('rexnetr_100', pretrained, ch_div=8, **kwargs)
+
+
+@register_model
+def rexnetr_130(pretrained=False, **kwargs) -> RexNet:
+ """ReXNet V1 1.3x w/ rounded (mod 8) channels"""
+ return _create_rexnet('rexnetr_130', pretrained, width_mult=1.3, ch_div=8, **kwargs)
+
+
+@register_model
+def rexnetr_150(pretrained=False, **kwargs) -> RexNet:
+ """ReXNet V1 1.5x w/ rounded (mod 8) channels"""
+ return _create_rexnet('rexnetr_150', pretrained, width_mult=1.5, ch_div=8, **kwargs)
+
+
+@register_model
+def rexnetr_200(pretrained=False, **kwargs) -> RexNet:
+ """ReXNet V1 2.0x w/ rounded (mod 8) channels"""
+ return _create_rexnet('rexnetr_200', pretrained, width_mult=2.0, ch_div=8, **kwargs)
+
+
+@register_model
+def rexnetr_300(pretrained=False, **kwargs) -> RexNet:
+ """ReXNet V1 3.0x w/ rounded (mod 16) channels"""
+ return _create_rexnet('rexnetr_300', pretrained, width_mult=3.0, ch_div=16, **kwargs)
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/sequencer.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/sequencer.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb0f15f385b92e58125a8e54d32e85495a6add68
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/sequencer.py
@@ -0,0 +1,540 @@
+""" Sequencer
+
+Paper: `Sequencer: Deep LSTM for Image Classification` - https://arxiv.org/pdf/2205.01972.pdf
+
+"""
+# Copyright (c) 2022. Yuki Tatsunami
+# Licensed under the Apache License, Version 2.0 (the "License");
+
+import math
+from functools import partial
+from itertools import accumulate
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
+from timm.layers import lecun_normal_, DropPath, Mlp, PatchEmbed, ClassifierHead
+from ._builder import build_model_with_cfg
+from ._manipulate import named_apply
+from ._registry import register_model, generate_default_cfgs
+
+__all__ = ['Sequencer2d'] # model_registry will add each entrypoint fn to this
+
+
+def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=False):
+ if isinstance(module, nn.Linear):
+ if name.startswith('head'):
+ nn.init.zeros_(module.weight)
+ nn.init.constant_(module.bias, head_bias)
+ else:
+ if flax:
+ # Flax defaults
+ lecun_normal_(module.weight)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ else:
+ nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ if 'mlp' in name:
+ nn.init.normal_(module.bias, std=1e-6)
+ else:
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, nn.Conv2d):
+ lecun_normal_(module.weight)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):
+ nn.init.ones_(module.weight)
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, (nn.RNN, nn.GRU, nn.LSTM)):
+ stdv = 1.0 / math.sqrt(module.hidden_size)
+ for weight in module.parameters():
+ nn.init.uniform_(weight, -stdv, stdv)
+ elif hasattr(module, 'init_weights'):
+ module.init_weights()
+
+
+class RNNIdentity(nn.Module):
+ def __init__(self, *args, **kwargs):
+ super(RNNIdentity, self).__init__()
+
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, None]:
+ return x, None
+
+
+class RNN2dBase(nn.Module):
+
+ def __init__(
+ self,
+ input_size: int,
+ hidden_size: int,
+ num_layers: int = 1,
+ bias: bool = True,
+ bidirectional: bool = True,
+ union="cat",
+ with_fc=True,
+ ):
+ super().__init__()
+
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.output_size = 2 * hidden_size if bidirectional else hidden_size
+ self.union = union
+
+ self.with_vertical = True
+ self.with_horizontal = True
+ self.with_fc = with_fc
+
+ self.fc = None
+ if with_fc:
+ if union == "cat":
+ self.fc = nn.Linear(2 * self.output_size, input_size)
+ elif union == "add":
+ self.fc = nn.Linear(self.output_size, input_size)
+ elif union == "vertical":
+ self.fc = nn.Linear(self.output_size, input_size)
+ self.with_horizontal = False
+ elif union == "horizontal":
+ self.fc = nn.Linear(self.output_size, input_size)
+ self.with_vertical = False
+ else:
+ raise ValueError("Unrecognized union: " + union)
+ elif union == "cat":
+ pass
+ if 2 * self.output_size != input_size:
+ raise ValueError(f"The output channel {2 * self.output_size} is different from the input channel {input_size}.")
+ elif union == "add":
+ pass
+ if self.output_size != input_size:
+ raise ValueError(f"The output channel {self.output_size} is different from the input channel {input_size}.")
+ elif union == "vertical":
+ if self.output_size != input_size:
+ raise ValueError(f"The output channel {self.output_size} is different from the input channel {input_size}.")
+ self.with_horizontal = False
+ elif union == "horizontal":
+ if self.output_size != input_size:
+ raise ValueError(f"The output channel {self.output_size} is different from the input channel {input_size}.")
+ self.with_vertical = False
+ else:
+ raise ValueError("Unrecognized union: " + union)
+
+ self.rnn_v = RNNIdentity()
+ self.rnn_h = RNNIdentity()
+
+ def forward(self, x):
+ B, H, W, C = x.shape
+
+ if self.with_vertical:
+ v = x.permute(0, 2, 1, 3)
+ v = v.reshape(-1, H, C)
+ v, _ = self.rnn_v(v)
+ v = v.reshape(B, W, H, -1)
+ v = v.permute(0, 2, 1, 3)
+ else:
+ v = None
+
+ if self.with_horizontal:
+ h = x.reshape(-1, W, C)
+ h, _ = self.rnn_h(h)
+ h = h.reshape(B, H, W, -1)
+ else:
+ h = None
+
+ if v is not None and h is not None:
+ if self.union == "cat":
+ x = torch.cat([v, h], dim=-1)
+ else:
+ x = v + h
+ elif v is not None:
+ x = v
+ elif h is not None:
+ x = h
+
+ if self.fc is not None:
+ x = self.fc(x)
+
+ return x
+
+
+class LSTM2d(RNN2dBase):
+
+ def __init__(
+ self,
+ input_size: int,
+ hidden_size: int,
+ num_layers: int = 1,
+ bias: bool = True,
+ bidirectional: bool = True,
+ union="cat",
+ with_fc=True,
+ ):
+ super().__init__(input_size, hidden_size, num_layers, bias, bidirectional, union, with_fc)
+ if self.with_vertical:
+ self.rnn_v = nn.LSTM(
+ input_size,
+ hidden_size,
+ num_layers,
+ batch_first=True,
+ bias=bias,
+ bidirectional=bidirectional,
+ )
+ if self.with_horizontal:
+ self.rnn_h = nn.LSTM(
+ input_size,
+ hidden_size,
+ num_layers,
+ batch_first=True,
+ bias=bias,
+ bidirectional=bidirectional,
+ )
+
+
+class Sequencer2dBlock(nn.Module):
+ def __init__(
+ self,
+ dim,
+ hidden_size,
+ mlp_ratio=3.0,
+ rnn_layer=LSTM2d,
+ mlp_layer=Mlp,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ act_layer=nn.GELU,
+ num_layers=1,
+ bidirectional=True,
+ union="cat",
+ with_fc=True,
+ drop=0.,
+ drop_path=0.,
+ ):
+ super().__init__()
+ channels_dim = int(mlp_ratio * dim)
+ self.norm1 = norm_layer(dim)
+ self.rnn_tokens = rnn_layer(
+ dim,
+ hidden_size,
+ num_layers=num_layers,
+ bidirectional=bidirectional,
+ union=union,
+ with_fc=with_fc,
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop)
+
+ def forward(self, x):
+ x = x + self.drop_path(self.rnn_tokens(self.norm1(x)))
+ x = x + self.drop_path(self.mlp_channels(self.norm2(x)))
+ return x
+
+
+class Shuffle(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ if self.training:
+ B, H, W, C = x.shape
+ r = torch.randperm(H * W)
+ x = x.reshape(B, -1, C)
+ x = x[:, r, :].reshape(B, H, W, -1)
+ return x
+
+
+class Downsample2d(nn.Module):
+ def __init__(self, input_dim, output_dim, patch_size):
+ super().__init__()
+ self.down = nn.Conv2d(input_dim, output_dim, kernel_size=patch_size, stride=patch_size)
+
+ def forward(self, x):
+ x = x.permute(0, 3, 1, 2)
+ x = self.down(x)
+ x = x.permute(0, 2, 3, 1)
+ return x
+
+
+class Sequencer2dStage(nn.Module):
+ def __init__(
+ self,
+ dim,
+ dim_out,
+ depth,
+ patch_size,
+ hidden_size,
+ mlp_ratio,
+ downsample=False,
+ block_layer=Sequencer2dBlock,
+ rnn_layer=LSTM2d,
+ mlp_layer=Mlp,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ act_layer=nn.GELU,
+ num_layers=1,
+ bidirectional=True,
+ union="cat",
+ with_fc=True,
+ drop=0.,
+ drop_path=0.,
+ ):
+ super().__init__()
+ if downsample:
+ self.downsample = Downsample2d(dim, dim_out, patch_size)
+ else:
+ assert dim == dim_out
+ self.downsample = nn.Identity()
+
+ blocks = []
+ for block_idx in range(depth):
+ blocks.append(block_layer(
+ dim_out,
+ hidden_size,
+ mlp_ratio=mlp_ratio,
+ rnn_layer=rnn_layer,
+ mlp_layer=mlp_layer,
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ num_layers=num_layers,
+ bidirectional=bidirectional,
+ union=union,
+ with_fc=with_fc,
+ drop=drop,
+ drop_path=drop_path[block_idx] if isinstance(drop_path, (list, tuple)) else drop_path,
+ ))
+ self.blocks = nn.Sequential(*blocks)
+
+ def forward(self, x):
+ x = self.downsample(x)
+ x = self.blocks(x)
+ return x
+
+
+class Sequencer2d(nn.Module):
+ def __init__(
+ self,
+ num_classes=1000,
+ img_size=224,
+ in_chans=3,
+ global_pool='avg',
+ layers=(4, 3, 8, 3),
+ patch_sizes=(7, 2, 2, 1),
+ embed_dims=(192, 384, 384, 384),
+ hidden_sizes=(48, 96, 96, 96),
+ mlp_ratios=(3.0, 3.0, 3.0, 3.0),
+ block_layer=Sequencer2dBlock,
+ rnn_layer=LSTM2d,
+ mlp_layer=Mlp,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ act_layer=nn.GELU,
+ num_rnn_layers=1,
+ bidirectional=True,
+ union="cat",
+ with_fc=True,
+ drop_rate=0.,
+ drop_path_rate=0.,
+ nlhb=False,
+ stem_norm=False,
+ ):
+ super().__init__()
+ assert global_pool in ('', 'avg')
+ self.num_classes = num_classes
+ self.global_pool = global_pool
+ self.num_features = embed_dims[-1] # num_features for consistency with other models
+ self.feature_dim = -1 # channel dim index for feature outputs (rank 4, NHWC)
+ self.output_fmt = 'NHWC'
+ self.feature_info = []
+
+ self.stem = PatchEmbed(
+ img_size=None,
+ patch_size=patch_sizes[0],
+ in_chans=in_chans,
+ embed_dim=embed_dims[0],
+ norm_layer=norm_layer if stem_norm else None,
+ flatten=False,
+ output_fmt='NHWC',
+ )
+
+ assert len(layers) == len(patch_sizes) == len(embed_dims) == len(hidden_sizes) == len(mlp_ratios)
+ reductions = list(accumulate(patch_sizes, lambda x, y: x * y))
+ stages = []
+ prev_dim = embed_dims[0]
+ for i, _ in enumerate(embed_dims):
+ stages += [Sequencer2dStage(
+ prev_dim,
+ embed_dims[i],
+ depth=layers[i],
+ downsample=i > 0,
+ patch_size=patch_sizes[i],
+ hidden_size=hidden_sizes[i],
+ mlp_ratio=mlp_ratios[i],
+ block_layer=block_layer,
+ rnn_layer=rnn_layer,
+ mlp_layer=mlp_layer,
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ num_layers=num_rnn_layers,
+ bidirectional=bidirectional,
+ union=union,
+ with_fc=with_fc,
+ drop=drop_rate,
+ drop_path=drop_path_rate,
+ )]
+ prev_dim = embed_dims[i]
+ self.feature_info += [dict(num_chs=prev_dim, reduction=reductions[i], module=f'stages.{i}')]
+
+ self.stages = nn.Sequential(*stages)
+ self.norm = norm_layer(embed_dims[-1])
+ self.head = ClassifierHead(
+ self.num_features,
+ num_classes,
+ pool_type=global_pool,
+ drop_rate=drop_rate,
+ input_fmt=self.output_fmt,
+ )
+
+ self.init_weights(nlhb=nlhb)
+
+ def init_weights(self, nlhb=False):
+ head_bias = -math.log(self.num_classes) if nlhb else 0.
+ named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse=False):
+ return dict(
+ stem=r'^stem',
+ blocks=[
+ (r'^stages\.(\d+)', None),
+ (r'^norm', (99999,))
+ ] if coarse else [
+ (r'^stages\.(\d+)\.blocks\.(\d+)', None),
+ (r'^stages\.(\d+)\.downsample', (0,)),
+ (r'^norm', (99999,))
+ ]
+ )
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ assert not enable, 'gradient checkpointing not supported'
+
+ @torch.jit.ignore
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=None):
+ self.num_classes = num_classes
+ self.head.reset(num_classes, pool_type=global_pool)
+
+ def forward_features(self, x):
+ x = self.stem(x)
+ x = self.stages(x)
+ x = self.norm(x)
+ return x
+
+ def forward_head(self, x, pre_logits: bool = False):
+ return self.head(x, pre_logits=True) if pre_logits else self.head(x)
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.forward_head(x)
+ return x
+
+
+def checkpoint_filter_fn(state_dict, model):
+ """ Remap original checkpoints -> timm """
+ if 'stages.0.blocks.0.norm1.weight' in state_dict:
+ return state_dict # already translated checkpoint
+ if 'model' in state_dict:
+ state_dict = state_dict['model']
+
+ import re
+ out_dict = {}
+ for k, v in state_dict.items():
+ k = re.sub(r'blocks.([0-9]+).([0-9]+).down', lambda x: f'stages.{int(x.group(1)) + 1}.downsample.down', k)
+ k = re.sub(r'blocks.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
+ k = k.replace('head.', 'head.fc.')
+ out_dict[k] = v
+
+ return out_dict
+
+
+def _create_sequencer2d(variant, pretrained=False, **kwargs):
+ default_out_indices = tuple(range(3))
+ out_indices = kwargs.pop('out_indices', default_out_indices)
+
+ model = build_model_with_cfg(
+ Sequencer2d,
+ variant,
+ pretrained,
+ pretrained_filter_fn=checkpoint_filter_fn,
+ feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
+ **kwargs,
+ )
+ return model
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': DEFAULT_CROP_PCT, 'interpolation': 'bicubic', 'fixed_input_size': True,
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'stem.proj', 'classifier': 'head.fc',
+ **kwargs
+ }
+
+
+default_cfgs = generate_default_cfgs({
+ 'sequencer2d_s.in1k': _cfg(hf_hub_id='timm/'),
+ 'sequencer2d_m.in1k': _cfg(hf_hub_id='timm/'),
+ 'sequencer2d_l.in1k': _cfg(hf_hub_id='timm/'),
+})
+
+
+@register_model
+def sequencer2d_s(pretrained=False, **kwargs) -> Sequencer2d:
+ model_args = dict(
+ layers=[4, 3, 8, 3],
+ patch_sizes=[7, 2, 1, 1],
+ embed_dims=[192, 384, 384, 384],
+ hidden_sizes=[48, 96, 96, 96],
+ mlp_ratios=[3.0, 3.0, 3.0, 3.0],
+ rnn_layer=LSTM2d,
+ bidirectional=True,
+ union="cat",
+ with_fc=True,
+ )
+ model = _create_sequencer2d('sequencer2d_s', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def sequencer2d_m(pretrained=False, **kwargs) -> Sequencer2d:
+ model_args = dict(
+ layers=[4, 3, 14, 3],
+ patch_sizes=[7, 2, 1, 1],
+ embed_dims=[192, 384, 384, 384],
+ hidden_sizes=[48, 96, 96, 96],
+ mlp_ratios=[3.0, 3.0, 3.0, 3.0],
+ rnn_layer=LSTM2d,
+ bidirectional=True,
+ union="cat",
+ with_fc=True,
+ **kwargs)
+ model = _create_sequencer2d('sequencer2d_m', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def sequencer2d_l(pretrained=False, **kwargs) -> Sequencer2d:
+ model_args = dict(
+ layers=[8, 8, 16, 4],
+ patch_sizes=[7, 2, 1, 1],
+ embed_dims=[192, 384, 384, 384],
+ hidden_sizes=[48, 96, 96, 96],
+ mlp_ratios=[3.0, 3.0, 3.0, 3.0],
+ rnn_layer=LSTM2d,
+ bidirectional=True,
+ union="cat",
+ with_fc=True,
+ **kwargs)
+ model = _create_sequencer2d('sequencer2d_l', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/sknet.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/sknet.py
new file mode 100644
index 0000000000000000000000000000000000000000..01565875cb5a6284c04b57ab557169bfe6ea8a60
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/sknet.py
@@ -0,0 +1,240 @@
+""" Selective Kernel Networks (ResNet base)
+
+Paper: Selective Kernel Networks (https://arxiv.org/abs/1903.06586)
+
+This was inspired by reading 'Compounding the Performance Improvements...' (https://arxiv.org/abs/2001.06268)
+and a streamlined impl at https://github.com/clovaai/assembled-cnn but I ended up building something closer
+to the original paper with some modifications of my own to better balance param count vs accuracy.
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+import math
+
+from torch import nn as nn
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.layers import SelectiveKernel, ConvNormAct, create_attn
+from ._builder import build_model_with_cfg
+from ._registry import register_model, generate_default_cfgs
+from .resnet import ResNet
+
+
+class SelectiveKernelBasic(nn.Module):
+ expansion = 1
+
+ def __init__(
+ self,
+ inplanes,
+ planes,
+ stride=1,
+ downsample=None,
+ cardinality=1,
+ base_width=64,
+ sk_kwargs=None,
+ reduce_first=1,
+ dilation=1,
+ first_dilation=None,
+ act_layer=nn.ReLU,
+ norm_layer=nn.BatchNorm2d,
+ attn_layer=None,
+ aa_layer=None,
+ drop_block=None,
+ drop_path=None,
+ ):
+ super(SelectiveKernelBasic, self).__init__()
+
+ sk_kwargs = sk_kwargs or {}
+ conv_kwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
+ assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
+ assert base_width == 64, 'BasicBlock doest not support changing base width'
+ first_planes = planes // reduce_first
+ outplanes = planes * self.expansion
+ first_dilation = first_dilation or dilation
+
+ self.conv1 = SelectiveKernel(
+ inplanes, first_planes, stride=stride, dilation=first_dilation,
+ aa_layer=aa_layer, drop_layer=drop_block, **conv_kwargs, **sk_kwargs)
+ self.conv2 = ConvNormAct(
+ first_planes, outplanes, kernel_size=3, dilation=dilation, apply_act=False, **conv_kwargs)
+ self.se = create_attn(attn_layer, outplanes)
+ self.act = act_layer(inplace=True)
+ self.downsample = downsample
+ self.drop_path = drop_path
+
+ def zero_init_last(self):
+ if getattr(self.conv2.bn, 'weight', None) is not None:
+ nn.init.zeros_(self.conv2.bn.weight)
+
+ def forward(self, x):
+ shortcut = x
+ x = self.conv1(x)
+ x = self.conv2(x)
+ if self.se is not None:
+ x = self.se(x)
+ if self.drop_path is not None:
+ x = self.drop_path(x)
+ if self.downsample is not None:
+ shortcut = self.downsample(shortcut)
+ x += shortcut
+ x = self.act(x)
+ return x
+
+
+class SelectiveKernelBottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(
+ self,
+ inplanes,
+ planes,
+ stride=1,
+ downsample=None,
+ cardinality=1,
+ base_width=64,
+ sk_kwargs=None,
+ reduce_first=1,
+ dilation=1,
+ first_dilation=None,
+ act_layer=nn.ReLU,
+ norm_layer=nn.BatchNorm2d,
+ attn_layer=None,
+ aa_layer=None,
+ drop_block=None,
+ drop_path=None,
+ ):
+ super(SelectiveKernelBottleneck, self).__init__()
+
+ sk_kwargs = sk_kwargs or {}
+ conv_kwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
+ width = int(math.floor(planes * (base_width / 64)) * cardinality)
+ first_planes = width // reduce_first
+ outplanes = planes * self.expansion
+ first_dilation = first_dilation or dilation
+
+ self.conv1 = ConvNormAct(inplanes, first_planes, kernel_size=1, **conv_kwargs)
+ self.conv2 = SelectiveKernel(
+ first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality,
+ aa_layer=aa_layer, drop_layer=drop_block, **conv_kwargs, **sk_kwargs)
+ self.conv3 = ConvNormAct(width, outplanes, kernel_size=1, apply_act=False, **conv_kwargs)
+ self.se = create_attn(attn_layer, outplanes)
+ self.act = act_layer(inplace=True)
+ self.downsample = downsample
+ self.drop_path = drop_path
+
+ def zero_init_last(self):
+ if getattr(self.conv3.bn, 'weight', None) is not None:
+ nn.init.zeros_(self.conv3.bn.weight)
+
+ def forward(self, x):
+ shortcut = x
+ x = self.conv1(x)
+ x = self.conv2(x)
+ x = self.conv3(x)
+ if self.se is not None:
+ x = self.se(x)
+ if self.drop_path is not None:
+ x = self.drop_path(x)
+ if self.downsample is not None:
+ shortcut = self.downsample(shortcut)
+ x += shortcut
+ x = self.act(x)
+ return x
+
+
+def _create_skresnet(variant, pretrained=False, **kwargs):
+ return build_model_with_cfg(
+ ResNet,
+ variant,
+ pretrained,
+ **kwargs,
+ )
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'conv1', 'classifier': 'fc',
+ **kwargs
+ }
+
+
+default_cfgs = generate_default_cfgs({
+ 'skresnet18.ra_in1k': _cfg(hf_hub_id='timm/'),
+ 'skresnet34.ra_in1k': _cfg(hf_hub_id='timm/'),
+ 'skresnet50.untrained': _cfg(),
+ 'skresnet50d.untrained': _cfg(
+ first_conv='conv1.0'),
+ 'skresnext50_32x4d.ra_in1k': _cfg(hf_hub_id='timm/'),
+})
+
+
+@register_model
+def skresnet18(pretrained=False, **kwargs) -> ResNet:
+ """Constructs a Selective Kernel ResNet-18 model.
+
+ Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
+ variation splits the input channels to the selective convolutions to keep param count down.
+ """
+ sk_kwargs = dict(rd_ratio=1 / 8, rd_divisor=16, split_input=True)
+ model_args = dict(
+ block=SelectiveKernelBasic, layers=[2, 2, 2, 2], block_args=dict(sk_kwargs=sk_kwargs),
+ zero_init_last=False, **kwargs)
+ return _create_skresnet('skresnet18', pretrained, **model_args)
+
+
+@register_model
+def skresnet34(pretrained=False, **kwargs) -> ResNet:
+ """Constructs a Selective Kernel ResNet-34 model.
+
+ Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
+ variation splits the input channels to the selective convolutions to keep param count down.
+ """
+ sk_kwargs = dict(rd_ratio=1 / 8, rd_divisor=16, split_input=True)
+ model_args = dict(
+ block=SelectiveKernelBasic, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs),
+ zero_init_last=False, **kwargs)
+ return _create_skresnet('skresnet34', pretrained, **model_args)
+
+
+@register_model
+def skresnet50(pretrained=False, **kwargs) -> ResNet:
+ """Constructs a Select Kernel ResNet-50 model.
+
+ Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
+ variation splits the input channels to the selective convolutions to keep param count down.
+ """
+ sk_kwargs = dict(split_input=True)
+ model_args = dict(
+ block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs),
+ zero_init_last=False, **kwargs)
+ return _create_skresnet('skresnet50', pretrained, **model_args)
+
+
+@register_model
+def skresnet50d(pretrained=False, **kwargs) -> ResNet:
+ """Constructs a Select Kernel ResNet-50-D model.
+
+ Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
+ variation splits the input channels to the selective convolutions to keep param count down.
+ """
+ sk_kwargs = dict(split_input=True)
+ model_args = dict(
+ block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
+ block_args=dict(sk_kwargs=sk_kwargs), zero_init_last=False, **kwargs)
+ return _create_skresnet('skresnet50d', pretrained, **model_args)
+
+
+@register_model
+def skresnext50_32x4d(pretrained=False, **kwargs) -> ResNet:
+ """Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to
+ the SKNet-50 model in the Select Kernel Paper
+ """
+ sk_kwargs = dict(rd_ratio=1/16, rd_divisor=32, split_input=False)
+ model_args = dict(
+ block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4,
+ block_args=dict(sk_kwargs=sk_kwargs), zero_init_last=False, **kwargs)
+ return _create_skresnet('skresnext50_32x4d', pretrained, **model_args)
+
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/swin_transformer_v2_cr.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/swin_transformer_v2_cr.py
new file mode 100644
index 0000000000000000000000000000000000000000..1aae86459fc347d7d911190a5ed3d6789d696275
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/swin_transformer_v2_cr.py
@@ -0,0 +1,1024 @@
+""" Swin Transformer V2
+
+A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution`
+ - https://arxiv.org/pdf/2111.09883
+
+Code adapted from https://github.com/ChristophReich1996/Swin-Transformer-V2, original copyright/license info below
+
+This implementation is experimental and subject to change in manners that will break weight compat:
+* Size of the pos embed MLP are not spelled out in paper in terms of dim, fixed for all models? vary with num_heads?
+ * currently dim is fixed, I feel it may make sense to scale with num_heads (dim per head)
+* The specifics of the memory saving 'sequential attention' are not detailed, Christoph Reich has an impl at
+ GitHub link above. It needs further investigation as throughput vs mem tradeoff doesn't appear beneficial.
+* num_heads per stage is not detailed for Huge and Giant model variants
+* 'Giant' is 3B params in paper but ~2.6B here despite matching paper dim + block counts
+* experiments are ongoing wrt to 'main branch' norm layer use and weight init scheme
+
+Noteworthy additions over official Swin v1:
+* MLP relative position embedding is looking promising and adapts to different image/window sizes
+* This impl has been designed to allow easy change of image size with matching window size changes
+* Non-square image size and window size are supported
+
+Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman
+"""
+# --------------------------------------------------------
+# Swin Transformer V2 reimplementation
+# Copyright (c) 2021 Christoph Reich
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Christoph Reich
+# --------------------------------------------------------
+import logging
+import math
+from typing import Tuple, Optional, List, Union, Any, Type
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.layers import DropPath, Mlp, ClassifierHead, to_2tuple, _assert, ndgrid
+from ._builder import build_model_with_cfg
+from ._features_fx import register_notrace_function
+from ._manipulate import named_apply
+from ._registry import generate_default_cfgs, register_model
+
+__all__ = ['SwinTransformerV2Cr'] # model_registry will add each entrypoint fn to this
+
+_logger = logging.getLogger(__name__)
+
+
+def bchw_to_bhwc(x: torch.Tensor) -> torch.Tensor:
+ """Permutes a tensor from the shape (B, C, H, W) to (B, H, W, C). """
+ return x.permute(0, 2, 3, 1)
+
+
+def bhwc_to_bchw(x: torch.Tensor) -> torch.Tensor:
+ """Permutes a tensor from the shape (B, H, W, C) to (B, C, H, W). """
+ return x.permute(0, 3, 1, 2)
+
+
+def window_partition(x, window_size: Tuple[int, int]):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
+ return windows
+
+
+@register_notrace_function # reason: int argument is a Proxy
+def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]):
+ """
+ Args:
+ windows: (num_windows * B, window_size[0], window_size[1], C)
+ window_size (Tuple[int, int]): Window size
+ img_size (Tuple[int, int]): Image size
+
+ Returns:
+ x: (B, H, W, C)
+ """
+ H, W = img_size
+ C = windows.shape[-1]
+ x = windows.view(-1, H // window_size[0], W // window_size[1], window_size[0], window_size[1], C)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, H, W, C)
+ return x
+
+
+class WindowMultiHeadAttention(nn.Module):
+ r"""This class implements window-based Multi-Head-Attention with log-spaced continuous position bias.
+
+ Args:
+ dim (int): Number of input features
+ window_size (int): Window size
+ num_heads (int): Number of attention heads
+ drop_attn (float): Dropout rate of attention map
+ drop_proj (float): Dropout rate after projection
+ meta_hidden_dim (int): Number of hidden features in the two layer MLP meta network
+ sequential_attn (bool): If true sequential self-attention is performed
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ window_size: Tuple[int, int],
+ drop_attn: float = 0.0,
+ drop_proj: float = 0.0,
+ meta_hidden_dim: int = 384, # FIXME what's the optimal value?
+ sequential_attn: bool = False,
+ ) -> None:
+ super(WindowMultiHeadAttention, self).__init__()
+ assert dim % num_heads == 0, \
+ "The number of input features (in_features) are not divisible by the number of heads (num_heads)."
+ self.in_features: int = dim
+ self.window_size: Tuple[int, int] = window_size
+ self.num_heads: int = num_heads
+ self.sequential_attn: bool = sequential_attn
+
+ self.qkv = nn.Linear(in_features=dim, out_features=dim * 3, bias=True)
+ self.attn_drop = nn.Dropout(drop_attn)
+ self.proj = nn.Linear(in_features=dim, out_features=dim, bias=True)
+ self.proj_drop = nn.Dropout(drop_proj)
+ # meta network for positional encodings
+ self.meta_mlp = Mlp(
+ 2, # x, y
+ hidden_features=meta_hidden_dim,
+ out_features=num_heads,
+ act_layer=nn.ReLU,
+ drop=(0.125, 0.) # FIXME should there be stochasticity, appears to 'overfit' without?
+ )
+ # NOTE old checkpoints used inverse of logit_scale ('tau') following the paper, see conversion fn
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones(num_heads)))
+ self._make_pair_wise_relative_positions()
+
+ def _make_pair_wise_relative_positions(self) -> None:
+ """Method initializes the pair-wise relative positions to compute the positional biases."""
+ device = self.logit_scale.device
+ coordinates = torch.stack(ndgrid(
+ torch.arange(self.window_size[0], device=device),
+ torch.arange(self.window_size[1], device=device)
+ ), dim=0).flatten(1)
+ relative_coordinates = coordinates[:, :, None] - coordinates[:, None, :]
+ relative_coordinates = relative_coordinates.permute(1, 2, 0).reshape(-1, 2).float()
+ relative_coordinates_log = torch.sign(relative_coordinates) * torch.log(
+ 1.0 + relative_coordinates.abs())
+ self.register_buffer("relative_coordinates_log", relative_coordinates_log, persistent=False)
+
+ def update_input_size(self, new_window_size: int, **kwargs: Any) -> None:
+ """Method updates the window size and so the pair-wise relative positions
+
+ Args:
+ new_window_size (int): New window size
+ kwargs (Any): Unused
+ """
+ # Set new window size and new pair-wise relative positions
+ self.window_size: int = new_window_size
+ self._make_pair_wise_relative_positions()
+
+ def _relative_positional_encodings(self) -> torch.Tensor:
+ """Method computes the relative positional encodings
+
+ Returns:
+ relative_position_bias (torch.Tensor): Relative positional encodings
+ (1, number of heads, window size ** 2, window size ** 2)
+ """
+ window_area = self.window_size[0] * self.window_size[1]
+ relative_position_bias = self.meta_mlp(self.relative_coordinates_log)
+ relative_position_bias = relative_position_bias.transpose(1, 0).reshape(
+ self.num_heads, window_area, window_area
+ )
+ relative_position_bias = relative_position_bias.unsqueeze(0)
+ return relative_position_bias
+
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ """ Forward pass.
+ Args:
+ x (torch.Tensor): Input tensor of the shape (B * windows, N, C)
+ mask (Optional[torch.Tensor]): Attention mask for the shift case
+
+ Returns:
+ Output tensor of the shape [B * windows, N, C]
+ """
+ Bw, L, C = x.shape
+
+ qkv = self.qkv(x).view(Bw, L, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ query, key, value = qkv.unbind(0)
+
+ # compute attention map with scaled cosine attention
+ attn = (F.normalize(query, dim=-1) @ F.normalize(key, dim=-1).transpose(-2, -1))
+ logit_scale = torch.clamp(self.logit_scale.reshape(1, self.num_heads, 1, 1), max=math.log(1. / 0.01)).exp()
+ attn = attn * logit_scale
+ attn = attn + self._relative_positional_encodings()
+
+ if mask is not None:
+ # Apply mask if utilized
+ num_win: int = mask.shape[0]
+ attn = attn.view(Bw // num_win, num_win, self.num_heads, L, L)
+ attn = attn + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, L, L)
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ value).transpose(1, 2).reshape(Bw, L, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class SwinTransformerV2CrBlock(nn.Module):
+ r"""This class implements the Swin transformer block.
+
+ Args:
+ dim (int): Number of input channels
+ num_heads (int): Number of attention heads to be utilized
+ feat_size (Tuple[int, int]): Input resolution
+ window_size (Tuple[int, int]): Window size to be utilized
+ shift_size (int): Shifting size to be used
+ mlp_ratio (int): Ratio of the hidden dimension in the FFN to the input channels
+ proj_drop (float): Dropout in input mapping
+ drop_attn (float): Dropout rate of attention map
+ drop_path (float): Dropout in main path
+ extra_norm (bool): Insert extra norm on 'main' branch if True
+ sequential_attn (bool): If true sequential self-attention is performed
+ norm_layer (Type[nn.Module]): Type of normalization layer to be utilized
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ feat_size: Tuple[int, int],
+ window_size: Tuple[int, int],
+ shift_size: Tuple[int, int] = (0, 0),
+ mlp_ratio: float = 4.0,
+ init_values: Optional[float] = 0,
+ proj_drop: float = 0.0,
+ drop_attn: float = 0.0,
+ drop_path: float = 0.0,
+ extra_norm: bool = False,
+ sequential_attn: bool = False,
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
+ ) -> None:
+ super(SwinTransformerV2CrBlock, self).__init__()
+ self.dim: int = dim
+ self.feat_size: Tuple[int, int] = feat_size
+ self.target_shift_size: Tuple[int, int] = to_2tuple(shift_size)
+ self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size))
+ self.window_area = self.window_size[0] * self.window_size[1]
+ self.init_values: Optional[float] = init_values
+
+ # attn branch
+ self.attn = WindowMultiHeadAttention(
+ dim=dim,
+ num_heads=num_heads,
+ window_size=self.window_size,
+ drop_attn=drop_attn,
+ drop_proj=proj_drop,
+ sequential_attn=sequential_attn,
+ )
+ self.norm1 = norm_layer(dim)
+ self.drop_path1 = DropPath(drop_prob=drop_path) if drop_path > 0.0 else nn.Identity()
+
+ # mlp branch
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=int(dim * mlp_ratio),
+ drop=proj_drop,
+ out_features=dim,
+ )
+ self.norm2 = norm_layer(dim)
+ self.drop_path2 = DropPath(drop_prob=drop_path) if drop_path > 0.0 else nn.Identity()
+
+ # Extra main branch norm layer mentioned for Huge/Giant models in V2 paper.
+ # Also being used as final network norm and optional stage ending norm while still in a C-last format.
+ self.norm3 = norm_layer(dim) if extra_norm else nn.Identity()
+
+ self._make_attention_mask()
+ self.init_weights()
+
+ def _calc_window_shift(self, target_window_size):
+ window_size = [f if f <= w else w for f, w in zip(self.feat_size, target_window_size)]
+ shift_size = [0 if f <= w else s for f, w, s in zip(self.feat_size, window_size, self.target_shift_size)]
+ return tuple(window_size), tuple(shift_size)
+
+ def _make_attention_mask(self) -> None:
+ """Method generates the attention mask used in shift case."""
+ # Make masks for shift case
+ if any(self.shift_size):
+ # calculate attention mask for SW-MSA
+ H, W = self.feat_size
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
+ cnt = 0
+ for h in (
+ slice(0, -self.window_size[0]),
+ slice(-self.window_size[0], -self.shift_size[0]),
+ slice(-self.shift_size[0], None)):
+ for w in (
+ slice(0, -self.window_size[1]),
+ slice(-self.window_size[1], -self.shift_size[1]),
+ slice(-self.shift_size[1], None)):
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+ mask_windows = window_partition(img_mask, self.window_size) # num_windows, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_area)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+ else:
+ attn_mask = None
+ self.register_buffer("attn_mask", attn_mask, persistent=False)
+
+ def init_weights(self):
+ # extra, module specific weight init
+ if self.init_values is not None:
+ nn.init.constant_(self.norm1.weight, self.init_values)
+ nn.init.constant_(self.norm2.weight, self.init_values)
+
+ def update_input_size(self, new_window_size: Tuple[int, int], new_feat_size: Tuple[int, int]) -> None:
+ """Method updates the image resolution to be processed and window size and so the pair-wise relative positions.
+
+ Args:
+ new_window_size (int): New window size
+ new_feat_size (Tuple[int, int]): New input resolution
+ """
+ # Update input resolution
+ self.feat_size: Tuple[int, int] = new_feat_size
+ self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(new_window_size))
+ self.window_area = self.window_size[0] * self.window_size[1]
+ self.attn.update_input_size(new_window_size=self.window_size)
+ self._make_attention_mask()
+
+ def _shifted_window_attn(self, x):
+ B, H, W, C = x.shape
+
+ # cyclic shift
+ sh, sw = self.shift_size
+ do_shift: bool = any(self.shift_size)
+ if do_shift:
+ # FIXME PyTorch XLA needs cat impl, roll not lowered
+ # x = torch.cat([x[:, sh:], x[:, :sh]], dim=1)
+ # x = torch.cat([x[:, :, sw:], x[:, :, :sw]], dim=2)
+ x = torch.roll(x, shifts=(-sh, -sw), dims=(1, 2))
+
+ # partition windows
+ x_windows = window_partition(x, self.window_size) # num_windows * B, window_size, window_size, C
+ x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C)
+
+ # W-MSA/SW-MSA
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # num_windows * B, window_size * window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
+ x = window_reverse(attn_windows, self.window_size, self.feat_size) # B H' W' C
+
+ # reverse cyclic shift
+ if do_shift:
+ # FIXME PyTorch XLA needs cat impl, roll not lowered
+ # x = torch.cat([x[:, -sh:], x[:, :-sh]], dim=1)
+ # x = torch.cat([x[:, :, -sw:], x[:, :, :-sw]], dim=2)
+ x = torch.roll(x, shifts=(sh, sw), dims=(1, 2))
+
+ return x
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass.
+
+ Args:
+ x (torch.Tensor): Input tensor of the shape [B, C, H, W]
+
+ Returns:
+ output (torch.Tensor): Output tensor of the shape [B, C, H, W]
+ """
+ # post-norm branches (op -> norm -> drop)
+ x = x + self.drop_path1(self.norm1(self._shifted_window_attn(x)))
+
+ B, H, W, C = x.shape
+ x = x.reshape(B, -1, C)
+ x = x + self.drop_path2(self.norm2(self.mlp(x)))
+ x = self.norm3(x) # main-branch norm enabled for some blocks / stages (every 6 for Huge/Giant)
+ x = x.reshape(B, H, W, C)
+ return x
+
+
+class PatchMerging(nn.Module):
+ """ This class implements the patch merging as a strided convolution with a normalization before.
+ Args:
+ dim (int): Number of input channels
+ norm_layer (Type[nn.Module]): Type of normalization layer to be utilized.
+ """
+
+ def __init__(self, dim: int, norm_layer: Type[nn.Module] = nn.LayerNorm) -> None:
+ super(PatchMerging, self).__init__()
+ self.norm = norm_layer(4 * dim)
+ self.reduction = nn.Linear(in_features=4 * dim, out_features=2 * dim, bias=False)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """ Forward pass.
+ Args:
+ x (torch.Tensor): Input tensor of the shape [B, C, H, W]
+ Returns:
+ output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2]
+ """
+ B, H, W, C = x.shape
+ x = x.reshape(B, H // 2, 2, W // 2, 2, C).permute(0, 1, 3, 4, 2, 5).flatten(3)
+ x = self.norm(x)
+ x = self.reduction(x)
+ return x
+
+
+class PatchEmbed(nn.Module):
+ """ 2D Image to Patch Embedding """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x):
+ B, C, H, W = x.shape
+ _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
+ _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
+ x = self.proj(x)
+ x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+ return x
+
+
+class SwinTransformerV2CrStage(nn.Module):
+ r"""This class implements a stage of the Swin transformer including multiple layers.
+
+ Args:
+ embed_dim (int): Number of input channels
+ depth (int): Depth of the stage (number of layers)
+ downscale (bool): If true input is downsampled (see Fig. 3 or V1 paper)
+ feat_size (Tuple[int, int]): input feature map size (H, W)
+ num_heads (int): Number of attention heads to be utilized
+ window_size (int): Window size to be utilized
+ mlp_ratio (int): Ratio of the hidden dimension in the FFN to the input channels
+ proj_drop (float): Dropout in input mapping
+ drop_attn (float): Dropout rate of attention map
+ drop_path (float): Dropout in main path
+ norm_layer (Type[nn.Module]): Type of normalization layer to be utilized. Default: nn.LayerNorm
+ extra_norm_period (int): Insert extra norm layer on main branch every N (period) blocks
+ extra_norm_stage (bool): End each stage with an extra norm layer in main branch
+ sequential_attn (bool): If true sequential self-attention is performed
+ """
+
+ def __init__(
+ self,
+ embed_dim: int,
+ depth: int,
+ downscale: bool,
+ num_heads: int,
+ feat_size: Tuple[int, int],
+ window_size: Tuple[int, int],
+ mlp_ratio: float = 4.0,
+ init_values: Optional[float] = 0.0,
+ proj_drop: float = 0.0,
+ drop_attn: float = 0.0,
+ drop_path: Union[List[float], float] = 0.0,
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
+ extra_norm_period: int = 0,
+ extra_norm_stage: bool = False,
+ sequential_attn: bool = False,
+ ) -> None:
+ super(SwinTransformerV2CrStage, self).__init__()
+ self.downscale: bool = downscale
+ self.grad_checkpointing: bool = False
+ self.feat_size: Tuple[int, int] = (feat_size[0] // 2, feat_size[1] // 2) if downscale else feat_size
+
+ if downscale:
+ self.downsample = PatchMerging(embed_dim, norm_layer=norm_layer)
+ embed_dim = embed_dim * 2
+ else:
+ self.downsample = nn.Identity()
+
+ def _extra_norm(index):
+ i = index + 1
+ if extra_norm_period and i % extra_norm_period == 0:
+ return True
+ return i == depth if extra_norm_stage else False
+
+ self.blocks = nn.Sequential(*[
+ SwinTransformerV2CrBlock(
+ dim=embed_dim,
+ num_heads=num_heads,
+ feat_size=self.feat_size,
+ window_size=window_size,
+ shift_size=tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size]),
+ mlp_ratio=mlp_ratio,
+ init_values=init_values,
+ proj_drop=proj_drop,
+ drop_attn=drop_attn,
+ drop_path=drop_path[index] if isinstance(drop_path, list) else drop_path,
+ extra_norm=_extra_norm(index),
+ sequential_attn=sequential_attn,
+ norm_layer=norm_layer,
+ )
+ for index in range(depth)]
+ )
+
+ def update_input_size(self, new_window_size: int, new_feat_size: Tuple[int, int]) -> None:
+ """Method updates the resolution to utilize and the window size and so the pair-wise relative positions.
+
+ Args:
+ new_window_size (int): New window size
+ new_feat_size (Tuple[int, int]): New input resolution
+ """
+ self.feat_size: Tuple[int, int] = (
+ (new_feat_size[0] // 2, new_feat_size[1] // 2) if self.downscale else new_feat_size
+ )
+ for block in self.blocks:
+ block.update_input_size(new_window_size=new_window_size, new_feat_size=self.feat_size)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass.
+ Args:
+ x (torch.Tensor): Input tensor of the shape [B, C, H, W] or [B, L, C]
+ Returns:
+ output (torch.Tensor): Output tensor of the shape [B, 2 * C, H // 2, W // 2]
+ """
+ x = bchw_to_bhwc(x)
+ x = self.downsample(x)
+ for block in self.blocks:
+ # Perform checkpointing if utilized
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint.checkpoint(block, x)
+ else:
+ x = block(x)
+ x = bhwc_to_bchw(x)
+ return x
+
+
+class SwinTransformerV2Cr(nn.Module):
+ r""" Swin Transformer V2
+ A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution` -
+ https://arxiv.org/pdf/2111.09883
+
+ Args:
+ img_size: Input resolution.
+ window_size: Window size. If None, img_size // window_div
+ img_window_ratio: Window size to image size ratio.
+ patch_size: Patch size.
+ in_chans: Number of input channels.
+ depths: Depth of the stage (number of layers).
+ num_heads: Number of attention heads to be utilized.
+ embed_dim: Patch embedding dimension.
+ num_classes: Number of output classes.
+ mlp_ratio: Ratio of the hidden dimension in the FFN to the input channels.
+ drop_rate: Dropout rate.
+ proj_drop_rate: Projection dropout rate.
+ attn_drop_rate: Dropout rate of attention map.
+ drop_path_rate: Stochastic depth rate.
+ norm_layer: Type of normalization layer to be utilized.
+ extra_norm_period: Insert extra norm layer on main branch every N (period) blocks in stage
+ extra_norm_stage: End each stage with an extra norm layer in main branch
+ sequential_attn: If true sequential self-attention is performed.
+ """
+
+ def __init__(
+ self,
+ img_size: Tuple[int, int] = (224, 224),
+ patch_size: int = 4,
+ window_size: Optional[int] = None,
+ img_window_ratio: int = 32,
+ in_chans: int = 3,
+ num_classes: int = 1000,
+ embed_dim: int = 96,
+ depths: Tuple[int, ...] = (2, 2, 6, 2),
+ num_heads: Tuple[int, ...] = (3, 6, 12, 24),
+ mlp_ratio: float = 4.0,
+ init_values: Optional[float] = 0.,
+ drop_rate: float = 0.0,
+ proj_drop_rate: float = 0.0,
+ attn_drop_rate: float = 0.0,
+ drop_path_rate: float = 0.0,
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
+ extra_norm_period: int = 0,
+ extra_norm_stage: bool = False,
+ sequential_attn: bool = False,
+ global_pool: str = 'avg',
+ weight_init='skip',
+ **kwargs: Any
+ ) -> None:
+ super(SwinTransformerV2Cr, self).__init__()
+ img_size = to_2tuple(img_size)
+ window_size = tuple([
+ s // img_window_ratio for s in img_size]) if window_size is None else to_2tuple(window_size)
+
+ self.num_classes: int = num_classes
+ self.patch_size: int = patch_size
+ self.img_size: Tuple[int, int] = img_size
+ self.window_size: int = window_size
+ self.num_features: int = int(embed_dim * 2 ** (len(depths) - 1))
+ self.feature_info = []
+
+ self.patch_embed = PatchEmbed(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ norm_layer=norm_layer,
+ )
+ patch_grid_size: Tuple[int, int] = self.patch_embed.grid_size
+
+ dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
+ stages = []
+ in_dim = embed_dim
+ in_scale = 1
+ for stage_idx, (depth, num_heads) in enumerate(zip(depths, num_heads)):
+ stages += [SwinTransformerV2CrStage(
+ embed_dim=in_dim,
+ depth=depth,
+ downscale=stage_idx != 0,
+ feat_size=(
+ patch_grid_size[0] // in_scale,
+ patch_grid_size[1] // in_scale
+ ),
+ num_heads=num_heads,
+ window_size=window_size,
+ mlp_ratio=mlp_ratio,
+ init_values=init_values,
+ proj_drop=proj_drop_rate,
+ drop_attn=attn_drop_rate,
+ drop_path=dpr[stage_idx],
+ extra_norm_period=extra_norm_period,
+ extra_norm_stage=extra_norm_stage or (stage_idx + 1) == len(depths), # last stage ends w/ norm
+ sequential_attn=sequential_attn,
+ norm_layer=norm_layer,
+ )]
+ if stage_idx != 0:
+ in_dim *= 2
+ in_scale *= 2
+ self.feature_info += [dict(num_chs=in_dim, reduction=4 * in_scale, module=f'stages.{stage_idx}')]
+ self.stages = nn.Sequential(*stages)
+
+ self.head = ClassifierHead(
+ self.num_features,
+ num_classes,
+ pool_type=global_pool,
+ drop_rate=drop_rate,
+ )
+
+ # current weight init skips custom init and uses pytorch layer defaults, seems to work well
+ # FIXME more experiments needed
+ if weight_init != 'skip':
+ named_apply(init_weights, self)
+
+ def update_input_size(
+ self,
+ new_img_size: Optional[Tuple[int, int]] = None,
+ new_window_size: Optional[int] = None,
+ img_window_ratio: int = 32,
+ ) -> None:
+ """Method updates the image resolution to be processed and window size and so the pair-wise relative positions.
+
+ Args:
+ new_window_size (Optional[int]): New window size, if None based on new_img_size // window_div
+ new_img_size (Optional[Tuple[int, int]]): New input resolution, if None current resolution is used
+ img_window_ratio (int): divisor for calculating window size from image size
+ """
+ # Check parameters
+ if new_img_size is None:
+ new_img_size = self.img_size
+ else:
+ new_img_size = to_2tuple(new_img_size)
+ if new_window_size is None:
+ new_window_size = tuple([s // img_window_ratio for s in new_img_size])
+ # Compute new patch resolution & update resolution of each stage
+ new_patch_grid_size = (new_img_size[0] // self.patch_size, new_img_size[1] // self.patch_size)
+ for index, stage in enumerate(self.stages):
+ stage_scale = 2 ** max(index - 1, 0)
+ stage.update_input_size(
+ new_window_size=new_window_size,
+ new_img_size=(new_patch_grid_size[0] // stage_scale, new_patch_grid_size[1] // stage_scale),
+ )
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse=False):
+ return dict(
+ stem=r'^patch_embed', # stem and embed
+ blocks=r'^stages\.(\d+)' if coarse else [
+ (r'^stages\.(\d+).downsample', (0,)),
+ (r'^stages\.(\d+)\.\w+\.(\d+)', None),
+ ]
+ )
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ for s in self.stages:
+ s.grad_checkpointing = enable
+
+ @torch.jit.ignore()
+ def get_classifier(self) -> nn.Module:
+ """Method returns the classification head of the model.
+ Returns:
+ head (nn.Module): Current classification head
+ """
+ return self.head.fc
+
+ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None) -> None:
+ """Method results the classification head
+
+ Args:
+ num_classes (int): Number of classes to be predicted
+ global_pool (str): Unused
+ """
+ self.num_classes = num_classes
+ self.head.reset(num_classes, global_pool)
+
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.patch_embed(x)
+ x = self.stages(x)
+ return x
+
+ def forward_head(self, x, pre_logits: bool = False):
+ return self.head(x, pre_logits=True) if pre_logits else self.head(x)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.forward_features(x)
+ x = self.forward_head(x)
+ return x
+
+
+def init_weights(module: nn.Module, name: str = ''):
+ # FIXME WIP determining if there's a better weight init
+ if isinstance(module, nn.Linear):
+ if 'qkv' in name:
+ # treat the weights of Q, K, V separately
+ val = math.sqrt(6. / float(module.weight.shape[0] // 3 + module.weight.shape[1]))
+ nn.init.uniform_(module.weight, -val, val)
+ elif 'head' in name:
+ nn.init.zeros_(module.weight)
+ else:
+ nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif hasattr(module, 'init_weights'):
+ module.init_weights()
+
+
+def checkpoint_filter_fn(state_dict, model):
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
+ state_dict = state_dict.get('model', state_dict)
+ state_dict = state_dict.get('state_dict', state_dict)
+ if 'head.fc.weight' in state_dict:
+ return state_dict
+ out_dict = {}
+ for k, v in state_dict.items():
+ if 'tau' in k:
+ # convert old tau based checkpoints -> logit_scale (inverse)
+ v = torch.log(1 / v)
+ k = k.replace('tau', 'logit_scale')
+ k = k.replace('head.', 'head.fc.')
+ out_dict[k] = v
+ return out_dict
+
+
+def _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs):
+ default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 1, 1))))
+ out_indices = kwargs.pop('out_indices', default_out_indices)
+
+ model = build_model_with_cfg(
+ SwinTransformerV2Cr, variant, pretrained,
+ pretrained_filter_fn=checkpoint_filter_fn,
+ feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
+ **kwargs
+ )
+ return model
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000,
+ 'input_size': (3, 224, 224),
+ 'pool_size': (7, 7),
+ 'crop_pct': 0.9,
+ 'interpolation': 'bicubic',
+ 'fixed_input_size': True,
+ 'mean': IMAGENET_DEFAULT_MEAN,
+ 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'patch_embed.proj',
+ 'classifier': 'head.fc',
+ **kwargs,
+ }
+
+
+default_cfgs = generate_default_cfgs({
+ 'swinv2_cr_tiny_384.untrained': _cfg(
+ url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
+ 'swinv2_cr_tiny_224.untrained': _cfg(
+ url="", input_size=(3, 224, 224), crop_pct=0.9),
+ 'swinv2_cr_tiny_ns_224.sw_in1k': _cfg(
+ hf_hub_id='timm/',
+ url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_tiny_ns_224-ba8166c6.pth",
+ input_size=(3, 224, 224), crop_pct=0.9),
+ 'swinv2_cr_small_384.untrained': _cfg(
+ url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
+ 'swinv2_cr_small_224.sw_in1k': _cfg(
+ hf_hub_id='timm/',
+ url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_224-0813c165.pth",
+ input_size=(3, 224, 224), crop_pct=0.9),
+ 'swinv2_cr_small_ns_224.sw_in1k': _cfg(
+ hf_hub_id='timm/',
+ url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_ns_224_iv-2ce90f8e.pth",
+ input_size=(3, 224, 224), crop_pct=0.9),
+ 'swinv2_cr_small_ns_256.untrained': _cfg(
+ url="", input_size=(3, 256, 256), crop_pct=1.0, pool_size=(8, 8)),
+ 'swinv2_cr_base_384.untrained': _cfg(
+ url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
+ 'swinv2_cr_base_224.untrained': _cfg(
+ url="", input_size=(3, 224, 224), crop_pct=0.9),
+ 'swinv2_cr_base_ns_224.untrained': _cfg(
+ url="", input_size=(3, 224, 224), crop_pct=0.9),
+ 'swinv2_cr_large_384.untrained': _cfg(
+ url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
+ 'swinv2_cr_large_224.untrained': _cfg(
+ url="", input_size=(3, 224, 224), crop_pct=0.9),
+ 'swinv2_cr_huge_384.untrained': _cfg(
+ url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
+ 'swinv2_cr_huge_224.untrained': _cfg(
+ url="", input_size=(3, 224, 224), crop_pct=0.9),
+ 'swinv2_cr_giant_384.untrained': _cfg(
+ url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
+ 'swinv2_cr_giant_224.untrained': _cfg(
+ url="", input_size=(3, 224, 224), crop_pct=0.9),
+})
+
+
+@register_model
+def swinv2_cr_tiny_384(pretrained=False, **kwargs) -> SwinTransformerV2Cr:
+ """Swin-T V2 CR @ 384x384, trained ImageNet-1k"""
+ model_args = dict(
+ embed_dim=96,
+ depths=(2, 2, 6, 2),
+ num_heads=(3, 6, 12, 24),
+ )
+ return _create_swin_transformer_v2_cr('swinv2_cr_tiny_384', pretrained=pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def swinv2_cr_tiny_224(pretrained=False, **kwargs) -> SwinTransformerV2Cr:
+ """Swin-T V2 CR @ 224x224, trained ImageNet-1k"""
+ model_args = dict(
+ embed_dim=96,
+ depths=(2, 2, 6, 2),
+ num_heads=(3, 6, 12, 24),
+ )
+ return _create_swin_transformer_v2_cr('swinv2_cr_tiny_224', pretrained=pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def swinv2_cr_tiny_ns_224(pretrained=False, **kwargs) -> SwinTransformerV2Cr:
+ """Swin-T V2 CR @ 224x224, trained ImageNet-1k w/ extra stage norms.
+ ** Experimental, may make default if results are improved. **
+ """
+ model_args = dict(
+ embed_dim=96,
+ depths=(2, 2, 6, 2),
+ num_heads=(3, 6, 12, 24),
+ extra_norm_stage=True,
+ )
+ return _create_swin_transformer_v2_cr('swinv2_cr_tiny_ns_224', pretrained=pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def swinv2_cr_small_384(pretrained=False, **kwargs) -> SwinTransformerV2Cr:
+ """Swin-S V2 CR @ 384x384, trained ImageNet-1k"""
+ model_args = dict(
+ embed_dim=96,
+ depths=(2, 2, 18, 2),
+ num_heads=(3, 6, 12, 24),
+ )
+ return _create_swin_transformer_v2_cr('swinv2_cr_small_384', pretrained=pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def swinv2_cr_small_224(pretrained=False, **kwargs) -> SwinTransformerV2Cr:
+ """Swin-S V2 CR @ 224x224, trained ImageNet-1k"""
+ model_args = dict(
+ embed_dim=96,
+ depths=(2, 2, 18, 2),
+ num_heads=(3, 6, 12, 24),
+ )
+ return _create_swin_transformer_v2_cr('swinv2_cr_small_224', pretrained=pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def swinv2_cr_small_ns_224(pretrained=False, **kwargs) -> SwinTransformerV2Cr:
+ """Swin-S V2 CR @ 224x224, trained ImageNet-1k"""
+ model_args = dict(
+ embed_dim=96,
+ depths=(2, 2, 18, 2),
+ num_heads=(3, 6, 12, 24),
+ extra_norm_stage=True,
+ )
+ return _create_swin_transformer_v2_cr('swinv2_cr_small_ns_224', pretrained=pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def swinv2_cr_small_ns_256(pretrained=False, **kwargs) -> SwinTransformerV2Cr:
+ """Swin-S V2 CR @ 256x256, trained ImageNet-1k"""
+ model_args = dict(
+ embed_dim=96,
+ depths=(2, 2, 18, 2),
+ num_heads=(3, 6, 12, 24),
+ extra_norm_stage=True,
+ )
+ return _create_swin_transformer_v2_cr('swinv2_cr_small_ns_256', pretrained=pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def swinv2_cr_base_384(pretrained=False, **kwargs) -> SwinTransformerV2Cr:
+ """Swin-B V2 CR @ 384x384, trained ImageNet-1k"""
+ model_args = dict(
+ embed_dim=128,
+ depths=(2, 2, 18, 2),
+ num_heads=(4, 8, 16, 32),
+ )
+ return _create_swin_transformer_v2_cr('swinv2_cr_base_384', pretrained=pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def swinv2_cr_base_224(pretrained=False, **kwargs) -> SwinTransformerV2Cr:
+ """Swin-B V2 CR @ 224x224, trained ImageNet-1k"""
+ model_args = dict(
+ embed_dim=128,
+ depths=(2, 2, 18, 2),
+ num_heads=(4, 8, 16, 32),
+ )
+ return _create_swin_transformer_v2_cr('swinv2_cr_base_224', pretrained=pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def swinv2_cr_base_ns_224(pretrained=False, **kwargs) -> SwinTransformerV2Cr:
+ """Swin-B V2 CR @ 224x224, trained ImageNet-1k"""
+ model_args = dict(
+ embed_dim=128,
+ depths=(2, 2, 18, 2),
+ num_heads=(4, 8, 16, 32),
+ extra_norm_stage=True,
+ )
+ return _create_swin_transformer_v2_cr('swinv2_cr_base_ns_224', pretrained=pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def swinv2_cr_large_384(pretrained=False, **kwargs) -> SwinTransformerV2Cr:
+ """Swin-L V2 CR @ 384x384, trained ImageNet-1k"""
+ model_args = dict(
+ embed_dim=192,
+ depths=(2, 2, 18, 2),
+ num_heads=(6, 12, 24, 48),
+ )
+ return _create_swin_transformer_v2_cr('swinv2_cr_large_384', pretrained=pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def swinv2_cr_large_224(pretrained=False, **kwargs) -> SwinTransformerV2Cr:
+ """Swin-L V2 CR @ 224x224, trained ImageNet-1k"""
+ model_args = dict(
+ embed_dim=192,
+ depths=(2, 2, 18, 2),
+ num_heads=(6, 12, 24, 48),
+ )
+ return _create_swin_transformer_v2_cr('swinv2_cr_large_224', pretrained=pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def swinv2_cr_huge_384(pretrained=False, **kwargs) -> SwinTransformerV2Cr:
+ """Swin-H V2 CR @ 384x384, trained ImageNet-1k"""
+ model_args = dict(
+ embed_dim=352,
+ depths=(2, 2, 18, 2),
+ num_heads=(11, 22, 44, 88), # head count not certain for Huge, 384 & 224 trying diff values
+ extra_norm_period=6,
+ )
+ return _create_swin_transformer_v2_cr('swinv2_cr_huge_384', pretrained=pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def swinv2_cr_huge_224(pretrained=False, **kwargs) -> SwinTransformerV2Cr:
+ """Swin-H V2 CR @ 224x224, trained ImageNet-1k"""
+ model_args = dict(
+ embed_dim=352,
+ depths=(2, 2, 18, 2),
+ num_heads=(8, 16, 32, 64), # head count not certain for Huge, 384 & 224 trying diff values
+ extra_norm_period=6,
+ )
+ return _create_swin_transformer_v2_cr('swinv2_cr_huge_224', pretrained=pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def swinv2_cr_giant_384(pretrained=False, **kwargs) -> SwinTransformerV2Cr:
+ """Swin-G V2 CR @ 384x384, trained ImageNet-1k"""
+ model_args = dict(
+ embed_dim=512,
+ depths=(2, 2, 42, 2),
+ num_heads=(16, 32, 64, 128),
+ extra_norm_period=6,
+ )
+ return _create_swin_transformer_v2_cr('swinv2_cr_giant_384', pretrained=pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def swinv2_cr_giant_224(pretrained=False, **kwargs) -> SwinTransformerV2Cr:
+ """Swin-G V2 CR @ 224x224, trained ImageNet-1k"""
+ model_args = dict(
+ embed_dim=512,
+ depths=(2, 2, 42, 2),
+ num_heads=(16, 32, 64, 128),
+ extra_norm_period=6,
+ )
+ return _create_swin_transformer_v2_cr('swinv2_cr_giant_224', pretrained=pretrained, **dict(model_args, **kwargs))
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/tnt.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/tnt.py
new file mode 100644
index 0000000000000000000000000000000000000000..c35901876459d94160dcdd69b44f8684625d6fd5
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/tnt.py
@@ -0,0 +1,372 @@
+""" Transformer in Transformer (TNT) in PyTorch
+
+A PyTorch implement of TNT as described in
+'Transformer in Transformer' - https://arxiv.org/abs/2103.00112
+
+The official mindspore code is released and available at
+https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT
+"""
+import math
+
+import torch
+import torch.nn as nn
+from torch.utils.checkpoint import checkpoint
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.layers import Mlp, DropPath, trunc_normal_, _assert, to_2tuple
+from ._builder import build_model_with_cfg
+from ._registry import register_model
+from .vision_transformer import resize_pos_embed
+
+__all__ = ['TNT'] # model_registry will add each entrypoint fn to this
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'pixel_embed.proj', 'classifier': 'head',
+ **kwargs
+ }
+
+
+default_cfgs = {
+ 'tnt_s_patch16_224': _cfg(
+ url='https://github.com/contrastive/pytorch-image-models/releases/download/TNT/tnt_s_patch16_224.pth.tar',
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
+ ),
+ 'tnt_b_patch16_224': _cfg(
+ mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
+ ),
+}
+
+
+class Attention(nn.Module):
+ """ Multi-Head Attention
+ """
+ def __init__(self, dim, hidden_dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
+ super().__init__()
+ self.hidden_dim = hidden_dim
+ self.num_heads = num_heads
+ head_dim = hidden_dim // num_heads
+ self.head_dim = head_dim
+ self.scale = head_dim ** -0.5
+
+ self.qk = nn.Linear(dim, hidden_dim * 2, bias=qkv_bias)
+ self.v = nn.Linear(dim, dim, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop, inplace=True)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop, inplace=True)
+
+ def forward(self, x):
+ B, N, C = x.shape
+ qk = self.qk(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
+ q, k = qk.unbind(0) # make torchscript happy (cannot use tensor as tuple)
+ v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
+
+ attn = (q @ k.transpose(-2, -1)) * self.scale
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class Block(nn.Module):
+ """ TNT Block
+ """
+ def __init__(
+ self,
+ dim,
+ dim_out,
+ num_pixel,
+ num_heads_in=4,
+ num_heads_out=12,
+ mlp_ratio=4.,
+ qkv_bias=False,
+ proj_drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ ):
+ super().__init__()
+ # Inner transformer
+ self.norm_in = norm_layer(dim)
+ self.attn_in = Attention(
+ dim,
+ dim,
+ num_heads=num_heads_in,
+ qkv_bias=qkv_bias,
+ attn_drop=attn_drop,
+ proj_drop=proj_drop,
+ )
+
+ self.norm_mlp_in = norm_layer(dim)
+ self.mlp_in = Mlp(
+ in_features=dim,
+ hidden_features=int(dim * 4),
+ out_features=dim,
+ act_layer=act_layer,
+ drop=proj_drop,
+ )
+
+ self.norm1_proj = norm_layer(dim)
+ self.proj = nn.Linear(dim * num_pixel, dim_out, bias=True)
+
+ # Outer transformer
+ self.norm_out = norm_layer(dim_out)
+ self.attn_out = Attention(
+ dim_out,
+ dim_out,
+ num_heads=num_heads_out,
+ qkv_bias=qkv_bias,
+ attn_drop=attn_drop,
+ proj_drop=proj_drop,
+ )
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.norm_mlp = norm_layer(dim_out)
+ self.mlp = Mlp(
+ in_features=dim_out,
+ hidden_features=int(dim_out * mlp_ratio),
+ out_features=dim_out,
+ act_layer=act_layer,
+ drop=proj_drop,
+ )
+
+ def forward(self, pixel_embed, patch_embed):
+ # inner
+ pixel_embed = pixel_embed + self.drop_path(self.attn_in(self.norm_in(pixel_embed)))
+ pixel_embed = pixel_embed + self.drop_path(self.mlp_in(self.norm_mlp_in(pixel_embed)))
+ # outer
+ B, N, C = patch_embed.size()
+ patch_embed = torch.cat(
+ [patch_embed[:, 0:1], patch_embed[:, 1:] + self.proj(self.norm1_proj(pixel_embed).reshape(B, N - 1, -1))],
+ dim=1)
+ patch_embed = patch_embed + self.drop_path(self.attn_out(self.norm_out(patch_embed)))
+ patch_embed = patch_embed + self.drop_path(self.mlp(self.norm_mlp(patch_embed)))
+ return pixel_embed, patch_embed
+
+
+class PixelEmbed(nn.Module):
+ """ Image to Pixel Embedding
+ """
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ # grid_size property necessary for resizing positional embedding
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
+ num_patches = (self.grid_size[0]) * (self.grid_size[1])
+ self.img_size = img_size
+ self.num_patches = num_patches
+ self.in_dim = in_dim
+ new_patch_size = [math.ceil(ps / stride) for ps in patch_size]
+ self.new_patch_size = new_patch_size
+
+ self.proj = nn.Conv2d(in_chans, self.in_dim, kernel_size=7, padding=3, stride=stride)
+ self.unfold = nn.Unfold(kernel_size=new_patch_size, stride=new_patch_size)
+
+ def forward(self, x, pixel_pos):
+ B, C, H, W = x.shape
+ _assert(H == self.img_size[0],
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
+ _assert(W == self.img_size[1],
+ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]}).")
+ x = self.proj(x)
+ x = self.unfold(x)
+ x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1])
+ x = x + pixel_pos
+ x = x.reshape(B * self.num_patches, self.in_dim, -1).transpose(1, 2)
+ return x
+
+
+class TNT(nn.Module):
+ """ Transformer in Transformer - https://arxiv.org/abs/2103.00112
+ """
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ num_classes=1000,
+ global_pool='token',
+ embed_dim=768,
+ inner_dim=48,
+ depth=12,
+ num_heads_inner=4,
+ num_heads_outer=12,
+ mlp_ratio=4.,
+ qkv_bias=False,
+ drop_rate=0.,
+ pos_drop_rate=0.,
+ proj_drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.,
+ norm_layer=nn.LayerNorm,
+ first_stride=4,
+ ):
+ super().__init__()
+ assert global_pool in ('', 'token', 'avg')
+ self.num_classes = num_classes
+ self.global_pool = global_pool
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.grad_checkpointing = False
+
+ self.pixel_embed = PixelEmbed(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=in_chans,
+ in_dim=inner_dim,
+ stride=first_stride,
+ )
+ num_patches = self.pixel_embed.num_patches
+ self.num_patches = num_patches
+ new_patch_size = self.pixel_embed.new_patch_size
+ num_pixel = new_patch_size[0] * new_patch_size[1]
+
+ self.norm1_proj = norm_layer(num_pixel * inner_dim)
+ self.proj = nn.Linear(num_pixel * inner_dim, embed_dim)
+ self.norm2_proj = norm_layer(embed_dim)
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.patch_pos = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
+ self.pixel_pos = nn.Parameter(torch.zeros(1, inner_dim, new_patch_size[0], new_patch_size[1]))
+ self.pos_drop = nn.Dropout(p=pos_drop_rate)
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ blocks = []
+ for i in range(depth):
+ blocks.append(Block(
+ dim=inner_dim,
+ dim_out=embed_dim,
+ num_pixel=num_pixel,
+ num_heads_in=num_heads_inner,
+ num_heads_out=num_heads_outer,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_drop=proj_drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ ))
+ self.blocks = nn.ModuleList(blocks)
+ self.norm = norm_layer(embed_dim)
+
+ self.head_drop = nn.Dropout(drop_rate)
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ trunc_normal_(self.cls_token, std=.02)
+ trunc_normal_(self.patch_pos, std=.02)
+ trunc_normal_(self.pixel_pos, std=.02)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'patch_pos', 'pixel_pos', 'cls_token'}
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse=False):
+ matcher = dict(
+ stem=r'^cls_token|patch_pos|pixel_pos|pixel_embed|norm[12]_proj|proj', # stem and embed / pos
+ blocks=[
+ (r'^blocks\.(\d+)', None),
+ (r'^norm', (99999,)),
+ ]
+ )
+ return matcher
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.grad_checkpointing = enable
+
+ @torch.jit.ignore
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes, global_pool=None):
+ self.num_classes = num_classes
+ if global_pool is not None:
+ assert global_pool in ('', 'token', 'avg')
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ B = x.shape[0]
+ pixel_embed = self.pixel_embed(x, self.pixel_pos)
+
+ patch_embed = self.norm2_proj(self.proj(self.norm1_proj(pixel_embed.reshape(B, self.num_patches, -1))))
+ patch_embed = torch.cat((self.cls_token.expand(B, -1, -1), patch_embed), dim=1)
+ patch_embed = patch_embed + self.patch_pos
+ patch_embed = self.pos_drop(patch_embed)
+
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ for blk in self.blocks:
+ pixel_embed, patch_embed = checkpoint(blk, pixel_embed, patch_embed)
+ else:
+ for blk in self.blocks:
+ pixel_embed, patch_embed = blk(pixel_embed, patch_embed)
+
+ patch_embed = self.norm(patch_embed)
+ return patch_embed
+
+ def forward_head(self, x, pre_logits: bool = False):
+ if self.global_pool:
+ x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
+ x = self.head_drop(x)
+ return x if pre_logits else self.head(x)
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.forward_head(x)
+ return x
+
+
+def checkpoint_filter_fn(state_dict, model):
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
+ if state_dict['patch_pos'].shape != model.patch_pos.shape:
+ state_dict['patch_pos'] = resize_pos_embed(state_dict['patch_pos'],
+ model.patch_pos, getattr(model, 'num_tokens', 1), model.pixel_embed.grid_size)
+ return state_dict
+
+
+def _create_tnt(variant, pretrained=False, **kwargs):
+ if kwargs.get('features_only', None):
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
+
+ model = build_model_with_cfg(
+ TNT, variant, pretrained,
+ pretrained_filter_fn=checkpoint_filter_fn,
+ **kwargs)
+ return model
+
+
+@register_model
+def tnt_s_patch16_224(pretrained=False, **kwargs) -> TNT:
+ model_cfg = dict(
+ patch_size=16, embed_dim=384, inner_dim=24, depth=12, num_heads_outer=6,
+ qkv_bias=False)
+ model = _create_tnt('tnt_s_patch16_224', pretrained=pretrained, **dict(model_cfg, **kwargs))
+ return model
+
+
+@register_model
+def tnt_b_patch16_224(pretrained=False, **kwargs) -> TNT:
+ model_cfg = dict(
+ patch_size=16, embed_dim=640, inner_dim=40, depth=12, num_heads_outer=10,
+ qkv_bias=False)
+ model = _create_tnt('tnt_b_patch16_224', pretrained=pretrained, **dict(model_cfg, **kwargs))
+ return model
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/tresnet.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/tresnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..33d375f7bb7800dfdb0dc9d7889da7b4022072d5
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/tresnet.py
@@ -0,0 +1,355 @@
+"""
+TResNet: High Performance GPU-Dedicated Architecture
+https://arxiv.org/pdf/2003.13630.pdf
+
+Original model: https://github.com/mrT23/TResNet
+
+"""
+from collections import OrderedDict
+from functools import partial
+
+import torch
+import torch.nn as nn
+
+from timm.layers import SpaceToDepth, BlurPool2d, ClassifierHead, SEModule,\
+ ConvNormActAa, ConvNormAct, DropPath
+from ._builder import build_model_with_cfg
+from ._manipulate import checkpoint_seq
+from ._registry import register_model, generate_default_cfgs, register_model_deprecations
+
+__all__ = ['TResNet'] # model_registry will add each entrypoint fn to this
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(
+ self,
+ inplanes,
+ planes,
+ stride=1,
+ downsample=None,
+ use_se=True,
+ aa_layer=None,
+ drop_path_rate=0.
+ ):
+ super(BasicBlock, self).__init__()
+ self.downsample = downsample
+ self.stride = stride
+ act_layer = partial(nn.LeakyReLU, negative_slope=1e-3)
+
+ if stride == 1:
+ self.conv1 = ConvNormAct(inplanes, planes, kernel_size=3, stride=1, act_layer=act_layer)
+ else:
+ self.conv1 = ConvNormActAa(
+ inplanes, planes, kernel_size=3, stride=2, act_layer=act_layer, aa_layer=aa_layer)
+
+ self.conv2 = ConvNormAct(planes, planes, kernel_size=3, stride=1, apply_act=False, act_layer=None)
+ self.act = nn.ReLU(inplace=True)
+
+ rd_chs = max(planes * self.expansion // 4, 64)
+ self.se = SEModule(planes * self.expansion, rd_channels=rd_chs) if use_se else None
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
+
+ def forward(self, x):
+ if self.downsample is not None:
+ shortcut = self.downsample(x)
+ else:
+ shortcut = x
+ out = self.conv1(x)
+ out = self.conv2(out)
+ if self.se is not None:
+ out = self.se(out)
+ out = self.drop_path(out) + shortcut
+ out = self.act(out)
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(
+ self,
+ inplanes,
+ planes,
+ stride=1,
+ downsample=None,
+ use_se=True,
+ act_layer=None,
+ aa_layer=None,
+ drop_path_rate=0.,
+ ):
+ super(Bottleneck, self).__init__()
+ self.downsample = downsample
+ self.stride = stride
+ act_layer = act_layer or partial(nn.LeakyReLU, negative_slope=1e-3)
+
+ self.conv1 = ConvNormAct(
+ inplanes, planes, kernel_size=1, stride=1, act_layer=act_layer)
+ if stride == 1:
+ self.conv2 = ConvNormAct(
+ planes, planes, kernel_size=3, stride=1, act_layer=act_layer)
+ else:
+ self.conv2 = ConvNormActAa(
+ planes, planes, kernel_size=3, stride=2, act_layer=act_layer, aa_layer=aa_layer)
+
+ reduction_chs = max(planes * self.expansion // 8, 64)
+ self.se = SEModule(planes, rd_channels=reduction_chs) if use_se else None
+
+ self.conv3 = ConvNormAct(
+ planes, planes * self.expansion, kernel_size=1, stride=1, apply_act=False, act_layer=None)
+
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
+ self.act = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ if self.downsample is not None:
+ shortcut = self.downsample(x)
+ else:
+ shortcut = x
+ out = self.conv1(x)
+ out = self.conv2(out)
+ if self.se is not None:
+ out = self.se(out)
+ out = self.conv3(out)
+ out = self.drop_path(out) + shortcut
+ out = self.act(out)
+ return out
+
+
+class TResNet(nn.Module):
+ def __init__(
+ self,
+ layers,
+ in_chans=3,
+ num_classes=1000,
+ width_factor=1.0,
+ v2=False,
+ global_pool='fast',
+ drop_rate=0.,
+ drop_path_rate=0.,
+ ):
+ self.num_classes = num_classes
+ self.drop_rate = drop_rate
+ self.grad_checkpointing = False
+ super(TResNet, self).__init__()
+
+ aa_layer = BlurPool2d
+ act_layer = nn.LeakyReLU
+
+ # TResnet stages
+ self.inplanes = int(64 * width_factor)
+ self.planes = int(64 * width_factor)
+ if v2:
+ self.inplanes = self.inplanes // 8 * 8
+ self.planes = self.planes // 8 * 8
+
+ dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(layers)).split(layers)]
+ conv1 = ConvNormAct(in_chans * 16, self.planes, stride=1, kernel_size=3, act_layer=act_layer)
+ layer1 = self._make_layer(
+ Bottleneck if v2 else BasicBlock,
+ self.planes, layers[0], stride=1, use_se=True, aa_layer=aa_layer, drop_path_rate=dpr[0])
+ layer2 = self._make_layer(
+ Bottleneck if v2 else BasicBlock,
+ self.planes * 2, layers[1], stride=2, use_se=True, aa_layer=aa_layer, drop_path_rate=dpr[1])
+ layer3 = self._make_layer(
+ Bottleneck,
+ self.planes * 4, layers[2], stride=2, use_se=True, aa_layer=aa_layer, drop_path_rate=dpr[2])
+ layer4 = self._make_layer(
+ Bottleneck,
+ self.planes * 8, layers[3], stride=2, use_se=False, aa_layer=aa_layer, drop_path_rate=dpr[3])
+
+ # body
+ self.body = nn.Sequential(OrderedDict([
+ ('s2d', SpaceToDepth()),
+ ('conv1', conv1),
+ ('layer1', layer1),
+ ('layer2', layer2),
+ ('layer3', layer3),
+ ('layer4', layer4),
+ ]))
+
+ self.feature_info = [
+ dict(num_chs=self.planes, reduction=2, module=''), # Not with S2D?
+ dict(num_chs=self.planes * (Bottleneck.expansion if v2 else 1), reduction=4, module='body.layer1'),
+ dict(num_chs=self.planes * 2 * (Bottleneck.expansion if v2 else 1), reduction=8, module='body.layer2'),
+ dict(num_chs=self.planes * 4 * Bottleneck.expansion, reduction=16, module='body.layer3'),
+ dict(num_chs=self.planes * 8 * Bottleneck.expansion, reduction=32, module='body.layer4'),
+ ]
+
+ # head
+ self.num_features = (self.planes * 8) * Bottleneck.expansion
+ self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
+
+ # model initialization
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
+ if isinstance(m, nn.Linear):
+ m.weight.data.normal_(0, 0.01)
+
+ # residual connections special initialization
+ for m in self.modules():
+ if isinstance(m, BasicBlock):
+ nn.init.zeros_(m.conv2.bn.weight)
+ if isinstance(m, Bottleneck):
+ nn.init.zeros_(m.conv3.bn.weight)
+
+ def _make_layer(self, block, planes, blocks, stride=1, use_se=True, aa_layer=None, drop_path_rate=0.):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ layers = []
+ if stride == 2:
+ # avg pooling before 1x1 conv
+ layers.append(nn.AvgPool2d(kernel_size=2, stride=2, ceil_mode=True, count_include_pad=False))
+ layers += [ConvNormAct(
+ self.inplanes, planes * block.expansion, kernel_size=1, stride=1, apply_act=False, act_layer=None)]
+ downsample = nn.Sequential(*layers)
+
+ layers = []
+ for i in range(blocks):
+ layers.append(block(
+ self.inplanes,
+ planes,
+ stride=stride if i == 0 else 1,
+ downsample=downsample if i == 0 else None,
+ use_se=use_se,
+ aa_layer=aa_layer,
+ drop_path_rate=drop_path_rate[i] if isinstance(drop_path_rate, list) else drop_path_rate,
+ ))
+ self.inplanes = planes * block.expansion
+ return nn.Sequential(*layers)
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse=False):
+ matcher = dict(stem=r'^body\.conv1', blocks=r'^body\.layer(\d+)' if coarse else r'^body\.layer(\d+)\.(\d+)')
+ return matcher
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.grad_checkpointing = enable
+
+ @torch.jit.ignore
+ def get_classifier(self):
+ return self.head.fc
+
+ def reset_classifier(self, num_classes, global_pool=None):
+ self.head.reset(num_classes, pool_type=global_pool)
+
+ def forward_features(self, x):
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ x = self.body.s2d(x)
+ x = self.body.conv1(x)
+ x = checkpoint_seq([
+ self.body.layer1,
+ self.body.layer2,
+ self.body.layer3,
+ self.body.layer4],
+ x, flatten=True)
+ else:
+ x = self.body(x)
+ return x
+
+ def forward_head(self, x, pre_logits: bool = False):
+ return x if pre_logits else self.head(x)
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.forward_head(x)
+ return x
+
+
+def checkpoint_filter_fn(state_dict, model):
+ if 'body.conv1.conv.weight' in state_dict:
+ return state_dict
+
+ import re
+ state_dict = state_dict.get('model', state_dict)
+ state_dict = state_dict.get('state_dict', state_dict)
+ out_dict = {}
+ for k, v in state_dict.items():
+ k = re.sub(r'conv(\d+)\.0.0', lambda x: f'conv{int(x.group(1))}.conv', k)
+ k = re.sub(r'conv(\d+)\.0.1', lambda x: f'conv{int(x.group(1))}.bn', k)
+ k = re.sub(r'conv(\d+)\.0', lambda x: f'conv{int(x.group(1))}.conv', k)
+ k = re.sub(r'conv(\d+)\.1', lambda x: f'conv{int(x.group(1))}.bn', k)
+ k = re.sub(r'downsample\.(\d+)\.0', lambda x: f'downsample.{int(x.group(1))}.conv', k)
+ k = re.sub(r'downsample\.(\d+)\.1', lambda x: f'downsample.{int(x.group(1))}.bn', k)
+ if k.endswith('bn.weight'):
+ # convert weight from inplace_abn to batchnorm
+ v = v.abs().add(1e-5)
+ out_dict[k] = v
+ return out_dict
+
+
+def _create_tresnet(variant, pretrained=False, **kwargs):
+ return build_model_with_cfg(
+ TResNet,
+ variant,
+ pretrained,
+ pretrained_filter_fn=checkpoint_filter_fn,
+ feature_cfg=dict(out_indices=(1, 2, 3, 4), flatten_sequential=True),
+ **kwargs,
+ )
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.875, 'interpolation': 'bilinear',
+ 'mean': (0., 0., 0.), 'std': (1., 1., 1.),
+ 'first_conv': 'body.conv1.conv', 'classifier': 'head.fc',
+ **kwargs
+ }
+
+
+default_cfgs = generate_default_cfgs({
+ 'tresnet_m.miil_in21k_ft_in1k': _cfg(hf_hub_id='timm/'),
+ 'tresnet_m.miil_in21k': _cfg(hf_hub_id='timm/', num_classes=11221),
+ 'tresnet_m.miil_in1k': _cfg(hf_hub_id='timm/'),
+ 'tresnet_l.miil_in1k': _cfg(hf_hub_id='timm/'),
+ 'tresnet_xl.miil_in1k': _cfg(hf_hub_id='timm/'),
+ 'tresnet_m.miil_in1k_448': _cfg(
+ input_size=(3, 448, 448), pool_size=(14, 14),
+ hf_hub_id='timm/'),
+ 'tresnet_l.miil_in1k_448': _cfg(
+ input_size=(3, 448, 448), pool_size=(14, 14),
+ hf_hub_id='timm/'),
+ 'tresnet_xl.miil_in1k_448': _cfg(
+ input_size=(3, 448, 448), pool_size=(14, 14),
+ hf_hub_id='timm/'),
+
+ 'tresnet_v2_l.miil_in21k_ft_in1k': _cfg(hf_hub_id='timm/'),
+ 'tresnet_v2_l.miil_in21k': _cfg(hf_hub_id='timm/', num_classes=11221),
+})
+
+
+@register_model
+def tresnet_m(pretrained=False, **kwargs) -> TResNet:
+ model_args = dict(layers=[3, 4, 11, 3])
+ return _create_tresnet('tresnet_m', pretrained=pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def tresnet_l(pretrained=False, **kwargs) -> TResNet:
+ model_args = dict(layers=[4, 5, 18, 3], width_factor=1.2)
+ return _create_tresnet('tresnet_l', pretrained=pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def tresnet_xl(pretrained=False, **kwargs) -> TResNet:
+ model_args = dict(layers=[4, 5, 24, 3], width_factor=1.3)
+ return _create_tresnet('tresnet_xl', pretrained=pretrained, **dict(model_args, **kwargs))
+
+
+@register_model
+def tresnet_v2_l(pretrained=False, **kwargs) -> TResNet:
+ model_args = dict(layers=[3, 4, 23, 3], width_factor=1.0, v2=True)
+ return _create_tresnet('tresnet_v2_l', pretrained=pretrained, **dict(model_args, **kwargs))
+
+
+register_model_deprecations(__name__, {
+ 'tresnet_m_miil_in21k': 'tresnet_m.miil_in21k',
+ 'tresnet_m_448': 'tresnet_m.miil_in1k_448',
+ 'tresnet_l_448': 'tresnet_l.miil_in1k_448',
+ 'tresnet_xl_448': 'tresnet_xl.miil_in1k_448',
+})
\ No newline at end of file
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/vgg.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/vgg.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ba12c9a92b2735b75a58ca96a2a29b4ee0941ac
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/vgg.py
@@ -0,0 +1,301 @@
+"""VGG
+
+Adapted from https://github.com/pytorch/vision 'vgg.py' (BSD-3-Clause) with a few changes for
+timm functionality.
+
+Copyright 2021 Ross Wightman
+"""
+from typing import Union, List, Dict, Any, cast
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.layers import ClassifierHead
+from ._builder import build_model_with_cfg
+from ._features_fx import register_notrace_module
+from ._registry import register_model, generate_default_cfgs
+
+__all__ = ['VGG']
+
+
+cfgs: Dict[str, List[Union[str, int]]] = {
+ 'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
+ 'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
+ 'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
+ 'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
+}
+
+
+@register_notrace_module # reason: FX can't symbolically trace control flow in forward method
+class ConvMlp(nn.Module):
+
+ def __init__(
+ self,
+ in_features=512,
+ out_features=4096,
+ kernel_size=7,
+ mlp_ratio=1.0,
+ drop_rate: float = 0.2,
+ act_layer: nn.Module = None,
+ conv_layer: nn.Module = None,
+ ):
+ super(ConvMlp, self).__init__()
+ self.input_kernel_size = kernel_size
+ mid_features = int(out_features * mlp_ratio)
+ self.fc1 = conv_layer(in_features, mid_features, kernel_size, bias=True)
+ self.act1 = act_layer(True)
+ self.drop = nn.Dropout(drop_rate)
+ self.fc2 = conv_layer(mid_features, out_features, 1, bias=True)
+ self.act2 = act_layer(True)
+
+ def forward(self, x):
+ if x.shape[-2] < self.input_kernel_size or x.shape[-1] < self.input_kernel_size:
+ # keep the input size >= 7x7
+ output_size = (max(self.input_kernel_size, x.shape[-2]), max(self.input_kernel_size, x.shape[-1]))
+ x = F.adaptive_avg_pool2d(x, output_size)
+ x = self.fc1(x)
+ x = self.act1(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.act2(x)
+ return x
+
+
+class VGG(nn.Module):
+
+ def __init__(
+ self,
+ cfg: List[Any],
+ num_classes: int = 1000,
+ in_chans: int = 3,
+ output_stride: int = 32,
+ mlp_ratio: float = 1.0,
+ act_layer: nn.Module = nn.ReLU,
+ conv_layer: nn.Module = nn.Conv2d,
+ norm_layer: nn.Module = None,
+ global_pool: str = 'avg',
+ drop_rate: float = 0.,
+ ) -> None:
+ super(VGG, self).__init__()
+ assert output_stride == 32
+ self.num_classes = num_classes
+ self.num_features = 4096
+ self.drop_rate = drop_rate
+ self.grad_checkpointing = False
+ self.use_norm = norm_layer is not None
+ self.feature_info = []
+ prev_chs = in_chans
+ net_stride = 1
+ pool_layer = nn.MaxPool2d
+ layers: List[nn.Module] = []
+ for v in cfg:
+ last_idx = len(layers) - 1
+ if v == 'M':
+ self.feature_info.append(dict(num_chs=prev_chs, reduction=net_stride, module=f'features.{last_idx}'))
+ layers += [pool_layer(kernel_size=2, stride=2)]
+ net_stride *= 2
+ else:
+ v = cast(int, v)
+ conv2d = conv_layer(prev_chs, v, kernel_size=3, padding=1)
+ if norm_layer is not None:
+ layers += [conv2d, norm_layer(v), act_layer(inplace=True)]
+ else:
+ layers += [conv2d, act_layer(inplace=True)]
+ prev_chs = v
+ self.features = nn.Sequential(*layers)
+ self.feature_info.append(dict(num_chs=prev_chs, reduction=net_stride, module=f'features.{len(layers) - 1}'))
+
+ self.pre_logits = ConvMlp(
+ prev_chs,
+ self.num_features,
+ 7,
+ mlp_ratio=mlp_ratio,
+ drop_rate=drop_rate,
+ act_layer=act_layer,
+ conv_layer=conv_layer,
+ )
+ self.head = ClassifierHead(
+ self.num_features,
+ num_classes,
+ pool_type=global_pool,
+ drop_rate=drop_rate,
+ )
+
+ self._initialize_weights()
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse=False):
+ # this treats BN layers as separate groups for bn variants, a lot of effort to fix that
+ return dict(stem=r'^features\.0', blocks=r'^features\.(\d+)')
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ assert not enable, 'gradient checkpointing not supported'
+
+ @torch.jit.ignore
+ def get_classifier(self):
+ return self.head.fc
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.num_classes = num_classes
+ self.head = ClassifierHead(
+ self.num_features,
+ self.num_classes,
+ pool_type=global_pool,
+ drop_rate=self.drop_rate,
+ )
+
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.features(x)
+ return x
+
+ def forward_head(self, x: torch.Tensor, pre_logits: bool = False):
+ x = self.pre_logits(x)
+ return x if pre_logits else self.head(x)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.forward_features(x)
+ x = self.forward_head(x)
+ return x
+
+ def _initialize_weights(self) -> None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, 0, 0.01)
+ nn.init.constant_(m.bias, 0)
+
+
+def _filter_fn(state_dict):
+ """ convert patch embedding weight from manual patchify + linear proj to conv"""
+ out_dict = {}
+ for k, v in state_dict.items():
+ k_r = k
+ k_r = k_r.replace('classifier.0', 'pre_logits.fc1')
+ k_r = k_r.replace('classifier.3', 'pre_logits.fc2')
+ k_r = k_r.replace('classifier.6', 'head.fc')
+ if 'classifier.0.weight' in k:
+ v = v.reshape(-1, 512, 7, 7)
+ if 'classifier.3.weight' in k:
+ v = v.reshape(-1, 4096, 1, 1)
+ out_dict[k_r] = v
+ return out_dict
+
+
+def _create_vgg(variant: str, pretrained: bool, **kwargs: Any) -> VGG:
+ cfg = variant.split('_')[0]
+ # NOTE: VGG is one of few models with stride==1 features w/ 6 out_indices [0..5]
+ out_indices = kwargs.pop('out_indices', (0, 1, 2, 3, 4, 5))
+ model = build_model_with_cfg(
+ VGG,
+ variant,
+ pretrained,
+ model_cfg=cfgs[cfg],
+ feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
+ pretrained_filter_fn=_filter_fn,
+ **kwargs,
+ )
+ return model
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.875, 'interpolation': 'bilinear',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'features.0', 'classifier': 'head.fc',
+ **kwargs
+ }
+
+
+default_cfgs = generate_default_cfgs({
+ 'vgg11.tv_in1k': _cfg(hf_hub_id='timm/'),
+ 'vgg13.tv_in1k': _cfg(hf_hub_id='timm/'),
+ 'vgg16.tv_in1k': _cfg(hf_hub_id='timm/'),
+ 'vgg19.tv_in1k': _cfg(hf_hub_id='timm/'),
+ 'vgg11_bn.tv_in1k': _cfg(hf_hub_id='timm/'),
+ 'vgg13_bn.tv_in1k': _cfg(hf_hub_id='timm/'),
+ 'vgg16_bn.tv_in1k': _cfg(hf_hub_id='timm/'),
+ 'vgg19_bn.tv_in1k': _cfg(hf_hub_id='timm/'),
+})
+
+
+@register_model
+def vgg11(pretrained: bool = False, **kwargs: Any) -> VGG:
+ r"""VGG 11-layer model (configuration "A") from
+ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._
+ """
+ model_args = dict(**kwargs)
+ return _create_vgg('vgg11', pretrained=pretrained, **model_args)
+
+
+@register_model
+def vgg11_bn(pretrained: bool = False, **kwargs: Any) -> VGG:
+ r"""VGG 11-layer model (configuration "A") with batch normalization
+ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._
+ """
+ model_args = dict(norm_layer=nn.BatchNorm2d, **kwargs)
+ return _create_vgg('vgg11_bn', pretrained=pretrained, **model_args)
+
+
+@register_model
+def vgg13(pretrained: bool = False, **kwargs: Any) -> VGG:
+ r"""VGG 13-layer model (configuration "B")
+ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._
+ """
+ model_args = dict(**kwargs)
+ return _create_vgg('vgg13', pretrained=pretrained, **model_args)
+
+
+@register_model
+def vgg13_bn(pretrained: bool = False, **kwargs: Any) -> VGG:
+ r"""VGG 13-layer model (configuration "B") with batch normalization
+ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._
+ """
+ model_args = dict(norm_layer=nn.BatchNorm2d, **kwargs)
+ return _create_vgg('vgg13_bn', pretrained=pretrained, **model_args)
+
+
+@register_model
+def vgg16(pretrained: bool = False, **kwargs: Any) -> VGG:
+ r"""VGG 16-layer model (configuration "D")
+ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._
+ """
+ model_args = dict(**kwargs)
+ return _create_vgg('vgg16', pretrained=pretrained, **model_args)
+
+
+@register_model
+def vgg16_bn(pretrained: bool = False, **kwargs: Any) -> VGG:
+ r"""VGG 16-layer model (configuration "D") with batch normalization
+ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._
+ """
+ model_args = dict(norm_layer=nn.BatchNorm2d, **kwargs)
+ return _create_vgg('vgg16_bn', pretrained=pretrained, **model_args)
+
+
+@register_model
+def vgg19(pretrained: bool = False, **kwargs: Any) -> VGG:
+ r"""VGG 19-layer model (configuration "E")
+ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._
+ """
+ model_args = dict(**kwargs)
+ return _create_vgg('vgg19', pretrained=pretrained, **model_args)
+
+
+@register_model
+def vgg19_bn(pretrained: bool = False, **kwargs: Any) -> VGG:
+ r"""VGG 19-layer model (configuration 'E') with batch normalization
+ `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._
+ """
+ model_args = dict(norm_layer=nn.BatchNorm2d, **kwargs)
+ return _create_vgg('vgg19_bn', pretrained=pretrained, **model_args)
\ No newline at end of file
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/vision_transformer_relpos.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/vision_transformer_relpos.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea8cf0ea1d2b0c4cde221d5ba0123807c24d3753
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/vision_transformer_relpos.py
@@ -0,0 +1,615 @@
+""" Relative Position Vision Transformer (ViT) in PyTorch
+
+NOTE: these models are experimental / WIP, expect changes
+
+Hacked together by / Copyright 2022, Ross Wightman
+"""
+import logging
+import math
+from functools import partial
+from typing import Optional, Tuple, Type, Union
+
+try:
+ from typing import Literal
+except ImportError:
+ from typing_extensions import Literal
+
+import torch
+import torch.nn as nn
+from torch.jit import Final
+from torch.utils.checkpoint import checkpoint
+
+from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
+from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias, use_fused_attn, LayerType
+from ._builder import build_model_with_cfg
+from ._manipulate import named_apply
+from ._registry import generate_default_cfgs, register_model
+from .vision_transformer import get_init_weights_vit
+
+__all__ = ['VisionTransformerRelPos'] # model_registry will add each entrypoint fn to this
+
+_logger = logging.getLogger(__name__)
+
+
+class RelPosAttention(nn.Module):
+ fused_attn: Final[bool]
+
+ def __init__(
+ self,
+ dim,
+ num_heads=8,
+ qkv_bias=False,
+ qk_norm=False,
+ rel_pos_cls=None,
+ attn_drop=0.,
+ proj_drop=0.,
+ norm_layer=nn.LayerNorm,
+ ):
+ super().__init__()
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.scale = self.head_dim ** -0.5
+ self.fused_attn = use_fused_attn()
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
+ self.rel_pos = rel_pos_cls(num_heads=num_heads) if rel_pos_cls else None
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv.unbind(0)
+ q = self.q_norm(q)
+ k = self.k_norm(k)
+
+ if self.fused_attn:
+ if self.rel_pos is not None:
+ attn_bias = self.rel_pos.get_bias()
+ elif shared_rel_pos is not None:
+ attn_bias = shared_rel_pos
+ else:
+ attn_bias = None
+
+ x = torch.nn.functional.scaled_dot_product_attention(
+ q, k, v,
+ attn_mask=attn_bias,
+ dropout_p=self.attn_drop.p if self.training else 0.,
+ )
+ else:
+ q = q * self.scale
+ attn = q @ k.transpose(-2, -1)
+ if self.rel_pos is not None:
+ attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos)
+ elif shared_rel_pos is not None:
+ attn = attn + shared_rel_pos
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+ x = attn @ v
+
+ x = x.transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class LayerScale(nn.Module):
+ def __init__(self, dim, init_values=1e-5, inplace=False):
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x):
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
+
+
+class RelPosBlock(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.,
+ qkv_bias=False,
+ qk_norm=False,
+ rel_pos_cls=None,
+ init_values=None,
+ proj_drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ ):
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = RelPosAttention(
+ dim,
+ num_heads,
+ qkv_bias=qkv_bias,
+ qk_norm=qk_norm,
+ rel_pos_cls=rel_pos_cls,
+ attn_drop=attn_drop,
+ proj_drop=proj_drop,
+ )
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=int(dim * mlp_ratio),
+ act_layer=act_layer,
+ drop=proj_drop,
+ )
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
+ x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), shared_rel_pos=shared_rel_pos)))
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
+ return x
+
+
+class ResPostRelPosBlock(nn.Module):
+
+ def __init__(
+ self,
+ dim,
+ num_heads,
+ mlp_ratio=4.,
+ qkv_bias=False,
+ qk_norm=False,
+ rel_pos_cls=None,
+ init_values=None,
+ proj_drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ ):
+ super().__init__()
+ self.init_values = init_values
+
+ self.attn = RelPosAttention(
+ dim,
+ num_heads,
+ qkv_bias=qkv_bias,
+ qk_norm=qk_norm,
+ rel_pos_cls=rel_pos_cls,
+ attn_drop=attn_drop,
+ proj_drop=proj_drop,
+ )
+ self.norm1 = norm_layer(dim)
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=int(dim * mlp_ratio),
+ act_layer=act_layer,
+ drop=proj_drop,
+ )
+ self.norm2 = norm_layer(dim)
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+
+ self.init_weights()
+
+ def init_weights(self):
+ # NOTE this init overrides that base model init with specific changes for the block type
+ if self.init_values is not None:
+ nn.init.constant_(self.norm1.weight, self.init_values)
+ nn.init.constant_(self.norm2.weight, self.init_values)
+
+ def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
+ x = x + self.drop_path1(self.norm1(self.attn(x, shared_rel_pos=shared_rel_pos)))
+ x = x + self.drop_path2(self.norm2(self.mlp(x)))
+ return x
+
+
+class VisionTransformerRelPos(nn.Module):
+ """ Vision Transformer w/ Relative Position Bias
+
+ Differing from classic vit, this impl
+ * uses relative position index (swin v1 / beit) or relative log coord + mlp (swin v2) pos embed
+ * defaults to no class token (can be enabled)
+ * defaults to global avg pool for head (can be changed)
+ * layer-scale (residual branch gain) enabled
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ num_classes: int = 1000,
+ global_pool: Literal['', 'avg', 'token', 'map'] = 'avg',
+ embed_dim: int = 768,
+ depth: int = 12,
+ num_heads: int = 12,
+ mlp_ratio: float = 4.,
+ qkv_bias: bool = True,
+ qk_norm: bool = False,
+ init_values: Optional[float] = 1e-6,
+ class_token: bool = False,
+ fc_norm: bool = False,
+ rel_pos_type: str = 'mlp',
+ rel_pos_dim: Optional[int] = None,
+ shared_rel_pos: bool = False,
+ drop_rate: float = 0.,
+ proj_drop_rate: float = 0.,
+ attn_drop_rate: float = 0.,
+ drop_path_rate: float = 0.,
+ weight_init: Literal['skip', 'jax', 'moco', ''] = 'skip',
+ fix_init: bool = False,
+ embed_layer: Type[nn.Module] = PatchEmbed,
+ norm_layer: Optional[LayerType] = None,
+ act_layer: Optional[LayerType] = None,
+ block_fn: Type[nn.Module] = RelPosBlock
+ ):
+ """
+ Args:
+ img_size: input image size
+ patch_size: patch size
+ in_chans: number of input channels
+ num_classes: number of classes for classification head
+ global_pool: type of global pooling for final sequence (default: 'avg')
+ embed_dim: embedding dimension
+ depth: depth of transformer
+ num_heads: number of attention heads
+ mlp_ratio: ratio of mlp hidden dim to embedding dim
+ qkv_bias: enable bias for qkv if True
+ qk_norm: Enable normalization of query and key in attention
+ init_values: layer-scale init values
+ class_token: use class token (default: False)
+ fc_norm: use pre classifier norm instead of pre-pool
+ rel_pos_type: type of relative position
+ shared_rel_pos: share relative pos across all blocks
+ drop_rate: dropout rate
+ proj_drop_rate: projection dropout rate
+ attn_drop_rate: attention dropout rate
+ drop_path_rate: stochastic depth rate
+ weight_init: weight init scheme
+ fix_init: apply weight initialization fix (scaling w/ layer index)
+ embed_layer: patch embedding layer
+ norm_layer: normalization layer
+ act_layer: MLP activation layer
+ """
+ super().__init__()
+ assert global_pool in ('', 'avg', 'token')
+ assert class_token or global_pool != 'token'
+ norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
+ act_layer = act_layer or nn.GELU
+
+ self.num_classes = num_classes
+ self.global_pool = global_pool
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
+ self.num_prefix_tokens = 1 if class_token else 0
+ self.grad_checkpointing = False
+
+ self.patch_embed = embed_layer(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ )
+ feat_size = self.patch_embed.grid_size
+
+ rel_pos_args = dict(window_size=feat_size, prefix_tokens=self.num_prefix_tokens)
+ if rel_pos_type.startswith('mlp'):
+ if rel_pos_dim:
+ rel_pos_args['hidden_dim'] = rel_pos_dim
+ if 'swin' in rel_pos_type:
+ rel_pos_args['mode'] = 'swin'
+ rel_pos_cls = partial(RelPosMlp, **rel_pos_args)
+ else:
+ rel_pos_cls = partial(RelPosBias, **rel_pos_args)
+ self.shared_rel_pos = None
+ if shared_rel_pos:
+ self.shared_rel_pos = rel_pos_cls(num_heads=num_heads)
+ # NOTE shared rel pos currently mutually exclusive w/ per-block, but could support both...
+ rel_pos_cls = None
+
+ self.cls_token = nn.Parameter(torch.zeros(1, self.num_prefix_tokens, embed_dim)) if class_token else None
+
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
+ self.blocks = nn.ModuleList([
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_norm=qk_norm,
+ rel_pos_cls=rel_pos_cls,
+ init_values=init_values,
+ proj_drop=proj_drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[i],
+ norm_layer=norm_layer,
+ act_layer=act_layer,
+ )
+ for i in range(depth)])
+ self.norm = norm_layer(embed_dim) if not fc_norm else nn.Identity()
+
+ # Classifier Head
+ self.fc_norm = norm_layer(embed_dim) if fc_norm else nn.Identity()
+ self.head_drop = nn.Dropout(drop_rate)
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ if weight_init != 'skip':
+ self.init_weights(weight_init)
+ if fix_init:
+ self.fix_init_weight()
+
+ def init_weights(self, mode=''):
+ assert mode in ('jax', 'moco', '')
+ if self.cls_token is not None:
+ nn.init.normal_(self.cls_token, std=1e-6)
+ named_apply(get_init_weights_vit(mode), self)
+
+ def fix_init_weight(self):
+ def rescale(param, _layer_id):
+ param.div_(math.sqrt(2.0 * _layer_id))
+
+ for layer_id, layer in enumerate(self.blocks):
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'cls_token'}
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse=False):
+ return dict(
+ stem=r'^cls_token|patch_embed', # stem and embed
+ blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
+ )
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ self.grad_checkpointing = enable
+
+ @torch.jit.ignore
+ def get_classifier(self):
+ return self.head
+
+ def reset_classifier(self, num_classes: int, global_pool=None):
+ self.num_classes = num_classes
+ if global_pool is not None:
+ assert global_pool in ('', 'avg', 'token')
+ self.global_pool = global_pool
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
+
+ def forward_features(self, x):
+ x = self.patch_embed(x)
+ if self.cls_token is not None:
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+
+ shared_rel_pos = self.shared_rel_pos.get_bias() if self.shared_rel_pos is not None else None
+ for blk in self.blocks:
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint(blk, x, shared_rel_pos=shared_rel_pos)
+ else:
+ x = blk(x, shared_rel_pos=shared_rel_pos)
+ x = self.norm(x)
+ return x
+
+ def forward_head(self, x, pre_logits: bool = False):
+ if self.global_pool:
+ x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
+ x = self.fc_norm(x)
+ x = self.head_drop(x)
+ return x if pre_logits else self.head(x)
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.forward_head(x)
+ return x
+
+
+def _create_vision_transformer_relpos(variant, pretrained=False, **kwargs):
+ if kwargs.get('features_only', None):
+ raise RuntimeError('features_only not implemented for Vision Transformer models.')
+
+ model = build_model_with_cfg(VisionTransformerRelPos, variant, pretrained, **kwargs)
+ return model
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url,
+ 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
+ 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
+ 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
+ 'first_conv': 'patch_embed.proj', 'classifier': 'head',
+ **kwargs
+ }
+
+
+default_cfgs = generate_default_cfgs({
+ 'vit_relpos_base_patch32_plus_rpn_256.sw_in1k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_replos_base_patch32_plus_rpn_256-sw-dd486f51.pth',
+ hf_hub_id='timm/',
+ input_size=(3, 256, 256)),
+ 'vit_relpos_base_patch16_plus_240.untrained': _cfg(url='', input_size=(3, 240, 240)),
+
+ 'vit_relpos_small_patch16_224.sw_in1k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_small_patch16_224-sw-ec2778b4.pth',
+ hf_hub_id='timm/'),
+ 'vit_relpos_medium_patch16_224.sw_in1k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_224-sw-11c174af.pth',
+ hf_hub_id='timm/'),
+ 'vit_relpos_base_patch16_224.sw_in1k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_224-sw-49049aed.pth',
+ hf_hub_id='timm/'),
+
+ 'vit_srelpos_small_patch16_224.sw_in1k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_small_patch16_224-sw-6cdb8849.pth',
+ hf_hub_id='timm/'),
+ 'vit_srelpos_medium_patch16_224.sw_in1k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_medium_patch16_224-sw-ad702b8c.pth',
+ hf_hub_id='timm/'),
+
+ 'vit_relpos_medium_patch16_cls_224.sw_in1k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_cls_224-sw-cfe8e259.pth',
+ hf_hub_id='timm/'),
+ 'vit_relpos_base_patch16_cls_224.untrained': _cfg(),
+ 'vit_relpos_base_patch16_clsgap_224.sw_in1k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_gapcls_224-sw-1a341d6c.pth',
+ hf_hub_id='timm/'),
+
+ 'vit_relpos_small_patch16_rpn_224.untrained': _cfg(),
+ 'vit_relpos_medium_patch16_rpn_224.sw_in1k': _cfg(
+ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_rpn_224-sw-5d2befd8.pth',
+ hf_hub_id='timm/'),
+ 'vit_relpos_base_patch16_rpn_224.untrained': _cfg(),
+})
+
+
+@register_model
+def vit_relpos_base_patch32_plus_rpn_256(pretrained=False, **kwargs) -> VisionTransformerRelPos:
+ """ ViT-Base (ViT-B/32+) w/ relative log-coord position and residual post-norm, no class token
+ """
+ model_args = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, block_fn=ResPostRelPosBlock)
+ model = _create_vision_transformer_relpos(
+ 'vit_relpos_base_patch32_plus_rpn_256', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def vit_relpos_base_patch16_plus_240(pretrained=False, **kwargs) -> VisionTransformerRelPos:
+ """ ViT-Base (ViT-B/16+) w/ relative log-coord position, no class token
+ """
+ model_args = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14)
+ model = _create_vision_transformer_relpos(
+ 'vit_relpos_base_patch16_plus_240', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def vit_relpos_small_patch16_224(pretrained=False, **kwargs) -> VisionTransformerRelPos:
+ """ ViT-Base (ViT-B/16) w/ relative log-coord position, no class token
+ """
+ model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, qkv_bias=False, fc_norm=True)
+ model = _create_vision_transformer_relpos(
+ 'vit_relpos_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def vit_relpos_medium_patch16_224(pretrained=False, **kwargs) -> VisionTransformerRelPos:
+ """ ViT-Base (ViT-B/16) w/ relative log-coord position, no class token
+ """
+ model_args = dict(
+ patch_size=16, embed_dim=512, depth=12, num_heads=8, qkv_bias=False, fc_norm=True)
+ model = _create_vision_transformer_relpos(
+ 'vit_relpos_medium_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def vit_relpos_base_patch16_224(pretrained=False, **kwargs) -> VisionTransformerRelPos:
+ """ ViT-Base (ViT-B/16) w/ relative log-coord position, no class token
+ """
+ model_args = dict(
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, fc_norm=True)
+ model = _create_vision_transformer_relpos(
+ 'vit_relpos_base_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def vit_srelpos_small_patch16_224(pretrained=False, **kwargs) -> VisionTransformerRelPos:
+ """ ViT-Base (ViT-B/16) w/ shared relative log-coord position, no class token
+ """
+ model_args = dict(
+ patch_size=16, embed_dim=384, depth=12, num_heads=6, qkv_bias=False, fc_norm=False,
+ rel_pos_dim=384, shared_rel_pos=True)
+ model = _create_vision_transformer_relpos(
+ 'vit_srelpos_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def vit_srelpos_medium_patch16_224(pretrained=False, **kwargs) -> VisionTransformerRelPos:
+ """ ViT-Base (ViT-B/16) w/ shared relative log-coord position, no class token
+ """
+ model_args = dict(
+ patch_size=16, embed_dim=512, depth=12, num_heads=8, qkv_bias=False, fc_norm=False,
+ rel_pos_dim=512, shared_rel_pos=True)
+ model = _create_vision_transformer_relpos(
+ 'vit_srelpos_medium_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def vit_relpos_medium_patch16_cls_224(pretrained=False, **kwargs) -> VisionTransformerRelPos:
+ """ ViT-Base (ViT-M/16) w/ relative log-coord position, class token present
+ """
+ model_args = dict(
+ patch_size=16, embed_dim=512, depth=12, num_heads=8, qkv_bias=False, fc_norm=False,
+ rel_pos_dim=256, class_token=True, global_pool='token')
+ model = _create_vision_transformer_relpos(
+ 'vit_relpos_medium_patch16_cls_224', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def vit_relpos_base_patch16_cls_224(pretrained=False, **kwargs) -> VisionTransformerRelPos:
+ """ ViT-Base (ViT-B/16) w/ relative log-coord position, class token present
+ """
+ model_args = dict(
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, class_token=True, global_pool='token')
+ model = _create_vision_transformer_relpos(
+ 'vit_relpos_base_patch16_cls_224', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def vit_relpos_base_patch16_clsgap_224(pretrained=False, **kwargs) -> VisionTransformerRelPos:
+ """ ViT-Base (ViT-B/16) w/ relative log-coord position, class token present
+ NOTE this config is a bit of a mistake, class token was enabled but global avg-pool w/ fc-norm was not disabled
+ Leaving here for comparisons w/ a future re-train as it performs quite well.
+ """
+ model_args = dict(
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, fc_norm=True, class_token=True)
+ model = _create_vision_transformer_relpos(
+ 'vit_relpos_base_patch16_clsgap_224', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def vit_relpos_small_patch16_rpn_224(pretrained=False, **kwargs) -> VisionTransformerRelPos:
+ """ ViT-Base (ViT-B/16) w/ relative log-coord position and residual post-norm, no class token
+ """
+ model_args = dict(
+ patch_size=16, embed_dim=384, depth=12, num_heads=6, qkv_bias=False, block_fn=ResPostRelPosBlock)
+ model = _create_vision_transformer_relpos(
+ 'vit_relpos_small_patch16_rpn_224', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def vit_relpos_medium_patch16_rpn_224(pretrained=False, **kwargs) -> VisionTransformerRelPos:
+ """ ViT-Base (ViT-B/16) w/ relative log-coord position and residual post-norm, no class token
+ """
+ model_args = dict(
+ patch_size=16, embed_dim=512, depth=12, num_heads=8, qkv_bias=False, block_fn=ResPostRelPosBlock)
+ model = _create_vision_transformer_relpos(
+ 'vit_relpos_medium_patch16_rpn_224', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
+
+
+@register_model
+def vit_relpos_base_patch16_rpn_224(pretrained=False, **kwargs) -> VisionTransformerRelPos:
+ """ ViT-Base (ViT-B/16) w/ relative log-coord position and residual post-norm, no class token
+ """
+ model_args = dict(
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, block_fn=ResPostRelPosBlock)
+ model = _create_vision_transformer_relpos(
+ 'vit_relpos_base_patch16_rpn_224', pretrained=pretrained, **dict(model_args, **kwargs))
+ return model
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/vovnet.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/vovnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4a06065e82a1cb23798335f7b1e370155d19014
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/models/vovnet.py
@@ -0,0 +1,470 @@
+""" VoVNet (V1 & V2)
+
+Papers:
+* `An Energy and GPU-Computation Efficient Backbone Network` - https://arxiv.org/abs/1904.09730
+* `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
+
+Looked at https://github.com/youngwanLEE/vovnet-detectron2 &
+https://github.com/stigma0617/VoVNet.pytorch/blob/master/models_vovnet/vovnet.py
+for some reference, rewrote most of the code.
+
+Hacked together by / Copyright 2020 Ross Wightman
+"""
+
+from typing import List
+
+import torch
+import torch.nn as nn
+
+from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
+from timm.layers import ConvNormAct, SeparableConvNormAct, BatchNormAct2d, ClassifierHead, DropPath, \
+ create_attn, create_norm_act_layer
+from ._builder import build_model_with_cfg
+from ._manipulate import checkpoint_seq
+from ._registry import register_model, generate_default_cfgs
+
+__all__ = ['VovNet'] # model_registry will add each entrypoint fn to this
+
+
+class SequentialAppendList(nn.Sequential):
+ def __init__(self, *args):
+ super(SequentialAppendList, self).__init__(*args)
+
+ def forward(self, x: torch.Tensor, concat_list: List[torch.Tensor]) -> torch.Tensor:
+ for i, module in enumerate(self):
+ if i == 0:
+ concat_list.append(module(x))
+ else:
+ concat_list.append(module(concat_list[-1]))
+ x = torch.cat(concat_list, dim=1)
+ return x
+
+
+class OsaBlock(nn.Module):
+
+ def __init__(
+ self,
+ in_chs,
+ mid_chs,
+ out_chs,
+ layer_per_block,
+ residual=False,
+ depthwise=False,
+ attn='',
+ norm_layer=BatchNormAct2d,
+ act_layer=nn.ReLU,
+ drop_path=None,
+ ):
+ super(OsaBlock, self).__init__()
+
+ self.residual = residual
+ self.depthwise = depthwise
+ conv_kwargs = dict(norm_layer=norm_layer, act_layer=act_layer)
+
+ next_in_chs = in_chs
+ if self.depthwise and next_in_chs != mid_chs:
+ assert not residual
+ self.conv_reduction = ConvNormAct(next_in_chs, mid_chs, 1, **conv_kwargs)
+ else:
+ self.conv_reduction = None
+
+ mid_convs = []
+ for i in range(layer_per_block):
+ if self.depthwise:
+ conv = SeparableConvNormAct(mid_chs, mid_chs, **conv_kwargs)
+ else:
+ conv = ConvNormAct(next_in_chs, mid_chs, 3, **conv_kwargs)
+ next_in_chs = mid_chs
+ mid_convs.append(conv)
+ self.conv_mid = SequentialAppendList(*mid_convs)
+
+ # feature aggregation
+ next_in_chs = in_chs + layer_per_block * mid_chs
+ self.conv_concat = ConvNormAct(next_in_chs, out_chs, **conv_kwargs)
+
+ self.attn = create_attn(attn, out_chs) if attn else None
+
+ self.drop_path = drop_path
+
+ def forward(self, x):
+ output = [x]
+ if self.conv_reduction is not None:
+ x = self.conv_reduction(x)
+ x = self.conv_mid(x, output)
+ x = self.conv_concat(x)
+ if self.attn is not None:
+ x = self.attn(x)
+ if self.drop_path is not None:
+ x = self.drop_path(x)
+ if self.residual:
+ x = x + output[0]
+ return x
+
+
+class OsaStage(nn.Module):
+
+ def __init__(
+ self,
+ in_chs,
+ mid_chs,
+ out_chs,
+ block_per_stage,
+ layer_per_block,
+ downsample=True,
+ residual=True,
+ depthwise=False,
+ attn='ese',
+ norm_layer=BatchNormAct2d,
+ act_layer=nn.ReLU,
+ drop_path_rates=None,
+ ):
+ super(OsaStage, self).__init__()
+ self.grad_checkpointing = False
+
+ if downsample:
+ self.pool = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
+ else:
+ self.pool = None
+
+ blocks = []
+ for i in range(block_per_stage):
+ last_block = i == block_per_stage - 1
+ if drop_path_rates is not None and drop_path_rates[i] > 0.:
+ drop_path = DropPath(drop_path_rates[i])
+ else:
+ drop_path = None
+ blocks += [OsaBlock(
+ in_chs, mid_chs, out_chs, layer_per_block, residual=residual and i > 0, depthwise=depthwise,
+ attn=attn if last_block else '', norm_layer=norm_layer, act_layer=act_layer, drop_path=drop_path)
+ ]
+ in_chs = out_chs
+ self.blocks = nn.Sequential(*blocks)
+
+ def forward(self, x):
+ if self.pool is not None:
+ x = self.pool(x)
+ if self.grad_checkpointing and not torch.jit.is_scripting():
+ x = checkpoint_seq(self.blocks, x)
+ else:
+ x = self.blocks(x)
+ return x
+
+
+class VovNet(nn.Module):
+
+ def __init__(
+ self,
+ cfg,
+ in_chans=3,
+ num_classes=1000,
+ global_pool='avg',
+ output_stride=32,
+ norm_layer=BatchNormAct2d,
+ act_layer=nn.ReLU,
+ drop_rate=0.,
+ drop_path_rate=0.,
+ **kwargs,
+ ):
+ """
+ Args:
+ cfg (dict): Model architecture configuration
+ in_chans (int): Number of input channels (default: 3)
+ num_classes (int): Number of classifier classes (default: 1000)
+ global_pool (str): Global pooling type (default: 'avg')
+ output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
+ norm_layer (Union[str, nn.Module]): normalization layer
+ act_layer (Union[str, nn.Module]): activation layer
+ drop_rate (float): Dropout rate (default: 0.)
+ drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
+ kwargs (dict): Extra kwargs overlayed onto cfg
+ """
+ super(VovNet, self).__init__()
+ self.num_classes = num_classes
+ self.drop_rate = drop_rate
+ assert output_stride == 32 # FIXME support dilation
+
+ cfg = dict(cfg, **kwargs)
+ stem_stride = cfg.get("stem_stride", 4)
+ stem_chs = cfg["stem_chs"]
+ stage_conv_chs = cfg["stage_conv_chs"]
+ stage_out_chs = cfg["stage_out_chs"]
+ block_per_stage = cfg["block_per_stage"]
+ layer_per_block = cfg["layer_per_block"]
+ conv_kwargs = dict(norm_layer=norm_layer, act_layer=act_layer)
+
+ # Stem module
+ last_stem_stride = stem_stride // 2
+ conv_type = SeparableConvNormAct if cfg["depthwise"] else ConvNormAct
+ self.stem = nn.Sequential(*[
+ ConvNormAct(in_chans, stem_chs[0], 3, stride=2, **conv_kwargs),
+ conv_type(stem_chs[0], stem_chs[1], 3, stride=1, **conv_kwargs),
+ conv_type(stem_chs[1], stem_chs[2], 3, stride=last_stem_stride, **conv_kwargs),
+ ])
+ self.feature_info = [dict(
+ num_chs=stem_chs[1], reduction=2, module=f'stem.{1 if stem_stride == 4 else 2}')]
+ current_stride = stem_stride
+
+ # OSA stages
+ stage_dpr = torch.split(torch.linspace(0, drop_path_rate, sum(block_per_stage)), block_per_stage)
+ in_ch_list = stem_chs[-1:] + stage_out_chs[:-1]
+ stage_args = dict(residual=cfg["residual"], depthwise=cfg["depthwise"], attn=cfg["attn"], **conv_kwargs)
+ stages = []
+ for i in range(4): # num_stages
+ downsample = stem_stride == 2 or i > 0 # first stage has no stride/downsample if stem_stride is 4
+ stages += [OsaStage(
+ in_ch_list[i],
+ stage_conv_chs[i],
+ stage_out_chs[i],
+ block_per_stage[i],
+ layer_per_block,
+ downsample=downsample,
+ drop_path_rates=stage_dpr[i],
+ **stage_args,
+ )]
+ self.num_features = stage_out_chs[i]
+ current_stride *= 2 if downsample else 1
+ self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')]
+
+ self.stages = nn.Sequential(*stages)
+
+ self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
+
+ for n, m in self.named_modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.Linear):
+ nn.init.zeros_(m.bias)
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse=False):
+ return dict(
+ stem=r'^stem',
+ blocks=r'^stages\.(\d+)' if coarse else r'^stages\.(\d+).blocks\.(\d+)',
+ )
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ for s in self.stages:
+ s.grad_checkpointing = enable
+
+ @torch.jit.ignore
+ def get_classifier(self):
+ return self.head.fc
+
+ def reset_classifier(self, num_classes, global_pool='avg'):
+ self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
+
+ def forward_features(self, x):
+ x = self.stem(x)
+ return self.stages(x)
+
+ def forward_head(self, x, pre_logits: bool = False):
+ return self.head(x, pre_logits=pre_logits)
+
+ def forward(self, x):
+ x = self.forward_features(x)
+ x = self.forward_head(x)
+ return x
+
+
+# model cfgs adapted from https://github.com/youngwanLEE/vovnet-detectron2 &
+# https://github.com/stigma0617/VoVNet.pytorch/blob/master/models_vovnet/vovnet.py
+model_cfgs = dict(
+ vovnet39a=dict(
+ stem_chs=[64, 64, 128],
+ stage_conv_chs=[128, 160, 192, 224],
+ stage_out_chs=[256, 512, 768, 1024],
+ layer_per_block=5,
+ block_per_stage=[1, 1, 2, 2],
+ residual=False,
+ depthwise=False,
+ attn='',
+ ),
+ vovnet57a=dict(
+ stem_chs=[64, 64, 128],
+ stage_conv_chs=[128, 160, 192, 224],
+ stage_out_chs=[256, 512, 768, 1024],
+ layer_per_block=5,
+ block_per_stage=[1, 1, 4, 3],
+ residual=False,
+ depthwise=False,
+ attn='',
+
+ ),
+ ese_vovnet19b_slim_dw=dict(
+ stem_chs=[64, 64, 64],
+ stage_conv_chs=[64, 80, 96, 112],
+ stage_out_chs=[112, 256, 384, 512],
+ layer_per_block=3,
+ block_per_stage=[1, 1, 1, 1],
+ residual=True,
+ depthwise=True,
+ attn='ese',
+
+ ),
+ ese_vovnet19b_dw=dict(
+ stem_chs=[64, 64, 64],
+ stage_conv_chs=[128, 160, 192, 224],
+ stage_out_chs=[256, 512, 768, 1024],
+ layer_per_block=3,
+ block_per_stage=[1, 1, 1, 1],
+ residual=True,
+ depthwise=True,
+ attn='ese',
+ ),
+ ese_vovnet19b_slim=dict(
+ stem_chs=[64, 64, 128],
+ stage_conv_chs=[64, 80, 96, 112],
+ stage_out_chs=[112, 256, 384, 512],
+ layer_per_block=3,
+ block_per_stage=[1, 1, 1, 1],
+ residual=True,
+ depthwise=False,
+ attn='ese',
+ ),
+ ese_vovnet19b=dict(
+ stem_chs=[64, 64, 128],
+ stage_conv_chs=[128, 160, 192, 224],
+ stage_out_chs=[256, 512, 768, 1024],
+ layer_per_block=3,
+ block_per_stage=[1, 1, 1, 1],
+ residual=True,
+ depthwise=False,
+ attn='ese',
+
+ ),
+ ese_vovnet39b=dict(
+ stem_chs=[64, 64, 128],
+ stage_conv_chs=[128, 160, 192, 224],
+ stage_out_chs=[256, 512, 768, 1024],
+ layer_per_block=5,
+ block_per_stage=[1, 1, 2, 2],
+ residual=True,
+ depthwise=False,
+ attn='ese',
+ ),
+ ese_vovnet57b=dict(
+ stem_chs=[64, 64, 128],
+ stage_conv_chs=[128, 160, 192, 224],
+ stage_out_chs=[256, 512, 768, 1024],
+ layer_per_block=5,
+ block_per_stage=[1, 1, 4, 3],
+ residual=True,
+ depthwise=False,
+ attn='ese',
+
+ ),
+ ese_vovnet99b=dict(
+ stem_chs=[64, 64, 128],
+ stage_conv_chs=[128, 160, 192, 224],
+ stage_out_chs=[256, 512, 768, 1024],
+ layer_per_block=5,
+ block_per_stage=[1, 3, 9, 3],
+ residual=True,
+ depthwise=False,
+ attn='ese',
+ ),
+ eca_vovnet39b=dict(
+ stem_chs=[64, 64, 128],
+ stage_conv_chs=[128, 160, 192, 224],
+ stage_out_chs=[256, 512, 768, 1024],
+ layer_per_block=5,
+ block_per_stage=[1, 1, 2, 2],
+ residual=True,
+ depthwise=False,
+ attn='eca',
+ ),
+)
+model_cfgs['ese_vovnet39b_evos'] = model_cfgs['ese_vovnet39b']
+
+
+def _create_vovnet(variant, pretrained=False, **kwargs):
+ return build_model_with_cfg(
+ VovNet,
+ variant,
+ pretrained,
+ model_cfg=model_cfgs[variant],
+ feature_cfg=dict(flatten_sequential=True),
+ **kwargs,
+ )
+
+
+def _cfg(url='', **kwargs):
+ return {
+ 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
+ 'crop_pct': 0.875, 'interpolation': 'bicubic',
+ 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
+ 'first_conv': 'stem.0.conv', 'classifier': 'head.fc', **kwargs,
+ }
+
+
+default_cfgs = generate_default_cfgs({
+ 'vovnet39a.untrained': _cfg(url=''),
+ 'vovnet57a.untrained': _cfg(url=''),
+ 'ese_vovnet19b_slim_dw.untrained': _cfg(url=''),
+ 'ese_vovnet19b_dw.ra_in1k': _cfg(
+ hf_hub_id='timm/',
+ test_input_size=(3, 288, 288), test_crop_pct=0.95),
+ 'ese_vovnet19b_slim.untrained': _cfg(url=''),
+ 'ese_vovnet39b.ra_in1k': _cfg(
+ hf_hub_id='timm/',
+ test_input_size=(3, 288, 288), test_crop_pct=0.95),
+ 'ese_vovnet57b.untrained': _cfg(url=''),
+ 'ese_vovnet99b.untrained': _cfg(url=''),
+ 'eca_vovnet39b.untrained': _cfg(url=''),
+ 'ese_vovnet39b_evos.untrained': _cfg(url=''),
+})
+
+
+@register_model
+def vovnet39a(pretrained=False, **kwargs) -> VovNet:
+ return _create_vovnet('vovnet39a', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def vovnet57a(pretrained=False, **kwargs) -> VovNet:
+ return _create_vovnet('vovnet57a', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def ese_vovnet19b_slim_dw(pretrained=False, **kwargs) -> VovNet:
+ return _create_vovnet('ese_vovnet19b_slim_dw', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def ese_vovnet19b_dw(pretrained=False, **kwargs) -> VovNet:
+ return _create_vovnet('ese_vovnet19b_dw', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def ese_vovnet19b_slim(pretrained=False, **kwargs) -> VovNet:
+ return _create_vovnet('ese_vovnet19b_slim', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def ese_vovnet39b(pretrained=False, **kwargs) -> VovNet:
+ return _create_vovnet('ese_vovnet39b', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def ese_vovnet57b(pretrained=False, **kwargs) -> VovNet:
+ return _create_vovnet('ese_vovnet57b', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def ese_vovnet99b(pretrained=False, **kwargs) -> VovNet:
+ return _create_vovnet('ese_vovnet99b', pretrained=pretrained, **kwargs)
+
+
+@register_model
+def eca_vovnet39b(pretrained=False, **kwargs) -> VovNet:
+ return _create_vovnet('eca_vovnet39b', pretrained=pretrained, **kwargs)
+
+
+# Experimental Models
+
+@register_model
+def ese_vovnet39b_evos(pretrained=False, **kwargs) -> VovNet:
+ def norm_act_fn(num_features, **nkwargs):
+ return create_norm_act_layer('evonorms0', num_features, jit=False, **nkwargs)
+ return _create_vovnet('ese_vovnet39b_evos', pretrained=pretrained, norm_layer=norm_act_fn, **kwargs)
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/adahessian.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/adahessian.py
new file mode 100644
index 0000000000000000000000000000000000000000..985c67ca686a65f61f5c5b1a7db3e5bba815a19b
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/adahessian.py
@@ -0,0 +1,156 @@
+""" AdaHessian Optimizer
+
+Lifted from https://github.com/davda54/ada-hessian/blob/master/ada_hessian.py
+Originally licensed MIT, Copyright 2020, David Samuel
+"""
+import torch
+
+
+class Adahessian(torch.optim.Optimizer):
+ """
+ Implements the AdaHessian algorithm from "ADAHESSIAN: An Adaptive Second OrderOptimizer for Machine Learning"
+
+ Arguments:
+ params (iterable): iterable of parameters to optimize or dicts defining parameter groups
+ lr (float, optional): learning rate (default: 0.1)
+ betas ((float, float), optional): coefficients used for computing running averages of gradient and the
+ squared hessian trace (default: (0.9, 0.999))
+ eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0.0)
+ hessian_power (float, optional): exponent of the hessian trace (default: 1.0)
+ update_each (int, optional): compute the hessian trace approximation only after *this* number of steps
+ (to save time) (default: 1)
+ n_samples (int, optional): how many times to sample `z` for the approximation of the hessian trace (default: 1)
+ """
+
+ def __init__(self, params, lr=0.1, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0,
+ hessian_power=1.0, update_each=1, n_samples=1, avg_conv_kernel=False):
+ if not 0.0 <= lr:
+ raise ValueError(f"Invalid learning rate: {lr}")
+ if not 0.0 <= eps:
+ raise ValueError(f"Invalid epsilon value: {eps}")
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
+ if not 0.0 <= hessian_power <= 1.0:
+ raise ValueError(f"Invalid Hessian power value: {hessian_power}")
+
+ self.n_samples = n_samples
+ self.update_each = update_each
+ self.avg_conv_kernel = avg_conv_kernel
+
+ # use a separate generator that deterministically generates the same `z`s across all GPUs in case of distributed training
+ self.seed = 2147483647
+ self.generator = torch.Generator().manual_seed(self.seed)
+
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, hessian_power=hessian_power)
+ super(Adahessian, self).__init__(params, defaults)
+
+ for p in self.get_params():
+ p.hess = 0.0
+ self.state[p]["hessian step"] = 0
+
+ @property
+ def is_second_order(self):
+ return True
+
+ def get_params(self):
+ """
+ Gets all parameters in all param_groups with gradients
+ """
+
+ return (p for group in self.param_groups for p in group['params'] if p.requires_grad)
+
+ def zero_hessian(self):
+ """
+ Zeros out the accumalated hessian traces.
+ """
+
+ for p in self.get_params():
+ if not isinstance(p.hess, float) and self.state[p]["hessian step"] % self.update_each == 0:
+ p.hess.zero_()
+
+ @torch.no_grad()
+ def set_hessian(self):
+ """
+ Computes the Hutchinson approximation of the hessian trace and accumulates it for each trainable parameter.
+ """
+
+ params = []
+ for p in filter(lambda p: p.grad is not None, self.get_params()):
+ if self.state[p]["hessian step"] % self.update_each == 0: # compute the trace only each `update_each` step
+ params.append(p)
+ self.state[p]["hessian step"] += 1
+
+ if len(params) == 0:
+ return
+
+ if self.generator.device != params[0].device: # hackish way of casting the generator to the right device
+ self.generator = torch.Generator(params[0].device).manual_seed(self.seed)
+
+ grads = [p.grad for p in params]
+
+ for i in range(self.n_samples):
+ # Rademacher distribution {-1.0, 1.0}
+ zs = [torch.randint(0, 2, p.size(), generator=self.generator, device=p.device) * 2.0 - 1.0 for p in params]
+ h_zs = torch.autograd.grad(
+ grads, params, grad_outputs=zs, only_inputs=True, retain_graph=i < self.n_samples - 1)
+ for h_z, z, p in zip(h_zs, zs, params):
+ p.hess += h_z * z / self.n_samples # approximate the expected values of z*(H@z)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """
+ Performs a single optimization step.
+ Arguments:
+ closure (callable, optional) -- a closure that reevaluates the model and returns the loss (default: None)
+ """
+
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ self.zero_hessian()
+ self.set_hessian()
+
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None or p.hess is None:
+ continue
+
+ if self.avg_conv_kernel and p.dim() == 4:
+ p.hess = torch.abs(p.hess).mean(dim=[2, 3], keepdim=True).expand_as(p.hess).clone()
+
+ # Perform correct stepweight decay as in AdamW
+ p.mul_(1 - group['lr'] * group['weight_decay'])
+
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 1:
+ state['step'] = 0
+ # Exponential moving average of gradient values
+ state['exp_avg'] = torch.zeros_like(p)
+ # Exponential moving average of Hessian diagonal square values
+ state['exp_hessian_diag_sq'] = torch.zeros_like(p)
+
+ exp_avg, exp_hessian_diag_sq = state['exp_avg'], state['exp_hessian_diag_sq']
+ beta1, beta2 = group['betas']
+ state['step'] += 1
+
+ # Decay the first and second moment running average coefficient
+ exp_avg.mul_(beta1).add_(p.grad, alpha=1 - beta1)
+ exp_hessian_diag_sq.mul_(beta2).addcmul_(p.hess, p.hess, value=1 - beta2)
+
+ bias_correction1 = 1 - beta1 ** state['step']
+ bias_correction2 = 1 - beta2 ** state['step']
+
+ k = group['hessian_power']
+ denom = (exp_hessian_diag_sq / bias_correction2).pow_(k / 2).add_(group['eps'])
+
+ # make update
+ step_size = group['lr'] / bias_correction1
+ p.addcdiv_(exp_avg, denom, value=-step_size)
+
+ return loss
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/nadam.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/nadam.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e911420efc4326a537182e0f31633502d6b2026
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/nadam.py
@@ -0,0 +1,97 @@
+import math
+
+import torch
+from torch.optim.optimizer import Optimizer
+
+
+class Nadam(Optimizer):
+ """Implements Nadam algorithm (a variant of Adam based on Nesterov momentum).
+
+ It has been proposed in `Incorporating Nesterov Momentum into Adam`__.
+
+ Arguments:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 2e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+ schedule_decay (float, optional): momentum schedule decay (default: 4e-3)
+
+ __ http://cs229.stanford.edu/proj2015/054_report.pdf
+ __ http://www.cs.toronto.edu/~fritz/absps/momentum.pdf
+
+ Originally taken from: https://github.com/pytorch/pytorch/pull/1408
+ NOTE: Has potential issues but does work well on some problems.
+ """
+
+ def __init__(self, params, lr=2e-3, betas=(0.9, 0.999), eps=1e-8,
+ weight_decay=0, schedule_decay=4e-3):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ defaults = dict(
+ lr=lr,
+ betas=betas,
+ eps=eps,
+ weight_decay=weight_decay,
+ schedule_decay=schedule_decay,
+ )
+ super(Nadam, self).__init__(params, defaults)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state['step'] = 0
+ state['m_schedule'] = 1.
+ state['exp_avg'] = torch.zeros_like(p)
+ state['exp_avg_sq'] = torch.zeros_like(p)
+
+ # Warming momentum schedule
+ m_schedule = state['m_schedule']
+ schedule_decay = group['schedule_decay']
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+ beta1, beta2 = group['betas']
+ eps = group['eps']
+ state['step'] += 1
+ t = state['step']
+ bias_correction2 = 1 - beta2 ** t
+
+ if group['weight_decay'] != 0:
+ grad = grad.add(p, alpha=group['weight_decay'])
+
+ momentum_cache_t = beta1 * (1. - 0.5 * (0.96 ** (t * schedule_decay)))
+ momentum_cache_t_1 = beta1 * (1. - 0.5 * (0.96 ** ((t + 1) * schedule_decay)))
+ m_schedule_new = m_schedule * momentum_cache_t
+ m_schedule_next = m_schedule * momentum_cache_t * momentum_cache_t_1
+ state['m_schedule'] = m_schedule_new
+
+ # Decay the first and second moment running average coefficient
+ exp_avg.mul_(beta1).add_(grad, alpha=1. - beta1)
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1. - beta2)
+
+ denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
+ p.addcdiv_(grad, denom, value=-group['lr'] * (1. - momentum_cache_t) / (1. - m_schedule_new))
+ p.addcdiv_(exp_avg, denom, value=-group['lr'] * momentum_cache_t_1 / (1. - m_schedule_next))
+
+ return loss
diff --git a/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/optim_factory.py b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/optim_factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..8187b55a1b9539213b7df2a7447b0b425e23b0aa
--- /dev/null
+++ b/my_container_sandbox/workspace/anaconda3/lib/python3.8/site-packages/timm/optim/optim_factory.py
@@ -0,0 +1,423 @@
+""" Optimizer Factory w/ Custom Weight Decay
+Hacked together by / Copyright 2021 Ross Wightman
+"""
+import logging
+from itertools import islice
+from typing import Optional, Callable, Tuple
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+from timm.models import group_parameters
+
+from .adabelief import AdaBelief
+from .adafactor import Adafactor
+from .adahessian import Adahessian
+from .adamp import AdamP
+from .adan import Adan
+from .lamb import Lamb
+from .lars import Lars
+from .lion import Lion
+from .lookahead import Lookahead
+from .madgrad import MADGRAD
+from .nadam import Nadam
+from .nadamw import NAdamW
+from .nvnovograd import NvNovoGrad
+from .radam import RAdam
+from .rmsprop_tf import RMSpropTF
+from .sgdp import SGDP
+from .sgdw import SGDW
+
+
+_logger = logging.getLogger(__name__)
+
+
+# optimizers to default to multi-tensor
+_DEFAULT_FOREACH = {
+ 'lion',
+}
+
+
+def param_groups_weight_decay(
+ model: nn.Module,
+ weight_decay=1e-5,
+ no_weight_decay_list=()
+):
+ no_weight_decay_list = set(no_weight_decay_list)
+ decay = []
+ no_decay = []
+ for name, param in model.named_parameters():
+ if not param.requires_grad:
+ continue
+
+ if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list:
+ no_decay.append(param)
+ else:
+ decay.append(param)
+
+ return [
+ {'params': no_decay, 'weight_decay': 0.},
+ {'params': decay, 'weight_decay': weight_decay}]
+
+
+def _group(it, size):
+ it = iter(it)
+ return iter(lambda: tuple(islice(it, size)), ())
+
+
+def _layer_map(model, layers_per_group=12, num_groups=None):
+ def _in_head(n, hp):
+ if not hp:
+ return True
+ elif isinstance(hp, (tuple, list)):
+ return any([n.startswith(hpi) for hpi in hp])
+ else:
+ return n.startswith(hp)
+
+ head_prefix = getattr(model, 'pretrained_cfg', {}).get('classifier', None)
+ names_trunk = []
+ names_head = []
+ for n, _ in model.named_parameters():
+ names_head.append(n) if _in_head(n, head_prefix) else names_trunk.append(n)
+
+ # group non-head layers
+ num_trunk_layers = len(names_trunk)
+ if num_groups is not None:
+ layers_per_group = -(num_trunk_layers // -num_groups)
+ names_trunk = list(_group(names_trunk, layers_per_group))
+
+ num_trunk_groups = len(names_trunk)
+ layer_map = {n: i for i, l in enumerate(names_trunk) for n in l}
+ layer_map.update({n: num_trunk_groups for n in names_head})
+ return layer_map
+
+
+def param_groups_layer_decay(
+ model: nn.Module,
+ weight_decay: float = 0.05,
+ no_weight_decay_list: Tuple[str] = (),
+ layer_decay: float = .75,
+ end_layer_decay: Optional[float] = None,
+ verbose: bool = False,
+):
+ """
+ Parameter groups for layer-wise lr decay & weight decay
+ Based on BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
+ """
+ no_weight_decay_list = set(no_weight_decay_list)
+ param_group_names = {} # NOTE for debugging
+ param_groups = {}
+
+ if hasattr(model, 'group_matcher'):
+ # FIXME interface needs more work
+ layer_map = group_parameters(model, model.group_matcher(coarse=False), reverse=True)
+ else:
+ # fallback
+ layer_map = _layer_map(model)
+ num_layers = max(layer_map.values()) + 1
+ layer_max = num_layers - 1
+ layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers))
+
+ for name, param in model.named_parameters():
+ if not param.requires_grad:
+ continue
+
+ # no decay: all 1D parameters and model specific ones
+ if param.ndim == 1 or name in no_weight_decay_list:
+ g_decay = "no_decay"
+ this_decay = 0.
+ else:
+ g_decay = "decay"
+ this_decay = weight_decay
+
+ layer_id = layer_map.get(name, layer_max)
+ group_name = "layer_%d_%s" % (layer_id, g_decay)
+
+ if group_name not in param_groups:
+ this_scale = layer_scales[layer_id]
+ param_group_names[group_name] = {
+ "lr_scale": this_scale,
+ "weight_decay": this_decay,
+ "param_names": [],
+ }
+ param_groups[group_name] = {
+ "lr_scale": this_scale,
+ "weight_decay": this_decay,
+ "params": [],
+ }
+
+ param_group_names[group_name]["param_names"].append(name)
+ param_groups[group_name]["params"].append(param)
+
+ if verbose:
+ import json
+ _logger.info("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
+
+ return list(param_groups.values())
+
+
+def optimizer_kwargs(cfg):
+ """ cfg/argparse to kwargs helper
+ Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn.
+ """
+ kwargs = dict(
+ opt=cfg.opt,
+ lr=cfg.lr,
+ weight_decay=cfg.weight_decay,
+ momentum=cfg.momentum,
+ )
+ if getattr(cfg, 'opt_eps', None) is not None:
+ kwargs['eps'] = cfg.opt_eps
+ if getattr(cfg, 'opt_betas', None) is not None:
+ kwargs['betas'] = cfg.opt_betas
+ if getattr(cfg, 'layer_decay', None) is not None:
+ kwargs['layer_decay'] = cfg.layer_decay
+ if getattr(cfg, 'opt_args', None) is not None:
+ kwargs.update(cfg.opt_args)
+ if getattr(cfg, 'opt_foreach', None) is not None:
+ kwargs['foreach'] = cfg.opt_foreach
+ return kwargs
+
+
+def create_optimizer(args, model, filter_bias_and_bn=True):
+ """ Legacy optimizer factory for backwards compatibility.
+ NOTE: Use create_optimizer_v2 for new code.
+ """
+ return create_optimizer_v2(
+ model,
+ **optimizer_kwargs(cfg=args),
+ filter_bias_and_bn=filter_bias_and_bn,
+ )
+
+
+def create_optimizer_v2(
+ model_or_params,
+ opt: str = 'sgd',
+ lr: Optional[float] = None,
+ weight_decay: float = 0.,
+ momentum: float = 0.9,
+ foreach: Optional[bool] = None,
+ filter_bias_and_bn: bool = True,
+ layer_decay: Optional[float] = None,
+ param_group_fn: Optional[Callable] = None,
+ **kwargs,
+):
+ """ Create an optimizer.
+
+ TODO currently the model is passed in and all parameters are selected for optimization.
+ For more general use an interface that allows selection of parameters to optimize and lr groups, one of:
+ * a filter fn interface that further breaks params into groups in a weight_decay compatible fashion
+ * expose the parameters interface and leave it up to caller
+
+ Args:
+ model_or_params (nn.Module): model containing parameters to optimize
+ opt: name of optimizer to create
+ lr: initial learning rate
+ weight_decay: weight decay to apply in optimizer
+ momentum: momentum for momentum based optimizers (others may use betas via kwargs)
+ foreach: Enable / disable foreach (multi-tensor) operation if True / False. Choose safe default if None
+ filter_bias_and_bn: filter out bias, bn and other 1d params from weight decay
+ **kwargs: extra optimizer specific kwargs to pass through
+
+ Returns:
+ Optimizer
+ """
+ if isinstance(model_or_params, nn.Module):
+ # a model was passed in, extract parameters and add weight decays to appropriate layers
+ no_weight_decay = {}
+ if hasattr(model_or_params, 'no_weight_decay'):
+ no_weight_decay = model_or_params.no_weight_decay()
+
+ if param_group_fn:
+ parameters = param_group_fn(model_or_params)
+ elif layer_decay is not None:
+ parameters = param_groups_layer_decay(
+ model_or_params,
+ weight_decay=weight_decay,
+ layer_decay=layer_decay,
+ no_weight_decay_list=no_weight_decay,
+ )
+ weight_decay = 0.
+ elif weight_decay and filter_bias_and_bn:
+ parameters = param_groups_weight_decay(model_or_params, weight_decay, no_weight_decay)
+ weight_decay = 0.
+ else:
+ parameters = model_or_params.parameters()
+ else:
+ # iterable of parameters or param groups passed in
+ parameters = model_or_params
+
+ opt_lower = opt.lower()
+ opt_split = opt_lower.split('_')
+ opt_lower = opt_split[-1]
+
+ if opt_lower.startswith('fused'):
+ try:
+ from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
+ has_apex = True
+ except ImportError:
+ has_apex = False
+ assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
+
+ if opt_lower.startswith('bnb'):
+ try:
+ import bitsandbytes as bnb
+ has_bnb = True
+ except ImportError:
+ has_bnb = False
+ assert has_bnb and torch.cuda.is_available(), 'bitsandbytes and CUDA required for bnb optimizers'
+
+ opt_args = dict(weight_decay=weight_decay, **kwargs)
+
+ if lr is not None:
+ opt_args.setdefault('lr', lr)
+
+ if foreach is None:
+ if opt in _DEFAULT_FOREACH:
+ opt_args.setdefault('foreach', True)
+ else:
+ opt_args['foreach'] = foreach
+
+ # basic SGD & related
+ if opt_lower == 'sgd' or opt_lower == 'nesterov':
+ # NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons
+ opt_args.pop('eps', None)
+ optimizer = optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args)
+ elif opt_lower == 'momentum':
+ opt_args.pop('eps', None)
+ optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args)
+ elif opt_lower == 'sgdp':
+ optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args)
+ elif opt_lower == 'sgdw' or opt_lower == 'nesterovw':
+ # NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons
+ opt_args.pop('eps', None)
+ optimizer = SGDW(parameters, momentum=momentum, nesterov=True, **opt_args)
+ elif opt_lower == 'momentumw':
+ opt_args.pop('eps', None)
+ optimizer = SGDW(parameters, momentum=momentum, nesterov=False, **opt_args)
+
+ # adaptive
+ elif opt_lower == 'adam':
+ optimizer = optim.Adam(parameters, **opt_args)
+ elif opt_lower == 'adamw':
+ optimizer = optim.AdamW(parameters, **opt_args)
+ elif opt_lower == 'adamp':
+ optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
+ elif opt_lower == 'nadam':
+ try:
+ # NOTE PyTorch >= 1.10 should have native NAdam
+ optimizer = optim.Nadam(parameters, **opt_args)
+ except AttributeError:
+ optimizer = Nadam(parameters, **opt_args)
+ elif opt_lower == 'nadamw':
+ optimizer = NAdamW(parameters, **opt_args)
+ elif opt_lower == 'radam':
+ optimizer = RAdam(parameters, **opt_args)
+ elif opt_lower == 'adamax':
+ optimizer = optim.Adamax(parameters, **opt_args)
+ elif opt_lower == 'adabelief':
+ optimizer = AdaBelief(parameters, rectify=False, **opt_args)
+ elif opt_lower == 'radabelief':
+ optimizer = AdaBelief(parameters, rectify=True, **opt_args)
+ elif opt_lower == 'adadelta':
+ optimizer = optim.Adadelta(parameters, **opt_args)
+ elif opt_lower == 'adagrad':
+ opt_args.setdefault('eps', 1e-8)
+ optimizer = optim.Adagrad(parameters, **opt_args)
+ elif opt_lower == 'adafactor':
+ optimizer = Adafactor(parameters, **opt_args)
+ elif opt_lower == 'adanp':
+ optimizer = Adan(parameters, no_prox=False, **opt_args)
+ elif opt_lower == 'adanw':
+ optimizer = Adan(parameters, no_prox=True, **opt_args)
+ elif opt_lower == 'lamb':
+ optimizer = Lamb(parameters, **opt_args)
+ elif opt_lower == 'lambc':
+ optimizer = Lamb(parameters, trust_clip=True, **opt_args)
+ elif opt_lower == 'larc':
+ optimizer = Lars(parameters, momentum=momentum, trust_clip=True, **opt_args)
+ elif opt_lower == 'lars':
+ optimizer = Lars(parameters, momentum=momentum, **opt_args)
+ elif opt_lower == 'nlarc':
+ optimizer = Lars(parameters, momentum=momentum, trust_clip=True, nesterov=True, **opt_args)
+ elif opt_lower == 'nlars':
+ optimizer = Lars(parameters, momentum=momentum, nesterov=True, **opt_args)
+ elif opt_lower == 'madgrad':
+ optimizer = MADGRAD(parameters, momentum=momentum, **opt_args)
+ elif opt_lower == 'madgradw':
+ optimizer = MADGRAD(parameters, momentum=momentum, decoupled_decay=True, **opt_args)
+ elif opt_lower == 'novograd' or opt_lower == 'nvnovograd':
+ optimizer = NvNovoGrad(parameters, **opt_args)
+ elif opt_lower == 'rmsprop':
+ optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=momentum, **opt_args)
+ elif opt_lower == 'rmsproptf':
+ optimizer = RMSpropTF(parameters, alpha=0.9, momentum=momentum, **opt_args)
+ elif opt_lower == 'lion':
+ opt_args.pop('eps', None)
+ optimizer = Lion(parameters, **opt_args)
+
+ # second order
+ elif opt_lower == 'adahessian':
+ optimizer = Adahessian(parameters, **opt_args)
+
+ # NVIDIA fused optimizers, require APEX to be installed
+ elif opt_lower == 'fusedsgd':
+ opt_args.pop('eps', None)
+ optimizer = FusedSGD(parameters, momentum=momentum, nesterov=True, **opt_args)
+ elif opt_lower == 'fusedmomentum':
+ opt_args.pop('eps', None)
+ optimizer = FusedSGD(parameters, momentum=momentum, nesterov=False, **opt_args)
+ elif opt_lower == 'fusedadam':
+ optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
+ elif opt_lower == 'fusedadamw':
+ optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
+ elif opt_lower == 'fusedlamb':
+ optimizer = FusedLAMB(parameters, **opt_args)
+ elif opt_lower == 'fusednovograd':
+ opt_args.setdefault('betas', (0.95, 0.98))
+ optimizer = FusedNovoGrad(parameters, **opt_args)
+
+ # bitsandbytes optimizers, require bitsandbytes to be installed
+ elif opt_lower == 'bnbsgd':
+ opt_args.pop('eps', None)
+ optimizer = bnb.optim.SGD(parameters, momentum=momentum, nesterov=True, **opt_args)
+ elif opt_lower == 'bnbsgd8bit':
+ opt_args.pop('eps', None)
+ optimizer = bnb.optim.SGD8bit(parameters, momentum=momentum, nesterov=True, **opt_args)
+ elif opt_lower == 'bnbmomentum':
+ opt_args.pop('eps', None)
+ optimizer = bnb.optim.SGD(parameters, momentum=momentum, **opt_args)
+ elif opt_lower == 'bnbmomentum8bit':
+ opt_args.pop('eps', None)
+ optimizer = bnb.optim.SGD8bit(parameters, momentum=momentum, **opt_args)
+ elif opt_lower == 'bnbadam':
+ optimizer = bnb.optim.Adam(parameters, **opt_args)
+ elif opt_lower == 'bnbadam8bit':
+ optimizer = bnb.optim.Adam8bit(parameters, **opt_args)
+ elif opt_lower == 'bnbadamw':
+ optimizer = bnb.optim.AdamW(parameters, **opt_args)
+ elif opt_lower == 'bnbadamw8bit':
+ optimizer = bnb.optim.AdamW8bit(parameters, **opt_args)
+ elif opt_lower == 'bnblamb':
+ optimizer = bnb.optim.LAMB(parameters, **opt_args)
+ elif opt_lower == 'bnblamb8bit':
+ optimizer = bnb.optim.LAMB8bit(parameters, **opt_args)
+ elif opt_lower == 'bnblars':
+ optimizer = bnb.optim.LARS(parameters, **opt_args)
+ elif opt_lower == 'bnblarsb8bit':
+ optimizer = bnb.optim.LAMB8bit(parameters, **opt_args)
+ elif opt_lower == 'bnblion':
+ optimizer = bnb.optim.Lion(parameters, **opt_args)
+ elif opt_lower == 'bnblion8bit':
+ optimizer = bnb.optim.Lion8bit(parameters, **opt_args)
+
+ else:
+ assert False and "Invalid optimizer"
+ raise ValueError
+
+ if len(opt_split) > 1:
+ if opt_split[0] == 'lookahead':
+ optimizer = Lookahead(optimizer)
+
+ return optimizer
diff --git a/tmp_inputs_32_31/case00004.nii.gz b/tmp_inputs_32_31/case00004.nii.gz
new file mode 100644
index 0000000000000000000000000000000000000000..6287341581f7f4ec63805968f7d5c6aac9085708
--- /dev/null
+++ b/tmp_inputs_32_31/case00004.nii.gz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ab604fb141c148047832364ede166971ee7c905ed9ce623ef6ab4ef581d64569
+size 42153294